Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cebf9c4d34 |
@@ -1,23 +0,0 @@
|
|||||||
# https://docs.github.com/en/actions/use-cases-and-examples/project-management/closing-inactive-issues
|
|
||||||
name: Close inactive issues
|
|
||||||
on:
|
|
||||||
schedule:
|
|
||||||
- cron: "32 1 * * *"
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
close-issues:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
issues: write
|
|
||||||
pull-requests: write
|
|
||||||
steps:
|
|
||||||
- uses: actions/stale@v9
|
|
||||||
with:
|
|
||||||
days-before-issue-stale: 30
|
|
||||||
days-before-issue-close: 14
|
|
||||||
stale-issue-label: "stale"
|
|
||||||
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
|
|
||||||
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
|
|
||||||
days-before-pr-stale: -1
|
|
||||||
days-before-pr-close: -1
|
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
@@ -16,7 +16,6 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
platform: [intel, cuda, vulkan, cpu, musa]
|
platform: [intel, cuda, vulkan, cpu, musa]
|
||||||
fail-fast: false
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -1,32 +0,0 @@
|
|||||||
# This workflow will build a golang project
|
|
||||||
|
|
||||||
name: CI
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ "main" ]
|
|
||||||
|
|
||||||
pull_request:
|
|
||||||
branches: [ "main" ]
|
|
||||||
|
|
||||||
# Allows manual triggering of the workflow
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
|
|
||||||
run-tests:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Set up Go
|
|
||||||
uses: actions/setup-go@v4
|
|
||||||
with:
|
|
||||||
go-version: '1.23'
|
|
||||||
|
|
||||||
# necessary for testing proxy/Process swapping
|
|
||||||
- name: Create simple-responder
|
|
||||||
run: make simple-responder
|
|
||||||
|
|
||||||
- name: Test all
|
|
||||||
run: make test-all
|
|
||||||
@@ -16,15 +16,3 @@ builds:
|
|||||||
goarch: arm64
|
goarch: arm64
|
||||||
- goos: windows
|
- goos: windows
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
|
|
||||||
# use zip format for windows
|
|
||||||
archives:
|
|
||||||
- id: default
|
|
||||||
format: tar.gz
|
|
||||||
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
|
||||||
builds_info:
|
|
||||||
group: root
|
|
||||||
owner: root
|
|
||||||
format_overrides:
|
|
||||||
- goos: windows
|
|
||||||
format: zip
|
|
||||||
@@ -1,9 +1,4 @@
|
|||||||

|

|
||||||

|
|
||||||

|
|
||||||

