Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 04b4760e7e | |||
| 9fc5d5b5eb | |||
| cf82b3c633 | |||
| e363f8f498 | |||
| c9629cf3a2 | |||
| 50426935a4 | |||
| 2fceb78e8d | |||
| 9a81c53664 | |||
| 716d37de82 |
@@ -8,12 +8,13 @@ Features:
|
|||||||
|
|
||||||
- ✅ Easy to deploy: single binary with no dependencies
|
- ✅ Easy to deploy: single binary with no dependencies
|
||||||
- ✅ Single yaml configuration file
|
- ✅ Single yaml configuration file
|
||||||
- ✅ Automatically switching between models
|
- ✅ Automatic switching between models
|
||||||
- ✅ Full control over llama.cpp server settings per model
|
- ✅ Full control over llama.cpp server settings per model
|
||||||
- ✅ OpenAI API support (`v1/completions` and `v1/chat/completions`)
|
- ✅ OpenAI API support (`v1/completions` and `v1/chat/completions`)
|
||||||
- ✅ Multiple GPU support
|
- ✅ Multiple GPU support
|
||||||
- ✅ Run multiple models at once with `profiles`
|
- ✅ Run multiple models at once with `profiles`
|
||||||
- ✅ Remote log monitoring at `/log`
|
- ✅ Remote log monitoring at `/log`
|
||||||
|
- ✅ Automatic unloading of models from GPUs after timeout
|
||||||
|
|
||||||
## config.yaml
|
## config.yaml
|
||||||
|
|
||||||
@@ -71,6 +72,8 @@ profiles:
|
|||||||
- "llama"
|
- "llama"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
More [examples](examples/README.md) are available for different use cases.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
# Example Configurations
|
||||||
|
|
||||||
|
Learning by example is best.
|
||||||
|
|
||||||
|
Here in the `examples/` folder are llama-swap configurations that can be used on your local LLM server.
|
||||||
|
|
||||||
|
## List
|
||||||
|
|
||||||
|
* [Speculative Decoding](speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||||
@@ -0,0 +1,124 @@
|
|||||||
|
# Speculative Decoding
|
||||||
|
|
||||||
|
Speculative decoding can significantly improve the tokens per second. However, this comes at the cost of increased VRAM usage for the draft model. The examples provided are based on a server with three P40s and one 3090.
|
||||||
|
|
||||||
|
## Coding Use Case
|
||||||
|
|
||||||
|
This example uses Qwen2.5 Coder 32B with the 0.5B model as a draft. A quantization of Q8_0 was chosen for the draft model, as quantization has a greater impact on smaller models.
|
||||||
|
|
||||||
|
The models used are:
|
||||||
|
|
||||||
|
* [Bartowski Qwen2.5-Coder-32B-Instruct](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF)
|
||||||
|
* [Bartowski Qwen2.5-Coder-0.5B-Instruct](https://huggingface.co/bartowski/Qwen2.5-Coder-0.5B-Instruct-GGUF)
|
||||||
|
|
||||||
|
The llama-swap configuration is as follows:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
models:
|
||||||
|
"qwen-coder-32b-q4":
|
||||||
|
# main model on 3090, draft on P40 #1
|
||||||
|
cmd: >
|
||||||
|
/mnt/nvme/llama-server/llama-server-be0e35
|
||||||
|
--host 127.0.0.1 --port 9503
|
||||||
|
--flash-attn --metrics
|
||||||
|
--slots
|
||||||
|
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||||
|
-ngl 99
|
||||||
|
--ctx-size 19000
|
||||||
|
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
|
||||||
|
-ngld 99
|
||||||
|
--draft-max 16
|
||||||
|
--draft-min 4
|
||||||
|
--draft-p-min 0.4
|
||||||
|
--device CUDA0
|
||||||
|
--device-draft CUDA1
|
||||||
|
proxy: "http://127.0.0.1:9503"
|
||||||
|
```
|
||||||
|
|
||||||
|
In this configuration, two GPUs are used: a 3090 (CUDA0) for the main model and a P40 (CUDA1) for the draft model. Although both models can fit on the 3090, relocating the draft model to the P40 freed up space for a larger context size. Despite the P40 being about 1/3rd the speed of the 3090, the small model still improved tokens per second.
|
||||||
|
|
||||||
|
Multiple tests were run with various parameters, and the fastest result was chosen for the configuration. In all tests, the 0.5B model produced the largest improvements to tokens per second.
|
||||||
|
|
||||||
|
Baseline: 33.92 tokens/second on 3090 without a draft model.
|
||||||
|
|
||||||
|
| draft-max | draft-min | draft-p-min | python | TS | swift |
|
||||||
|
|-----------|-----------|-------------|--------|----|-------|
|
||||||
|
| 16 | 1 | 0.9 | 71.64 | 55.55 | 48.06 |
|
||||||
|
| 16 | 1 | 0.4 | 83.21 | 58.55 | 45.50 |
|
||||||
|
| 16 | 1 | 0.1 | 79.72 | 55.66 | 43.94 |
|
||||||
|
| 16 | 2 | 0.9 | 68.47 | 55.13 | 43.12 |
|
||||||
|
| 16 | 2 | 0.4 | 82.82 | 57.42 | 48.83 |
|
||||||
|
| 16 | 2 | 0.1 | 81.68 | 51.37 | 45.72 |
|
||||||
|
| 16 | 4 | 0.9 | 66.44 | 48.49 | 42.40 |
|
||||||
|
| 16 | 4 | 0.4 | _83.62_ (fastest)| _58.29_ | _50.17_ |
|
||||||
|
| 16 | 4 | 0.1 | 82.46 | 51.45 | 40.71 |
|
||||||
|
| 8 | 1 | 0.4 | 67.07 | 55.17 | 48.46 |
|
||||||
|
| 4 | 1 | 0.4 | 50.13 | 44.96 | 40.79 |
|
||||||
|
|
||||||
|
The test script can be found in this [gist](https://gist.github.com/mostlygeek/da429769796ac8a111142e75660820f1). It is a simple curl script that prompts generating a snake game in Python, TypeScript, or Swift. Evaluation metrics were pulled from llama.cpp's logs.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
for lang in "python" "typescript" "swift"; do
|
||||||
|
echo "Generating Snake Game in $lang using $model"
|
||||||
|
curl -s --url http://localhost:8080/v1/chat/completions -d "{\"messages\": [{\"role\": \"system\", \"content\": \"you only write code.\"}, {\"role\": \"user\", \"content\": \"write snake game in $lang\"}], \"temperature\": 0.1, \"model\":\"$model\"}" > /dev/null
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
Python consistently outperformed Swift in all tests, likely due to the 0.5B draft model being more proficient in generating Python code accepted by the larger 32B model.
|
||||||
|
|
||||||
|
## Chat
|
||||||
|
|
||||||
|
This configuration is for a regular chat use case. It produces approximately 13 tokens/second in typical use, up from ~9 tokens/second with only the 3xP40s. This is great news for P40 owners.
|
||||||
|
|
||||||
|
The models used are:
|
||||||
|
|
||||||
|
* [Bartowski Meta-Llama-3.1-70B-Instruct-GGUF](https://huggingface.co/bartowski/Meta-Llama-3.1-70B-Instruct-GGUF)
|
||||||
|
* [Bartowski Llama-3.2-3B-Instruct-GGUF](https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
models:
|
||||||
|
"llama-70B":
|
||||||
|
cmd: >
|
||||||
|
/mnt/nvme/llama-server/llama-server-be0e35
|
||||||
|
--host 127.0.0.1 --port 9602
|
||||||
|
--flash-attn --metrics
|
||||||
|
--split-mode row
|
||||||
|
--ctx-size 80000
|
||||||
|
--model /mnt/nvme/models/Meta-Llama-3.1-70B-Instruct-Q4_K_L.gguf
|
||||||
|
-ngl 99
|
||||||
|
--model-draft /mnt/nvme/models/Llama-3.2-3B-Instruct-Q4_K_M.gguf
|
||||||
|
-ngld 99
|
||||||
|
--draft-max 16
|
||||||
|
--draft-min 1
|
||||||
|
--draft-p-min 0.4
|
||||||
|
--device-draft CUDA0
|
||||||
|
--tensor-split 0,1,1,1
|
||||||
|
```
|
||||||
|
|
||||||
|
In this configuration, Llama-3.1-70B is split across three P40s, and Llama-3.2-3B is on the 3090.
|
||||||
|
|
||||||
|
Some flags deserve further explanation:
|
||||||
|
|
||||||
|
* `--split-mode row` - increases inference speeds using multiple P40s by about 30%. This is a P40-specific feature.
|
||||||
|
* `--tensor-split 0,1,1,1` - controls how the main model is split across the GPUs. This means 0% on the 3090 and an even split across the P40s. A value of `--tensor-split 0,5,4,1` would mean 0% on the 3090, 50%, 40%, and 10% respectively across the other P40s. However, this would exceed the available VRAM.
|
||||||
|
* `--ctx-size 80000` - the maximum context size that can fit in the remaining VRAM.
|
||||||
|
|
||||||
|
## What is CUDA0, CUDA1, CUDA2, CUDA3?
|
||||||
|
|
||||||
|
These devices are the IDs used by llama.cpp.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ ./llama-server --list-devices
|
||||||
|
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
|
||||||
|
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
|
||||||
|
ggml_cuda_init: found 4 CUDA devices:
|
||||||
|
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
|
||||||
|
Device 1: Tesla P40, compute capability 6.1, VMM: yes
|
||||||
|
Device 2: Tesla P40, compute capability 6.1, VMM: yes
|
||||||
|
Device 3: Tesla P40, compute capability 6.1, VMM: yes
|
||||||
|
Available devices:
|
||||||
|
CUDA0: NVIDIA GeForce RTX 3090 (24154 MiB, 23892 MiB free)
|
||||||
|
CUDA1: Tesla P40 (24438 MiB, 24290 MiB free)
|
||||||
|
CUDA2: Tesla P40 (24438 MiB, 24290 MiB free)
|
||||||
|
CUDA3: Tesla P40 (24438 MiB, 24290 MiB free)
|
||||||
|
```
|
||||||
@@ -20,6 +20,7 @@ require (
|
|||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||||
github.com/leodido/go-urn v1.4.0 // indirect
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaC
|
|||||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
||||||
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
|
|||||||
@@ -3,60 +3,137 @@ package main
|
|||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
// Define a command-line flag for the port
|
// Define a command-line flag for the port
|
||||||
port := flag.String("port", "8080", "port to listen on")
|
port := flag.String("port", "8080", "port to listen on")
|
||||||
|
|
||||||
// Define a command-line flag for the response message
|
// Define a command-line flag for the response message
|
||||||
responseMessage := flag.String("respond", "hi", "message to respond with")
|
responseMessage := flag.String("respond", "hi", "message to respond with")
|
||||||
|
|
||||||
|
silent := flag.Bool("silent", false, "disable all logging")
|
||||||
|
|
||||||
flag.Parse() // Parse the command-line flags
|
flag.Parse() // Parse the command-line flags
|
||||||
|
|
||||||
responseMessageHandler := func(w http.ResponseWriter, r *http.Request) {
|
// Create a new Gin router
|
||||||
// Set the header to text/plain
|
r := gin.New()
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
|
||||||
fmt.Fprintln(w, *responseMessage)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up the handler function using the provided response message
|
// Set up the handler function using the provided response message
|
||||||
http.HandleFunc("/v1/chat/completions", responseMessageHandler)
|
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||||
http.HandleFunc("/v1/completions", responseMessageHandler)
|
c.Header("Content-Type", "text/plain")
|
||||||
http.HandleFunc("/test", responseMessageHandler)
|
|
||||||
|
|
||||||
http.HandleFunc("/env", func(w http.ResponseWriter, r *http.Request) {
|
// add a wait to simulate a slow query
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||||
fmt.Fprintln(w, *responseMessage)
|
time.Sleep(wait)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.String(200, *responseMessage)
|
||||||
|
})
|
||||||
|
|
||||||
|
r.POST("/v1/completions", func(c *gin.Context) {
|
||||||
|
c.Header("Content-Type", "text/plain")
|
||||||
|
c.String(200, *responseMessage)
|
||||||
|
})
|
||||||
|
|
||||||
|
r.GET("/slow-respond", func(c *gin.Context) {
|
||||||
|
echo := c.Query("echo")
|
||||||
|
delay := c.Query("delay")
|
||||||
|
|
||||||
|
if echo == "" {
|
||||||
|
echo = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the duration
|
||||||
|
if delay == "" {
|
||||||
|
delay = "100ms"
|
||||||
|
}
|
||||||
|
|
||||||
|
t, err := time.ParseDuration(delay)
|
||||||
|
if err != nil {
|
||||||
|
c.Header("Content-Type", "text/plain")
|
||||||
|
c.String(http.StatusBadRequest, fmt.Sprintf("Invalid duration: %s", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Header("Content-Type", "text/plain")
|
||||||
|
for _, char := range echo {
|
||||||
|
c.Writer.Write([]byte(string(char)))
|
||||||
|
c.Writer.Flush()
|
||||||
|
|
||||||
|
// wait
|
||||||
|
<-time.After(t)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
r.GET("/test", func(c *gin.Context) {
|
||||||
|
c.Header("Content-Type", "text/plain")
|
||||||
|
c.String(200, *responseMessage)
|
||||||
|
})
|
||||||
|
|
||||||
|
r.GET("/env", func(c *gin.Context) {
|
||||||
|
c.Header("Content-Type", "text/plain")
|
||||||
|
c.String(200, *responseMessage)
|
||||||
|
|
||||||
// Get environment variables
|
// Get environment variables
|
||||||
envVars := os.Environ()
|
envVars := os.Environ()
|
||||||
|
|
||||||
// Write each environment variable to the response
|
// Write each environment variable to the response
|
||||||
for _, envVar := range envVars {
|
for _, envVar := range envVars {
|
||||||
fmt.Fprintln(w, envVar)
|
c.String(200, envVar)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Set up the /health endpoint handler function
|
// Set up the /health endpoint handler function
|
||||||
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
r.GET("/health", func(c *gin.Context) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
response := `{"status": "ok"}`
|
c.JSON(200, gin.H{"status": "ok"})
|
||||||
w.Write([]byte(response))
|
|
||||||
})
|
})
|
||||||
|
|
||||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
r.GET("/", func(c *gin.Context) {
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
c.Header("Content-Type", "text/plain")
|
||||||
fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path)
|
c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
|
||||||
})
|
})
|
||||||
|
|
||||||
address := "127.0.0.1:" + *port // Address with the specified port
|
address := "127.0.0.1:" + *port // Address with the specified port
|
||||||
fmt.Printf("Server is listening on port %s\n", *port)
|
|
||||||
|
|
||||||
// Start the server and log any error if it occurs
|
srv := &http.Server{
|
||||||
if err := http.ListenAndServe(address, nil); err != nil {
|
Addr: address,
|
||||||
fmt.Printf("Error starting server: %s\n", err)
|
Handler: r.Handler(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Disable logging if the --silent flag is set
|
||||||
|
if *silent {
|
||||||
|
gin.SetMode(gin.ReleaseMode)
|
||||||
|
gin.DefaultWriter = io.Discard
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
log.Printf("simple-responder listening on %s\n", address)
|
||||||
|
// service connections
|
||||||
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Fatalf("simple-responder err: %s\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for interrupt signal to gracefully shutdown the server with
|
||||||
|
// a timeout of 5 seconds.
|
||||||
|
quit := make(chan os.Signal, 1)
|
||||||
|
// kill (no param) default send syscall.SIGTERM
|
||||||
|
// kill -2 is syscall.SIGINT
|
||||||
|
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
|
||||||
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
<-quit
|
||||||
|
log.Println("simple-responder shutting down")
|
||||||
}
|
}
|
||||||
|
|||||||
+5
-1
@@ -5,6 +5,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/shlex"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -81,7 +82,10 @@ func SanitizeCommand(cmdStr string) ([]string, error) {
|
|||||||
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
|
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
|
||||||
|
|
||||||
// Split the command into arguments
|
// Split the command into arguments
|
||||||
args := strings.Fields(cmdStr)
|
args, err := shlex.Split(cmdStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure the command is not empty
|
// Ensure the command is not empty
|
||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
|
|||||||
+17
-8
@@ -148,17 +148,26 @@ func TestConfig_FindConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||||
// Test a simple command
|
|
||||||
args, err := SanitizeCommand("python model1.py")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{"python", "model1.py"}, args)
|
|
||||||
|
|
||||||
// Test a command with spaces and newlines
|
// Test a command with spaces and newlines
|
||||||
args, err = SanitizeCommand(`python model1.py \
|
args, err := SanitizeCommand(`python model1.py \
|
||||||
--arg1 value1 \
|
-a "double quotes" \
|
||||||
--arg2 value2`)
|
--arg2 'single quotes'
|
||||||
|
-s
|
||||||
|
--arg3 123 \
|
||||||
|
--arg4 '"string in string"'
|
||||||
|
-c "'single quoted'"
|
||||||
|
`)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
assert.Equal(t, []string{
|
||||||
|
"python", "model1.py",
|
||||||
|
"-a", "double quotes",
|
||||||
|
"--arg2", "single quotes",
|
||||||
|
"-s",
|
||||||
|
"--arg3", "123",
|
||||||
|
"--arg4", `"string in string"`,
|
||||||
|
"-c", `'single quoted'`,
|
||||||
|
}, args)
|
||||||
|
|
||||||
// Test an empty command
|
// Test an empty command
|
||||||
args, err = SanitizeCommand("")
|
args, err = SanitizeCommand("")
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelCon
|
|||||||
|
|
||||||
// Create a process configuration
|
// Create a process configuration
|
||||||
return ModelConfig{
|
return ModelConfig{
|
||||||
Cmd: fmt.Sprintf("%s --port %d --respond '%s'", binaryPath, port, expectedMessage),
|
Cmd: fmt.Sprintf("%s --port %d --silent --respond %s", binaryPath, port, expectedMessage),
|
||||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
}
|
}
|
||||||
|
|||||||
+105
-33
@@ -14,6 +14,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ProcessState string
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateStopped ProcessState = ProcessState("stopped")
|
||||||
|
StateReady ProcessState = ProcessState("ready")
|
||||||
|
StateFailed ProcessState = ProcessState("failed")
|
||||||
|
)
|
||||||
|
|
||||||
type Process struct {
|
type Process struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
@@ -23,8 +31,12 @@ type Process struct {
|
|||||||
logMonitor *LogMonitor
|
logMonitor *LogMonitor
|
||||||
healthCheckTimeout int
|
healthCheckTimeout int
|
||||||
|
|
||||||
isRunning bool
|
|
||||||
lastRequestHandled time.Time
|
lastRequestHandled time.Time
|
||||||
|
|
||||||
|
stateMutex sync.RWMutex
|
||||||
|
state ProcessState
|
||||||
|
|
||||||
|
inFlightRequests sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
||||||
@@ -34,16 +46,22 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonito
|
|||||||
cmd: nil,
|
cmd: nil,
|
||||||
logMonitor: logMonitor,
|
logMonitor: logMonitor,
|
||||||
healthCheckTimeout: healthCheckTimeout,
|
healthCheckTimeout: healthCheckTimeout,
|
||||||
|
state: StateStopped,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// start the process and check it for errors
|
// start the process and returns when it is ready
|
||||||
func (p *Process) start() error {
|
func (p *Process) start() error {
|
||||||
p.Lock()
|
|
||||||
defer p.Unlock()
|
|
||||||
|
|
||||||
if p.isRunning {
|
p.stateMutex.Lock()
|
||||||
return fmt.Errorf("process already running")
|
defer p.stateMutex.Unlock()
|
||||||
|
|
||||||
|
if p.state == StateReady {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.state == StateFailed {
|
||||||
|
return fmt.Errorf("process is in a failed state and can not be restarted")
|
||||||
}
|
}
|
||||||
|
|
||||||
args, err := p.config.SanitizedCommand()
|
args, err := p.config.SanitizedCommand()
|
||||||
@@ -57,34 +75,47 @@ func (p *Process) start() error {
|
|||||||
p.cmd.Env = p.config.Env
|
p.cmd.Env = p.config.Env
|
||||||
|
|
||||||
err = p.cmd.Start()
|
err = p.cmd.Start()
|
||||||
p.isRunning = true
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// watch for the command to exit
|
// One of three things can happen at this stage:
|
||||||
cmdCtx, cancel := context.WithCancelCause(context.Background())
|
// 1. The command exits unexpectedly
|
||||||
|
// 2. The health check fails
|
||||||
|
// 3. The health check passes
|
||||||
|
//
|
||||||
|
// only in the third case will the process be considered Ready to accept
|
||||||
|
healthCheckContext, cancelHealthCheck := context.WithCancelCause(context.Background())
|
||||||
|
defer cancelHealthCheck(nil) // clean up
|
||||||
|
cmdWaitChan := make(chan error, 1)
|
||||||
|
healthCheckChan := make(chan error, 1)
|
||||||
|
|
||||||
// monitor the command's exit status. Usually this happens if
|
|
||||||
// the process exited unexpectedly
|
|
||||||
go func() {
|
go func() {
|
||||||
err := p.cmd.Wait()
|
// possible cmd exits early
|
||||||
if err != nil {
|
cmdWaitChan <- p.cmd.Wait()
|
||||||
cancel(fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error()))
|
|
||||||
} else {
|
|
||||||
cancel(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
p.isRunning = false
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// wait a bit for process to start before checking the health endpoint
|
go func() {
|
||||||
time.Sleep(250 * time.Millisecond)
|
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
||||||
|
healthCheckChan <- p.checkHealthEndpoint(healthCheckContext)
|
||||||
|
}()
|
||||||
|
|
||||||
// wait for checkHealthEndpoint
|
select {
|
||||||
if err := p.checkHealthEndpoint(cmdCtx); err != nil {
|
case err := <-cmdWaitChan:
|
||||||
|
p.state = StateFailed
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error())
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("command [%s] exited unexpected", strings.Join(p.cmd.Args, " "))
|
||||||
|
}
|
||||||
|
cancelHealthCheck(err)
|
||||||
return err
|
return err
|
||||||
|
case err := <-healthCheckChan:
|
||||||
|
if err != nil {
|
||||||
|
p.state = StateFailed
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.config.UnloadAfter > 0 {
|
if p.config.UnloadAfter > 0 {
|
||||||
@@ -106,27 +137,64 @@ func (p *Process) start() error {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.state = StateReady
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) Stop() {
|
func (p *Process) Stop() {
|
||||||
p.Lock()
|
// wait for any inflight requests before proceeding
|
||||||
defer p.Unlock()
|
p.inFlightRequests.Wait()
|
||||||
|
|
||||||
if !p.isRunning || p.cmd == nil || p.cmd.Process == nil {
|
p.stateMutex.Lock()
|
||||||
|
defer p.stateMutex.Unlock()
|
||||||
|
|
||||||
|
if p.state != StateReady {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if p.cmd == nil || p.cmd.Process == nil {
|
||||||
|
// this situation should never happen... but if it does just update the state
|
||||||
|
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.")
|
||||||
|
p.state = StateStopped
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pretty sure this stopping code needs some work for windows and
|
||||||
|
// will be a source of pain in the future.
|
||||||
|
|
||||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||||
p.cmd.Process.Wait()
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
p.isRunning = false
|
defer cancel()
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- p.cmd.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
fmt.Printf("!!! process for %s timed out waiting to stop\n", p.ID)
|
||||||
|
p.cmd.Process.Kill()
|
||||||
|
p.cmd.Wait()
|
||||||
|
case err := <-done:
|
||||||
|
if err != nil {
|
||||||
|
if err.Error() != "wait: no child processes" {
|
||||||
|
// possible that simple-responder for testing is just not
|
||||||
|
// existing right, so suppress those errors.
|
||||||
|
fmt.Printf("!!! process for %s stopped with error > %v\n", p.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.state = StateStopped
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) IsRunning() bool {
|
func (p *Process) CurrentState() ProcessState {
|
||||||
return p.isRunning
|
p.stateMutex.RLock()
|
||||||
|
defer p.stateMutex.RUnlock()
|
||||||
|
return p.state
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) checkHealthEndpoint(cmdCtx context.Context) error {
|
func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error {
|
||||||
if p.config.Proxy == "" {
|
if p.config.Proxy == "" {
|
||||||
return fmt.Errorf("no upstream available to check /health")
|
return fmt.Errorf("no upstream available to check /health")
|
||||||
}
|
}
|
||||||
@@ -158,7 +226,7 @@ func (p *Process) checkHealthEndpoint(cmdCtx context.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(cmdCtx, time.Second)
|
ctx, cancel := context.WithTimeout(ctxFromStart, time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
@@ -205,7 +273,11 @@ func (p *Process) checkHealthEndpoint(cmdCtx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
if !p.isRunning {
|
|
||||||
|
p.inFlightRequests.Add(1)
|
||||||
|
defer p.inFlightRequests.Done()
|
||||||
|
|
||||||
|
if p.CurrentState() != StateReady {
|
||||||
if err := p.start(); err != nil {
|
if err := p.start(); err != nil {
|
||||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||||
http.Error(w, errstr, http.StatusInternalServerError)
|
http.Error(w, errstr, http.StatusInternalServerError)
|
||||||
|
|||||||
+63
-6
@@ -1,9 +1,12 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -23,9 +26,9 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
// process is automatically started
|
// process is automatically started
|
||||||
assert.False(t, process.IsRunning())
|
assert.Equal(t, StateStopped, process.CurrentState())
|
||||||
process.ProxyRequest(w, req)
|
process.ProxyRequest(w, req)
|
||||||
assert.True(t, process.IsRunning())
|
assert.Equal(t, StateReady, process.CurrentState())
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), expectedMessage)
|
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||||
@@ -49,7 +52,7 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
|||||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||||
// Create a process configuration
|
// Create a process configuration
|
||||||
config := ModelConfig{
|
config := ModelConfig{
|
||||||
Cmd: "nonexistant-command",
|
Cmd: "nonexistent-command",
|
||||||
Proxy: "http://127.0.0.1:9913",
|
Proxy: "http://127.0.0.1:9913",
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
}
|
}
|
||||||
@@ -84,13 +87,67 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
|||||||
|
|
||||||
// Proxy the request (auto start)
|
// Proxy the request (auto start)
|
||||||
process.ProxyRequest(w, req)
|
process.ProxyRequest(w, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), expectedMessage)
|
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||||
|
|
||||||
assert.True(t, process.IsRunning())
|
assert.Equal(t, StateReady, process.CurrentState())
|
||||||
|
|
||||||
// wait 5 seconds
|
// wait 5 seconds
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
|
assert.Equal(t, StateStopped, process.CurrentState())
|
||||||
assert.False(t, process.IsRunning())
|
}
|
||||||
|
|
||||||
|
// issue #19
|
||||||
|
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping long test")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMessage := "12345"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
results := map[string]string{
|
||||||
|
"12345": "",
|
||||||
|
"abcde": "",
|
||||||
|
"fghij": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
for key := range results {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(key string) {
|
||||||
|
defer wg.Done()
|
||||||
|
// send a request that should take 5 * 200ms (1 second) to complete
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=200ms", key), nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
results[key] = w.Body.String()
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
}(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stop the requests in the middle
|
||||||
|
go func() {
|
||||||
|
<-time.After(500 * time.Millisecond)
|
||||||
|
process.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
for key, result := range results {
|
||||||
|
assert.Equal(t, key, result)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+26
-13
@@ -14,6 +14,10 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
PROFILE_SPLIT_CHAR = ":"
|
||||||
|
)
|
||||||
|
|
||||||
type ProxyManager struct {
|
type ProxyManager struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
@@ -69,9 +73,14 @@ func (pm *ProxyManager) StopProcesses() {
|
|||||||
|
|
||||||
// for internal usage
|
// for internal usage
|
||||||
func (pm *ProxyManager) stopProcesses() {
|
func (pm *ProxyManager) stopProcesses() {
|
||||||
|
if len(pm.currentProcesses) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
for _, process := range pm.currentProcesses {
|
for _, process := range pm.currentProcesses {
|
||||||
process.Stop()
|
process.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
pm.currentProcesses = make(map[string]*Process)
|
pm.currentProcesses = make(map[string]*Process)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,15 +110,15 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
// Check if requestedModel contains a /
|
// Check if requestedModel contains a /
|
||||||
groupName, modelName := "", requestedModel
|
profileName, modelName := "", requestedModel
|
||||||
if idx := strings.Index(requestedModel, "/"); idx != -1 {
|
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||||
groupName = requestedModel[:idx]
|
profileName = requestedModel[:idx]
|
||||||
modelName = requestedModel[idx+1:]
|
modelName = requestedModel[idx+1:]
|
||||||
}
|
}
|
||||||
|
|
||||||
if groupName != "" {
|
if profileName != "" {
|
||||||
if _, found := pm.config.Profiles[groupName]; !found {
|
if _, found := pm.config.Profiles[profileName]; !found {
|
||||||
return nil, fmt.Errorf("model group not found %s", groupName)
|
return nil, fmt.Errorf("model group not found %s", profileName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,7 +129,8 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// exit early when already running, otherwise stop everything and swap
|
// exit early when already running, otherwise stop everything and swap
|
||||||
requestedProcessKey := groupName + "/" + realModelName
|
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
||||||
|
|
||||||
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
||||||
return process, nil
|
return process, nil
|
||||||
}
|
}
|
||||||
@@ -128,25 +138,25 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
// stop all running models
|
// stop all running models
|
||||||
pm.stopProcesses()
|
pm.stopProcesses()
|
||||||
|
|
||||||
if groupName == "" {
|
if profileName == "" {
|
||||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||||
if !found {
|
if !found {
|
||||||
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||||
processKey := groupName + "/" + modelID
|
processKey := ProcessKeyName(profileName, modelID)
|
||||||
pm.currentProcesses[processKey] = process
|
pm.currentProcesses[processKey] = process
|
||||||
} else {
|
} else {
|
||||||
for _, modelName := range pm.config.Profiles[groupName] {
|
for _, modelName := range pm.config.Profiles[profileName] {
|
||||||
if realModelName, found := pm.config.RealModelName(modelName); found {
|
if realModelName, found := pm.config.RealModelName(modelName); found {
|
||||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||||
if !found {
|
if !found {
|
||||||
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, groupName)
|
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||||
processKey := groupName + "/" + modelID
|
processKey := ProcessKeyName(profileName, modelID)
|
||||||
pm.currentProcesses[processKey] = process
|
pm.currentProcesses[processKey] = process
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -185,7 +195,6 @@ func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {
|
|||||||
|
|
||||||
process.ProxyRequest(c.Writer, c.Request)
|
process.ProxyRequest(c.Writer, c.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) {
|
||||||
@@ -197,3 +206,7 @@ func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) {
|
|||||||
|
|
||||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request"))
|
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ProcessKeyName(groupName, modelName string) string {
|
||||||
|
return groupName + PROFILE_SPLIT_CHAR + modelName
|
||||||
|
}
|
||||||
|
|||||||
+77
-10
@@ -5,7 +5,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
@@ -31,7 +33,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), modelName)
|
assert.Contains(t, w.Body.String(), modelName)
|
||||||
|
|
||||||
_, exists := proxy.currentProcesses["/"+modelName]
|
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
|
||||||
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
|
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -41,21 +43,31 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||||
|
|
||||||
|
model1 := "path1/model1"
|
||||||
|
model2 := "path2/model2"
|
||||||
|
|
||||||
|
profileModel1 := ProcessKeyName("test", model1)
|
||||||
|
profileModel2 := ProcessKeyName("test", model2)
|
||||||
|
|
||||||
config := &Config{
|
config := &Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
model1: getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
model2: getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
Profiles: map[string][]string{
|
Profiles: map[string][]string{
|
||||||
"test": {"model1", "model2"},
|
"test": {model1, model2},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
for modelID, requestedModel := range map[string]string{"model1": "test/model1", "model2": "test/model2"} {
|
for modelID, requestedModel := range map[string]string{
|
||||||
|
"model1": profileModel1,
|
||||||
|
"model2": profileModel2,
|
||||||
|
} {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -67,10 +79,65 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
|||||||
|
|
||||||
// make sure there's two loaded models
|
// make sure there's two loaded models
|
||||||
assert.Len(t, proxy.currentProcesses, 2)
|
assert.Len(t, proxy.currentProcesses, 2)
|
||||||
_, exists := proxy.currentProcesses["test/model1"]
|
_, exists := proxy.currentProcesses[profileModel1]
|
||||||
assert.True(t, exists, "expected test/model1 key in currentProcesses")
|
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
|
||||||
|
|
||||||
_, exists = proxy.currentProcesses["test/model2"]
|
|
||||||
assert.True(t, exists, "expected test/model2 key in currentProcesses")
|
|
||||||
|
|
||||||
|
_, exists = proxy.currentProcesses[profileModel2]
|
||||||
|
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
|
||||||
|
}
|
||||||
|
|
||||||
|
// When a request for a different model comes in ProxyManager should wait until
|
||||||
|
// the first request is complete before swapping. Both requests should complete
|
||||||
|
func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping slow test")
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
"model3": getTestSimpleResponderConfig("model3"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
results := map[string]string{}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
for key := range config.Models {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(key string) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
|
||||||
|
results[key] = w.Body.String()
|
||||||
|
mu.Unlock()
|
||||||
|
}(key)
|
||||||
|
|
||||||
|
<-time.After(time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
assert.Len(t, results, len(config.Models))
|
||||||
|
|
||||||
|
for key, result := range results {
|
||||||
|
assert.Equal(t, key, result)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user