Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9b2ed244e2 | |||
| eeb72297f7 | |||
| eabfe70cc6 | |||
| 29cd98878d | |||
| b3d331da0d | |||
| 62275e078d | |||
| 88916059e1 | |||
| 082d5d0fc5 | |||
| 53338938bd | |||
| af653347ae |
@@ -0,0 +1,23 @@
|
|||||||
|
# https://docs.github.com/en/actions/use-cases-and-examples/project-management/closing-inactive-issues
|
||||||
|
name: Close inactive issues
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: "32 1 * * *"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
close-issues:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
pull-requests: write
|
||||||
|
steps:
|
||||||
|
- uses: actions/stale@v9
|
||||||
|
with:
|
||||||
|
days-before-issue-stale: 30
|
||||||
|
days-before-issue-close: 14
|
||||||
|
stale-issue-label: "stale"
|
||||||
|
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
|
||||||
|
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
|
||||||
|
days-before-pr-stale: -1
|
||||||
|
days-before-pr-close: -1
|
||||||
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
@@ -16,6 +16,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
platform: [intel, cuda, vulkan, cpu, musa]
|
platform: [intel, cuda, vulkan, cpu, musa]
|
||||||
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -0,0 +1,32 @@
|
|||||||
|
# This workflow will build a golang project
|
||||||
|
|
||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "main" ]
|
||||||
|
|
||||||
|
pull_request:
|
||||||
|
branches: [ "main" ]
|
||||||
|
|
||||||
|
# Allows manual triggering of the workflow
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
|
||||||
|
run-tests:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Go
|
||||||
|
uses: actions/setup-go@v4
|
||||||
|
with:
|
||||||
|
go-version: '1.23'
|
||||||
|
|
||||||
|
# necessary for testing proxy/Process swapping
|
||||||
|
- name: Create simple-responder
|
||||||
|
run: make simple-responder
|
||||||
|
|
||||||
|
- name: Test all
|
||||||
|
run: make test-all
|
||||||
@@ -18,13 +18,13 @@ Written in golang, it is very easy to install (single binary with no dependancie
|
|||||||
- `v1/embeddings`
|
- `v1/embeddings`
|
||||||
- `v1/rerank`
|
- `v1/rerank`
|
||||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||||
- ✅ Multiple GPU support
|
|
||||||
- ✅ Docker and Podman support
|
|
||||||
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
|
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
|
||||||
- ✅ Remote log monitoring at `/log`
|
- ✅ Remote log monitoring at `/log`
|
||||||
- ✅ Automatic unloading of models from GPUs after timeout
|
|
||||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
|
||||||
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||||
|
- ✅ Manually unload models via `/unload` endpoint ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||||
|
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
||||||
|
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||||
|
- ✅ Docker and Podman support
|
||||||
|
|
||||||
## How does llama-swap work?
|
## How does llama-swap work?
|
||||||
|
|
||||||
@@ -126,11 +126,16 @@ profiles:
|
|||||||
- "llama"
|
- "llama"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Advanced Examples
|
### Use Case Examples
|
||||||
|
|
||||||
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
||||||
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||||
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||||
|
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
llama-s
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -249,3 +254,11 @@ StartLimitInterval=30
|
|||||||
[Install]
|
[Install]
|
||||||
WantedBy=multi-user.target
|
WantedBy=multi-user.target
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Star History
|
||||||
|
|
||||||
|
<picture>
|
||||||
|
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
|
||||||
|
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
|
||||||
|
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
|
||||||
|
</picture>
|
||||||
|
|||||||
@@ -38,6 +38,12 @@ else
|
|||||||
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
||||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||||
|
|
||||||
|
# Abort if LCPP_TAG is empty.
|
||||||
|
if [[ -z "$LCPP_TAG" ]]; then
|
||||||
|
echo "Abort: Could not find llama-server container for arch: $ARCH"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
|
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
|
||||||
echo "Building ${CONTAINER_TAG} $LS_VER"
|
echo "Building ${CONTAINER_TAG} $LS_VER"
|
||||||
|
|||||||
@@ -0,0 +1,51 @@
|
|||||||
|
# Restart llama-swap on config change
|
||||||
|
|
||||||
|
Sometimes editing the configuration file can take a bit of trail and error to get a model configuration tuned just right. The `watch-and-restart.sh` script can be used to watch `config.yaml` for changes and restart `llama-swap` when it detects a change.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# A simple watch and restart llama-swap when its configuration
|
||||||
|
# file changes. Useful for trying out configuration changes
|
||||||
|
# without manually restarting the server each time.
|
||||||
|
if [ -z "$1" ]; then
|
||||||
|
echo "Usage: $0 <path to config.yaml>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
while true; do
|
||||||
|
# Start the process again
|
||||||
|
./llama-swap-linux-amd64 -config $1 -listen :1867 &
|
||||||
|
PID=$!
|
||||||
|
echo "Started llama-swap with PID $PID"
|
||||||
|
|
||||||
|
# Wait for modifications in the specified directory or file
|
||||||
|
inotifywait -e modify "$1"
|
||||||
|
|
||||||
|
# Check if process exists before sending signal
|
||||||
|
if kill -0 $PID 2>/dev/null; then
|
||||||
|
echo "Sending SIGTERM to $PID"
|
||||||
|
kill -SIGTERM $PID
|
||||||
|
wait $PID
|
||||||
|
else
|
||||||
|
echo "Process $PID no longer exists"
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage and output example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ ./watch-and-restart.sh config.yaml
|
||||||
|
Started llama-swap with PID 495455
|
||||||
|
Setting up watches.
|
||||||
|
Watches established.
|
||||||
|
llama-swap listening on :1867
|
||||||
|
Sending SIGTERM to 495455
|
||||||
|
Shutting down llama-swap
|
||||||
|
Started llama-swap with PID 495486
|
||||||
|
Setting up watches.
|
||||||
|
Watches established.
|
||||||
|
llama-swap listening on :1867
|
||||||
|
```
|
||||||
@@ -29,6 +29,10 @@ require (
|
|||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/tidwall/gjson v1.18.0 // indirect
|
||||||
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
|
github.com/tidwall/sjson v1.2.5 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
|
|||||||
@@ -57,6 +57,16 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
|||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||||
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||||
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||||
|
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||||
|
|||||||
@@ -12,12 +12,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
// Define a command-line flag for the port
|
// Define a command-line flag for the port
|
||||||
port := flag.String("port", "8080", "port to listen on")
|
port := flag.String("port", "8080", "port to listen on")
|
||||||
|
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
|
||||||
|
|
||||||
// Define a command-line flag for the response message
|
// Define a command-line flag for the response message
|
||||||
responseMessage := flag.String("respond", "hi", "message to respond with")
|
responseMessage := flag.String("respond", "hi", "message to respond with")
|
||||||
@@ -41,6 +43,25 @@ func main() {
|
|||||||
c.String(200, *responseMessage)
|
c.String(200, *responseMessage)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// for issue #62 to check model name strips profile slug
|
||||||
|
// has to be one of the openAI API endpoints that llama-swap proxies
|
||||||
|
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
|
||||||
|
r.POST("/v1/audio/speech", func(c *gin.Context) {
|
||||||
|
body, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer c.Request.Body.Close()
|
||||||
|
modelName := gjson.GetBytes(body, "model").String()
|
||||||
|
if modelName != *expectedModel {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
r.POST("/v1/completions", func(c *gin.Context) {
|
r.POST("/v1/completions", func(c *gin.Context) {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "text/plain")
|
||||||
c.String(200, *responseMessage)
|
c.String(200, *responseMessage)
|
||||||
|
|||||||
+16
-1
@@ -133,9 +133,24 @@ func (p *Process) start() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// There is the possibility of a hard to replicate race condition where
|
||||||
|
// curState *WAS* StateStopped but by the time we get to the p.stateMutex.Lock()
|
||||||
|
// below, it's value has changed!
|
||||||
|
|
||||||
p.stateMutex.Lock()
|
p.stateMutex.Lock()
|
||||||
defer p.stateMutex.Unlock()
|
defer p.stateMutex.Unlock()
|
||||||
|
|
||||||
|
// with the exclusive lock, check if p.state is StateStopped, which is the only valid state
|
||||||
|
// to transition from to StateReady
|
||||||
|
|
||||||
|
if p.state != StateStopped {
|
||||||
|
if p.state == StateReady {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("start() can not proceed expected StateReady but process is in %v", p.state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := p.setState(StateStarting); err != nil {
|
if err := p.setState(StateStarting); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -215,7 +230,7 @@ func (p *Process) start() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
<-time.After(time.Second)
|
<-time.After(5 * time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -169,6 +169,8 @@ func TestProcess_LowTTLValue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// issue #19
|
// issue #19
|
||||||
|
// This test makes sure using Process.Stop() does not affect pending HTTP
|
||||||
|
// requests. All HTTP requests in this test should complete successfully.
|
||||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.Skip("skipping slow test")
|
t.Skip("skipping slow test")
|
||||||
@@ -192,8 +194,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(key string) {
|
go func(key string) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
// send a request that should take 5 * 200ms (1 second) to complete
|
// send a request where simple-responder is will wait 300ms before responding
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=200ms", key), nil)
|
// this will simulate an in-progress request.
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
process.ProxyRequest(w, req)
|
process.ProxyRequest(w, req)
|
||||||
@@ -209,9 +212,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
|||||||
}(key)
|
}(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop the requests in the middle
|
// Stop the process while requests are still being processed
|
||||||
go func() {
|
go func() {
|
||||||
<-time.After(500 * time.Millisecond)
|
<-time.After(150 * time.Millisecond)
|
||||||
process.Stop()
|
process.Stop()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
+33
-14
@@ -13,6 +13,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -104,6 +106,8 @@ func New(config *Config) *ProxyManager {
|
|||||||
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
||||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
||||||
|
|
||||||
|
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||||
|
|
||||||
pm.ginEngine.GET("/", func(c *gin.Context) {
|
pm.ginEngine.GET("/", func(c *gin.Context) {
|
||||||
// Set the Content-Type header to text/html
|
// Set the Content-Type header to text/html
|
||||||
c.Header("Content-Type", "text/html")
|
c.Header("Content-Type", "text/html")
|
||||||
@@ -222,11 +226,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
||||||
profileName, modelName := "", requestedModel
|
profileName, modelName := splitRequestedModel(requestedModel)
|
||||||
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
|
||||||
profileName = requestedModel[:idx]
|
|
||||||
modelName = requestedModel[idx+1:]
|
|
||||||
}
|
|
||||||
|
|
||||||
if profileName != "" {
|
if profileName != "" {
|
||||||
if _, found := pm.config.Profiles[profileName]; !found {
|
if _, found := pm.config.Profiles[profileName]; !found {
|
||||||
@@ -342,21 +342,26 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestBody map[string]interface{}
|
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||||
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
|
if requestedModel == "" {
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
model, ok := requestBody["model"].(string)
|
|
||||||
if !ok {
|
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if process, err := pm.swapModel(model); err != nil {
|
if process, err := pm.swapModel(requestedModel); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
|
// strip
|
||||||
|
profileName, modelName := splitRequestedModel(requestedModel)
|
||||||
|
if profileName != "" {
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|
||||||
// dechunk it as we already have all the body bytes see issue #11
|
// dechunk it as we already have all the body bytes see issue #11
|
||||||
@@ -377,6 +382,20 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||||
|
pm.StopProcesses()
|
||||||
|
c.String(http.StatusOK, "OK")
|
||||||
|
}
|
||||||
|
|
||||||
func ProcessKeyName(groupName, modelName string) string {
|
func ProcessKeyName(groupName, modelName string) string {
|
||||||
return groupName + PROFILE_SPLIT_CHAR + modelName
|
return groupName + PROFILE_SPLIT_CHAR + modelName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func splitRequestedModel(requestedModel string) (string, string) {
|
||||||
|
profileName, modelName := "", requestedModel
|
||||||
|
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||||
|
profileName = requestedModel[:idx]
|
||||||
|
modelName = requestedModel[idx+1:]
|
||||||
|
}
|
||||||
|
return profileName, modelName
|
||||||
|
}
|
||||||
|
|||||||
@@ -304,3 +304,48 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_Unload(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
proc, err := proxy.swapModel("model1")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, proc)
|
||||||
|
|
||||||
|
assert.Len(t, proxy.currentProcesses, 1)
|
||||||
|
req := httptest.NewRequest("GET", "/unload", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, w.Body.String(), "OK")
|
||||||
|
assert.Len(t, proxy.currentProcesses, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// issue 62, strip profile slug from model name
|
||||||
|
func TestProxyManager_StripProfileSlug(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
|
||||||
|
},
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
|
||||||
|
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "ok")
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user