Compare commits

...

18 Commits

Author SHA1 Message Date
Benson Wong b8f888f864 Logging Improvements (#88)
This change revamps the internal logging architecture to be more flexible and descriptive. Previously all logs from both llama-swap and upstream services were mixed together. This makes it harder to troubleshoot and identify problems. This PR adds these new endpoints: 

- `/logs/stream/proxy` - just llama-swap's logs
- `/logs/stream/upstream` - stdout output from the upstream server
2025-04-04 21:01:33 -07:00
Benson Wong 192b2ae621 Remove no longer needed test 2025-04-04 14:46:01 -07:00
Benson Wong b7f8cb5094 Limit Access-Control-Allow-Origin to OPTIONS preflight requests #85 2025-04-04 14:44:35 -07:00
Benson Wong a23da6eb57 Sanitize CORS headers (#85)
Add sanitation step for `Access-Control-Allow-Headers` when echoing back user supplied headers
2025-04-01 08:43:53 -07:00
Grigorii Khvatskii 4c3aa40564 add graceful process termination on windows (#82) 2025-03-25 15:26:33 -07:00
Benson Wong 84e2c07a7e Refactor wildcard out of CORS headers (#81)
Changes to CORS functionality: 

- `Access-Control-Allow-Origin: *` is set for all requests 
- for pre-flight OPTIONS requests
  - specify methods: `Access-Control-Allow-Methods: GET, POST, PUT, PATCH, DELETE, OPTIONS`
  - if the client sent `Access-Control-Request-Headers` then echo back the same value in `Access-Control-Allow-Headers`. If no `Access-Control-Request-Headers` were sent, then send back a default set
  - set `Access-Control-Max-Age: 86400` to that may improve performance 
- Add CORS tests to the proxy-manager
2025-03-25 15:24:43 -07:00
Benson Wong 680af28bcc Allow very permissive CORS headers (#77) 2025-03-20 15:50:21 -07:00
Benson Wong d94db42ffe fix bug checking incorrect error 2025-03-20 15:49:36 -07:00
Benson Wong 93cd83c55c add override for windows (#76) 2025-03-20 13:23:04 -07:00
Benson Wong 5565fca3ac add some badges to README 2025-03-19 11:25:06 -07:00
Benson Wong d625ab8d92 Refactor process state management (#70) (#73)
* add isValidStateTransition helper function
* Replace Process.setState() with Process.swapState()
* Refactor locking logic in Process
2025-03-15 17:14:03 -07:00
Benson Wong a3f82c140b tidy up config examples in README 2025-03-15 10:36:45 -07:00
Benson Wong 5c97299e7b Add support for sending a custom model name to upstream (#69) (#71)
* add test for splitRequestedModel()
* Add `useModelName` parameter to model configuration
* add docs to README
2025-03-14 21:07:52 -07:00
Benson Wong 671c1a5a7b update deps 2025-03-13 14:00:15 -07:00
Benson Wong 52c0196e0f clean up feature list in readme 2025-03-13 13:55:20 -07:00
Benson Wong 3201a68a04 Add /v1/audio/transcriptions support (#41)
* add support for /v1/audio/transcriptions
2025-03-13 13:49:39 -07:00
Florin-Gabriel Dumitru 3ac94ad20e Adds an endpoint '/running' (#61)
* Adds an endpoint '/running' that returns either an empty JSON object if no model has been loaded so far, or the last model loaded (model key) and it's current state (state key). Possible state values are: stopped, starting, ready and stopping.

* Improves the `/running` endpoint by allowing multiple entries under the `running` key within the JSON response.
Refactors the `/running` method name (listRunningProcessesHandler).
Removes the unlisted filter implementation.

* Adds tests for:
- no model loaded
- one model loaded
- multiple models loaded

* Adds simple comments.

* Simplified code structure as per 250313 comments on PR #65.

---------

Co-authored-by: FGDumitru|B <xelotx@gmail.com>
2025-03-13 13:42:59 -07:00
Benson Wong 60355bf74a fix some potentially confusing Process.start() comment 2025-03-11 11:00:45 -07:00
17 changed files with 1127 additions and 245 deletions
+12
View File
@@ -16,3 +16,15 @@ builds:
goarch: arm64 goarch: arm64
- goos: windows - goos: windows
goarch: arm64 goarch: arm64
# use zip format for windows
archives:
- id: default
format: tar.gz
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
builds_info:
group: root
owner: root
format_overrides:
- goos: windows
format: zip
+32 -20
View File
@@ -1,4 +1,7 @@
![llama-swap header image](header.jpeg) ![llama-swap header image](header.jpeg)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml)
![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap)
# llama-swap # llama-swap
@@ -11,20 +14,23 @@ Written in golang, it is very easy to install (single binary with no dependancie
- ✅ Easy to deploy: single binary with no dependencies - ✅ Easy to deploy: single binary with no dependencies
- ✅ Easy to config: single yaml file - ✅ Easy to config: single yaml file
- ✅ On-demand model switching - ✅ On-demand model switching
- ✅ Full control over server settings per model
- ✅ OpenAI API supported endpoints: - ✅ OpenAI API supported endpoints:
- `v1/completions` - `v1/completions`
- `v1/chat/completions` - `v1/chat/completions`
- `v1/embeddings` - `v1/embeddings`
- `v1/rerank` - `v1/rerank`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36)) - `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
- ✅ llama-swap custom API endpoints
- `/log` - remote log monitoring
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741)) - ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
- ✅ Remote log monitoring at `/log`
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- ✅ Manually unload models via `/unload` endpoint ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- ✅ Automatic unloading of models after timeout by setting a `ttl` - ✅ Automatic unloading of models after timeout by setting a `ttl`
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc) - ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Docker and Podman support - ✅ Docker and Podman support
- ✅ Full control over server settings per model
## How does llama-swap work? ## How does llama-swap work?
@@ -61,13 +67,20 @@ models:
# Default (and minimum) is 15 seconds # Default (and minimum) is 15 seconds
healthCheckTimeout: 60 healthCheckTimeout: 60
# Write HTTP logs (useful for troubleshooting), defaults to false # Valid log levels: debug, info (default), warn, error
logRequests: true logLevel: info
# define valid model values and the upstream server start # define valid model values and the upstream server start
models: models:
"llama": "llama":
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf # multiline for readability
cmd: >
llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
# environment variables to pass to the command
env:
- "CUDA_VISIBLE_DEVICES=0"
# where to reach the server started by cmd, make sure the ports match # where to reach the server started by cmd, make sure the ports match
proxy: http://127.0.0.1:8999 proxy: http://127.0.0.1:8999
@@ -88,16 +101,9 @@ models:
# default: 0 = never unload model # default: 0 = never unload model
ttl: 60 ttl: 60
"qwen": # `useModelName` overrides the model name in the request
# environment variables to pass to the command # and sends a specific name to the upstream server
env: useModelName: "qwen:qwq"
- "CUDA_VISIBLE_DEVICES=0"
# multiline for readability
cmd: >
llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
proxy: http://127.0.0.1:8999
# unlisted models do not show up in /v1/models or /upstream lists # unlisted models do not show up in /v1/models or /upstream lists
# but they can still be requested as normal # but they can still be requested as normal
@@ -114,7 +120,7 @@ models:
ghcr.io/ggerganov/llama.cpp:server ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf' --model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# profiles make it easy to managing multi model (and gpu) configurations. # profiles eliminates swapping by running multiple models at the same time
# #
# Tips: # Tips:
# - each model must be listening on a unique address and port # - each model must be listening on a unique address and port
@@ -122,8 +128,8 @@ models:
# - the profile will load and unload all models in the profile at the same time # - the profile will load and unload all models in the profile at the same time
profiles: profiles:
coding: coding:
- "qwen"
- "llama" - "llama"
- "qwen-unlisted"
``` ```
### Use Case Examples ### Use Case Examples
@@ -213,9 +219,15 @@ Of course, CLI access is also supported:
# sends up to the last 10KB of logs # sends up to the last 10KB of logs
curl http://host/logs' curl http://host/logs'
# streams logs # streams combined logs
curl -Ns 'http://host/logs/stream' curl -Ns 'http://host/logs/stream'
# just llama-swap's logs
curl -Ns 'http://host/logs/stream/proxy'
# just upstream's logs
curl -Ns 'http://host/logs/stream/upstream'
# stream and filter logs with linux pipes # stream and filter logs with linux pipes
curl -Ns http://host/logs/stream | grep 'eval time' curl -Ns http://host/logs/stream | grep 'eval time'
+3 -3
View File
@@ -1,9 +1,9 @@
# Seconds to wait for llama.cpp to be available to serve requests # Seconds to wait for llama.cpp to be available to serve requests
# Default (and minimum): 15 seconds # Default (and minimum): 15 seconds
healthCheckTimeout: 15 healthCheckTimeout: 90
# Log HTTP requests helpful for troubleshoot, defaults to False # valid log levels: debug, info (default), warn, error
logRequests: true logLevel: info
models: models:
"llama": "llama":
+8 -8
View File
@@ -3,7 +3,11 @@ module github.com/mostlygeek/llama-swap
go 1.23.0 go 1.23.0
require ( require (
github.com/gin-gonic/gin v1.10.0
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
@@ -15,12 +19,10 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/gin-gonic/gin v1.10.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/go-playground/validator/v10 v10.20.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
@@ -29,16 +31,14 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.31.0 // indirect golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.33.0 // indirect golang.org/x/net v0.37.0 // indirect
golang.org/x/sys v0.28.0 // indirect golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.21.0 // indirect golang.org/x/text v0.23.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect google.golang.org/protobuf v1.34.1 // indirect
) )
+8
View File
@@ -78,20 +78,28 @@ golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+40
View File
@@ -67,6 +67,46 @@ func main() {
c.String(200, *responseMessage) c.String(200, *responseMessage)
}) })
// issue #41
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
// Parse the multipart form
if err := c.Request.ParseMultipartForm(10 << 20); err != nil { // 10 MB max memory
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
return
}
// Get the model from the form values
model := c.Request.FormValue("model")
if model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing model parameter"})
return
}
// Get the file from the form
file, _, err := c.Request.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error getting file: %s", err)})
return
}
defer file.Close()
// Read the file content to get its size
fileBytes, err := io.ReadAll(file)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error reading file: %s", err)})
return
}
fileSize := len(fileBytes)
// Return a JSON response with the model and transcription text including file size
c.JSON(http.StatusOK, gin.H{
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
"model": model,
})
})
r.GET("/slow-respond", func(c *gin.Context) { r.GET("/slow-respond", func(c *gin.Context) {
echo := c.Query("echo") echo := c.Query("echo")
delay := c.Query("delay") delay := c.Query("delay")
+2
View File
@@ -17,6 +17,7 @@ type ModelConfig struct {
CheckEndpoint string `yaml:"checkEndpoint"` CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"` UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"` Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
} }
func (m *ModelConfig) SanitizedCommand() ([]string, error) { func (m *ModelConfig) SanitizedCommand() ([]string, error) {
@@ -26,6 +27,7 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
type Config struct { type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"` HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"` LogRequests bool `yaml:"logRequests"`
LogLevel string `yaml:"logLevel"`
Models map[string]ModelConfig `yaml:"models"` Models map[string]ModelConfig `yaml:"models"`
Profiles map[string][]string `yaml:"profiles"` Profiles map[string][]string `yaml:"profiles"`
+90
View File
@@ -2,11 +2,21 @@ package proxy
import ( import (
"container/ring" "container/ring"
"fmt"
"io" "io"
"os" "os"
"sync" "sync"
) )
type LogLevel int
const (
LevelDebug LogLevel = iota
LevelInfo
LevelWarn
LevelError
)
type LogMonitor struct { type LogMonitor struct {
clients map[chan []byte]bool clients map[chan []byte]bool
mu sync.RWMutex mu sync.RWMutex
@@ -15,6 +25,10 @@ type LogMonitor struct {
// typically this can be os.Stdout // typically this can be os.Stdout
stdout io.Writer stdout io.Writer
// logging levels
level LogLevel
prefix string
} }
func NewLogMonitor() *LogMonitor { func NewLogMonitor() *LogMonitor {
@@ -26,6 +40,8 @@ func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
clients: make(map[chan []byte]bool), clients: make(map[chan []byte]bool),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout, stdout: stdout,
level: LevelInfo,
prefix: "",
} }
} }
@@ -94,3 +110,77 @@ func (w *LogMonitor) broadcast(msg []byte) {
} }
} }
} }
func (w *LogMonitor) SetPrefix(prefix string) {
w.mu.Lock()
defer w.mu.Unlock()
w.prefix = prefix
}
func (w *LogMonitor) SetLogLevel(level LogLevel) {
w.mu.Lock()
defer w.mu.Unlock()
w.level = level
}
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
prefix := ""
if w.prefix != "" {
prefix = fmt.Sprintf("[%s] ", w.prefix)
}
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
}
func (w *LogMonitor) log(level LogLevel, msg string) {
if level < w.level {
return
}
w.Write(w.formatMessage(level.String(), msg))
}
func (w *LogMonitor) Debug(msg string) {
w.log(LevelDebug, msg)
}
func (w *LogMonitor) Info(msg string) {
w.log(LevelInfo, msg)
}
func (w *LogMonitor) Warn(msg string) {
w.log(LevelWarn, msg)
}
func (w *LogMonitor) Error(msg string) {
w.log(LevelError, msg)
}
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
w.log(LevelDebug, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Infof(format string, args ...interface{}) {
w.log(LevelInfo, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
w.log(LevelWarn, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
w.log(LevelError, fmt.Sprintf(format, args...))
}
func (l LogLevel) String() string {
switch l {
case LevelDebug:
return "DEBUG"
case LevelInfo:
return "INFO"
case LevelWarn:
return "WARN"
case LevelError:
return "ERROR"
default:
return "UNKNOWN"
}
}
+114 -112
View File
@@ -33,8 +33,12 @@ type Process struct {
ID string ID string
config ModelConfig config ModelConfig
cmd *exec.Cmd cmd *exec.Cmd
logMonitor *LogMonitor
processLogger *LogMonitor
proxyLogger *LogMonitor
healthCheckTimeout int healthCheckTimeout int
healthCheckLoopInterval time.Duration
lastRequestHandled time.Time lastRequestHandled time.Time
@@ -51,54 +55,68 @@ type Process struct {
shutdownCancel context.CancelFunc shutdownCancel context.CancelFunc
} }
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process { func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &Process{ return &Process{
ID: ID, ID: ID,
config: config, config: config,
cmd: nil, cmd: nil,
logMonitor: logMonitor, processLogger: processLogger,
proxyLogger: proxyLogger,
healthCheckTimeout: healthCheckTimeout, healthCheckTimeout: healthCheckTimeout,
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
state: StateStopped, state: StateStopped,
shutdownCtx: ctx, shutdownCtx: ctx,
shutdownCancel: cancel, shutdownCancel: cancel,
} }
} }
func (p *Process) setState(newState ProcessState) error { // LogMonitor returns the log monitor associated with the process.
// enforce valid state transitions func (p *Process) LogMonitor() *LogMonitor {
invalidTransition := false return p.processLogger
if p.state == StateStopped {
// stopped -> starting
if newState != StateStarting {
invalidTransition = true
}
} else if p.state == StateStarting {
// starting -> ready | failed | stopping
if newState != StateReady && newState != StateFailed && newState != StateStopping {
invalidTransition = true
}
} else if p.state == StateReady {
// ready -> stopping
if newState != StateStopping {
invalidTransition = true
}
} else if p.state == StateStopping {
// stopping -> stopped | shutdown
if newState != StateStopped && newState != StateShutdown {
invalidTransition = true
}
} else if p.state == StateFailed || p.state == StateShutdown {
invalidTransition = true
} }
if invalidTransition { // custom error types for swapping state
//panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState)) var (
return fmt.Errorf("invalid state transition from %s to %s", p.state, newState) ErrExpectedStateMismatch = errors.New("expected state mismatch")
ErrInvalidStateTransition = errors.New("invalid state transition")
)
// swapState performs a compare and swap of the state atomically. It returns the current state
// and an error if the swap failed.
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) {
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if p.state != expectedState {
return p.state, ErrExpectedStateMismatch
} }
if !isValidTransition(p.state, newState) {
p.proxyLogger.Warnf("Invalid state transition from %s to %s", p.state, newState)
return p.state, ErrInvalidStateTransition
}
p.proxyLogger.Debugf("State transition from %s to %s", expectedState, newState)
p.state = newState p.state = newState
return nil return p.state, nil
}
// Helper function to encapsulate transition rules
func isValidTransition(from, to ProcessState) bool {
switch from {
case StateStopped:
return to == StateStarting
case StateStarting:
return to == StateReady || to == StateFailed || to == StateStopping
case StateReady:
return to == StateStopping
case StateStopping:
return to == StateStopped || to == StateShutdown
case StateFailed, StateShutdown:
return false // No transitions allowed from these states
}
return false
} }
func (p *Process) CurrentState() ProcessState { func (p *Process) CurrentState() ProcessState {
@@ -116,62 +134,48 @@ func (p *Process) start() error {
return fmt.Errorf("can not start(), upstream proxy missing") return fmt.Errorf("can not start(), upstream proxy missing")
} }
// wait for the other start() to complete
curState := p.CurrentState()
if curState == StateReady {
return nil
}
if curState == StateStarting {
p.waitStarting.Wait()
if state := p.CurrentState(); state != StateReady {
return fmt.Errorf("start() failed current state: %v", state)
}
return nil
}
// There is the possibility of a hard to replicate race condition where
// curState *WAS* StateStopped but by the time we get to the p.stateMutex.Lock()
// below, it's value has changed!
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
// with the exclusive lock, check if p.state is StateStopped, which is the only valid state
// to transition from to StateReady
if p.state != StateStopped {
if p.state == StateReady {
return nil
} else {
return fmt.Errorf("start() can not proceed expected StateReady but process is in %v", p.state)
}
}
if err := p.setState(StateStarting); err != nil {
return err
}
p.waitStarting.Add(1)
defer p.waitStarting.Done()
args, err := p.config.SanitizedCommand() args, err := p.config.SanitizedCommand()
if err != nil { if err != nil {
return fmt.Errorf("unable to get sanitized command: %v", err) return fmt.Errorf("unable to get sanitized command: %v", err)
} }
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
if err == ErrExpectedStateMismatch {
// already starting, just wait for it to complete and expect
// it to be be in the Ready start after. If not, return an error
if curState == StateStarting {
p.waitStarting.Wait()
if state := p.CurrentState(); state == StateReady {
return nil
} else {
return fmt.Errorf("process was already starting but wound up in state %v", state)
}
} else {
return fmt.Errorf("processes was in state %v when start() was called", curState)
}
} else {
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
}
}
p.waitStarting.Add(1)
defer p.waitStarting.Done()
p.cmd = exec.Command(args[0], args[1:]...) p.cmd = exec.Command(args[0], args[1:]...)
p.cmd.Stdout = p.logMonitor p.cmd.Stdout = p.processLogger
p.cmd.Stderr = p.logMonitor p.cmd.Stderr = p.processLogger
p.cmd.Env = p.config.Env p.cmd.Env = p.config.Env
err = p.cmd.Start() err = p.cmd.Start()
// Set process state to failed
if err != nil { if err != nil {
p.setState(StateFailed) if curState, swapErr := p.swapState(StateStarting, StateFailed); swapErr != nil {
return fmt.Errorf(
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
err, curState, swapErr,
)
}
return fmt.Errorf("start() failed: %v", err) return fmt.Errorf("start() failed: %v", err)
} }
@@ -206,31 +210,35 @@ func (p *Process) start() error {
) )
defer cancelHealthCheck() defer cancelHealthCheck()
// Health check loop
loop: loop:
// Ready Check loop
for { for {
select { select {
case <-checkDeadline.Done(): case <-checkDeadline.Done():
p.setState(StateFailed) if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("health check failed after %vs", maxDuration.Seconds()) return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState)
} else {
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
}
case <-p.shutdownCtx.Done(): case <-p.shutdownCtx.Done():
return errors.New("health check interrupted due to shutdown") return errors.New("health check interrupted due to shutdown")
default: default:
if err := p.checkHealthEndpoint(healthURL); err == nil { if err := p.checkHealthEndpoint(healthURL); err == nil {
p.proxyLogger.Infof("Health check passed on %s", healthURL)
cancelHealthCheck() cancelHealthCheck()
break loop break loop
} else { } else {
if strings.Contains(err.Error(), "connection refused") { if strings.Contains(err.Error(), "connection refused") {
endTime, _ := checkDeadline.Deadline() endTime, _ := checkDeadline.Deadline()
ttl := time.Until(endTime) ttl := time.Until(endTime)
fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds()) p.proxyLogger.Infof("Connection refused on %s, retrying in %.0fs", healthURL, ttl.Seconds())
} else { } else {
fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err) p.proxyLogger.Infof("Health check error on %s, %v", healthURL, err)
} }
} }
} }
<-time.After(5 * time.Second) <-time.After(p.healthCheckLoopInterval)
} }
} }
@@ -241,7 +249,7 @@ func (p *Process) start() error {
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
for range time.Tick(time.Second) { for range time.Tick(time.Second) {
if p.state != StateReady { if p.CurrentState() != StateReady {
return return
} }
@@ -249,7 +257,8 @@ func (p *Process) start() error {
p.inFlightRequests.Wait() p.inFlightRequests.Wait()
if time.Since(p.lastRequestHandled) > maxDuration { if time.Since(p.lastRequestHandled) > maxDuration {
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
p.proxyLogger.Infof("Unloading model %s, TTL of %ds reached.", p.ID, p.config.UnloadAfter)
p.Stop() p.Stop()
return return
} }
@@ -257,26 +266,28 @@ func (p *Process) start() error {
}() }()
} }
return p.setState(StateReady) if curState, err := p.swapState(StateStarting, StateReady); err != nil {
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
} else {
return nil
}
} }
func (p *Process) Stop() { func (p *Process) Stop() {
// wait for any inflight requests before proceeding // wait for any inflight requests before proceeding
p.inFlightRequests.Wait() p.inFlightRequests.Wait()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
// calling Stop() when state is invalid is a no-op // calling Stop() when state is invalid is a no-op
if err := p.setState(StateStopping); err != nil { if curState, err := p.swapState(StateReady, StateStopping); err != nil {
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err) p.proxyLogger.Infof("Stop() Ready -> StateStopping err: %v, current state: %v", err, curState)
return return
} }
// stop the process with a graceful exit timeout // stop the process with a graceful exit timeout
p.stopCommand(5 * time.Second) p.stopCommand(5 * time.Second)
if err := p.setState(StateStopped); err != nil { if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err)) p.proxyLogger.Infof("Stop() StateStopping -> StateStopped err: %v, current state: %v", err, curState)
} }
} }
@@ -284,19 +295,9 @@ func (p *Process) Stop() {
// of time for any inflight requests to complete before shutting down. If the Process // of time for any inflight requests to complete before shutting down. If the Process
// is in the state of starting, it will cancel it and shut it down // is in the state of starting, it will cancel it and shut it down
func (p *Process) Shutdown() { func (p *Process) Shutdown() {
// cancel anything that can be interrupted by a shutdown (ie: healthcheck)
p.shutdownCancel() p.shutdownCancel()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
p.setState(StateStopping)
// 5 seconds to stop the process
p.stopCommand(5 * time.Second) p.stopCommand(5 * time.Second)
if err := p.setState(StateShutdown); err != nil { p.state = StateShutdown
fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err)
}
p.setState(StateShutdown)
} }
// stopCommand will send a SIGTERM to the process and wait for it to exit. // stopCommand will send a SIGTERM to the process and wait for it to exit.
@@ -311,31 +312,32 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
}() }()
if p.cmd == nil || p.cmd.Process == nil { if p.cmd == nil || p.cmd.Process == nil {
fmt.Fprintf(p.logMonitor, "!!! process [%s] cmd or cmd.Process is nil", p.ID) p.proxyLogger.Warnf("Process [%s] cmd or cmd.Process is nil", p.ID)
return return
} }
p.cmd.Process.Signal(syscall.SIGTERM) if err := p.terminateProcess(); err != nil {
p.proxyLogger.Infof("Failed to gracefully terminate process [%s]: %v", p.ID, err)
}
select { select {
case <-sigtermTimeout.Done(): case <-sigtermTimeout.Done():
fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID) p.proxyLogger.Infof("Process [%s] timed out waiting to stop, sending KILL signal", p.ID)
p.cmd.Process.Kill() p.cmd.Process.Kill()
case err := <-sigtermNormal: case err := <-sigtermNormal:
if err != nil { if err != nil {
if errno, ok := err.(syscall.Errno); ok { if errno, ok := err.(syscall.Errno); ok {
fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno) p.proxyLogger.Errorf("Process [%s] errno >> %v", p.ID, errno)
} else if exitError, ok := err.(*exec.ExitError); ok { } else if exitError, ok := err.(*exec.ExitError); ok {
if strings.Contains(exitError.String(), "signal: terminated") { if strings.Contains(exitError.String(), "signal: terminated") {
fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID) p.proxyLogger.Infof("Process [%s] stopped OK", p.ID)
} else if strings.Contains(exitError.String(), "signal: interrupt") { } else if strings.Contains(exitError.String(), "signal: interrupt") {
fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID) p.proxyLogger.Infof("Process [%s] interrupted OK", p.ID)
} else { } else {
fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode()) p.proxyLogger.Warnf("Process [%s] ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
} }
} else { } else {
fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err) p.proxyLogger.Errorf("Process [%s] exited >> %v", p.ID, err)
} }
} }
} }
+9
View File
@@ -0,0 +1,9 @@
//go:build !windows
package proxy
import "syscall"
func (p *Process) terminateProcess() error {
return p.cmd.Process.Signal(syscall.SIGTERM)
}
+14
View File
@@ -0,0 +1,14 @@
//go:build windows
package proxy
import (
"fmt"
"os/exec"
)
func (p *Process) terminateProcess() error {
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
return cmd.Run()
}
+40 -35
View File
@@ -5,7 +5,6 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -13,13 +12,17 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var (
discardLogger = NewLogMonitorWriter(io.Discard)
)
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
// Create a process // Create a process
process := NewProcess("test-process", 5, config, logMonitor) process := NewProcess("test-process", 5, config, discardLogger, discardLogger)
defer process.Stop() defer process.Stop()
req := httptest.NewRequest("GET", "/test", nil) req := httptest.NewRequest("GET", "/test", nil)
@@ -52,11 +55,10 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
// are all handled successfully, even though they all may ask for the process to .start() // are all handled successfully, even though they all may ask for the process to .start()
func TestProcess_WaitOnMultipleStarts(t *testing.T) { func TestProcess_WaitOnMultipleStarts(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("test-process", 5, config, logMonitor) process := NewProcess("test-process", 5, config, discardLogger, discardLogger)
defer process.Stop() defer process.Stop()
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -84,7 +86,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
CheckEndpoint: "/health", CheckEndpoint: "/health",
} }
process := NewProcess("broken", 1, config, NewLogMonitor()) process := NewProcess("broken", 1, config, discardLogger, discardLogger)
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -109,7 +111,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
config.UnloadAfter = 3 // seconds config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter) assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard)) process := NewProcess("ttl_test", 2, config, discardLogger, discardLogger)
defer process.Stop() defer process.Stop()
// this should take 4 seconds // this should take 4 seconds
@@ -151,7 +153,7 @@ func TestProcess_LowTTLValue(t *testing.T) {
config.UnloadAfter = 1 // second config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter) assert.Equal(t, 1, config.UnloadAfter)
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout)) process := NewProcess("ttl", 2, config, discardLogger, discardLogger)
defer process.Stop() defer process.Stop()
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
@@ -178,7 +180,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
expectedMessage := "12345" expectedMessage := "12345"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout)) process := NewProcess("t", 10, config, discardLogger, discardLogger)
defer process.Stop() defer process.Stop()
results := map[string]string{ results := map[string]string{
@@ -225,39 +227,40 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
} }
} }
func TestSetState(t *testing.T) { func TestProcess_SwapState(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
currentState ProcessState currentState ProcessState
expectedState ProcessState
newState ProcessState newState ProcessState
expectedError error expectedError error
expectedResult ProcessState expectedResult ProcessState
}{ }{
{"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting}, {"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
{"Starting to Ready", StateStarting, StateReady, nil, StateReady}, {"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
{"Starting to Failed", StateStarting, StateFailed, nil, StateFailed}, {"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed},
{"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping}, {"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
{"Ready to Stopping", StateReady, StateStopping, nil, StateStopping}, {"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
{"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped}, {"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
{"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown}, {"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
{"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped}, {"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
{"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting}, {"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting},
{"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady}, {"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
{"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady}, {"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady},
{"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping}, {"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
{"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed}, {"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed},
{"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed}, {"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed},
{"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown}, {"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
{"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown}, {"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
p := &Process{ p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), discardLogger, discardLogger)
state: test.currentState, p.state = test.currentState
}
err := p.setState(test.newState) resultState, err := p.swapState(test.expectedState, test.newState)
if err != nil && test.expectedError == nil { if err != nil && test.expectedError == nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} else if err == nil && test.expectedError != nil { } else if err == nil && test.expectedError != nil {
@@ -268,8 +271,8 @@ func TestSetState(t *testing.T) {
} }
} }
if p.state != test.expectedResult { if resultState != test.expectedResult {
t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state) t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
} }
}) })
} }
@@ -280,7 +283,6 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
t.Skip("skipping long shutdown test") t.Skip("skipping long shutdown test")
} }
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
// make a config where the healthcheck will always fail because port is wrong // make a config where the healthcheck will always fail because port is wrong
@@ -288,13 +290,16 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
config.Proxy = "http://localhost:9998/test" config.Proxy = "http://localhost:9998/test"
healthCheckTTLSeconds := 30 healthCheckTTLSeconds := 30
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor) process := NewProcess("test-process", healthCheckTTLSeconds, config, discardLogger, discardLogger)
// make it a lot faster
process.healthCheckLoopInterval = time.Second
// start a goroutine to simulate a shutdown // start a goroutine to simulate a shutdown
var wg sync.WaitGroup var wg sync.WaitGroup
go func() { go func() {
defer wg.Done() defer wg.Done()
<-time.After(time.Second * 2) <-time.After(time.Millisecond * 500)
process.Shutdown() process.Shutdown()
}() }()
wg.Add(1) wg.Add(1)
+196 -16
View File
@@ -5,7 +5,9 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"mime/multipart"
"net/http" "net/http"
"os"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -26,19 +28,47 @@ type ProxyManager struct {
config *Config config *Config
currentProcesses map[string]*Process currentProcesses map[string]*Process
logMonitor *LogMonitor
ginEngine *gin.Engine ginEngine *gin.Engine
// logging
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
muxLogger *LogMonitor
} }
func New(config *Config) *ProxyManager { func New(config *Config) *ProxyManager {
// set up loggers
stdoutLogger := NewLogMonitorWriter(os.Stdout)
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
proxyLogger := NewLogMonitorWriter(stdoutLogger)
if config.LogRequests {
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
}
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
case "debug":
proxyLogger.SetLogLevel(LevelDebug)
case "info":
proxyLogger.SetLogLevel(LevelInfo)
case "warn":
proxyLogger.SetLogLevel(LevelWarn)
case "error":
proxyLogger.SetLogLevel(LevelError)
default:
proxyLogger.SetLogLevel(LevelInfo)
}
pm := &ProxyManager{ pm := &ProxyManager{
config: config, config: config,
currentProcesses: make(map[string]*Process), currentProcesses: make(map[string]*Process),
logMonitor: NewLogMonitor(),
ginEngine: gin.New(), ginEngine: gin.New(),
proxyLogger: proxyLogger,
muxLogger: stdoutLogger,
upstreamLogger: upstreamLogger,
} }
if config.LogRequests {
pm.ginEngine.Use(func(c *gin.Context) { pm.ginEngine.Use(func(c *gin.Context) {
// Start timer // Start timer
start := time.Now() start := time.Now()
@@ -57,9 +87,8 @@ func New(config *Config) *ProxyManager {
statusCode := c.Writer.Status() statusCode := c.Writer.Status()
bodySize := c.Writer.Size() bodySize := c.Writer.Size()
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n", pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
clientIP, clientIP,
time.Now().Format("2006-01-02 15:04:05"),
method, method,
path, path,
c.Request.Proto, c.Request.Proto,
@@ -69,16 +98,26 @@ func New(config *Config) *ProxyManager {
duration, duration,
) )
}) })
}
// see: https://github.com/mostlygeek/llama-swap/issues/42 // see: issue: #81, #77 and #42 for CORS issues
// respond with permissive OPTIONS for any endpoint // respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) { pm.ginEngine.Use(func(c *gin.Context) {
if c.Request.Method == "OPTIONS" { if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS") c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.AbortWithStatus(204) // allow whatever the client requested by default
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
c.Header("Access-Control-Allow-Headers", sanitized)
} else {
c.Header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, Accept, X-Requested-With",
)
}
c.Header("Access-Control-Max-Age", "86400")
c.AbortWithStatus(http.StatusNoContent)
return return
} }
c.Next() c.Next()
@@ -95,6 +134,7 @@ func New(config *Config) *ProxyManager {
// Support audio/speech endpoint // Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler) pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler) pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
@@ -102,12 +142,16 @@ func New(config *Config) *ProxyManager {
pm.ginEngine.GET("/logs", pm.sendLogsHandlers) pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE) pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/upstream", pm.upstreamIndex) pm.ginEngine.GET("/upstream", pm.upstreamIndex)
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream) pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler) pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
pm.ginEngine.GET("/", func(c *gin.Context) { pm.ginEngine.GET("/", func(c *gin.Context) {
// Set the Content-Type header to text/html // Set the Content-Type header to text/html
c.Header("Content-Type", "text/html") c.Header("Content-Type", "text/html")
@@ -259,19 +303,20 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
requestedProcessKey := ProcessKeyName(profileName, realModelName) requestedProcessKey := ProcessKeyName(profileName, realModelName)
if process, found := pm.currentProcesses[requestedProcessKey]; found { if process, found := pm.currentProcesses[requestedProcessKey]; found {
pm.proxyLogger.Debugf("No-swap, using existing process for model [%s]", requestedModel)
return process, nil return process, nil
} }
// stop all running models // stop all running models
pm.proxyLogger.Infof("Swapping model to [%s]", requestedModel)
pm.stopProcesses() pm.stopProcesses()
if profileName == "" { if profileName == "" {
modelConfig, modelID, found := pm.config.FindConfig(realModelName) modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found { if !found {
return nil, fmt.Errorf("could not find configuration for %s", realModelName) return nil, fmt.Errorf("could not find configuration for %s", realModelName)
} }
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor) process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
processKey := ProcessKeyName(profileName, modelID) processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process pm.currentProcesses[processKey] = process
} else { } else {
@@ -282,7 +327,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName) return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
} }
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor) process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
processKey := ProcessKeyName(profileName, modelID) processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process pm.currentProcesses[processKey] = process
} }
@@ -347,12 +392,21 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
} }
if process, err := pm.swapModel(requestedModel); err != nil { process, err := pm.swapModel(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error())) pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return return
} else { }
// strip // issue #69 allow custom model names to be sent to upstream
if process.config.UseModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return
}
} else {
profileName, modelName := splitRequestedModel(requestedModel) profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" { if profileName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName) bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
@@ -361,6 +415,7 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
return return
} }
} }
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
@@ -369,7 +424,110 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes))) c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
process.ProxyRequest(c.Writer, c.Request) process.ProxyRequest(c.Writer, c.Request)
} }
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
return
}
// Get model parameter from the form
requestedModel := c.Request.FormValue("model")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
return
}
// Swap to the requested model
process, err := pm.swapModel(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return
}
// Get profile name and model name from the requested model
profileName, modelName := splitRequestedModel(requestedModel)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" {
if process.config.UseModelName != "" {
fieldValue = process.config.UseModelName
} else if profileName != "" {
fieldValue = modelName
}
}
field, err := multipartWriter.CreateFormField(key)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
return
}
if _, err = field.Write([]byte(fieldValue)); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
return
}
}
}
// Copy all files from the original request
for key, fileHeaders := range c.Request.MultipartForm.File {
for _, fileHeader := range fileHeaders {
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
return
}
file, err := fileHeader.Open()
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
return
}
if _, err = io.Copy(formFile, file); err != nil {
file.Close()
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
return
}
file.Close()
}
}
// Close the multipart writer to finalize the form
if err := multipartWriter.Close(); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
return
}
// Create a new request with the reconstructed form data
modifiedReq, err := http.NewRequestWithContext(
c.Request.Context(),
c.Request.Method,
c.Request.URL.String(),
&requestBuffer,
)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
return
}
// Copy the headers from the original request
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// Use the modified request for proxying
process.ProxyRequest(c.Writer, modifiedReq)
} }
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) { func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
@@ -387,6 +545,28 @@ func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
c.String(http.StatusOK, "OK") c.String(http.StatusOK, "OK")
} }
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response.
for _, process := range pm.currentProcesses {
// Append the process ID and State (multiple entries if profiles are being used).
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
})
}
// Put the results under the `running` key.
response := gin.H{
"running": runningProcesses,
}
context.JSON(http.StatusOK, response) // Always return 200 OK
}
func ProcessKeyName(groupName, modelName string) string { func ProcessKeyName(groupName, modelName string) string {
return groupName + PROFILE_SPLIT_CHAR + modelName return groupName + PROFILE_SPLIT_CHAR + modelName
} }
+37 -8
View File
@@ -9,7 +9,6 @@ import (
) )
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) { func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
accept := c.GetHeader("Accept") accept := c.GetHeader("Accept")
if strings.Contains(accept, "text/html") { if strings.Contains(accept, "text/html") {
// Set the Content-Type header to text/html // Set the Content-Type header to text/html
@@ -28,7 +27,7 @@ func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
} }
} else { } else {
c.Header("Content-Type", "text/plain") c.Header("Content-Type", "text/plain")
history := pm.logMonitor.GetHistory() history := pm.muxLogger.GetHistory()
_, err := c.Writer.Write(history) _, err := c.Writer.Write(history)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
@@ -42,8 +41,14 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
c.Header("Transfer-Encoding", "chunked") c.Header("Transfer-Encoding", "chunked")
c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Content-Type-Options", "nosniff")
ch := pm.logMonitor.Subscribe() logMonitorId := c.Param("logMonitorID")
defer pm.logMonitor.Unsubscribe(ch) logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done() notify := c.Request.Context().Done()
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
@@ -56,7 +61,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// Send history first if not skipped // Send history first if not skipped
if !skipHistory { if !skipHistory {
history := pm.logMonitor.GetHistory() history := logger.GetHistory()
if len(history) != 0 { if len(history) != 0 {
c.Writer.Write(history) c.Writer.Write(history)
flusher.Flush() flusher.Flush()
@@ -85,15 +90,21 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Content-Type-Options", "nosniff")
ch := pm.logMonitor.Subscribe() logMonitorId := c.Param("logMonitorID")
defer pm.logMonitor.Unsubscribe(ch) logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done() notify := c.Request.Context().Done()
// Send history first if not skipped // Send history first if not skipped
_, skipHistory := c.GetQuery("no-history") _, skipHistory := c.GetQuery("no-history")
if !skipHistory { if !skipHistory {
history := pm.logMonitor.GetHistory() history := logger.GetHistory()
if len(history) != 0 { if len(history) != 0 {
c.SSEvent("message", string(history)) c.SSEvent("message", string(history))
c.Writer.Flush() c.Writer.Flush()
@@ -111,3 +122,21 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
} }
} }
} }
// getLogger searches for the appropriate logger based on the logMonitorId
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
var logger *LogMonitor
if logMonitorId == "" {
// maintain the default
logger = pm.muxLogger
} else if logMonitorId == "proxy" {
logger = pm.proxyLogger
} else if logMonitorId == "upstream" {
logger = pm.upstreamLogger
} else {
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
}
return logger, nil
}
+359
View File
@@ -4,6 +4,8 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math/rand"
"mime/multipart"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync" "sync"
@@ -349,3 +351,360 @@ func TestProxyManager_StripProfileSlug(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "ok") assert.Contains(t, w.Body.String(), "ok")
} }
// Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
}
// Define a helper struct to parse the JSON response.
type RunningResponse struct {
Running []struct {
Model string `json:"model"`
State string `json:"state"`
} `json:"running"`
}
// Create proxy once for all tests
proxy := New(config)
defer proxy.StopProcesses()
t.Run("no models loaded", func(t *testing.T) {
req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response RunningResponse
// Check if this is a valid JSON object.
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// We should have an empty running array here.
assert.Empty(t, response.Running, "expected no running models")
})
t.Run("single model loaded", func(t *testing.T) {
// Load just a model.
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Simulate browser call for the `/running` endpoint.
req = httptest.NewRequest("GET", "/running", nil)
w = httptest.NewRecorder()
proxy.HandlerFunc(w, req)
var response RunningResponse
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// Check if we have a single array element.
assert.Len(t, response.Running, 1)
// Is this the right model?
assert.Equal(t, "model1", response.Running[0].Model)
// Is the model loaded?
assert.Equal(t, "ready", response.Running[0].State)
})
t.Run("multiple models via profile", func(t *testing.T) {
// Load more than one model.
for _, model := range []string{"model1", "model2"} {
profileModel := ProcessKeyName("test", model)
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
// Simulate the browser call.
req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
var response RunningResponse
// The JSON response must be valid.
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// The response should contain 2 models.
assert.Len(t, response.Running, 2)
expectedModels := map[string]struct{}{
"model1": {},
"model2": {},
}
// Iterate through the models and check their states as well.
for _, entry := range response.Running {
_, exists := expectedModels[entry.Model]
assert.True(t, exists, "unexpected model %s", entry.Model)
assert.Equal(t, "ready", entry.State)
delete(expectedModels, entry.Model)
}
// Since we deleted each model while testing for its validity we should have no more models in the response.
assert.Empty(t, expectedModels, "unexpected additional models in response")
})
}
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"},
},
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
testCases := []struct {
name string
modelInput string
expectModel string
}{
{
name: "With Profile Prefix",
modelInput: "test:TheExpectedModel",
expectModel: "TheExpectedModel", // Profile prefix should be stripped
},
{
name: "Without Profile Prefix",
modelInput: "TheExpectedModel",
expectModel: "TheExpectedModel", // Should remain the same
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(tc.modelInput))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
// Generate random content length between 10 and 20
contentLength := rand.Intn(11) + 10 // 10 to 20
content := make([]byte, contentLength)
_, err = fw.Write(content)
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, tc.expectModel, response["model"])
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
})
}
}
func TestProxyManager_SplitRequestedModel(t *testing.T) {
tests := []struct {
name string
requestedModel string
expectedProfile string
expectedModel string
}{
{"no profile", "gpt-4", "", "gpt-4"},
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
{"only profile", "profile1:", "profile1", ""},
{"empty model", ":gpt-4", "", "gpt-4"},
{"empty profile", ":", "", ""},
{"no split char", "gpt-4", "", "gpt-4"},
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
profileName, modelName := splitRequestedModel(tt.requestedModel)
if profileName != tt.expectedProfile {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
if modelName != tt.expectedModel {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
})
}
}
// Test useModelName in configuration sends overrides what is sent to upstream
func TestProxyManager_UseModelName(t *testing.T) {
upstreamModelName := "upstreamModel"
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
modelConfig.UseModelName = upstreamModelName
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1"},
},
Models: map[string]ModelConfig{
"model1": modelConfig,
},
}
proxy := New(config)
defer proxy.StopProcesses()
tests := []struct {
description string
requestedModel string
}{
{"useModelName over rides requested model", "model1"},
{"useModelName over rides requested profile:model", "test:model1"},
}
for _, tt := range tests {
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), upstreamModelName)
})
}
for _, tt := range tests {
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(tt.requestedModel))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
_, err = fw.Write([]byte("test"))
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, upstreamModelName, response["model"])
})
}
}
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogRequests: true,
}
tests := []struct {
name string
method string
requestHeaders map[string]string
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "OPTIONS with no headers",
method: "OPTIONS",
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
},
},
{
name: "OPTIONS with specific headers",
method: "OPTIONS",
requestHeaders: map[string]string{
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
},
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
},
},
{
name: "Non-OPTIONS request",
method: "GET",
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proxy := New(config)
defer proxy.StopProcesses()
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
for k, v := range tt.requestHeaders {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()
proxy.ginEngine.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
for header, expectedValue := range tt.expectedHeaders {
assert.Equal(t, expectedValue, w.Header().Get(header))
}
})
}
}
+43
View File
@@ -0,0 +1,43 @@
package proxy
import (
"strings"
)
func isTokenChar(r rune) bool {
switch {
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r >= '0' && r <= '9':
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
default:
return false
}
return true
}
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
parts := strings.Split(headerValues, ",")
valid := make([]string, 0, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v == "" {
continue
}
validPart := true
for _, c := range v {
if !isTokenChar(c) {
validPart = false
break
}
}
if validPart {
valid = append(valid, v)
}
}
return strings.Join(valid, ", ")
}
+77
View File
@@ -0,0 +1,77 @@
package proxy
import "testing"
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "empty string",
input: "",
expected: "",
},
{
name: "whitespace only",
input: " ",
expected: "",
},
{
name: "single valid value",
input: "content-type",
expected: "content-type",
},
{
name: "multiple valid values",
input: "content-type, authorization, x-requested-with",
expected: "content-type, authorization, x-requested-with",
},
{
name: "values with extra spaces",
input: " content-type , authorization ",
expected: "content-type, authorization",
},
{
name: "values with tabs",
input: "content-type,\tauthorization",
expected: "content-type, authorization",
},
{
name: "values with invalid characters",
input: "content-type, auth\n, x-requested-with\r",
expected: "content-type, auth, x-requested-with",
},
{
name: "empty values in list",
input: "content-type,,authorization",
expected: "content-type, authorization",
},
{
name: "leading and trailing commas",
input: ",content-type,authorization,",
expected: "content-type, authorization",
},
{
name: "mixed valid and invalid values",
input: "content-type, \x00invalid, x-requested-with",
expected: "content-type, x-requested-with",
},
{
name: "mixed case values",
input: "Content-Type, my-Valid-Header, Another-hEader",
expected: "Content-Type, my-Valid-Header, Another-hEader",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SanitizeAccessControlRequestHeaderValues(tt.input)
if got != tt.expected {
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
tt.input, got, tt.expected)
}
})
}
}