Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| db6715bec3 | |||
| da5d9e8a6a | |||
| 84b667ca7a | |||
| 29657106fc | |||
| 9c8860471e | |||
| 9b4e3f307e | |||
| 6fe37c3abf | |||
| 7f45493a37 | |||
| 891f6a5b5a | |||
| 7183f6b43d | |||
| d89bfeb441 | |||
| 9a0c6bed40 | |||
| d6ca535939 | |||
| 27302c0c02 | |||
| d4e22cceaa | |||
| 4c94927658 | |||
| a955a4a5c0 | |||
| 22d3f1a4f9 | |||
| e2443251ad | |||
| 5fbd53c616 | |||
| 97dae50dc4 | |||
| cb978f760f | |||
| 387f0ef6c4 | |||
| 18c134624d | |||
| da2326bdc7 | |||
| da46545630 |
@@ -30,4 +30,4 @@ jobs:
|
||||
version: '~> v2'
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
+2
-1
@@ -2,4 +2,5 @@
|
||||
.env
|
||||
build/
|
||||
dist/
|
||||
.vscode
|
||||
.vscode
|
||||
.DS_Store
|
||||
|
||||
@@ -2,6 +2,16 @@
|
||||
APP_NAME = llama-swap
|
||||
BUILD_DIR = build
|
||||
|
||||
# Get the current Git hash
|
||||
GIT_HASH := $(shell git rev-parse --short HEAD)
|
||||
ifneq ($(shell git status --porcelain),)
|
||||
# There are untracked changes
|
||||
GIT_HASH := $(GIT_HASH)+
|
||||
endif
|
||||
|
||||
# Capture the current build date in RFC3339 format
|
||||
BUILD_DATE := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
# Default target: Builds binaries for both OSX and Linux
|
||||
all: mac linux simple-responder
|
||||
|
||||
@@ -18,12 +28,12 @@ test-all:
|
||||
# Build OSX binary
|
||||
mac:
|
||||
@echo "Building Mac binary..."
|
||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
||||
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
||||
|
||||
# Build Linux binary
|
||||
linux:
|
||||
@echo "Building Linux binary..."
|
||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||
|
||||
# for testing proxy.Process
|
||||
simple-responder:
|
||||
@@ -35,5 +45,19 @@ simple-responder:
|
||||
$(BUILD_DIR):
|
||||
mkdir -p $(BUILD_DIR)
|
||||
|
||||
# Create a new release tag
|
||||
release:
|
||||
@echo "Checking for unstaged changes..."
|
||||
@if [ -n "$(shell git status --porcelain)" ]; then \
|
||||
echo "Error: There are unstaged changes. Please commit or stash your changes before creating a release tag." >&2; \
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
# Get the highest tag in v{number} format, increment it, and create a new tag
|
||||
@highest_tag=$$(git tag --sort=-v:refname | grep -E '^v[0-9]+$$' | head -n 1 || echo "v0"); \
|
||||
new_tag="v$$(( $${highest_tag#v} + 1 ))"; \
|
||||
echo "tagging new version: $$new_tag"; \
|
||||
git tag "$$new_tag";
|
||||
|
||||
# Phony targets
|
||||
.PHONY: all clean osx linux
|
||||
|
||||
@@ -2,19 +2,33 @@
|
||||
|
||||

