Compare commits

...

138 Commits

Author SHA1 Message Date
Benson Wong b8f888f864 Logging Improvements (#88)
This change revamps the internal logging architecture to be more flexible and descriptive. Previously all logs from both llama-swap and upstream services were mixed together. This makes it harder to troubleshoot and identify problems. This PR adds these new endpoints: 

- `/logs/stream/proxy` - just llama-swap's logs
- `/logs/stream/upstream` - stdout output from the upstream server
2025-04-04 21:01:33 -07:00
Benson Wong 192b2ae621 Remove no longer needed test 2025-04-04 14:46:01 -07:00
Benson Wong b7f8cb5094 Limit Access-Control-Allow-Origin to OPTIONS preflight requests #85 2025-04-04 14:44:35 -07:00
Benson Wong a23da6eb57 Sanitize CORS headers (#85)
Add sanitation step for `Access-Control-Allow-Headers` when echoing back user supplied headers
2025-04-01 08:43:53 -07:00
Grigorii Khvatskii 4c3aa40564 add graceful process termination on windows (#82) 2025-03-25 15:26:33 -07:00
Benson Wong 84e2c07a7e Refactor wildcard out of CORS headers (#81)
Changes to CORS functionality: 

- `Access-Control-Allow-Origin: *` is set for all requests 
- for pre-flight OPTIONS requests
  - specify methods: `Access-Control-Allow-Methods: GET, POST, PUT, PATCH, DELETE, OPTIONS`
  - if the client sent `Access-Control-Request-Headers` then echo back the same value in `Access-Control-Allow-Headers`. If no `Access-Control-Request-Headers` were sent, then send back a default set
  - set `Access-Control-Max-Age: 86400` to that may improve performance 
- Add CORS tests to the proxy-manager
2025-03-25 15:24:43 -07:00
Benson Wong 680af28bcc Allow very permissive CORS headers (#77) 2025-03-20 15:50:21 -07:00
Benson Wong d94db42ffe fix bug checking incorrect error 2025-03-20 15:49:36 -07:00
Benson Wong 93cd83c55c add override for windows (#76) 2025-03-20 13:23:04 -07:00
Benson Wong 5565fca3ac add some badges to README 2025-03-19 11:25:06 -07:00
Benson Wong d625ab8d92 Refactor process state management (#70) (#73)
* add isValidStateTransition helper function
* Replace Process.setState() with Process.swapState()
* Refactor locking logic in Process
2025-03-15 17:14:03 -07:00
Benson Wong a3f82c140b tidy up config examples in README 2025-03-15 10:36:45 -07:00
Benson Wong 5c97299e7b Add support for sending a custom model name to upstream (#69) (#71)
* add test for splitRequestedModel()
* Add `useModelName` parameter to model configuration
* add docs to README
2025-03-14 21:07:52 -07:00
Benson Wong 671c1a5a7b update deps 2025-03-13 14:00:15 -07:00
Benson Wong 52c0196e0f clean up feature list in readme 2025-03-13 13:55:20 -07:00
Benson Wong 3201a68a04 Add /v1/audio/transcriptions support (#41)
* add support for /v1/audio/transcriptions
2025-03-13 13:49:39 -07:00
Florin-Gabriel Dumitru 3ac94ad20e Adds an endpoint '/running' (#61)
* Adds an endpoint '/running' that returns either an empty JSON object if no model has been loaded so far, or the last model loaded (model key) and it's current state (state key). Possible state values are: stopped, starting, ready and stopping.

* Improves the `/running` endpoint by allowing multiple entries under the `running` key within the JSON response.
Refactors the `/running` method name (listRunningProcessesHandler).
Removes the unlisted filter implementation.

* Adds tests for:
- no model loaded
- one model loaded
- multiple models loaded

* Adds simple comments.

* Simplified code structure as per 250313 comments on PR #65.

---------

Co-authored-by: FGDumitru|B <xelotx@gmail.com>
2025-03-13 13:42:59 -07:00
Benson Wong 60355bf74a fix some potentially confusing Process.start() comment 2025-03-11 11:00:45 -07:00
Benson Wong 9b2ed244e2 Improve Continuous integration and fix concurrency bugs (#66)
- improvements to the continuous GH actions
- fix edge case concurrency bugs with Process.start() and state transitions discovered setting up CI.
2025-03-11 10:39:14 -07:00
Benson Wong eeb72297f7 add first version of CI for go 2025-03-11 08:45:56 -07:00
Benson Wong eabfe70cc6 add GH action to close inactive issues 2025-03-09 19:51:48 -07:00
Benson Wong 29cd98878d better container build logic when upstream containers do not exist 2025-03-09 13:02:06 -07:00
Benson Wong b3d331da0d Properly strip profile name slug from models fixes (#62)
The profile slug in a model name, `profile:model`, is specific to
llama-swap. This strips `profile:` out of the model name request so
upstreams that expect just `model` work and do not require knowing about
the profile slug.
2025-03-09 12:41:52 -07:00
Benson Wong 62275e078d add examples to restart on config change #59 2025-03-06 10:50:29 -08:00
Benson Wong 88916059e1 add /unload to docs 2025-03-03 10:44:16 -08:00
Benson Wong 082d5d0fc5 Add /unload endpoint (#58) to unload all currently running models 2025-03-03 10:33:36 -08:00
Benson Wong 53338938bd increase health check to a minimum of 5 seconds 2025-03-03 10:04:08 -08:00
Benson Wong af653347ae Update README.md w/ starhistory graph 2025-02-27 16:43:34 -08:00
Benson Wong 1e25b44a06 add workflow_dispatch to release action 2025-02-18 17:27:43 -08:00
Benson Wong 0815bb4cc3 Add windows to goreleaser #54 2025-02-18 17:26:43 -08:00
daschiller 7187cfe52e add Windows build support to Makefile (#54) 2025-02-18 17:24:31 -08:00
Benson Wong 24089d2d9c remove "no musa container" note from README 2025-02-18 16:38:48 -08:00
Benson Wong ebabe55ff3 Delete untagged packages after build and push (#55) 2025-02-18 10:32:32 -08:00
Benson Wong 41a338297c deletion of untagged containers happen after build-and-push 2025-02-18 10:11:59 -08:00
Benson Wong 7e3353efeb add action step to remove untagged containers 2025-02-18 10:08:41 -08:00
Benson Wong 4ed58fb173 update container build action 2025-02-18 09:59:06 -08:00
Benson Wong f5a2be698d revert package src until new ggml-org has them 2025-02-15 18:23:58 -08:00
Benson Wong f5e6ec3b7a fix package src in containerfile 2025-02-15 18:20:35 -08:00
Benson Wong 3f462da146 switch package source from ggerganov to ggml-org 2025-02-15 18:18:49 -08:00
Benson Wong 48bd766536 Update README.md 2025-02-14 22:05:52 -08:00
Benson Wong 8d319da4dd improve README organization (i think...) 2025-02-14 15:59:12 -08:00
Benson Wong be7c502448 improve docs 2025-02-14 15:47:31 -08:00
Benson Wong 92336f00bf more container build fixes 2025-02-14 15:34:38 -08:00
Benson Wong ed2a50d9a6 fix bug in build-container.sh 2025-02-14 15:27:56 -08:00
Benson Wong 0acfdb9f78 update workflow to build cpu and disable musa 2025-02-14 15:26:59 -08:00
Benson Wong 96a8ea0241 add cpu docker container build 2025-02-14 15:25:45 -08:00
Benson Wong f20f2c9b7a add docs and container build improvements #43 2025-02-14 12:20:07 -08:00
Benson Wong 7a97c38828 enable parallel container built #46 2025-02-14 11:04:33 -08:00
Benson Wong 4885132565 more permissions futzing 2025-02-14 11:02:15 -08:00
Benson Wong 8b46a0b7f1 grant package:write to container workflow #46 2025-02-14 10:55:30 -08:00
Benson Wong 1b6736ec6f rename workflow for containers 2025-02-14 10:50:15 -08:00
Benson Wong ddc1ce031e fix container file name #46 2025-02-14 10:49:44 -08:00
Benson Wong 11d024bbaa just build cuda while debugging 2025-02-14 10:48:06 -08:00
Benson Wong 43e23c16dc add check for GITHUB_TOKEN #46 2025-02-14 10:47:25 -08:00
Benson Wong f9c8e763ba add execute bit on build-container.sh 2025-02-14 10:44:53 -08:00
Benson Wong d7e1bb9f7c add GITHUB_TOKEN to container build env 2025-02-14 10:43:44 -08:00
Benson Wong ab93460a8b first container code (#52) 2025-02-14 10:39:25 -08:00
Benson Wong 13d4552edc Add FreeBSD/amd64 to auto built releases (#51) 2025-02-13 16:44:31 -08:00
Benson Wong 6667e307a2 Update README.md 2025-02-08 10:28:35 -08:00
Benson Wong 7ac446e6a9 Update README.md 2025-02-08 10:26:11 -08:00
Benson Wong eab9795bcc remove panic() when cmd or process is nil 2025-02-07 14:00:32 -08:00
Benson Wong 09bdd86b54 Improve shutdown behaviour (#47) (#49)
Introduce `Process.Shutdown()` and `ProxyManager.Shutdown()`. These two function required a lot of internal process state management refactoring. A key benefit is that `Process.start()` is now interruptable. When `Shutdown()` is called it will break the long health check loop. 

State management within Process is also improved. Added `starting`, `stopping` and `shutdown` states. Additionally, introduced a simple finite state machine to manage transitions.
2025-02-05 17:19:59 -08:00
Benson Wong 85cd74a51c Improve process start and stop reliability (#38)
Refactor Process.start()/Stop() logic (#38)
- remove cmd.Wait() call in start(). This seems to conflict with the one
  in .Stop(). Removing it eliminated no child errors
- eliminate goroutines in .start() as it no longer required
2025-02-03 11:50:38 -08:00
Benson Wong 314d2f2212 remove cmd_stop configuration and functionality from PR #40 (#44)
* remove cmd_stop functionality from #40
2025-01-31 12:42:44 -08:00
Benson Wong fad25f3e11 Use client request context in proxy request (#43)
Canceled or closed HTTP requests from clients will also stop the proxied
HTTP requests to upstreamed servers.
2025-01-31 10:21:49 -08:00
Benson Wong 2c3e3e27f7 Support OPTIONS requests (#42)
Add middleware that responds with permissive OPTIONS headers
for all request paths.
2025-01-31 10:09:07 -08:00
Benson Wong baeb0c4e7f Add cmd_stop configuration to better support docker (#35)
Add `cmd_stop` to model configuration to run a command instead of sending a SIGTERM to shutdown a process before swapping.
2025-01-30 16:59:57 -08:00
Benson Wong 2833517eef Improve handling of process that do not handle SIGTERM (#38)
- Process TTL goroutine did not have a return after .Stop()
- Improve logging
- Add test TestProcess_LowTTLValue to measure SIGTERM error rate
2025-01-20 14:39:52 -08:00
Benson Wong abdc2bfdb3 Fix panic when requesting non-members of profiles
A panic occurs when a request for an invalid profile:model pair is made.
The edge case is that the profile exists and the model exists but they're
not configured as a pair.

This adds an additional check to make sure the profile:model pair is
valid before attempting to swap the model.
2025-01-16 12:06:38 -08:00
Benson Wong c3b834737f Update README.md 2025-01-13 22:37:30 -08:00
Benson Wong 3c8e727b73 Update README.md 2025-01-12 19:48:35 -08:00
Benson Wong 3a1e9f81f1 support TTS /v1/audio/speech (#36) 2025-01-12 16:27:01 -08:00
Benson Wong 72c883f36c Update README.md 2025-01-02 09:01:51 -08:00
Benson Wong 1b04d034cf Update README.md 2025-01-02 08:59:11 -08:00
Benson Wong 2e45f5692a Update README.md
Improve README documentation.
2025-01-01 12:51:24 -08:00
Benson Wong c97b80bdfe Update README.md 2025-01-01 12:25:45 -08:00
Benson Wong ae3ef9bc39 Refactor UI (#33)
- add html to / instead of 404
- add client side regex to /logs
2024-12-23 19:48:59 -08:00
Benson Wong db6715bec3 update golang.org/x/net -> v0.33.0 for dependabot 2024-12-20 11:28:32 -08:00
Benson Wong da5d9e8a6a fix HTTP logging so true path is printed 2024-12-20 11:25:01 -08:00
Benson Wong 84b667ca7a improve logging and error reporting for troubleshooting 2024-12-20 10:46:56 -08:00
Benson Wong 29657106fc add more OpenAI API supported in README 2024-12-20 10:08:20 -08:00
Benson Wong 9c8860471e support v1/rerank endpoint 2024-12-17 21:22:25 -08:00
Benson Wong 9b4e3f307e rename proxy handler 2024-12-17 17:25:10 -08:00
Benson Wong 6fe37c3abf support /v1/embeddings (#4) 2024-12-17 17:25:10 -08:00
Benson Wong 7f45493a37 Update README.md 2024-12-17 14:45:41 -08:00
Benson Wong 891f6a5b5a Add /upstream endpoint (#30)
* remove catch-all route to upstream proxy (it was broken anyways)
* add /upstream/:model_id to swap and route to upstream path
* add /upstream HTML endpoint and unlisted option
* add /upstream endpoint to show a list of available models
* add `unlisted` configuration option to omit a model from /v1/models and /upstream lists
* add favicon.ico
2024-12-17 14:37:44 -08:00
Benson Wong 7183f6b43d fix bad logging due to wrong []byte used #28 2024-12-16 16:22:14 -08:00
Benson Wong d89bfeb441 add .DS_Store to .gitignore 2024-12-16 12:30:31 -08:00
Benson Wong 9a0c6bed40 Improve stop exceptions (#28) (#29)
Stop Process TTL goroutine when process is not ready (#28)

- fix issue where the goroutine will continue even though the child
  process is no longer running and the Process' state is not Ready
- fix issue where some logs were going to stdout instead of p.logMonitor
  causing them to not show up in the /logs
- add units to unloading model message
2024-12-16 12:29:25 -08:00
Benson Wong d6ca535939 tweak release tagging so it is not based on number of commits 2024-12-14 15:46:10 -08:00
Benson Wong 27302c0c02 change llama-swap to use goreleaser default ldflag values 2024-12-14 10:30:06 -08:00
Benson Wong d4e22cceaa Fix security vulnerability with golang.org/x/crypto
- does not affect the project as llama-swap does not use the crypto
  libraries
- good practice to keep security deps updated!
2024-12-14 10:20:22 -08:00
Benson Wong 4c94927658 Move release to Makefile out of goreleaser
- less complexity
- easier
- goreleaser, github, pipelines: 1...  mostlygeek: 0
2024-12-14 10:16:46 -08:00
Benson Wong a955a4a5c0 create tag to release 2024-12-14 10:07:20 -08:00
Benson Wong 22d3f1a4f9 Change versioning to use git commits counts instead of semver
- less work for me
- more frequent releases
2024-12-14 09:53:13 -08:00
Benson Wong e2443251ad update readme 2024-12-09 19:14:49 -08:00
Benson Wong 5fbd53c616 delay TTL check until after all requests are complete (#25)
- fixes #25 where requests that last longer than the TTL will cause the
  process to be unloaded before the next request.
- new behavior, TTL waits until all requests are complete before
  checking timeout
2024-12-09 19:08:03 -08:00
Benson Wong 97dae50dc4 update readme 2024-12-08 21:34:16 -08:00
Benson Wong cb978f760f add web interface to /logs 2024-12-08 21:26:22 -08:00
Benson Wong 387f0ef6c4 use new timings data in server response in run-benchmark.sh 2024-12-03 20:48:36 -08:00
Benson Wong 18c134624d Add Access-Control-Allow-Origin CORS header to /v1/models endpoint
- match behavior of llama.cpp where the Origin in request is used
- add test for listModelsHandler
2024-12-03 15:53:59 -08:00
Benson Wong da2326bdc7 add example: optimizing code generation 2024-12-03 10:25:43 -08:00
Benson Wong da46545630 fix profile example in README 2024-12-01 10:13:31 -08:00
Benson Wong 04b4760e7e change profile split character to : (colon) (#21)
- change from `/` to `:` for multiple models loaded as part of a profile
- breaking change now, but allows for more compatibility with other inference engines that may have model references like `coding:Qwen/Qwen-2.5-Coder-32B`
2024-12-01 09:10:50 -08:00
Benson Wong 9fc5d5b5eb improve cmd parsing (#22)
Switch from using a naive strings.Fields() to shlex.Split() for parsing the model startup command into a string[]. This makes parsing much more reliable around newlines, quotes, etc.
2024-12-01 09:02:58 -08:00
Benson Wong cf82b3c633 Improve Concurrency and Parallel Request Handling (#19)
Rewrite the swap behaviour so that in-flight requests block process swapping until they are completed. 

Additionally: 

- add tests for parallel requests with proxy.ProxyManager and proxy.Process
- improve Process startup behaviour and simplified the code 
- stopping of processes are sent SIGTERM and have 5 seconds to terminate, before they are killed
2024-11-30 15:24:42 -08:00
Benson Wong e363f8f498 clean up writing with AI :b 2024-11-28 22:12:44 -08:00
Benson Wong c9629cf3a2 add speculative decoding example 2024-11-28 22:07:22 -08:00
Benson Wong 50426935a4 . 2024-11-28 22:06:29 -08:00
Benson Wong 2fceb78e8d Add examples 2024-11-28 22:05:41 -08:00
Ikko Eltociear Ashimine 9a81c53664 chore: update process_test.go (#17)
nonexistant -> nonexistent
2024-11-26 10:20:16 -08:00
Benson Wong 716d37de82 Update README.md
fix grammar
2024-11-25 12:35:00 -08:00
Benson Wong 73ad85ea69 Implement Multi-Process Handling (#7)
Refactor code to support starting of multiple back end llama.cpp servers. This functionality is exposed as `profiles` to create a simple configuration format. 

Changes: 

* refactor proxy tests to get ready for multi-process support
* update proxy/ProxyManager to support multiple processes (#7)
* Add support for Groups in configuration
* improve handling of Model alias configs
* implement multi-model swapping
* improve code clarity for swapModel
* improve docs, rename groups to profiles in config
2024-11-23 19:45:13 -08:00
Benson Wong 533162ce6a add support for automatically unloading a model (#10) (#14)
* Make starting upstream process on-demand (#10)
* Add automatic unload of model after TTL is reached
* add `ttl` configuration parameter to models in seconds, default is 0 (never unload)
2024-11-19 16:32:51 -08:00
Benson Wong ba39ed4c18 Add support for legacy v1/completions API (#12) 2024-11-19 09:57:39 -08:00
Benson Wong 21f54f96c2 Merge pull request #13 from mostlygeek/set-content-length
Dechunk HTTP requests by default (#11)
2024-11-19 09:46:03 -08:00
Benson Wong 7eec51f3f2 Dechunk HTTP requests by default (#11)
ProxyManager already has all the Request body's data. There is no never
a need to use chunked transfer encoding to the upstream process.
2024-11-19 09:40:44 -08:00
Benson Wong 5021e0f299 remove the process handler override 2024-11-18 21:26:39 -08:00
Benson Wong c9233d2c9a use gin instead of standard http lib in main 2024-11-18 15:58:28 -08:00
Benson Wong a33ac6f8fb update README 2024-11-18 15:37:50 -08:00
Benson Wong 401aa88949 move log handlers to separate file 2024-11-18 15:33:06 -08:00
Benson Wong e9e88fd229 rename proxy.go to proxymanager.go 2024-11-18 15:30:34 -08:00
Benson Wong c3b4bb1684 use gin for http server 2024-11-18 15:30:16 -08:00
Benson Wong e5c909ddf7 add tests for proxy.Process 2024-11-17 20:49:14 -08:00
Benson Wong 36a31f450f add proxy.Process to manage upstream proxy logic 2024-11-17 16:41:15 -08:00
Benson Wong a8e5ee13b9 Add logging with pipes example to README 2024-11-15 09:10:43 -08:00
Benson Wong 5944a86e86 fix early timeout bug 2024-11-09 20:08:40 -08:00
Benson Wong 63d4a7d0eb Improve LogMonitor to handle empty writes and ensure buffer immutability
- Add a check to return immediately if the write buffer is empty
- Create a copy of new history data to ensure it is immutable
- Update the `GetHistory` method to use the `any` type for the buffer interface
- Add a test case to verify that the buffer remains unchanged
  even if the original message is modified after writing
2024-11-02 10:41:23 -07:00
Benson Wong f45469f7ff Merge pull request #8 from mostlygeek/improve-upstream-monitoring-issue5
Improvements to handling of the upstream process so errors happen whenever one of these is first:

    the health check timeout is reached waiting for the upstream process to be ready
    the upstream process exits unexpectedly

With this change llama-swap is more compatible with use cases like containerized upstream services (#5) which pull the container before HTTP endpoints are ready.
2024-11-01 15:28:06 -07:00
Benson Wong 34f9fd7340 Improve timeout and exit handling of child processes. fix #3 and #5
llama-swap only waited a maximum of 5 seconds for an upstream
HTTP server to be available. If it took longer than that it will error
out the request. Now it will wait up to the configured healthCheckTimeout
or the upstream process unexpectedly exits.
2024-11-01 14:32:39 -07:00
Benson Wong 8448efa7fc revise health check logic to not error on 5 second timeout 2024-11-01 09:42:37 -07:00
Benson Wong 8cf2a389d8 Refactor log implementation
- use []byte instead of unnecessary string conversions
- make LogManager.Broadcast private
- make LogManager.GetHistory public
- add tests
2024-10-31 12:16:54 -07:00
Benson Wong 0f133f5b74 Add /logs endpoint to monitor upstream processes
- outputs last 10KB of logs from upstream processes
- supports streaming
2024-10-30 21:02:30 -07:00
Benson Wong 1510b3fbd9 clean up README 2024-10-22 10:37:45 -07:00
Benson Wong 0f8a8e70f1 add header image 2024-10-22 10:30:30 -07:00
Benson Wong 6c3819022c Add compatibility with OpenAI /v1/models endpoint to list models 2024-10-21 15:38:12 -07:00
Benson Wong 8580f0f733 Merge pull request #6 from mostlygeek/multiline-config
Support multiline cmds in YAML configuration
2024-10-19 20:07:36 -07:00
Benson Wong be82d1a6a0 Support multiline cmds in YAML configuration
Add support for multiline `cmd` configurations allowing for nicer looking configuration YAML files.
2024-10-19 20:06:59 -07:00
44 changed files with 4322 additions and 313 deletions
+23
View File
@@ -0,0 +1,23 @@
# https://docs.github.com/en/actions/use-cases-and-examples/project-management/closing-inactive-issues
name: Close inactive issues
on:
schedule:
- cron: "32 1 * * *"
jobs:
close-issues:
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v9
with:
days-before-issue-stale: 30
days-before-issue-close: 14
stale-issue-label: "stale"
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
days-before-pr-stale: -1
days-before-pr-close: -1
repo-token: ${{ secrets.GITHUB_TOKEN }}
+46
View File
@@ -0,0 +1,46 @@
name: Build Containers
on:
# time has no specific meaning, trying to time it after
# the llama.cpp daily packages are published
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
schedule:
- cron: "37 5 * * *"
# Allows manual triggering of the workflow
workflow_dispatch:
jobs:
build-and-push:
runs-on: ubuntu-latest
strategy:
matrix:
platform: [intel, cuda, vulkan, cpu, musa]
fail-fast: false
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Log in to GitHub Container Registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Run build-container
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: ./docker/build-container.sh ${{ matrix.platform }} true
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
# see: https://github.com/actions/delete-package-versions/issues/74
delete-untagged-containers:
needs: build-and-push
runs-on: ubuntu-latest
steps:
- uses: actions/delete-package-versions@v5
with:
package-name: 'llama-swap'
package-type: 'container'
delete-only-untagged-versions: 'true'
+32
View File
@@ -0,0 +1,32 @@
# This workflow will build a golang project
name: CI
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
# Allows manual triggering of the workflow
workflow_dispatch:
jobs:
run-tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.23'
# necessary for testing proxy/Process swapping
- name: Create simple-responder
run: make simple-responder
- name: Test all
run: make test-all
+4 -1
View File
@@ -5,6 +5,9 @@ on:
tags:
- '*'
# Allows manual triggering of the workflow
workflow_dispatch:
permissions:
contents: write
@@ -30,4 +33,4 @@ jobs:
version: '~> v2'
args: release --clean
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+2 -1
View File
@@ -2,4 +2,5 @@
.env
build/
dist/
.vscode
.vscode
.DS_Store
+20 -1
View File
@@ -6,6 +6,25 @@ builds:
goos:
- linux
- darwin
- freebsd
- windows
goarch:
- amd64
- arm64
- arm64
ignore:
- goos: freebsd
goarch: arm64
- goos: windows
goarch: arm64
# use zip format for windows
archives:
- id: default
format: tar.gz
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
builds_info:
group: root
owner: root
format_overrides:
- goos: windows
format: zip
+41 -5
View File
@@ -2,6 +2,16 @@
APP_NAME = llama-swap
BUILD_DIR = build
# Get the current Git hash
GIT_HASH := $(shell git rev-parse --short HEAD)
ifneq ($(shell git status --porcelain),)
# There are untracked changes
GIT_HASH := $(GIT_HASH)+
endif
# Capture the current build date in RFC3339 format
BUILD_DATE := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
# Default target: Builds binaries for both OSX and Linux
all: mac linux simple-responder
@@ -9,24 +19,50 @@ all: mac linux simple-responder
clean:
rm -rf $(BUILD_DIR)
test:
go test -short -v ./proxy
test-all:
go test -v ./proxy
# Build OSX binary
mac:
@echo "Building Mac binary..."
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
# Build Linux binary
linux:
@echo "Building Linux binary..."
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
# for testing things
# Build Windows binary
windows:
@echo "Building Windows binary..."
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
# for testing proxy.Process
simple-responder:
@echo "Building simple responder"
go build -o $(BUILD_DIR)/simple-responder misc/simple-responder/simple-responder.go
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
# Ensure build directory exists
$(BUILD_DIR):
mkdir -p $(BUILD_DIR)
# Create a new release tag
release:
@echo "Checking for unstaged changes..."
@if [ -n "$(shell git status --porcelain)" ]; then \
echo "Error: There are unstaged changes. Please commit or stash your changes before creating a release tag." >&2; \
exit 1; \
fi
# Get the highest tag in v{number} format, increment it, and create a new tag
@highest_tag=$$(git tag --sort=-v:refname | grep -E '^v[0-9]+$$' | head -n 1 || echo "v0"); \
new_tag="v$$(( $${highest_tag#v} + 1 ))"; \
echo "tagging new version: $$new_tag"; \
git tag "$$new_tag";
# Phony targets
.PHONY: all clean osx linux
.PHONY: all clean mac linux windows simple-responder
+223 -31
View File
@@ -1,62 +1,252 @@
![llama-swap header image](header.jpeg)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml)
![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap)
# llama-swap
[llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, so let's swap llama-server instead!
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
llama-swap is a proxy server that sits in front of llama-server. When a request for `/v1/chat/completions` comes in it will extract the `model` requested and change the underlying llama-server automatically.
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
- ✅ easy to deploy: single binary with no dependencies
- ✅ full control over llama-server's startup settings
-❤️ for nvidia P40 users who are rely on llama.cpp for inference
## Features:
-Easy to deploy: single binary with no dependencies
- ✅ Easy to config: single yaml file
- ✅ On-demand model switching
- ✅ OpenAI API supported endpoints:
- `v1/completions`
- `v1/chat/completions`
- `v1/embeddings`
- `v1/rerank`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
- ✅ llama-swap custom API endpoints
- `/log` - remote log monitoring
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
- ✅ Automatic unloading of models after timeout by setting a `ttl`
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Docker and Podman support
- ✅ Full control over server settings per model
## How does llama-swap work?
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
## config.yaml
llama-swap's configuration purposefully simple.
llama-swap's configuration is purposefully simple.
```yaml
models:
"qwen2.5":
proxy: "http://127.0.0.1:9999"
cmd: >
/app/llama-server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
--port 9999
"smollm2":
proxy: "http://127.0.0.1:9999"
cmd: >
/app/llama-server
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
--port 9999
```
<details>
<summary>But also very powerful ...</summary>
```yaml
# Seconds to wait for llama.cpp to load and be ready to serve requests
# Default (and minimum) is 15 seconds
healthCheckTimeout: 60
# Valid log levels: debug, info (default), warn, error
logLevel: info
# define valid model values and the upstream server start
models:
"llama":
cmd: "llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf"
# multiline for readability
cmd: >
llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
# Where to proxy to, important it matches this format
proxy: "http://127.0.0.1:8999"
# aliases model names to use this configuration for
aliases:
- "gpt-4o-mini"
- "gpt-3.5-turbo"
# wait for this path to return an HTTP 200 before serving requests
# defaults to /health to match llama.cpp
#
# use "none" to skip endpoint checking. This may cause requests to fail
# until the server is ready
checkEndpoint: "/custom-endpoint"
"qwen":
# environment variables to pass to the command
env:
- "CUDA_VISIBLE_DEVICES=0"
cmd: "llama-server --port 8999 -m path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf"
proxy: "http://127.0.0.1:8999"
# where to reach the server started by cmd, make sure the ports match
proxy: http://127.0.0.1:8999
# aliases names to use this model for
aliases:
- "gpt-4o-mini"
- "gpt-3.5-turbo"
# check this path for an HTTP 200 OK before serving requests
# default: /health to match llama.cpp
# use "none" to skip endpoint checking, but may cause HTTP errors
# until the model is ready
checkEndpoint: /custom-endpoint
# automatically unload the model after this many seconds
# ttl values must be a value greater than 0
# default: 0 = never unload model
ttl: 60
# `useModelName` overrides the model name in the request
# and sends a specific name to the upstream server
useModelName: "qwen:qwq"
# unlisted models do not show up in /v1/models or /upstream lists
# but they can still be requested as normal
"qwen-unlisted":
unlisted: true
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
# Docker Support (v26.1.4+ required!)
"docker-llama":
proxy: "http://127.0.0.1:9790"
cmd: >
docker run --name dockertest
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# profiles eliminates swapping by running multiple models at the same time
#
# Tips:
# - each model must be listening on a unique address and port
# - the model name is in this format: "profile_name:model", like "coding:qwen"
# - the profile will load and unload all models in the profile at the same time
profiles:
coding:
- "llama"
- "qwen-unlisted"
```
## Installation
### Use Case Examples
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
## Configuration
llama-s
</details>
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Docker is the quickest way to try out llama-swap:
```
# use CPU inference
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
# qwen2.5 0.5B
$ curl -s http://localhost:9292/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer no-key" \
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
jq -r '.choices[0].message.content'
# SmolLM2 135M
$ curl -s http://localhost:9292/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer no-key" \
-d '{"model":"smollm2","messages": [{"role": "user","content": "tell me a joke"}]}' | \
jq -r '.choices[0].message.content'
```
<details>
<summary>Docker images are nightly ...</summary>
They include:
- `ghcr.io/mostlygeek/llama-swap:cpu`
- `ghcr.io/mostlygeek/llama-swap:cuda`
- `ghcr.io/mostlygeek/llama-swap:intel`
- `ghcr.io/mostlygeek/llama-swap:vulkan`
- ROCm disabled until fixed in llama.cpp container
Specific versions are also available and are tagged with the llama-swap, architecture and llama.cpp versions. For example: `ghcr.io/mostlygeek/llama-swap:v89-cuda-b4716`
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
```
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
ghcr.io/mostlygeek/llama-swap:cuda
```
</details>
## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases))
Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server.
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
* _Note: Windows currently untested._
1. Run the binary with `llama-swap --config path/to/config.yaml`
### Building from source
1. Install golang for your system
1. `git clone git@github.com:mostlygeek/llama-swap.git`
1. `make clean all`
1. Binaries will be in `build/` subdirectory
## Monitoring Logs
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
Of course, CLI access is also supported:
```
# sends up to the last 10KB of logs
curl http://host/logs'
# streams combined logs
curl -Ns 'http://host/logs/stream'
# just llama-swap's logs
curl -Ns 'http://host/logs/stream/proxy'
# just upstream's logs
curl -Ns 'http://host/logs/stream/upstream'
# stream and filter logs with linux pipes
curl -Ns http://host/logs/stream | grep 'eval time'
# skips history and just streams new log entries
curl -Ns 'http://host/logs/stream?no-history'
```
## Do I need to use llama.cpp's server (llama-server)?
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
## Systemd Unit Files
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
`/etc/systemd/system/llama-swap.service`
```
[Unit]
Description=llama-swap
@@ -77,8 +267,10 @@ StartLimitInterval=30
WantedBy=multi-user.target
```
## Building from Source
## Star History
1. Install golang for your system
1. run `make clean all`
1. binaries will be built into `build/` directory
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
</picture>
+69 -16
View File
@@ -1,39 +1,92 @@
# Seconds to wait for llama.cpp to be available to serve requests
# Default (and minimum): 15 seconds
healthCheckTimeout: 60
healthCheckTimeout: 90
# valid log levels: debug, info (default), warn, error
logLevel: info
models:
"llama":
cmd: "models/llama-server-osx --port 8999 -m models/Llama-3.2-1B-Instruct-Q4_K_M.gguf"
proxy: "http://127.0.0.1:8999"
cmd: >
models/llama-server-osx
--port 9001
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
proxy: http://127.0.0.1:9001
# list of model name aliases this llama.cpp instance can serve
aliases:
- "gpt-4o-mini"
- gpt-4o-mini
# check this path for a HTTP 200 response for the server to be ready
checkEndpoint: "/health"
checkEndpoint: /health
# unload model after 5 seconds
ttl: 5
"qwen":
cmd: "models/llama-server-osx --port 8999 -m models/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf"
proxy: "http://127.0.0.1:8999"
cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
proxy: http://127.0.0.1:9002
aliases:
- "gpt-3.5-turbo"
- gpt-3.5-turbo
# Embedding example with Nomic
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
"nomic":
proxy: http://127.0.0.1:9005
cmd: >
models/llama-server-osx --port 9005
-m models/nomic-embed-text-v1.5.Q8_0.gguf
--ctx-size 8192
--batch-size 8192
--rope-scaling yarn
--rope-freq-scale 0.75
-ngl 99
--embeddings
# Reranking example with bge-reranker
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
"bge-reranker":
proxy: http://127.0.0.1:9006
cmd: >
models/llama-server-osx --port 9006
-m models/bge-reranker-v2-m3-Q4_K_M.gguf
--ctx-size 8192
--reranking
# Docker Support (v26.1.4+ required!)
"dockertest":
proxy: "http://127.0.0.1:9790"
cmd: >
docker run --name dockertest
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
"simple":
# example of setting environment variables
env:
- "CUDA_VISIBLE_DEVICES=0,1"
- "env1=hello"
cmd: "build/simple-responder --port 8999"
proxy: "http://127.0.0.1:8999"
- CUDA_VISIBLE_DEVICES=0,1
- env1=hello
cmd: build/simple-responder --port 8999
proxy: http://127.0.0.1:8999
unlisted: true
# use "none" to skip check. Caution this may cause some requests to fail
# until the upstream server is ready for traffic
checkEndpoint: "none"
checkEndpoint: none
# don't use this, just for testing if things are broken
# don't use these, just for testing if things are broken
"broken":
cmd: "models/llama-server-osx --port 8999 -m models/doesnotexist.gguf"
proxy: "http://127.0.0.1:8999"
cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf
proxy: http://127.0.0.1:8999
unlisted: true
"broken_timeout":
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
proxy: http://127.0.0.1:9000
unlisted: true
# creating a coding profile with models for code generation and general questions
profiles:
coding:
- "qwen"
- "llama"
+55
View File
@@ -0,0 +1,55 @@
#!/bin/bash
cd $(dirname "$0")
ARCH=$1
PUSH_IMAGES=${2:-false}
# List of allowed architectures
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
# Check if ARCH is in the allowed list
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
exit 1
fi
# Check if GITHUB_TOKEN is set and not empty
if [[ -z "$GITHUB_TOKEN" ]]; then
echo "Error: GITHUB_TOKEN is not set or is empty."
exit 1
fi
# the most recent llama-swap tag
# have to strip out the 'v' due to .tar.gz file naming
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
if [ "$ARCH" == "cpu" ]; then
# cpu only containers just use the latest available
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu"
echo "Building ${CONTAINER_LATEST} $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} .
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_LATEST}
fi
else
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
| sort -r | head -n1 | awk -F '-' '{print $3}')
# Abort if LCPP_TAG is empty.
if [[ -z "$LCPP_TAG" ]]; then
echo "Abort: Could not find llama-server container for arch: $ARCH"
exit 1
fi
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
echo "Building ${CONTAINER_TAG} $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_TAG}
docker push ${CONTAINER_LATEST}
fi
fi
+17
View File
@@ -0,0 +1,17 @@
healthCheckTimeout: 300
logRequests: true
models:
"qwen2.5":
proxy: "http://127.0.0.1:9999"
cmd: >
/app/llama-server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
--port 9999
"smollm2":
proxy: "http://127.0.0.1:9999"
cmd: >
/app/llama-server
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
--port 9999
+16
View File
@@ -0,0 +1,16 @@
ARG BASE_TAG=server-cuda
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
# has to be after the FROM
ARG LS_VER=89
WORKDIR /app
RUN \
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
COPY config.example.yaml /app/config.yaml
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
+6
View File
@@ -0,0 +1,6 @@
# Example Configs and Use Cases
A collections of usecases and examples for getting the most out of llama-swap.
* [Speculative Decoding](speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
* [Optimizing Code Generation](benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
+123
View File
@@ -0,0 +1,123 @@
# Optimizing Code Generation with llama-swap
Finding the best mix of settings for your hardware can be time consuming. This example demonstrates using a custom configuration file to automate testing different scenarios to find the an optimal configuration.
The benchmark writes a snake game in Python, TypeScript, and Swift using the Qwen 2.5 Coder models. The experiments were done using a 3090 and a P40.
**Benchmark Scenarios**
Three scenarios are tested:
- 3090-only: Just the main model on the 3090
- 3090-with-draft: the main and draft models on the 3090
- 3090-P40-draft: the main model on the 3090 with the draft model offloaded to the P40
**Available Devices**
Use the following command to list available devices IDs for the configuration:
```
$ /mnt/nvme/llama-server/llama-server-f3252055 --list-devices
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 4 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
Device 1: Tesla P40, compute capability 6.1, VMM: yes
Device 2: Tesla P40, compute capability 6.1, VMM: yes
Device 3: Tesla P40, compute capability 6.1, VMM: yes
Available devices:
CUDA0: NVIDIA GeForce RTX 3090 (24154 MiB, 406 MiB free)
CUDA1: Tesla P40 (24438 MiB, 22942 MiB free)
CUDA2: Tesla P40 (24438 MiB, 24144 MiB free)
CUDA3: Tesla P40 (24438 MiB, 24144 MiB free)
```
**Configuration**
The configuration file, `benchmark-config.yaml`, defines the three scenarios:
```yaml
models:
"3090-only":
proxy: "http://127.0.0.1:9503"
cmd: >
/mnt/nvme/llama-server/llama-server-f3252055
--host 127.0.0.1 --port 9503
--flash-attn
--slots
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
-ngl 99
--device CUDA0
--ctx-size 32768
--cache-type-k q8_0 --cache-type-v q8_0
"3090-with-draft":
proxy: "http://127.0.0.1:9503"
# --ctx-size 28500 max that can fit on 3090 after draft model
cmd: >
/mnt/nvme/llama-server/llama-server-f3252055
--host 127.0.0.1 --port 9503
--flash-attn
--slots
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
-ngl 99
--device CUDA0
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
-ngld 99
--draft-max 16
--draft-min 4
--draft-p-min 0.4
--device-draft CUDA0
--ctx-size 28500
--cache-type-k q8_0 --cache-type-v q8_0
"3090-P40-draft":
proxy: "http://127.0.0.1:9503"
cmd: >
/mnt/nvme/llama-server/llama-server-f3252055
--host 127.0.0.1 --port 9503
--flash-attn --metrics
--slots
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
-ngl 99
--device CUDA0
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
-ngld 99
--draft-max 16
--draft-min 4
--draft-p-min 0.4
--device-draft CUDA1
--ctx-size 32768
--cache-type-k q8_0 --cache-type-v q8_0
```
> Note in the `3090-with-draft` scenario the `--ctx-size` had to be reduced from 32768 to to accommodate the draft model.
**Running the Benchmark**
To run the benchmark, execute the following commands:
1. `llama-swap -config benchmark-config.yaml`
1. `./run-benchmark.sh http://localhost:8080 "3090-only" "3090-with-draft" "3090-P40-draft"`
The [benchmark script](run-benchmark.sh) generates a CSV output of the results, which can be converted to a Markdown table for readability.
**Results (tokens/second)**
| model | python | typescript | swift |
|-----------------|--------|------------|-------|
| 3090-only | 34.03 | 34.01 | 34.01 |
| 3090-with-draft | 106.65 | 70.48 | 57.89 |
| 3090-P40-draft | 81.54 | 60.35 | 46.50 |
Many different factors, like the programming language, can have big impacts on the performance gains. However, with a custom configuration file for benchmarking it is easy to test the different variations to discover what's best for your hardware.
Happy coding!
+40
View File
@@ -0,0 +1,40 @@
#!/usr/bin/env bash
# This script generates a CSV file showing the token/second for generating a Snake Game in python, typescript and swift
# It was created to test the effects of speculative decoding and the various draft settings on performance.
#
# Writing code with a low temperature seems to provide fairly consistent logic.
#
# Usage: ./benchmark.sh <url> <model1> [model2 ...]
# Example: ./benchmark.sh http://localhost:8080 model1 model2
if [ "$#" -lt 2 ]; then
echo "Usage: $0 <url> <model1> [model2 ...]"
exit 1
fi
url=$1; shift
echo "model,python,typescript,swift"
for model in "$@"; do
echo -n "$model,"
for lang in "python" "typescript" "swift"; do
# expects a llama.cpp after PR https://github.com/ggerganov/llama.cpp/pull/10548
# (Dec 3rd/2024)
time=$(curl -s --url "$url/v1/chat/completions" -d "{\"messages\": [{\"role\": \"system\", \"content\": \"you only write code.\"}, {\"role\": \"user\", \"content\": \"write snake game in $lang\"}], \"top_k\": 1, \"timings_per_token\":true, \"model\":\"$model\"}" | jq -r .timings.predicted_per_second)
if [ $? -ne 0 ]; then
time="error"
exit 1
fi
if [ "$lang" != "swift" ]; then
printf "%0.2f tps," $time
else
printf "%0.2f tps\n" $time
fi
done
done
@@ -0,0 +1,51 @@
# Restart llama-swap on config change
Sometimes editing the configuration file can take a bit of trail and error to get a model configuration tuned just right. The `watch-and-restart.sh` script can be used to watch `config.yaml` for changes and restart `llama-swap` when it detects a change.
```bash
#!/bin/bash
#
# A simple watch and restart llama-swap when its configuration
# file changes. Useful for trying out configuration changes
# without manually restarting the server each time.
if [ -z "$1" ]; then
echo "Usage: $0 <path to config.yaml>"
exit 1
fi
while true; do
# Start the process again
./llama-swap-linux-amd64 -config $1 -listen :1867 &
PID=$!
echo "Started llama-swap with PID $PID"
# Wait for modifications in the specified directory or file
inotifywait -e modify "$1"
# Check if process exists before sending signal
if kill -0 $PID 2>/dev/null; then
echo "Sending SIGTERM to $PID"
kill -SIGTERM $PID
wait $PID
else
echo "Process $PID no longer exists"
fi
sleep 1
done
```
## Usage and output example
```bash
$ ./watch-and-restart.sh config.yaml
Started llama-swap with PID 495455
Setting up watches.
Watches established.
llama-swap listening on :1867
Sending SIGTERM to 495455
Shutting down llama-swap
Started llama-swap with PID 495486
Setting up watches.
Watches established.
llama-swap listening on :1867
```
+124
View File
@@ -0,0 +1,124 @@
# Speculative Decoding
Speculative decoding can significantly improve the tokens per second. However, this comes at the cost of increased VRAM usage for the draft model. The examples provided are based on a server with three P40s and one 3090.
## Coding Use Case
This example uses Qwen2.5 Coder 32B with the 0.5B model as a draft. A quantization of Q8_0 was chosen for the draft model, as quantization has a greater impact on smaller models.
The models used are:
* [Bartowski Qwen2.5-Coder-32B-Instruct](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF)
* [Bartowski Qwen2.5-Coder-0.5B-Instruct](https://huggingface.co/bartowski/Qwen2.5-Coder-0.5B-Instruct-GGUF)
The llama-swap configuration is as follows:
```yaml
models:
"qwen-coder-32b-q4":
# main model on 3090, draft on P40 #1
cmd: >
/mnt/nvme/llama-server/llama-server-be0e35
--host 127.0.0.1 --port 9503
--flash-attn --metrics
--slots
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
-ngl 99
--ctx-size 19000
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
-ngld 99
--draft-max 16
--draft-min 4
--draft-p-min 0.4
--device CUDA0
--device-draft CUDA1
proxy: "http://127.0.0.1:9503"
```
In this configuration, two GPUs are used: a 3090 (CUDA0) for the main model and a P40 (CUDA1) for the draft model. Although both models can fit on the 3090, relocating the draft model to the P40 freed up space for a larger context size. Despite the P40 being about 1/3rd the speed of the 3090, the small model still improved tokens per second.
Multiple tests were run with various parameters, and the fastest result was chosen for the configuration. In all tests, the 0.5B model produced the largest improvements to tokens per second.
Baseline: 33.92 tokens/second on 3090 without a draft model.
| draft-max | draft-min | draft-p-min | python | TS | swift |
|-----------|-----------|-------------|--------|----|-------|
| 16 | 1 | 0.9 | 71.64 | 55.55 | 48.06 |
| 16 | 1 | 0.4 | 83.21 | 58.55 | 45.50 |
| 16 | 1 | 0.1 | 79.72 | 55.66 | 43.94 |
| 16 | 2 | 0.9 | 68.47 | 55.13 | 43.12 |
| 16 | 2 | 0.4 | 82.82 | 57.42 | 48.83 |
| 16 | 2 | 0.1 | 81.68 | 51.37 | 45.72 |
| 16 | 4 | 0.9 | 66.44 | 48.49 | 42.40 |
| 16 | 4 | 0.4 | _83.62_ (fastest)| _58.29_ | _50.17_ |
| 16 | 4 | 0.1 | 82.46 | 51.45 | 40.71 |
| 8 | 1 | 0.4 | 67.07 | 55.17 | 48.46 |
| 4 | 1 | 0.4 | 50.13 | 44.96 | 40.79 |
The test script can be found in this [gist](https://gist.github.com/mostlygeek/da429769796ac8a111142e75660820f1). It is a simple curl script that prompts generating a snake game in Python, TypeScript, or Swift. Evaluation metrics were pulled from llama.cpp's logs.
```bash
for lang in "python" "typescript" "swift"; do
echo "Generating Snake Game in $lang using $model"
curl -s --url http://localhost:8080/v1/chat/completions -d "{\"messages\": [{\"role\": \"system\", \"content\": \"you only write code.\"}, {\"role\": \"user\", \"content\": \"write snake game in $lang\"}], \"temperature\": 0.1, \"model\":\"$model\"}" > /dev/null
done
```
Python consistently outperformed Swift in all tests, likely due to the 0.5B draft model being more proficient in generating Python code accepted by the larger 32B model.
## Chat
This configuration is for a regular chat use case. It produces approximately 13 tokens/second in typical use, up from ~9 tokens/second with only the 3xP40s. This is great news for P40 owners.
The models used are:
* [Bartowski Meta-Llama-3.1-70B-Instruct-GGUF](https://huggingface.co/bartowski/Meta-Llama-3.1-70B-Instruct-GGUF)
* [Bartowski Llama-3.2-3B-Instruct-GGUF](https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF)
```yaml
models:
"llama-70B":
cmd: >
/mnt/nvme/llama-server/llama-server-be0e35
--host 127.0.0.1 --port 9602
--flash-attn --metrics
--split-mode row
--ctx-size 80000
--model /mnt/nvme/models/Meta-Llama-3.1-70B-Instruct-Q4_K_L.gguf
-ngl 99
--model-draft /mnt/nvme/models/Llama-3.2-3B-Instruct-Q4_K_M.gguf
-ngld 99
--draft-max 16
--draft-min 1
--draft-p-min 0.4
--device-draft CUDA0
--tensor-split 0,1,1,1
```
In this configuration, Llama-3.1-70B is split across three P40s, and Llama-3.2-3B is on the 3090.
Some flags deserve further explanation:
* `--split-mode row` - increases inference speeds using multiple P40s by about 30%. This is a P40-specific feature.
* `--tensor-split 0,1,1,1` - controls how the main model is split across the GPUs. This means 0% on the 3090 and an even split across the P40s. A value of `--tensor-split 0,5,4,1` would mean 0% on the 3090, 50%, 40%, and 10% respectively across the other P40s. However, this would exceed the available VRAM.
* `--ctx-size 80000` - the maximum context size that can fit in the remaining VRAM.
## What is CUDA0, CUDA1, CUDA2, CUDA3?
These devices are the IDs used by llama.cpp.
```bash
$ ./llama-server --list-devices
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 4 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
Device 1: Tesla P40, compute capability 6.1, VMM: yes
Device 2: Tesla P40, compute capability 6.1, VMM: yes
Device 3: Tesla P40, compute capability 6.1, VMM: yes
Available devices:
CUDA0: NVIDIA GeForce RTX 3090 (24154 MiB, 23892 MiB free)
CUDA1: Tesla P40 (24438 MiB, 24290 MiB free)
CUDA2: Tesla P40 (24438 MiB, 24290 MiB free)
CUDA3: Tesla P40 (24438 MiB, 24290 MiB free)
```
+40 -1
View File
@@ -2,4 +2,43 @@ module github.com/mostlygeek/llama-swap
go 1.23.0
require gopkg.in/yaml.v3 v3.0.1
require (
github.com/gin-gonic/gin v1.10.0
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.37.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
)
+107
View File
@@ -1,4 +1,111 @@
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 261 KiB

+30 -4
View File
@@ -3,31 +3,57 @@ package main
import (
"flag"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/proxy"
)
var version string = "0"
var commit string = "abcd1234"
var date = "unknown"
func main() {
// Define a command-line flag for the port
configPath := flag.String("config", "config.yaml", "config file name")
listenStr := flag.String("listen", ":8080", "listen ip/port")
showVersion := flag.Bool("version", false, "show version of build")
flag.Parse() // Parse the command-line flags
if *showVersion {
fmt.Printf("version: %s (%s), built at %s\n", version, commit, date)
os.Exit(0)
}
config, err := proxy.LoadConfig(*configPath)
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
os.Exit(1)
}
if mode := os.Getenv("GIN_MODE"); mode != "" {
gin.SetMode(mode)
} else {
gin.SetMode(gin.ReleaseMode)
}
proxyManager := proxy.New(config)
http.HandleFunc("/", proxyManager.HandleFunc)
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
fmt.Println("Shutting down llama-swap")
proxyManager.Shutdown()
os.Exit(0)
}()
fmt.Println("llama-swap listening on " + *listenStr)
if err := http.ListenAndServe(*listenStr, nil); err != nil {
fmt.Printf("Error starting server: %v\n", err)
if err := proxyManager.Run(*listenStr); err != nil {
fmt.Printf("Server error: %v\n", err)
os.Exit(1)
}
}
Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

+166 -15
View File
@@ -3,47 +3,198 @@ package main
import (
"flag"
"fmt"
"io"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
func main() {
gin.SetMode(gin.TestMode)
// Define a command-line flag for the port
port := flag.String("port", "8080", "port to listen on")
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
// Define a command-line flag for the response message
responseMessage := flag.String("respond", "hi", "message to respond with")
silent := flag.Bool("silent", false, "disable all logging")
flag.Parse() // Parse the command-line flags
// Set up the handler function using the provided response message
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// Set the header to text/plain
w.Header().Set("Content-Type", "text/plain")
// Create a new Gin router
r := gin.New()
fmt.Fprintln(w, *responseMessage)
// Set up the handler function using the provided response message
r.POST("/v1/chat/completions", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
// add a wait to simulate a slow query
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
time.Sleep(wait)
}
c.String(200, *responseMessage)
})
// for issue #62 to check model name strips profile slug
// has to be one of the openAI API endpoints that llama-swap proxies
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
r.POST("/v1/audio/speech", func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
return
}
defer c.Request.Body.Close()
modelName := gjson.GetBytes(body, "model").String()
if modelName != *expectedModel {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
return
} else {
c.JSON(http.StatusOK, gin.H{"message": "ok"})
}
})
r.POST("/v1/completions", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, *responseMessage)
})
// issue #41
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
// Parse the multipart form
if err := c.Request.ParseMultipartForm(10 << 20); err != nil { // 10 MB max memory
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
return
}
// Get the model from the form values
model := c.Request.FormValue("model")
if model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing model parameter"})
return
}
// Get the file from the form
file, _, err := c.Request.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error getting file: %s", err)})
return
}
defer file.Close()
// Read the file content to get its size
fileBytes, err := io.ReadAll(file)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error reading file: %s", err)})
return
}
fileSize := len(fileBytes)
// Return a JSON response with the model and transcription text including file size
c.JSON(http.StatusOK, gin.H{
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
"model": model,
})
})
r.GET("/slow-respond", func(c *gin.Context) {
echo := c.Query("echo")
delay := c.Query("delay")
if echo == "" {
echo = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
}
// Parse the duration
if delay == "" {
delay = "100ms"
}
t, err := time.ParseDuration(delay)
if err != nil {
c.Header("Content-Type", "text/plain")
c.String(http.StatusBadRequest, fmt.Sprintf("Invalid duration: %s", err))
return
}
c.Header("Content-Type", "text/plain")
for _, char := range echo {
c.Writer.Write([]byte(string(char)))
c.Writer.Flush()
// wait
<-time.After(t)
}
})
r.GET("/test", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, *responseMessage)
})
r.GET("/env", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, *responseMessage)
// Get environment variables
envVars := os.Environ()
// Write each environment variable to the response
for _, envVar := range envVars {
fmt.Fprintln(w, envVar)
c.String(200, envVar)
}
})
// Set up the /health endpoint handler function
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
response := `{"status": "ok"}`
w.Write([]byte(response))
r.GET("/health", func(c *gin.Context) {
c.Header("Content-Type", "application/json")
c.JSON(200, gin.H{"status": "ok"})
})
address := ":" + *port // Address with the specified port
fmt.Printf("Server is listening on port %s\n", *port)
r.GET("/", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
})
// Start the server and log any error if it occurs
if err := http.ListenAndServe(address, nil); err != nil {
fmt.Printf("Error starting server: %s\n", err)
address := "127.0.0.1:" + *port // Address with the specified port
srv := &http.Server{
Addr: address,
Handler: r.Handler(),
}
// Disable logging if the --silent flag is set
if *silent {
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard
log.SetOutput(io.Discard)
}
go func() {
log.Printf("simple-responder listening on %s\n", address)
// service connections
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("simple-responder err: %s\n", err)
}
}()
// Wait for interrupt signal to gracefully shutdown the server with
// a timeout of 5 seconds.
quit := make(chan os.Signal, 1)
// kill (no param) default send syscall.SIGTERM
// kill -2 is syscall.SIGINT
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("simple-responder shutting down")
}
+4
View File
@@ -0,0 +1,4 @@
The rerank-test.json data is from https://github.com/ggerganov/llama.cpp/pull/9510
To run it:
> curl http://127.0.0.1:8080/v1/rerank -H "Content-Type: application/json" -d @reranker-test.json -v | jq .
+17
View File
@@ -0,0 +1,17 @@
{
"model": "bge-reranker",
"query": "Organic skincare products for sensitive skin",
"top_n": 3,
"documents": [
"Organic skincare for sensitive skin with aloe vera and chamomile: Imagine the soothing embrace of nature with our organic skincare range, crafted specifically for sensitive skin. Infused with the calming properties of aloe vera and chamomile, each product provides gentle nourishment and protection. Say goodbye to irritation and hello to a glowing, healthy complexion.",
"New makeup trends focus on bold colors and innovative techniques: Step into the world of cutting-edge beauty with this seasons makeup trends. Bold, vibrant colors and groundbreaking techniques are redefining the art of makeup. From neon eyeliners to holographic highlighters, unleash your creativity and make a statement with every look.",
"Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille: Erleben Sie die wohltuende Wirkung unserer Bio-Hautpflege, speziell für empfindliche Haut entwickelt. Mit den beruhigenden Eigenschaften von Aloe Vera und Kamille pflegen und schützen unsere Produkte Ihre Haut auf natürliche Weise. Verabschieden Sie sich von Hautirritationen und genießen Sie einen strahlenden Teint.",
"Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken: Tauchen Sie ein in die Welt der modernen Schönheit mit den neuesten Make-up-Trends. Kräftige, lebendige Farben und innovative Techniken setzen neue Maßstäbe. Von auffälligen Eyelinern bis hin zu holografischen Highlightern lassen Sie Ihrer Kreativität freien Lauf und setzen Sie jedes Mal ein Statement.",
"Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla: Descubre el poder de la naturaleza con nuestra línea de cuidado de la piel orgánico, diseñada especialmente para pieles sensibles. Enriquecidos con aloe vera y manzanilla, estos productos ofrecen una hidratación y protección suave. Despídete de las irritaciones y saluda a una piel radiante y saludable.",
"Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras: Entra en el fascinante mundo del maquillaje con las tendencias más actuales. Colores vivos y técnicas innovadoras están revolucionando el arte del maquillaje. Desde delineadores neón hasta iluminadores holográficos, desata tu creatividad y destaca en cada look.",
"针对敏感肌专门设计的天然有机护肤产品:体验由芦荟和洋甘菊提取物带来的自然呵护。我们的护肤产品特别为敏感肌设计,温和滋润,保护您的肌肤不受刺激。让您的肌肤告别不适,迎来健康光彩。",
"新的化妆趋势注重鲜艳的颜色和创新的技巧:进入化妆艺术的新纪元,本季的化妆趋势以大胆的颜色和创新的技巧为主。无论是霓虹眼线还是全息高光,每一款妆容都能让您脱颖而出,展现独特魅力。",
"敏感肌のために特別に設計された天然有機スキンケア製品: アロエベラとカモミールのやさしい力で、自然の抱擁を感じてください。敏感肌用に特別に設計された私たちのスキンケア製品は、肌に優しく栄養を与え、保護します。肌トラブルにさようなら、輝く健康な肌にこんにちは。",
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています: 今シーズンのメイクアップトレンドは、大胆な色彩と革新的な技術に注目しています。ネオンアイライナーからホログラフィックハイライターまで、クリエイティビティを解き放ち、毎回ユニークなルックを演出しましょう。"
]
}
+57 -14
View File
@@ -1,8 +1,11 @@
package proxy
import (
"fmt"
"os"
"strings"
"github.com/google/shlex"
"gopkg.in/yaml.v3"
)
@@ -12,29 +15,42 @@ type ModelConfig struct {
Aliases []string `yaml:"aliases"`
Env []string `yaml:"env"`
CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
}
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd)
}
type Config struct {
Models map[string]ModelConfig `yaml:"models"`
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"`
LogLevel string `yaml:"logLevel"`
Models map[string]ModelConfig `yaml:"models"`
Profiles map[string][]string `yaml:"profiles"`
// map aliases to actual model IDs
aliases map[string]string
}
func (c *Config) FindConfig(modelName string) (ModelConfig, bool) {
modelConfig, found := c.Models[modelName]
if found {
return modelConfig, true
func (c *Config) RealModelName(search string) (string, bool) {
if _, found := c.Models[search]; found {
return search, true
} else if name, found := c.aliases[search]; found {
return name, found
} else {
return "", false
}
}
// Search through aliases to find the right config
for _, config := range c.Models {
for _, alias := range config.Aliases {
if alias == modelName {
return config, true
}
}
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
if realName, found := c.RealModelName(modelName); !found {
return ModelConfig{}, "", false
} else {
return c.Models[realName], realName, true
}
return ModelConfig{}, false
}
func LoadConfig(path string) (*Config, error) {
@@ -53,5 +69,32 @@ func LoadConfig(path string) (*Config, error) {
config.HealthCheckTimeout = 15
}
// Populate the aliases map
config.aliases = make(map[string]string)
for modelName, modelConfig := range config.Models {
for _, alias := range modelConfig.Aliases {
config.aliases[alias] = modelName
}
}
return &config, nil
}
func SanitizeCommand(cmdStr string) ([]string, error) {
// Remove trailing backslashes
cmdStr = strings.ReplaceAll(cmdStr, "\\ \n", " ")
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
// Split the command into arguments
args, err := shlex.Split(cmdStr)
if err != nil {
return nil, err
}
// Ensure the command is not empty
if len(args) == 0 {
return nil, fmt.Errorf("empty command")
}
return args, nil
}
+176
View File
@@ -0,0 +1,176 @@
package proxy
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfig_Load(t *testing.T) {
// Create a temporary YAML file for testing
tempDir, err := os.MkdirTemp("", "test-config")
if err != nil {
t.Fatalf("Failed to create temporary directory: %v", err)
}
defer os.RemoveAll(tempDir)
tempFile := filepath.Join(tempDir, "config.yaml")
content := `
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
aliases:
- "m1"
- "model-one"
env:
- "VAR1=value1"
- "VAR2=value2"
checkEndpoint: "/health"
model2:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
aliases:
- "m2"
checkEndpoint: "/"
healthCheckTimeout: 15
profiles:
test:
- model1
- model2
`
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
t.Fatalf("Failed to write temporary file: %v", err)
}
// Load the config and verify
config, err := LoadConfig(tempFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
expected := &Config{
Models: map[string]ModelConfig{
"model1": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
},
"model2": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: nil,
CheckEndpoint: "/",
},
},
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
aliases: map[string]string{
"m1": "model1",
"model-one": "model1",
"m2": "model2",
},
}
assert.Equal(t, expected, config)
realname, found := config.RealModelName("m1")
assert.True(t, found)
assert.Equal(t, "model1", realname)
}
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
config := &ModelConfig{
Cmd: `python model1.py \
--arg1 value1 \
--arg2 value2`,
}
args, err := config.SanitizedCommand()
assert.NoError(t, err)
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
}
func TestConfig_FindConfig(t *testing.T) {
// TODO?
// make make this shared between the different tests
config := &Config{
Models: map[string]ModelConfig{
"model1": {
Cmd: "python model1.py",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
},
"model2": {
Cmd: "python model2.py",
Proxy: "http://localhost:8081",
Aliases: []string{"m2", "model-two"},
Env: []string{"VAR3=value3", "VAR4=value4"},
CheckEndpoint: "/status",
},
},
HealthCheckTimeout: 10,
aliases: map[string]string{
"m1": "model1",
"model-one": "model1",
"m2": "model2",
},
}
// Test finding a model by its name
modelConfig, modelId, found := config.FindConfig("model1")
assert.True(t, found)
assert.Equal(t, "model1", modelId)
assert.Equal(t, config.Models["model1"], modelConfig)
// Test finding a model by its alias
modelConfig, modelId, found = config.FindConfig("m1")
assert.True(t, found)
assert.Equal(t, "model1", modelId)
assert.Equal(t, config.Models["model1"], modelConfig)
// Test finding a model that does not exist
modelConfig, modelId, found = config.FindConfig("model3")
assert.False(t, found)
assert.Equal(t, "", modelId)
assert.Equal(t, ModelConfig{}, modelConfig)
}
func TestConfig_SanitizeCommand(t *testing.T) {
// Test a command with spaces and newlines
args, err := SanitizeCommand(`python model1.py \
-a "double quotes" \
--arg2 'single quotes'
-s
--arg3 123 \
--arg4 '"string in string"'
-c "'single quoted'"
`)
assert.NoError(t, err)
assert.Equal(t, []string{
"python", "model1.py",
"-a", "double quotes",
"--arg2", "single quotes",
"-s",
"--arg3", "123",
"--arg4", `"string in string"`,
"-c", `'single quoted'`,
}, args)
// Test an empty command
args, err = SanitizeCommand("")
assert.Error(t, err)
assert.Nil(t, args)
}
+58
View File
@@ -0,0 +1,58 @@
package proxy
import (
"fmt"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
"github.com/gin-gonic/gin"
)
var (
nextTestPort int = 12000
portMutex sync.Mutex
)
// Check if the binary exists
func TestMain(m *testing.M) {
binaryPath := getSimpleResponderPath()
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
fmt.Printf("simple-responder not found at %s, did you `make simple-responder`?\n", binaryPath)
os.Exit(1)
}
gin.SetMode(gin.TestMode)
m.Run()
}
// Helper function to get the binary path
func getSimpleResponderPath() string {
goos := runtime.GOOS
goarch := runtime.GOARCH
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
}
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
portMutex.Lock()
defer portMutex.Unlock()
port := nextTestPort
nextTestPort++
return getTestSimpleResponderConfigPort(expectedMessage, port)
}
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
binaryPath := getSimpleResponderPath()
// Create a process configuration
return ModelConfig{
Cmd: fmt.Sprintf("%s --port %d --silent --respond %s", binaryPath, port, expectedMessage),
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
}
}
Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

+14
View File
@@ -0,0 +1,14 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>llama-swap</title>
</head>
<body>
<h1>llama-swap</h1>
<p>
<a href="/logs">view logs</a> | <a href="/upstream">configured models</a> | <a href="https://github.com/mostlygeek/llama-swap">github</a>
</p>
</body>
</html>
+145
View File
@@ -0,0 +1,145 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Logs</title>
<style>
body {
margin: 0;
height: 100vh;
display: flex;
flex-direction: column;
font-family: "Courier New", Courier, monospace;
}
#log-controls {
margin: 0.5em;
display: flex;
align-items: center;
justify-content: space-between; /* Spaces out elements evenly */
}
#log-controls input {
flex: 1;
}
#log-controls input:focus {
outline: none; /* Ensures no outline is shown when the input is focused */
}
#log-stream {
flex: 1;
margin: 0.5em;
padding: 1em;
background: #f4f4f4;
overflow-y: auto;
white-space: pre-wrap; /* Ensures line wrapping */
word-wrap: break-word; /* Ensures long words wrap */
}
.regex-error {
background-color: #ff0000 !important;
}
/* Dark mode styles */
@media (prefers-color-scheme: dark) {
body {
background-color: #333;
color: #fff;
}
#log-stream {
background: #444;
color: #fff;
}
#log-controls input {
background: #555;
color: #fff;
border: 1px solid #777;
}
#log-controls button {
background: #555;
color: #fff;
border: 1px solid #777;
}
}
</style>
</head>
<body>
<pre id="log-stream">Waiting for logs...</pre>
<div id="log-controls">
<input type="text" id="filter-input" placeholder="regex filter">
<button id="clear-button">clear</button>
</div>
<script>
const logStream = document.getElementById('log-stream');
const filterInput = document.getElementById('filter-input');
var logData = "";
let regexFilter = null;
function setupEventSource() {
if (typeof(EventSource) !== "undefined") {
const eventSource = new EventSource("/logs/streamSSE");
eventSource.onmessage = function(event) {
logData += event.data;
render()
};
eventSource.onerror = function(err) {
logData = "EventSource failed: " + err.message;
};
} else {
logData = "SSE Not supported by this browser."
}
}
// poor-ai's react ¯\_(ツ)_/¯
function render() {
if (regexFilter) {
const lines = logData.split('\n');
const filteredLines = lines.filter(line => {
return regexFilter === null || regexFilter.test(line);
});
if (filteredLines.length > 0) {
logStream.textContent = filteredLines.join('\n') + '\n';
} else {
logStream.textContent = "";
}
} else {
logStream.textContent = logData;
}
logStream.scrollTop = logStream.scrollHeight;
}
function updateFilter() {
const pattern = filterInput.value.trim();
filterInput.classList.remove('regex-error');
if (pattern) {
try {
regexFilter = new RegExp(pattern);
} catch (e) {
console.error("Invalid regex pattern:", e);
regexFilter = null;
filterInput.classList.add('regex-error');
return
}
} else {
regexFilter = null;
}
render();
}
filterInput.addEventListener('input', updateFilter);
document.getElementById('clear-button').addEventListener('click', () => {
filterInput.value = "";
regexFilter = null;
render();
});
setupEventSource();
updateFilter();
</script>
</body>
</html>
+10
View File
@@ -0,0 +1,10 @@
package proxy
import "embed"
//go:embed html
var htmlFiles embed.FS
func getHTMLFile(path string) ([]byte, error) {
return htmlFiles.ReadFile("html/" + path)
}
+186
View File
@@ -0,0 +1,186 @@
package proxy
import (
"container/ring"
"fmt"
"io"
"os"
"sync"
)
type LogLevel int
const (
LevelDebug LogLevel = iota
LevelInfo
LevelWarn
LevelError
)
type LogMonitor struct {
clients map[chan []byte]bool
mu sync.RWMutex
buffer *ring.Ring
bufferMu sync.RWMutex
// typically this can be os.Stdout
stdout io.Writer
// logging levels
level LogLevel
prefix string
}
func NewLogMonitor() *LogMonitor {
return NewLogMonitorWriter(os.Stdout)
}
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
return &LogMonitor{
clients: make(map[chan []byte]bool),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout,
level: LevelInfo,
prefix: "",
}
}
func (w *LogMonitor) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
n, err = w.stdout.Write(p)
if err != nil {
return n, err
}
w.bufferMu.Lock()
bufferCopy := make([]byte, len(p))
copy(bufferCopy, p)
w.buffer.Value = bufferCopy
w.buffer = w.buffer.Next()
w.bufferMu.Unlock()
w.broadcast(bufferCopy)
return n, nil
}
func (w *LogMonitor) GetHistory() []byte {
w.bufferMu.RLock()
defer w.bufferMu.RUnlock()
var history []byte
w.buffer.Do(func(p any) {
if p != nil {
if content, ok := p.([]byte); ok {
history = append(history, content...)
}
}
})
return history
}
func (w *LogMonitor) Subscribe() chan []byte {
w.mu.Lock()
defer w.mu.Unlock()
ch := make(chan []byte, 100)
w.clients[ch] = true
return ch
}
func (w *LogMonitor) Unsubscribe(ch chan []byte) {
w.mu.Lock()
defer w.mu.Unlock()
delete(w.clients, ch)
close(ch)
}
func (w *LogMonitor) broadcast(msg []byte) {
w.mu.RLock()
defer w.mu.RUnlock()
for client := range w.clients {
select {
case client <- msg:
default:
// If client buffer is full, skip
}
}
}
func (w *LogMonitor) SetPrefix(prefix string) {
w.mu.Lock()
defer w.mu.Unlock()
w.prefix = prefix
}
func (w *LogMonitor) SetLogLevel(level LogLevel) {
w.mu.Lock()
defer w.mu.Unlock()
w.level = level
}
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
prefix := ""
if w.prefix != "" {
prefix = fmt.Sprintf("[%s] ", w.prefix)
}
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
}
func (w *LogMonitor) log(level LogLevel, msg string) {
if level < w.level {
return
}
w.Write(w.formatMessage(level.String(), msg))
}
func (w *LogMonitor) Debug(msg string) {
w.log(LevelDebug, msg)
}
func (w *LogMonitor) Info(msg string) {
w.log(LevelInfo, msg)
}
func (w *LogMonitor) Warn(msg string) {
w.log(LevelWarn, msg)
}
func (w *LogMonitor) Error(msg string) {
w.log(LevelError, msg)
}
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
w.log(LevelDebug, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Infof(format string, args ...interface{}) {
w.log(LevelInfo, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
w.log(LevelWarn, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
w.log(LevelError, fmt.Sprintf(format, args...))
}
func (l LogLevel) String() string {
switch l {
case LevelDebug:
return "DEBUG"
case LevelInfo:
return "INFO"
case LevelWarn:
return "WARN"
case LevelError:
return "ERROR"
default:
return "UNKNOWN"
}
}
+95
View File
@@ -0,0 +1,95 @@
package proxy
import (
"bytes"
"io"
"sync"
"testing"
)
func TestLogMonitor(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
// Test subscription
client1 := logMonitor.Subscribe()
client2 := logMonitor.Subscribe()
defer logMonitor.Unsubscribe(client1)
defer logMonitor.Unsubscribe(client2)
client1Messages := make([]byte, 0)
client2Messages := make([]byte, 0)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case data := <-client1:
client1Messages = append(client1Messages, data...)
case data := <-client2:
client2Messages = append(client2Messages, data...)
default:
return
}
}
}()
logMonitor.Write([]byte("1"))
logMonitor.Write([]byte("2"))
logMonitor.Write([]byte("3"))
// Wait for the goroutine to finish
wg.Wait()
// Check the buffer
expectedHistory := "123"
history := string(logMonitor.GetHistory())
if history != expectedHistory {
t.Errorf("Expected history: %s, got: %s", expectedHistory, history)
}
c1Data := string(client1Messages)
if c1Data != expectedHistory {
t.Errorf("Client1 expected %s, got: %s", expectedHistory, c1Data)
}
c2Data := string(client2Messages)
if c2Data != expectedHistory {
t.Errorf("Client2 expected %s, got: %s", expectedHistory, c2Data)
}
}
func TestWrite_ImmutableBuffer(t *testing.T) {
// Create a new LogMonitor instance
lm := NewLogMonitorWriter(io.Discard)
// Prepare a message to write
msg := []byte("Hello, World!")
lenmsg := len(msg)
// Write the message to the LogMonitor
n, err := lm.Write(msg)
if err != nil {
t.Fatalf("Write failed: %v", err)
}
if n != lenmsg {
t.Errorf("Expected %d bytes written but got %d", lenmsg, n)
}
// Change the original message
msg[0] = 'B' // This should not affect the buffer
// Get the history from the LogMonitor
history := lm.GetHistory()
// Check that the history contains the original message, not the modified one
expected := []byte("Hello, World!")
if !bytes.Equal(history, expected) {
t.Errorf("Expected history to be %q, got %q", expected, history)
}
}
-224
View File
@@ -1,224 +0,0 @@
package proxy
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"strings"
"sync"
"syscall"
"time"
)
type ProxyManager struct {
sync.Mutex
config *Config
currentCmd *exec.Cmd
currentConfig ModelConfig
}
func New(config *Config) *ProxyManager {
return &ProxyManager{config: config}
}
func (pm *ProxyManager) HandleFunc(w http.ResponseWriter, r *http.Request) {
// https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#api-endpoints
if r.URL.Path == "/v1/chat/completions" {
// extracts the `model` from json body
pm.proxyChatRequest(w, r)
} else {
pm.proxyRequest(w, r)
}
}
func (pm *ProxyManager) swapModel(requestedModel string) error {
pm.Lock()
defer pm.Unlock()
// find the model configuration matching requestedModel
modelConfig, found := pm.config.FindConfig(requestedModel)
if !found {
return fmt.Errorf("could not find configuration for %s", requestedModel)
}
// no need to swap llama.cpp instances
if pm.currentConfig.Cmd == modelConfig.Cmd {
return nil
}
// kill the current running one to swap it
if pm.currentCmd != nil {
pm.currentCmd.Process.Signal(syscall.SIGTERM)
// wait for it to end
pm.currentCmd.Process.Wait()
}
pm.currentConfig = modelConfig
args := strings.Fields(modelConfig.Cmd)
cmd := exec.Command(args[0], args[1:]...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = modelConfig.Env
err := cmd.Start()
if err != nil {
return err
}
pm.currentCmd = cmd
if err := pm.checkHealthEndpoint(); err != nil {
return err
}
return nil
}
func (pm *ProxyManager) checkHealthEndpoint() error {
if pm.currentConfig.Proxy == "" {
return fmt.Errorf("no upstream available to check /health")
}
checkEndpoint := strings.TrimSpace(pm.currentConfig.CheckEndpoint)
if checkEndpoint == "none" {
return nil
}
// keep default behaviour
if checkEndpoint == "" {
checkEndpoint = "/health"
}
proxyTo := pm.currentConfig.Proxy
maxDuration := time.Second * time.Duration(pm.config.HealthCheckTimeout)
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
if err != nil {
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
}
client := &http.Client{}
startTime := time.Now()
for {
req, err := http.NewRequest("GET", healthURL, nil)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(req.Context(), 250*time.Millisecond)
defer cancel()
req = req.WithContext(ctx)
resp, err := client.Do(req)
if err != nil {
if strings.Contains(err.Error(), "connection refused") {
// if TCP dial can't connect any HTTP response after 5 seconds
// exit quickly.
if time.Since(startTime) > 5*time.Second {
return fmt.Errorf("health check endpoint took more than 5 seconds to respond")
}
}
if time.Since(startTime) >= maxDuration {
return fmt.Errorf("failed to check health from: %s", healthURL)
}
time.Sleep(time.Second)
continue
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil
}
if time.Since(startTime) >= maxDuration {
return fmt.Errorf("failed to check health from: %s", healthURL)
}
time.Sleep(time.Second)
}
}
func (pm *ProxyManager) proxyChatRequest(w http.ResponseWriter, r *http.Request) {
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Invalid JSON", http.StatusBadRequest)
return
}
var requestBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
http.Error(w, "Invalid JSON", http.StatusBadRequest)
return
}
model, ok := requestBody["model"].(string)
if !ok {
http.Error(w, "Missing or invalid 'model' key", http.StatusBadRequest)
return
}
if err := pm.swapModel(model); err != nil {
http.Error(w, fmt.Sprintf("unable to swap to model: %s", err.Error()), http.StatusNotFound)
return
}
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
pm.proxyRequest(w, r)
}
func (pm *ProxyManager) proxyRequest(w http.ResponseWriter, r *http.Request) {
if pm.currentConfig.Proxy == "" {
http.Error(w, "No upstream proxy", http.StatusInternalServerError)
return
}
proxyTo := pm.currentConfig.Proxy
client := &http.Client{}
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
req.Header = r.Header
resp, err := client.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
// faster than io.Copy when streaming
buf := make([]byte, 32*1024)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
http.Error(w, writeErr.Error(), http.StatusInternalServerError)
return
}
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
}
if err == io.EOF {
break
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
}
}
+436
View File
@@ -0,0 +1,436 @@
package proxy
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os/exec"
"strings"
"sync"
"syscall"
"time"
)
type ProcessState string
const (
StateStopped ProcessState = ProcessState("stopped")
StateStarting ProcessState = ProcessState("starting")
StateReady ProcessState = ProcessState("ready")
StateStopping ProcessState = ProcessState("stopping")
// failed a health check on start and will not be recovered
StateFailed ProcessState = ProcessState("failed")
// process is shutdown and will not be restarted
StateShutdown ProcessState = ProcessState("shutdown")
)
type Process struct {
ID string
config ModelConfig
cmd *exec.Cmd
processLogger *LogMonitor
proxyLogger *LogMonitor
healthCheckTimeout int
healthCheckLoopInterval time.Duration
lastRequestHandled time.Time
stateMutex sync.RWMutex
state ProcessState
inFlightRequests sync.WaitGroup
// used to block on multiple start() calls
waitStarting sync.WaitGroup
// for managing shutdown state
shutdownCtx context.Context
shutdownCancel context.CancelFunc
}
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
ctx, cancel := context.WithCancel(context.Background())
return &Process{
ID: ID,
config: config,
cmd: nil,
processLogger: processLogger,
proxyLogger: proxyLogger,
healthCheckTimeout: healthCheckTimeout,
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
state: StateStopped,
shutdownCtx: ctx,
shutdownCancel: cancel,
}
}
// LogMonitor returns the log monitor associated with the process.
func (p *Process) LogMonitor() *LogMonitor {
return p.processLogger
}
// custom error types for swapping state
var (
ErrExpectedStateMismatch = errors.New("expected state mismatch")
ErrInvalidStateTransition = errors.New("invalid state transition")
)
// swapState performs a compare and swap of the state atomically. It returns the current state
// and an error if the swap failed.
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) {
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if p.state != expectedState {
return p.state, ErrExpectedStateMismatch
}
if !isValidTransition(p.state, newState) {
p.proxyLogger.Warnf("Invalid state transition from %s to %s", p.state, newState)
return p.state, ErrInvalidStateTransition
}
p.proxyLogger.Debugf("State transition from %s to %s", expectedState, newState)
p.state = newState
return p.state, nil
}
// Helper function to encapsulate transition rules
func isValidTransition(from, to ProcessState) bool {
switch from {
case StateStopped:
return to == StateStarting
case StateStarting:
return to == StateReady || to == StateFailed || to == StateStopping
case StateReady:
return to == StateStopping
case StateStopping:
return to == StateStopped || to == StateShutdown
case StateFailed, StateShutdown:
return false // No transitions allowed from these states
}
return false
}
func (p *Process) CurrentState() ProcessState {
p.stateMutex.RLock()
defer p.stateMutex.RUnlock()
return p.state
}
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
// it is a private method because starting is automatic but stopping can be called
// at any time.
func (p *Process) start() error {
if p.config.Proxy == "" {
return fmt.Errorf("can not start(), upstream proxy missing")
}
args, err := p.config.SanitizedCommand()
if err != nil {
return fmt.Errorf("unable to get sanitized command: %v", err)
}
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
if err == ErrExpectedStateMismatch {
// already starting, just wait for it to complete and expect
// it to be be in the Ready start after. If not, return an error
if curState == StateStarting {
p.waitStarting.Wait()
if state := p.CurrentState(); state == StateReady {
return nil
} else {
return fmt.Errorf("process was already starting but wound up in state %v", state)
}
} else {
return fmt.Errorf("processes was in state %v when start() was called", curState)
}
} else {
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
}
}
p.waitStarting.Add(1)
defer p.waitStarting.Done()
p.cmd = exec.Command(args[0], args[1:]...)
p.cmd.Stdout = p.processLogger
p.cmd.Stderr = p.processLogger
p.cmd.Env = p.config.Env
err = p.cmd.Start()
// Set process state to failed
if err != nil {
if curState, swapErr := p.swapState(StateStarting, StateFailed); swapErr != nil {
return fmt.Errorf(
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
err, curState, swapErr,
)
}
return fmt.Errorf("start() failed: %v", err)
}
// One of three things can happen at this stage:
// 1. The command exits unexpectedly
// 2. The health check fails
// 3. The health check passes
//
// only in the third case will the process be considered Ready to accept
<-time.After(250 * time.Millisecond) // give process a bit of time to start
checkStartTime := time.Now()
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
// a "none" means don't check for health ... I could have picked a better word :facepalm:
if checkEndpoint != "none" {
// keep default behaviour
if checkEndpoint == "" {
checkEndpoint = "/health"
}
proxyTo := p.config.Proxy
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
if err != nil {
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
}
checkDeadline, cancelHealthCheck := context.WithDeadline(
context.Background(),
checkStartTime.Add(maxDuration),
)
defer cancelHealthCheck()
loop:
// Ready Check loop
for {
select {
case <-checkDeadline.Done():
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState)
} else {
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
}
case <-p.shutdownCtx.Done():
return errors.New("health check interrupted due to shutdown")
default:
if err := p.checkHealthEndpoint(healthURL); err == nil {
p.proxyLogger.Infof("Health check passed on %s", healthURL)
cancelHealthCheck()
break loop
} else {
if strings.Contains(err.Error(), "connection refused") {
endTime, _ := checkDeadline.Deadline()
ttl := time.Until(endTime)
p.proxyLogger.Infof("Connection refused on %s, retrying in %.0fs", healthURL, ttl.Seconds())
} else {
p.proxyLogger.Infof("Health check error on %s, %v", healthURL, err)
}
}
}
<-time.After(p.healthCheckLoopInterval)
}
}
if p.config.UnloadAfter > 0 {
// start a goroutine to check every second if
// the process should be stopped
go func() {
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
for range time.Tick(time.Second) {
if p.CurrentState() != StateReady {
return
}
// wait for all inflight requests to complete and ticker
p.inFlightRequests.Wait()
if time.Since(p.lastRequestHandled) > maxDuration {
p.proxyLogger.Infof("Unloading model %s, TTL of %ds reached.", p.ID, p.config.UnloadAfter)
p.Stop()
return
}
}
}()
}
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
} else {
return nil
}
}
func (p *Process) Stop() {
// wait for any inflight requests before proceeding
p.inFlightRequests.Wait()
// calling Stop() when state is invalid is a no-op
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
p.proxyLogger.Infof("Stop() Ready -> StateStopping err: %v, current state: %v", err, curState)
return
}
// stop the process with a graceful exit timeout
p.stopCommand(5 * time.Second)
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
p.proxyLogger.Infof("Stop() StateStopping -> StateStopped err: %v, current state: %v", err, curState)
}
}
// Shutdown is called when llama-swap is shutting down. It will give a little bit
// of time for any inflight requests to complete before shutting down. If the Process
// is in the state of starting, it will cancel it and shut it down
func (p *Process) Shutdown() {
p.shutdownCancel()
p.stopCommand(5 * time.Second)
p.state = StateShutdown
}
// stopCommand will send a SIGTERM to the process and wait for it to exit.
// If it does not exit within 5 seconds, it will send a SIGKILL.
func (p *Process) stopCommand(sigtermTTL time.Duration) {
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
defer cancelTimeout()
sigtermNormal := make(chan error, 1)
go func() {
sigtermNormal <- p.cmd.Wait()
}()
if p.cmd == nil || p.cmd.Process == nil {
p.proxyLogger.Warnf("Process [%s] cmd or cmd.Process is nil", p.ID)
return
}
if err := p.terminateProcess(); err != nil {
p.proxyLogger.Infof("Failed to gracefully terminate process [%s]: %v", p.ID, err)
}
select {
case <-sigtermTimeout.Done():
p.proxyLogger.Infof("Process [%s] timed out waiting to stop, sending KILL signal", p.ID)
p.cmd.Process.Kill()
case err := <-sigtermNormal:
if err != nil {
if errno, ok := err.(syscall.Errno); ok {
p.proxyLogger.Errorf("Process [%s] errno >> %v", p.ID, errno)
} else if exitError, ok := err.(*exec.ExitError); ok {
if strings.Contains(exitError.String(), "signal: terminated") {
p.proxyLogger.Infof("Process [%s] stopped OK", p.ID)
} else if strings.Contains(exitError.String(), "signal: interrupt") {
p.proxyLogger.Infof("Process [%s] interrupted OK", p.ID)
} else {
p.proxyLogger.Warnf("Process [%s] ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
}
} else {
p.proxyLogger.Errorf("Process [%s] exited >> %v", p.ID, err)
}
}
}
}
func (p *Process) checkHealthEndpoint(healthURL string) error {
client := &http.Client{
Timeout: 500 * time.Millisecond,
}
req, err := http.NewRequest("GET", healthURL, nil)
if err != nil {
return err
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
// got a response but it was not an OK
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status code: %d", resp.StatusCode)
}
return nil
}
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
// prevent new requests from being made while stopping or irrecoverable
currentState := p.CurrentState()
if currentState == StateFailed || currentState == StateShutdown || currentState == StateStopping {
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
return
}
p.inFlightRequests.Add(1)
defer func() {
p.lastRequestHandled = time.Now()
p.inFlightRequests.Done()
}()
// start the process on demand
if p.CurrentState() != StateReady {
if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusBadGateway)
return
}
}
proxyTo := p.config.Proxy
client := &http.Client{}
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
req.Header = r.Header.Clone()
resp, err := client.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
// faster than io.Copy when streaming
buf := make([]byte, 32*1024)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
return
}
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
}
if err == io.EOF {
break
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
}
}
+9
View File
@@ -0,0 +1,9 @@
//go:build !windows
package proxy
import "syscall"
func (p *Process) terminateProcess() error {
return p.cmd.Process.Signal(syscall.SIGTERM)
}
+14
View File
@@ -0,0 +1,14 @@
//go:build windows
package proxy
import (
"fmt"
"os/exec"
)
func (p *Process) terminateProcess() error {
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
return cmd.Run()
}
+313
View File
@@ -0,0 +1,313 @@
package proxy
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
var (
discardLogger = NewLogMonitorWriter(io.Discard)
)
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage)
// Create a process
process := NewProcess("test-process", 5, config, discardLogger, discardLogger)
defer process.Stop()
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
// process is automatically started
assert.Equal(t, StateStopped, process.CurrentState())
process.ProxyRequest(w, req)
assert.Equal(t, StateReady, process.CurrentState())
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), expectedMessage)
// Stop the process
process.Stop()
req = httptest.NewRequest("GET", "/", nil)
w = httptest.NewRecorder()
// Proxy the request
process.ProxyRequest(w, req)
// should have automatically started the process again
if w.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
}
}
// TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests
// are all handled successfully, even though they all may ask for the process to .start()
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("test-process", 5, config, discardLogger, discardLogger)
defer process.Stop()
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func(reqID int) {
defer wg.Done()
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID)
assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID)
}(i)
}
wg.Wait()
assert.Equal(t, StateReady, process.CurrentState())
}
// test that the automatic start returns the expected error type
func TestProcess_BrokenModelConfig(t *testing.T) {
// Create a process configuration
config := ModelConfig{
Cmd: "nonexistent-command",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
}
process := NewProcess("broken", 1, config, discardLogger, discardLogger)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "unable to start process")
w = httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
assert.Contains(t, w.Body.String(), "Process can not ProxyRequest, state is failed")
}
func TestProcess_UnloadAfterTTL(t *testing.T) {
if testing.Short() {
t.Skip("skipping long auto unload TTL test")
}
expectedMessage := "I_sense_imminent_danger"
config := getTestSimpleResponderConfig(expectedMessage)
assert.Equal(t, 0, config.UnloadAfter)
config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl_test", 2, config, discardLogger, discardLogger)
defer process.Stop()
// this should take 4 seconds
req1 := httptest.NewRequest("GET", "/slow-respond?echo=1234&delay=1000ms", nil)
req2 := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
// Proxy the request (auto start) with a slow response that takes longer than config.UnloadAfter
process.ProxyRequest(w, req1)
t.Log("sending slow first request (4 seconds)")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "1234")
assert.Equal(t, StateReady, process.CurrentState())
// ensure the TTL timeout does not race slow requests (see issue #25)
t.Log("sending second request (1 second)")
time.Sleep(time.Second)
w = httptest.NewRecorder()
process.ProxyRequest(w, req2)
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), expectedMessage)
assert.Equal(t, StateReady, process.CurrentState())
// wait 5 seconds
t.Log("sleep 5 seconds and check if unloaded")
time.Sleep(5 * time.Second)
assert.Equal(t, StateStopped, process.CurrentState())
}
func TestProcess_LowTTLValue(t *testing.T) {
if true { // change this code to run this ...
t.Skip("skipping test, edit process_test.go to run it ")
}
config := getTestSimpleResponderConfig("fast_ttl")
assert.Equal(t, 0, config.UnloadAfter)
config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter)
process := NewProcess("ttl", 2, config, discardLogger, discardLogger)
defer process.Stop()
for i := 0; i < 100; i++ {
t.Logf("Waiting before sending request %d", i)
time.Sleep(1500 * time.Millisecond)
expected := fmt.Sprintf("echo=test_%d", i)
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), expected)
}
}
// issue #19
// This test makes sure using Process.Stop() does not affect pending HTTP
// requests. All HTTP requests in this test should complete successfully.
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
expectedMessage := "12345"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("t", 10, config, discardLogger, discardLogger)
defer process.Stop()
results := map[string]string{
"12345": "",
"abcde": "",
"fghij": "",
}
var wg sync.WaitGroup
var mu sync.Mutex
for key := range results {
wg.Add(1)
go func(key string) {
defer wg.Done()
// send a request where simple-responder is will wait 300ms before responding
// this will simulate an in-progress request.
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
}
mu.Lock()
results[key] = w.Body.String()
mu.Unlock()
}(key)
}
// Stop the process while requests are still being processed
go func() {
<-time.After(150 * time.Millisecond)
process.Stop()
}()
wg.Wait()
for key, result := range results {
assert.Equal(t, key, result)
}
}
func TestProcess_SwapState(t *testing.T) {
tests := []struct {
name string
currentState ProcessState
expectedState ProcessState
newState ProcessState
expectedError error
expectedResult ProcessState
}{
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
{"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed},
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting},
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
{"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady},
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
{"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed},
{"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed},
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), discardLogger, discardLogger)
p.state = test.currentState
resultState, err := p.swapState(test.expectedState, test.newState)
if err != nil && test.expectedError == nil {
t.Errorf("Unexpected error: %v", err)
} else if err == nil && test.expectedError != nil {
t.Errorf("Expected error: %v, but got none", test.expectedError)
} else if err != nil && test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("Expected error: %v, got: %v", test.expectedError, err)
}
}
if resultState != test.expectedResult {
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
}
})
}
}
func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
if testing.Short() {
t.Skip("skipping long shutdown test")
}
expectedMessage := "testing91931"
// make a config where the healthcheck will always fail because port is wrong
config := getTestSimpleResponderConfigPort(expectedMessage, 9999)
config.Proxy = "http://localhost:9998/test"
healthCheckTTLSeconds := 30
process := NewProcess("test-process", healthCheckTTLSeconds, config, discardLogger, discardLogger)
// make it a lot faster
process.healthCheckLoopInterval = time.Second
// start a goroutine to simulate a shutdown
var wg sync.WaitGroup
go func() {
defer wg.Done()
<-time.After(time.Millisecond * 500)
process.Shutdown()
}()
wg.Add(1)
// start the process, this is a blocking call
err := process.start()
wg.Wait()
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
assert.Equal(t, StateShutdown, process.CurrentState())
}
+581
View File
@@ -0,0 +1,581 @@
package proxy
import (
"bytes"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
PROFILE_SPLIT_CHAR = ":"
)
type ProxyManager struct {
sync.Mutex
config *Config
currentProcesses map[string]*Process
ginEngine *gin.Engine
// logging
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
muxLogger *LogMonitor
}
func New(config *Config) *ProxyManager {
// set up loggers
stdoutLogger := NewLogMonitorWriter(os.Stdout)
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
proxyLogger := NewLogMonitorWriter(stdoutLogger)
if config.LogRequests {
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
}
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
case "debug":
proxyLogger.SetLogLevel(LevelDebug)
case "info":
proxyLogger.SetLogLevel(LevelInfo)
case "warn":
proxyLogger.SetLogLevel(LevelWarn)
case "error":
proxyLogger.SetLogLevel(LevelError)
default:
proxyLogger.SetLogLevel(LevelInfo)
}
pm := &ProxyManager{
config: config,
currentProcesses: make(map[string]*Process),
ginEngine: gin.New(),
proxyLogger: proxyLogger,
muxLogger: stdoutLogger,
upstreamLogger: upstreamLogger,
}
pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
clientIP,
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
// see: issue: #81, #77 and #42 for CORS issues
// respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) {
if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
// allow whatever the client requested by default
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
c.Header("Access-Control-Allow-Headers", sanitized)
} else {
c.Header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, Accept, X-Requested-With",
)
}
c.Header("Access-Control-Max-Age", "86400")
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
})
// Set up routes using the Gin engine
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
// Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
// Support embeddings
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
// Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
// in proxymanager_loghandlers.go
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
pm.ginEngine.GET("/", func(c *gin.Context) {
// Set the Content-Type header to text/html
c.Header("Content-Type", "text/html")
// Write the embedded HTML content to the response
htmlData, err := getHTMLFile("index.html")
if err != nil {
c.String(http.StatusInternalServerError, err.Error())
return
}
_, err = c.Writer.Write(htmlData)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
return
}
})
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
if data, err := getHTMLFile("favicon.ico"); err == nil {
c.Data(http.StatusOK, "image/x-icon", data)
} else {
c.String(http.StatusInternalServerError, err.Error())
}
})
// Disable console color for testing
gin.DisableConsoleColor()
return pm
}
func (pm *ProxyManager) Run(addr ...string) error {
return pm.ginEngine.Run(addr...)
}
func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) {
pm.ginEngine.ServeHTTP(w, r)
}
func (pm *ProxyManager) StopProcesses() {
pm.Lock()
defer pm.Unlock()
pm.stopProcesses()
}
// for internal usage
func (pm *ProxyManager) stopProcesses() {
if len(pm.currentProcesses) == 0 {
return
}
// stop Processes in parallel
var wg sync.WaitGroup
for _, process := range pm.currentProcesses {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Stop()
}(process)
}
wg.Wait()
pm.currentProcesses = make(map[string]*Process)
}
// Shutdown is called to shutdown all upstream processes
// when llama-swap is shutting down.
func (pm *ProxyManager) Shutdown() {
pm.Lock()
defer pm.Unlock()
// shutdown process in parallel
var wg sync.WaitGroup
for _, process := range pm.currentProcesses {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Shutdown()
}(process)
}
wg.Wait()
}
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := []interface{}{}
for id, modelConfig := range pm.config.Models {
if modelConfig.Unlisted {
continue
}
data = append(data, map[string]interface{}{
"id": id,
"object": "model",
"created": time.Now().Unix(),
"owned_by": "llama-swap",
})
}
// Set the Content-Type header to application/json
c.Header("Content-Type", "application/json")
if origin := c.Request.Header.Get("Origin"); origin != "" {
c.Header("Access-Control-Allow-Origin", origin)
}
// Encode the data as JSON and write it to the response writer
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
return
}
}
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
pm.Lock()
defer pm.Unlock()
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
if _, found := pm.config.Profiles[profileName]; !found {
return nil, fmt.Errorf("model group not found %s", profileName)
}
}
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(modelName)
if !found {
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
}
// check if model is part of the profile
if profileName != "" {
found := false
for _, item := range pm.config.Profiles[profileName] {
if item == realModelName {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName)
}
}
// exit early when already running, otherwise stop everything and swap
requestedProcessKey := ProcessKeyName(profileName, realModelName)
if process, found := pm.currentProcesses[requestedProcessKey]; found {
pm.proxyLogger.Debugf("No-swap, using existing process for model [%s]", requestedModel)
return process, nil
}
// stop all running models
pm.proxyLogger.Infof("Swapping model to [%s]", requestedModel)
pm.stopProcesses()
if profileName == "" {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
} else {
for _, modelName := range pm.config.Profiles[profileName] {
if realModelName, found := pm.config.RealModelName(modelName); found {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
}
}
}
// requestedProcessKey should exist due to swap
return pm.currentProcesses[requestedProcessKey], nil
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
requestedModel := c.Param("model_id")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
return
}
if process, err := pm.swapModel(requestedModel); err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
} else {
// rewrite the path
c.Request.URL.Path = c.Param("upstreamPath")
process.ProxyRequest(c.Writer, c.Request)
}
}
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
var html strings.Builder
html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><ul>")
// Extract keys and sort them
var modelIDs []string
for modelID, modelConfig := range pm.config.Models {
if modelConfig.Unlisted {
continue
}
modelIDs = append(modelIDs, modelID)
}
sort.Strings(modelIDs)
// Iterate over sorted keys
for _, modelID := range modelIDs {
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a></li>", modelID, modelID))
}
html.WriteString("</ul></body></html>")
c.Header("Content-Type", "text/html")
c.String(http.StatusOK, html.String())
}
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
return
}
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
}
process, err := pm.swapModel(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return
}
// issue #69 allow custom model names to be sent to upstream
if process.config.UseModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return
}
} else {
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return
}
}
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
process.ProxyRequest(c.Writer, c.Request)
}
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
return
}
// Get model parameter from the form
requestedModel := c.Request.FormValue("model")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
return
}
// Swap to the requested model
process, err := pm.swapModel(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return
}
// Get profile name and model name from the requested model
profileName, modelName := splitRequestedModel(requestedModel)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" {
if process.config.UseModelName != "" {
fieldValue = process.config.UseModelName
} else if profileName != "" {
fieldValue = modelName
}
}
field, err := multipartWriter.CreateFormField(key)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
return
}
if _, err = field.Write([]byte(fieldValue)); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
return
}
}
}
// Copy all files from the original request
for key, fileHeaders := range c.Request.MultipartForm.File {
for _, fileHeader := range fileHeaders {
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
return
}
file, err := fileHeader.Open()
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
return
}
if _, err = io.Copy(formFile, file); err != nil {
file.Close()
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
return
}
file.Close()
}
}
// Close the multipart writer to finalize the form
if err := multipartWriter.Close(); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
return
}
// Create a new request with the reconstructed form data
modifiedReq, err := http.NewRequestWithContext(
c.Request.Context(),
c.Request.Method,
c.Request.URL.String(),
&requestBuffer,
)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
return
}
// Copy the headers from the original request
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// Use the modified request for proxying
process.ProxyRequest(c.Writer, modifiedReq)
}
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
acceptHeader := c.GetHeader("Accept")
if strings.Contains(acceptHeader, "application/json") {
c.JSON(statusCode, gin.H{"error": message})
} else {
c.String(statusCode, message)
}
}
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
pm.StopProcesses()
c.String(http.StatusOK, "OK")
}
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response.
for _, process := range pm.currentProcesses {
// Append the process ID and State (multiple entries if profiles are being used).
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
})
}
// Put the results under the `running` key.
response := gin.H{
"running": runningProcesses,
}
context.JSON(http.StatusOK, response) // Always return 200 OK
}
func ProcessKeyName(groupName, modelName string) string {
return groupName + PROFILE_SPLIT_CHAR + modelName
}
func splitRequestedModel(requestedModel string) (string, string) {
profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
}
return profileName, modelName
}
+142
View File
@@ -0,0 +1,142 @@
package proxy
import (
"fmt"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
accept := c.GetHeader("Accept")
if strings.Contains(accept, "text/html") {
// Set the Content-Type header to text/html
c.Header("Content-Type", "text/html")
// Write the embedded HTML content to the response
logsHTML, err := getHTMLFile("logs.html")
if err != nil {
c.String(http.StatusInternalServerError, err.Error())
return
}
_, err = c.Writer.Write(logsHTML)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
return
}
} else {
c.Header("Content-Type", "text/plain")
history := pm.muxLogger.GetHistory()
_, err := c.Writer.Write(history)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
}
}
func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.Header("Transfer-Encoding", "chunked")
c.Header("X-Content-Type-Options", "nosniff")
logMonitorId := c.Param("logMonitorID")
logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done()
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
return
}
_, skipHistory := c.GetQuery("no-history")
// Send history first if not skipped
if !skipHistory {
history := logger.GetHistory()
if len(history) != 0 {
c.Writer.Write(history)
flusher.Flush()
}
}
// Stream new logs
for {
select {
case msg := <-ch:
_, err := c.Writer.Write(msg)
if err != nil {
// just break the loop if we can't write for some reason
return
}
flusher.Flush()
case <-notify:
return
}
}
}
func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Content-Type-Options", "nosniff")
logMonitorId := c.Param("logMonitorID")
logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done()
// Send history first if not skipped
_, skipHistory := c.GetQuery("no-history")
if !skipHistory {
history := logger.GetHistory()
if len(history) != 0 {
c.SSEvent("message", string(history))
c.Writer.Flush()
}
}
// Stream new logs
for {
select {
case msg := <-ch:
c.SSEvent("message", string(msg))
c.Writer.Flush()
case <-notify:
return
}
}
}
// getLogger searches for the appropriate logger based on the logMonitorId
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
var logger *LogMonitor
if logMonitorId == "" {
// maintain the default
logger = pm.muxLogger
} else if logMonitorId == "proxy" {
logger = pm.proxyLogger
} else if logMonitorId == "upstream" {
logger = pm.upstreamLogger
} else {
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
}
return logger, nil
}
+710
View File
@@ -0,0 +1,710 @@
package proxy
import (
"bytes"
"encoding/json"
"fmt"
"math/rand"
"mime/multipart"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
}
// make sure there's only one loaded model
assert.Len(t, proxy.currentProcesses, 1)
}
func TestProxyManager_SwapMultiProcess(t *testing.T) {
model1 := "path1/model1"
model2 := "path2/model2"
profileModel1 := ProcessKeyName("test", model1)
profileModel2 := ProcessKeyName("test", model2)
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
model1: getTestSimpleResponderConfig("model1"),
model2: getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {model1, model2},
},
}
proxy := New(config)
defer proxy.StopProcesses()
for modelID, requestedModel := range map[string]string{
"model1": profileModel1,
"model2": profileModel2,
} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelID)
}
// make sure there's two loaded models
assert.Len(t, proxy.currentProcesses, 2)
_, exists := proxy.currentProcesses[profileModel1]
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
_, exists = proxy.currentProcesses[profileModel2]
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
}
// When a request for a different model comes in ProxyManager should wait until
// the first request is complete before swapping. Both requests should complete
func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
results := map[string]string{}
var wg sync.WaitGroup
var mu sync.Mutex
for key := range config.Models {
wg.Add(1)
go func(key string) {
defer wg.Done()
reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
}
mu.Lock()
results[key] = w.Body.String()
mu.Unlock()
}(key)
<-time.After(time.Millisecond)
}
wg.Wait()
assert.Len(t, results, len(config.Models))
for key, result := range results {
assert.Equal(t, key, result)
}
}
func TestProxyManager_ListModelsHandler(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
},
}
proxy := New(config)
// Create a test request
req := httptest.NewRequest("GET", "/v1/models", nil)
req.Header.Add("Origin", "i-am-the-origin")
w := httptest.NewRecorder()
// Call the listModelsHandler
proxy.HandlerFunc(w, req)
// Check the response status code
assert.Equal(t, http.StatusOK, w.Code)
// Check for Access-Control-Allow-Origin
assert.Equal(t, req.Header.Get("Origin"), w.Result().Header.Get("Access-Control-Allow-Origin"))
// Parse the JSON response
var response struct {
Data []map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse JSON response: %v", err)
}
// Check the number of models returned
assert.Len(t, response.Data, 3)
// Check the details of each model
expectedModels := map[string]struct{}{
"model1": {},
"model2": {},
"model3": {},
}
for _, model := range response.Data {
modelID, ok := model["id"].(string)
assert.True(t, ok, "model ID should be a string")
_, exists := expectedModels[modelID]
assert.True(t, exists, "unexpected model ID: %s", modelID)
delete(expectedModels, modelID)
object, ok := model["object"].(string)
assert.True(t, ok, "object should be a string")
assert.Equal(t, "model", object)
created, ok := model["created"].(float64)
assert.True(t, ok, "created should be a number")
assert.Greater(t, created, float64(0)) // Assuming the timestamp is positive
ownedBy, ok := model["owned_by"].(string)
assert.True(t, ok, "owned_by should be a string")
assert.Equal(t, "llama-swap", ownedBy)
}
// Ensure all expected models were returned
assert.Empty(t, expectedModels, "not all expected models were returned")
}
func TestProxyManager_ProfileNonMember(t *testing.T) {
model1 := "path1/model1"
model2 := "path2/model2"
profileMemberName := ProcessKeyName("test", model1)
profileNonMemberName := ProcessKeyName("test", model2)
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
model1: getTestSimpleResponderConfig("model1"),
model2: getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {model1},
},
}
proxy := New(config)
defer proxy.StopProcesses()
// actual member of profile
{
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "model1")
}
// actual model, but non-member will 404
{
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
}
func TestProxyManager_Shutdown(t *testing.T) {
// make broken model configurations
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
model1Config.Proxy = "http://localhost:10001/"
model2Config := getTestSimpleResponderConfigPort("model2", 9992)
model2Config.Proxy = "http://localhost:10002/"
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
model3Config.Proxy = "http://localhost:10003/"
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1", "model2", "model3"},
},
Models: map[string]ModelConfig{
"model1": model1Config,
"model2": model2Config,
"model3": model3Config,
},
}
proxy := New(config)
// Start all the processes
var wg sync.WaitGroup
for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} {
wg.Add(1)
go func(modelName string) {
defer wg.Done()
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
// send a request to trigger the proxy to load
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
//fmt.Println(w.Code, w.Body.String())
}(modelName)
}
go func() {
<-time.After(time.Second)
proxy.Shutdown()
}()
wg.Wait()
}
func TestProxyManager_Unload(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
}
proxy := New(config)
proc, err := proxy.swapModel("model1")
assert.NoError(t, err)
assert.NotNil(t, proc)
assert.Len(t, proxy.currentProcesses, 1)
req := httptest.NewRequest("GET", "/unload", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK")
assert.Len(t, proxy.currentProcesses, 0)
}
// issue 62, strip profile slug from model name
func TestProxyManager_StripProfileSlug(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
},
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "ok")
}
// Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
}
// Define a helper struct to parse the JSON response.
type RunningResponse struct {
Running []struct {
Model string `json:"model"`
State string `json:"state"`
} `json:"running"`
}
// Create proxy once for all tests
proxy := New(config)
defer proxy.StopProcesses()
t.Run("no models loaded", func(t *testing.T) {
req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response RunningResponse
// Check if this is a valid JSON object.
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// We should have an empty running array here.
assert.Empty(t, response.Running, "expected no running models")
})
t.Run("single model loaded", func(t *testing.T) {
// Load just a model.
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Simulate browser call for the `/running` endpoint.
req = httptest.NewRequest("GET", "/running", nil)
w = httptest.NewRecorder()
proxy.HandlerFunc(w, req)
var response RunningResponse
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// Check if we have a single array element.
assert.Len(t, response.Running, 1)
// Is this the right model?
assert.Equal(t, "model1", response.Running[0].Model)
// Is the model loaded?
assert.Equal(t, "ready", response.Running[0].State)
})
t.Run("multiple models via profile", func(t *testing.T) {
// Load more than one model.
for _, model := range []string{"model1", "model2"} {
profileModel := ProcessKeyName("test", model)
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
// Simulate the browser call.
req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
var response RunningResponse
// The JSON response must be valid.
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// The response should contain 2 models.
assert.Len(t, response.Running, 2)
expectedModels := map[string]struct{}{
"model1": {},
"model2": {},
}
// Iterate through the models and check their states as well.
for _, entry := range response.Running {
_, exists := expectedModels[entry.Model]
assert.True(t, exists, "unexpected model %s", entry.Model)
assert.Equal(t, "ready", entry.State)
delete(expectedModels, entry.Model)
}
// Since we deleted each model while testing for its validity we should have no more models in the response.
assert.Empty(t, expectedModels, "unexpected additional models in response")
})
}
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"},
},
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
testCases := []struct {
name string
modelInput string
expectModel string
}{
{
name: "With Profile Prefix",
modelInput: "test:TheExpectedModel",
expectModel: "TheExpectedModel", // Profile prefix should be stripped
},
{
name: "Without Profile Prefix",
modelInput: "TheExpectedModel",
expectModel: "TheExpectedModel", // Should remain the same
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(tc.modelInput))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
// Generate random content length between 10 and 20
contentLength := rand.Intn(11) + 10 // 10 to 20
content := make([]byte, contentLength)
_, err = fw.Write(content)
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, tc.expectModel, response["model"])
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
})
}
}
func TestProxyManager_SplitRequestedModel(t *testing.T) {
tests := []struct {
name string
requestedModel string
expectedProfile string
expectedModel string
}{
{"no profile", "gpt-4", "", "gpt-4"},
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
{"only profile", "profile1:", "profile1", ""},
{"empty model", ":gpt-4", "", "gpt-4"},
{"empty profile", ":", "", ""},
{"no split char", "gpt-4", "", "gpt-4"},
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
profileName, modelName := splitRequestedModel(tt.requestedModel)
if profileName != tt.expectedProfile {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
if modelName != tt.expectedModel {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
})
}
}
// Test useModelName in configuration sends overrides what is sent to upstream
func TestProxyManager_UseModelName(t *testing.T) {
upstreamModelName := "upstreamModel"
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
modelConfig.UseModelName = upstreamModelName
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1"},
},
Models: map[string]ModelConfig{
"model1": modelConfig,
},
}
proxy := New(config)
defer proxy.StopProcesses()
tests := []struct {
description string
requestedModel string
}{
{"useModelName over rides requested model", "model1"},
{"useModelName over rides requested profile:model", "test:model1"},
}
for _, tt := range tests {
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), upstreamModelName)
})
}
for _, tt := range tests {
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(tt.requestedModel))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
_, err = fw.Write([]byte("test"))
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, upstreamModelName, response["model"])
})
}
}
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogRequests: true,
}
tests := []struct {
name string
method string
requestHeaders map[string]string
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "OPTIONS with no headers",
method: "OPTIONS",
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
},
},
{
name: "OPTIONS with specific headers",
method: "OPTIONS",
requestHeaders: map[string]string{
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
},
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
},
},
{
name: "Non-OPTIONS request",
method: "GET",
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proxy := New(config)
defer proxy.StopProcesses()
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
for k, v := range tt.requestHeaders {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()
proxy.ginEngine.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
for header, expectedValue := range tt.expectedHeaders {
assert.Equal(t, expectedValue, w.Header().Get(header))
}
})
}
}
+43
View File
@@ -0,0 +1,43 @@
package proxy
import (
"strings"
)
func isTokenChar(r rune) bool {
switch {
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r >= '0' && r <= '9':
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
default:
return false
}
return true
}
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
parts := strings.Split(headerValues, ",")
valid := make([]string, 0, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v == "" {
continue
}
validPart := true
for _, c := range v {
if !isTokenChar(c) {
validPart = false
break
}
}
if validPart {
valid = append(valid, v)
}
}
return strings.Join(valid, ", ")
}
+77
View File
@@ -0,0 +1,77 @@
package proxy
import "testing"
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "empty string",
input: "",
expected: "",
},
{
name: "whitespace only",
input: " ",
expected: "",
},
{
name: "single valid value",
input: "content-type",
expected: "content-type",
},
{
name: "multiple valid values",
input: "content-type, authorization, x-requested-with",
expected: "content-type, authorization, x-requested-with",
},
{
name: "values with extra spaces",
input: " content-type , authorization ",
expected: "content-type, authorization",
},
{
name: "values with tabs",
input: "content-type,\tauthorization",
expected: "content-type, authorization",
},
{
name: "values with invalid characters",
input: "content-type, auth\n, x-requested-with\r",
expected: "content-type, auth, x-requested-with",
},
{
name: "empty values in list",
input: "content-type,,authorization",
expected: "content-type, authorization",
},
{
name: "leading and trailing commas",
input: ",content-type,authorization,",
expected: "content-type, authorization",
},
{
name: "mixed valid and invalid values",
input: "content-type, \x00invalid, x-requested-with",
expected: "content-type, x-requested-with",
},
{
name: "mixed case values",
input: "Content-Type, my-Valid-Header, Another-hEader",
expected: "Content-Type, my-Valid-Header, Another-hEader",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SanitizeAccessControlRequestHeaderValues(tt.input)
if got != tt.expected {
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
tt.input, got, tt.expected)
}
})
}
}