|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# llama-swap
|
# llama-swap
|
||||||
|
|
||||||
@@ -16,23 +11,20 @@ Written in golang, it is very easy to install (single binary with no dependancie
|
|||||||
- ✅ Easy to deploy: single binary with no dependencies
|
- ✅ Easy to deploy: single binary with no dependencies
|
||||||
- ✅ Easy to config: single yaml file
|
- ✅ Easy to config: single yaml file
|
||||||
- ✅ On-demand model switching
|
- ✅ On-demand model switching
|
||||||
|
- ✅ Full control over server settings per model
|
||||||
- ✅ OpenAI API supported endpoints:
|
- ✅ OpenAI API supported endpoints:
|
||||||
- `v1/completions`
|
- `v1/completions`
|
||||||
- `v1/chat/completions`
|
- `v1/chat/completions`
|
||||||
- `v1/embeddings`
|
- `v1/embeddings`
|
||||||
- `v1/rerank`
|
- `v1/rerank`
|
||||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
- ✅ Multiple GPU support
|
||||||
- ✅ llama-swap custom API endpoints
|
|
||||||
- `/log` - remote log monitoring
|
|
||||||
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
|
||||||
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
|
||||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
|
||||||
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
|
|
||||||
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
|
||||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
|
||||||
- ✅ Docker and Podman support
|
- ✅ Docker and Podman support
|
||||||
- ✅ Full control over server settings per model
|
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
|
||||||
|
- ✅ Remote log monitoring at `/log`
|
||||||
|
- ✅ Automatic unloading of models from GPUs after timeout
|
||||||
|
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||||
|
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||||
|
|
||||||
## How does llama-swap work?
|
## How does llama-swap work?
|
||||||
|
|
||||||
@@ -75,14 +67,7 @@ logRequests: true
|
|||||||
# define valid model values and the upstream server start
|
# define valid model values and the upstream server start
|
||||||
models:
|
models:
|
||||||
"llama":
|
"llama":
|
||||||
# multiline for readability
|
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
||||||
cmd: >
|
|
||||||
llama-server --port 8999
|
|
||||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
|
||||||
|
|
||||||
# environment variables to pass to the command
|
|
||||||
env:
|
|
||||||
- "CUDA_VISIBLE_DEVICES=0"
|
|
||||||
|
|
||||||
# where to reach the server started by cmd, make sure the ports match
|
# where to reach the server started by cmd, make sure the ports match
|
||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:8999
|
||||||
@@ -103,9 +88,16 @@ models:
|
|||||||
# default: 0 = never unload model
|
# default: 0 = never unload model
|
||||||
ttl: 60
|
ttl: 60
|
||||||
|
|
||||||
# `useModelName` overrides the model name in the request
|
"qwen":
|
||||||
# and sends a specific name to the upstream server
|
# environment variables to pass to the command
|
||||||
useModelName: "qwen:qwq"
|
env:
|
||||||
|
- "CUDA_VISIBLE_DEVICES=0"
|
||||||
|
|
||||||
|
# multiline for readability
|
||||||
|
cmd: >
|
||||||
|
llama-server --port 8999
|
||||||
|
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||||
|
proxy: http://127.0.0.1:8999
|
||||||
|
|
||||||
# unlisted models do not show up in /v1/models or /upstream lists
|
# unlisted models do not show up in /v1/models or /upstream lists
|
||||||
# but they can still be requested as normal
|
# but they can still be requested as normal
|
||||||
@@ -122,7 +114,7 @@ models:
|
|||||||
ghcr.io/ggerganov/llama.cpp:server
|
ghcr.io/ggerganov/llama.cpp:server
|
||||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||||
|
|
||||||
# profiles eliminates swapping by running multiple models at the same time
|
# profiles make it easy to managing multi model (and gpu) configurations.
|
||||||
#
|
#
|
||||||
# Tips:
|
# Tips:
|
||||||
# - each model must be listening on a unique address and port
|
# - each model must be listening on a unique address and port
|
||||||
@@ -130,20 +122,15 @@ models:
|
|||||||
# - the profile will load and unload all models in the profile at the same time
|
# - the profile will load and unload all models in the profile at the same time
|
||||||
profiles:
|
profiles:
|
||||||
coding:
|
coding:
|
||||||
|
- "qwen"
|
||||||
- "llama"
|
- "llama"
|
||||||
- "qwen-unlisted"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Use Case Examples
|
### Advanced Examples
|
||||||
|
|
||||||
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
||||||
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||||
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||||
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
|
|
||||||
|
|
||||||
## Configuration
|
|
||||||
|
|
||||||
llama-s
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -262,11 +249,3 @@ StartLimitInterval=30
|
|||||||
[Install]
|
[Install]
|
||||||
WantedBy=multi-user.target
|
WantedBy=multi-user.target
|
||||||
```
|
```
|
||||||
|
|
||||||
## Star History
|
|
||||||
|
|
||||||
<picture>
|
|
||||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
|
|
||||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
|
|
||||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
|
|
||||||
</picture>
|
|
||||||
|
|||||||
@@ -38,12 +38,6 @@ else
|
|||||||
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
||||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||||
|
|
||||||
# Abort if LCPP_TAG is empty.
|
|
||||||
if [[ -z "$LCPP_TAG" ]]; then
|
|
||||||
echo "Abort: Could not find llama-server container for arch: $ARCH"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
|
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
|
||||||
echo "Building ${CONTAINER_TAG} $LS_VER"
|
echo "Building ${CONTAINER_TAG} $LS_VER"
|
||||||
|
|||||||
@@ -1,51 +0,0 @@
|
|||||||
# Restart llama-swap on config change
|
|
||||||
|
|
||||||
Sometimes editing the configuration file can take a bit of trail and error to get a model configuration tuned just right. The `watch-and-restart.sh` script can be used to watch `config.yaml` for changes and restart `llama-swap` when it detects a change.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
#!/bin/bash
|
|
||||||
#
|
|
||||||
# A simple watch and restart llama-swap when its configuration
|
|
||||||
# file changes. Useful for trying out configuration changes
|
|
||||||
# without manually restarting the server each time.
|
|
||||||
if [ -z "$1" ]; then
|
|
||||||
echo "Usage: $0 <path to config.yaml>"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
while true; do
|
|
||||||
# Start the process again
|
|
||||||
./llama-swap-linux-amd64 -config $1 -listen :1867 &
|
|
||||||
PID=$!
|
|
||||||
echo "Started llama-swap with PID $PID"
|
|
||||||
|
|
||||||
# Wait for modifications in the specified directory or file
|
|
||||||
inotifywait -e modify "$1"
|
|
||||||
|
|
||||||
# Check if process exists before sending signal
|
|
||||||
if kill -0 $PID 2>/dev/null; then
|
|
||||||
echo "Sending SIGTERM to $PID"
|
|
||||||
kill -SIGTERM $PID
|
|
||||||
wait $PID
|
|
||||||
else
|
|
||||||
echo "Process $PID no longer exists"
|
|
||||||
fi
|
|
||||||
sleep 1
|
|
||||||
done
|
|
||||||
```
|
|
||||||
|
|
||||||
## Usage and output example
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ ./watch-and-restart.sh config.yaml
|
|
||||||
Started llama-swap with PID 495455
|
|
||||||
Setting up watches.
|
|
||||||
Watches established.
|
|
||||||
llama-swap listening on :1867
|
|
||||||
Sending SIGTERM to 495455
|
|
||||||
Shutting down llama-swap
|
|
||||||
Started llama-swap with PID 495486
|
|
||||||
Setting up watches.
|
|
||||||
Watches established.
|
|
||||||
llama-swap listening on :1867
|
|
||||||
```
|
|
||||||
@@ -3,11 +3,7 @@ module github.com/mostlygeek/llama-swap
|
|||||||
go 1.23.0
|
go 1.23.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/gin-gonic/gin v1.10.0
|
|
||||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
|
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
github.com/tidwall/gjson v1.18.0
|
|
||||||
github.com/tidwall/sjson v1.2.5
|
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,10 +15,12 @@ require (
|
|||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||||
|
github.com/gin-gonic/gin v1.10.0 // indirect
|
||||||
github.com/go-playground/locales v0.14.1 // indirect
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||||
github.com/leodido/go-urn v1.4.0 // indirect
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
@@ -31,14 +29,12 @@ require (
|
|||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/tidwall/match v1.1.1 // indirect
|
|
||||||
github.com/tidwall/pretty v1.2.1 // indirect
|
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.36.0 // indirect
|
golang.org/x/crypto v0.31.0 // indirect
|
||||||
golang.org/x/net v0.37.0 // indirect
|
golang.org/x/net v0.33.0 // indirect
|
||||||
golang.org/x/sys v0.31.0 // indirect
|
golang.org/x/sys v0.28.0 // indirect
|
||||||
golang.org/x/text v0.23.0 // indirect
|
golang.org/x/text v0.21.0 // indirect
|
||||||
google.golang.org/protobuf v1.34.1 // indirect
|
google.golang.org/protobuf v1.34.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -57,16 +57,6 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
|||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
|
||||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
|
||||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
|
||||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
|
||||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
|
||||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
|
||||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
|
||||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
|
||||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
|
||||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||||
@@ -78,28 +68,20 @@ golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
|||||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
|
||||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
|
||||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||||
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
|
||||||
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
|
||||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
|
||||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
|
||||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
|
||||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
|||||||
@@ -12,14 +12,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
// Define a command-line flag for the port
|
// Define a command-line flag for the port
|
||||||
port := flag.String("port", "8080", "port to listen on")
|
port := flag.String("port", "8080", "port to listen on")
|
||||||
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
|
|
||||||
|
|
||||||
// Define a command-line flag for the response message
|
// Define a command-line flag for the response message
|
||||||
responseMessage := flag.String("respond", "hi", "message to respond with")
|
responseMessage := flag.String("respond", "hi", "message to respond with")
|
||||||
@@ -43,70 +41,11 @@ func main() {
|
|||||||
c.String(200, *responseMessage)
|
c.String(200, *responseMessage)
|
||||||
})
|
})
|
||||||
|
|
||||||
// for issue #62 to check model name strips profile slug
|
|
||||||
// has to be one of the openAI API endpoints that llama-swap proxies
|
|
||||||
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
|
|
||||||
r.POST("/v1/audio/speech", func(c *gin.Context) {
|
|
||||||
body, err := io.ReadAll(c.Request.Body)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer c.Request.Body.Close()
|
|
||||||
modelName := gjson.GetBytes(body, "model").String()
|
|
||||||
if modelName != *expectedModel {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
r.POST("/v1/completions", func(c *gin.Context) {
|
r.POST("/v1/completions", func(c *gin.Context) {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "text/plain")
|
||||||
c.String(200, *responseMessage)
|
c.String(200, *responseMessage)
|
||||||
})
|
})
|
||||||
|
|
||||||
// issue #41
|
|
||||||
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
|
|
||||||
// Parse the multipart form
|
|
||||||
if err := c.Request.ParseMultipartForm(10 << 20); err != nil { // 10 MB max memory
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the model from the form values
|
|
||||||
model := c.Request.FormValue("model")
|
|
||||||
|
|
||||||
if model == "" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing model parameter"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the file from the form
|
|
||||||
file, _, err := c.Request.FormFile("file")
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error getting file: %s", err)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
// Read the file content to get its size
|
|
||||||
fileBytes, err := io.ReadAll(file)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error reading file: %s", err)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
fileSize := len(fileBytes)
|
|
||||||
|
|
||||||
// Return a JSON response with the model and transcription text including file size
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
|
||||||
"model": model,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
r.GET("/slow-respond", func(c *gin.Context) {
|
r.GET("/slow-respond", func(c *gin.Context) {
|
||||||
echo := c.Query("echo")
|
echo := c.Query("echo")
|
||||||
delay := c.Query("delay")
|
delay := c.Query("delay")
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ type ModelConfig struct {
|
|||||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||||
UnloadAfter int `yaml:"ttl"`
|
UnloadAfter int `yaml:"ttl"`
|
||||||
Unlisted bool `yaml:"unlisted"`
|
Unlisted bool `yaml:"unlisted"`
|
||||||
UseModelName string `yaml:"useModelName"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||||
|
|||||||
+75
-81
@@ -34,9 +34,7 @@ type Process struct {
|
|||||||
config ModelConfig
|
config ModelConfig
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
logMonitor *LogMonitor
|
logMonitor *LogMonitor
|
||||||
|
|
||||||
healthCheckTimeout int
|
healthCheckTimeout int
|
||||||
healthCheckLoopInterval time.Duration
|
|
||||||
|
|
||||||
lastRequestHandled time.Time
|
lastRequestHandled time.Time
|
||||||
|
|
||||||
@@ -61,52 +59,46 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonito
|
|||||||
cmd: nil,
|
cmd: nil,
|
||||||
logMonitor: logMonitor,
|
logMonitor: logMonitor,
|
||||||
healthCheckTimeout: healthCheckTimeout,
|
healthCheckTimeout: healthCheckTimeout,
|
||||||
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
|
||||||
state: StateStopped,
|
state: StateStopped,
|
||||||
shutdownCtx: ctx,
|
shutdownCtx: ctx,
|
||||||
shutdownCancel: cancel,
|
shutdownCancel: cancel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// custom error types for swapping state
|
func (p *Process) setState(newState ProcessState) error {
|
||||||
var (
|
// enforce valid state transitions
|
||||||
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
invalidTransition := false
|
||||||
ErrInvalidStateTransition = errors.New("invalid state transition")
|
if p.state == StateStopped {
|
||||||
)
|
// stopped -> starting
|
||||||
|
if newState != StateStarting {
|
||||||
// swapState performs a compare and swap of the state atomically. It returns the current state
|
invalidTransition = true
|
||||||
// and an error if the swap failed.
|
}
|
||||||
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) {
|
} else if p.state == StateStarting {
|
||||||
p.stateMutex.Lock()
|
// starting -> ready | failed | stopping
|
||||||
defer p.stateMutex.Unlock()
|
if newState != StateReady && newState != StateFailed && newState != StateStopping {
|
||||||
|
invalidTransition = true
|
||||||
if p.state != expectedState {
|
}
|
||||||
return p.state, ErrExpectedStateMismatch
|
} else if p.state == StateReady {
|
||||||
|
// ready -> stopping
|
||||||
|
if newState != StateStopping {
|
||||||
|
invalidTransition = true
|
||||||
|
}
|
||||||
|
} else if p.state == StateStopping {
|
||||||
|
// stopping -> stopped | shutdown
|
||||||
|
if newState != StateStopped && newState != StateShutdown {
|
||||||
|
invalidTransition = true
|
||||||
|
}
|
||||||
|
} else if p.state == StateFailed || p.state == StateShutdown {
|
||||||
|
invalidTransition = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isValidTransition(p.state, newState) {
|
if invalidTransition {
|
||||||
return p.state, ErrInvalidStateTransition
|
//panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState))
|
||||||
|
return fmt.Errorf("invalid state transition from %s to %s", p.state, newState)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.state = newState
|
p.state = newState
|
||||||
return p.state, nil
|
return nil
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to encapsulate transition rules
|
|
||||||
func isValidTransition(from, to ProcessState) bool {
|
|
||||||
switch from {
|
|
||||||
case StateStopped:
|
|
||||||
return to == StateStarting
|
|
||||||
case StateStarting:
|
|
||||||
return to == StateReady || to == StateFailed || to == StateStopping
|
|
||||||
case StateReady:
|
|
||||||
return to == StateStopping
|
|
||||||
case StateStopping:
|
|
||||||
return to == StateStopped || to == StateShutdown
|
|
||||||
case StateFailed, StateShutdown:
|
|
||||||
return false // No transitions allowed from these states
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) CurrentState() ProcessState {
|
func (p *Process) CurrentState() ProcessState {
|
||||||
@@ -124,33 +116,38 @@ func (p *Process) start() error {
|
|||||||
return fmt.Errorf("can not start(), upstream proxy missing")
|
return fmt.Errorf("can not start(), upstream proxy missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
args, err := p.config.SanitizedCommand()
|
// wait for the other start() to complete
|
||||||
if err != nil {
|
curState := p.CurrentState()
|
||||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
|
||||||
|
if curState == StateReady {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
|
|
||||||
if err == ErrExpectedStateMismatch {
|
|
||||||
// already starting, just wait for it to complete and expect
|
|
||||||
// it to be be in the Ready start after. If not, return an error
|
|
||||||
if curState == StateStarting {
|
if curState == StateStarting {
|
||||||
p.waitStarting.Wait()
|
p.waitStarting.Wait()
|
||||||
if state := p.CurrentState(); state == StateReady {
|
|
||||||
|
if state := p.CurrentState(); state != StateReady {
|
||||||
|
return fmt.Errorf("start() failed current state: %v", state)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
} else {
|
|
||||||
return fmt.Errorf("process was already starting but wound up in state %v", state)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("processes was in state %v when start() was called", curState)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.stateMutex.Lock()
|
||||||
|
defer p.stateMutex.Unlock()
|
||||||
|
|
||||||
|
if err := p.setState(StateStarting); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
p.waitStarting.Add(1)
|
p.waitStarting.Add(1)
|
||||||
defer p.waitStarting.Done()
|
defer p.waitStarting.Done()
|
||||||
|
|
||||||
|
args, err := p.config.SanitizedCommand()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
p.cmd = exec.Command(args[0], args[1:]...)
|
p.cmd = exec.Command(args[0], args[1:]...)
|
||||||
p.cmd.Stdout = p.logMonitor
|
p.cmd.Stdout = p.logMonitor
|
||||||
p.cmd.Stderr = p.logMonitor
|
p.cmd.Stderr = p.logMonitor
|
||||||
@@ -158,14 +155,8 @@ func (p *Process) start() error {
|
|||||||
|
|
||||||
err = p.cmd.Start()
|
err = p.cmd.Start()
|
||||||
|
|
||||||
// Set process state to failed
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if curState, swapErr := p.swapState(StateStarting, StateFailed); swapErr != nil {
|
p.setState(StateFailed)
|
||||||
return fmt.Errorf(
|
|
||||||
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
|
||||||
err, curState, swapErr,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("start() failed: %v", err)
|
return fmt.Errorf("start() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -200,16 +191,13 @@ func (p *Process) start() error {
|
|||||||
)
|
)
|
||||||
defer cancelHealthCheck()
|
defer cancelHealthCheck()
|
||||||
|
|
||||||
|
// Health check loop
|
||||||
loop:
|
loop:
|
||||||
// Ready Check loop
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-checkDeadline.Done():
|
case <-checkDeadline.Done():
|
||||||
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
p.setState(StateFailed)
|
||||||
return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState)
|
return fmt.Errorf("health check failed after %vs", maxDuration.Seconds())
|
||||||
} else {
|
|
||||||
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
|
||||||
}
|
|
||||||
case <-p.shutdownCtx.Done():
|
case <-p.shutdownCtx.Done():
|
||||||
return errors.New("health check interrupted due to shutdown")
|
return errors.New("health check interrupted due to shutdown")
|
||||||
default:
|
default:
|
||||||
@@ -227,7 +215,7 @@ func (p *Process) start() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
<-time.After(p.healthCheckLoopInterval)
|
<-time.After(5 * time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -238,7 +226,7 @@ func (p *Process) start() error {
|
|||||||
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||||
|
|
||||||
for range time.Tick(time.Second) {
|
for range time.Tick(time.Second) {
|
||||||
if p.CurrentState() != StateReady {
|
if p.state != StateReady {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -254,28 +242,26 @@ func (p *Process) start() error {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
return p.setState(StateReady)
|
||||||
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
|
||||||
} else {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) Stop() {
|
func (p *Process) Stop() {
|
||||||
// wait for any inflight requests before proceeding
|
// wait for any inflight requests before proceeding
|
||||||
p.inFlightRequests.Wait()
|
p.inFlightRequests.Wait()
|
||||||
|
p.stateMutex.Lock()
|
||||||
|
defer p.stateMutex.Unlock()
|
||||||
|
|
||||||
// calling Stop() when state is invalid is a no-op
|
// calling Stop() when state is invalid is a no-op
|
||||||
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
if err := p.setState(StateStopping); err != nil {
|
||||||
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() Ready -> StateStopping err: %v, current state: %v\n", err, curState)
|
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop the process with a graceful exit timeout
|
// stop the process with a graceful exit timeout
|
||||||
p.stopCommand(5 * time.Second)
|
p.stopCommand(5 * time.Second)
|
||||||
|
|
||||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
if err := p.setState(StateStopped); err != nil {
|
||||||
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() StateStopping -> StateStopped err: %v, current state: %v\n", err, curState)
|
panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -283,9 +269,19 @@ func (p *Process) Stop() {
|
|||||||
// of time for any inflight requests to complete before shutting down. If the Process
|
// of time for any inflight requests to complete before shutting down. If the Process
|
||||||
// is in the state of starting, it will cancel it and shut it down
|
// is in the state of starting, it will cancel it and shut it down
|
||||||
func (p *Process) Shutdown() {
|
func (p *Process) Shutdown() {
|
||||||
|
// cancel anything that can be interrupted by a shutdown (ie: healthcheck)
|
||||||
p.shutdownCancel()
|
p.shutdownCancel()
|
||||||
|
|
||||||
|
p.stateMutex.Lock()
|
||||||
|
defer p.stateMutex.Unlock()
|
||||||
|
p.setState(StateStopping)
|
||||||
|
|
||||||
|
// 5 seconds to stop the process
|
||||||
p.stopCommand(5 * time.Second)
|
p.stopCommand(5 * time.Second)
|
||||||
p.state = StateShutdown
|
if err := p.setState(StateShutdown); err != nil {
|
||||||
|
fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err)
|
||||||
|
}
|
||||||
|
p.setState(StateShutdown)
|
||||||
}
|
}
|
||||||
|
|
||||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||||
@@ -304,9 +300,7 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.terminateProcess(); err != nil {
|
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||||
fmt.Fprintf(p.logMonitor, "!!! failed to gracefully terminate process [%s]: %v\n", p.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-sigtermTimeout.Done():
|
case <-sigtermTimeout.Done():
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
//go:build !windows
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import "syscall"
|
|
||||||
|
|
||||||
func (p *Process) terminateProcess() error {
|
|
||||||
return p.cmd.Process.Signal(syscall.SIGTERM)
|
|
||||||
}
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
//go:build windows
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os/exec"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (p *Process) terminateProcess() error {
|
|
||||||
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
|
|
||||||
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
+25
-33
@@ -169,8 +169,6 @@ func TestProcess_LowTTLValue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// issue #19
|
// issue #19
|
||||||
// This test makes sure using Process.Stop() does not affect pending HTTP
|
|
||||||
// requests. All HTTP requests in this test should complete successfully.
|
|
||||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.Skip("skipping slow test")
|
t.Skip("skipping slow test")
|
||||||
@@ -194,9 +192,8 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(key string) {
|
go func(key string) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
// send a request where simple-responder is will wait 300ms before responding
|
// send a request that should take 5 * 200ms (1 second) to complete
|
||||||
// this will simulate an in-progress request.
|
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=200ms", key), nil)
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil)
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
process.ProxyRequest(w, req)
|
process.ProxyRequest(w, req)
|
||||||
@@ -212,9 +209,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
|||||||
}(key)
|
}(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the process while requests are still being processed
|
// stop the requests in the middle
|
||||||
go func() {
|
go func() {
|
||||||
<-time.After(150 * time.Millisecond)
|
<-time.After(500 * time.Millisecond)
|
||||||
process.Stop()
|
process.Stop()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -225,32 +222,30 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcess_SwapState(t *testing.T) {
|
func TestSetState(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
currentState ProcessState
|
currentState ProcessState
|
||||||
expectedState ProcessState
|
|
||||||
newState ProcessState
|
newState ProcessState
|
||||||
expectedError error
|
expectedError error
|
||||||
expectedResult ProcessState
|
expectedResult ProcessState
|
||||||
}{
|
}{
|
||||||
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
{"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting},
|
||||||
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
{"Starting to Ready", StateStarting, StateReady, nil, StateReady},
|
||||||
{"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed},
|
{"Starting to Failed", StateStarting, StateFailed, nil, StateFailed},
|
||||||
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
{"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping},
|
||||||
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
{"Ready to Stopping", StateReady, StateStopping, nil, StateStopping},
|
||||||
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
{"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped},
|
||||||
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
{"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown},
|
||||||
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
{"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped},
|
||||||
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting},
|
{"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting},
|
||||||
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
{"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady},
|
||||||
{"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady},
|
{"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady},
|
||||||
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
{"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping},
|
||||||
{"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed},
|
{"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed},
|
||||||
{"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed},
|
{"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed},
|
||||||
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
{"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown},
|
||||||
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
{"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown},
|
||||||
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@@ -259,7 +254,7 @@ func TestProcess_SwapState(t *testing.T) {
|
|||||||
state: test.currentState,
|
state: test.currentState,
|
||||||
}
|
}
|
||||||
|
|
||||||
resultState, err := p.swapState(test.expectedState, test.newState)
|
err := p.setState(test.newState)
|
||||||
if err != nil && test.expectedError == nil {
|
if err != nil && test.expectedError == nil {
|
||||||
t.Errorf("Unexpected error: %v", err)
|
t.Errorf("Unexpected error: %v", err)
|
||||||
} else if err == nil && test.expectedError != nil {
|
} else if err == nil && test.expectedError != nil {
|
||||||
@@ -270,8 +265,8 @@ func TestProcess_SwapState(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if resultState != test.expectedResult {
|
if p.state != test.expectedResult {
|
||||||
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
|
t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -292,14 +287,11 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
|||||||
healthCheckTTLSeconds := 30
|
healthCheckTTLSeconds := 30
|
||||||
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
|
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
|
||||||
|
|
||||||
// make it a lot faster
|
|
||||||
process.healthCheckLoopInterval = time.Second
|
|
||||||
|
|
||||||
// start a goroutine to simulate a shutdown
|
// start a goroutine to simulate a shutdown
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
<-time.After(time.Millisecond * 500)
|
<-time.After(time.Second * 2)
|
||||||
process.Shutdown()
|
process.Shutdown()
|
||||||
}()
|
}()
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
|||||||
+19
-191
@@ -5,7 +5,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -14,8 +13,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
"github.com/tidwall/sjson"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -72,27 +69,14 @@ func New(config *Config) *ProxyManager {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// see: issue: #81, #77 and #42 for CORS issues
|
// see: https://github.com/mostlygeek/llama-swap/issues/42
|
||||||
// respond with permissive OPTIONS for any endpoint
|
// respond with permissive OPTIONS for any endpoint
|
||||||
pm.ginEngine.Use(func(c *gin.Context) {
|
pm.ginEngine.Use(func(c *gin.Context) {
|
||||||
|
|
||||||
// set this for all requests
|
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
|
||||||
|
|
||||||
if c.Request.Method == "OPTIONS" {
|
if c.Request.Method == "OPTIONS" {
|
||||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||||
// allow whatever the client requested by default
|
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||||
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
c.AbortWithStatus(204)
|
||||||
c.Header("Access-Control-Allow-Headers", headers)
|
|
||||||
} else {
|
|
||||||
c.Header(
|
|
||||||
"Access-Control-Allow-Headers",
|
|
||||||
"Content-Type, Authorization, Accept, X-Requested-With",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
c.Header("Access-Control-Max-Age", "86400")
|
|
||||||
c.AbortWithStatus(http.StatusNoContent)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
@@ -109,7 +93,6 @@ func New(config *Config) *ProxyManager {
|
|||||||
|
|
||||||
// Support audio/speech endpoint
|
// Support audio/speech endpoint
|
||||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
|
||||||
|
|
||||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||||
|
|
||||||
@@ -121,10 +104,6 @@ func New(config *Config) *ProxyManager {
|
|||||||
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
||||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
||||||
|
|
||||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
|
||||||
|
|
||||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
|
||||||
|
|
||||||
pm.ginEngine.GET("/", func(c *gin.Context) {
|
pm.ginEngine.GET("/", func(c *gin.Context) {
|
||||||
// Set the Content-Type header to text/html
|
// Set the Content-Type header to text/html
|
||||||
c.Header("Content-Type", "text/html")
|
c.Header("Content-Type", "text/html")
|
||||||
@@ -243,7 +222,11 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
||||||
profileName, modelName := splitRequestedModel(requestedModel)
|
profileName, modelName := "", requestedModel
|
||||||
|
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||||
|
profileName = requestedModel[:idx]
|
||||||
|
modelName = requestedModel[idx+1:]
|
||||||
|
}
|
||||||
|
|
||||||
if profileName != "" {
|
if profileName != "" {
|
||||||
if _, found := pm.config.Profiles[profileName]; !found {
|
if _, found := pm.config.Profiles[profileName]; !found {
|
||||||
@@ -359,37 +342,21 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
var requestBody map[string]interface{}
|
||||||
if requestedModel == "" {
|
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
model, ok := requestBody["model"].(string)
|
||||||
|
if !ok {
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
process, err := pm.swapModel(requestedModel)
|
if process, err := pm.swapModel(model); err != nil {
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
|
||||||
|
|
||||||
// issue #69 allow custom model names to be sent to upstream
|
|
||||||
if process.config.UseModelName != "" {
|
|
||||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
profileName, modelName := splitRequestedModel(requestedModel)
|
|
||||||
if profileName != "" {
|
|
||||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|
||||||
// dechunk it as we already have all the body bytes see issue #11
|
// dechunk it as we already have all the body bytes see issue #11
|
||||||
@@ -397,110 +364,7 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||||
|
|
||||||
process.ProxyRequest(c.Writer, c.Request)
|
process.ProxyRequest(c.Writer, c.Request)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|
||||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
|
||||||
// Create a new buffer for the reconstructed request
|
|
||||||
var requestBuffer bytes.Buffer
|
|
||||||
multipartWriter := multipart.NewWriter(&requestBuffer)
|
|
||||||
|
|
||||||
// Parse multipart form
|
|
||||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get model parameter from the form
|
|
||||||
requestedModel := c.Request.FormValue("model")
|
|
||||||
if requestedModel == "" {
|
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Swap to the requested model
|
|
||||||
process, err := pm.swapModel(requestedModel)
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get profile name and model name from the requested model
|
|
||||||
profileName, modelName := splitRequestedModel(requestedModel)
|
|
||||||
|
|
||||||
// Copy all form values
|
|
||||||
for key, values := range c.Request.MultipartForm.Value {
|
|
||||||
for _, value := range values {
|
|
||||||
fieldValue := value
|
|
||||||
// If this is the model field and we have a profile, use just the model name
|
|
||||||
if key == "model" {
|
|
||||||
if process.config.UseModelName != "" {
|
|
||||||
fieldValue = process.config.UseModelName
|
|
||||||
} else if profileName != "" {
|
|
||||||
fieldValue = modelName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
field, err := multipartWriter.CreateFormField(key)
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err = field.Write([]byte(fieldValue)); err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy all files from the original request
|
|
||||||
for key, fileHeaders := range c.Request.MultipartForm.File {
|
|
||||||
for _, fileHeader := range fileHeaders {
|
|
||||||
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
file, err := fileHeader.Open()
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err = io.Copy(formFile, file); err != nil {
|
|
||||||
file.Close()
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
file.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the multipart writer to finalize the form
|
|
||||||
if err := multipartWriter.Close(); err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new request with the reconstructed form data
|
|
||||||
modifiedReq, err := http.NewRequestWithContext(
|
|
||||||
c.Request.Context(),
|
|
||||||
c.Request.Method,
|
|
||||||
c.Request.URL.String(),
|
|
||||||
&requestBuffer,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy the headers from the original request
|
|
||||||
modifiedReq.Header = c.Request.Header.Clone()
|
|
||||||
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
|
||||||
|
|
||||||
// Use the modified request for proxying
|
|
||||||
process.ProxyRequest(c.Writer, modifiedReq)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
||||||
@@ -513,42 +377,6 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
|
||||||
pm.StopProcesses()
|
|
||||||
c.String(http.StatusOK, "OK")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
|
||||||
context.Header("Content-Type", "application/json")
|
|
||||||
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
|
||||||
|
|
||||||
for _, process := range pm.currentProcesses {
|
|
||||||
|
|
||||||
// Append the process ID and State (multiple entries if profiles are being used).
|
|
||||||
runningProcesses = append(runningProcesses, gin.H{
|
|
||||||
"model": process.ID,
|
|
||||||
"state": process.state,
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put the results under the `running` key.
|
|
||||||
response := gin.H{
|
|
||||||
"running": runningProcesses,
|
|
||||||
}
|
|
||||||
|
|
||||||
context.JSON(http.StatusOK, response) // Always return 200 OK
|
|
||||||
}
|
|
||||||
|
|
||||||
func ProcessKeyName(groupName, modelName string) string {
|
func ProcessKeyName(groupName, modelName string) string {
|
||||||
return groupName + PROFILE_SPLIT_CHAR + modelName
|
return groupName + PROFILE_SPLIT_CHAR + modelName
|
||||||
}
|
}
|
||||||
|
|
||||||
func splitRequestedModel(requestedModel string) (string, string) {
|
|
||||||
profileName, modelName := "", requestedModel
|
|
||||||
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
|
||||||
profileName = requestedModel[:idx]
|
|
||||||
modelName = requestedModel[idx+1:]
|
|
||||||
}
|
|
||||||
return profileName, modelName
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
|
||||||
"mime/multipart"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -306,428 +304,3 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_Unload(t *testing.T) {
|
|
||||||
config := &Config{
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Models: map[string]ModelConfig{
|
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy := New(config)
|
|
||||||
proc, err := proxy.swapModel("model1")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, proc)
|
|
||||||
|
|
||||||
assert.Len(t, proxy.currentProcesses, 1)
|
|
||||||
req := httptest.NewRequest("GET", "/unload", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
proxy.HandlerFunc(w, req)
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
assert.Equal(t, w.Body.String(), "OK")
|
|
||||||
assert.Len(t, proxy.currentProcesses, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// issue 62, strip profile slug from model name
|
|
||||||
func TestProxyManager_StripProfileSlug(t *testing.T) {
|
|
||||||
config := &Config{
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Profiles: map[string][]string{
|
|
||||||
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
|
|
||||||
},
|
|
||||||
Models: map[string]ModelConfig{
|
|
||||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy := New(config)
|
|
||||||
defer proxy.StopProcesses()
|
|
||||||
|
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
|
|
||||||
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
proxy.HandlerFunc(w, req)
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
assert.Contains(t, w.Body.String(), "ok")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test issue #61 `Listing the current list of models and the loaded model.`
|
|
||||||
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
|
||||||
|
|
||||||
// Shared configuration
|
|
||||||
config := &Config{
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Models: map[string]ModelConfig{
|
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
|
||||||
},
|
|
||||||
Profiles: map[string][]string{
|
|
||||||
"test": {"model1", "model2"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Define a helper struct to parse the JSON response.
|
|
||||||
type RunningResponse struct {
|
|
||||||
Running []struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
State string `json:"state"`
|
|
||||||
} `json:"running"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create proxy once for all tests
|
|
||||||
proxy := New(config)
|
|
||||||
defer proxy.StopProcesses()
|
|
||||||
|
|
||||||
t.Run("no models loaded", func(t *testing.T) {
|
|
||||||
req := httptest.NewRequest("GET", "/running", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
proxy.HandlerFunc(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
|
|
||||||
var response RunningResponse
|
|
||||||
|
|
||||||
// Check if this is a valid JSON object.
|
|
||||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
|
||||||
|
|
||||||
// We should have an empty running array here.
|
|
||||||
assert.Empty(t, response.Running, "expected no running models")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("single model loaded", func(t *testing.T) {
|
|
||||||
// Load just a model.
|
|
||||||
reqBody := `{"model":"model1"}`
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
proxy.HandlerFunc(w, req)
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
|
|
||||||
// Simulate browser call for the `/running` endpoint.
|
|
||||||
req = httptest.NewRequest("GET", "/running", nil)
|
|
||||||
w = httptest.NewRecorder()
|
|
||||||
proxy.HandlerFunc(w, req)
|
|
||||||
|
|
||||||
var response RunningResponse
|
|
||||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
|
||||||
|
|
||||||
// Check if we have a single array element.
|
|
||||||
assert.Len(t, response.Running, 1)
|
|
||||||
|
|
||||||
// Is this the right model?
|
|
||||||
assert.Equal(t, "model1", response.Running[0].Model)
|
|
||||||
|
|
||||||
// Is the model loaded?
|
|
||||||
assert.Equal(t, "ready", response.Running[0].State)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("multiple models via profile", func(t *testing.T) {
|
|
||||||
// Load more than one model.
|
|
||||||
for _, model := range []string{"model1", "model2"} {
|
|
||||||
profileModel := ProcessKeyName("test", model)
|
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileModel)
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
proxy.HandlerFunc(w, req)
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simulate the browser call.
|
|
||||||
req := httptest.NewRequest("GET", "/running", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
proxy.HandlerFunc(w, req)
|
|
||||||
|
|
||||||
var response RunningResponse
|
|
||||||
|
|
||||||
// The JSON response must be valid.
|
|
||||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
|
||||||
|
|
||||||
// The response should contain 2 models.
|
|
||||||
assert.Len(t, response.Running, 2)
|
|
||||||
|
|
||||||
expectedModels := map[string]struct{}{
|
|
||||||
"model1": {},
|
|
||||||
"model2": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate through the models and check their states as well.
|
|
||||||
for _, entry := range response.Running {
|
|
||||||
_, exists := expectedModels[entry.Model]
|
|
||||||
assert.True(t, exists, "unexpected model %s", entry.Model)
|
|
||||||
assert.Equal(t, "ready", entry.State)
|
|
||||||
delete(expectedModels, entry.Model)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Since we deleted each model while testing for its validity we should have no more models in the response.
|
|
||||||
assert.Empty(t, expectedModels, "unexpected additional models in response")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
|
||||||
config := &Config{
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Profiles: map[string][]string{
|
|
||||||
"test": {"TheExpectedModel"},
|
|
||||||
},
|
|
||||||
Models: map[string]ModelConfig{
|
|
||||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy := New(config)
|
|
||||||
defer proxy.StopProcesses()
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
modelInput string
|
|
||||||
expectModel string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "With Profile Prefix",
|
|
||||||
modelInput: "test:TheExpectedModel",
|
|
||||||
expectModel: "TheExpectedModel", // Profile prefix should be stripped
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Without Profile Prefix",
|
|
||||||
modelInput: "TheExpectedModel",
|
|
||||||
expectModel: "TheExpectedModel", // Should remain the same
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
// Create a buffer with multipart form data
|
|
||||||
var b bytes.Buffer
|
|
||||||
w := multipart.NewWriter(&b)
|
|
||||||
|
|
||||||
// Add the model field
|
|
||||||
fw, err := w.CreateFormField("model")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
_, err = fw.Write([]byte(tc.modelInput))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
// Add a file field
|
|
||||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
// Generate random content length between 10 and 20
|
|
||||||
contentLength := rand.Intn(11) + 10 // 10 to 20
|
|
||||||
content := make([]byte, contentLength)
|
|
||||||
_, err = fw.Write(content)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
w.Close()
|
|
||||||
|
|
||||||
// Create the request with the multipart form data
|
|
||||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
|
||||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
proxy.HandlerFunc(rec, req)
|
|
||||||
|
|
||||||
// Verify the response
|
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
|
||||||
var response map[string]string
|
|
||||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, tc.expectModel, response["model"])
|
|
||||||
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyManager_SplitRequestedModel(t *testing.T) {
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
requestedModel string
|
|
||||||
expectedProfile string
|
|
||||||
expectedModel string
|
|
||||||
}{
|
|
||||||
{"no profile", "gpt-4", "", "gpt-4"},
|
|
||||||
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
|
|
||||||
{"only profile", "profile1:", "profile1", ""},
|
|
||||||
{"empty model", ":gpt-4", "", "gpt-4"},
|
|
||||||
{"empty profile", ":", "", ""},
|
|
||||||
{"no split char", "gpt-4", "", "gpt-4"},
|
|
||||||
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
profileName, modelName := splitRequestedModel(tt.requestedModel)
|
|
||||||
if profileName != tt.expectedProfile {
|
|
||||||
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
|
|
||||||
}
|
|
||||||
if modelName != tt.expectedModel {
|
|
||||||
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test useModelName in configuration sends overrides what is sent to upstream
|
|
||||||
func TestProxyManager_UseModelName(t *testing.T) {
|
|
||||||
|
|
||||||
upstreamModelName := "upstreamModel"
|
|
||||||
|
|
||||||
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
|
||||||
modelConfig.UseModelName = upstreamModelName
|
|
||||||
|
|
||||||
config := &Config{
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Profiles: map[string][]string{
|
|
||||||
"test": {"model1"},
|
|
||||||
},
|
|
||||||
|
|
||||||
Models: map[string]ModelConfig{
|
|
||||||
"model1": modelConfig,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy := New(config)
|
|
||||||
defer proxy.StopProcesses()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
description string
|
|
||||||
requestedModel string
|
|
||||||
}{
|
|
||||||
{"useModelName over rides requested model", "model1"},
|
|
||||||
{"useModelName over rides requested profile:model", "test:model1"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
|
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.HandlerFunc(w, req)
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
assert.Contains(t, w.Body.String(), upstreamModelName)
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
|
|
||||||
// Create a buffer with multipart form data
|
|
||||||
var b bytes.Buffer
|
|
||||||
w := multipart.NewWriter(&b)
|
|
||||||
|
|
||||||
// Add the model field
|
|
||||||
fw, err := w.CreateFormField("model")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
_, err = fw.Write([]byte(tt.requestedModel))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
// Add a file field
|
|
||||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
_, err = fw.Write([]byte("test"))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
w.Close()
|
|
||||||
|
|
||||||
// Create the request with the multipart form data
|
|
||||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
|
||||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
proxy.HandlerFunc(rec, req)
|
|
||||||
|
|
||||||
// Verify the response
|
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
|
||||||
var response map[string]string
|
|
||||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, upstreamModelName, response["model"])
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
|
||||||
config := &Config{
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Models: map[string]ModelConfig{
|
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
|
||||||
},
|
|
||||||
LogRequests: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
method string
|
|
||||||
requestHeaders map[string]string
|
|
||||||
expectedStatus int
|
|
||||||
expectedHeaders map[string]string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "OPTIONS with no headers",
|
|
||||||
method: "OPTIONS",
|
|
||||||
expectedStatus: http.StatusNoContent,
|
|
||||||
expectedHeaders: map[string]string{
|
|
||||||
"Access-Control-Allow-Origin": "*",
|
|
||||||
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
|
||||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "OPTIONS with specific headers",
|
|
||||||
method: "OPTIONS",
|
|
||||||
requestHeaders: map[string]string{
|
|
||||||
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
|
|
||||||
},
|
|
||||||
expectedStatus: http.StatusNoContent,
|
|
||||||
expectedHeaders: map[string]string{
|
|
||||||
"Access-Control-Allow-Origin": "*",
|
|
||||||
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
|
||||||
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Non-OPTIONS request",
|
|
||||||
method: "GET",
|
|
||||||
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
proxy := New(config)
|
|
||||||
defer proxy.StopProcesses()
|
|
||||||
|
|
||||||
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
|
|
||||||
for k, v := range tt.requestHeaders {
|
|
||||||
req.Header.Set(k, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
proxy.ginEngine.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
|
||||||
|
|
||||||
for header, expectedValue := range tt.expectedHeaders {
|
|
||||||
assert.Equal(t, expectedValue, w.Header().Get(header))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyManager_CORSHeadersInRegularRequest(t *testing.T) {
|
|
||||||
config := &Config{
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Models: map[string]ModelConfig{
|
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
|
||||||
},
|
|
||||||
LogRequests: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy := New(config)
|
|
||||||
defer proxy.StopProcesses()
|
|
||||||
|
|
||||||
// Test that CORS headers are present in regular POST requests
|
|
||||||
reqBody := `{"model":"model1"}`
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ginEngine.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user