|
||||
|
||||
llama-swap is a golang server that automatically swaps the llama.cpp server on demand. Since [llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, let's swap the server instead!
|
||||
# Introduction
|
||||
llama-swap is an OpenAI API compatible server that gives you complete control over how you use your hardware. It automatically swaps to the configuration of your choice for serving a model. Since [llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, let's swap the server instead!
|
||||
|
||||
Features:
|
||||
|
||||
- ✅ Easy to deploy: single binary with no dependencies
|
||||
- ✅ Single yaml configuration file
|
||||
- ✅ Automatic switching between models
|
||||
- ✅ Full control over llama.cpp server settings per model
|
||||
- ✅ OpenAI API support (`v1/completions` and `v1/chat/completions`)
|
||||
- ✅ Easy to config: single yaml file
|
||||
- ✅ On-demand model switching
|
||||
- ✅ Full control over server settings per model
|
||||
- ✅ OpenAI API support (`v1/completions`, `v1/chat/completions`, `v1/embeddings` and `v1/rerank`)
|
||||
- ✅ Multiple GPU support
|
||||
- ✅ Run multiple models at once with `profiles`
|
||||
- ✅ Remote log monitoring at `/log`
|
||||
- ✅ Automatic unloading of models from GPUs after timeout
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabblyAPI, etc)
|
||||
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
|
||||
## Releases
|
||||
|
||||
Builds for Linux and OSX are available on the [Releases](https://github.com/mostlygeek/llama-swap/releases) page.
|
||||
|
||||
### Building from source
|
||||
|
||||
1. Install golang for your system
|
||||
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
||||
1. `make clean all`
|
||||
1. Binaries will be in `build/` subdirectory
|
||||
|
||||
## config.yaml
|
||||
|
||||
@@ -25,6 +39,9 @@ llama-swap's configuration is purposefully simple.
|
||||
# Default (and minimum) is 15 seconds
|
||||
healthCheckTimeout: 60
|
||||
|
||||
# Write HTTP logs (useful for troubleshooting), defaults to false
|
||||
logRequests: true
|
||||
|
||||
# define valid model values and the upstream server start
|
||||
models:
|
||||
"llama":
|
||||
@@ -60,11 +77,17 @@ models:
|
||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# unlisted models do not show up in /v1/models or /upstream lists
|
||||
# but they can still be requested as normal
|
||||
"qwen-unlisted":
|
||||
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
unlisted: true
|
||||
|
||||
# profiles make it easy to managing multi model (and gpu) configurations.
|
||||
#
|
||||
# Tips:
|
||||
# - each model must be listening on a unique address and port
|
||||
# - the model name is in this format: "profile_name/model", like "coding/qwen"
|
||||
# - the model name is in this format: "profile_name:model", like "coding:qwen"
|
||||
# - the profile will load and unload all models in the profile at the same time
|
||||
profiles:
|
||||
coding:
|
||||
@@ -72,7 +95,11 @@ profiles:
|
||||
- "llama"
|
||||
```
|
||||
|
||||
More [examples](examples/README.md) are available for different use cases.
|
||||
**Advanced examples**
|
||||
|
||||
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
||||
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -83,22 +110,22 @@ More [examples](examples/README.md) are available for different use cases.
|
||||
|
||||
## Monitoring Logs
|
||||
|
||||
The `/logs` endpoint is available to monitor what llama-swap is doing. It will send the last 10KB of logs. Useful for monitoring the output of llama-server. It also supports streaming of logs.
|
||||
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
|
||||
|
||||
Usage:
|
||||
Of course, CLI access is also supported:
|
||||
|
||||
```
|
||||
# sends up to the last 10KB of logs
|
||||
curl http://host/logs'
|
||||
|
||||
# streams logs using chunk encoding
|
||||
# streams logs
|
||||
curl -Ns 'http://host/logs/stream'
|
||||
|
||||
# stream and filter logs with linux pipes
|
||||
curl -Ns http://host/logs/stream | grep 'eval time'
|
||||
|
||||
# skips history and just streams new log entries
|
||||
curl -Ns 'http://host/logs/stream?no-history'
|
||||
|
||||
# streams logs using Server Sent Events
|
||||
curl -Ns 'http://host/logs/streamSSE'
|
||||
```
|
||||
|
||||
## Systemd Unit Files
|
||||
@@ -125,9 +152,3 @@ StartLimitInterval=30
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
## Building from Source
|
||||
|
||||
1. Install golang for your system
|
||||
1. run `make clean all`
|
||||
1. binaries will be built into `build/` directory
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
# Default (and minimum): 15 seconds
|
||||
healthCheckTimeout: 15
|
||||
|
||||
# Log HTTP requests helpful for troubleshoot, defaults to False
|
||||
logRequests: true
|
||||
|
||||
models:
|
||||
"llama":
|
||||
cmd: >
|
||||
@@ -26,6 +29,31 @@ models:
|
||||
aliases:
|
||||
- gpt-3.5-turbo
|
||||
|
||||
# Embedding example with Nomic
|
||||
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
|
||||
"nomic":
|
||||
proxy: http://127.0.0.1:9005
|
||||
cmd: >
|
||||
models/llama-server-osx --port 9005
|
||||
-m models/nomic-embed-text-v1.5.Q8_0.gguf
|
||||
--ctx-size 8192
|
||||
--batch-size 8192
|
||||
--rope-scaling yarn
|
||||
--rope-freq-scale 0.75
|
||||
-ngl 99
|
||||
--embeddings
|
||||
|
||||
# Reranking example with bge-reranker
|
||||
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
|
||||
"bge-reranker":
|
||||
proxy: http://127.0.0.1:9006
|
||||
cmd: >
|
||||
models/llama-server-osx --port 9006
|
||||
-m models/bge-reranker-v2-m3-Q4_K_M.gguf
|
||||
--ctx-size 8192
|
||||
--reranking
|
||||
|
||||
|
||||
"simple":
|
||||
# example of setting environment variables
|
||||
env:
|
||||
@@ -33,6 +61,7 @@ models:
|
||||
- env1=hello
|
||||
cmd: build/simple-responder --port 8999
|
||||
proxy: http://127.0.0.1:8999
|
||||
unlisted: true
|
||||
|
||||
# use "none" to skip check. Caution this may cause some requests to fail
|
||||
# until the upstream server is ready for traffic
|
||||
@@ -42,9 +71,11 @@ models:
|
||||
"broken":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
unlisted: true
|
||||
"broken_timeout":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
proxy: http://127.0.0.1:9000
|
||||
unlisted: true
|
||||
|
||||
# creating a coding profile with models for code generation and general questions
|
||||
profiles:
|
||||
|
||||
+3
-6
@@ -1,9 +1,6 @@
|
||||
# Example Configurations
|
||||
# Example Configs and Use Cases
|
||||
|
||||
Learning by example is best.
|
||||
|
||||
Here in the `examples/` folder are llama-swap configurations that can be used on your local LLM server.
|
||||
|
||||
## List
|
||||
A collections of usecases and examples for getting the most out of llama-swap.
|
||||
|
||||
* [Speculative Decoding](speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||
* [Optimizing Code Generation](benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||
@@ -0,0 +1,123 @@
|
||||
# Optimizing Code Generation with llama-swap
|
||||
|
||||
Finding the best mix of settings for your hardware can be time consuming. This example demonstrates using a custom configuration file to automate testing different scenarios to find the an optimal configuration.
|
||||
|
||||
The benchmark writes a snake game in Python, TypeScript, and Swift using the Qwen 2.5 Coder models. The experiments were done using a 3090 and a P40.
|
||||
|
||||
**Benchmark Scenarios**
|
||||
|
||||
Three scenarios are tested:
|
||||
|
||||
- 3090-only: Just the main model on the 3090
|
||||
- 3090-with-draft: the main and draft models on the 3090
|
||||
- 3090-P40-draft: the main model on the 3090 with the draft model offloaded to the P40
|
||||
|
||||
**Available Devices**
|
||||
|
||||
Use the following command to list available devices IDs for the configuration:
|
||||
|
||||
```
|
||||
$ /mnt/nvme/llama-server/llama-server-f3252055 --list-devices
|
||||
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
|
||||
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
|
||||
ggml_cuda_init: found 4 CUDA devices:
|
||||
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
|
||||
Device 1: Tesla P40, compute capability 6.1, VMM: yes
|
||||
Device 2: Tesla P40, compute capability 6.1, VMM: yes
|
||||
Device 3: Tesla P40, compute capability 6.1, VMM: yes
|
||||
Available devices:
|
||||
CUDA0: NVIDIA GeForce RTX 3090 (24154 MiB, 406 MiB free)
|
||||
CUDA1: Tesla P40 (24438 MiB, 22942 MiB free)
|
||||
CUDA2: Tesla P40 (24438 MiB, 24144 MiB free)
|
||||
CUDA3: Tesla P40 (24438 MiB, 24144 MiB free)
|
||||
```
|
||||
|
||||
**Configuration**
|
||||
|
||||
The configuration file, `benchmark-config.yaml`, defines the three scenarios:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
"3090-only":
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
cmd: >
|
||||
/mnt/nvme/llama-server/llama-server-f3252055
|
||||
--host 127.0.0.1 --port 9503
|
||||
--flash-attn
|
||||
--slots
|
||||
|
||||
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
-ngl 99
|
||||
--device CUDA0
|
||||
|
||||
--ctx-size 32768
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
|
||||
"3090-with-draft":
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
# --ctx-size 28500 max that can fit on 3090 after draft model
|
||||
cmd: >
|
||||
/mnt/nvme/llama-server/llama-server-f3252055
|
||||
--host 127.0.0.1 --port 9503
|
||||
--flash-attn
|
||||
--slots
|
||||
|
||||
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
-ngl 99
|
||||
--device CUDA0
|
||||
|
||||
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
|
||||
-ngld 99
|
||||
--draft-max 16
|
||||
--draft-min 4
|
||||
--draft-p-min 0.4
|
||||
--device-draft CUDA0
|
||||
|
||||
--ctx-size 28500
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
|
||||
"3090-P40-draft":
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
cmd: >
|
||||
/mnt/nvme/llama-server/llama-server-f3252055
|
||||
--host 127.0.0.1 --port 9503
|
||||
--flash-attn --metrics
|
||||
--slots
|
||||
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
-ngl 99
|
||||
--device CUDA0
|
||||
|
||||
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
|
||||
-ngld 99
|
||||
--draft-max 16
|
||||
--draft-min 4
|
||||
--draft-p-min 0.4
|
||||
--device-draft CUDA1
|
||||
|
||||
--ctx-size 32768
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
```
|
||||
|
||||
> Note in the `3090-with-draft` scenario the `--ctx-size` had to be reduced from 32768 to to accommodate the draft model.
|
||||
|
||||
|
||||
**Running the Benchmark**
|
||||
|
||||
To run the benchmark, execute the following commands:
|
||||
|
||||
1. `llama-swap -config benchmark-config.yaml`
|
||||
1. `./run-benchmark.sh http://localhost:8080 "3090-only" "3090-with-draft" "3090-P40-draft"`
|
||||
|
||||
The [benchmark script](run-benchmark.sh) generates a CSV output of the results, which can be converted to a Markdown table for readability.
|
||||
|
||||
**Results (tokens/second)**
|
||||
|
||||
| model | python | typescript | swift |
|
||||
|-----------------|--------|------------|-------|
|
||||
| 3090-only | 34.03 | 34.01 | 34.01 |
|
||||
| 3090-with-draft | 106.65 | 70.48 | 57.89 |
|
||||
| 3090-P40-draft | 81.54 | 60.35 | 46.50 |
|
||||
|
||||
Many different factors, like the programming language, can have big impacts on the performance gains. However, with a custom configuration file for benchmarking it is easy to test the different variations to discover what's best for your hardware.
|
||||
|
||||
Happy coding!
|
||||
+40
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This script generates a CSV file showing the token/second for generating a Snake Game in python, typescript and swift
|
||||
# It was created to test the effects of speculative decoding and the various draft settings on performance.
|
||||
#
|
||||
# Writing code with a low temperature seems to provide fairly consistent logic.
|
||||
#
|
||||
# Usage: ./benchmark.sh <url> <model1> [model2 ...]
|
||||
# Example: ./benchmark.sh http://localhost:8080 model1 model2
|
||||
|
||||
if [ "$#" -lt 2 ]; then
|
||||
echo "Usage: $0 <url> <model1> [model2 ...]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
url=$1; shift
|
||||
|
||||
echo "model,python,typescript,swift"
|
||||
|
||||
for model in "$@"; do
|
||||
|
||||
echo -n "$model,"
|
||||
|
||||
for lang in "python" "typescript" "swift"; do
|
||||
# expects a llama.cpp after PR https://github.com/ggerganov/llama.cpp/pull/10548
|
||||
# (Dec 3rd/2024)
|
||||
time=$(curl -s --url "$url/v1/chat/completions" -d "{\"messages\": [{\"role\": \"system\", \"content\": \"you only write code.\"}, {\"role\": \"user\", \"content\": \"write snake game in $lang\"}], \"top_k\": 1, \"timings_per_token\":true, \"model\":\"$model\"}" | jq -r .timings.predicted_per_second)
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
time="error"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "$lang" != "swift" ]; then
|
||||
printf "%0.2f tps," $time
|
||||
else
|
||||
printf "%0.2f tps\n" $time
|
||||
fi
|
||||
done
|
||||
done
|
||||
@@ -32,9 +32,9 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.23.0 // indirect
|
||||
golang.org/x/net v0.25.0 // indirect
|
||||
golang.org/x/sys v0.20.0 // indirect
|
||||
golang.org/x/text v0.15.0 // indirect
|
||||
golang.org/x/crypto v0.31.0 // indirect
|
||||
golang.org/x/net v0.33.0 // indirect
|
||||
golang.org/x/sys v0.28.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
)
|
||||
|
||||
@@ -66,14 +66,22 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
|
||||
@@ -9,13 +9,23 @@ import (
|
||||
"github.com/mostlygeek/llama-swap/proxy"
|
||||
)
|
||||
|
||||
var version string = "0"
|
||||
var commit string = "abcd1234"
|
||||
var date = "unknown"
|
||||
|
||||
func main() {
|
||||
// Define a command-line flag for the port
|
||||
configPath := flag.String("config", "config.yaml", "config file name")
|
||||
listenStr := flag.String("listen", ":8080", "listen ip/port")
|
||||
showVersion := flag.Bool("version", false, "show version of build")
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
if *showVersion {
|
||||
fmt.Printf("version: %s (%s), built at %s\n", version, commit, date)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
config, err := proxy.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 51 KiB |
@@ -0,0 +1,4 @@
|
||||
The rerank-test.json data is from https://github.com/ggerganov/llama.cpp/pull/9510
|
||||
|
||||
To run it:
|
||||
> curl http://127.0.0.1:8080/v1/rerank -H "Content-Type: application/json" -d @reranker-test.json -v | jq .
|
||||
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"model": "bge-reranker",
|
||||
"query": "Organic skincare products for sensitive skin",
|
||||
"top_n": 3,
|
||||
"documents": [
|
||||
"Organic skincare for sensitive skin with aloe vera and chamomile: Imagine the soothing embrace of nature with our organic skincare range, crafted specifically for sensitive skin. Infused with the calming properties of aloe vera and chamomile, each product provides gentle nourishment and protection. Say goodbye to irritation and hello to a glowing, healthy complexion.",
|
||||
"New makeup trends focus on bold colors and innovative techniques: Step into the world of cutting-edge beauty with this seasons makeup trends. Bold, vibrant colors and groundbreaking techniques are redefining the art of makeup. From neon eyeliners to holographic highlighters, unleash your creativity and make a statement with every look.",
|
||||
"Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille: Erleben Sie die wohltuende Wirkung unserer Bio-Hautpflege, speziell für empfindliche Haut entwickelt. Mit den beruhigenden Eigenschaften von Aloe Vera und Kamille pflegen und schützen unsere Produkte Ihre Haut auf natürliche Weise. Verabschieden Sie sich von Hautirritationen und genießen Sie einen strahlenden Teint.",
|
||||
"Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken: Tauchen Sie ein in die Welt der modernen Schönheit mit den neuesten Make-up-Trends. Kräftige, lebendige Farben und innovative Techniken setzen neue Maßstäbe. Von auffälligen Eyelinern bis hin zu holografischen Highlightern – lassen Sie Ihrer Kreativität freien Lauf und setzen Sie jedes Mal ein Statement.",
|
||||
"Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla: Descubre el poder de la naturaleza con nuestra línea de cuidado de la piel orgánico, diseñada especialmente para pieles sensibles. Enriquecidos con aloe vera y manzanilla, estos productos ofrecen una hidratación y protección suave. Despídete de las irritaciones y saluda a una piel radiante y saludable.",
|
||||
"Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras: Entra en el fascinante mundo del maquillaje con las tendencias más actuales. Colores vivos y técnicas innovadoras están revolucionando el arte del maquillaje. Desde delineadores neón hasta iluminadores holográficos, desata tu creatividad y destaca en cada look.",
|
||||
"针对敏感肌专门设计的天然有机护肤产品:体验由芦荟和洋甘菊提取物带来的自然呵护。我们的护肤产品特别为敏感肌设计,温和滋润,保护您的肌肤不受刺激。让您的肌肤告别不适,迎来健康光彩。",
|
||||
"新的化妆趋势注重鲜艳的颜色和创新的技巧:进入化妆艺术的新纪元,本季的化妆趋势以大胆的颜色和创新的技巧为主。无论是霓虹眼线还是全息高光,每一款妆容都能让您脱颖而出,展现独特魅力。",
|
||||
"敏感肌のために特別に設計された天然有機スキンケア製品: アロエベラとカモミールのやさしい力で、自然の抱擁を感じてください。敏感肌用に特別に設計された私たちのスキンケア製品は、肌に優しく栄養を与え、保護します。肌トラブルにさようなら、輝く健康な肌にこんにちは。",
|
||||
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています: 今シーズンのメイクアップトレンドは、大胆な色彩と革新的な技術に注目しています。ネオンアイライナーからホログラフィックハイライターまで、クリエイティビティを解き放ち、毎回ユニークなルックを演出しましょう。"
|
||||
]
|
||||
}
|
||||
@@ -16,6 +16,7 @@ type ModelConfig struct {
|
||||
Env []string `yaml:"env"`
|
||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||
UnloadAfter int `yaml:"ttl"`
|
||||
Unlisted bool `yaml:"unlisted"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
@@ -24,6 +25,7 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
Models map[string]ModelConfig `yaml:"models"`
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 15 KiB |
@@ -0,0 +1,53 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Logs</title>
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
font-family: "Courier New", Courier, monospace;
|
||||
}
|
||||
#log-stream {
|
||||
flex: 1;
|
||||
margin: 1em;
|
||||
padding: 10px;
|
||||
background: #f4f4f4;
|
||||
overflow-y: auto;
|
||||
white-space: pre-wrap; /* Ensures line wrapping */
|
||||
word-wrap: break-word; /* Ensures long words wrap */
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<pre id="log-stream">Waiting for logs...
|
||||
</pre>
|
||||
|
||||
<script>
|
||||
// Establish an EventSource connection to the SSE endpoint
|
||||
if (typeof(EventSource) !== "undefined") {
|
||||
const eventSource = new EventSource("/logs/streamSSE");
|
||||
|
||||
eventSource.onmessage = function(event) {
|
||||
// Append the new log message to the <pre> element
|
||||
const logStream = document.getElementById('log-stream');
|
||||
|
||||
logStream.textContent += event.data;
|
||||
|
||||
// Auto-scroll to the bottom
|
||||
logStream.scrollTop = logStream.scrollHeight;
|
||||
};
|
||||
|
||||
eventSource.onerror = function(err) {
|
||||
console.error("EventSource failed:", err);
|
||||
};
|
||||
} else {
|
||||
console.error("SSE not supported by this browser.");
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
+1
-1
@@ -46,7 +46,7 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
||||
w.buffer = w.buffer.Next()
|
||||
w.bufferMu.Unlock()
|
||||
|
||||
w.broadcast(p)
|
||||
w.broadcast(bufferCopy)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
|
||||
+22
-17
@@ -122,17 +122,20 @@ func (p *Process) start() error {
|
||||
// start a goroutine to check every second if
|
||||
// the process should be stopped
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||
|
||||
for {
|
||||
<-ticker.C
|
||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %d reached.\n", p.ID, p.config.UnloadAfter)
|
||||
p.Stop()
|
||||
for range time.Tick(time.Second) {
|
||||
if p.state != StateReady {
|
||||
return
|
||||
}
|
||||
|
||||
// wait for all inflight requests to complete and ticker
|
||||
p.inFlightRequests.Wait()
|
||||
|
||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
|
||||
p.Stop()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -163,25 +166,25 @@ func (p *Process) Stop() {
|
||||
// will be a source of pain in the future.
|
||||
|
||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
sigtermNormal := make(chan error, 1)
|
||||
go func() {
|
||||
done <- p.cmd.Wait()
|
||||
sigtermNormal <- p.cmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Printf("!!! process for %s timed out waiting to stop\n", p.ID)
|
||||
case <-sigtermTimeout.Done():
|
||||
fmt.Fprintf(p.logMonitor, "!!! process for %s timed out waiting to stop\n", p.ID)
|
||||
p.cmd.Process.Kill()
|
||||
p.cmd.Wait()
|
||||
case err := <-done:
|
||||
case err := <-sigtermNormal:
|
||||
if err != nil {
|
||||
if err.Error() != "wait: no child processes" {
|
||||
// possible that simple-responder for testing is just not
|
||||
// existing right, so suppress those errors.
|
||||
fmt.Printf("!!! process for %s stopped with error > %v\n", p.ID, err)
|
||||
fmt.Fprintf(p.logMonitor, "!!! process for %s stopped with error > %v\n", p.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -275,7 +278,11 @@ func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error {
|
||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
p.inFlightRequests.Add(1)
|
||||
defer p.inFlightRequests.Done()
|
||||
|
||||
defer func() {
|
||||
p.lastRequestHandled = time.Now()
|
||||
p.inFlightRequests.Done()
|
||||
}()
|
||||
|
||||
if p.CurrentState() != StateReady {
|
||||
if err := p.start(); err != nil {
|
||||
@@ -285,8 +292,6 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
p.lastRequestHandled = time.Now()
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
|
||||
|
||||
+18
-5
@@ -82,18 +82,31 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard))
|
||||
defer process.Stop()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
// this should take 4 seconds
|
||||
req1 := httptest.NewRequest("GET", "/slow-respond?echo=1234&delay=1000ms", nil)
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Proxy the request (auto start)
|
||||
process.ProxyRequest(w, req)
|
||||
// Proxy the request (auto start) with a slow response that takes longer than config.UnloadAfter
|
||||
process.ProxyRequest(w, req1)
|
||||
|
||||
t.Log("sending slow first request (4 seconds)")
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "1234")
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// ensure the TTL timeout does not race slow requests (see issue #25)
|
||||
t.Log("sending second request (1 second)")
|
||||
time.Sleep(time.Second)
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req2)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// wait 5 seconds
|
||||
t.Log("sleep 5 seconds and check if unloaded")
|
||||
time.Sleep(5 * time.Second)
|
||||
assert.Equal(t, StateStopped, process.CurrentState())
|
||||
}
|
||||
@@ -101,7 +114,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
// issue #19
|
||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long test")
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
expectedMessage := "12345"
|
||||
|
||||
+124
-21
@@ -2,10 +2,12 @@ package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -18,6 +20,15 @@ const (
|
||||
PROFILE_SPLIT_CHAR = ":"
|
||||
)
|
||||
|
||||
//go:embed html/favicon.ico
|
||||
var faviconData []byte
|
||||
|
||||
//go:embed html/logs.html
|
||||
var logsHTML []byte
|
||||
|
||||
// make sure embed is kept there by the IDE auto-package importer
|
||||
var _ = embed.FS{}
|
||||
|
||||
type ProxyManager struct {
|
||||
sync.Mutex
|
||||
|
||||
@@ -35,11 +46,47 @@ func New(config *Config) *ProxyManager {
|
||||
ginEngine: gin.New(),
|
||||
}
|
||||
|
||||
// Set up routes using the Gin engine
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyChatRequestHandler)
|
||||
if config.LogRequests {
|
||||
pm.ginEngine.Use(func(c *gin.Context) {
|
||||
// Start timer
|
||||
start := time.Now()
|
||||
|
||||
// capture these because /upstream/:model rewrites them in c.Next()
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// Process request
|
||||
c.Next()
|
||||
|
||||
// Stop timer
|
||||
duration := time.Since(start)
|
||||
|
||||
statusCode := c.Writer.Status()
|
||||
bodySize := c.Writer.Size()
|
||||
|
||||
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
|
||||
clientIP,
|
||||
time.Now().Format("2006-01-02 15:04:05"),
|
||||
method,
|
||||
path,
|
||||
c.Request.Proto,
|
||||
statusCode,
|
||||
bodySize,
|
||||
c.Request.UserAgent(),
|
||||
duration,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// Set up routes using the Gin engine
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
||||
// Support legacy /v1/completions api, see issue #12
|
||||
pm.ginEngine.POST("/v1/completions", pm.proxyChatRequestHandler)
|
||||
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
|
||||
|
||||
// Support embeddings
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
||||
|
||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||
|
||||
@@ -48,7 +95,12 @@ func New(config *Config) *ProxyManager {
|
||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
||||
|
||||
pm.ginEngine.NoRoute(pm.proxyNoRouteHandler)
|
||||
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
||||
|
||||
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
||||
c.Data(http.StatusOK, "image/x-icon", faviconData)
|
||||
})
|
||||
|
||||
// Disable console color for testing
|
||||
gin.DisableConsoleColor()
|
||||
@@ -86,7 +138,11 @@ func (pm *ProxyManager) stopProcesses() {
|
||||
|
||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
data := []interface{}{}
|
||||
for id := range pm.config.Models {
|
||||
for id, modelConfig := range pm.config.Models {
|
||||
if modelConfig.Unlisted {
|
||||
continue
|
||||
}
|
||||
|
||||
data = append(data, map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "model",
|
||||
@@ -98,9 +154,13 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
// Set the Content-Type header to application/json
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
if origin := c.Request.Header.Get("Origin"); origin != "" {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
// Encode the data as JSON and write it to the response writer
|
||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error encoding JSON"))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -109,7 +169,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
// Check if requestedModel contains a /
|
||||
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
||||
profileName, modelName := "", requestedModel
|
||||
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||
profileName = requestedModel[:idx]
|
||||
@@ -166,25 +226,68 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
||||
return pm.currentProcesses[requestedProcessKey], nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("invalid JSON"))
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
requestedModel := c.Param("model_id")
|
||||
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||
return
|
||||
}
|
||||
|
||||
if process, err := pm.swapModel(requestedModel); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
} else {
|
||||
// rewrite the path
|
||||
c.Request.URL.Path = c.Param("upstreamPath")
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
|
||||
var html strings.Builder
|
||||
|
||||
html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><ul>")
|
||||
|
||||
// Extract keys and sort them
|
||||
var modelIDs []string
|
||||
for modelID, modelConfig := range pm.config.Models {
|
||||
if modelConfig.Unlisted {
|
||||
continue
|
||||
}
|
||||
|
||||
modelIDs = append(modelIDs, modelID)
|
||||
}
|
||||
sort.Strings(modelIDs)
|
||||
|
||||
// Iterate over sorted keys
|
||||
for _, modelID := range modelIDs {
|
||||
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a></li>", modelID, modelID))
|
||||
}
|
||||
html.WriteString("</ul></body></html>")
|
||||
c.Header("Content-Type", "text/html")
|
||||
c.String(http.StatusOK, html.String())
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||
return
|
||||
}
|
||||
|
||||
var requestBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
|
||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("invalid JSON"))
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
model, ok := requestBody["model"].(string)
|
||||
if !ok {
|
||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("missing or invalid 'model' key"))
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
return
|
||||
}
|
||||
|
||||
if process, err := pm.swapModel(model); err != nil {
|
||||
c.AbortWithError(http.StatusNotFound, fmt.Errorf("unable to swap to model, %s", err.Error()))
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
return
|
||||
} else {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
@@ -197,14 +300,14 @@ func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) {
|
||||
// since maps are unordered, just use the first available process if one exists
|
||||
for _, process := range pm.currentProcesses {
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
return
|
||||
}
|
||||
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
||||
acceptHeader := c.GetHeader("Accept")
|
||||
|
||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request"))
|
||||
if strings.Contains(acceptHeader, "application/json") {
|
||||
c.JSON(statusCode, gin.H{"error": message})
|
||||
} else {
|
||||
c.String(statusCode, message)
|
||||
}
|
||||
}
|
||||
|
||||
func ProcessKeyName(groupName, modelName string) string {
|
||||
|
||||
@@ -3,17 +3,32 @@ package proxy
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
history := pm.logMonitor.GetHistory()
|
||||
_, err := c.Writer.Write(history)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, err)
|
||||
return
|
||||
|
||||
accept := c.GetHeader("Accept")
|
||||
if strings.Contains(accept, "text/html") {
|
||||
// Set the Content-Type header to text/html
|
||||
c.Header("Content-Type", "text/html")
|
||||
|
||||
// Write the embedded HTML content to the response
|
||||
_, err := c.Writer.Write(logsHTML)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to write response: %v", err))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
history := pm.logMonitor.GetHistory()
|
||||
_, err := c.Writer.Write(history)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -141,3 +142,71 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
assert.Equal(t, key, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
req.Header.Add("Origin", "i-am-the-origin")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the listModelsHandler
|
||||
proxy.HandlerFunc(w, req)
|
||||
|
||||
// Check the response status code
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Check for Access-Control-Allow-Origin
|
||||
assert.Equal(t, req.Header.Get("Origin"), w.Result().Header.Get("Access-Control-Allow-Origin"))
|
||||
|
||||
// Parse the JSON response
|
||||
var response struct {
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||
}
|
||||
|
||||
// Check the number of models returned
|
||||
assert.Len(t, response.Data, 3)
|
||||
|
||||
// Check the details of each model
|
||||
expectedModels := map[string]struct{}{
|
||||
"model1": {},
|
||||
"model2": {},
|
||||
"model3": {},
|
||||
}
|
||||
|
||||
for _, model := range response.Data {
|
||||
modelID, ok := model["id"].(string)
|
||||
assert.True(t, ok, "model ID should be a string")
|
||||
_, exists := expectedModels[modelID]
|
||||
assert.True(t, exists, "unexpected model ID: %s", modelID)
|
||||
delete(expectedModels, modelID)
|
||||
|
||||
object, ok := model["object"].(string)
|
||||
assert.True(t, ok, "object should be a string")
|
||||
assert.Equal(t, "model", object)
|
||||
|
||||
created, ok := model["created"].(float64)
|
||||
assert.True(t, ok, "created should be a number")
|
||||
assert.Greater(t, created, float64(0)) // Assuming the timestamp is positive
|
||||
|
||||
ownedBy, ok := model["owned_by"].(string)
|
||||
assert.True(t, ok, "owned_by should be a string")
|
||||
assert.Equal(t, "llama-swap", ownedBy)
|
||||
}
|
||||
|
||||
// Ensure all expected models were returned
|
||||
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user