Compare commits
58 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 593604dfdc | |||
| b8f888f864 | |||
| 192b2ae621 | |||
| b7f8cb5094 | |||
| a23da6eb57 | |||
| 4c3aa40564 | |||
| 84e2c07a7e | |||
| 680af28bcc | |||
| d94db42ffe | |||
| 93cd83c55c | |||
| 5565fca3ac | |||
| d625ab8d92 | |||
| a3f82c140b | |||
| 5c97299e7b | |||
| 671c1a5a7b | |||
| 52c0196e0f | |||
| 3201a68a04 | |||
| 3ac94ad20e | |||
| 60355bf74a | |||
| 9b2ed244e2 | |||
| eeb72297f7 | |||
| eabfe70cc6 | |||
| 29cd98878d | |||
| b3d331da0d | |||
| 62275e078d | |||
| 88916059e1 | |||
| 082d5d0fc5 | |||
| 53338938bd | |||
| af653347ae | |||
| 1e25b44a06 | |||
| 0815bb4cc3 | |||
| 7187cfe52e | |||
| 24089d2d9c | |||
| ebabe55ff3 | |||
| 41a338297c | |||
| 7e3353efeb | |||
| 4ed58fb173 | |||
| f5a2be698d | |||
| f5e6ec3b7a | |||
| 3f462da146 | |||
| 48bd766536 | |||
| 8d319da4dd | |||
| be7c502448 | |||
| 92336f00bf | |||
| ed2a50d9a6 | |||
| 0acfdb9f78 | |||
| 96a8ea0241 | |||
| f20f2c9b7a | |||
| 7a97c38828 | |||
| 4885132565 | |||
| 8b46a0b7f1 | |||
| 1b6736ec6f | |||
| ddc1ce031e | |||
| 11d024bbaa | |||
| 43e23c16dc | |||
| f9c8e763ba | |||
| d7e1bb9f7c | |||
| ab93460a8b |
@@ -0,0 +1,23 @@
|
||||
# https://docs.github.com/en/actions/use-cases-and-examples/project-management/closing-inactive-issues
|
||||
name: Close inactive issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "32 1 * * *"
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -0,0 +1,46 @@
|
||||
name: Build Containers
|
||||
|
||||
on:
|
||||
# time has no specific meaning, trying to time it after
|
||||
# the llama.cpp daily packages are published
|
||||
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
|
||||
schedule:
|
||||
- cron: "37 5 * * *"
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [intel, cuda, vulkan, cpu, musa]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Run build-container
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: ./docker/build-container.sh ${{ matrix.platform }} true
|
||||
|
||||
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
|
||||
# see: https://github.com/actions/delete-package-versions/issues/74
|
||||
delete-untagged-containers:
|
||||
needs: build-and-push
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/delete-package-versions@v5
|
||||
with:
|
||||
package-name: 'llama-swap'
|
||||
package-type: 'container'
|
||||
delete-only-untagged-versions: 'true'
|
||||
@@ -0,0 +1,32 @@
|
||||
# This workflow will build a golang project
|
||||
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
run-tests:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.23'
|
||||
|
||||
# necessary for testing proxy/Process swapping
|
||||
- name: Create simple-responder
|
||||
run: make simple-responder
|
||||
|
||||
- name: Test all
|
||||
run: make test-all
|
||||
@@ -5,6 +5,9 @@ on:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
|
||||
+16
-1
@@ -7,9 +7,24 @@ builds:
|
||||
- linux
|
||||
- darwin
|
||||
- freebsd
|
||||
- windows
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
ignore:
|
||||
- goos: freebsd
|
||||
goarch: arm64
|
||||
goarch: arm64
|
||||
- goos: windows
|
||||
goarch: arm64
|
||||
|
||||
# use zip format for windows
|
||||
archives:
|
||||
- id: default
|
||||
format: tar.gz
|
||||
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||
builds_info:
|
||||
group: root
|
||||
owner: root
|
||||
format_overrides:
|
||||
- goos: windows
|
||||
format: zip
|
||||
@@ -35,6 +35,11 @@ linux:
|
||||
@echo "Building Linux binary..."
|
||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||
|
||||
# Build Windows binary
|
||||
windows:
|
||||
@echo "Building Windows binary..."
|
||||
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
||||
|
||||
# for testing proxy.Process
|
||||
simple-responder:
|
||||
@echo "Building simple responder"
|
||||
@@ -60,4 +65,4 @@ release:
|
||||
git tag "$$new_tag";
|
||||
|
||||
# Phony targets
|
||||
.PHONY: all clean mac linux
|
||||
.PHONY: all clean mac linux windows simple-responder
|
||||
|
||||
@@ -1,66 +1,94 @@
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
# llama-swap
|
||||
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file).
|
||||
|
||||
Download a pre-built [release](https://github.com/mostlygeek/llama-swap/releases) or build it yourself from source with `make clean all`.
|
||||
|
||||
## How does it work?
|
||||
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If a server is already running it will stop it and start the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
|
||||
|
||||
## Do I need to use llama.cpp's server (llama-server)?
|
||||
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
|
||||
|
||||
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
||||
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
|
||||
|
||||
## Features:
|
||||
|
||||
- ✅ Easy to deploy: single binary with no dependencies
|
||||
- ✅ Easy to config: single yaml file
|
||||
- ✅ On-demand model switching
|
||||
- ✅ Full control over server settings per model
|
||||
- ✅ OpenAI API supported endpoints:
|
||||
- `v1/completions`
|
||||
- `v1/chat/completions`
|
||||
- `v1/embeddings`
|
||||
- `v1/rerank`
|
||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||
- ✅ Multiple GPU support
|
||||
- ✅ Docker and Podman support
|
||||
- ✅ Run multiple models at once with `profiles`
|
||||
- ✅ Remote log monitoring at `/log`
|
||||
- ✅ Automatic unloading of models from GPUs after timeout
|
||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||
- ✅ llama-swap custom API endpoints
|
||||
- `/log` - remote log monitoring
|
||||
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
|
||||
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
- ✅ Docker and Podman support
|
||||
- ✅ Full control over server settings per model
|
||||
|
||||
## How does llama-swap work?
|
||||
|
||||
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
|
||||
|
||||
## config.yaml
|
||||
|
||||
llama-swap's configuration is purposefully simple.
|
||||
|
||||
```yaml
|
||||
models:
|
||||
"qwen2.5":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
|
||||
"smollm2":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>But also very powerful ...</summary>
|
||||
|
||||
```yaml
|
||||
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
||||
# Default (and minimum) is 15 seconds
|
||||
healthCheckTimeout: 60
|
||||
|
||||
# Write HTTP logs (useful for troubleshooting), defaults to false
|
||||
logRequests: true
|
||||
# Valid log levels: debug, info (default), warn, error
|
||||
logLevel: info
|
||||
|
||||
# define valid model values and the upstream server start
|
||||
models:
|
||||
"llama":
|
||||
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
||||
# multiline for readability
|
||||
cmd: >
|
||||
llama-server --port 8999
|
||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
|
||||
# environment variables to pass to the command
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
|
||||
# where to reach the server started by cmd, make sure the ports match
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# aliases names to use this model for
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
|
||||
# check this path for an HTTP 200 OK before serving requests
|
||||
# default: /health to match llama.cpp
|
||||
@@ -73,22 +101,15 @@ models:
|
||||
# default: 0 = never unload model
|
||||
ttl: 60
|
||||
|
||||
"qwen":
|
||||
# environment variables to pass to the command
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
|
||||
# multiline for readability
|
||||
cmd: >
|
||||
llama-server --port 8999
|
||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
# `useModelName` overrides the model name in the request
|
||||
# and sends a specific name to the upstream server
|
||||
useModelName: "qwen:qwq"
|
||||
|
||||
# unlisted models do not show up in /v1/models or /upstream lists
|
||||
# but they can still be requested as normal
|
||||
"qwen-unlisted":
|
||||
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
unlisted: true
|
||||
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
|
||||
# Docker Support (v26.1.4+ required!)
|
||||
"docker-llama":
|
||||
@@ -99,7 +120,7 @@ models:
|
||||
ghcr.io/ggerganov/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
# profiles make it easy to managing multi model (and gpu) configurations.
|
||||
# profiles eliminates swapping by running multiple models at the same time
|
||||
#
|
||||
# Tips:
|
||||
# - each model must be listening on a unique address and port
|
||||
@@ -107,21 +128,78 @@ models:
|
||||
# - the profile will load and unload all models in the profile at the same time
|
||||
profiles:
|
||||
coding:
|
||||
- "qwen"
|
||||
- "llama"
|
||||
- "qwen-unlisted"
|
||||
```
|
||||
|
||||
### Advanced Examples
|
||||
### Use Case Examples
|
||||
|
||||
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
||||
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
|
||||
|
||||
### Installation
|
||||
## Configuration
|
||||
|
||||
llama-s
|
||||
|
||||
</details>
|
||||
|
||||
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
|
||||
Docker is the quickest way to try out llama-swap:
|
||||
|
||||
```
|
||||
# use CPU inference
|
||||
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
|
||||
|
||||
|
||||
# qwen2.5 0.5B
|
||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer no-key" \
|
||||
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||
jq -r '.choices[0].message.content'
|
||||
|
||||
|
||||
# SmolLM2 135M
|
||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer no-key" \
|
||||
-d '{"model":"smollm2","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||
jq -r '.choices[0].message.content'
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Docker images are nightly ...</summary>
|
||||
|
||||
They include:
|
||||
|
||||
- `ghcr.io/mostlygeek/llama-swap:cpu`
|
||||
- `ghcr.io/mostlygeek/llama-swap:cuda`
|
||||
- `ghcr.io/mostlygeek/llama-swap:intel`
|
||||
- `ghcr.io/mostlygeek/llama-swap:vulkan`
|
||||
- ROCm disabled until fixed in llama.cpp container
|
||||
|
||||
Specific versions are also available and are tagged with the llama-swap, architecture and llama.cpp versions. For example: `ghcr.io/mostlygeek/llama-swap:v89-cuda-b4716`
|
||||
|
||||
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
|
||||
|
||||
```
|
||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
ghcr.io/mostlygeek/llama-swap:cuda
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases))
|
||||
|
||||
Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server.
|
||||
|
||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||
* _Note: Windows currently untested._
|
||||
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
||||
|
||||
### Building from source
|
||||
@@ -141,9 +219,15 @@ Of course, CLI access is also supported:
|
||||
# sends up to the last 10KB of logs
|
||||
curl http://host/logs'
|
||||
|
||||
# streams logs
|
||||
# streams combined logs
|
||||
curl -Ns 'http://host/logs/stream'
|
||||
|
||||
# just llama-swap's logs
|
||||
curl -Ns 'http://host/logs/stream/proxy'
|
||||
|
||||
# just upstream's logs
|
||||
curl -Ns 'http://host/logs/stream/upstream'
|
||||
|
||||
# stream and filter logs with linux pipes
|
||||
curl -Ns http://host/logs/stream | grep 'eval time'
|
||||
|
||||
@@ -151,11 +235,18 @@ curl -Ns http://host/logs/stream | grep 'eval time'
|
||||
curl -Ns 'http://host/logs/stream?no-history'
|
||||
```
|
||||
|
||||
## Do I need to use llama.cpp's server (llama-server)?
|
||||
|
||||
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
|
||||
|
||||
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
||||
|
||||
## Systemd Unit Files
|
||||
|
||||
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
||||
|
||||
`/etc/systemd/system/llama-swap.service`
|
||||
|
||||
```
|
||||
[Unit]
|
||||
Description=llama-swap
|
||||
@@ -175,3 +266,11 @@ StartLimitInterval=30
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
## Star History
|
||||
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
|
||||
</picture>
|
||||
|
||||
+3
-3
@@ -1,9 +1,9 @@
|
||||
# Seconds to wait for llama.cpp to be available to serve requests
|
||||
# Default (and minimum): 15 seconds
|
||||
healthCheckTimeout: 15
|
||||
healthCheckTimeout: 90
|
||||
|
||||
# Log HTTP requests helpful for troubleshoot, defaults to False
|
||||
logRequests: true
|
||||
# valid log levels: debug, info (default), warn, error
|
||||
logLevel: info
|
||||
|
||||
models:
|
||||
"llama":
|
||||
|
||||
Executable
+55
@@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd $(dirname "$0")
|
||||
|
||||
ARCH=$1
|
||||
PUSH_IMAGES=${2:-false}
|
||||
|
||||
# List of allowed architectures
|
||||
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
|
||||
|
||||
# Check if ARCH is in the allowed list
|
||||
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
|
||||
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if GITHUB_TOKEN is set and not empty
|
||||
if [[ -z "$GITHUB_TOKEN" ]]; then
|
||||
echo "Error: GITHUB_TOKEN is not set or is empty."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# the most recent llama-swap tag
|
||||
# have to strip out the 'v' due to .tar.gz file naming
|
||||
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||
|
||||
if [ "$ARCH" == "cpu" ]; then
|
||||
# cpu only containers just use the latest available
|
||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu"
|
||||
echo "Building ${CONTAINER_LATEST} $LS_VER"
|
||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} .
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
else
|
||||
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
||||
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||
|
||||
# Abort if LCPP_TAG is empty.
|
||||
if [[ -z "$LCPP_TAG" ]]; then
|
||||
echo "Abort: Could not find llama-server container for arch: $ARCH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
|
||||
echo "Building ${CONTAINER_TAG} $LS_VER"
|
||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_TAG}
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
fi
|
||||
@@ -0,0 +1,17 @@
|
||||
healthCheckTimeout: 300
|
||||
logRequests: true
|
||||
|
||||
models:
|
||||
"qwen2.5":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
|
||||
"smollm2":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
@@ -0,0 +1,16 @@
|
||||
ARG BASE_TAG=server-cuda
|
||||
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
|
||||
|
||||
# has to be after the FROM
|
||||
ARG LS_VER=89
|
||||
|
||||
WORKDIR /app
|
||||
RUN \
|
||||
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
|
||||
|
||||
COPY config.example.yaml /app/config.yaml
|
||||
|
||||
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
||||
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||
@@ -0,0 +1,51 @@
|
||||
# Restart llama-swap on config change
|
||||
|
||||
Sometimes editing the configuration file can take a bit of trail and error to get a model configuration tuned just right. The `watch-and-restart.sh` script can be used to watch `config.yaml` for changes and restart `llama-swap` when it detects a change.
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#
|
||||
# A simple watch and restart llama-swap when its configuration
|
||||
# file changes. Useful for trying out configuration changes
|
||||
# without manually restarting the server each time.
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: $0 <path to config.yaml>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
while true; do
|
||||
# Start the process again
|
||||
./llama-swap-linux-amd64 -config $1 -listen :1867 &
|
||||
PID=$!
|
||||
echo "Started llama-swap with PID $PID"
|
||||
|
||||
# Wait for modifications in the specified directory or file
|
||||
inotifywait -e modify "$1"
|
||||
|
||||
# Check if process exists before sending signal
|
||||
if kill -0 $PID 2>/dev/null; then
|
||||
echo "Sending SIGTERM to $PID"
|
||||
kill -SIGTERM $PID
|
||||
wait $PID
|
||||
else
|
||||
echo "Process $PID no longer exists"
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
```
|
||||
|
||||
## Usage and output example
|
||||
|
||||
```bash
|
||||
$ ./watch-and-restart.sh config.yaml
|
||||
Started llama-swap with PID 495455
|
||||
Setting up watches.
|
||||
Watches established.
|
||||
llama-swap listening on :1867
|
||||
Sending SIGTERM to 495455
|
||||
Shutting down llama-swap
|
||||
Started llama-swap with PID 495486
|
||||
Setting up watches.
|
||||
Watches established.
|
||||
llama-swap listening on :1867
|
||||
```
|
||||
@@ -3,7 +3,11 @@ module github.com/mostlygeek/llama-swap
|
||||
go 1.23.0
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -15,12 +19,10 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/gin-gonic/gin v1.10.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
@@ -29,12 +31,14 @@ require (
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.31.0 // indirect
|
||||
golang.org/x/net v0.33.0 // indirect
|
||||
golang.org/x/sys v0.28.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/net v0.37.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
)
|
||||
|
||||
@@ -57,6 +57,16 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
@@ -68,20 +78,28 @@ golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
||||
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
|
||||
@@ -12,12 +12,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func main() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
// Define a command-line flag for the port
|
||||
port := flag.String("port", "8080", "port to listen on")
|
||||
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
|
||||
|
||||
// Define a command-line flag for the response message
|
||||
responseMessage := flag.String("respond", "hi", "message to respond with")
|
||||
@@ -41,11 +43,70 @@ func main() {
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
// for issue #62 to check model name strips profile slug
|
||||
// has to be one of the openAI API endpoints that llama-swap proxies
|
||||
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
|
||||
r.POST("/v1/audio/speech", func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
defer c.Request.Body.Close()
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
if modelName != *expectedModel {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
})
|
||||
|
||||
r.POST("/v1/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
// issue #41
|
||||
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
|
||||
// Parse the multipart form
|
||||
if err := c.Request.ParseMultipartForm(10 << 20); err != nil { // 10 MB max memory
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
|
||||
return
|
||||
}
|
||||
|
||||
// Get the model from the form values
|
||||
model := c.Request.FormValue("model")
|
||||
|
||||
if model == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing model parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get the file from the form
|
||||
file, _, err := c.Request.FormFile("file")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error getting file: %s", err)})
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Read the file content to get its size
|
||||
fileBytes, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error reading file: %s", err)})
|
||||
return
|
||||
}
|
||||
|
||||
fileSize := len(fileBytes)
|
||||
|
||||
// Return a JSON response with the model and transcription text including file size
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
||||
"model": model,
|
||||
})
|
||||
})
|
||||
|
||||
r.GET("/slow-respond", func(c *gin.Context) {
|
||||
echo := c.Query("echo")
|
||||
delay := c.Query("delay")
|
||||
|
||||
@@ -17,6 +17,7 @@ type ModelConfig struct {
|
||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||
UnloadAfter int `yaml:"ttl"`
|
||||
Unlisted bool `yaml:"unlisted"`
|
||||
UseModelName string `yaml:"useModelName"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
@@ -26,6 +27,7 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
Models map[string]ModelConfig `yaml:"models"`
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
|
||||
|
||||
+188
-75
@@ -12,32 +12,65 @@
|
||||
flex-direction: column;
|
||||
font-family: "Courier New", Courier, monospace;
|
||||
}
|
||||
#log-controls {
|
||||
margin: 0.5em;
|
||||
.log-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between; /* Spaces out elements evenly */
|
||||
}
|
||||
#log-controls input {
|
||||
flex: 1;
|
||||
}
|
||||
#log-controls input:focus {
|
||||
outline: none; /* Ensures no outline is shown when the input is focused */
|
||||
}
|
||||
#log-stream {
|
||||
flex: 1;
|
||||
gap: 0.5em;
|
||||
margin: 0.5em;
|
||||
min-height: 0;
|
||||
}
|
||||
.log-column {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
transition: flex 0.3s ease;
|
||||
}
|
||||
.log-column.minimized {
|
||||
flex: 0.1;
|
||||
max-width: 50px;
|
||||
border: 1px solid #777;
|
||||
color: green;
|
||||
}
|
||||
.log-controls {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr auto;
|
||||
gap: 0.5em;
|
||||
margin-bottom: 0.5em;
|
||||
}
|
||||
.log-controls input {
|
||||
width: 100%;
|
||||
padding: 4px;
|
||||
}
|
||||
.log-controls input:focus {
|
||||
outline: none;
|
||||
}
|
||||
.log-stream {
|
||||
flex: 1;
|
||||
padding: 1em;
|
||||
background: #f4f4f4;
|
||||
overflow-y: auto;
|
||||
white-space: pre-wrap; /* Ensures line wrapping */
|
||||
word-wrap: break-word; /* Ensures long words wrap */
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.regex-error {
|
||||
background-color: #ff0000 !important;
|
||||
}
|
||||
|
||||
/* Make headers clickable and show pointer cursor */
|
||||
h2 {
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
margin: 0 0 0.5em 0;
|
||||
padding: 0.5em;
|
||||
}
|
||||
|
||||
h2:hover {
|
||||
background-color: rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
/* Dark mode styles */
|
||||
@media (prefers-color-scheme: dark) {
|
||||
body {
|
||||
@@ -45,101 +78,181 @@
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
#log-stream {
|
||||
.log-stream {
|
||||
background: #444;
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
#log-controls input {
|
||||
.log-controls input {
|
||||
background: #555;
|
||||
color: #fff;
|
||||
border: 1px solid #777;
|
||||
}
|
||||
|
||||
#log-controls button {
|
||||
.log-controls button {
|
||||
background: #555;
|
||||
color: #fff;
|
||||
border: 1px solid #777;
|
||||
}
|
||||
|
||||
h2:hover {
|
||||
background-color: rgba(255, 255, 255, 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
/* Hide content when minimized */
|
||||
.log-column.minimized .log-controls,
|
||||
.log-column.minimized .log-stream {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.log-column.minimized h2 {
|
||||
writing-mode: vertical-rl;
|
||||
text-orientation: mixed;
|
||||
transform: rotate(180deg);
|
||||
white-space: nowrap;
|
||||
margin: auto;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<pre id="log-stream">Waiting for logs...</pre>
|
||||
<div id="log-controls">
|
||||
<input type="text" id="filter-input" placeholder="regex filter">
|
||||
<button id="clear-button">clear</button>
|
||||
<div class="log-container">
|
||||
<div class="log-column">
|
||||
<h2>Proxy Logs</h2>
|
||||
<div class="log-controls">
|
||||
<input type="text" id="proxy-filter-input" placeholder="proxy regex filter">
|
||||
<button id="proxy-clear-button">clear</button>
|
||||
</div>
|
||||
<pre class="log-stream" id="proxy-log-stream">Waiting for proxy logs...</pre>
|
||||
</div>
|
||||
<div class="log-column minimized">
|
||||
<h2>Upstream Logs</h2>
|
||||
<div class="log-controls">
|
||||
<input type="text" id="upstream-filter-input" placeholder="upstream regex filter">
|
||||
<button id="upstream-clear-button">clear</button>
|
||||
</div>
|
||||
<pre class="log-stream" id="upstream-log-stream">Waiting for upstream logs...</pre>
|
||||
</div>
|
||||
</div>
|
||||
<script>
|
||||
const logStream = document.getElementById('log-stream');
|
||||
const filterInput = document.getElementById('filter-input');
|
||||
var logData = "";
|
||||
let regexFilter = null;
|
||||
class LogStream {
|
||||
constructor(streamElement, filterInput, clearButton, endpoint) {
|
||||
this.streamElement = streamElement;
|
||||
this.filterInput = filterInput;
|
||||
this.clearButton = clearButton;
|
||||
this.endpoint = endpoint;
|
||||
this.logData = "";
|
||||
this.regexFilter = null;
|
||||
this.eventSource = null;
|
||||
|
||||
function setupEventSource() {
|
||||
if (typeof(EventSource) !== "undefined") {
|
||||
const eventSource = new EventSource("/logs/streamSSE");
|
||||
|
||||
eventSource.onmessage = function(event) {
|
||||
logData += event.data;
|
||||
render()
|
||||
};
|
||||
|
||||
eventSource.onerror = function(err) {
|
||||
logData = "EventSource failed: " + err.message;
|
||||
};
|
||||
} else {
|
||||
logData = "SSE Not supported by this browser."
|
||||
this.initialize();
|
||||
}
|
||||
}
|
||||
|
||||
// poor-ai's react ¯\_(ツ)_/¯
|
||||
function render() {
|
||||
if (regexFilter) {
|
||||
const lines = logData.split('\n');
|
||||
const filteredLines = lines.filter(line => {
|
||||
return regexFilter === null || regexFilter.test(line);
|
||||
initialize() {
|
||||
this.filterInput.addEventListener('input', () => this.updateFilter());
|
||||
this.clearButton.addEventListener('click', () => {
|
||||
this.filterInput.value = "";
|
||||
this.regexFilter = null;
|
||||
this.render();
|
||||
});
|
||||
|
||||
if (filteredLines.length > 0) {
|
||||
logStream.textContent = filteredLines.join('\n') + '\n';
|
||||
} else {
|
||||
logStream.textContent = "";
|
||||
}
|
||||
} else {
|
||||
logStream.textContent = logData;
|
||||
this.setupEventSource();
|
||||
}
|
||||
|
||||
logStream.scrollTop = logStream.scrollHeight;
|
||||
}
|
||||
setupEventSource() {
|
||||
if (typeof(EventSource) === "undefined") {
|
||||
this.logData = "SSE Not supported by this browser.";
|
||||
this.render();
|
||||
return;
|
||||
}
|
||||
|
||||
const connect = () => {
|
||||
this.eventSource = new EventSource(this.endpoint);
|
||||
|
||||
this.eventSource.onmessage = (event) => {
|
||||
this.logData += event.data;
|
||||
this.render();
|
||||
};
|
||||
|
||||
this.eventSource.onerror = (err) => {
|
||||
// Close the current connection
|
||||
this.eventSource.close();
|
||||
|
||||
this.logData += "\nConnection lost. Retrying in 5 seconds...\n";
|
||||
this.render();
|
||||
|
||||
// Attempt to reconnect after 5 seconds
|
||||
setTimeout(() => {
|
||||
this.logData += "Attempting to reconnect...\n";
|
||||
this.render();
|
||||
connect();
|
||||
}, 5000);
|
||||
};
|
||||
};
|
||||
|
||||
// Initial connection
|
||||
connect();
|
||||
}
|
||||
|
||||
render() {
|
||||
let content = this.logData;
|
||||
|
||||
if (this.regexFilter) {
|
||||
const lines = content.split('\n');
|
||||
const filteredLines = lines.filter(line => this.regexFilter.test(line));
|
||||
content = filteredLines.length > 0 ? filteredLines.join('\n') + '\n' : "";
|
||||
}
|
||||
|
||||
this.streamElement.textContent = content;
|
||||
this.streamElement.scrollTop = this.streamElement.scrollHeight;
|
||||
}
|
||||
|
||||
updateFilter() {
|
||||
const pattern = this.filterInput.value.trim();
|
||||
this.filterInput.classList.remove('regex-error');
|
||||
|
||||
if (!pattern) {
|
||||
this.regexFilter = null;
|
||||
this.render();
|
||||
return;
|
||||
}
|
||||
|
||||
function updateFilter() {
|
||||
const pattern = filterInput.value.trim();
|
||||
filterInput.classList.remove('regex-error');
|
||||
if (pattern) {
|
||||
try {
|
||||
regexFilter = new RegExp(pattern);
|
||||
this.regexFilter = new RegExp(pattern);
|
||||
} catch (e) {
|
||||
console.error("Invalid regex pattern:", e);
|
||||
regexFilter = null;
|
||||
filterInput.classList.add('regex-error');
|
||||
return
|
||||
this.regexFilter = null;
|
||||
this.filterInput.classList.add('regex-error');
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
regexFilter = null;
|
||||
}
|
||||
|
||||
render();
|
||||
this.render();
|
||||
}
|
||||
}
|
||||
|
||||
filterInput.addEventListener('input', updateFilter);
|
||||
document.getElementById('clear-button').addEventListener('click', () => {
|
||||
filterInput.value = "";
|
||||
regexFilter = null;
|
||||
render();
|
||||
// Initialize both log streams
|
||||
document.addEventListener('DOMContentLoaded', () => {
|
||||
new LogStream(
|
||||
document.getElementById('proxy-log-stream'),
|
||||
document.getElementById('proxy-filter-input'),
|
||||
document.getElementById('proxy-clear-button'),
|
||||
"/logs/streamSSE/proxy"
|
||||
);
|
||||
|
||||
new LogStream(
|
||||
document.getElementById('upstream-log-stream'),
|
||||
document.getElementById('upstream-filter-input'),
|
||||
document.getElementById('upstream-clear-button'),
|
||||
"/logs/streamSSE/upstream"
|
||||
);
|
||||
|
||||
// Initialize clickable headers
|
||||
document.querySelectorAll('h2').forEach(header => {
|
||||
header.addEventListener('click', () => {
|
||||
const column = header.closest('.log-column');
|
||||
column.classList.toggle('minimized');
|
||||
});
|
||||
});
|
||||
});
|
||||
setupEventSource();
|
||||
updateFilter();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -2,11 +2,21 @@ package proxy
|
||||
|
||||
import (
|
||||
"container/ring"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type LogLevel int
|
||||
|
||||
const (
|
||||
LevelDebug LogLevel = iota
|
||||
LevelInfo
|
||||
LevelWarn
|
||||
LevelError
|
||||
)
|
||||
|
||||
type LogMonitor struct {
|
||||
clients map[chan []byte]bool
|
||||
mu sync.RWMutex
|
||||
@@ -15,6 +25,10 @@ type LogMonitor struct {
|
||||
|
||||
// typically this can be os.Stdout
|
||||
stdout io.Writer
|
||||
|
||||
// logging levels
|
||||
level LogLevel
|
||||
prefix string
|
||||
}
|
||||
|
||||
func NewLogMonitor() *LogMonitor {
|
||||
@@ -26,6 +40,8 @@ func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||
clients: make(map[chan []byte]bool),
|
||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||
stdout: stdout,
|
||||
level: LevelInfo,
|
||||
prefix: "",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,3 +110,77 @@ func (w *LogMonitor) broadcast(msg []byte) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *LogMonitor) SetPrefix(prefix string) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.prefix = prefix
|
||||
}
|
||||
|
||||
func (w *LogMonitor) SetLogLevel(level LogLevel) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.level = level
|
||||
}
|
||||
|
||||
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
|
||||
prefix := ""
|
||||
if w.prefix != "" {
|
||||
prefix = fmt.Sprintf("[%s] ", w.prefix)
|
||||
}
|
||||
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) log(level LogLevel, msg string) {
|
||||
if level < w.level {
|
||||
return
|
||||
}
|
||||
w.Write(w.formatMessage(level.String(), msg))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Debug(msg string) {
|
||||
w.log(LevelDebug, msg)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Info(msg string) {
|
||||
w.log(LevelInfo, msg)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Warn(msg string) {
|
||||
w.log(LevelWarn, msg)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Error(msg string) {
|
||||
w.log(LevelError, msg)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
|
||||
w.log(LevelDebug, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Infof(format string, args ...interface{}) {
|
||||
w.log(LevelInfo, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
|
||||
w.log(LevelWarn, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
|
||||
w.log(LevelError, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (l LogLevel) String() string {
|
||||
switch l {
|
||||
case LevelDebug:
|
||||
return "DEBUG"
|
||||
case LevelInfo:
|
||||
return "INFO"
|
||||
case LevelWarn:
|
||||
return "WARN"
|
||||
case LevelError:
|
||||
return "ERROR"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
+125
-108
@@ -30,11 +30,15 @@ const (
|
||||
)
|
||||
|
||||
type Process struct {
|
||||
ID string
|
||||
config ModelConfig
|
||||
cmd *exec.Cmd
|
||||
logMonitor *LogMonitor
|
||||
healthCheckTimeout int
|
||||
ID string
|
||||
config ModelConfig
|
||||
cmd *exec.Cmd
|
||||
|
||||
processLogger *LogMonitor
|
||||
proxyLogger *LogMonitor
|
||||
|
||||
healthCheckTimeout int
|
||||
healthCheckLoopInterval time.Duration
|
||||
|
||||
lastRequestHandled time.Time
|
||||
|
||||
@@ -51,54 +55,68 @@ type Process struct {
|
||||
shutdownCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Process{
|
||||
ID: ID,
|
||||
config: config,
|
||||
cmd: nil,
|
||||
logMonitor: logMonitor,
|
||||
healthCheckTimeout: healthCheckTimeout,
|
||||
state: StateStopped,
|
||||
shutdownCtx: ctx,
|
||||
shutdownCancel: cancel,
|
||||
ID: ID,
|
||||
config: config,
|
||||
cmd: nil,
|
||||
processLogger: processLogger,
|
||||
proxyLogger: proxyLogger,
|
||||
healthCheckTimeout: healthCheckTimeout,
|
||||
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
||||
state: StateStopped,
|
||||
shutdownCtx: ctx,
|
||||
shutdownCancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Process) setState(newState ProcessState) error {
|
||||
// enforce valid state transitions
|
||||
invalidTransition := false
|
||||
if p.state == StateStopped {
|
||||
// stopped -> starting
|
||||
if newState != StateStarting {
|
||||
invalidTransition = true
|
||||
}
|
||||
} else if p.state == StateStarting {
|
||||
// starting -> ready | failed | stopping
|
||||
if newState != StateReady && newState != StateFailed && newState != StateStopping {
|
||||
invalidTransition = true
|
||||
}
|
||||
} else if p.state == StateReady {
|
||||
// ready -> stopping
|
||||
if newState != StateStopping {
|
||||
invalidTransition = true
|
||||
}
|
||||
} else if p.state == StateStopping {
|
||||
// stopping -> stopped | shutdown
|
||||
if newState != StateStopped && newState != StateShutdown {
|
||||
invalidTransition = true
|
||||
}
|
||||
} else if p.state == StateFailed || p.state == StateShutdown {
|
||||
invalidTransition = true
|
||||
// LogMonitor returns the log monitor associated with the process.
|
||||
func (p *Process) LogMonitor() *LogMonitor {
|
||||
return p.processLogger
|
||||
}
|
||||
|
||||
// custom error types for swapping state
|
||||
var (
|
||||
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
||||
ErrInvalidStateTransition = errors.New("invalid state transition")
|
||||
)
|
||||
|
||||
// swapState performs a compare and swap of the state atomically. It returns the current state
|
||||
// and an error if the swap failed.
|
||||
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) {
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
if p.state != expectedState {
|
||||
return p.state, ErrExpectedStateMismatch
|
||||
}
|
||||
|
||||
if invalidTransition {
|
||||
//panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState))
|
||||
return fmt.Errorf("invalid state transition from %s to %s", p.state, newState)
|
||||
if !isValidTransition(p.state, newState) {
|
||||
p.proxyLogger.Warnf("Invalid state transition from %s to %s", p.state, newState)
|
||||
return p.state, ErrInvalidStateTransition
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("State transition from %s to %s", expectedState, newState)
|
||||
p.state = newState
|
||||
return nil
|
||||
return p.state, nil
|
||||
}
|
||||
|
||||
// Helper function to encapsulate transition rules
|
||||
func isValidTransition(from, to ProcessState) bool {
|
||||
switch from {
|
||||
case StateStopped:
|
||||
return to == StateStarting
|
||||
case StateStarting:
|
||||
return to == StateReady || to == StateFailed || to == StateStopping
|
||||
case StateReady:
|
||||
return to == StateStopping
|
||||
case StateStopping:
|
||||
return to == StateStopped || to == StateShutdown
|
||||
case StateFailed, StateShutdown:
|
||||
return false // No transitions allowed from these states
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Process) CurrentState() ProcessState {
|
||||
@@ -116,47 +134,48 @@ func (p *Process) start() error {
|
||||
return fmt.Errorf("can not start(), upstream proxy missing")
|
||||
}
|
||||
|
||||
// wait for the other start() to complete
|
||||
curState := p.CurrentState()
|
||||
|
||||
if curState == StateReady {
|
||||
return nil
|
||||
}
|
||||
|
||||
if curState == StateStarting {
|
||||
p.waitStarting.Wait()
|
||||
|
||||
if state := p.CurrentState(); state != StateReady {
|
||||
return fmt.Errorf("start() failed current state: %v", state)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
if err := p.setState(StateStarting); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.waitStarting.Add(1)
|
||||
defer p.waitStarting.Done()
|
||||
|
||||
args, err := p.config.SanitizedCommand()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||
}
|
||||
|
||||
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
|
||||
if err == ErrExpectedStateMismatch {
|
||||
// already starting, just wait for it to complete and expect
|
||||
// it to be be in the Ready start after. If not, return an error
|
||||
if curState == StateStarting {
|
||||
p.waitStarting.Wait()
|
||||
if state := p.CurrentState(); state == StateReady {
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("process was already starting but wound up in state %v", state)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("processes was in state %v when start() was called", curState)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
|
||||
}
|
||||
}
|
||||
|
||||
p.waitStarting.Add(1)
|
||||
defer p.waitStarting.Done()
|
||||
|
||||
p.cmd = exec.Command(args[0], args[1:]...)
|
||||
p.cmd.Stdout = p.logMonitor
|
||||
p.cmd.Stderr = p.logMonitor
|
||||
p.cmd.Stdout = p.processLogger
|
||||
p.cmd.Stderr = p.processLogger
|
||||
p.cmd.Env = p.config.Env
|
||||
|
||||
err = p.cmd.Start()
|
||||
|
||||
// Set process state to failed
|
||||
if err != nil {
|
||||
p.setState(StateFailed)
|
||||
if curState, swapErr := p.swapState(StateStarting, StateFailed); swapErr != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||
err, curState, swapErr,
|
||||
)
|
||||
}
|
||||
return fmt.Errorf("start() failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -191,31 +210,35 @@ func (p *Process) start() error {
|
||||
)
|
||||
defer cancelHealthCheck()
|
||||
|
||||
// Health check loop
|
||||
loop:
|
||||
// Ready Check loop
|
||||
for {
|
||||
select {
|
||||
case <-checkDeadline.Done():
|
||||
p.setState(StateFailed)
|
||||
return fmt.Errorf("health check failed after %vs", maxDuration.Seconds())
|
||||
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
||||
return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState)
|
||||
} else {
|
||||
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
||||
}
|
||||
case <-p.shutdownCtx.Done():
|
||||
return errors.New("health check interrupted due to shutdown")
|
||||
default:
|
||||
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||
p.proxyLogger.Infof("Health check passed on %s", healthURL)
|
||||
cancelHealthCheck()
|
||||
break loop
|
||||
} else {
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
endTime, _ := checkDeadline.Deadline()
|
||||
ttl := time.Until(endTime)
|
||||
fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
|
||||
p.proxyLogger.Infof("Connection refused on %s, retrying in %.0fs", healthURL, ttl.Seconds())
|
||||
} else {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
|
||||
p.proxyLogger.Infof("Health check error on %s, %v", healthURL, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
<-time.After(time.Second)
|
||||
<-time.After(p.healthCheckLoopInterval)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -226,7 +249,7 @@ func (p *Process) start() error {
|
||||
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||
|
||||
for range time.Tick(time.Second) {
|
||||
if p.state != StateReady {
|
||||
if p.CurrentState() != StateReady {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -234,7 +257,8 @@ func (p *Process) start() error {
|
||||
p.inFlightRequests.Wait()
|
||||
|
||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
|
||||
|
||||
p.proxyLogger.Infof("Unloading model %s, TTL of %ds reached.", p.ID, p.config.UnloadAfter)
|
||||
p.Stop()
|
||||
return
|
||||
}
|
||||
@@ -242,26 +266,28 @@ func (p *Process) start() error {
|
||||
}()
|
||||
}
|
||||
|
||||
return p.setState(StateReady)
|
||||
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Process) Stop() {
|
||||
// wait for any inflight requests before proceeding
|
||||
p.inFlightRequests.Wait()
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
// calling Stop() when state is invalid is a no-op
|
||||
if err := p.setState(StateStopping); err != nil {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err)
|
||||
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
||||
p.proxyLogger.Infof("Stop() Ready -> StateStopping err: %v, current state: %v", err, curState)
|
||||
return
|
||||
}
|
||||
|
||||
// stop the process with a graceful exit timeout
|
||||
p.stopCommand(5 * time.Second)
|
||||
|
||||
if err := p.setState(StateStopped); err != nil {
|
||||
panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err))
|
||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||
p.proxyLogger.Infof("Stop() StateStopping -> StateStopped err: %v, current state: %v", err, curState)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -269,19 +295,9 @@ func (p *Process) Stop() {
|
||||
// of time for any inflight requests to complete before shutting down. If the Process
|
||||
// is in the state of starting, it will cancel it and shut it down
|
||||
func (p *Process) Shutdown() {
|
||||
// cancel anything that can be interrupted by a shutdown (ie: healthcheck)
|
||||
p.shutdownCancel()
|
||||
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
p.setState(StateStopping)
|
||||
|
||||
// 5 seconds to stop the process
|
||||
p.stopCommand(5 * time.Second)
|
||||
if err := p.setState(StateShutdown); err != nil {
|
||||
fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err)
|
||||
}
|
||||
p.setState(StateShutdown)
|
||||
p.state = StateShutdown
|
||||
}
|
||||
|
||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||
@@ -296,31 +312,32 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
||||
}()
|
||||
|
||||
if p.cmd == nil || p.cmd.Process == nil {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] cmd or cmd.Process is nil", p.ID)
|
||||
p.proxyLogger.Warnf("Process [%s] cmd or cmd.Process is nil", p.ID)
|
||||
return
|
||||
}
|
||||
|
||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||
if err := p.terminateProcess(); err != nil {
|
||||
p.proxyLogger.Infof("Failed to gracefully terminate process [%s]: %v", p.ID, err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-sigtermTimeout.Done():
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
|
||||
p.proxyLogger.Infof("Process [%s] timed out waiting to stop, sending KILL signal", p.ID)
|
||||
p.cmd.Process.Kill()
|
||||
case err := <-sigtermNormal:
|
||||
if err != nil {
|
||||
if errno, ok := err.(syscall.Errno); ok {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
|
||||
p.proxyLogger.Errorf("Process [%s] errno >> %v", p.ID, errno)
|
||||
} else if exitError, ok := err.(*exec.ExitError); ok {
|
||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
|
||||
p.proxyLogger.Infof("Process [%s] stopped OK", p.ID)
|
||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
|
||||
p.proxyLogger.Infof("Process [%s] interrupted OK", p.ID)
|
||||
} else {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
|
||||
p.proxyLogger.Warnf("Process [%s] ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
||||
}
|
||||
|
||||
} else {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
|
||||
p.proxyLogger.Errorf("Process [%s] exited >> %v", p.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
//go:build !windows
|
||||
|
||||
package proxy
|
||||
|
||||
import "syscall"
|
||||
|
||||
func (p *Process) terminateProcess() error {
|
||||
return p.cmd.Process.Signal(syscall.SIGTERM)
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
//go:build windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
func (p *Process) terminateProcess() error {
|
||||
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
|
||||
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
|
||||
return cmd.Run()
|
||||
}
|
||||
+47
-39
@@ -5,7 +5,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -13,13 +12,17 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
discardLogger = NewLogMonitorWriter(io.Discard)
|
||||
)
|
||||
|
||||
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// Create a process
|
||||
process := NewProcess("test-process", 5, config, logMonitor)
|
||||
process := NewProcess("test-process", 5, config, discardLogger, discardLogger)
|
||||
defer process.Stop()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -52,11 +55,10 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||
// are all handled successfully, even though they all may ask for the process to .start()
|
||||
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
||||
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("test-process", 5, config, logMonitor)
|
||||
process := NewProcess("test-process", 5, config, discardLogger, discardLogger)
|
||||
defer process.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
@@ -84,7 +86,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("broken", 1, config, NewLogMonitor())
|
||||
process := NewProcess("broken", 1, config, discardLogger, discardLogger)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -109,7 +111,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
config.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, config.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
|
||||
process := NewProcess("ttl_test", 2, config, discardLogger, discardLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// this should take 4 seconds
|
||||
@@ -151,7 +153,7 @@ func TestProcess_LowTTLValue(t *testing.T) {
|
||||
config.UnloadAfter = 1 // second
|
||||
assert.Equal(t, 1, config.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
|
||||
process := NewProcess("ttl", 2, config, discardLogger, discardLogger)
|
||||
defer process.Stop()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
@@ -169,6 +171,8 @@ func TestProcess_LowTTLValue(t *testing.T) {
|
||||
}
|
||||
|
||||
// issue #19
|
||||
// This test makes sure using Process.Stop() does not affect pending HTTP
|
||||
// requests. All HTTP requests in this test should complete successfully.
|
||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
@@ -176,7 +180,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
|
||||
expectedMessage := "12345"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
|
||||
process := NewProcess("t", 10, config, discardLogger, discardLogger)
|
||||
defer process.Stop()
|
||||
|
||||
results := map[string]string{
|
||||
@@ -192,8 +196,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func(key string) {
|
||||
defer wg.Done()
|
||||
// send a request that should take 5 * 200ms (1 second) to complete
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=200ms", key), nil)
|
||||
// send a request where simple-responder is will wait 300ms before responding
|
||||
// this will simulate an in-progress request.
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
process.ProxyRequest(w, req)
|
||||
@@ -209,9 +214,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
}(key)
|
||||
}
|
||||
|
||||
// stop the requests in the middle
|
||||
// Stop the process while requests are still being processed
|
||||
go func() {
|
||||
<-time.After(500 * time.Millisecond)
|
||||
<-time.After(150 * time.Millisecond)
|
||||
process.Stop()
|
||||
}()
|
||||
|
||||
@@ -222,39 +227,40 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetState(t *testing.T) {
|
||||
func TestProcess_SwapState(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentState ProcessState
|
||||
expectedState ProcessState
|
||||
newState ProcessState
|
||||
expectedError error
|
||||
expectedResult ProcessState
|
||||
}{
|
||||
{"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting},
|
||||
{"Starting to Ready", StateStarting, StateReady, nil, StateReady},
|
||||
{"Starting to Failed", StateStarting, StateFailed, nil, StateFailed},
|
||||
{"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping},
|
||||
{"Ready to Stopping", StateReady, StateStopping, nil, StateStopping},
|
||||
{"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped},
|
||||
{"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown},
|
||||
{"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped},
|
||||
{"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting},
|
||||
{"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady},
|
||||
{"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady},
|
||||
{"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping},
|
||||
{"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed},
|
||||
{"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed},
|
||||
{"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown},
|
||||
{"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown},
|
||||
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
||||
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
||||
{"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed},
|
||||
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
||||
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
||||
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
||||
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
||||
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
||||
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting},
|
||||
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
||||
{"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady},
|
||||
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
||||
{"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed},
|
||||
{"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed},
|
||||
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
||||
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
||||
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
p := &Process{
|
||||
state: test.currentState,
|
||||
}
|
||||
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), discardLogger, discardLogger)
|
||||
p.state = test.currentState
|
||||
|
||||
err := p.setState(test.newState)
|
||||
resultState, err := p.swapState(test.expectedState, test.newState)
|
||||
if err != nil && test.expectedError == nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
} else if err == nil && test.expectedError != nil {
|
||||
@@ -265,8 +271,8 @@ func TestSetState(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if p.state != test.expectedResult {
|
||||
t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state)
|
||||
if resultState != test.expectedResult {
|
||||
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -277,7 +283,6 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
||||
t.Skip("skipping long shutdown test")
|
||||
}
|
||||
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
expectedMessage := "testing91931"
|
||||
|
||||
// make a config where the healthcheck will always fail because port is wrong
|
||||
@@ -285,13 +290,16 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
||||
config.Proxy = "http://localhost:9998/test"
|
||||
|
||||
healthCheckTTLSeconds := 30
|
||||
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
|
||||
process := NewProcess("test-process", healthCheckTTLSeconds, config, discardLogger, discardLogger)
|
||||
|
||||
// make it a lot faster
|
||||
process.healthCheckLoopInterval = time.Second
|
||||
|
||||
// start a goroutine to simulate a shutdown
|
||||
var wg sync.WaitGroup
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-time.After(time.Second * 2)
|
||||
<-time.After(time.Millisecond * 500)
|
||||
process.Shutdown()
|
||||
}()
|
||||
wg.Add(1)
|
||||
|
||||
+257
-58
@@ -5,7 +5,9 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -13,6 +15,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -24,59 +28,96 @@ type ProxyManager struct {
|
||||
|
||||
config *Config
|
||||
currentProcesses map[string]*Process
|
||||
logMonitor *LogMonitor
|
||||
ginEngine *gin.Engine
|
||||
|
||||
// logging
|
||||
proxyLogger *LogMonitor
|
||||
upstreamLogger *LogMonitor
|
||||
muxLogger *LogMonitor
|
||||
}
|
||||
|
||||
func New(config *Config) *ProxyManager {
|
||||
// set up loggers
|
||||
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
||||
proxyLogger := NewLogMonitorWriter(stdoutLogger)
|
||||
|
||||
if config.LogRequests {
|
||||
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
|
||||
case "debug":
|
||||
proxyLogger.SetLogLevel(LevelDebug)
|
||||
case "info":
|
||||
proxyLogger.SetLogLevel(LevelInfo)
|
||||
case "warn":
|
||||
proxyLogger.SetLogLevel(LevelWarn)
|
||||
case "error":
|
||||
proxyLogger.SetLogLevel(LevelError)
|
||||
default:
|
||||
proxyLogger.SetLogLevel(LevelInfo)
|
||||
}
|
||||
|
||||
pm := &ProxyManager{
|
||||
config: config,
|
||||
currentProcesses: make(map[string]*Process),
|
||||
logMonitor: NewLogMonitor(),
|
||||
ginEngine: gin.New(),
|
||||
|
||||
proxyLogger: proxyLogger,
|
||||
muxLogger: stdoutLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
}
|
||||
|
||||
if config.LogRequests {
|
||||
pm.ginEngine.Use(func(c *gin.Context) {
|
||||
// Start timer
|
||||
start := time.Now()
|
||||
pm.ginEngine.Use(func(c *gin.Context) {
|
||||
// Start timer
|
||||
start := time.Now()
|
||||
|
||||
// capture these because /upstream/:model rewrites them in c.Next()
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
path := c.Request.URL.Path
|
||||
// capture these because /upstream/:model rewrites them in c.Next()
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// Process request
|
||||
c.Next()
|
||||
// Process request
|
||||
c.Next()
|
||||
|
||||
// Stop timer
|
||||
duration := time.Since(start)
|
||||
// Stop timer
|
||||
duration := time.Since(start)
|
||||
|
||||
statusCode := c.Writer.Status()
|
||||
bodySize := c.Writer.Size()
|
||||
statusCode := c.Writer.Status()
|
||||
bodySize := c.Writer.Size()
|
||||
|
||||
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
|
||||
clientIP,
|
||||
time.Now().Format("2006-01-02 15:04:05"),
|
||||
method,
|
||||
path,
|
||||
c.Request.Proto,
|
||||
statusCode,
|
||||
bodySize,
|
||||
c.Request.UserAgent(),
|
||||
duration,
|
||||
)
|
||||
})
|
||||
}
|
||||
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
|
||||
clientIP,
|
||||
method,
|
||||
path,
|
||||
c.Request.Proto,
|
||||
statusCode,
|
||||
bodySize,
|
||||
c.Request.UserAgent(),
|
||||
duration,
|
||||
)
|
||||
})
|
||||
|
||||
// see: https://github.com/mostlygeek/llama-swap/issues/42
|
||||
// see: issue: #81, #77 and #42 for CORS issues
|
||||
// respond with permissive OPTIONS for any endpoint
|
||||
pm.ginEngine.Use(func(c *gin.Context) {
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
c.AbortWithStatus(204)
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
|
||||
// allow whatever the client requested by default
|
||||
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
||||
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
|
||||
c.Header("Access-Control-Allow-Headers", sanitized)
|
||||
} else {
|
||||
c.Header(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Content-Type, Authorization, Accept, X-Requested-With",
|
||||
)
|
||||
}
|
||||
c.Header("Access-Control-Max-Age", "86400")
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
@@ -93,6 +134,7 @@ func New(config *Config) *ProxyManager {
|
||||
|
||||
// Support audio/speech endpoint
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
||||
|
||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||
|
||||
@@ -100,10 +142,16 @@ func New(config *Config) *ProxyManager {
|
||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
||||
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
|
||||
|
||||
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
||||
|
||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||
|
||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||
|
||||
pm.ginEngine.GET("/", func(c *gin.Context) {
|
||||
// Set the Content-Type header to text/html
|
||||
c.Header("Content-Type", "text/html")
|
||||
@@ -222,11 +270,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
||||
defer pm.Unlock()
|
||||
|
||||
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
||||
profileName, modelName := "", requestedModel
|
||||
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||
profileName = requestedModel[:idx]
|
||||
modelName = requestedModel[idx+1:]
|
||||
}
|
||||
profileName, modelName := splitRequestedModel(requestedModel)
|
||||
|
||||
if profileName != "" {
|
||||
if _, found := pm.config.Profiles[profileName]; !found {
|
||||
@@ -259,19 +303,20 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
||||
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
||||
|
||||
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
||||
pm.proxyLogger.Debugf("No-swap, using existing process for model [%s]", requestedModel)
|
||||
return process, nil
|
||||
}
|
||||
|
||||
// stop all running models
|
||||
pm.proxyLogger.Infof("Swapping model to [%s]", requestedModel)
|
||||
pm.stopProcesses()
|
||||
|
||||
if profileName == "" {
|
||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
||||
}
|
||||
|
||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
|
||||
processKey := ProcessKeyName(profileName, modelID)
|
||||
pm.currentProcesses[processKey] = process
|
||||
} else {
|
||||
@@ -282,7 +327,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
||||
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
|
||||
}
|
||||
|
||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
|
||||
processKey := ProcessKeyName(profileName, modelID)
|
||||
pm.currentProcesses[processKey] = process
|
||||
}
|
||||
@@ -342,29 +387,147 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var requestBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
model, ok := requestBody["model"].(string)
|
||||
if !ok {
|
||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
return
|
||||
}
|
||||
|
||||
if process, err := pm.swapModel(model); err != nil {
|
||||
process, err := pm.swapModel(requestedModel)
|
||||
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
return
|
||||
} else {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// dechunk it as we already have all the body bytes see issue #11
|
||||
c.Request.Header.Del("transfer-encoding")
|
||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
}
|
||||
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
if process.config.UseModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
profileName, modelName := splitRequestedModel(requestedModel)
|
||||
if profileName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// dechunk it as we already have all the body bytes see issue #11
|
||||
c.Request.Header.Del("transfer-encoding")
|
||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||
// Create a new buffer for the reconstructed request
|
||||
var requestBuffer bytes.Buffer
|
||||
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||
|
||||
// Parse multipart form
|
||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Get model parameter from the form
|
||||
requestedModel := c.Request.FormValue("model")
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
|
||||
return
|
||||
}
|
||||
|
||||
// Swap to the requested model
|
||||
process, err := pm.swapModel(requestedModel)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Get profile name and model name from the requested model
|
||||
profileName, modelName := splitRequestedModel(requestedModel)
|
||||
|
||||
// Copy all form values
|
||||
for key, values := range c.Request.MultipartForm.Value {
|
||||
for _, value := range values {
|
||||
fieldValue := value
|
||||
// If this is the model field and we have a profile, use just the model name
|
||||
if key == "model" {
|
||||
if process.config.UseModelName != "" {
|
||||
fieldValue = process.config.UseModelName
|
||||
} else if profileName != "" {
|
||||
fieldValue = modelName
|
||||
}
|
||||
}
|
||||
field, err := multipartWriter.CreateFormField(key)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
|
||||
return
|
||||
}
|
||||
if _, err = field.Write([]byte(fieldValue)); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copy all files from the original request
|
||||
for key, fileHeaders := range c.Request.MultipartForm.File {
|
||||
for _, fileHeader := range fileHeaders {
|
||||
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
|
||||
return
|
||||
}
|
||||
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = io.Copy(formFile, file); err != nil {
|
||||
file.Close()
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
|
||||
return
|
||||
}
|
||||
file.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Close the multipart writer to finalize the form
|
||||
if err := multipartWriter.Close(); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
|
||||
return
|
||||
}
|
||||
|
||||
// Create a new request with the reconstructed form data
|
||||
modifiedReq, err := http.NewRequestWithContext(
|
||||
c.Request.Context(),
|
||||
c.Request.Method,
|
||||
c.Request.URL.String(),
|
||||
&requestBuffer,
|
||||
)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
|
||||
return
|
||||
}
|
||||
|
||||
// Copy the headers from the original request
|
||||
modifiedReq.Header = c.Request.Header.Clone()
|
||||
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
||||
|
||||
// Use the modified request for proxying
|
||||
process.ProxyRequest(c.Writer, modifiedReq)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
||||
@@ -377,6 +540,42 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||
pm.StopProcesses()
|
||||
c.String(http.StatusOK, "OK")
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
||||
context.Header("Content-Type", "application/json")
|
||||
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
||||
|
||||
for _, process := range pm.currentProcesses {
|
||||
|
||||
// Append the process ID and State (multiple entries if profiles are being used).
|
||||
runningProcesses = append(runningProcesses, gin.H{
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// Put the results under the `running` key.
|
||||
response := gin.H{
|
||||
"running": runningProcesses,
|
||||
}
|
||||
|
||||
context.JSON(http.StatusOK, response) // Always return 200 OK
|
||||
}
|
||||
|
||||
func ProcessKeyName(groupName, modelName string) string {
|
||||
return groupName + PROFILE_SPLIT_CHAR + modelName
|
||||
}
|
||||
|
||||
func splitRequestedModel(requestedModel string) (string, string) {
|
||||
profileName, modelName := "", requestedModel
|
||||
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||
profileName = requestedModel[:idx]
|
||||
modelName = requestedModel[idx+1:]
|
||||
}
|
||||
return profileName, modelName
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
)
|
||||
|
||||
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||
|
||||
accept := c.GetHeader("Accept")
|
||||
if strings.Contains(accept, "text/html") {
|
||||
// Set the Content-Type header to text/html
|
||||
@@ -28,7 +27,7 @@ func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||
}
|
||||
} else {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
history := pm.logMonitor.GetHistory()
|
||||
history := pm.muxLogger.GetHistory()
|
||||
_, err := c.Writer.Write(history)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, err)
|
||||
@@ -42,8 +41,14 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
ch := pm.logMonitor.Subscribe()
|
||||
defer pm.logMonitor.Unsubscribe(ch)
|
||||
logMonitorId := c.Param("logMonitorID")
|
||||
logger, err := pm.getLogger(logMonitorId)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
ch := logger.Subscribe()
|
||||
defer logger.Unsubscribe(ch)
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
@@ -56,7 +61,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
// Send history first if not skipped
|
||||
|
||||
if !skipHistory {
|
||||
history := pm.logMonitor.GetHistory()
|
||||
history := logger.GetHistory()
|
||||
if len(history) != 0 {
|
||||
c.Writer.Write(history)
|
||||
flusher.Flush()
|
||||
@@ -85,15 +90,21 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
ch := pm.logMonitor.Subscribe()
|
||||
defer pm.logMonitor.Unsubscribe(ch)
|
||||
logMonitorId := c.Param("logMonitorID")
|
||||
logger, err := pm.getLogger(logMonitorId)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
ch := logger.Subscribe()
|
||||
defer logger.Unsubscribe(ch)
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
|
||||
// Send history first if not skipped
|
||||
_, skipHistory := c.GetQuery("no-history")
|
||||
if !skipHistory {
|
||||
history := pm.logMonitor.GetHistory()
|
||||
history := logger.GetHistory()
|
||||
if len(history) != 0 {
|
||||
c.SSEvent("message", string(history))
|
||||
c.Writer.Flush()
|
||||
@@ -111,3 +122,21 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getLogger searches for the appropriate logger based on the logMonitorId
|
||||
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
|
||||
var logger *LogMonitor
|
||||
|
||||
if logMonitorId == "" {
|
||||
// maintain the default
|
||||
logger = pm.muxLogger
|
||||
} else if logMonitorId == "proxy" {
|
||||
logger = pm.proxyLogger
|
||||
} else if logMonitorId == "upstream" {
|
||||
logger = pm.upstreamLogger
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
|
||||
}
|
||||
|
||||
return logger, nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
@@ -304,3 +306,405 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestProxyManager_Unload(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
proc, err := proxy.swapModel("model1")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, proc)
|
||||
|
||||
assert.Len(t, proxy.currentProcesses, 1)
|
||||
req := httptest.NewRequest("GET", "/unload", nil)
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, w.Body.String(), "OK")
|
||||
assert.Len(t, proxy.currentProcesses, 0)
|
||||
}
|
||||
|
||||
// issue 62, strip profile slug from model name
|
||||
func TestProxyManager_StripProfileSlug(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
|
||||
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "ok")
|
||||
}
|
||||
|
||||
// Test issue #61 `Listing the current list of models and the loaded model.`
|
||||
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
|
||||
// Shared configuration
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
}
|
||||
|
||||
// Define a helper struct to parse the JSON response.
|
||||
type RunningResponse struct {
|
||||
Running []struct {
|
||||
Model string `json:"model"`
|
||||
State string `json:"state"`
|
||||
} `json:"running"`
|
||||
}
|
||||
|
||||
// Create proxy once for all tests
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
t.Run("no models loaded", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/running", nil)
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response RunningResponse
|
||||
|
||||
// Check if this is a valid JSON object.
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
|
||||
// We should have an empty running array here.
|
||||
assert.Empty(t, response.Running, "expected no running models")
|
||||
})
|
||||
|
||||
t.Run("single model loaded", func(t *testing.T) {
|
||||
// Load just a model.
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Simulate browser call for the `/running` endpoint.
|
||||
req = httptest.NewRequest("GET", "/running", nil)
|
||||
w = httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
|
||||
var response RunningResponse
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
|
||||
// Check if we have a single array element.
|
||||
assert.Len(t, response.Running, 1)
|
||||
|
||||
// Is this the right model?
|
||||
assert.Equal(t, "model1", response.Running[0].Model)
|
||||
|
||||
// Is the model loaded?
|
||||
assert.Equal(t, "ready", response.Running[0].State)
|
||||
})
|
||||
|
||||
t.Run("multiple models via profile", func(t *testing.T) {
|
||||
// Load more than one model.
|
||||
for _, model := range []string{"model1", "model2"} {
|
||||
profileModel := ProcessKeyName("test", model)
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
// Simulate the browser call.
|
||||
req := httptest.NewRequest("GET", "/running", nil)
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
|
||||
var response RunningResponse
|
||||
|
||||
// The JSON response must be valid.
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
|
||||
// The response should contain 2 models.
|
||||
assert.Len(t, response.Running, 2)
|
||||
|
||||
expectedModels := map[string]struct{}{
|
||||
"model1": {},
|
||||
"model2": {},
|
||||
}
|
||||
|
||||
// Iterate through the models and check their states as well.
|
||||
for _, entry := range response.Running {
|
||||
_, exists := expectedModels[entry.Model]
|
||||
assert.True(t, exists, "unexpected model %s", entry.Model)
|
||||
assert.Equal(t, "ready", entry.State)
|
||||
delete(expectedModels, entry.Model)
|
||||
}
|
||||
|
||||
// Since we deleted each model while testing for its validity we should have no more models in the response.
|
||||
assert.Empty(t, expectedModels, "unexpected additional models in response")
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"TheExpectedModel"},
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
modelInput string
|
||||
expectModel string
|
||||
}{
|
||||
{
|
||||
name: "With Profile Prefix",
|
||||
modelInput: "test:TheExpectedModel",
|
||||
expectModel: "TheExpectedModel", // Profile prefix should be stripped
|
||||
},
|
||||
{
|
||||
name: "Without Profile Prefix",
|
||||
modelInput: "TheExpectedModel",
|
||||
expectModel: "TheExpectedModel", // Should remain the same
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create a buffer with multipart form data
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
|
||||
// Add the model field
|
||||
fw, err := w.CreateFormField("model")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte(tc.modelInput))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Add a file field
|
||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||
assert.NoError(t, err)
|
||||
// Generate random content length between 10 and 20
|
||||
contentLength := rand.Intn(11) + 10 // 10 to 20
|
||||
content := make([]byte, contentLength)
|
||||
_, err = fw.Write(content)
|
||||
assert.NoError(t, err)
|
||||
w.Close()
|
||||
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(rec, req)
|
||||
|
||||
// Verify the response
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response map[string]string
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expectModel, response["model"])
|
||||
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_SplitRequestedModel(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedModel string
|
||||
expectedProfile string
|
||||
expectedModel string
|
||||
}{
|
||||
{"no profile", "gpt-4", "", "gpt-4"},
|
||||
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
|
||||
{"only profile", "profile1:", "profile1", ""},
|
||||
{"empty model", ":gpt-4", "", "gpt-4"},
|
||||
{"empty profile", ":", "", ""},
|
||||
{"no split char", "gpt-4", "", "gpt-4"},
|
||||
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
profileName, modelName := splitRequestedModel(tt.requestedModel)
|
||||
if profileName != tt.expectedProfile {
|
||||
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
|
||||
}
|
||||
if modelName != tt.expectedModel {
|
||||
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test useModelName in configuration sends overrides what is sent to upstream
|
||||
func TestProxyManager_UseModelName(t *testing.T) {
|
||||
|
||||
upstreamModelName := "upstreamModel"
|
||||
|
||||
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
||||
modelConfig.UseModelName = upstreamModelName
|
||||
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1"},
|
||||
},
|
||||
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": modelConfig,
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
tests := []struct {
|
||||
description string
|
||||
requestedModel string
|
||||
}{
|
||||
{"useModelName over rides requested model", "model1"},
|
||||
{"useModelName over rides requested profile:model", "test:model1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), upstreamModelName)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
|
||||
// Create a buffer with multipart form data
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
|
||||
// Add the model field
|
||||
fw, err := w.CreateFormField("model")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte(tt.requestedModel))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Add a file field
|
||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte("test"))
|
||||
assert.NoError(t, err)
|
||||
w.Close()
|
||||
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(rec, req)
|
||||
|
||||
// Verify the response
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response map[string]string
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, upstreamModelName, response["model"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogRequests: true,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
requestHeaders map[string]string
|
||||
expectedStatus int
|
||||
expectedHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "OPTIONS with no headers",
|
||||
method: "OPTIONS",
|
||||
expectedStatus: http.StatusNoContent,
|
||||
expectedHeaders: map[string]string{
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OPTIONS with specific headers",
|
||||
method: "OPTIONS",
|
||||
requestHeaders: map[string]string{
|
||||
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
|
||||
},
|
||||
expectedStatus: http.StatusNoContent,
|
||||
expectedHeaders: map[string]string{
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Non-OPTIONS request",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
|
||||
for k, v := range tt.requestHeaders {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
proxy.ginEngine.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
|
||||
for header, expectedValue := range tt.expectedHeaders {
|
||||
assert.Equal(t, expectedValue, w.Header().Get(header))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func isTokenChar(r rune) bool {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r >= '0' && r <= '9':
|
||||
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
|
||||
parts := strings.Split(headerValues, ",")
|
||||
valid := make([]string, 0, len(parts))
|
||||
|
||||
for _, p := range parts {
|
||||
v := strings.TrimSpace(p)
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
validPart := true
|
||||
for _, c := range v {
|
||||
if !isTokenChar(c) {
|
||||
validPart = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if validPart {
|
||||
valid = append(valid, v)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(valid, ", ")
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package proxy
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
input: " ",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single valid value",
|
||||
input: "content-type",
|
||||
expected: "content-type",
|
||||
},
|
||||
{
|
||||
name: "multiple valid values",
|
||||
input: "content-type, authorization, x-requested-with",
|
||||
expected: "content-type, authorization, x-requested-with",
|
||||
},
|
||||
{
|
||||
name: "values with extra spaces",
|
||||
input: " content-type , authorization ",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "values with tabs",
|
||||
input: "content-type,\tauthorization",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "values with invalid characters",
|
||||
input: "content-type, auth\n, x-requested-with\r",
|
||||
expected: "content-type, auth, x-requested-with",
|
||||
},
|
||||
{
|
||||
name: "empty values in list",
|
||||
input: "content-type,,authorization",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "leading and trailing commas",
|
||||
input: ",content-type,authorization,",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "mixed valid and invalid values",
|
||||
input: "content-type, \x00invalid, x-requested-with",
|
||||
expected: "content-type, x-requested-with",
|
||||
},
|
||||
{
|
||||
name: "mixed case values",
|
||||
input: "Content-Type, my-Valid-Header, Another-hEader",
|
||||
expected: "Content-Type, my-Valid-Header, Another-hEader",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SanitizeAccessControlRequestHeaderValues(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
|
||||
tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user