Compare commits

..

185 Commits

Author SHA1 Message Date
Benson Wong 45ea792a3a Fix UI panel not saving position correctly 2025-08-06 14:02:22 -07:00
Benson Wong 1bc2802353 fix panels not saving sizing state 2025-08-06 14:00:21 -07:00
Benson Wong 701476c0c4 Update README.md - remove contributor block [skip ci]
Contributor information available on the Github page's sidebar. Redundant.
2025-08-06 11:11:47 -07:00
Ben Greene 5c63e0066c return models sorted by id in /v1/models (#222) 2025-08-06 10:04:52 -07:00
Martin Garton 8be5073c51 Fix typo (#223) [skip ci]
Fix typo `lama-swap` -> `llama-swap`
2025-08-06 10:02:38 -07:00
Aaron Ang 6307bd3205 Add support for building Linux ARM64 binary in Makefile (#221) 2025-08-05 16:26:06 -07:00
Benson Wong 558a72de17 UI Improvements (#219)
- use react-resizable-panels for UI
- improve icons for buttons
- improve mobile layout with drag/resize panels
2025-08-03 17:49:13 -07:00
Leoyzen dc42cf366d Add config monitor support for k8s configmap. (#217) 2025-08-03 08:05:48 -07:00
Ryein Goddard ba0a81937a Update README.md (#216)
Update git clone protocol to https
2025-08-01 19:48:09 -07:00
Benson Wong 574fdfabb4 UI improvements (#213)
* use two column for logs view on wider screens

* hide log controls when panel is minimized
2025-07-31 11:59:21 -07:00
Benson Wong 5172cb2e12 Update docs in Readme [skip ci] 2025-07-30 11:51:14 -07:00
Benson Wong 5672cb03fd Update github actions for notifying homebrew build (#212)
Combine homebrew-llama-swap event with the release action
2025-07-30 11:29:03 -07:00
Benson Wong 0f583163f7 add /health (#211) 2025-07-30 10:37:10 -07:00
Benson Wong 7905fa9ea3 Update trigger-homebrew-update.yml [skip ci] 2025-07-30 10:13:49 -07:00
Ian Sebastian Mathew bbaf172956 add trigger to rebuild homebrew formula (#210) 2025-07-30 10:12:21 -07:00
Benson Wong fd50932dbc Decouple MetricsMiddleware from downstream handlers (#206)
* Decouple MetricsMiddleware from downstream handlers

Remove ls-real-model-name optimization. Within proxyOAIHandler the
request body's bytes are required for various rewriting features
anyways. This negated any benefits from trying not to parse it twice.
2025-07-27 10:36:06 -07:00
Gaël James 8c693e7fcf Add endpoint aliases for reranking models (#201)
* Add endpoint aliases for reranking models
* Add MetricsMiddleware to the previous reranking endpoint
* Fix the embeddings endpoint not having model set
2025-07-24 08:32:47 -07:00
Benson Wong 8f2af26a41 fix stats on model page 2025-07-23 13:57:33 -07:00
Benson Wong 01d4838fb3 Fix token metrics parsing (#199)
Fix #198

- use llama-server's `timings` info if available in response body
- send "-1" for token/sec when not able to accurately calculate
  performance
- optimize streaming body search for metrics information
2025-07-22 23:10:14 -07:00
Benson Wong accd65294b add contributors to README [skip ci] 2025-07-21 23:16:48 -07:00
Benson Wong 7472a25864 Update README.md [skip ci]
update screenshot for web UI
2025-07-21 23:08:19 -07:00
Benson Wong cce0bc6aa1 add guard to ensure ls-real-model-name is set in context 2025-07-21 22:59:41 -07:00
Benson Wong 36e25125e8 UI tidy [skip ci] 2025-07-21 22:47:55 -07:00
Benson Wong 9a54273d15 Update UI with new Activity event stream from #195
- use new metrics data instead of log parsing
- auto-start events connection to server, improves responsiveness
- remove unnecessary libraries and code
2025-07-21 22:42:30 -07:00
g2mt 87dce5f8f6 Add metrics logging for chat completion requests (#195)
- Add token and performance metrics  for v1/chat/completions 
- Add Activity Page in UI
- Add /api/metrics endpoint

Contributed by @g2mt
2025-07-21 22:19:55 -07:00
Benson Wong 307e619521 remove old eventsources from UI 2025-07-19 15:36:40 -07:00
Benson Wong 6299c1b874 Fix High CPU (#189)
* vendor in kelindar/event lib and refactor to remove time.Ticker
2025-07-15 18:04:30 -07:00
Yathi a906cd459b Strip comments before macro expansion in config (#193)
A bug fix that ensures comments don't interfere with macro expansion by
removing them first. This prevents unwanted comment text from appearing
in the final expanded command.

Co-authored-by: Yathiraj Bollimbala G <yathi@yStudio.localdomain>
2025-07-15 10:14:16 -07:00
Benson Wong 78b2bc3dbc add toggle to hide/show unlisted models (#187) 2025-07-02 16:14:20 -07:00
Benson Wong 6a058e4191 Change fsnotify to watch config directory instead of file
The fsnotify library suggests watching a directory and checking that the
name matches the configuration file.
2025-07-02 10:23:52 -07:00
Benson Wong 1921e570d7 Add Event Bus (#184)
Major internal refactor to use an event bus to pass event/messages along. These changes are largely invisible user facing but sets up internal design for real time stats and information.

- `--watch-config` logic refactored for events
- remove multiple SSE api endpoints, replaced with /api/events
- keep all functionality essentially the same
- UI/backend sync is in near real time now
2025-07-01 22:17:35 -07:00
Benson Wong c867a6c9a2 Add name and description to v1/models list (#179)
* Add support for name and description in v1/models list
* add configuration example for name and description
2025-06-30 23:02:44 -07:00
Leoyzen 3bd1b23ce0 fix config hot-reload on k8s (#181)
Co-authored-by: Leoyzen <leoyzen@gmial.com>
2025-06-27 11:49:31 -07:00
srevn 10606abf89 fix config hot-reload on macos (#180)
Co-authored-by: srevn <srevn@github>
2025-06-26 09:20:50 -07:00
Benson Wong fefd14903d improve log display and add a small stats table in ui (#178) 2025-06-25 12:27:49 -07:00
Benson Wong 717d64e336 update GUI image in README [skip ci] 2025-06-24 10:38:28 -07:00
Benson Wong 285191e655 Various UI improvements (#176)
* add retry/backoff to reconnecting log streams
* update favicons
2025-06-23 16:17:21 -07:00
Benson Wong 4236cec03a Add Filters to Model Configuration (#174)
llama-swap can strip specific keys in JSON requests. This is useful for removing the ability for clients to set sampling parameters like temperature, top_k, top_p, etc.
2025-06-23 10:52:29 -07:00
Alex O'Connell 756193d0dd Load models in the UI without navigating the page (#173)
* Load models in the UI without navigating the page

* fix table layout for mobile
2025-06-19 14:39:07 -07:00
Benson Wong a6b2e930d8 Update README.md [skip ci] 2025-06-18 11:47:08 -07:00
Benson Wong 9e02c22ff8 stopCmd should use same environment as p.cmd.Env (#171, #172) 2025-06-18 11:36:59 -07:00
Benson Wong 0bdbf2fdc1 fix more goreleaser deprecation warnings [skip ci] 2025-06-18 11:15:12 -07:00
Benson Wong 49035e2e8e Append custom env vars instead of replace in Process (#171)
Append custom env vars instead of replace in Process (#168, #169)

PR #162 refactored the default configuration code. This
introduced a subtle bug where `env` became `[]string{}` instead of the
default of `nil`.

In golang, `exec.Cmd.Env == nil` means to use the "current process's
environment". By setting it to `[]string{}` as a default the Process's
environment was emptied out which caused an array of strange and
difficult to troubleshoot behaviour. See issues #168 and #169

This commit changes the behaviour to append model configured environment
variables to the default list rather than replace them.
2025-06-18 11:09:13 -07:00
Benson Wong 9963ae18bf fix? deprecation warning in .goreleaser.yaml [skip-ci] 2025-06-18 07:49:33 -07:00
Benson Wong 2ae48c713b add debug output for start command 2025-06-18 07:43:23 -07:00
Benson Wong 54c519e365 update Makefile to install ui deps 2025-06-17 09:54:01 -07:00
Benson Wong 3fce9ee0e9 Update README.md [skip ci] 2025-06-17 09:53:22 -07:00
Benson Wong 5899ae7966 Update README.md [skip ci] 2025-06-17 09:52:47 -07:00
Benson Wong 591a9cdf4d update release.yml 2025-06-16 16:50:25 -07:00
Benson Wong 9a3c656738 New UI (#157, #164)
- Add a react UI to replace the plain HTML one. 
- Serve as a foundation for better GUI interactions
2025-06-16 16:45:19 -07:00
Benson Wong 75015f82ea fix bug caused by macro replacement order (#166)
User defined macros should be applied before checking for ${PORT} constraint in model.cmd and model.proxy.
2025-06-16 15:32:09 -07:00
Thammachart Chinvarapon cc33b6c270 restore intel docker builds (#163) 2025-06-16 11:13:49 -07:00
Benson Wong 4fa12a429c Refactor all default config values into config.go (#162)
- Move all default values into one place.
- Update tests to be more cross platform
2025-06-15 12:32:00 -07:00
Benson Wong 2dc0ca0663 improve llama-swap upstream process recovery and restarts (#155)
Refactor internal upstream process life cycle management to recover better from unexpected situations. With this change llama-swap should never need to be restarted due to a crashed upstream child process.  The `StateFailed` state was removed in favour of always trying to start/restart a process.
2025-06-05 16:24:55 -07:00
Daniel Hofer a84098d3b4 Add missing object type to /v1/models endpoint (#154) 2025-06-02 09:25:45 -07:00
Benson Wong 4d02ccd26a Update README.md [skip ci] 2025-05-30 09:38:45 -07:00
Benson Wong dfd47eeac4 Readme updates [skip ci] 2025-05-30 09:19:08 -07:00
Benson Wong 1ac6499c08 Add macros to Configuration schema (#149)
* Add macros to Configuration schema
* update docs
2025-05-29 21:51:25 -07:00
Benson Wong 25f3dc25e7 small doc update [skip ci] 2025-05-26 16:03:27 -07:00
Benson Wong 8422e4e6a1 move some docs to the wiki [no-ci] 2025-05-26 15:46:08 -07:00
Benson Wong 02ee29d881 increase default healthCheckTimeout to 120s 2025-05-26 09:57:53 -07:00
Benson Wong b2a891f8f4 Disable building of intel container until it's fixed upstream 2025-05-23 22:54:43 -07:00
Yuta Hayashibe 8d2b568897 Improve install script (#144)
* Use `python3` instead of `curl` and `jq`

* Use quote to word splitting

* Remove undefined `local` in POSIX sh

* Added `LLAMA_SWAP_DEFAULT_ADDRESS` to customize the server address

* Added `mktemp` to `NEEDS`
2025-05-23 09:39:55 -07:00
Yuta Hayashibe fb44cf4e08 Fix typos (#143) 2025-05-23 08:40:15 -07:00
Benson Wong 02aee4e86d remove noisy debug print message 2025-05-20 10:43:10 -07:00
Benson Wong f45896d395 add guard to avoid unnecessary logic in Process.Shutdown 2025-05-20 10:43:09 -07:00
choyuansu f7e46a359f Add link to unload endpoint in upstream list (#140)
* Add link to open /unload
2025-05-20 08:31:44 -07:00
choyuansu c260907415 Add linux install and uninstall shell scripts (#139)
Contribution for install, and uninstall llama-swap in linux.
2025-05-19 12:03:33 -07:00
Benson Wong b83a5fa291 make Failed stated recoverable (#137)
A process in the failed state can transition to stopped either by calling /unload or swapping to another model.
2025-05-16 19:54:44 -07:00
Benson Wong 6e2ff28d59 improve cmdStop docs [no ci] 2025-05-16 13:52:04 -07:00
Benson Wong a8b81f2799 Add stopCmd for custom stopping instructions (#136)
Allow configuration of how a model is stopped before swapping. Setting `cmdStop` in the configuration will override the default behaviour and enables better integration with other process/container managers like docker or podman.
2025-05-16 13:48:42 -07:00
Benson Wong f9ee7156dc update configuration examples for multiline yaml commands #133 2025-05-16 11:45:39 -07:00
fakezeta 2d00120781 Update proxymanager.go (#135) 2025-05-16 06:45:09 -07:00
Benson Wong afc9aef058 Fix #133 SanitizeCommand removes comments (#134) 2025-05-15 15:28:50 -07:00
Benson Wong d7b390df74 Add GH Action for Testing on Windows (#132)
* Add windows specific test changes
* Change the command line parsing library - Possible breaking changes for windows users!
2025-05-14 21:51:53 -07:00
Benson Wong 5025c2f1f3 Add GH windows tests (not working yet) 2025-05-14 19:58:22 -07:00
Benson Wong e3a0b013c1 add content length test for #131 2025-05-14 19:50:01 -07:00
Fadenfire f5763a94a0 Fix content length being incorrect when useModelName is used (#131)
* Fix content length being incorrect when useModelName is used
* Update c.Request.ContentLength as well
2025-05-14 19:37:54 -07:00
Benson Wong 8ada72eb57 Update issue templates 2025-05-14 16:36:32 -07:00
Benson Wong 2441b383d3 Make checking for process killed status more robust 2025-05-14 16:26:56 -07:00
Benson Wong 25f251699c Prevent StateFailed after SIGKILL (#129)
Closes #125
2025-05-14 10:47:35 -07:00
Benson Wong 7f37bcc6eb Improve testing around using SIGKILL (#127)
* Add test for SIGKILL of process
* silent TestProxyManager_RunningEndpoint debug output
* Ref #125
2025-05-13 21:21:52 -07:00
Benson Wong 519c3a4d22 Change /unload to not wait for inflight requests (#125)
Sometimes upstreams can accept HTTP but never respond causing requests
to build up waiting for a response. This can block Process.Stop() as
that waits for inflight requests to finish. This change refactors the
code to not wait when attempting to shutdown the process.
2025-05-13 11:39:19 -07:00
Benson Wong 9dc4bcb46c Add a concurrency limit to Process.ProxyRequest (#123) 2025-05-12 18:12:52 -07:00
Benson Wong cb876c143b update example config 2025-05-12 10:20:18 -07:00
Sam bc652709a5 Add config hot-reload (#106)
introduce --watch-config command line option to reload ProxyManager when configuration changes.
2025-05-11 17:37:00 -07:00
Thammachart Chinvarapon 9548931258 ci: re-enabled intel build pipeline (#121) 2025-05-11 00:19:57 -07:00
Benson Wong 5c5a5da664 Update README.md
removed extra section.
2025-05-06 06:59:15 -07:00
Benson Wong aa9ef59aa5 Create .coderabbit.yaml 2025-05-05 19:47:23 -07:00
Benson Wong 09e52c0500 Automatic Port Numbers (#105)
Add automatic port numbers assignment in configuration file. The string `${PORT}` will be substituted in model.cmd and model.proxy for an actual port number. This also allows model.proxy to be omitted from the configuration.
2025-05-05 17:07:43 -07:00
Benson Wong ca9063ffbe ensure aliases are unique (#116) 2025-05-05 15:34:18 -07:00
Benson Wong 21d7973d11 Improve content-length handling (#115)
ref: See #114

* Improve content-length handling
- Content length was not always being sent
- Add tests for content-length
2025-05-05 10:46:26 -07:00
Yi Hong Ang cc450e9c5f fix issue where proxy is still proxying with chunked transfer-encoding (#114) 2025-05-05 10:00:03 -07:00
Benson Wong 27465fe053 bug fix with missing early return statements fix #112 2025-05-05 09:32:44 -07:00
Benson Wong 9667989727 Disabling intel container build since it's been broken for weeks. 2025-05-04 21:39:42 -07:00
Benson Wong d9a1ddea0d Truncate web logs to 100K characters (#111)
* set log limit to 100K in browser
2025-05-02 23:43:21 -07:00
Benson Wong e7ab024ca0 small locking optimization 2025-05-02 23:18:07 -07:00
Benson Wong 448ccae959 Introduce Groups Feature (#107)
Groups allows more control over swapping behaviour when a model is requested. The new groups feature provides three ways to control swapping: within the group, swapping out other groups or keep the models in the group loaded persistently (never swapped out). 

Closes #96, #99 and #106.
2025-05-02 22:35:38 -07:00
Benson Wong ec0348e431 Reduce stale time for issues 2025-04-29 21:16:34 -07:00
Benson Wong 06eda7f591 tag all process logs with its ID (#103)
Makes identifying Process of log messages easier
2025-04-25 12:58:25 -07:00
Benson Wong 5fad24c16f Make checkHealthTimeout Interruptable during startup (#102)
interrupt and exit Process.start() early if the upstream process exits prematurely or unexpectedly.
2025-04-24 14:39:33 -07:00
Benson Wong 8404244fab Moderate security update for golang/x/net -> v0.38.0 2025-04-24 09:58:40 -07:00
Benson Wong 712cd01081 fix confusing INFO message [no ci] 2025-04-24 09:56:20 -07:00
Benson Wong 1f7aa359b1 Update header image
AI has finally made my dreams of llamas in funny clothing and stuck in
a claw machine waiting to be picked come true!
2025-04-23 13:02:12 -07:00
Benson Wong b138d6cf25 fix starhistory in README 2025-04-15 20:23:46 -07:00
Benson Wong fb7c808082 add timing for Process start, stop, total request time (#91) 2025-04-14 14:34:59 -07:00
Benson Wong a7e640b0f7 add aider example 2025-04-07 12:37:14 -07:00
Benson Wong 593604dfdc Show proxy and upstream logs in separate columns in logs UI 2025-04-05 10:36:54 -07:00
Benson Wong b8f888f864 Logging Improvements (#88)
This change revamps the internal logging architecture to be more flexible and descriptive. Previously all logs from both llama-swap and upstream services were mixed together. This makes it harder to troubleshoot and identify problems. This PR adds these new endpoints: 

- `/logs/stream/proxy` - just llama-swap's logs
- `/logs/stream/upstream` - stdout output from the upstream server
2025-04-04 21:01:33 -07:00
Benson Wong 192b2ae621 Remove no longer needed test 2025-04-04 14:46:01 -07:00
Benson Wong b7f8cb5094 Limit Access-Control-Allow-Origin to OPTIONS preflight requests #85 2025-04-04 14:44:35 -07:00
Benson Wong a23da6eb57 Sanitize CORS headers (#85)
Add sanitation step for `Access-Control-Allow-Headers` when echoing back user supplied headers
2025-04-01 08:43:53 -07:00
Grigorii Khvatskii 4c3aa40564 add graceful process termination on windows (#82) 2025-03-25 15:26:33 -07:00
Benson Wong 84e2c07a7e Refactor wildcard out of CORS headers (#81)
Changes to CORS functionality: 

- `Access-Control-Allow-Origin: *` is set for all requests 
- for pre-flight OPTIONS requests
  - specify methods: `Access-Control-Allow-Methods: GET, POST, PUT, PATCH, DELETE, OPTIONS`
  - if the client sent `Access-Control-Request-Headers` then echo back the same value in `Access-Control-Allow-Headers`. If no `Access-Control-Request-Headers` were sent, then send back a default set
  - set `Access-Control-Max-Age: 86400` to that may improve performance 
- Add CORS tests to the proxy-manager
2025-03-25 15:24:43 -07:00
Benson Wong 680af28bcc Allow very permissive CORS headers (#77) 2025-03-20 15:50:21 -07:00
Benson Wong d94db42ffe fix bug checking incorrect error 2025-03-20 15:49:36 -07:00
Benson Wong 93cd83c55c add override for windows (#76) 2025-03-20 13:23:04 -07:00
Benson Wong 5565fca3ac add some badges to README 2025-03-19 11:25:06 -07:00
Benson Wong d625ab8d92 Refactor process state management (#70) (#73)
* add isValidStateTransition helper function
* Replace Process.setState() with Process.swapState()
* Refactor locking logic in Process
2025-03-15 17:14:03 -07:00
Benson Wong a3f82c140b tidy up config examples in README 2025-03-15 10:36:45 -07:00
Benson Wong 5c97299e7b Add support for sending a custom model name to upstream (#69) (#71)
* add test for splitRequestedModel()
* Add `useModelName` parameter to model configuration
* add docs to README
2025-03-14 21:07:52 -07:00
Benson Wong 671c1a5a7b update deps 2025-03-13 14:00:15 -07:00
Benson Wong 52c0196e0f clean up feature list in readme 2025-03-13 13:55:20 -07:00
Benson Wong 3201a68a04 Add /v1/audio/transcriptions support (#41)
* add support for /v1/audio/transcriptions
2025-03-13 13:49:39 -07:00
Florin-Gabriel Dumitru 3ac94ad20e Adds an endpoint '/running' (#61)
* Adds an endpoint '/running' that returns either an empty JSON object if no model has been loaded so far, or the last model loaded (model key) and it's current state (state key). Possible state values are: stopped, starting, ready and stopping.

* Improves the `/running` endpoint by allowing multiple entries under the `running` key within the JSON response.
Refactors the `/running` method name (listRunningProcessesHandler).
Removes the unlisted filter implementation.

* Adds tests for:
- no model loaded
- one model loaded
- multiple models loaded

* Adds simple comments.

* Simplified code structure as per 250313 comments on PR #65.

---------

Co-authored-by: FGDumitru|B <xelotx@gmail.com>
2025-03-13 13:42:59 -07:00
Benson Wong 60355bf74a fix some potentially confusing Process.start() comment 2025-03-11 11:00:45 -07:00
Benson Wong 9b2ed244e2 Improve Continuous integration and fix concurrency bugs (#66)
- improvements to the continuous GH actions
- fix edge case concurrency bugs with Process.start() and state transitions discovered setting up CI.
2025-03-11 10:39:14 -07:00
Benson Wong eeb72297f7 add first version of CI for go 2025-03-11 08:45:56 -07:00
Benson Wong eabfe70cc6 add GH action to close inactive issues 2025-03-09 19:51:48 -07:00
Benson Wong 29cd98878d better container build logic when upstream containers do not exist 2025-03-09 13:02:06 -07:00
Benson Wong b3d331da0d Properly strip profile name slug from models fixes (#62)
The profile slug in a model name, `profile:model`, is specific to
llama-swap. This strips `profile:` out of the model name request so
upstreams that expect just `model` work and do not require knowing about
the profile slug.
2025-03-09 12:41:52 -07:00
Benson Wong 62275e078d add examples to restart on config change #59 2025-03-06 10:50:29 -08:00
Benson Wong 88916059e1 add /unload to docs 2025-03-03 10:44:16 -08:00
Benson Wong 082d5d0fc5 Add /unload endpoint (#58) to unload all currently running models 2025-03-03 10:33:36 -08:00
Benson Wong 53338938bd increase health check to a minimum of 5 seconds 2025-03-03 10:04:08 -08:00
Benson Wong af653347ae Update README.md w/ starhistory graph 2025-02-27 16:43:34 -08:00
Benson Wong 1e25b44a06 add workflow_dispatch to release action 2025-02-18 17:27:43 -08:00
Benson Wong 0815bb4cc3 Add windows to goreleaser #54 2025-02-18 17:26:43 -08:00
daschiller 7187cfe52e add Windows build support to Makefile (#54) 2025-02-18 17:24:31 -08:00
Benson Wong 24089d2d9c remove "no musa container" note from README 2025-02-18 16:38:48 -08:00
Benson Wong ebabe55ff3 Delete untagged packages after build and push (#55) 2025-02-18 10:32:32 -08:00
Benson Wong 41a338297c deletion of untagged containers happen after build-and-push 2025-02-18 10:11:59 -08:00
Benson Wong 7e3353efeb add action step to remove untagged containers 2025-02-18 10:08:41 -08:00
Benson Wong 4ed58fb173 update container build action 2025-02-18 09:59:06 -08:00
Benson Wong f5a2be698d revert package src until new ggml-org has them 2025-02-15 18:23:58 -08:00
Benson Wong f5e6ec3b7a fix package src in containerfile 2025-02-15 18:20:35 -08:00
Benson Wong 3f462da146 switch package source from ggerganov to ggml-org 2025-02-15 18:18:49 -08:00
Benson Wong 48bd766536 Update README.md 2025-02-14 22:05:52 -08:00
Benson Wong 8d319da4dd improve README organization (i think...) 2025-02-14 15:59:12 -08:00
Benson Wong be7c502448 improve docs 2025-02-14 15:47:31 -08:00
Benson Wong 92336f00bf more container build fixes 2025-02-14 15:34:38 -08:00
Benson Wong ed2a50d9a6 fix bug in build-container.sh 2025-02-14 15:27:56 -08:00
Benson Wong 0acfdb9f78 update workflow to build cpu and disable musa 2025-02-14 15:26:59 -08:00
Benson Wong 96a8ea0241 add cpu docker container build 2025-02-14 15:25:45 -08:00
Benson Wong f20f2c9b7a add docs and container build improvements #43 2025-02-14 12:20:07 -08:00
Benson Wong 7a97c38828 enable parallel container built #46 2025-02-14 11:04:33 -08:00
Benson Wong 4885132565 more permissions futzing 2025-02-14 11:02:15 -08:00
Benson Wong 8b46a0b7f1 grant package:write to container workflow #46 2025-02-14 10:55:30 -08:00
Benson Wong 1b6736ec6f rename workflow for containers 2025-02-14 10:50:15 -08:00
Benson Wong ddc1ce031e fix container file name #46 2025-02-14 10:49:44 -08:00
Benson Wong 11d024bbaa just build cuda while debugging 2025-02-14 10:48:06 -08:00
Benson Wong 43e23c16dc add check for GITHUB_TOKEN #46 2025-02-14 10:47:25 -08:00
Benson Wong f9c8e763ba add execute bit on build-container.sh 2025-02-14 10:44:53 -08:00
Benson Wong d7e1bb9f7c add GITHUB_TOKEN to container build env 2025-02-14 10:43:44 -08:00
Benson Wong ab93460a8b first container code (#52) 2025-02-14 10:39:25 -08:00
Benson Wong 13d4552edc Add FreeBSD/amd64 to auto built releases (#51) 2025-02-13 16:44:31 -08:00
Benson Wong 6667e307a2 Update README.md 2025-02-08 10:28:35 -08:00
Benson Wong 7ac446e6a9 Update README.md 2025-02-08 10:26:11 -08:00
Benson Wong eab9795bcc remove panic() when cmd or process is nil 2025-02-07 14:00:32 -08:00
Benson Wong 09bdd86b54 Improve shutdown behaviour (#47) (#49)
Introduce `Process.Shutdown()` and `ProxyManager.Shutdown()`. These two function required a lot of internal process state management refactoring. A key benefit is that `Process.start()` is now interruptable. When `Shutdown()` is called it will break the long health check loop. 

State management within Process is also improved. Added `starting`, `stopping` and `shutdown` states. Additionally, introduced a simple finite state machine to manage transitions.
2025-02-05 17:19:59 -08:00
Benson Wong 85cd74a51c Improve process start and stop reliability (#38)
Refactor Process.start()/Stop() logic (#38)
- remove cmd.Wait() call in start(). This seems to conflict with the one
  in .Stop(). Removing it eliminated no child errors
- eliminate goroutines in .start() as it no longer required
2025-02-03 11:50:38 -08:00
Benson Wong 314d2f2212 remove cmd_stop configuration and functionality from PR #40 (#44)
* remove cmd_stop functionality from #40
2025-01-31 12:42:44 -08:00
Benson Wong fad25f3e11 Use client request context in proxy request (#43)
Canceled or closed HTTP requests from clients will also stop the proxied
HTTP requests to upstreamed servers.
2025-01-31 10:21:49 -08:00
Benson Wong 2c3e3e27f7 Support OPTIONS requests (#42)
Add middleware that responds with permissive OPTIONS headers
for all request paths.
2025-01-31 10:09:07 -08:00
Benson Wong baeb0c4e7f Add cmd_stop configuration to better support docker (#35)
Add `cmd_stop` to model configuration to run a command instead of sending a SIGTERM to shutdown a process before swapping.
2025-01-30 16:59:57 -08:00
Benson Wong 2833517eef Improve handling of process that do not handle SIGTERM (#38)
- Process TTL goroutine did not have a return after .Stop()
- Improve logging
- Add test TestProcess_LowTTLValue to measure SIGTERM error rate
2025-01-20 14:39:52 -08:00
Benson Wong abdc2bfdb3 Fix panic when requesting non-members of profiles
A panic occurs when a request for an invalid profile:model pair is made.
The edge case is that the profile exists and the model exists but they're
not configured as a pair.

This adds an additional check to make sure the profile:model pair is
valid before attempting to swap the model.
2025-01-16 12:06:38 -08:00
Benson Wong c3b834737f Update README.md 2025-01-13 22:37:30 -08:00
Benson Wong 3c8e727b73 Update README.md 2025-01-12 19:48:35 -08:00
Benson Wong 3a1e9f81f1 support TTS /v1/audio/speech (#36) 2025-01-12 16:27:01 -08:00
Benson Wong 72c883f36c Update README.md 2025-01-02 09:01:51 -08:00
Benson Wong 1b04d034cf Update README.md 2025-01-02 08:59:11 -08:00
Benson Wong 2e45f5692a Update README.md
Improve README documentation.
2025-01-01 12:51:24 -08:00
Benson Wong c97b80bdfe Update README.md 2025-01-01 12:25:45 -08:00
Benson Wong ae3ef9bc39 Refactor UI (#33)
- add html to / instead of 404
- add client side regex to /logs
2024-12-23 19:48:59 -08:00
87 changed files with 11907 additions and 942 deletions
+15
View File
@@ -0,0 +1,15 @@
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
language: "en-US"
early_access: false
reviews:
profile: "chill"
request_changes_workflow: false
high_level_summary: true
poem: false
review_status: true
collapse_walkthrough: false
auto_review:
enabled: true
drafts: false
chat:
auto_reply: true
+37
View File
@@ -0,0 +1,37 @@
---
name: Bug Report
about: Something is not working as expected...
title: ''
labels: bug
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**Expected behaviour**
A clear and concise description of what you expected to happen.
**Operating system and version**
- OS: (linux, osx, windows, freebsd, etc)
- GPUs: (list architecture)
**My Configuration**
```yaml
# copy / paste your configuration here
```
**Proxy Logs**
```
# copy / paste from /logs
```
**Upstream Logs**
```
# copy/paste from /logs
```
+23
View File
@@ -0,0 +1,23 @@
# https://docs.github.com/en/actions/use-cases-and-examples/project-management/closing-inactive-issues
name: Close inactive issues
on:
schedule:
- cron: "32 1 * * *"
jobs:
close-issues:
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v9
with:
days-before-issue-stale: 14
days-before-issue-close: 14
stale-issue-label: "stale"
stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity."
close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale."
days-before-pr-stale: -1
days-before-pr-close: -1
repo-token: ${{ secrets.GITHUB_TOKEN }}
+46
View File
@@ -0,0 +1,46 @@
name: Build Containers
on:
# time has no specific meaning, trying to time it after
# the llama.cpp daily packages are published
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
schedule:
- cron: "37 5 * * *"
# Allows manual triggering of the workflow
workflow_dispatch:
jobs:
build-and-push:
runs-on: ubuntu-latest
strategy:
matrix:
platform: [intel, cuda, vulkan, cpu, musa]
fail-fast: false
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Log in to GitHub Container Registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Run build-container
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: ./docker/build-container.sh ${{ matrix.platform }} true
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
# see: https://github.com/actions/delete-package-versions/issues/74
delete-untagged-containers:
needs: build-and-push
runs-on: ubuntu-latest
steps:
- uses: actions/delete-package-versions@v5
with:
package-name: 'llama-swap'
package-type: 'container'
delete-only-untagged-versions: 'true'
+50
View File
@@ -0,0 +1,50 @@
name: Windows CI
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
# Allows manual triggering of the workflow
workflow_dispatch:
jobs:
run-tests:
runs-on: windows-latest
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.23'
# cache simple-responder to save the build time
- name: Restore Simple Responder
id: restore-simple-responder
uses: actions/cache/restore@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
# necessary for testing proxy/Process swapping
- name: Create simple-responder
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
shell: bash
run: make simple-responder-windows
- name: Save Simple Responder
# nothing new to save ... skip this step
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
id: save-simple-responder
uses: actions/cache/save@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
- name: Test all
shell: bash
run: make test-all
+47
View File
@@ -0,0 +1,47 @@
name: Linux CI
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
# Allows manual triggering of the workflow
workflow_dispatch:
jobs:
run-tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.23'
# cache simple-responder to save the build time
- name: Restore Simple Responder
id: restore-simple-responder
uses: actions/cache/restore@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
# necessary for testing proxy/Process swapping
- name: Create simple-responder
run: make simple-responder
- name: Save Simple Responder
# nothing new to save ... skip this step
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
id: save-simple-responder
uses: actions/cache/save@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
- name: Test all
run: make test-all
+47 -1
View File
@@ -5,6 +5,13 @@ on:
tags:
- '*'
# Allows manual triggering of the workflow
workflow_dispatch:
inputs:
tag:
description: 'Tag version to release (e.g. v144)'
required: true
permissions:
contents: write
@@ -17,9 +24,22 @@ jobs:
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.inputs.tag || github.ref }}
-
name: Set up Go
uses: actions/setup-go@v5
-
name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '23'
-
name: Install dependencies and build UI
run: |
cd ui
npm ci
npm run build
-
name: Run GoReleaser
uses: goreleaser/goreleaser-action@v6
@@ -30,4 +50,30 @@ jobs:
version: '~> v2'
args: release --clean
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
trigger-tap-update:
runs-on: ubuntu-latest
needs: goreleaser
steps:
- name: "Resolve tag to dispatch"
id: tag
run: |
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
echo "tag=${{ github.event.inputs.tag }}" >> "$GITHUB_OUTPUT"
else
echo "tag=${{ github.ref_name }}" >> "$GITHUB_OUTPUT"
fi
- name: "Trigger tap repository update"
uses: peter-evans/repository-dispatch@v2
with:
token: ${{ secrets.TAP_REPO_PAT }}
repository: mostlygeek/homebrew-llama-swap
event-type: new-release
client-payload: |
{
"release": {
"tag_name": "${{ steps.tag.outputs.tag }}"
}
}
+22 -1
View File
@@ -6,6 +6,27 @@ builds:
goos:
- linux
- darwin
- freebsd
- windows
goarch:
- amd64
- arm64
- arm64
ignore:
- goos: freebsd
goarch: arm64
- goos: windows
goarch: arm64
archives:
- id: default
formats:
- tar.gz
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
builds_info:
group: root
owner: root
format_overrides:
# use zip format for windows
- goos: windows
formats:
- zip
+28 -7
View File
@@ -19,21 +19,38 @@ all: mac linux simple-responder
clean:
rm -rf $(BUILD_DIR)
test:
go test -short -v ./proxy
proxy/ui_dist/placeholder.txt:
mkdir -p proxy/ui_dist
touch $@
test-all:
go test -v ./proxy
test: proxy/ui_dist/placeholder.txt
go test -short -v -count=1 ./proxy
test-all: proxy/ui_dist/placeholder.txt
go test -v -count=1 ./proxy
ui/node_modules:
cd ui && npm install
# build react UI
ui: ui/node_modules
cd ui && npm run build
# Build OSX binary
mac:
mac: ui
@echo "Building Mac binary..."
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
# Build Linux binary
linux:
linux: ui
@echo "Building Linux binary..."
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
# Build Windows binary
windows: ui
@echo "Building Windows binary..."
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
# for testing proxy.Process
simple-responder:
@@ -41,6 +58,10 @@ simple-responder:
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
simple-responder-windows:
@echo "Building simple responder for windows"
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe misc/simple-responder/simple-responder.go
# Ensure build directory exists
$(BUILD_DIR):
mkdir -p $(BUILD_DIR)
@@ -60,4 +81,4 @@ release:
git tag "$$new_tag";
# Phony targets
.PHONY: all clean osx linux
.PHONY: all clean ui mac linux windows simple-responder
+156 -108
View File
@@ -1,126 +1,189 @@
![llama-swap header image](header2.png)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml)
![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap)
# llama-swap
![llama-swap header image](header.jpeg)
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
# Introduction
llama-swap is an OpenAI API compatible server that gives you complete control over how you use your hardware. It automatically swaps to the configuration of your choice for serving a model. Since [llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, let's swap the server instead!
Written in golang, it is very easy to install (single binary with no dependencies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
Features:
## Features:
- ✅ Easy to deploy: single binary with no dependencies
- ✅ Easy to config: single yaml file
- ✅ On-demand model switching
- ✅ OpenAI API supported endpoints:
- `v1/completions`
- `v1/chat/completions`
- `v1/embeddings`
- `v1/rerank`, `v1/reranking`, `rerank`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
- ✅ llama-swap custom API endpoints
- `/ui` - web UI
- `/log` - remote log monitoring
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- `/health` - just returns "OK"
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
- ✅ Automatic unloading of models after timeout by setting a `ttl`
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Docker and Podman support
- ✅ Full control over server settings per model
- ✅ OpenAI API support (`v1/completions`, `v1/chat/completions`, `v1/embeddings` and `v1/rerank`)
- ✅ Multiple GPU support
- ✅ Run multiple models at once with `profiles`
- ✅ Remote log monitoring at `/log`
- ✅ Automatic unloading of models from GPUs after timeout
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabblyAPI, etc)
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
## Releases
## How does llama-swap work?
Builds for Linux and OSX are available on the [Releases](https://github.com/mostlygeek/llama-swap/releases) page.
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
### Building from source
1. Install golang for your system
1. `git clone git@github.com:mostlygeek/llama-swap.git`
1. `make clean all`
1. Binaries will be in `build/` subdirectory
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
## config.yaml
llama-swap's configuration is purposefully simple.
llama-swap is managed entirely through a yaml configuration file.
It can be very minimal to start:
```yaml
# Seconds to wait for llama.cpp to load and be ready to serve requests
# Default (and minimum) is 15 seconds
healthCheckTimeout: 60
# Write HTTP logs (useful for troubleshooting), defaults to false
logRequests: true
# define valid model values and the upstream server start
models:
"llama":
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
# where to reach the server started by cmd, make sure the ports match
proxy: http://127.0.0.1:8999
# aliases names to use this model for
aliases:
- "gpt-4o-mini"
- "gpt-3.5-turbo"
# check this path for an HTTP 200 OK before serving requests
# default: /health to match llama.cpp
# use "none" to skip endpoint checking, but may cause HTTP errors
# until the model is ready
checkEndpoint: /custom-endpoint
# automatically unload the model after this many seconds
# ttl values must be a value greater than 0
# default: 0 = never unload model
ttl: 60
"qwen":
# environment variables to pass to the command
env:
- "CUDA_VISIBLE_DEVICES=0"
# multiline for readability
cmd: >
llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
proxy: http://127.0.0.1:8999
# unlisted models do not show up in /v1/models or /upstream lists
# but they can still be requested as normal
"qwen-unlisted":
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
unlisted: true
# profiles make it easy to managing multi model (and gpu) configurations.
#
# Tips:
# - each model must be listening on a unique address and port
# - the model name is in this format: "profile_name:model", like "coding:qwen"
# - the profile will load and unload all models in the profile at the same time
profiles:
coding:
- "qwen"
- "llama"
"qwen2.5":
cmd: |
/path/to/llama-server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
--port ${PORT}
```
**Advanced examples**
However, there are many more capabilities that llama-swap supports:
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
- `groups` to run multiple models at once
- `ttl` to automatically unload models
- `macros` for reusable snippets
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
- `env` to pass custom environment variables to inference servers
- `cmdStop` for to gracefully stop Docker/Podman containers
- `useModelName` to override model names sent to upstream servers
- `healthCheckTimeout` to control model startup wait times
- `${PORT}` automatic port variables for dynamic port assignment
See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration) in the wiki all options and examples.
## Web UI
llama-swap ships with a real time web interface to monitor logs and status of models:
<img width="1786" height="1334" alt="image" src="https://github.com/user-attachments/assets/d6258cb9-1dad-40db-828f-2be860aec8fe" />
## Installation
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
llama-swap can be installed in multiple ways
1. Docker
2. Homebrew (OSX and Linux)
3. From release binaries
4. From source
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Docker images with llama-swap and llama-server are built nightly.
```shell
# use CPU inference comes with the example config above
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
# qwen2.5 0.5B
$ curl -s http://localhost:9292/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer no-key" \
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
jq -r '.choices[0].message.content'
# SmolLM2 135M
$ curl -s http://localhost:9292/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer no-key" \
-d '{"model":"smollm2","messages": [{"role": "user","content": "tell me a joke"}]}' | \
jq -r '.choices[0].message.content'
```
<details>
<summary>Docker images are built nightly with llama-server for cuda, intel, vulcan and musa.</summary>
They include:
- `ghcr.io/mostlygeek/llama-swap:cpu`
- `ghcr.io/mostlygeek/llama-swap:cuda`
- `ghcr.io/mostlygeek/llama-swap:intel`
- `ghcr.io/mostlygeek/llama-swap:vulkan`
- ROCm disabled until fixed in llama.cpp container
Specific versions are also available and are tagged with the llama-swap, architecture and llama.cpp versions. For example: `ghcr.io/mostlygeek/llama-swap:v89-cuda-b4716`
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
```shell
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
ghcr.io/mostlygeek/llama-swap:cuda
```
</details>
### Homebrew Install (macOS/Linux)
The latest release of `llama-swap` can be installed via [Homebrew](https://brew.sh).
```shell
# Set up tap and install formula
brew tap mostlygeek/llama-swap
brew install llama-swap
# Run llama-swap
llama-swap --config path/to/config.yaml --listen localhost:8080
```
This will install the `llama-swap` binary and make it available in your path. See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration)
### Pre-built Binaries ([download](https://github.com/mostlygeek/llama-swap/releases))
Binaries are available for Linux, Mac, Windows and FreeBSD. These are automatically published and are likely a few hours ahead of the docker releases. The binary install works with any OpenAI compatible server, not just llama-server.
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
* _Note: Windows currently untested._
1. Run the binary with `llama-swap --config path/to/config.yaml`
1. Create a configuration file, see the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration).
1. Run the binary with `llama-swap --config path/to/config.yaml --listen localhost:8080`.
Available flags:
- `--config`: Path to the configuration file (default: `config.yaml`).
- `--listen`: Address and port to listen on (default: `:8080`).
- `--version`: Show version information and exit.
- `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`).
### Building from source
1. Build requires golang and nodejs for the user interface.
1. `git clone https://github.com/mostlygeek/llama-swap.git`
1. `make clean all`
1. Binaries will be in `build/` subdirectory
## Monitoring Logs
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
Open the `http://<host>:<port>/` with your browser to get a web interface with streaming logs.
Of course, CLI access is also supported:
CLI access is also supported:
```
```shell
# sends up to the last 10KB of logs
curl http://host/logs'
# streams logs
# streams combined logs
curl -Ns 'http://host/logs/stream'
# just llama-swap's logs
curl -Ns 'http://host/logs/stream/proxy'
# just upstream's logs
curl -Ns 'http://host/logs/stream/upstream'
# stream and filter logs with linux pipes
curl -Ns http://host/logs/stream | grep 'eval time'
@@ -128,27 +191,12 @@ curl -Ns http://host/logs/stream | grep 'eval time'
curl -Ns 'http://host/logs/stream?no-history'
```
## Systemd Unit Files
## Do I need to use llama.cpp's server (llama-server)?
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
`/etc/systemd/system/llama-swap.service`
```
[Unit]
Description=llama-swap
After=network.target
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
[Service]
User=nobody
## Star History
# set this to match your environment
ExecStart=/path/to/llama-swap --config /path/to/llama-swap.config.yml
Restart=on-failure
RestartSec=3
StartLimitBurst=3
StartLimitInterval=30
[Install]
WantedBy=multi-user.target
```
[![Star History Chart](https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date)](https://www.star-history.com/#mostlygeek/llama-swap&Date)
+197 -72
View File
@@ -1,84 +1,209 @@
# Seconds to wait for llama.cpp to be available to serve requests
# Default (and minimum): 15 seconds
healthCheckTimeout: 15
# llama-swap YAML configuration example
# -------------------------------------
#
# - Below are all the available configuration options for llama-swap.
# - Settings with a default value, or noted as optional can be omitted.
# - Settings that are marked required must be in your configuration file
# Log HTTP requests helpful for troubleshoot, defaults to False
logRequests: true
# healthCheckTimeout: number of seconds to wait for a model to be ready to serve requests
# - optional, default: 120
# - minimum value is 15 seconds, anything less will be set to this value
healthCheckTimeout: 500
# logLevel: sets the logging value
# - optional, default: info
# - Valid log levels: debug, info, warn, error
logLevel: info
# metricsMaxInMemory: maximum number of metrics to keep in memory
# - optional, default: 1000
# - controls how many metrics are stored in memory before older ones are discarded
# - useful for limiting memory usage when processing large volumes of metrics
metricsMaxInMemory: 1000
# startPort: sets the starting port number for the automatic ${PORT} macro.
# - optional, default: 5800
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
# - it is automatically incremented for every model that uses it
startPort: 10001
# macros: sets a dictionary of string:string pairs
# - optional, default: empty dictionary
# - these are reusable snippets
# - used in a model's cmd, cmdStop, proxy and checkEndpoint
# - useful for reducing common configuration settings
macros:
"latest-llama": >
/path/to/llama-server/llama-server-ec9e0301
--port ${PORT}
# models: a dictionary of model configurations
# - required
# - each key is the model's ID, used in API requests
# - model settings have default values that are used if they are not defined here
# - below are examples of the various settings a model can have:
# - available model settings: env, cmd, cmdStop, proxy, aliases, checkEndpoint, ttl, unlisted
models:
# keys are the model names used in API requests
"llama":
cmd: >
models/llama-server-osx
--port 9001
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
proxy: http://127.0.0.1:9001
# cmd: the command to run to start the inference server.
# - required
# - it is just a string, similar to what you would run on the CLI
# - using `|` allows for comments in the command, these will be parsed out
# - macros can be used within cmd
cmd: |
# ${latest-llama} is a macro that is defined above
${latest-llama}
--model path/to/llama-8B-Q4_K_M.gguf
# list of model name aliases this llama.cpp instance can serve
aliases:
- gpt-4o-mini
# name: a display name for the model
# - optional, default: empty string
# - if set, it will be used in the v1/models API response
# - if not set, it will be omitted in the JSON model record
name: "llama 3.1 8B"
# check this path for a HTTP 200 response for the server to be ready
checkEndpoint: /health
# description: a description for the model
# - optional, default: empty string
# - if set, it will be used in the v1/models API response
# - if not set, it will be omitted in the JSON model record
description: "A small but capable model used for quick testing"
# unload model after 5 seconds
ttl: 5
"qwen":
cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
proxy: http://127.0.0.1:9002
aliases:
- gpt-3.5-turbo
# Embedding example with Nomic
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
"nomic":
proxy: http://127.0.0.1:9005
cmd: >
models/llama-server-osx --port 9005
-m models/nomic-embed-text-v1.5.Q8_0.gguf
--ctx-size 8192
--batch-size 8192
--rope-scaling yarn
--rope-freq-scale 0.75
-ngl 99
--embeddings
# Reranking example with bge-reranker
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
"bge-reranker":
proxy: http://127.0.0.1:9006
cmd: >
models/llama-server-osx --port 9006
-m models/bge-reranker-v2-m3-Q4_K_M.gguf
--ctx-size 8192
--reranking
"simple":
# example of setting environment variables
# env: define an array of environment variables to inject into cmd's environment
# - optional, default: empty array
# - each value is a single string
# - in the format: ENV_NAME=value
env:
- CUDA_VISIBLE_DEVICES=0,1
- env1=hello
cmd: build/simple-responder --port 8999
- "CUDA_VISIBLE_DEVICES=0,1,2"
# proxy: the URL where llama-swap routes API requests
# - optional, default: http://localhost:${PORT}
# - if you used ${PORT} in cmd this can be omitted
# - if you use a custom port in cmd this *must* be set
proxy: http://127.0.0.1:8999
unlisted: true
# use "none" to skip check. Caution this may cause some requests to fail
# until the upstream server is ready for traffic
checkEndpoint: none
# aliases: alternative model names that this model configuration is used for
# - optional, default: empty array
# - aliases must be unique globally
# - useful for impersonating a specific model
aliases:
- "gpt-4o-mini"
- "gpt-3.5-turbo"
# don't use these, just for testing if things are broken
"broken":
cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf
proxy: http://127.0.0.1:8999
unlisted: true
"broken_timeout":
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
proxy: http://127.0.0.1:9000
unlisted: true
# checkEndpoint: URL path to check if the server is ready
# - optional, default: /health
# - use "none" to skip endpoint ready checking
# - endpoint is expected to return an HTTP 200 response
# - all requests wait until the endpoint is ready (or fails)
checkEndpoint: /custom-endpoint
# creating a coding profile with models for code generation and general questions
profiles:
coding:
- "qwen"
- "llama"
# ttl: automatically unload the model after this many seconds
# - optional, default: 0
# - ttl values must be a value greater than 0
# - a value of 0 disables automatic unloading of the model
ttl: 60
# useModelName: overrides the model name that is sent to upstream server
# - optional, default: ""
# - useful when the upstream server expects a specific model name or format
useModelName: "qwen:qwq"
# filters: a dictionary of filter settings
# - optional, default: empty dictionary
filters:
# strip_params: a comma separated list of parameters to remove from the request
# - optional, default: ""
# - useful for preventing overriding of default server params by requests
# - `model` parameter is never removed
# - can be any JSON key in the request body
# - recommended to stick to sampling parameters
strip_params: "temperature, top_p, top_k"
# Unlisted model example:
"qwen-unlisted":
# unlisted: true or false
# - optional, default: false
# - unlisted models do not show up in /v1/models or /upstream lists
# - can be requested as normal through all apis
unlisted: true
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
# Docker example:
# container run times like Docker and Podman can also be used with a
# a combination of cmd and cmdStop.
"docker-llama":
proxy: "http://127.0.0.1:${PORT}"
cmd: |
docker run --name dockertest
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
ghcr.io/ggml-org/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# cmdStop: command to run to stop the model gracefully
# - optional, default: ""
# - useful for stopping commands managed by another system
# - on POSIX systems: a SIGTERM is sent for graceful shutdown
# - on Windows, taskkill is used
# - processes are given 5 seconds to shutdown until they are forcefully killed
# - the upstream's process id is available in the ${PID} macro
cmdStop: docker stop dockertest
# groups: a dictionary of group settings
# - optional, default: empty dictionary
# - provide advanced controls over model swapping behaviour.
# - Using groups some models can be kept loaded indefinitely, while others are swapped out.
# - model ids must be defined in the Models section
# - a model can only be a member of one group
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
# - see issue #109 for details
#
# NOTE: the example below uses model names that are not defined above for demonstration purposes
groups:
# group1 is same as the default behaviour of llama-swap where only one model is allowed
# to run a time across the whole llama-swap instance
"group1":
# swap: controls the model swapping behaviour in within the group
# - optional, default: true
# - true : only one model is allowed to run at a time
# - false: all models can run together, no swapping
swap: true
# exclusive: controls how the group affects other groups
# - optional, default: true
# - true: causes all other groups to unload when this group runs a model
# - false: does not affect other groups
exclusive: true
# members references the models defined above
# required
members:
- "llama"
- "qwen-unlisted"
# Example:
# - in this group all the models can run at the same time
# - when a different group loads all running models in this group are unloaded
"group2":
swap: false
exclusive: false
members:
- "docker-llama"
- "modelA"
- "modelB"
# Example:
# - a persistent group, prevents other groups from unloading it
"forever":
# persistent: prevents over groups from unloading the models in this group
# - optional, default: false
# - does not affect individual model behaviour
persistent: true
# set swap/exclusive to false to prevent swapping inside the group
# and the unloading of other groups
swap: false
exclusive: false
members:
- "forever-modelA"
- "forever-modelB"
- "forever-modelc"
+55
View File
@@ -0,0 +1,55 @@
#!/bin/bash
cd $(dirname "$0")
ARCH=$1
PUSH_IMAGES=${2:-false}
# List of allowed architectures
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
# Check if ARCH is in the allowed list
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
exit 1
fi
# Check if GITHUB_TOKEN is set and not empty
if [[ -z "$GITHUB_TOKEN" ]]; then
echo "Error: GITHUB_TOKEN is not set or is empty."
exit 1
fi
# the most recent llama-swap tag
# have to strip out the 'v' due to .tar.gz file naming
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
if [ "$ARCH" == "cpu" ]; then
# cpu only containers just use the latest available
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu"
echo "Building ${CONTAINER_LATEST} $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} .
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_LATEST}
fi
else
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
| sort -r | head -n1 | awk -F '-' '{print $3}')
# Abort if LCPP_TAG is empty.
if [[ -z "$LCPP_TAG" ]]; then
echo "Abort: Could not find llama-server container for arch: $ARCH"
exit 1
fi
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
echo "Building ${CONTAINER_TAG} $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_TAG}
docker push ${CONTAINER_LATEST}
fi
fi
+18
View File
@@ -0,0 +1,18 @@
healthCheckTimeout: 300
logRequests: true
metricsMaxInMemory: 1000
models:
"qwen2.5":
proxy: "http://127.0.0.1:9999"
cmd: >
/app/llama-server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
--port 9999
"smollm2":
proxy: "http://127.0.0.1:9999"
cmd: >
/app/llama-server
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
--port 9999
+16
View File
@@ -0,0 +1,16 @@
ARG BASE_TAG=server-cuda
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
# has to be after the FROM
ARG LS_VER=89
WORKDIR /app
RUN \
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
COPY config.example.yaml /app/config.yaml
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
+3
View File
@@ -0,0 +1,3 @@
The code in `event` was originally a part of https://github.com/kelindar/event (v1.5.2)
The original code uses a `time.Ticker` to process the event queue which caused a large increase in CPU usage ([#189](https://github.com/mostlygeek/llama-swap/issues/189)). This code was ported to remove the ticker and instead be more event driven.
+30
View File
@@ -0,0 +1,30 @@
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
package event
import (
"context"
)
// Default initializes a default in-process dispatcher
var Default = NewDispatcherConfig(25000)
// On subscribes to an event, the type of the event will be automatically
// inferred from the provided type. Must be constant for this to work. This
// functions same way as Subscribe() but uses the default dispatcher instead.
func On[T Event](handler func(T)) context.CancelFunc {
return Subscribe(Default, handler)
}
// OnType subscribes to an event with the specified event type. This functions
// same way as SubscribeTo() but uses the default dispatcher instead.
func OnType[T Event](eventType uint32, handler func(T)) context.CancelFunc {
return SubscribeTo(Default, eventType, handler)
}
// Emit writes an event into the dispatcher. This functions same way as
// Publish() but uses the default dispatcher instead.
func Emit[T Event](ev T) {
Publish(Default, ev)
}
+54
View File
@@ -0,0 +1,54 @@
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
package event
import (
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
)
/*
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
BenchmarkSubcribeConcurrent-24 1826686 606.3 ns/op 1648 B/op 5 allocs/op
*/
func BenchmarkSubscribeConcurrent(b *testing.B) {
d := NewDispatcher()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
unsub := Subscribe(d, func(ev MyEvent1) {})
unsub()
}
})
}
func TestDefaultPublish(t *testing.T) {
var wg sync.WaitGroup
// Subscribe
var count int64
defer On(func(ev MyEvent1) {
atomic.AddInt64(&count, 1)
wg.Done()
})()
defer OnType(TypeEvent1, func(ev MyEvent1) {
atomic.AddInt64(&count, 1)
wg.Done()
})()
// Publish
wg.Add(4)
Emit(MyEvent1{})
Emit(MyEvent1{})
// Wait and check
wg.Wait()
assert.Equal(t, int64(4), count)
}
+324
View File
@@ -0,0 +1,324 @@
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for details.
package event
import (
"context"
"fmt"
"reflect"
"sort"
"strings"
"sync"
"sync/atomic"
)
// Event represents an event contract
type Event interface {
Type() uint32
}
// registry holds an immutable sorted array of event mappings
type registry struct {
keys []uint32 // Event types (sorted)
grps []any // Corresponding subscribers
}
// ------------------------------------- Dispatcher -------------------------------------
// Dispatcher represents an event dispatcher.
type Dispatcher struct {
subs atomic.Pointer[registry] // Atomic pointer to immutable array
done chan struct{} // Cancellation
maxQueue int // Maximum queue size per consumer
mu sync.Mutex // Only for writes (subscribe/unsubscribe)
}
// NewDispatcher creates a new dispatcher of events.
func NewDispatcher() *Dispatcher {
return NewDispatcherConfig(50000)
}
// NewDispatcherConfig creates a new dispatcher with configurable max queue size
func NewDispatcherConfig(maxQueue int) *Dispatcher {
d := &Dispatcher{
done: make(chan struct{}),
maxQueue: maxQueue,
}
d.subs.Store(&registry{
keys: make([]uint32, 0, 16),
grps: make([]any, 0, 16),
})
return d
}
// Close closes the dispatcher
func (d *Dispatcher) Close() error {
close(d.done)
return nil
}
// isClosed returns whether the dispatcher is closed or not
func (d *Dispatcher) isClosed() bool {
select {
case <-d.done:
return true
default:
return false
}
}
// findGroup performs a lock-free binary search for the event type
func (d *Dispatcher) findGroup(eventType uint32) any {
reg := d.subs.Load()
keys := reg.keys
// Inlined binary search for better cache locality
left, right := 0, len(keys)
for left < right {
mid := left + (right-left)/2
if keys[mid] < eventType {
left = mid + 1
} else {
right = mid
}
}
if left < len(keys) && keys[left] == eventType {
return reg.grps[left]
}
return nil
}
// Subscribe subscribes to an event, the type of the event will be automatically
// inferred from the provided type. Must be constant for this to work.
func Subscribe[T Event](broker *Dispatcher, handler func(T)) context.CancelFunc {
var event T
return SubscribeTo(broker, event.Type(), handler)
}
// SubscribeTo subscribes to an event with the specified event type.
func SubscribeTo[T Event](broker *Dispatcher, eventType uint32, handler func(T)) context.CancelFunc {
if broker.isClosed() {
panic(errClosed)
}
broker.mu.Lock()
defer broker.mu.Unlock()
// Check if group already exists
if existing := broker.findGroup(eventType); existing != nil {
grp := groupOf[T](eventType, existing)
sub := grp.Add(handler)
return func() {
grp.Del(sub)
}
}
// Create new group
grp := &group[T]{cond: sync.NewCond(new(sync.Mutex)), maxQueue: broker.maxQueue}
sub := grp.Add(handler)
// Copy-on-write: insert new entry in sorted position
old := broker.subs.Load()
idx := sort.Search(len(old.keys), func(i int) bool {
return old.keys[i] >= eventType
})
// Create new arrays with space for one more element
newKeys := make([]uint32, len(old.keys)+1)
newGrps := make([]any, len(old.grps)+1)
// Copy elements before insertion point
copy(newKeys[:idx], old.keys[:idx])
copy(newGrps[:idx], old.grps[:idx])
// Insert new element
newKeys[idx] = eventType
newGrps[idx] = grp
// Copy elements after insertion point
copy(newKeys[idx+1:], old.keys[idx:])
copy(newGrps[idx+1:], old.grps[idx:])
// Atomically store the new registry (mutex ensures no concurrent writers)
newReg := &registry{keys: newKeys, grps: newGrps}
broker.subs.Store(newReg)
return func() {
grp.Del(sub)
}
}
// Publish writes an event into the dispatcher
func Publish[T Event](broker *Dispatcher, ev T) {
eventType := ev.Type()
if sub := broker.findGroup(eventType); sub != nil {
group := groupOf[T](eventType, sub)
group.Broadcast(ev)
}
}
// Count counts the number of subscribers, this is for testing only.
func (d *Dispatcher) count(eventType uint32) int {
if group := d.findGroup(eventType); group != nil {
return group.(interface{ Count() int }).Count()
}
return 0
}
// groupOf casts the subscriber group to the specified generic type
func groupOf[T Event](eventType uint32, subs any) *group[T] {
if group, ok := subs.(*group[T]); ok {
return group
}
panic(errConflict[T](eventType, subs))
}
// ------------------------------------- Subscriber -------------------------------------
// consumer represents a consumer with a message queue
type consumer[T Event] struct {
queue []T // Current work queue
stop bool // Stop signal
}
// Listen listens to the event queue and processes events
func (s *consumer[T]) Listen(c *sync.Cond, fn func(T)) {
pending := make([]T, 0, 128)
for {
c.L.Lock()
for len(s.queue) == 0 {
switch {
case s.stop:
c.L.Unlock()
return
default:
c.Wait()
}
}
// Swap buffers and reset the current queue
temp := s.queue
s.queue = pending[:0]
pending = temp
c.L.Unlock()
// Outside of the critical section, process the work
for _, event := range pending {
fn(event)
}
// Notify potential publishers waiting due to backpressure
c.Broadcast()
}
}
// ------------------------------------- Subscriber Group -------------------------------------
// group represents a consumer group
type group[T Event] struct {
cond *sync.Cond
subs []*consumer[T]
maxQueue int // Maximum queue size per consumer
maxLen int // Current maximum queue length across all consumers
}
// Broadcast sends an event to all consumers
func (s *group[T]) Broadcast(ev T) {
s.cond.L.Lock()
defer s.cond.L.Unlock()
// Calculate current maximum queue length
s.maxLen = 0
for _, sub := range s.subs {
if len(sub.queue) > s.maxLen {
s.maxLen = len(sub.queue)
}
}
// Backpressure: wait if queues are full
for s.maxLen >= s.maxQueue {
s.cond.Wait()
// Recalculate after wakeup
s.maxLen = 0
for _, sub := range s.subs {
if len(sub.queue) > s.maxLen {
s.maxLen = len(sub.queue)
}
}
}
// Add event to all queues and track new maximum
newMax := 0
for _, sub := range s.subs {
sub.queue = append(sub.queue, ev)
if len(sub.queue) > newMax {
newMax = len(sub.queue)
}
}
s.maxLen = newMax
s.cond.Broadcast() // Wake consumers
}
// Add adds a subscriber to the list
func (s *group[T]) Add(handler func(T)) *consumer[T] {
sub := &consumer[T]{
queue: make([]T, 0, 64),
}
// Add the consumer to the list of active consumers
s.cond.L.Lock()
s.subs = append(s.subs, sub)
s.cond.L.Unlock()
// Start listening
go sub.Listen(s.cond, handler)
return sub
}
// Del removes a subscriber from the list
func (s *group[T]) Del(sub *consumer[T]) {
s.cond.L.Lock()
defer s.cond.L.Unlock()
// Search and remove the subscriber
sub.stop = true
for i, v := range s.subs {
if v == sub {
copy(s.subs[i:], s.subs[i+1:])
s.subs = s.subs[:len(s.subs)-1]
break
}
}
}
// ------------------------------------- Debugging -------------------------------------
var errClosed = fmt.Errorf("event dispatcher is closed")
// Count returns the number of subscribers in this group
func (s *group[T]) Count() int {
return len(s.subs)
}
// String returns string representation of the type
func (s *group[T]) String() string {
typ := reflect.TypeOf(s).String()
idx := strings.LastIndex(typ, "/")
typ = typ[idx+1 : len(typ)-1]
return typ
}
// errConflict returns a conflict message
func errConflict[T any](eventType uint32, existing any) string {
var want T
return fmt.Sprintf(
"conflicting event type, want=<%T>, registered=<%s>, event=0x%v",
want, existing, eventType,
)
}
+324
View File
@@ -0,0 +1,324 @@
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
package event
import (
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestPublish(t *testing.T) {
d := NewDispatcher()
var wg sync.WaitGroup
// Subscribe, must be received in order
var count int64
defer Subscribe(d, func(ev MyEvent1) {
assert.Equal(t, int(atomic.AddInt64(&count, 1)), ev.Number)
wg.Done()
})()
// Publish
wg.Add(3)
Publish(d, MyEvent1{Number: 1})
Publish(d, MyEvent1{Number: 2})
Publish(d, MyEvent1{Number: 3})
// Wait and check
wg.Wait()
assert.Equal(t, int64(3), count)
}
func TestUnsubscribe(t *testing.T) {
d := NewDispatcher()
assert.Equal(t, 0, d.count(TypeEvent1))
unsubscribe := Subscribe(d, func(ev MyEvent1) {
// Nothing
})
assert.Equal(t, 1, d.count(TypeEvent1))
unsubscribe()
assert.Equal(t, 0, d.count(TypeEvent1))
}
func TestConcurrent(t *testing.T) {
const max = 1000000
var count int64
var wg sync.WaitGroup
wg.Add(1)
d := NewDispatcher()
defer Subscribe(d, func(ev MyEvent1) {
if current := atomic.AddInt64(&count, 1); current == max {
wg.Done()
}
})()
// Asynchronously publish
go func() {
for i := 0; i < max; i++ {
Publish(d, MyEvent1{})
}
}()
defer Subscribe(d, func(ev MyEvent1) {
// Subscriber that does nothing
})()
wg.Wait()
assert.Equal(t, max, int(count))
}
func TestSubscribeDifferentType(t *testing.T) {
d := NewDispatcher()
assert.Panics(t, func() {
SubscribeTo(d, TypeEvent1, func(ev MyEvent1) {})
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
})
}
func TestPublishDifferentType(t *testing.T) {
d := NewDispatcher()
assert.Panics(t, func() {
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
Publish(d, MyEvent1{})
})
}
func TestCloseDispatcher(t *testing.T) {
d := NewDispatcher()
defer SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})()
assert.NoError(t, d.Close())
assert.Panics(t, func() {
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
})
}
func TestMatrix(t *testing.T) {
const amount = 1000
for _, subs := range []int{1, 10, 100} {
for _, topics := range []int{1, 10} {
expected := subs * topics * amount
t.Run(fmt.Sprintf("%dx%d", topics, subs), func(t *testing.T) {
var count atomic.Int64
var wg sync.WaitGroup
wg.Add(expected)
d := NewDispatcher()
for i := 0; i < subs; i++ {
for id := 0; id < topics; id++ {
defer SubscribeTo(d, uint32(id), func(ev MyEvent3) {
count.Add(1)
wg.Done()
})()
}
}
for n := 0; n < amount; n++ {
for id := 0; id < topics; id++ {
go Publish(d, MyEvent3{ID: id})
}
}
wg.Wait()
assert.Equal(t, expected, int(count.Load()))
})
}
}
}
func TestConcurrentSubscriptionRace(t *testing.T) {
// This test specifically targets the race condition that occurs when multiple
// goroutines try to subscribe to different event types simultaneously.
// Without the CAS loop, subscriptions could be lost due to registry corruption.
const numGoroutines = 100
const numEventTypes = 50
d := NewDispatcher()
defer d.Close()
var wg sync.WaitGroup
var receivedCount int64
var subscribedTypes sync.Map // Thread-safe map
wg.Add(numGoroutines)
// Start multiple goroutines that subscribe to different event types concurrently
for i := 0; i < numGoroutines; i++ {
go func(goroutineID int) {
defer wg.Done()
// Each goroutine subscribes to a unique event type
eventType := uint32(goroutineID%numEventTypes + 1000) // Offset to avoid collision with other tests
// Subscribe to the event type
SubscribeTo(d, eventType, func(ev MyEvent3) {
atomic.AddInt64(&receivedCount, 1)
})
// Record that this type was subscribed
subscribedTypes.Store(eventType, true)
}(i)
}
// Wait for all subscriptions to complete
wg.Wait()
// Count the number of unique event types subscribed
expectedTypes := 0
subscribedTypes.Range(func(key, value interface{}) bool {
expectedTypes++
return true
})
// Small delay to ensure all subscriptions are fully processed
time.Sleep(10 * time.Millisecond)
// Publish events to each subscribed type
subscribedTypes.Range(func(key, value interface{}) bool {
eventType := key.(uint32)
Publish(d, MyEvent3{ID: int(eventType)})
return true
})
// Wait for all events to be processed
time.Sleep(50 * time.Millisecond)
// Verify that we received at least the expected number of events
// (there might be more if multiple goroutines subscribed to the same event type)
received := atomic.LoadInt64(&receivedCount)
assert.GreaterOrEqual(t, int(received), expectedTypes,
"Should have received at least %d events, got %d", expectedTypes, received)
// Verify that we have the expected number of unique event types
assert.Equal(t, numEventTypes, expectedTypes,
"Should have exactly %d unique event types", numEventTypes)
}
func TestConcurrentHandlerRegistration(t *testing.T) {
const numGoroutines = 100
// Test concurrent subscriptions to the same event type
t.Run("SameEventType", func(t *testing.T) {
d := NewDispatcher()
var handlerCount int64
var wg sync.WaitGroup
// Start multiple goroutines subscribing to the same event type (0x1)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
SubscribeTo(d, uint32(0x1), func(ev MyEvent1) {
atomic.AddInt64(&handlerCount, 1)
})
}()
}
wg.Wait()
// Verify all handlers were registered by publishing an event
atomic.StoreInt64(&handlerCount, 0)
Publish(d, MyEvent1{})
// Small delay to ensure all handlers have executed
time.Sleep(10 * time.Millisecond)
assert.Equal(t, int64(numGoroutines), atomic.LoadInt64(&handlerCount),
"Not all handlers were registered due to race condition")
})
// Test concurrent subscriptions to different event types
t.Run("DifferentEventTypes", func(t *testing.T) {
d := NewDispatcher()
var wg sync.WaitGroup
receivedEvents := make(map[uint32]*int64)
// Create multiple event types and subscribe concurrently
for i := 0; i < numGoroutines; i++ {
eventType := uint32(100 + i)
counter := new(int64)
receivedEvents[eventType] = counter
wg.Add(1)
go func(et uint32, cnt *int64) {
defer wg.Done()
SubscribeTo(d, et, func(ev MyEvent3) {
atomic.AddInt64(cnt, 1)
})
}(eventType, counter)
}
wg.Wait()
// Publish events to all types
for eventType := uint32(100); eventType < uint32(100+numGoroutines); eventType++ {
Publish(d, MyEvent3{ID: int(eventType)})
}
// Small delay to ensure all handlers have executed
time.Sleep(10 * time.Millisecond)
// Verify all event types received their events
for eventType, counter := range receivedEvents {
assert.Equal(t, int64(1), atomic.LoadInt64(counter),
"Event type %d did not receive its event", eventType)
}
})
}
func TestBackpressure(t *testing.T) {
d := NewDispatcher()
d.maxQueue = 10
var processedCount int64
unsub := SubscribeTo(d, uint32(0x200), func(ev MyEvent3) {
atomic.AddInt64(&processedCount, 1)
})
defer unsub()
const eventsToPublish = 1000
for i := 0; i < eventsToPublish; i++ {
Publish(d, MyEvent3{ID: 0x200})
}
time.Sleep(100 * time.Millisecond)
// Verify all events were eventually processed
finalProcessed := atomic.LoadInt64(&processedCount)
assert.Equal(t, int64(eventsToPublish), finalProcessed)
t.Logf("Events processed: %d/%d", finalProcessed, eventsToPublish)
}
// ------------------------------------- Test Events -------------------------------------
const (
TypeEvent1 = 0x1
TypeEvent2 = 0x2
)
type MyEvent1 struct {
Number int
}
func (t MyEvent1) Type() uint32 { return TypeEvent1 }
type MyEvent2 struct {
Text string
}
func (t MyEvent2) Type() uint32 { return TypeEvent2 }
type MyEvent3 struct {
ID int
}
func (t MyEvent3) Type() uint32 { return uint32(t.ID) }
+153
View File
@@ -0,0 +1,153 @@
# aider, QwQ, Qwen-Coder 2.5 and llama-swap
This guide show how to use aider and llama-swap to get a 100% local coding co-pilot setup. The focus is on the trickest part which is configuring aider, llama-swap and llama-server to work together.
## Here's what you you need:
- aider - [installation docs](https://aider.chat/docs/install.html)
- llama-server - [download latest release](https://github.com/ggml-org/llama.cpp/releases)
- llama-swap - [download latest release](https://github.com/mostlygeek/llama-swap/releases)
- [QwQ 32B](https://huggingface.co/bartowski/Qwen_QwQ-32B-GGUF) and [Qwen Coder 2.5 32B](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF) models
- 24GB VRAM video card
## Running aider
The goal is getting this command line to work:
```sh
aider --architect \
--no-show-model-warnings \
--model openai/QwQ \
--editor-model openai/qwen-coder-32B \
--model-settings-file aider.model.settings.yml \
--openai-api-key "sk-na" \
--openai-api-base "http://10.0.1.24:8080/v1" \
```
Set `--openai-api-base` to the IP and port where your llama-swap is running.
## Create an aider model settings file
```yaml
# aider.model.settings.yml
#
# !!! important: model names must match llama-swap configuration names !!!
#
- name: "openai/QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/qwen-coder-32B"
editor_model_name: "openai/qwen-coder-32B"
- name: "openai/qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/qwen-coder-32B"
```
## llama-swap configuration
```yaml
# config.yaml
# The parameters are tweaked to fit model+context into 24GB VRAM GPUs
models:
"qwen-coder-32B":
proxy: "http://127.0.0.1:8999"
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 8999 --flash-attn --slots
--ctx-size 16000
--cache-type-k q8_0 --cache-type-v q8_0
-ngl 99
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
"QwQ":
proxy: "http://127.0.0.1:9503"
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 9503 --flash-attn --metrics--slots
--cache-type-k q8_0 --cache-type-v q8_0
--ctx-size 32000
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
--temp 0.6 --repeat-penalty 1.1 --dry-multiplier 0.5
--min-p 0.01 --top-k 40 --top-p 0.95
-ngl 99
--model /mnt/nvme/models/bartowski/Qwen_QwQ-32B-Q4_K_M.gguf
```
## Advanced, Dual GPU Configuration
If you have _dual 24GB GPUs_ you can use llama-swap profiles to avoid swapping between QwQ and Qwen Coder.
In llama-swap's configuration file:
1. add a `profiles` section with `aider` as the profile name
2. using the `env` field to specify the GPU IDs for each model
```yaml
# config.yaml
# Add a profile for aider
profiles:
aider:
- qwen-coder-32B
- QwQ
models:
"qwen-coder-32B":
# manually set the GPU to run on
env:
- "CUDA_VISIBLE_DEVICES=0"
proxy: "http://127.0.0.1:8999"
cmd: /path/to/llama-server ...
"QwQ":
# manually set the GPU to run on
env:
- "CUDA_VISIBLE_DEVICES=1"
proxy: "http://127.0.0.1:9503"
cmd: /path/to/llama-server ...
```
Append the profile tag, `aider:`, to the model names in the model settings file
```yaml
# aider.model.settings.yml
- name: "openai/aider:QwQ"
weak_model_name: "openai/aider:qwen-coder-32B-aider"
editor_model_name: "openai/aider:qwen-coder-32B-aider"
- name: "openai/aider:qwen-coder-32B"
editor_model_name: "openai/aider:qwen-coder-32B-aider"
```
Run aider with:
```sh
$ aider --architect \
--no-show-model-warnings \
--model openai/aider:QwQ \
--editor-model openai/aider:qwen-coder-32B \
--config aider.conf.yml \
--model-settings-file aider.model.settings.yml
--openai-api-key "sk-na" \
--openai-api-base "http://10.0.1.24:8080/v1"
```
@@ -0,0 +1,28 @@
# this makes use of llama-swap's profile feature to
# keep the architect and editor models in VRAM on different GPUs
- name: "openai/aider:QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/aider:qwen-coder-32B"
editor_model_name: "openai/aider:qwen-coder-32B"
- name: "openai/aider:qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/aider:qwen-coder-32B"
@@ -0,0 +1,26 @@
- name: "openai/QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/qwen-coder-32B"
editor_model_name: "openai/qwen-coder-32B"
- name: "openai/qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/qwen-coder-32B"
+49
View File
@@ -0,0 +1,49 @@
healthCheckTimeout: 300
logLevel: debug
profiles:
aider:
- qwen-coder-32B
- QwQ
models:
"qwen-coder-32B":
env:
- "CUDA_VISIBLE_DEVICES=0"
aliases:
- coder
proxy: "http://127.0.0.1:8999"
# set appropriate paths for your environment
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 8999 --flash-attn --slots
--ctx-size 16000
--ctx-size-draft 16000
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
--model-draft /path/to/Qwen2.5-Coder-1.5B-Instruct-Q8_0.gguf
-ngl 99 -ngld 99
--draft-max 16 --draft-min 4 --draft-p-min 0.4
--cache-type-k q8_0 --cache-type-v q8_0
"QwQ":
env:
- "CUDA_VISIBLE_DEVICES=1"
proxy: "http://127.0.0.1:9503"
# set appropriate paths for your environment
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 9503
--flash-attn --metrics
--slots
--model /path/to/Qwen_QwQ-32B-Q4_K_M.gguf
--cache-type-k q8_0 --cache-type-v q8_0
--ctx-size 32000
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
--temp 0.6
--repeat-penalty 1.1
--dry-multiplier 0.5
--min-p 0.01
--top-k 40
--top-p 0.95
-ngl 99 -ngld 99
@@ -0,0 +1,51 @@
# Restart llama-swap on config change
Sometimes editing the configuration file can take a bit of trail and error to get a model configuration tuned just right. The `watch-and-restart.sh` script can be used to watch `config.yaml` for changes and restart `llama-swap` when it detects a change.
```bash
#!/bin/bash
#
# A simple watch and restart llama-swap when its configuration
# file changes. Useful for trying out configuration changes
# without manually restarting the server each time.
if [ -z "$1" ]; then
echo "Usage: $0 <path to config.yaml>"
exit 1
fi
while true; do
# Start the process again
./llama-swap-linux-amd64 -config $1 -listen :1867 &
PID=$!
echo "Started llama-swap with PID $PID"
# Wait for modifications in the specified directory or file
inotifywait -e modify "$1"
# Check if process exists before sending signal
if kill -0 $PID 2>/dev/null; then
echo "Sending SIGTERM to $PID"
kill -SIGTERM $PID
wait $PID
else
echo "Process $PID no longer exists"
fi
sleep 1
done
```
## Usage and output example
```bash
$ ./watch-and-restart.sh config.yaml
Started llama-swap with PID 495455
Setting up watches.
Watches established.
llama-swap listening on :1867
Sending SIGTERM to 495455
Shutting down llama-swap
Started llama-swap with PID 495486
Setting up watches.
Watches established.
llama-swap listening on :1867
```
+11 -6
View File
@@ -3,7 +3,12 @@ module github.com/mostlygeek/llama-swap
go 1.23.0
require (
github.com/billziss-gh/golib v0.2.0
github.com/fsnotify/fsnotify v1.9.0
github.com/gin-gonic/gin v1.10.0
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
gopkg.in/yaml.v3 v3.0.1
)
@@ -15,12 +20,10 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/gin-gonic/gin v1.10.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
@@ -29,12 +32,14 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
)
+28 -18
View File
@@ -1,3 +1,5 @@
github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8=
github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
@@ -9,12 +11,16 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
@@ -23,9 +29,9 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
@@ -57,6 +63,16 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
@@ -64,24 +80,18 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 351 KiB

+149 -8
View File
@@ -1,23 +1,35 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"path/filepath"
"syscall"
"time"
"github.com/fsnotify/fsnotify"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy"
)
var version string = "0"
var commit string = "abcd1234"
var date = "unknown"
var (
version string = "0"
commit string = "abcd1234"
date string = "unknown"
)
func main() {
// Define a command-line flag for the port
configPath := flag.String("config", "config.yaml", "config file name")
listenStr := flag.String("listen", ":8080", "listen ip/port")
showVersion := flag.Bool("version", false, "show version of build")
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
flag.Parse() // Parse the command-line flags
@@ -32,16 +44,145 @@ func main() {
os.Exit(1)
}
if len(config.Profiles) > 0 {
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
}
if mode := os.Getenv("GIN_MODE"); mode != "" {
gin.SetMode(mode)
} else {
gin.SetMode(gin.ReleaseMode)
}
proxyManager := proxy.New(config)
fmt.Println("llama-swap listening on " + *listenStr)
if err := proxyManager.Run(*listenStr); err != nil {
fmt.Printf("Server error: %v\n", err)
os.Exit(1)
// Setup channels for server management
exitChan := make(chan struct{})
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Create server with initial handler
srv := &http.Server{
Addr: *listenStr,
}
// Support for watching config and reloading when it changes
reloadProxyManager := func() {
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
config, err = proxy.LoadConfig(*configPath)
if err != nil {
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
return
}
fmt.Println("Configuration Changed")
currentPM.Shutdown()
srv.Handler = proxy.New(config)
fmt.Println("Configuration Reloaded")
// wait a few seconds and tell any UI to reload
time.AfterFunc(3*time.Second, func() {
event.Emit(proxy.ConfigFileChangedEvent{
ReloadingState: proxy.ReloadingStateEnd,
})
})
} else {
config, err = proxy.LoadConfig(*configPath)
if err != nil {
fmt.Printf("Error, unable to load configuration: %v\n", err)
os.Exit(1)
}
srv.Handler = proxy.New(config)
}
}
// load the initial proxy manager
reloadProxyManager()
debouncedReload := debounce(time.Second, reloadProxyManager)
if *watchConfig {
defer event.On(func(e proxy.ConfigFileChangedEvent) {
if e.ReloadingState == proxy.ReloadingStateStart {
debouncedReload()
}
})()
fmt.Println("Watching Configuration for changes")
go func() {
absConfigPath, err := filepath.Abs(*configPath)
if err != nil {
fmt.Printf("Error getting absolute path for watching config file: %v\n", err)
return
}
watcher, err := fsnotify.NewWatcher()
if err != nil {
fmt.Printf("Error creating file watcher: %v. File watching disabled.\n", err)
return
}
configDir := filepath.Dir(absConfigPath)
err = watcher.Add(configDir)
if err != nil {
fmt.Printf("Error adding config path directory (%s) to watcher: %v. File watching disabled.", configDir, err)
return
}
defer watcher.Close()
for {
select {
case changeEvent := <-watcher.Events:
if changeEvent.Name == absConfigPath && (changeEvent.Has(fsnotify.Write) || changeEvent.Has(fsnotify.Create) || changeEvent.Has(fsnotify.Remove)) {
event.Emit(proxy.ConfigFileChangedEvent{
ReloadingState: proxy.ReloadingStateStart,
})
} else if changeEvent.Name == filepath.Join(configDir, "..data") && changeEvent.Has(fsnotify.Create) {
// the change for k8s configmap
event.Emit(proxy.ConfigFileChangedEvent{
ReloadingState: proxy.ReloadingStateStart,
})
}
case err := <-watcher.Errors:
log.Printf("File watcher error: %v", err)
}
}
}()
}
// shutdown on signal
go func() {
sig := <-sigChan
fmt.Printf("Received signal %v, shutting down...\n", sig)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
if pm, ok := srv.Handler.(*proxy.ProxyManager); ok {
pm.Shutdown()
} else {
fmt.Println("srv.Handler is not of type *proxy.ProxyManager")
}
if err := srv.Shutdown(ctx); err != nil {
fmt.Printf("Server shutdown error: %v\n", err)
}
close(exitChan)
}()
// Start server
fmt.Printf("llama-swap listening on %s\n", *listenStr)
go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Fatal server error: %v\n", err)
}
}()
// Wait for exit signal
<-exitChan
}
func debounce(interval time.Duration, f func()) func() {
var timer *time.Timer
return func() {
if timer != nil {
timer.Stop()
}
timer = time.AfterFunc(interval, f)
}
}
+91
View File
@@ -0,0 +1,91 @@
package main
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"os/signal"
"syscall"
"time"
)
/*
**
Test how exec.Cmd.CommandContext behaves under certain conditions:*
- process is killed externally, what happens with cmd.Wait() *
✔︎ it returns. catches crashes.*
- process ignores SIGTERM*
✔︎ `kill()` is called after cmd.WaitDelay*
- this process exits, what happens with children (kill -9 <this process' pid>)*
x they stick around. have to be manually killed.*
- .WithTimeout()'s cancel is called *
✔︎ process is killed after it ignores sigterm, cmd.Wait() catches it.*
- parent receives SIGINT/SIGTERM, what happens
✔︎ waits for child process to exit, then exits gracefully.
*/
func main() {
// swap between these to use kill -9 <pid> on the cli to sim external crash
ctx, cancel := context.WithCancel(context.Background())
//ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond)
defer cancel()
//cmd := exec.CommandContext(ctx, "sleep", "1")
cmd := exec.CommandContext(ctx,
"../../build/simple-responder_darwin_arm64",
//"-ignore-sig-term", /* so it doesn't exit on receiving SIGTERM, test cmd.WaitTimeout */
)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
// set a wait delay before signing sig kill
cmd.WaitDelay = 500 * time.Millisecond
cmd.Cancel = func() error {
fmt.Println("✔︎ Cancel() called, sending SIGTERM")
cmd.Process.Signal(syscall.SIGTERM)
//return nil
// this error is returned by cmd.Wait(), and can be used to
// single an error when the process couldn't be normally terminated
// but since a SIGTERM is sent, it's probably ok to return a nil
// as WaitDelay timing out will override the any error set here.
//
// test by enabling/disabling -ignore-sig-term on the process
// with -ignore-sig-term enabled, cmd.Wait() will have "signal: killed"
// without it, it will show the "new error from cancel"
return errors.New("error from cmd.Cancel()") // sets error returned by cmd.Wait()
}
if err := cmd.Start(); err != nil {
fmt.Println("Error starting process:", err)
return
}
// catch signals. Calls cancel() which will cause cmd.Wait() to return and
// this program to eventually exit gracefully.
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
signal := <-sigChan
fmt.Printf("✔︎ Received signal: %d, Killing process... with cancel before exiting\n", signal)
cancel()
}()
fmt.Printf("✔︎ Parent Pid: %d, Process Pid: %d\n", os.Getpid(), cmd.Process.Pid)
fmt.Println("✔︎ Process started, cmd.Wait() ... ")
if err := cmd.Wait(); err != nil {
fmt.Println("✔︎ cmd.Wait returned, Error:", err)
} else {
fmt.Println("✔︎ cmd.Wait returned, Process exited on its own")
}
fmt.Println("✔︎ Child process exited, Done.")
}
+190 -10
View File
@@ -12,18 +12,22 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
func main() {
gin.SetMode(gin.TestMode)
// Define a command-line flag for the port
port := flag.String("port", "8080", "port to listen on")
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
// Define a command-line flag for the response message
responseMessage := flag.String("respond", "hi", "message to respond with")
silent := flag.Bool("silent", false, "disable all logging")
ignoreSigTerm := flag.Bool("ignore-sig-term", false, "ignore SIGTERM signal")
flag.Parse() // Parse the command-line flags
// Create a new Gin router
@@ -31,19 +35,166 @@ func main() {
// Set up the handler function using the provided response message
r.POST("/v1/chat/completions", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
bodyBytes, _ := io.ReadAll(c.Request.Body)
// add a wait to simulate a slow query
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
time.Sleep(wait)
// Check if streaming is requested
// Query is checked instead of JSON body since that event stream conflicts with other tests
isStreaming := c.Query("stream") == "true"
if isStreaming {
// Set headers for streaming
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
// add a wait to simulate a slow query
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
time.Sleep(wait)
}
// Send 10 "asdf" tokens
for i := 0; i < 10; i++ {
data := gin.H{
"created": time.Now().Unix(),
"choices": []gin.H{
{
"index": 0,
"delta": gin.H{
"content": "asdf",
},
"finish_reason": nil,
},
},
}
c.SSEvent("message", data)
c.Writer.Flush()
}
// Send final data with usage info
finalData := gin.H{
"usage": gin.H{
"completion_tokens": 10,
"prompt_tokens": 25,
"total_tokens": 35,
},
// add timings to simulate llama.cpp
"timings": gin.H{
"prompt_n": 25,
"prompt_ms": 13,
"predicted_n": 10,
"predicted_ms": 17,
"predicted_per_second": 10,
},
}
c.SSEvent("message", finalData)
c.Writer.Flush()
// Send [DONE]
c.SSEvent("message", "[DONE]")
c.Writer.Flush()
} else {
c.Header("Content-Type", "application/json")
// add a wait to simulate a slow query
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
time.Sleep(wait)
}
c.JSON(http.StatusOK, gin.H{
"responseMessage": *responseMessage,
"h_content_length": c.Request.Header.Get("Content-Length"),
"request_body": string(bodyBytes),
"usage": gin.H{
"completion_tokens": 10,
"prompt_tokens": 25,
"total_tokens": 35,
},
"timings": gin.H{
"prompt_n": 25,
"prompt_ms": 13,
"predicted_n": 10,
"predicted_ms": 17,
"predicted_per_second": 10,
},
})
}
})
c.String(200, *responseMessage)
// for issue #62 to check model name strips profile slug
// has to be one of the openAI API endpoints that llama-swap proxies
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
r.POST("/v1/audio/speech", func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
return
}
defer c.Request.Body.Close()
modelName := gjson.GetBytes(body, "model").String()
if modelName != *expectedModel {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
return
} else {
c.JSON(http.StatusOK, gin.H{"message": "ok"})
}
})
r.POST("/v1/completions", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, *responseMessage)
c.Header("Content-Type", "application/json")
c.JSON(http.StatusOK, gin.H{
"responseMessage": *responseMessage,
"usage": gin.H{
"completion_tokens": 10,
"prompt_tokens": 25,
"total_tokens": 35,
},
})
})
// issue #41
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
// Parse the multipart form
if err := c.Request.ParseMultipartForm(10 << 20); err != nil { // 10 MB max memory
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
return
}
// Get the model from the form values
model := c.Request.FormValue("model")
if model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing model parameter"})
return
}
// Get the file from the form
file, _, err := c.Request.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error getting file: %s", err)})
return
}
defer file.Close()
// Read the file content to get its size
fileBytes, err := io.ReadAll(file)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error reading file: %s", err)})
return
}
fileSize := len(fileBytes)
// Return a JSON response with the model and transcription text including file size
c.JSON(http.StatusOK, gin.H{
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
"model": model,
// expose some header values for testing
"h_content_type": c.GetHeader("Content-Type"),
"h_content_length": c.GetHeader("Content-Length"),
})
})
r.GET("/slow-respond", func(c *gin.Context) {
@@ -119,6 +270,10 @@ func main() {
log.SetOutput(io.Discard)
}
if !*silent {
fmt.Printf("My PID: %d\n", os.Getpid())
}
go func() {
log.Printf("simple-responder listening on %s\n", address)
// service connections
@@ -129,11 +284,36 @@ func main() {
// Wait for interrupt signal to gracefully shutdown the server with
// a timeout of 5 seconds.
quit := make(chan os.Signal, 1)
sigChan := make(chan os.Signal, 1)
// kill (no param) default send syscall.SIGTERM
// kill -2 is syscall.SIGINT
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
countSigInt := 0
runloop:
for {
signal := <-sigChan
switch signal {
case syscall.SIGINT:
countSigInt++
if countSigInt > 1 {
break runloop
} else {
log.Println("Received SIGINT, send another SIGINT to shutdown")
}
case syscall.SIGTERM:
if *ignoreSigTerm {
log.Println("Ignoring SIGTERM")
} else {
log.Println("Received SIGTERM, shutting down")
break runloop
}
default:
break runloop
}
}
log.Println("simple-responder shutting down")
}
+1
View File
@@ -0,0 +1 @@
ui_dist/*
+343 -14
View File
@@ -2,35 +2,159 @@ package proxy
import (
"fmt"
"io"
"os"
"regexp"
"runtime"
"slices"
"sort"
"strconv"
"strings"
"github.com/google/shlex"
"github.com/billziss-gh/golib/shlex"
"gopkg.in/yaml.v3"
)
const DEFAULT_GROUP_ID = "(default)"
type ModelConfig struct {
Cmd string `yaml:"cmd"`
CmdStop string `yaml:"cmdStop"`
Proxy string `yaml:"proxy"`
Aliases []string `yaml:"aliases"`
Env []string `yaml:"env"`
CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
// #179 for /v1/models
Name string `yaml:"name"`
Description string `yaml:"description"`
// Limit concurrency of HTTP requests to process
ConcurrencyLimit int `yaml:"concurrencyLimit"`
// Model filters see issue #174
Filters ModelFilters `yaml:"filters"`
}
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelConfig ModelConfig
defaults := rawModelConfig{
Cmd: "",
CmdStop: "",
Proxy: "http://localhost:${PORT}",
Aliases: []string{},
Env: []string{},
CheckEndpoint: "/health",
UnloadAfter: 0,
Unlisted: false,
UseModelName: "",
ConcurrencyLimit: 0,
Name: "",
Description: "",
}
// the default cmdStop to taskkill /f /t /pid ${PID}
if runtime.GOOS == "windows" {
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
}
if err := unmarshal(&defaults); err != nil {
return err
}
*m = ModelConfig(defaults)
return nil
}
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd)
}
// ModelFilters see issue #174
type ModelFilters struct {
StripParams string `yaml:"strip_params"`
}
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelFilters ModelFilters
defaults := rawModelFilters{
StripParams: "",
}
if err := unmarshal(&defaults); err != nil {
return err
}
*m = ModelFilters(defaults)
return nil
}
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
if f.StripParams == "" {
return nil, nil
}
params := strings.Split(f.StripParams, ",")
cleaned := make([]string, 0, len(params))
for _, param := range params {
trimmed := strings.TrimSpace(param)
if trimmed == "model" || trimmed == "" {
continue
}
cleaned = append(cleaned, trimmed)
}
// sort cleaned
slices.Sort(cleaned)
return cleaned, nil
}
type GroupConfig struct {
Swap bool `yaml:"swap"`
Exclusive bool `yaml:"exclusive"`
Persistent bool `yaml:"persistent"`
Members []string `yaml:"members"`
}
// set default values for GroupConfig
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawGroupConfig GroupConfig
defaults := rawGroupConfig{
Swap: true,
Exclusive: true,
Persistent: false,
Members: []string{},
}
if err := unmarshal(&defaults); err != nil {
return err
}
*c = GroupConfig(defaults)
return nil
}
type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"`
Models map[string]ModelConfig `yaml:"models"`
LogLevel string `yaml:"logLevel"`
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
Profiles map[string][]string `yaml:"profiles"`
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
Macros map[string]string `yaml:"macros"`
// map aliases to actual model IDs
aliases map[string]string
// automatic port assignments
StartPort int `yaml:"startPort"`
}
func (c *Config) RealModelName(search string) (string, bool) {
@@ -51,42 +175,234 @@ func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
}
}
func LoadConfig(path string) (*Config, error) {
data, err := os.ReadFile(path)
func LoadConfig(path string) (Config, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
return Config{}, err
}
defer file.Close()
return LoadConfigFromReader(file)
}
func LoadConfigFromReader(r io.Reader) (Config, error) {
data, err := io.ReadAll(r)
if err != nil {
return Config{}, err
}
var config Config
// default configuration values
config := Config{
HealthCheckTimeout: 120,
StartPort: 5800,
LogLevel: "info",
MetricsMaxInMemory: 1000,
}
err = yaml.Unmarshal(data, &config)
if err != nil {
return nil, err
return Config{}, err
}
if config.HealthCheckTimeout < 15 {
// set a minimum of 15 seconds
config.HealthCheckTimeout = 15
}
if config.StartPort < 1 {
return Config{}, fmt.Errorf("startPort must be greater than 1")
}
// Populate the aliases map
config.aliases = make(map[string]string)
for modelName, modelConfig := range config.Models {
for _, alias := range modelConfig.Aliases {
if _, found := config.aliases[alias]; found {
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
}
config.aliases[alias] = modelName
}
}
return &config, nil
/* check macro constraint rules:
- name must fit the regex ^[a-zA-Z0-9_-]+$
- names must be less than 64 characters (no reason, just cause)
- name can not be any reserved macros: PORT
- macro values must be less than 1024 characters
*/
macroNameRegex := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
for macroName, macroValue := range config.Macros {
if len(macroName) >= 64 {
return Config{}, fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", macroName)
}
if !macroNameRegex.MatchString(macroName) {
return Config{}, fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", macroName)
}
if len(macroValue) >= 1024 {
return Config{}, fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", macroName)
}
switch macroName {
case "PORT":
return Config{}, fmt.Errorf("macro name '%s' is reserved and cannot be used", macroName)
}
}
// Get and sort all model IDs first, makes testing more consistent
modelIds := make([]string, 0, len(config.Models))
for modelId := range config.Models {
modelIds = append(modelIds, modelId)
}
sort.Strings(modelIds) // This guarantees stable iteration order
nextPort := config.StartPort
for _, modelId := range modelIds {
modelConfig := config.Models[modelId]
// Strip comments from command fields before macro expansion
modelConfig.Cmd = StripComments(modelConfig.Cmd)
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
for macroName, macroValue := range config.Macros {
macroSlug := fmt.Sprintf("${%s}", macroName)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroValue)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroValue)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroValue)
}
// enforce ${PORT} used in both cmd and proxy
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
}
// only iterate over models that use ${PORT} to keep port numbers from increasing unnecessarily
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
nextPortStr := strconv.Itoa(nextPort)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", nextPortStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${PORT}", nextPortStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", nextPortStr)
nextPort++
}
// make sure there are no unknown macros that have not been replaced
macroPattern := regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
fieldMap := map[string]string{
"cmd": modelConfig.Cmd,
"cmdStop": modelConfig.CmdStop,
"proxy": modelConfig.Proxy,
"checkEndpoint": modelConfig.CheckEndpoint,
}
for fieldName, fieldValue := range fieldMap {
matches := macroPattern.FindAllStringSubmatch(fieldValue, -1)
for _, match := range matches {
macroName := match[1]
if macroName == "PID" && fieldName == "cmdStop" {
continue // this is ok, has to be replaced by process later
}
if _, exists := config.Macros[macroName]; !exists {
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
}
}
}
config.Models[modelId] = modelConfig
}
config = AddDefaultGroupToConfig(config)
// check that members are all unique in the groups
memberUsage := make(map[string]string) // maps member to group it appears in
for groupID, groupConfig := range config.Groups {
prevSet := make(map[string]bool)
for _, member := range groupConfig.Members {
// Check for duplicates within this group
if _, found := prevSet[member]; found {
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
}
prevSet[member] = true
// Check if member is used in another group
if existingGroup, exists := memberUsage[member]; exists {
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
}
memberUsage[member] = groupID
}
}
return config, nil
}
// rewrites the yaml to include a default group with any orphaned models
func AddDefaultGroupToConfig(config Config) Config {
if config.Groups == nil {
config.Groups = make(map[string]GroupConfig)
}
defaultGroup := GroupConfig{
Swap: true,
Exclusive: true,
Members: []string{},
}
// if groups is empty, create a default group and put
// all models into it
if len(config.Groups) == 0 {
for modelName := range config.Models {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
} else {
// iterate over existing group members and add non-grouped models into the default group
for modelName, _ := range config.Models {
foundModel := false
found:
// search for the model in existing groups
for _, groupConfig := range config.Groups {
for _, member := range groupConfig.Members {
if member == modelName {
foundModel = true
break found
}
}
}
if !foundModel {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
}
}
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
return config
}
func SanitizeCommand(cmdStr string) ([]string, error) {
// Remove trailing backslashes
cmdStr = strings.ReplaceAll(cmdStr, "\\ \n", " ")
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
var cleanedLines []string
for _, line := range strings.Split(cmdStr, "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
// Handle trailing backslashes by replacing with space
if strings.HasSuffix(trimmed, "\\") {
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
} else {
cleanedLines = append(cleanedLines, line)
}
}
// put it back together
cmdStr = strings.Join(cleanedLines, "\n")
// Split the command into arguments
args, err := shlex.Split(cmdStr)
if err != nil {
return nil, err
var args []string
if runtime.GOOS == "windows" {
args = shlex.Windows.Split(cmdStr)
} else {
args = shlex.Posix.Split(cmdStr)
}
// Ensure the command is not empty
@@ -96,3 +412,16 @@ func SanitizeCommand(cmdStr string) ([]string, error) {
return args, nil
}
func StripComments(cmdStr string) string {
var cleanedLines []string
for _, line := range strings.Split(cmdStr, "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
cleanedLines = append(cleanedLines, line)
}
return strings.Join(cleanedLines, "\n")
}
+234
View File
@@ -0,0 +1,234 @@
//go:build !windows
package proxy
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfig_SanitizeCommand(t *testing.T) {
// Test a command with spaces and newlines
args, err := SanitizeCommand(`python model1.py \
-a "double quotes" \
--arg2 'single quotes'
-s
# comment 1
--arg3 123 \
# comment 2
--arg4 '"string in string"'
# this will get stripped out as well as the white space above
-c "'single quoted'"
`)
assert.NoError(t, err)
assert.Equal(t, []string{
"python", "model1.py",
"-a", "double quotes",
"--arg2", "single quotes",
"-s",
"--arg3", "123",
"--arg4", `"string in string"`,
"-c", `'single quoted'`,
}, args)
// Test an empty command
args, err = SanitizeCommand("")
assert.Error(t, err)
assert.Nil(t, args)
}
// Test the default values are automatically set for global, model and group configurations
// after loading the configuration
func TestConfig_DefaultValuesPosix(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, 120, config.HealthCheckTimeout)
assert.Equal(t, 5800, config.StartPort)
assert.Equal(t, "info", config.LogLevel)
// Test default group exists
defaultGroup, exists := config.Groups["(default)"]
assert.True(t, exists, "default group should exist")
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
assert.Equal(t, true, defaultGroup.Swap)
assert.Equal(t, true, defaultGroup.Exclusive)
assert.Equal(t, false, defaultGroup.Persistent)
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
}
model1, exists := config.Models["model1"]
assert.True(t, exists, "model1 should exist")
if assert.NotNil(t, model1, "model1 should not be nil") {
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
assert.Equal(t, "", model1.CmdStop)
assert.Equal(t, "http://localhost:5800", model1.Proxy)
assert.Equal(t, "/health", model1.CheckEndpoint)
assert.Equal(t, []string{}, model1.Aliases)
assert.Equal(t, []string{}, model1.Env)
assert.Equal(t, 0, model1.UnloadAfter)
assert.Equal(t, false, model1.Unlisted)
assert.Equal(t, "", model1.UseModelName)
assert.Equal(t, 0, model1.ConcurrencyLimit)
}
// default empty filter exists
assert.Equal(t, "", model1.Filters.StripParams)
}
func TestConfig_LoadPosix(t *testing.T) {
// Create a temporary YAML file for testing
tempDir, err := os.MkdirTemp("", "test-config")
if err != nil {
t.Fatalf("Failed to create temporary directory: %v", err)
}
defer os.RemoveAll(tempDir)
tempFile := filepath.Join(tempDir, "config.yaml")
content := `
macros:
svr-path: "path/to/server"
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
name: "Model 1"
description: "This is model 1"
aliases:
- "m1"
- "model-one"
env:
- "VAR1=value1"
- "VAR2=value2"
checkEndpoint: "/health"
model2:
cmd: ${svr-path} --arg1 one
proxy: "http://localhost:8081"
aliases:
- "m2"
checkEndpoint: "/"
model3:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
aliases:
- "mthree"
checkEndpoint: "/"
model4:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8082"
checkEndpoint: "/"
healthCheckTimeout: 15
profiles:
test:
- model1
- model2
groups:
group1:
swap: true
exclusive: false
members: ["model2"]
forever:
exclusive: false
persistent: true
members:
- "model4"
`
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
t.Fatalf("Failed to write temporary file: %v", err)
}
// Load the config and verify
config, err := LoadConfig(tempFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
expected := Config{
LogLevel: "info",
StartPort: 5800,
Macros: map[string]string{
"svr-path": "path/to/server",
},
Models: map[string]ModelConfig{
"model1": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
Name: "Model 1",
Description: "This is model 1",
},
"model2": {
Cmd: "path/to/server --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
},
"model3": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
},
},
HealthCheckTimeout: 15,
MetricsMaxInMemory: 1000,
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
aliases: map[string]string{
"m1": "model1",
"model-one": "model1",
"m2": "model2",
"mthree": "model3",
},
Groups: map[string]GroupConfig{
DEFAULT_GROUP_ID: {
Swap: true,
Exclusive: true,
Members: []string{"model1", "model3"},
},
"group1": {
Swap: true,
Exclusive: false,
Members: []string{"model2"},
},
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model4"},
},
},
}
assert.Equal(t, expected, config)
realname, found := config.RealModelName("m1")
assert.True(t, found)
assert.Equal(t, "model1", realname)
}
+355 -89
View File
@@ -1,90 +1,68 @@
package proxy
import (
"os"
"path/filepath"
"slices"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfig_Load(t *testing.T) {
// Create a temporary YAML file for testing
tempDir, err := os.MkdirTemp("", "test-config")
if err != nil {
t.Fatalf("Failed to create temporary directory: %v", err)
}
defer os.RemoveAll(tempDir)
func TestConfig_GroupMemberIsUnique(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
model2:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
model3:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
tempFile := filepath.Join(tempDir, "config.yaml")
healthCheckTimeout: 15
groups:
group1:
swap: true
exclusive: false
members: ["model2"]
group2:
swap: true
exclusive: false
members: ["model2"]
`
// Load the config and verify
_, err := LoadConfigFromReader(strings.NewReader(content))
// a Contains as order of the map is not guaranteed
assert.Contains(t, err.Error(), "model member model2 is used in multiple groups:")
}
func TestConfig_ModelAliasesAreUnique(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
aliases:
- "m1"
- "model-one"
env:
- "VAR1=value1"
- "VAR2=value2"
checkEndpoint: "/health"
- m1
model2:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
aliases:
- "m2"
checkEndpoint: "/"
healthCheckTimeout: 15
profiles:
test:
- model1
- model2
aliases:
- m1
- m2
`
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
t.Fatalf("Failed to write temporary file: %v", err)
}
// Load the config and verify
config, err := LoadConfig(tempFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
_, err := LoadConfigFromReader(strings.NewReader(content))
expected := &Config{
Models: map[string]ModelConfig{
"model1": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
},
"model2": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: nil,
CheckEndpoint: "/",
},
},
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
aliases: map[string]string{
"m1": "model1",
"model-one": "model1",
"m2": "model2",
},
}
assert.Equal(t, expected, config)
realname, found := config.RealModelName("m1")
assert.True(t, found)
assert.Equal(t, "model1", realname)
// this is a contains because it could be `model1` or `model2` depending on the order
// go decided on the order of the map
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
}
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
@@ -147,30 +125,318 @@ func TestConfig_FindConfig(t *testing.T) {
assert.Equal(t, ModelConfig{}, modelConfig)
}
func TestConfig_SanitizeCommand(t *testing.T) {
func TestConfig_AutomaticPortAssignments(t *testing.T) {
// Test a command with spaces and newlines
args, err := SanitizeCommand(`python model1.py \
-a "double quotes" \
--arg2 'single quotes'
-s
--arg3 123 \
--arg4 '"string in string"'
-c "'single quoted'"
`)
assert.NoError(t, err)
assert.Equal(t, []string{
"python", "model1.py",
"-a", "double quotes",
"--arg2", "single quotes",
"-s",
"--arg3", "123",
"--arg4", `"string in string"`,
"-c", `'single quoted'`,
}, args)
t.Run("Default Port Ranges", func(t *testing.T) {
content := ``
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
// Test an empty command
args, err = SanitizeCommand("")
assert.Error(t, err)
assert.Nil(t, args)
assert.Equal(t, 5800, config.StartPort)
})
t.Run("User specific port ranges", func(t *testing.T) {
content := `startPort: 1000`
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
assert.Equal(t, 1000, config.StartPort)
})
t.Run("Invalid start port", func(t *testing.T) {
content := `startPort: abcd`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.NotNil(t, err)
})
t.Run("start port must be greater than 1", func(t *testing.T) {
content := `startPort: -99`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.NotNil(t, err)
})
t.Run("Automatic port assignments", func(t *testing.T) {
content := `
startPort: 5800
models:
model1:
cmd: svr --port ${PORT}
model2:
cmd: svr --port ${PORT}
proxy: "http://172.11.22.33:${PORT}"
model3:
cmd: svr --port 1999
proxy: "http://1.2.3.4:1999"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
assert.Equal(t, 5800, config.StartPort)
assert.Equal(t, "svr --port 5800", config.Models["model1"].Cmd)
assert.Equal(t, "http://localhost:5800", config.Models["model1"].Proxy)
assert.Equal(t, "svr --port 5801", config.Models["model2"].Cmd)
assert.Equal(t, "http://172.11.22.33:5801", config.Models["model2"].Proxy)
assert.Equal(t, "svr --port 1999", config.Models["model3"].Cmd)
assert.Equal(t, "http://1.2.3.4:1999", config.Models["model3"].Proxy)
})
t.Run("Proxy value required if no ${PORT} in cmd", func(t *testing.T) {
content := `
models:
model1:
cmd: svr --port 111
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Equal(t, "model model1: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", err.Error())
})
}
func TestConfig_MacroReplacement(t *testing.T) {
content := `
startPort: 9990
macros:
svr-path: "path/to/server"
argOne: "--arg1"
argTwo: "--arg2"
autoPort: "--port ${PORT}"
models:
model1:
cmd: |
${svr-path} ${argTwo}
# the automatic ${PORT} is replaced
${autoPort}
${argOne}
--arg3 three
cmdStop: |
/path/to/stop.sh --port ${PORT} ${argTwo}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
assert.NoError(t, err)
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three", strings.Join(sanitizedCmd, " "))
sanitizedCmdStop, err := SanitizeCommand(config.Models["model1"].CmdStop)
assert.NoError(t, err)
assert.Equal(t, "/path/to/stop.sh --port 9990 --arg2", strings.Join(sanitizedCmdStop, " "))
}
func TestConfig_MacroErrorOnUnknownMacros(t *testing.T) {
tests := []struct {
name string
field string
content string
}{
{
name: "unknown macro in cmd",
field: "cmd",
content: `
startPort: 9990
macros:
svr-path: "path/to/server"
models:
model1:
cmd: |
${svr-path} --port ${PORT}
${unknownMacro}
`,
},
{
name: "unknown macro in cmdStop",
field: "cmdStop",
content: `
startPort: 9990
macros:
svr-path: "path/to/server"
models:
model1:
cmd: "${svr-path} --port ${PORT}"
cmdStop: "kill ${unknownMacro}"
`,
},
{
name: "unknown macro in proxy",
field: "proxy",
content: `
startPort: 9990
macros:
svr-path: "path/to/server"
models:
model1:
cmd: "${svr-path} --port ${PORT}"
proxy: "http://localhost:${unknownMacro}"
`,
},
{
name: "unknown macro in checkEndpoint",
field: "checkEndpoint",
content: `
startPort: 9990
macros:
svr-path: "path/to/server"
models:
model1:
cmd: "${svr-path} --port ${PORT}"
checkEndpoint: "http://localhost:${unknownMacro}/health"
`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field)
//t.Log(err)
})
}
}
func TestConfig_ModelFilters(t *testing.T) {
content := `
macros:
default_strip: "temperature, top_p"
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
strip_params: "model, top_k, ${default_strip}, , ,"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
modelConfig, ok := config.Models["model1"]
if !assert.True(t, ok) {
t.FailNow()
}
// make sure `model` and enmpty strings are not in the list
assert.Equal(t, "model, top_k, temperature, top_p, , ,", modelConfig.Filters.StripParams)
sanitized, err := modelConfig.Filters.SanitizedStripParams()
if assert.NoError(t, err) {
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
}
}
func TestStripComments(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "no comments",
input: "echo hello\necho world",
expected: "echo hello\necho world",
},
{
name: "single comment line",
input: "# this is a comment\necho hello",
expected: "echo hello",
},
{
name: "multiple comment lines",
input: "# comment 1\necho hello\n# comment 2\necho world",
expected: "echo hello\necho world",
},
{
name: "comment with spaces",
input: " # indented comment\necho hello",
expected: "echo hello",
},
{
name: "empty lines preserved",
input: "echo hello\n\necho world",
expected: "echo hello\n\necho world",
},
{
name: "only comments",
input: "# comment 1\n# comment 2",
expected: "",
},
{
name: "empty string",
input: "",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := StripComments(tt.input)
if result != tt.expected {
t.Errorf("StripComments() = %q, expected %q", result, tt.expected)
}
})
}
}
func TestConfig_MacroInCommentStrippedBeforeExpansion(t *testing.T) {
// Test case that reproduces the original bug where a macro in a comment
// would get expanded and cause the comment text to be included in the command
content := `
startPort: 9990
macros:
"latest-llama": >
/user/llama.cpp/build/bin/llama-server
--port ${PORT}
models:
"test-model":
cmd: |
# ${latest-llama} is a macro that is defined above
${latest-llama}
--model /path/to/model.gguf
-ngl 99
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
// Get the sanitized command
sanitizedCmd, err := SanitizeCommand(config.Models["test-model"].Cmd)
assert.NoError(t, err)
// Join the command for easier inspection
cmdStr := strings.Join(sanitizedCmd, " ")
// Verify that comment text is NOT present in the final command as separate arguments
commentWords := []string{"is", "macro", "that", "defined", "above"}
for _, word := range commentWords {
found := slices.Contains(sanitizedCmd, word)
assert.False(t, found, "Comment text '%s' should not be present as a separate argument in final command", word)
}
// Verify that the actual command components ARE present
expectedParts := []string{
"/user/llama.cpp/build/bin/llama-server",
"--port",
"9990",
"--model",
"/path/to/model.gguf",
"-ngl",
"99",
}
for _, part := range expectedParts {
assert.Contains(t, cmdStr, part, "Expected command part '%s' not found in final command", part)
}
// Verify the server path appears exactly once (not duplicated due to macro expansion)
serverPath := "/user/llama.cpp/build/bin/llama-server"
count := strings.Count(cmdStr, serverPath)
assert.Equal(t, 1, count, "Expected exactly 1 occurrence of server path, found %d", count)
// Verify the expected final command structure
expectedCmd := "/user/llama.cpp/build/bin/llama-server --port 9990 --model /path/to/model.gguf -ngl 99"
assert.Equal(t, expectedCmd, cmdStr, "Final command does not match expected structure")
}
+231
View File
@@ -0,0 +1,231 @@
//go:build windows
package proxy
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfig_SanitizeCommand(t *testing.T) {
// does not support single quoted strings like in config_posix_test.go
args, err := SanitizeCommand(`python model1.py \
-a "double quotes" \
-s
--arg3 123 \
# comment 2
--arg4 '"string in string"'
# this will get stripped out as well as the white space above
-c "'single quoted'"
`)
assert.NoError(t, err)
assert.Equal(t, []string{
"python", "model1.py",
"-a", "double quotes",
"-s",
"--arg3", "123",
"--arg4", "'string in string'", // this is a little weird but the lexer says so...?
"-c", `'single quoted'`,
}, args)
// Test an empty command
args, err = SanitizeCommand("")
assert.Error(t, err)
assert.Nil(t, args)
}
func TestConfig_DefaultValuesWindows(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, 120, config.HealthCheckTimeout)
assert.Equal(t, 5800, config.StartPort)
assert.Equal(t, "info", config.LogLevel)
// Test default group exists
defaultGroup, exists := config.Groups["(default)"]
assert.True(t, exists, "default group should exist")
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
assert.Equal(t, true, defaultGroup.Swap)
assert.Equal(t, true, defaultGroup.Exclusive)
assert.Equal(t, false, defaultGroup.Persistent)
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
}
model1, exists := config.Models["model1"]
assert.True(t, exists, "model1 should exist")
if assert.NotNil(t, model1, "model1 should not be nil") {
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
assert.Equal(t, "taskkill /f /t /pid ${PID}", model1.CmdStop)
assert.Equal(t, "http://localhost:5800", model1.Proxy)
assert.Equal(t, "/health", model1.CheckEndpoint)
assert.Equal(t, []string{}, model1.Aliases)
assert.Equal(t, []string{}, model1.Env)
assert.Equal(t, 0, model1.UnloadAfter)
assert.Equal(t, false, model1.Unlisted)
assert.Equal(t, "", model1.UseModelName)
assert.Equal(t, 0, model1.ConcurrencyLimit)
}
// default empty filter exists
assert.Equal(t, "", model1.Filters.StripParams)
}
func TestConfig_LoadWindows(t *testing.T) {
// Create a temporary YAML file for testing
tempDir, err := os.MkdirTemp("", "test-config")
if err != nil {
t.Fatalf("Failed to create temporary directory: %v", err)
}
defer os.RemoveAll(tempDir)
tempFile := filepath.Join(tempDir, "config.yaml")
content := `
macros:
svr-path: "path/to/server"
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
aliases:
- "m1"
- "model-one"
env:
- "VAR1=value1"
- "VAR2=value2"
checkEndpoint: "/health"
model2:
cmd: ${svr-path} --arg1 one
proxy: "http://localhost:8081"
aliases:
- "m2"
checkEndpoint: "/"
model3:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
aliases:
- "mthree"
checkEndpoint: "/"
model4:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8082"
checkEndpoint: "/"
healthCheckTimeout: 15
profiles:
test:
- model1
- model2
groups:
group1:
swap: true
exclusive: false
members: ["model2"]
forever:
exclusive: false
persistent: true
members:
- "model4"
`
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
t.Fatalf("Failed to write temporary file: %v", err)
}
// Load the config and verify
config, err := LoadConfig(tempFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
expected := Config{
LogLevel: "info",
StartPort: 5800,
Macros: map[string]string{
"svr-path": "path/to/server",
},
Models: map[string]ModelConfig{
"model1": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
},
"model2": {
Cmd: "path/to/server --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
},
"model3": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
},
},
HealthCheckTimeout: 15,
MetricsMaxInMemory: 1000,
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
aliases: map[string]string{
"m1": "model1",
"model-one": "model1",
"m2": "model2",
"mthree": "model3",
},
Groups: map[string]GroupConfig{
DEFAULT_GROUP_ID: {
Swap: true,
Exclusive: true,
Members: []string{"model1", "model3"},
},
"group1": {
Swap: true,
Exclusive: false,
Members: []string{"model2"},
},
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model4"},
},
},
}
assert.Equal(t, expected, config)
realname, found := config.RealModelName("m1")
assert.True(t, found)
assert.Equal(t, "model1", realname)
}
+50
View File
@@ -0,0 +1,50 @@
package proxy
// package level registry of the different event types
const ProcessStateChangeEventID = 0x01
const ChatCompletionStatsEventID = 0x02
const ConfigFileChangedEventID = 0x03
const LogDataEventID = 0x04
const TokenMetricsEventID = 0x05
type ProcessStateChangeEvent struct {
ProcessName string
NewState ProcessState
OldState ProcessState
}
func (e ProcessStateChangeEvent) Type() uint32 {
return ProcessStateChangeEventID
}
type ChatCompletionStats struct {
TokensGenerated int
}
func (e ChatCompletionStats) Type() uint32 {
return ChatCompletionStatsEventID
}
type ReloadingState int
const (
ReloadingStateStart ReloadingState = iota
ReloadingStateEnd
)
type ConfigFileChangedEvent struct {
ReloadingState ReloadingState
}
func (e ConfigFileChangedEvent) Type() uint32 {
return ConfigFileChangedEventID
}
type LogDataEvent struct {
Data []byte
}
func (e LogDataEvent) Type() uint32 {
return LogDataEventID
}
+36 -8
View File
@@ -9,11 +9,13 @@ import (
"testing"
"github.com/gin-gonic/gin"
"gopkg.in/yaml.v3"
)
var (
nextTestPort int = 12000
portMutex sync.Mutex
testLogger = NewLogMonitorWriter(os.Stdout)
)
// Check if the binary exists
@@ -26,6 +28,17 @@ func TestMain(m *testing.M) {
gin.SetMode(gin.TestMode)
switch os.Getenv("LOG_LEVEL") {
case "debug":
testLogger.SetLogLevel(LevelDebug)
case "warn":
testLogger.SetLogLevel(LevelWarn)
case "info":
testLogger.SetLogLevel(LevelInfo)
default:
testLogger.SetLogLevel(LevelWarn)
}
m.Run()
}
@@ -33,26 +46,41 @@ func TestMain(m *testing.M) {
func getSimpleResponderPath() string {
goos := runtime.GOOS
goarch := runtime.GOARCH
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
if goos == "windows" {
return filepath.Join("..", "build", "simple-responder.exe")
} else {
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
}
}
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
func getTestPort() int {
portMutex.Lock()
defer portMutex.Unlock()
port := nextTestPort
nextTestPort++
return getTestSimpleResponderConfigPort(expectedMessage, port)
return port
}
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
}
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
binaryPath := getSimpleResponderPath()
// Create a process configuration
return ModelConfig{
Cmd: fmt.Sprintf("%s --port %d --silent --respond %s", binaryPath, port, expectedMessage),
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
// Create a YAML string with just the values we want to set
yamlStr := fmt.Sprintf(`
cmd: '%s --port %d --silent --respond %s'
proxy: "http://127.0.0.1:%d"
`, binaryPath, port, expectedMessage, port)
var cfg ModelConfig
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
}
return cfg
}
Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

-53
View File
@@ -1,53 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Logs</title>
<style>
body {
margin: 0;
height: 100vh;
display: flex;
flex-direction: column;
font-family: "Courier New", Courier, monospace;
}
#log-stream {
flex: 1;
margin: 1em;
padding: 10px;
background: #f4f4f4;
overflow-y: auto;
white-space: pre-wrap; /* Ensures line wrapping */
word-wrap: break-word; /* Ensures long words wrap */
}
</style>
</head>
<body>
<pre id="log-stream">Waiting for logs...
</pre>
<script>
// Establish an EventSource connection to the SSE endpoint
if (typeof(EventSource) !== "undefined") {
const eventSource = new EventSource("/logs/streamSSE");
eventSource.onmessage = function(event) {
// Append the new log message to the <pre> element
const logStream = document.getElementById('log-stream');
logStream.textContent += event.data;
// Auto-scroll to the bottom
logStream.scrollTop = logStream.scrollHeight;
};
eventSource.onerror = function(err) {
console.error("EventSource failed:", err);
};
} else {
console.error("SSE not supported by this browser.");
}
</script>
</body>
</html>
+100 -27
View File
@@ -2,19 +2,36 @@ package proxy
import (
"container/ring"
"context"
"fmt"
"io"
"os"
"sync"
"github.com/mostlygeek/llama-swap/event"
)
type LogLevel int
const (
LevelDebug LogLevel = iota
LevelInfo
LevelWarn
LevelError
)
type LogMonitor struct {
clients map[chan []byte]bool
eventbus *event.Dispatcher
mu sync.RWMutex
buffer *ring.Ring
bufferMu sync.RWMutex
// typically this can be os.Stdout
stdout io.Writer
// logging levels
level LogLevel
prefix string
}
func NewLogMonitor() *LogMonitor {
@@ -23,9 +40,11 @@ func NewLogMonitor() *LogMonitor {
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
return &LogMonitor{
clients: make(map[chan []byte]bool),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout,
eventbus: event.NewDispatcherConfig(1000),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout,
level: LevelInfo,
prefix: "",
}
}
@@ -65,32 +84,86 @@ func (w *LogMonitor) GetHistory() []byte {
return history
}
func (w *LogMonitor) Subscribe() chan []byte {
w.mu.Lock()
defer w.mu.Unlock()
ch := make(chan []byte, 100)
w.clients[ch] = true
return ch
}
func (w *LogMonitor) Unsubscribe(ch chan []byte) {
w.mu.Lock()
defer w.mu.Unlock()
delete(w.clients, ch)
close(ch)
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
return event.Subscribe(w.eventbus, func(e LogDataEvent) {
callback(e.Data)
})
}
func (w *LogMonitor) broadcast(msg []byte) {
w.mu.RLock()
defer w.mu.RUnlock()
event.Publish(w.eventbus, LogDataEvent{Data: msg})
}
for client := range w.clients {
select {
case client <- msg:
default:
// If client buffer is full, skip
}
func (w *LogMonitor) SetPrefix(prefix string) {
w.mu.Lock()
defer w.mu.Unlock()
w.prefix = prefix
}
func (w *LogMonitor) SetLogLevel(level LogLevel) {
w.mu.Lock()
defer w.mu.Unlock()
w.level = level
}
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
prefix := ""
if w.prefix != "" {
prefix = fmt.Sprintf("[%s] ", w.prefix)
}
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
}
func (w *LogMonitor) log(level LogLevel, msg string) {
if level < w.level {
return
}
w.Write(w.formatMessage(level.String(), msg))
}
func (w *LogMonitor) Debug(msg string) {
w.log(LevelDebug, msg)
}
func (w *LogMonitor) Info(msg string) {
w.log(LevelInfo, msg)
}
func (w *LogMonitor) Warn(msg string) {
w.log(LevelWarn, msg)
}
func (w *LogMonitor) Error(msg string) {
w.log(LevelError, msg)
}
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
w.log(LevelDebug, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Infof(format string, args ...interface{}) {
w.log(LevelInfo, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
w.log(LevelWarn, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
w.log(LevelError, fmt.Sprintf(format, args...))
}
func (l LogLevel) String() string {
switch l {
case LevelDebug:
return "DEBUG"
case LevelInfo:
return "INFO"
case LevelWarn:
return "WARN"
case LevelError:
return "ERROR"
default:
return "UNKNOWN"
}
}
+13 -22
View File
@@ -10,38 +10,29 @@ import (
func TestLogMonitor(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
// Test subscription
client1 := logMonitor.Subscribe()
client2 := logMonitor.Subscribe()
defer logMonitor.Unsubscribe(client1)
defer logMonitor.Unsubscribe(client2)
// A WaitGroup is used to wait for all the expected writes to complete
var wg sync.WaitGroup
client1Messages := make([]byte, 0)
client2Messages := make([]byte, 0)
var wg sync.WaitGroup
wg.Add(1)
defer logMonitor.OnLogData(func(data []byte) {
client1Messages = append(client1Messages, data...)
wg.Done()
})()
go func() {
defer wg.Done()
for {
select {
case data := <-client1:
client1Messages = append(client1Messages, data...)
case data := <-client2:
client2Messages = append(client2Messages, data...)
default:
return
}
}
}()
defer logMonitor.OnLogData(func(data []byte) {
client2Messages = append(client2Messages, data...)
wg.Done()
})()
wg.Add(6) // 2 x 3 writes
logMonitor.Write([]byte("1"))
logMonitor.Write([]byte("2"))
logMonitor.Write([]byte("3"))
// Wait for the goroutine to finish
// wait for all writes to complete
wg.Wait()
// Check the buffer
+170
View File
@@ -0,0 +1,170 @@
package proxy
import (
"bytes"
"fmt"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
c.Abort()
return
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
c.Abort()
return
}
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
c.Abort()
return
}
writer := &MetricsResponseWriter{
ResponseWriter: c.Writer,
metricsRecorder: &MetricsRecorder{
metricsMonitor: pm.metricsMonitor,
realModelName: realModelName,
isStreaming: gjson.GetBytes(bodyBytes, "stream").Bool(),
startTime: time.Now(),
},
}
c.Writer = writer
c.Next()
rec := writer.metricsRecorder
rec.processBody(writer.body)
}
}
type MetricsRecorder struct {
metricsMonitor *MetricsMonitor
realModelName string
isStreaming bool
startTime time.Time
}
// processBody handles response processing after request completes
func (rec *MetricsRecorder) processBody(body []byte) {
if rec.isStreaming {
rec.processStreamingResponse(body)
} else {
rec.processNonStreamingResponse(body)
}
}
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool {
usage := jsonData.Get("usage")
if !usage.Exists() {
return false
}
// default values
outputTokens := int(jsonData.Get("usage.completion_tokens").Int())
inputTokens := int(jsonData.Get("usage.prompt_tokens").Int())
tokensPerSecond := -1.0
durationMs := int(time.Since(rec.startTime).Milliseconds())
// use llama-server's timing data for tok/sec and duration as it is more accurate
if timings := jsonData.Get("timings"); timings.Exists() {
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
}
rec.metricsMonitor.addMetrics(TokenMetrics{
Timestamp: time.Now(),
Model: rec.realModelName,
InputTokens: inputTokens,
OutputTokens: outputTokens,
TokensPerSecond: tokensPerSecond,
DurationMs: durationMs,
})
return true
}
func (rec *MetricsRecorder) processStreamingResponse(body []byte) {
// Iterate **backwards** through the lines looking for the data payload with
// usage data
lines := bytes.Split(body, []byte("\n"))
for i := len(lines) - 1; i >= 0; i-- {
line := bytes.TrimSpace(lines[i])
if len(line) == 0 {
continue
}
// SSE payload always follows "data:"
prefix := []byte("data:")
if !bytes.HasPrefix(line, prefix) {
continue
}
data := bytes.TrimSpace(line[len(prefix):])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
// [DONE] line itself contains nothing of interest.
continue
}
if gjson.ValidBytes(data) {
if rec.parseAndRecordMetrics(gjson.ParseBytes(data)) {
return // short circuit if a metric was recorded
}
}
}
}
func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) {
if len(body) == 0 {
return
}
// Parse JSON to extract usage information
if gjson.ValidBytes(body) {
rec.parseAndRecordMetrics(gjson.ParseBytes(body))
}
}
// MetricsResponseWriter captures the entire response for non-streaming
type MetricsResponseWriter struct {
gin.ResponseWriter
body []byte
metricsRecorder *MetricsRecorder
}
func (w *MetricsResponseWriter) Write(b []byte) (int, error) {
n, err := w.ResponseWriter.Write(b)
if err != nil {
return n, err
}
w.body = append(w.body, b...)
return n, nil
}
func (w *MetricsResponseWriter) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *MetricsResponseWriter) Header() http.Header {
return w.ResponseWriter.Header()
}
+82
View File
@@ -0,0 +1,82 @@
package proxy
import (
"encoding/json"
"sync"
"time"
"github.com/mostlygeek/llama-swap/event"
)
// TokenMetrics represents parsed token statistics from llama-server logs
type TokenMetrics struct {
ID int `json:"id"`
Timestamp time.Time `json:"timestamp"`
Model string `json:"model"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TokensPerSecond float64 `json:"tokens_per_second"`
DurationMs int `json:"duration_ms"`
}
// TokenMetricsEvent represents a token metrics event
type TokenMetricsEvent struct {
Metrics TokenMetrics
}
func (e TokenMetricsEvent) Type() uint32 {
return TokenMetricsEventID // defined in events.go
}
// MetricsMonitor parses llama-server output for token statistics
type MetricsMonitor struct {
mu sync.RWMutex
metrics []TokenMetrics
maxMetrics int
nextID int
}
func NewMetricsMonitor(config *Config) *MetricsMonitor {
maxMetrics := config.MetricsMaxInMemory
if maxMetrics <= 0 {
maxMetrics = 1000 // Default fallback
}
mp := &MetricsMonitor{
maxMetrics: maxMetrics,
}
return mp
}
// addMetrics adds a new metric to the collection and publishes an event
func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
mp.mu.Lock()
defer mp.mu.Unlock()
metric.ID = mp.nextID
mp.nextID++
mp.metrics = append(mp.metrics, metric)
if len(mp.metrics) > mp.maxMetrics {
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
}
event.Emit(TokenMetricsEvent{Metrics: metric})
}
// GetMetrics returns a copy of the current metrics
func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
mp.mu.RLock()
defer mp.mu.RUnlock()
result := make([]TokenMetrics, len(mp.metrics))
copy(result, mp.metrics)
return result
}
// GetMetricsJSON returns metrics as JSON
func (mp *MetricsMonitor) GetMetricsJSON() ([]byte, error) {
mp.mu.RLock()
defer mp.mu.RUnlock()
return json.Marshal(mp.metrics)
}
+400 -184
View File
@@ -8,28 +8,50 @@ import (
"net/http"
"net/url"
"os/exec"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/mostlygeek/llama-swap/event"
)
type ProcessState string
const (
StateStopped ProcessState = ProcessState("stopped")
StateReady ProcessState = ProcessState("ready")
StateFailed ProcessState = ProcessState("failed")
StateStopped ProcessState = ProcessState("stopped")
StateStarting ProcessState = ProcessState("starting")
StateReady ProcessState = ProcessState("ready")
StateStopping ProcessState = ProcessState("stopping")
// process is shutdown and will not be restarted
StateShutdown ProcessState = ProcessState("shutdown")
)
type StopStrategy int
const (
StopImmediately StopStrategy = iota
StopWaitForInflightRequest
)
type Process struct {
sync.Mutex
ID string
config ModelConfig
cmd *exec.Cmd
ID string
config ModelConfig
cmd *exec.Cmd
logMonitor *LogMonitor
healthCheckTimeout int
// PR #155 called to cancel the upstream process
cancelUpstream context.CancelFunc
// closed when command exits
cmdWaitChan chan struct{}
processLogger *LogMonitor
proxyLogger *LogMonitor
healthCheckTimeout int
healthCheckLoopInterval time.Duration
lastRequestHandled time.Time
@@ -37,31 +59,110 @@ type Process struct {
state ProcessState
inFlightRequests sync.WaitGroup
// used to block on multiple start() calls
waitStarting sync.WaitGroup
// for managing concurrency limits
concurrencyLimitSemaphore chan struct{}
// used for testing to override the default value
gracefulStopTimeout time.Duration
// track the number of failed starts
failedStartCount int
}
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
concurrentLimit := 10
if config.ConcurrencyLimit > 0 {
concurrentLimit = config.ConcurrencyLimit
}
return &Process{
ID: ID,
config: config,
cmd: nil,
logMonitor: logMonitor,
healthCheckTimeout: healthCheckTimeout,
state: StateStopped,
ID: ID,
config: config,
cmd: nil,
cancelUpstream: nil,
processLogger: processLogger,
proxyLogger: proxyLogger,
healthCheckTimeout: healthCheckTimeout,
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
state: StateStopped,
// concurrency limit
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
// To be removed when migration over exec.CommandContext is complete
// stop timeout
gracefulStopTimeout: 10 * time.Second,
cmdWaitChan: make(chan struct{}),
}
}
// start the process and returns when it is ready
func (p *Process) start() error {
// LogMonitor returns the log monitor associated with the process.
func (p *Process) LogMonitor() *LogMonitor {
return p.processLogger
}
// custom error types for swapping state
var (
ErrExpectedStateMismatch = errors.New("expected state mismatch")
ErrInvalidStateTransition = errors.New("invalid state transition")
)
// swapState performs a compare and swap of the state atomically. It returns the current state
// and an error if the swap failed.
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) {
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if p.state == StateReady {
return nil
if p.state != expectedState {
p.proxyLogger.Warnf("<%s> swapState() Unexpected current state %s, expected %s", p.ID, p.state, expectedState)
return p.state, ErrExpectedStateMismatch
}
if p.state == StateFailed {
return fmt.Errorf("process is in a failed state and can not be restarted")
if !isValidTransition(p.state, newState) {
p.proxyLogger.Warnf("<%s> swapState() Invalid state transition from %s to %s", p.ID, p.state, newState)
return p.state, ErrInvalidStateTransition
}
p.state = newState
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
return p.state, nil
}
// Helper function to encapsulate transition rules
func isValidTransition(from, to ProcessState) bool {
switch from {
case StateStopped:
return to == StateStarting
case StateStarting:
return to == StateReady || to == StateStopping || to == StateStopped
case StateReady:
return to == StateStopping
case StateStopping:
return to == StateStopped || to == StateShutdown
case StateShutdown:
return false // No transitions allowed from these states
}
return false
}
func (p *Process) CurrentState() ProcessState {
p.stateMutex.RLock()
defer p.stateMutex.RUnlock()
return p.state
}
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
// it is a private method because starting is automatic but stopping can be called
// at any time.
func (p *Process) start() error {
if p.config.Proxy == "" {
return fmt.Errorf("can not start(), upstream proxy missing")
}
args, err := p.config.SanitizedCommand()
@@ -69,52 +170,105 @@ func (p *Process) start() error {
return fmt.Errorf("unable to get sanitized command: %v", err)
}
p.cmd = exec.Command(args[0], args[1:]...)
p.cmd.Stdout = p.logMonitor
p.cmd.Stderr = p.logMonitor
p.cmd.Env = p.config.Env
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
if err == ErrExpectedStateMismatch {
// already starting, just wait for it to complete and expect
// it to be be in the Ready start after. If not, return an error
if curState == StateStarting {
p.waitStarting.Wait()
if state := p.CurrentState(); state == StateReady {
return nil
} else {
return fmt.Errorf("process was already starting but wound up in state %v", state)
}
} else {
return fmt.Errorf("processes was in state %v when start() was called", curState)
}
} else {
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
}
}
p.waitStarting.Add(1)
defer p.waitStarting.Done()
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
p.cmd = exec.CommandContext(cmdContext, args[0], args[1:]...)
p.cmd.Stdout = p.processLogger
p.cmd.Stderr = p.processLogger
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
p.cmd.Cancel = p.cmdStopUpstreamProcess
p.cmd.WaitDelay = p.gracefulStopTimeout
p.cancelUpstream = ctxCancelUpstream
p.cmdWaitChan = make(chan struct{})
p.failedStartCount++ // this will be reset to zero when the process has successfully started
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.ID, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
err = p.cmd.Start()
// Set process state to failed
if err != nil {
return err
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
p.state = StateStopped // force it into a stopped state
return fmt.Errorf(
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
strings.Join(args, " "), err, curState, swapErr,
)
}
return fmt.Errorf("start() failed for command '%s': %v", strings.Join(args, " "), err)
}
// Capture the exit error for later signalling
go p.waitForCmd()
// One of three things can happen at this stage:
// 1. The command exits unexpectedly
// 2. The health check fails
// 3. The health check passes
//
// only in the third case will the process be considered Ready to accept
healthCheckContext, cancelHealthCheck := context.WithCancelCause(context.Background())
defer cancelHealthCheck(nil) // clean up
cmdWaitChan := make(chan error, 1)
healthCheckChan := make(chan error, 1)
<-time.After(250 * time.Millisecond) // give process a bit of time to start
go func() {
// possible cmd exits early
cmdWaitChan <- p.cmd.Wait()
}()
checkStartTime := time.Now()
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
go func() {
<-time.After(250 * time.Millisecond) // give process a bit of time to start
healthCheckChan <- p.checkHealthEndpoint(healthCheckContext)
}()
select {
case err := <-cmdWaitChan:
p.state = StateFailed
// a "none" means don't check for health ... I could have picked a better word :facepalm:
if checkEndpoint != "none" {
proxyTo := p.config.Proxy
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
if err != nil {
err = fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error())
} else {
err = fmt.Errorf("command [%s] exited unexpected", strings.Join(p.cmd.Args, " "))
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
}
cancelHealthCheck(err)
return err
case err := <-healthCheckChan:
if err != nil {
p.state = StateFailed
return err
// Ready Check loop
for {
currentState := p.CurrentState()
if currentState != StateStarting {
if currentState == StateStopped {
return fmt.Errorf("upstream command exited prematurely but successfully")
}
return errors.New("health check interrupted due to shutdown")
}
if time.Since(checkStartTime) > maxDuration {
p.stopCommand()
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
}
if err := p.checkHealthEndpoint(healthURL); err == nil {
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
break
} else {
if strings.Contains(err.Error(), "connection refused") {
ttl := time.Until(checkStartTime.Add(maxDuration))
p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds())
} else {
p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err)
}
}
<-time.After(p.healthCheckLoopInterval)
}
}
@@ -125,7 +279,7 @@ func (p *Process) start() error {
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
for range time.Tick(time.Second) {
if p.state != StateReady {
if p.CurrentState() != StateReady {
return
}
@@ -133,173 +287,155 @@ func (p *Process) start() error {
p.inFlightRequests.Wait()
if time.Since(p.lastRequestHandled) > maxDuration {
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
p.Stop()
return
}
}
}()
}
p.state = StateReady
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
} else {
p.failedStartCount = 0
return nil
}
}
// Stop will wait for inflight requests to complete before stopping the process.
func (p *Process) Stop() {
if !isValidTransition(p.CurrentState(), StateStopping) {
return
}
// wait for any inflight requests before proceeding
p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID)
p.inFlightRequests.Wait()
p.StopImmediately()
}
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
func (p *Process) StopImmediately() {
if !isValidTransition(p.CurrentState(), StateStopping) {
return
}
p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState())
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
return
}
p.stopCommand()
}
// Shutdown is called when llama-swap is shutting down. It will give a little bit
// of time for any inflight requests to complete before shutting down. If the Process
// is in the state of starting, it will cancel it and shut it down. Once a process is in
// the StateShutdown state, it can not be started again.
func (p *Process) Shutdown() {
if !isValidTransition(p.CurrentState(), StateStopping) {
return
}
p.stopCommand()
// just force it to this state since there is no recovery from shutdown
p.state = StateShutdown
}
// stopCommand will send a SIGTERM to the process and wait for it to exit.
// If it does not exit within 5 seconds, it will send a SIGKILL.
func (p *Process) stopCommand() {
stopStartTime := time.Now()
defer func() {
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
}()
if p.cancelUpstream == nil {
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
return
}
p.cancelUpstream()
<-p.cmdWaitChan
}
func (p *Process) checkHealthEndpoint(healthURL string) error {
client := &http.Client{
Timeout: 500 * time.Millisecond,
}
req, err := http.NewRequest("GET", healthURL, nil)
if err != nil {
return err
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
// got a response but it was not an OK
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status code: %d", resp.StatusCode)
}
return nil
}
func (p *Process) Stop() {
// wait for any inflight requests before proceeding
p.inFlightRequests.Wait()
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
requestBeginTime := time.Now()
var startDuration time.Duration
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if p.state != StateReady {
// prevent new requests from being made while stopping or irrecoverable
currentState := p.CurrentState()
if currentState == StateShutdown || currentState == StateStopping {
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
return
}
if p.cmd == nil || p.cmd.Process == nil {
// this situation should never happen... but if it does just update the state
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.")
p.state = StateStopped
return
}
// Pretty sure this stopping code needs some work for windows and
// will be a source of pain in the future.
p.cmd.Process.Signal(syscall.SIGTERM)
sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
sigtermNormal := make(chan error, 1)
go func() {
sigtermNormal <- p.cmd.Wait()
}()
select {
case <-sigtermTimeout.Done():
fmt.Fprintf(p.logMonitor, "!!! process for %s timed out waiting to stop\n", p.ID)
p.cmd.Process.Kill()
p.cmd.Wait()
case err := <-sigtermNormal:
if err != nil {
if err.Error() != "wait: no child processes" {
// possible that simple-responder for testing is just not
// existing right, so suppress those errors.
fmt.Fprintf(p.logMonitor, "!!! process for %s stopped with error > %v\n", p.ID, err)
}
}
case p.concurrencyLimitSemaphore <- struct{}{}:
defer func() { <-p.concurrencyLimitSemaphore }()
default:
http.Error(w, "Too many requests", http.StatusTooManyRequests)
return
}
p.state = StateStopped
}
func (p *Process) CurrentState() ProcessState {
p.stateMutex.RLock()
defer p.stateMutex.RUnlock()
return p.state
}
func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error {
if p.config.Proxy == "" {
return fmt.Errorf("no upstream available to check /health")
}
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
if checkEndpoint == "none" {
return nil
}
// keep default behaviour
if checkEndpoint == "" {
checkEndpoint = "/health"
}
proxyTo := p.config.Proxy
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
if err != nil {
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
}
client := &http.Client{}
startTime := time.Now()
for {
req, err := http.NewRequest("GET", healthURL, nil)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(ctxFromStart, time.Second)
defer cancel()
req = req.WithContext(ctx)
resp, err := client.Do(req)
ttl := (maxDuration - time.Since(startTime)).Seconds()
if err != nil {
// check if the context was cancelled
select {
case <-ctx.Done():
err := context.Cause(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
return err
}
default:
}
// wait a bit longer for TCP connection issues
if strings.Contains(err.Error(), "connection refused") {
fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
time.Sleep(5 * time.Second)
} else {
time.Sleep(time.Second)
}
if ttl < 0 {
return fmt.Errorf("failed to check health from: %s", healthURL)
}
continue
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil
}
if ttl < 0 {
return fmt.Errorf("failed to check health from: %s", healthURL)
}
time.Sleep(time.Second)
}
}
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
p.inFlightRequests.Add(1)
defer func() {
p.lastRequestHandled = time.Now()
p.inFlightRequests.Done()
}()
// start the process on demand
if p.CurrentState() != StateReady {
beginStartTime := time.Now()
if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusInternalServerError)
http.Error(w, errstr, http.StatusBadGateway)
return
}
startDuration = time.Since(beginStartTime)
}
proxyTo := p.config.Proxy
client := &http.Client{}
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
req.Header = r.Header.Clone()
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
if err == nil {
req.ContentLength = contentLength
}
resp, err := client.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
@@ -333,4 +469,84 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return
}
}
totalTime := time.Since(requestBeginTime)
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
p.ID, r.RequestURI, startDuration, totalTime)
}
// waitForCmd waits for the command to exit and handles exit conditions depending on current state
func (p *Process) waitForCmd() {
exitErr := p.cmd.Wait()
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
if exitErr != nil {
if errno, ok := exitErr.(syscall.Errno); ok {
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
} else if exitError, ok := exitErr.(*exec.ExitError); ok {
if strings.Contains(exitError.String(), "signal: terminated") {
p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID)
} else if strings.Contains(exitError.String(), "signal: interrupt") {
p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID)
} else {
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
}
} else {
if exitErr.Error() != "context canceled" /* this is normal */ {
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, exitErr)
}
}
}
currentState := p.CurrentState()
switch currentState {
case StateStopping:
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
p.state = StateStopped
}
default:
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
p.state = StateStopped // force it to be in this state
}
close(p.cmdWaitChan)
}
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
func (p *Process) cmdStopUpstreamProcess() error {
p.processLogger.Debugf("<%s> cmdStopUpstreamProcess() initiating graceful stop of upstream process", p.ID)
// this should never happen ...
if p.cmd == nil || p.cmd.Process == nil {
p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID)
return fmt.Errorf("<%s> process is nil or cmd is nil, skipping graceful stop", p.ID)
}
if p.config.CmdStop != "" {
// replace ${PID} with the pid of the process
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
if err != nil {
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
return err
}
p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " "))
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
stopCmd.Stdout = p.processLogger
stopCmd.Stderr = p.processLogger
stopCmd.Env = p.cmd.Env
if err := stopCmd.Run(); err != nil {
p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err)
return err
}
} else {
if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil {
p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err)
return err
}
}
return nil
}
+340 -13
View File
@@ -2,10 +2,10 @@ package proxy
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"runtime"
"sync"
"testing"
"time"
@@ -13,13 +13,26 @@ import (
"github.com/stretchr/testify/assert"
)
var (
debugLogger = NewLogMonitorWriter(os.Stdout)
)
func init() {
// flip to help with debugging tests
if false {
debugLogger.SetLogLevel(LevelDebug)
} else {
debugLogger.SetLogLevel(LevelError)
}
}
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage)
// Create a process
process := NewProcess("test-process", 5, config, logMonitor)
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
defer process.Stop()
req := httptest.NewRequest("GET", "/test", nil)
@@ -48,6 +61,32 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
}
}
// TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests
// are all handled successfully, even though they all may ask for the process to .start()
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
defer process.Stop()
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func(reqID int) {
defer wg.Done()
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID)
assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID)
}(i)
}
wg.Wait()
assert.Equal(t, StateReady, process.CurrentState())
}
// test that the automatic start returns the expected error type
func TestProcess_BrokenModelConfig(t *testing.T) {
// Create a process configuration
@@ -57,17 +96,20 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
CheckEndpoint: "/health",
}
process := NewProcess("broken", 1, config, NewLogMonitor())
defer process.Stop()
process := NewProcess("broken", 1, config, debugLogger, debugLogger)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "unable to start process")
w = httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "start() failed for command 'nonexistent-command':")
}
// test that the process unloads after the TTL
func TestProcess_UnloadAfterTTL(t *testing.T) {
if testing.Short() {
t.Skip("skipping long auto unload TTL test")
@@ -79,7 +121,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard))
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
defer process.Stop()
// this should take 4 seconds
@@ -111,7 +153,36 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
assert.Equal(t, StateStopped, process.CurrentState())
}
func TestProcess_LowTTLValue(t *testing.T) {
if true { // change this code to run this ...
t.Skip("skipping test, edit process_test.go to run it ")
}
config := getTestSimpleResponderConfig("fast_ttl")
assert.Equal(t, 0, config.UnloadAfter)
config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter)
process := NewProcess("ttl", 2, config, debugLogger, debugLogger)
defer process.Stop()
for i := 0; i < 100; i++ {
t.Logf("Waiting before sending request %d", i)
time.Sleep(1500 * time.Millisecond)
expected := fmt.Sprintf("echo=test_%d", i)
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), expected)
}
}
// issue #19
// This test makes sure using Process.Stop() does not affect pending HTTP
// requests. All HTTP requests in this test should complete successfully.
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
@@ -119,7 +190,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
expectedMessage := "12345"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
process := NewProcess("t", 10, config, debugLogger, debugLogger)
defer process.Stop()
results := map[string]string{
@@ -135,8 +206,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
wg.Add(1)
go func(key string) {
defer wg.Done()
// send a request that should take 5 * 200ms (1 second) to complete
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=200ms", key), nil)
// send a request where simple-responder is will wait 300ms before responding
// this will simulate an in-progress request.
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
@@ -152,9 +224,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
}(key)
}
// stop the requests in the middle
// Stop the process while requests are still being processed
go func() {
<-time.After(500 * time.Millisecond)
<-time.After(150 * time.Millisecond)
process.Stop()
}()
@@ -164,3 +236,258 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
assert.Equal(t, key, result)
}
}
func TestProcess_SwapState(t *testing.T) {
tests := []struct {
name string
currentState ProcessState
expectedState ProcessState
newState ProcessState
expectedError error
expectedResult ProcessState
}{
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, nil, StateStopped},
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger)
p.state = test.currentState
resultState, err := p.swapState(test.expectedState, test.newState)
if err != nil && test.expectedError == nil {
t.Errorf("Unexpected error: %v", err)
} else if err == nil && test.expectedError != nil {
t.Errorf("Expected error: %v, but got none", test.expectedError)
} else if err != nil && test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("Expected error: %v, got: %v", test.expectedError, err)
}
}
if resultState != test.expectedResult {
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
}
})
}
}
func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
if testing.Short() {
t.Skip("skipping long shutdown test")
}
expectedMessage := "testing91931"
// make a config where the healthcheck will always fail because port is wrong
config := getTestSimpleResponderConfigPort(expectedMessage, 9999)
config.Proxy = "http://localhost:9998/test"
healthCheckTTLSeconds := 30
process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger)
// make it a lot faster
process.healthCheckLoopInterval = time.Second
// start a goroutine to simulate a shutdown
var wg sync.WaitGroup
go func() {
defer wg.Done()
<-time.After(time.Millisecond * 500)
process.Shutdown()
}()
wg.Add(1)
// start the process, this is a blocking call
err := process.start()
wg.Wait()
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
assert.Equal(t, StateShutdown, process.CurrentState())
}
func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
if testing.Short() {
t.Skip("skipping Exit Interrupts Health Check test")
}
// should run and exit but interrupt the long checkHealthTimeout
checkHealthTimeout := 5
config := ModelConfig{
Cmd: "sleep 1",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
}
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
process.healthCheckLoopInterval = time.Second // make it faster
err := process.start()
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
assert.Equal(t, process.CurrentState(), StateStopped)
}
func TestProcess_ConcurrencyLimit(t *testing.T) {
if testing.Short() {
t.Skip("skipping long concurrency limit test")
}
expectedMessage := "concurrency_limit_test"
config := getTestSimpleResponderConfig(expectedMessage)
// only allow 1 concurrent request at a time
config.ConcurrencyLimit = 1
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore))
defer process.Stop()
// launch a goroutine first to take up the semaphore
go func() {
req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req1)
assert.Equal(t, http.StatusOK, w.Code)
}()
// let the goroutine start
<-time.After(time.Millisecond * 25)
denied := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, denied)
assert.Equal(t, http.StatusTooManyRequests, w.Code)
}
func TestProcess_StopImmediately(t *testing.T) {
expectedMessage := "test_stop_immediate"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
defer process.Stop()
err := process.start()
assert.Nil(t, err)
assert.Equal(t, process.CurrentState(), StateReady)
go func() {
// slow, but will get killed by StopImmediate
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
}()
<-time.After(time.Millisecond)
process.StopImmediately()
assert.Equal(t, process.CurrentState(), StateStopped)
}
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
// the upstream command
func TestProcess_ForceStopWithKill(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping SIGTERM test on Windows ")
}
expectedMessage := "test_sigkill"
binaryPath := getSimpleResponderPath()
port := getTestPort()
config := ModelConfig{
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
// to force the process to exit
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
}
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
defer process.Stop()
// reduce to make testing go faster
process.gracefulStopTimeout = time.Second
err := process.start()
assert.Nil(t, err)
assert.Equal(t, process.CurrentState(), StateReady)
waitChan := make(chan struct{})
go func() {
// slow, but will get killed by StopImmediate
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=2s", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
// StatusOK because that was already sent before the kill
assert.Equal(t, http.StatusOK, w.Code)
// unexpected EOF because the kill happened, the "1" is sent before the kill
// then the unexpected EOF is sent after the kill
if runtime.GOOS == "windows" {
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
} else {
assert.Contains(t, w.Body.String(), "unexpected EOF")
}
close(waitChan)
}()
<-time.After(time.Millisecond)
process.StopImmediately()
assert.Equal(t, process.CurrentState(), StateStopped)
// the request should have been interrupted by SIGKILL
<-waitChan
}
func TestProcess_StopCmd(t *testing.T) {
config := getTestSimpleResponderConfig("test_stop_cmd")
if runtime.GOOS == "windows" {
config.CmdStop = "taskkill /f /t /pid ${PID}"
} else {
config.CmdStop = "kill -TERM ${PID}"
}
process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger)
defer process.Stop()
err := process.start()
assert.Nil(t, err)
assert.Equal(t, process.CurrentState(), StateReady)
process.StopImmediately()
assert.Equal(t, process.CurrentState(), StateStopped)
}
func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
expectedMessage := "test_env_not_emptied"
config := getTestSimpleResponderConfig(expectedMessage)
// ensure that the the default config does not blank out the inherited environment
configWEnv := config
// ensure the additiona variables are appended to the process' environment
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
process1 := NewProcess("env_test", 2, config, debugLogger, debugLogger)
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
process1.start()
defer process1.Stop()
process2.start()
defer process2.Stop()
assert.NotZero(t, len(process1.cmd.Environ()))
assert.NotZero(t, len(process2.cmd.Environ()))
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
}
+114
View File
@@ -0,0 +1,114 @@
package proxy
import (
"fmt"
"net/http"
"slices"
"sync"
)
type ProcessGroup struct {
sync.Mutex
config Config
id string
swap bool
exclusive bool
persistent bool
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
// map of current processes
processes map[string]*Process
lastUsedProcess string
}
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
groupConfig, ok := config.Groups[id]
if !ok {
panic("Unable to find configuration for group id: " + id)
}
pg := &ProcessGroup{
id: id,
config: config,
swap: groupConfig.Swap,
exclusive: groupConfig.Exclusive,
persistent: groupConfig.Persistent,
proxyLogger: proxyLogger,
upstreamLogger: upstreamLogger,
processes: make(map[string]*Process),
}
// Create a Process for each member in the group
for _, modelID := range groupConfig.Members {
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
pg.processes[modelID] = process
}
return pg
}
// ProxyRequest proxies a request to the specified model
func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error {
if !pg.HasMember(modelID) {
return fmt.Errorf("model %s not part of group %s", modelID, pg.id)
}
if pg.swap {
pg.Lock()
if pg.lastUsedProcess != modelID {
if pg.lastUsedProcess != "" {
pg.processes[pg.lastUsedProcess].Stop()
}
pg.lastUsedProcess = modelID
}
pg.Unlock()
}
pg.processes[modelID].ProxyRequest(writer, request)
return nil
}
func (pg *ProcessGroup) HasMember(modelName string) bool {
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
}
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
pg.Lock()
defer pg.Unlock()
if len(pg.processes) == 0 {
return
}
// stop Processes in parallel
var wg sync.WaitGroup
for _, process := range pg.processes {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
switch strategy {
case StopImmediately:
process.StopImmediately()
default:
process.Stop()
}
}(process)
}
wg.Wait()
}
func (pg *ProcessGroup) Shutdown() {
var wg sync.WaitGroup
for _, process := range pg.processes {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Shutdown()
}(process)
}
wg.Wait()
}
+96
View File
@@ -0,0 +1,96 @@
package proxy
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
"model4": getTestSimpleResponderConfig("model4"),
"model5": getTestSimpleResponderConfig("model5"),
},
Groups: map[string]GroupConfig{
"G1": {
Swap: true,
Exclusive: true,
Members: []string{"model1", "model2"},
},
"G2": {
Swap: false,
Exclusive: true,
Members: []string{"model3", "model4"},
},
},
})
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
assert.True(t, pg.HasMember("model5"))
}
func TestProcessGroup_HasMember(t *testing.T) {
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
assert.True(t, pg.HasMember("model1"))
assert.True(t, pg.HasMember("model2"))
assert.False(t, pg.HasMember("model3"))
}
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses(StopWaitForInflightRequest)
tests := []string{"model1", "model2"}
for _, modelName := range tests {
t.Run(modelName, func(t *testing.T) {
reqBody := `{"x", "y"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
// make sure only one process is in the running state
count := 0
for _, process := range pg.processes {
if process.CurrentState() == StateReady {
count++
}
}
assert.Equal(t, 1, count)
})
}
}
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses(StopWaitForInflightRequest)
tests := []string{"model3", "model4"}
for _, modelName := range tests {
t.Run(modelName, func(t *testing.T) {
reqBody := `{"x", "y"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
})
}
// make sure all the processes are running
for _, process := range pg.processes {
assert.Equal(t, StateReady, process.CurrentState())
}
}
+474 -195
View File
@@ -2,11 +2,12 @@ package proxy
import (
"bytes"
"embed"
"encoding/json"
"context"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"sort"
"strconv"
"strings"
@@ -14,216 +15,342 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
PROFILE_SPLIT_CHAR = ":"
)
//go:embed html/favicon.ico
var faviconData []byte
//go:embed html/logs.html
var logsHTML []byte
// make sure embed is kept there by the IDE auto-package importer
var _ = embed.FS{}
type ProxyManager struct {
sync.Mutex
config *Config
currentProcesses map[string]*Process
logMonitor *LogMonitor
ginEngine *gin.Engine
config Config
ginEngine *gin.Engine
// logging
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
muxLogger *LogMonitor
metricsMonitor *MetricsMonitor
processGroups map[string]*ProcessGroup
// shutdown signaling
shutdownCtx context.Context
shutdownCancel context.CancelFunc
}
func New(config *Config) *ProxyManager {
pm := &ProxyManager{
config: config,
currentProcesses: make(map[string]*Process),
logMonitor: NewLogMonitor(),
ginEngine: gin.New(),
}
func New(config Config) *ProxyManager {
// set up loggers
stdoutLogger := NewLogMonitorWriter(os.Stdout)
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
proxyLogger := NewLogMonitorWriter(stdoutLogger)
if config.LogRequests {
pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
clientIP,
time.Now().Format("2006-01-02 15:04:05"),
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
}
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
case "debug":
proxyLogger.SetLogLevel(LevelDebug)
upstreamLogger.SetLogLevel(LevelDebug)
case "info":
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
case "warn":
proxyLogger.SetLogLevel(LevelWarn)
upstreamLogger.SetLogLevel(LevelWarn)
case "error":
proxyLogger.SetLogLevel(LevelError)
upstreamLogger.SetLogLevel(LevelError)
default:
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
}
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
pm := &ProxyManager{
config: config,
ginEngine: gin.New(),
proxyLogger: proxyLogger,
muxLogger: stdoutLogger,
upstreamLogger: upstreamLogger,
metricsMonitor: NewMetricsMonitor(&config),
processGroups: make(map[string]*ProcessGroup),
shutdownCtx: shutdownCtx,
shutdownCancel: shutdownCancel,
}
// create the process groups
for groupID := range config.Groups {
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup
}
pm.setupGinEngine()
return pm
}
func (pm *ProxyManager) setupGinEngine() {
pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
clientIP,
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
// see: issue: #81, #77 and #42 for CORS issues
// respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) {
if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
// allow whatever the client requested by default
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
c.Header("Access-Control-Allow-Headers", sanitized)
} else {
c.Header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, Accept, X-Requested-With",
)
}
c.Header("Access-Control-Max-Age", "86400")
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
})
mm := MetricsMiddleware(pm)
// Set up routes using the Gin engine
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler)
// Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler)
// Support embeddings
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler)
// Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
// in proxymanager_loghandlers.go
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
/**
* User Interface Endpoints
*/
pm.ginEngine.GET("/", func(c *gin.Context) {
c.Redirect(http.StatusFound, "/ui")
})
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
c.Redirect(http.StatusFound, "/ui/models")
})
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
c.Data(http.StatusOK, "image/x-icon", faviconData)
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
pm.ginEngine.GET("/health", func(c *gin.Context) {
c.String(http.StatusOK, "OK")
})
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil {
c.Data(http.StatusOK, "image/x-icon", data)
} else {
c.String(http.StatusInternalServerError, err.Error())
}
})
reactFS, err := GetReactFS()
if err != nil {
pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err)
} else {
// serve files that exist under /ui/*
pm.ginEngine.StaticFS("/ui", reactFS)
// server SPA for UI under /ui/*
pm.ginEngine.NoRoute(func(c *gin.Context) {
if !strings.HasPrefix(c.Request.URL.Path, "/ui") {
c.AbortWithStatus(http.StatusNotFound)
return
}
file, err := reactFS.Open("index.html")
if err != nil {
c.String(http.StatusInternalServerError, err.Error())
return
}
defer file.Close()
http.ServeContent(c.Writer, c.Request, "index.html", time.Now(), file)
})
}
// see: proxymanager_api.go
// add API handler functions
addApiHandlers(pm)
// Disable console color for testing
gin.DisableConsoleColor()
return pm
}
func (pm *ProxyManager) Run(addr ...string) error {
return pm.ginEngine.Run(addr...)
}
func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) {
// ServeHTTP implements http.Handler interface
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
pm.ginEngine.ServeHTTP(w, r)
}
func (pm *ProxyManager) StopProcesses() {
// StopProcesses acquires a lock and stops all running upstream processes.
// This is the public method safe for concurrent calls.
// Unlike Shutdown, this method only stops the processes but doesn't perform
// a complete shutdown, allowing for process replacement without full termination.
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
pm.Lock()
defer pm.Unlock()
pm.stopProcesses()
// stop Processes in parallel
var wg sync.WaitGroup
for _, processGroup := range pm.processGroups {
wg.Add(1)
go func(processGroup *ProcessGroup) {
defer wg.Done()
processGroup.StopProcesses(strategy)
}(processGroup)
}
wg.Wait()
}
// for internal usage
func (pm *ProxyManager) stopProcesses() {
if len(pm.currentProcesses) == 0 {
return
// Shutdown stops all processes managed by this ProxyManager
func (pm *ProxyManager) Shutdown() {
pm.Lock()
defer pm.Unlock()
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
var wg sync.WaitGroup
// Send shutdown signal to all process in groups
for _, processGroup := range pm.processGroups {
wg.Add(1)
go func(processGroup *ProcessGroup) {
defer wg.Done()
processGroup.Shutdown()
}(processGroup)
}
wg.Wait()
pm.shutdownCancel()
}
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
}
for _, process := range pm.currentProcesses {
process.Stop()
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
}
pm.currentProcesses = make(map[string]*Process)
if processGroup.exclusive {
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
for groupId, otherGroup := range pm.processGroups {
if groupId != processGroup.id && !otherGroup.persistent {
otherGroup.StopProcesses(StopWaitForInflightRequest)
}
}
}
return processGroup, realModelName, nil
}
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := []interface{}{}
data := make([]gin.H, 0, len(pm.config.Models))
createdTime := time.Now().Unix()
for id, modelConfig := range pm.config.Models {
if modelConfig.Unlisted {
continue
}
data = append(data, map[string]interface{}{
record := gin.H{
"id": id,
"object": "model",
"created": time.Now().Unix(),
"created": createdTime,
"owned_by": "llama-swap",
})
}
if name := strings.TrimSpace(modelConfig.Name); name != "" {
record["name"] = name
}
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
record["description"] = desc
}
data = append(data, record)
}
// Set the Content-Type header to application/json
c.Header("Content-Type", "application/json")
// Sort by the "id" key
sort.Slice(data, func(i, j int) bool {
si, _ := data[i]["id"].(string)
sj, _ := data[j]["id"].(string)
return si < sj
})
if origin := c.Request.Header.Get("Origin"); origin != "" {
// Set CORS headers if origin exists
if origin := c.GetHeader("Origin"); origin != "" {
c.Header("Access-Control-Allow-Origin", origin)
}
// Encode the data as JSON and write it to the response writer
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
return
}
}
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
pm.Lock()
defer pm.Unlock()
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
}
if profileName != "" {
if _, found := pm.config.Profiles[profileName]; !found {
return nil, fmt.Errorf("model group not found %s", profileName)
}
}
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(modelName)
if !found {
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
}
// exit early when already running, otherwise stop everything and swap
requestedProcessKey := ProcessKeyName(profileName, realModelName)
if process, found := pm.currentProcesses[requestedProcessKey]; found {
return process, nil
}
// stop all running models
pm.stopProcesses()
if profileName == "" {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
} else {
for _, modelName := range pm.config.Profiles[profileName] {
if realModelName, found := pm.config.RealModelName(modelName); found {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
}
}
}
// requestedProcessKey should exist due to swap
return pm.currentProcesses[requestedProcessKey], nil
// Use gin's JSON method which handles content-type and encoding
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": data,
})
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
@@ -234,38 +361,15 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
return
}
if process, err := pm.swapModel(requestedModel); err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
} else {
// rewrite the path
c.Request.URL.Path = c.Param("upstreamPath")
process.ProxyRequest(c.Writer, c.Request)
processGroup, _, err := pm.swapProcessGroup(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
}
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
var html strings.Builder
html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><ul>")
// Extract keys and sort them
var modelIDs []string
for modelID, modelConfig := range pm.config.Models {
if modelConfig.Unlisted {
continue
}
modelIDs = append(modelIDs, modelID)
}
sort.Strings(modelIDs)
// Iterate over sorted keys
for _, modelID := range modelIDs {
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a></li>", modelID, modelID))
}
html.WriteString("</ul></body></html>")
c.Header("Content-Type", "text/html")
c.String(http.StatusOK, html.String())
// rewrite the path
c.Request.URL.Path = c.Param("upstreamPath")
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
}
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
@@ -275,28 +379,170 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
return
}
var requestBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
return
}
model, ok := requestBody["model"].(string)
if !ok {
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
}
if process, err := pm.swapModel(model); err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
return
}
processGroup, _, err := pm.swapProcessGroup(realModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
// issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[realModelName].UseModelName
if useModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
return
}
}
// issue #174 strip parameters from the JSON body
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams()
if err != nil { // just log it and continue
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error())
} else {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
return
}
}
}
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
process.ProxyRequest(c.Writer, c.Request)
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.ContentLength = int64(len(bodyBytes))
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
}
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
return
}
// Get model parameter from the form
requestedModel := c.Request.FormValue("model")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
return
}
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" {
// # issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[realModelName].UseModelName
if useModelName != "" {
fieldValue = useModelName
} else {
fieldValue = requestedModel
}
}
field, err := multipartWriter.CreateFormField(key)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
return
}
if _, err = field.Write([]byte(fieldValue)); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
return
}
}
}
// Copy all files from the original request
for key, fileHeaders := range c.Request.MultipartForm.File {
for _, fileHeader := range fileHeaders {
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
return
}
file, err := fileHeader.Open()
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
return
}
if _, err = io.Copy(formFile, file); err != nil {
file.Close()
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
return
}
file.Close()
}
}
// Close the multipart writer to finalize the form
if err := multipartWriter.Close(); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
return
}
// Create a new request with the reconstructed form data
modifiedReq, err := http.NewRequestWithContext(
c.Request.Context(),
c.Request.Method,
c.Request.URL.String(),
&requestBuffer,
)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
return
}
// Copy the headers from the original request
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// set the content length of the body
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
}
@@ -310,6 +556,39 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
}
}
func ProcessKeyName(groupName, modelName string) string {
return groupName + PROFILE_SPLIT_CHAR + modelName
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
pm.StopProcesses(StopImmediately)
c.String(http.StatusOK, "OK")
}
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response.
for _, processGroup := range pm.processGroups {
for _, process := range processGroup.processes {
if process.CurrentState() == StateReady {
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
})
}
}
}
// Put the results under the `running` key.
response := gin.H{
"running": runningProcesses,
}
context.JSON(http.StatusOK, response) // Always return 200 OK
}
func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
for _, group := range pm.processGroups {
if group.HasMember(modelName) {
return group
}
}
return nil
}
+204
View File
@@ -0,0 +1,204 @@
package proxy
import (
"context"
"encoding/json"
"net/http"
"sort"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
)
type Model struct {
Id string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
State string `json:"state"`
Unlisted bool `json:"unlisted"`
}
func addApiHandlers(pm *ProxyManager) {
// Add API endpoints for React to consume
apiGroup := pm.ginEngine.Group("/api")
{
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
apiGroup.GET("/events", pm.apiSendEvents)
apiGroup.GET("/metrics", pm.apiGetMetrics)
}
}
func (pm *ProxyManager) apiUnloadAllModels(c *gin.Context) {
pm.StopProcesses(StopImmediately)
c.JSON(http.StatusOK, gin.H{"msg": "ok"})
}
func (pm *ProxyManager) getModelStatus() []Model {
// Extract keys and sort them
models := []Model{}
modelIDs := make([]string, 0, len(pm.config.Models))
for modelID := range pm.config.Models {
modelIDs = append(modelIDs, modelID)
}
sort.Strings(modelIDs)
// Iterate over sorted keys
for _, modelID := range modelIDs {
// Get process state
processGroup := pm.findGroupByModelName(modelID)
state := "unknown"
if processGroup != nil {
process := processGroup.processes[modelID]
if process != nil {
var stateStr string
switch process.CurrentState() {
case StateReady:
stateStr = "ready"
case StateStarting:
stateStr = "starting"
case StateStopping:
stateStr = "stopping"
case StateShutdown:
stateStr = "shutdown"
case StateStopped:
stateStr = "stopped"
default:
stateStr = "unknown"
}
state = stateStr
}
}
models = append(models, Model{
Id: modelID,
Name: pm.config.Models[modelID].Name,
Description: pm.config.Models[modelID].Description,
State: state,
Unlisted: pm.config.Models[modelID].Unlisted,
})
}
return models
}
type messageType string
const (
msgTypeModelStatus messageType = "modelStatus"
msgTypeLogData messageType = "logData"
msgTypeMetrics messageType = "metrics"
)
type messageEnvelope struct {
Type messageType `json:"type"`
Data string `json:"data"`
}
// sends a stream of different message types that happen on the server
func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Content-Type-Options", "nosniff")
sendBuffer := make(chan messageEnvelope, 25)
ctx, cancel := context.WithCancel(c.Request.Context())
sendModels := func() {
data, err := json.Marshal(pm.getModelStatus())
if err == nil {
msg := messageEnvelope{Type: msgTypeModelStatus, Data: string(data)}
select {
case sendBuffer <- msg:
case <-ctx.Done():
return
default:
}
}
}
sendLogData := func(source string, data []byte) {
data, err := json.Marshal(gin.H{
"source": source,
"data": string(data),
})
if err == nil {
select {
case sendBuffer <- messageEnvelope{Type: msgTypeLogData, Data: string(data)}:
case <-ctx.Done():
return
default:
}
}
}
sendMetrics := func(metrics TokenMetrics) {
jsonData, err := json.Marshal(metrics)
if err == nil {
select {
case sendBuffer <- messageEnvelope{Type: msgTypeMetrics, Data: string(jsonData)}:
case <-ctx.Done():
return
default:
}
}
}
/**
* Send updated models list
*/
defer event.On(func(e ProcessStateChangeEvent) {
sendModels()
})()
defer event.On(func(e ConfigFileChangedEvent) {
sendModels()
})()
/**
* Send Log data
*/
defer pm.proxyLogger.OnLogData(func(data []byte) {
sendLogData("proxy", data)
})()
defer pm.upstreamLogger.OnLogData(func(data []byte) {
sendLogData("upstream", data)
})()
/**
* Send Metrics data
*/
defer event.On(func(e TokenMetricsEvent) {
sendMetrics(e.Metrics)
})()
// send initial batch of data
sendLogData("proxy", pm.proxyLogger.GetHistory())
sendLogData("upstream", pm.upstreamLogger.GetHistory())
sendModels()
for _, metrics := range pm.metricsMonitor.GetMetrics() {
sendMetrics(metrics)
}
for {
select {
case <-c.Request.Context().Done():
cancel()
return
case <-pm.shutdownCtx.Done():
cancel()
return
case msg := <-sendBuffer:
c.SSEvent("message", msg)
c.Writer.Flush()
}
}
}
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
jsonData, err := pm.metricsMonitor.GetMetricsJSON()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
return
}
c.Data(http.StatusOK, "application/json", jsonData)
}
+44 -58
View File
@@ -1,6 +1,7 @@
package proxy
import (
"context"
"fmt"
"net/http"
"strings"
@@ -9,21 +10,12 @@ import (
)
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
accept := c.GetHeader("Accept")
if strings.Contains(accept, "text/html") {
// Set the Content-Type header to text/html
c.Header("Content-Type", "text/html")
// Write the embedded HTML content to the response
_, err := c.Writer.Write(logsHTML)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to write response: %v", err))
return
}
c.Redirect(http.StatusFound, "/ui/")
} else {
c.Header("Content-Type", "text/plain")
history := pm.logMonitor.GetHistory()
history := pm.muxLogger.GetHistory()
_, err := c.Writer.Write(history)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
@@ -37,13 +29,16 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
c.Header("Transfer-Encoding", "chunked")
c.Header("X-Content-Type-Options", "nosniff")
ch := pm.logMonitor.Subscribe()
defer pm.logMonitor.Unsubscribe(ch)
logMonitorId := c.Param("logMonitorID")
logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
notify := c.Request.Context().Done()
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("Streaming unsupported"))
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
return
}
@@ -51,62 +46,53 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// Send history first if not skipped
if !skipHistory {
history := pm.logMonitor.GetHistory()
history := logger.GetHistory()
if len(history) != 0 {
_, err := c.Writer.Write(history)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
c.Writer.Write(history)
flusher.Flush()
}
}
// Stream new logs
sendChan := make(chan []byte, 10)
ctx, cancel := context.WithCancel(c.Request.Context())
defer logger.OnLogData(func(data []byte) {
select {
case sendChan <- data:
case <-ctx.Done():
return
default:
}
})()
for {
select {
case msg := <-ch:
_, err := c.Writer.Write(msg)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
case <-c.Request.Context().Done():
cancel()
return
case <-pm.shutdownCtx.Done():
cancel()
return
case data := <-sendChan:
c.Writer.Write(data)
flusher.Flush()
case <-notify:
return
}
}
}
func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Content-Type-Options", "nosniff")
// getLogger searches for the appropriate logger based on the logMonitorId
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
var logger *LogMonitor
ch := pm.logMonitor.Subscribe()
defer pm.logMonitor.Unsubscribe(ch)
notify := c.Request.Context().Done()
// Send history first if not skipped
_, skipHistory := c.GetQuery("no-history")
if !skipHistory {
history := pm.logMonitor.GetHistory()
if len(history) != 0 {
c.SSEvent("message", string(history))
c.Writer.Flush()
}
if logMonitorId == "" {
// maintain the default
logger = pm.muxLogger
} else if logMonitorId == "proxy" {
logger = pm.proxyLogger
} else if logMonitorId == "upstream" {
logger = pm.upstreamLogger
} else {
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
}
// Stream new logs
for {
select {
case msg := <-ch:
c.SSEvent("message", string(msg))
c.Writer.Flush()
case <-notify:
return
}
}
return logger, nil
}
+656 -48
View File
@@ -4,87 +4,124 @@ import (
"bytes"
"encoding/json"
"fmt"
"math/rand"
"mime/multipart"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
)
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := &Config{
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
}
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)
for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
}
// make sure there's only one loaded model
assert.Len(t, proxy.currentProcesses, 1)
}
func TestProxyManager_SwapMultiProcess(t *testing.T) {
model1 := "path1/model1"
model2 := "path2/model2"
profileModel1 := ProcessKeyName("test", model1)
profileModel2 := ProcessKeyName("test", model2)
config := &Config{
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
model1: getTestSimpleResponderConfig("model1"),
model2: getTestSimpleResponderConfig("model2"),
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {model1, model2},
LogLevel: "error",
Groups: map[string]GroupConfig{
"G1": {
Swap: true,
Exclusive: false,
Members: []string{"model1"},
},
"G2": {
Swap: true,
Exclusive: false,
Members: []string{"model2"},
},
},
}
})
proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)
for modelID, requestedModel := range map[string]string{
"model1": profileModel1,
"model2": profileModel2,
} {
tests := []string{"model1", "model2"}
for _, requestedModel := range tests {
t.Run(requestedModel, func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), requestedModel)
})
}
// make sure there's two loaded models
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
}
// Test that a persistent group is not affected by the swapping behaviour of
// other groups.
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
"model2": getTestSimpleResponderConfig("model2"),
},
LogLevel: "error",
Groups: map[string]GroupConfig{
// the forever group is persistent and should not be affected by model1
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model2"},
},
},
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
// make requests to load all models, loading model1 should not affect model2
tests := []string{"model2", "model1"}
for _, requestedModel := range tests {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelID)
assert.Contains(t, w.Body.String(), requestedModel)
}
// make sure there's two loaded models
assert.Len(t, proxy.currentProcesses, 2)
_, exists := proxy.currentProcesses[profileModel1]
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
_, exists = proxy.currentProcesses[profileModel2]
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
}
// When a request for a different model comes in ProxyManager should wait until
@@ -94,17 +131,18 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
t.Skip("skipping slow test")
}
config := &Config{
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
},
}
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)
results := map[string]string{}
@@ -120,15 +158,18 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
proxy.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
}
mu.Lock()
results[key] = w.Body.String()
var response map[string]interface{}
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
result, ok := response["responseMessage"].(string)
assert.Equal(t, ok, true)
results[key] = result
mu.Unlock()
}(key)
@@ -144,13 +185,23 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
}
func TestProxyManager_ListModelsHandler(t *testing.T) {
config := &Config{
model1Config := getTestSimpleResponderConfig("model1")
model1Config.Name = "Model 1"
model1Config.Description = "Model 1 description is used for testing"
model2Config := getTestSimpleResponderConfig("model2")
model2Config.Name = " " // empty whitespace only strings will get ignored
model2Config.Description = " "
config := Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model1": model1Config,
"model2": model2Config,
"model3": getTestSimpleResponderConfig("model3"),
},
LogLevel: "error",
}
proxy := New(config)
@@ -161,7 +212,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
w := httptest.NewRecorder()
// Call the listModelsHandler
proxy.HandlerFunc(w, req)
proxy.ServeHTTP(w, req)
// Check the response status code
assert.Equal(t, http.StatusOK, w.Code)
@@ -173,6 +224,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
var response struct {
Data []map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse JSON response: %v", err)
}
@@ -187,6 +239,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
"model3": {},
}
// make all models
for _, model := range response.Data {
modelID, ok := model["id"].(string)
assert.True(t, ok, "model ID should be a string")
@@ -205,8 +258,563 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
ownedBy, ok := model["owned_by"].(string)
assert.True(t, ok, "owned_by should be a string")
assert.Equal(t, "llama-swap", ownedBy)
// check for optional name and description
if modelID == "model1" {
name, ok := model["name"].(string)
assert.True(t, ok, "name should be a string")
assert.Equal(t, "Model 1", name)
description, ok := model["description"].(string)
assert.True(t, ok, "description should be a string")
assert.Equal(t, "Model 1 description is used for testing", description)
} else {
_, exists := model["name"]
assert.False(t, exists, "unexpected name field for model: %s", modelID)
_, exists = model["description"]
assert.False(t, exists, "unexpected description field for model: %s", modelID)
}
}
// Ensure all expected models were returned
assert.Empty(t, expectedModels, "not all expected models were returned")
}
func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
// Intentionally add models in non-sorted order and with an unlisted model
config := Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"zeta": getTestSimpleResponderConfig("zeta"),
"alpha": getTestSimpleResponderConfig("alpha"),
"beta": getTestSimpleResponderConfig("beta"),
"hidden": func() ModelConfig {
mc := getTestSimpleResponderConfig("hidden")
mc.Unlisted = true
return mc
}(),
},
LogLevel: "error",
}
proxy := New(config)
// Request models list
req := httptest.NewRequest("GET", "/v1/models", nil)
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response struct {
Data []map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse JSON response: %v", err)
}
// We expect only the listed models in sorted order by id
expectedOrder := []string{"alpha", "beta", "zeta"}
if assert.Len(t, response.Data, len(expectedOrder), "unexpected number of listed models") {
got := make([]string, 0, len(response.Data))
for _, m := range response.Data {
id, _ := m["id"].(string)
got = append(got, id)
}
assert.Equal(t, expectedOrder, got, "models should be sorted by id ascending")
}
}
func TestProxyManager_Shutdown(t *testing.T) {
// make broken model configurations
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
model1Config.Proxy = "http://localhost:10001/"
model2Config := getTestSimpleResponderConfigPort("model2", 9992)
model2Config.Proxy = "http://localhost:10002/"
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
model3Config.Proxy = "http://localhost:10003/"
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": model1Config,
"model2": model2Config,
"model3": model3Config,
},
LogLevel: "error",
Groups: map[string]GroupConfig{
"test": {
Swap: false,
Members: []string{"model1", "model2", "model3"},
},
},
})
proxy := New(config)
// Start all the processes
var wg sync.WaitGroup
for _, modelName := range []string{"model1", "model2", "model3"} {
wg.Add(1)
go func(modelName string) {
defer wg.Done()
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
// send a request to trigger the proxy to load ... this should hang waiting for start up
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
}(modelName)
}
go func() {
<-time.After(time.Second)
proxy.Shutdown()
}()
wg.Wait()
}
func TestProxyManager_Unload(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
req = httptest.NewRequest("GET", "/unload", nil)
w = httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK")
// give it a bit of time to stop
<-time.After(time.Millisecond * 250)
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
}
// Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
LogLevel: "warn",
})
// Define a helper struct to parse the JSON response.
type RunningResponse struct {
Running []struct {
Model string `json:"model"`
State string `json:"state"`
} `json:"running"`
}
// Create proxy once for all tests
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
t.Run("no models loaded", func(t *testing.T) {
req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response RunningResponse
// Check if this is a valid JSON object.
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// We should have an empty running array here.
assert.Empty(t, response.Running, "expected no running models")
})
t.Run("single model loaded", func(t *testing.T) {
// Load just a model.
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Simulate browser call for the `/running` endpoint.
req = httptest.NewRequest("GET", "/running", nil)
w = httptest.NewRecorder()
proxy.ServeHTTP(w, req)
var response RunningResponse
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// Check if we have a single array element.
assert.Len(t, response.Running, 1)
// Is this the right model?
assert.Equal(t, "model1", response.Running[0].Model)
// Is the model loaded?
assert.Equal(t, "ready", response.Running[0].State)
})
}
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte("TheExpectedModel"))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
// Generate random content length between 10 and 20
contentLength := rand.Intn(11) + 10 // 10 to 20
content := make([]byte, contentLength)
_, err = fw.Write(content)
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, "TheExpectedModel", response["model"])
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"])
}
// Test useModelName in configuration sends overrides what is sent to upstream
func TestProxyManager_UseModelName(t *testing.T) {
upstreamModelName := "upstreamModel"
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
modelConfig.UseModelName = upstreamModelName
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": modelConfig,
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
requestedModel := "model1"
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), upstreamModelName)
// make sure the content length was set correctly
// simple-responder will return the content length it got in the response
body := w.Body.Bytes()
contentLength := int(gjson.GetBytes(body, "h_content_length").Int())
assert.Equal(t, len(fmt.Sprintf(`{"model":"%s"}`, upstreamModelName)), contentLength)
})
t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(requestedModel))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
_, err = fw.Write([]byte("test"))
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, upstreamModelName, response["model"])
})
}
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
tests := []struct {
name string
method string
requestHeaders map[string]string
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "OPTIONS with no headers",
method: "OPTIONS",
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
},
},
{
name: "OPTIONS with specific headers",
method: "OPTIONS",
requestHeaders: map[string]string{
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
},
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
},
},
{
name: "Non-OPTIONS request",
method: "GET",
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
for k, v := range tt.requestHeaders {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
for header, expectedValue := range tt.expectedHeaders {
assert.Equal(t, expectedValue, w.Header().Get(header))
}
})
}
}
func TestProxyManager_Upstream(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String())
}
func TestProxyManager_ChatContentLength(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response map[string]interface{}
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
assert.Equal(t, "81", response["h_content_length"])
assert.Equal(t, "model1", response["responseMessage"])
}
func TestProxyManager_FiltersStripParams(t *testing.T) {
modelConfig := getTestSimpleResponderConfig("model1")
modelConfig.Filters = ModelFilters{
StripParams: "temperature, model, stream",
}
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
LogLevel: "error",
Models: map[string]ModelConfig{
"model1": modelConfig,
},
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response map[string]interface{}
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// `temperature` and `stream` are gone but model remains
assert.Equal(t, `{"model":"model1", "x_param":"123", "y_param":"abc"}`, response["request_body"])
// assert.Nil(t, response["temperature"])
// assert.Equal(t, "123", response["x_param"])
// assert.Equal(t, "abc", response["y_param"])
// t.Logf("%v", response)
}
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
// Make a non-streaming request
reqBody := `{"model":"model1", "stream": false}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Check that metrics were recorded
metrics := proxy.metricsMonitor.GetMetrics()
if !assert.NotEmpty(t, metrics, "metrics should be recorded for non-streaming request") {
return
}
// Verify the last metric has the correct model
lastMetric := metrics[len(metrics)-1]
assert.Equal(t, "model1", lastMetric.Model)
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
}
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
// Make a streaming request
reqBody := `{"model":"model1", "stream": true}`
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Check that metrics were recorded
metrics := proxy.metricsMonitor.GetMetrics()
if !assert.NotEmpty(t, metrics, "metrics should be recorded for streaming request") {
return
}
// Verify the last metric has the correct model
lastMetric := metrics[len(metrics)-1]
assert.Equal(t, "model1", lastMetric.Model)
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
}
func TestProxyManager_HealthEndpoint(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "OK", rec.Body.String())
}
+43
View File
@@ -0,0 +1,43 @@
package proxy
import (
"strings"
)
func isTokenChar(r rune) bool {
switch {
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r >= '0' && r <= '9':
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
default:
return false
}
return true
}
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
parts := strings.Split(headerValues, ",")
valid := make([]string, 0, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v == "" {
continue
}
validPart := true
for _, c := range v {
if !isTokenChar(c) {
validPart = false
break
}
}
if validPart {
valid = append(valid, v)
}
}
return strings.Join(valid, ", ")
}
+77
View File
@@ -0,0 +1,77 @@
package proxy
import "testing"
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "empty string",
input: "",
expected: "",
},
{
name: "whitespace only",
input: " ",
expected: "",
},
{
name: "single valid value",
input: "content-type",
expected: "content-type",
},
{
name: "multiple valid values",
input: "content-type, authorization, x-requested-with",
expected: "content-type, authorization, x-requested-with",
},
{
name: "values with extra spaces",
input: " content-type , authorization ",
expected: "content-type, authorization",
},
{
name: "values with tabs",
input: "content-type,\tauthorization",
expected: "content-type, authorization",
},
{
name: "values with invalid characters",
input: "content-type, auth\n, x-requested-with\r",
expected: "content-type, auth, x-requested-with",
},
{
name: "empty values in list",
input: "content-type,,authorization",
expected: "content-type, authorization",
},
{
name: "leading and trailing commas",
input: ",content-type,authorization,",
expected: "content-type, authorization",
},
{
name: "mixed valid and invalid values",
input: "content-type, \x00invalid, x-requested-with",
expected: "content-type, x-requested-with",
},
{
name: "mixed case values",
input: "Content-Type, my-Valid-Header, Another-hEader",
expected: "Content-Type, my-Valid-Header, Another-hEader",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SanitizeAccessControlRequestHeaderValues(tt.input)
if got != tt.expected {
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
tt.input, got, tt.expected)
}
})
}
}
+24
View File
@@ -0,0 +1,24 @@
package proxy
import (
"embed"
"io/fs"
"net/http"
)
//go:embed ui_dist
var reactStaticFS embed.FS
// GetReactFS returns the embedded React filesystem
func GetReactFS() (http.FileSystem, error) {
subFS, err := fs.Sub(reactStaticFS, "ui_dist")
if err != nil {
return nil, err
}
return http.FS(subFS), nil
}
// GetReactIndexHTML returns the main index.html for the React app
func GetReactIndexHTML() ([]byte, error) {
return reactStaticFS.ReadFile("ui_dist/index.html")
}
+213
View File
@@ -0,0 +1,213 @@
#!/bin/sh
# This script installs llama-swap on Linux.
# It detects the current operating system architecture and installs the appropriate version of llama-swap.
set -eu
LLAMA_SWAP_DEFAULT_ADDRESS=${LLAMA_SWAP_DEFAULT_ADDRESS:-"127.0.0.1:8080"}
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
status() { echo ">>> $*" >&2; }
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
warning() { echo "${red}WARNING:${plain} $*"; }
available() { command -v "$1" >/dev/null; }
require() {
_MISSING=''
for TOOL in "$@"; do
if ! available "$TOOL"; then
_MISSING="$_MISSING $TOOL"
fi
done
echo "$_MISSING"
}
SUDO=
if [ "$(id -u)" -ne 0 ]; then
if ! available sudo; then
error "This script requires superuser permissions. Please re-run as root."
fi
SUDO="sudo"
fi
NEEDS=$(require tee tar python3 mktemp)
if [ -n "$NEEDS" ]; then
status "ERROR: The following tools are required but missing:"
for NEED in $NEEDS; do
echo " - $NEED"
done
exit 1
fi
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
ARCH=$(uname -m)
case "$ARCH" in
x86_64) ARCH="amd64" ;;
aarch64|arm64) ARCH="arm64" ;;
*) error "Unsupported architecture: $ARCH" ;;
esac
IS_WSL2=false
KERN=$(uname -r)
case "$KERN" in
*icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;;
*icrosoft) error "Microsoft WSL1 is not currently supported. Please use WSL2 with 'wsl --set-version <distro> 2'" ;;
*) ;;
esac
download_binary() {
ASSET_NAME="linux_$ARCH"
TMPDIR=$(mktemp -d)
trap 'rm -rf "${TMPDIR}"' EXIT INT TERM HUP
PYTHON_SCRIPT=$(cat <<EOF
import os
import json
import sys
import urllib.request
ASSET_NAME = "${ASSET_NAME}"
with urllib.request.urlopen("https://api.github.com/repos/mostlygeek/llama-swap/releases/latest") as resp:
data = json.load(resp)
for asset in data.get("assets", []):
if ASSET_NAME in asset.get("name", ""):
url = asset["browser_download_url"]
break
else:
print("ERROR: Matching asset not found.", file=sys.stderr)
exit(1)
print("Downloading:", url, file=sys.stderr)
output_path = os.path.join("${TMPDIR}", "llama-swap.tar.gz")
urllib.request.urlretrieve(url, output_path)
print(output_path)
EOF
)
TARFILE=$(python3 -c "$PYTHON_SCRIPT")
if [ ! -f "$TARFILE" ]; then
error "Failed to download binary."
fi
status "Extracting to /usr/local/bin"
$SUDO tar -xzf "$TARFILE" -C /usr/local/bin llama-swap
}
download_binary
configure_systemd() {
if ! id llama-swap >/dev/null 2>&1; then
status "Creating llama-swap user..."
$SUDO useradd -r -s /bin/false -U -m -d /usr/share/llama-swap llama-swap
fi
if getent group render >/dev/null 2>&1; then
status "Adding llama-swap user to render group..."
$SUDO usermod -a -G render llama-swap
fi
if getent group video >/dev/null 2>&1; then
status "Adding llama-swap user to video group..."
$SUDO usermod -a -G video llama-swap
fi
if getent group docker >/dev/null 2>&1; then
status "Adding llama-swap user to docker group..."
$SUDO usermod -a -G docker llama-swap
fi
status "Adding current user to llama-swap group..."
$SUDO usermod -a -G llama-swap "$(whoami)"
if [ ! -f "/usr/share/llama-swap/config.yaml" ]; then
status "Creating default config.yaml..."
cat <<EOF | $SUDO -u llama-swap tee /usr/share/llama-swap/config.yaml >/dev/null
# default 15s likely to fail for default models due to downloading models
healthCheckTimeout: 60
models:
"qwen2.5":
cmd: |
docker run
--rm
-p \${PORT}:8080
--name qwen2.5
ghcr.io/ggml-org/llama.cpp:server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
cmdStop: docker stop qwen2.5
"smollm2":
cmd: |
docker run
--rm
-p \${PORT}:8080
--name smollm2
ghcr.io/ggml-org/llama.cpp:server
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
cmdStop: docker stop smollm2
EOF
fi
status "Creating llama-swap systemd service..."
cat <<EOF | $SUDO tee /etc/systemd/system/llama-swap.service >/dev/null
[Unit]
Description=llama-swap
After=network.target
[Service]
User=llama-swap
Group=llama-swap
# set this to match your environment
ExecStart=/usr/local/bin/llama-swap --config /usr/share/llama-swap/config.yaml --watch-config -listen ${LLAMA_SWAP_DEFAULT_ADDRESS}
Restart=on-failure
RestartSec=3
StartLimitBurst=3
StartLimitInterval=30
[Install]
WantedBy=multi-user.target
EOF
SYSTEMCTL_RUNNING="$(systemctl is-system-running || true)"
case $SYSTEMCTL_RUNNING in
running|degraded)
status "Enabling and starting llama-swap service..."
$SUDO systemctl daemon-reload
$SUDO systemctl enable llama-swap
start_service() { $SUDO systemctl restart llama-swap; }
trap start_service EXIT
;;
*)
warning "systemd is not running"
if [ "$IS_WSL2" = true ]; then
warning "see https://learn.microsoft.com/en-us/windows/wsl/systemd#how-to-enable-systemd to enable it"
fi
;;
esac
}
if available systemctl; then
configure_systemd
fi
install_success() {
status "The llama-swap API is now available at http://${LLAMA_SWAP_DEFAULT_ADDRESS}"
status 'Customize the config file at /usr/share/llama-swap/config.yaml.'
status 'Install complete.'
}
# WSL2 only supports GPUs via nvidia passthrough
# so check for nvidia-smi to determine if GPU is available
if [ "$IS_WSL2" = true ]; then
if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
status "Nvidia GPU detected."
fi
exit 0
fi
install_success
+68
View File
@@ -0,0 +1,68 @@
#!/bin/sh
# This script uninstalls llama-swap on Linux.
# It removes the binary, systemd service, config.yaml (optional), and llama-swap user and group.
set -eu
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
status() { echo ">>> $*" >&2; }
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
warning() { echo "${red}WARNING:${plain} $*"; }
available() { command -v $1 >/dev/null; }
SUDO=
if [ "$(id -u)" -ne 0 ]; then
if ! available sudo; then
error "This script requires superuser permissions. Please re-run as root."
fi
SUDO="sudo"
fi
configure_systemd() {
status "Stopping llama-swap service..."
$SUDO systemctl stop llama-swap
status "Disabling llama-swap service..."
$SUDO systemctl disable llama-swap
}
if available systemctl; then
configure_systemd
fi
if available llama-swap; then
status "Removing llama-swap binary..."
$SUDO rm $(which llama-swap)
fi
if [ -f "/usr/share/llama-swap/config.yaml" ]; then
while true; do
printf "Delete config.yaml (/usr/share/llama-swap/config.yaml)? [y/N] " >&2
read answer
case "$answer" in
[Yy]* )
$SUDO rm -r /usr/share/llama-swap
break
;;
[Nn]* | "" )
break
;;
* )
echo "Invalid input. Please enter y or n."
;;
esac
done
fi
if id llama-swap >/dev/null 2>&1; then
status "Removing llama-swap user..."
$SUDO userdel llama-swap
fi
if getent group llama-swap >/dev/null 2>&1; then
status "Removing llama-swap group..."
$SUDO groupdel llama-swap
fi
+25
View File
@@ -0,0 +1,25 @@
.vite
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?
+54
View File
@@ -0,0 +1,54 @@
# React + TypeScript + Vite
This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
Currently, two official plugins are available:
- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Babel](https://babeljs.io/) for Fast Refresh
- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
## Expanding the ESLint configuration
If you are developing a production application, we recommend updating the configuration to enable type-aware lint rules:
```js
export default tseslint.config({
extends: [
// Remove ...tseslint.configs.recommended and replace with this
...tseslint.configs.recommendedTypeChecked,
// Alternatively, use this for stricter rules
...tseslint.configs.strictTypeChecked,
// Optionally, add this for stylistic rules
...tseslint.configs.stylisticTypeChecked,
],
languageOptions: {
// other options...
parserOptions: {
project: ['./tsconfig.node.json', './tsconfig.app.json'],
tsconfigRootDir: import.meta.dirname,
},
},
})
```
You can also install [eslint-plugin-react-x](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-x) and [eslint-plugin-react-dom](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-dom) for React-specific lint rules:
```js
// eslint.config.js
import reactX from 'eslint-plugin-react-x'
import reactDom from 'eslint-plugin-react-dom'
export default tseslint.config({
plugins: {
// Add the react-x and react-dom plugins
'react-x': reactX,
'react-dom': reactDom,
},
rules: {
// other rules...
// Enable its recommended typescript rules
...reactX.configs['recommended-typescript'].rules,
...reactDom.configs.recommended.rules,
},
})
```
+28
View File
@@ -0,0 +1,28 @@
import js from '@eslint/js'
import globals from 'globals'
import reactHooks from 'eslint-plugin-react-hooks'
import reactRefresh from 'eslint-plugin-react-refresh'
import tseslint from 'typescript-eslint'
export default tseslint.config(
{ ignores: ['dist'] },
{
extends: [js.configs.recommended, ...tseslint.configs.recommended],
files: ['**/*.{ts,tsx}'],
languageOptions: {
ecmaVersion: 2020,
globals: globals.browser,
},
plugins: {
'react-hooks': reactHooks,
'react-refresh': reactRefresh,
},
rules: {
...reactHooks.configs.recommended.rules,
'react-refresh/only-export-components': [
'warn',
{ allowConstantExport: true },
],
},
},
)
+17
View File
@@ -0,0 +1,17 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<link rel="icon" type="image/png" href="/favicon-96x96.png" sizes="96x96" />
<link rel="icon" type="image/svg+xml" href="/favicon.svg" />
<link rel="shortcut icon" href="/favicon.ico" />
<link rel="apple-touch-icon" sizes="180x180" href="/apple-touch-icon.png" />
<link rel="manifest" href="/site.webmanifest" />
<title>llama-swap</title>
</head>
<body >
<div id="root"></div>
<script type="module" src="/src/main.tsx"></script>
</body>
</html>
Binary file not shown.
+4049
View File
File diff suppressed because it is too large Load Diff
+35
View File
@@ -0,0 +1,35 @@
{
"name": "ui",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "tsc -b && vite build --emptyOutDir",
"lint": "eslint .",
"preview": "vite preview"
},
"dependencies": {
"@tailwindcss/vite": "^4.1.8",
"@tanstack/react-query": "^5.80.6",
"react": "^19.1.0",
"react-dom": "^19.1.0",
"react-icons": "^5.5.0",
"react-resizable-panels": "^3.0.4",
"react-router-dom": "^7.6.2",
"tailwindcss": "^4.1.8"
},
"devDependencies": {
"@eslint/js": "^9.25.0",
"@types/react": "^19.1.2",
"@types/react-dom": "^19.1.2",
"@vitejs/plugin-react": "^4.4.1",
"eslint": "^9.25.0",
"eslint-plugin-react-hooks": "^5.2.0",
"eslint-plugin-react-refresh": "^0.4.19",
"globals": "^16.0.0",
"typescript": "~5.8.3",
"typescript-eslint": "^8.30.1",
"vite": "^6.3.5"
}
}
Binary file not shown.

After

Width:  |  Height:  |  Size: 5.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 38 KiB

+21
View File
@@ -0,0 +1,21 @@
{
"name": "llama-swap",
"short_name": "llama-swap",
"icons": [
{
"src": "/web-app-manifest-192x192.png",
"sizes": "192x192",
"type": "image/png",
"purpose": "maskable"
},
{
"src": "/web-app-manifest-512x512.png",
"sizes": "512x512",
"type": "image/png",
"purpose": "maskable"
}
],
"theme_color": "#ffffff",
"background_color": "#ffffff",
"display": "standalone"
}
Binary file not shown.

After

Width:  |  Height:  |  Size: 6.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

+6
View File
@@ -0,0 +1,6 @@
#root {
max-width: 1280px;
margin: 0 auto;
padding: 2rem;
text-align: center;
}
+52
View File
@@ -0,0 +1,52 @@
import { BrowserRouter as Router, Routes, Route, Navigate, NavLink } from "react-router-dom";
import { useTheme } from "./contexts/ThemeProvider";
import { APIProvider } from "./contexts/APIProvider";
import LogViewerPage from "./pages/LogViewer";
import ModelPage from "./pages/Models";
import ActivityPage from "./pages/Activity";
import { RiSunFill, RiMoonFill } from "react-icons/ri";
function App() {
const { isNarrow, toggleTheme, isDarkMode } = useTheme();
return (
<Router basename="/ui/">
<APIProvider>
<div className="flex flex-col h-screen">
<nav className="bg-surface border-b border-border p-2 h-[75px]">
<div className="flex items-center justify-between mx-auto px-4 h-full">
{!isNarrow && <h1 className="flex items-center p-0">llama-swap</h1>}
<div className="flex items-center space-x-4">
<NavLink to="/" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
Logs
</NavLink>
<NavLink to="/models" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
Models
</NavLink>
<NavLink to="/activity" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
Activity
</NavLink>
<button className="" onClick={toggleTheme}>
{isDarkMode ? <RiMoonFill /> : <RiSunFill />}
</button>
</div>
</div>
</nav>
<main className="flex-1 overflow-auto p-4">
<Routes>
<Route path="/" element={<LogViewerPage />} />
<Route path="/models" element={<ModelPage />} />
<Route path="/activity" element={<ActivityPage />} />
<Route path="*" element={<Navigate to="/" replace />} />
</Routes>
</main>
</div>
</APIProvider>
</Router>
);
}
export default App;
Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

+1
View File
@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="35.93" height="32" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 228"><path fill="#00D8FF" d="M210.483 73.824a171.49 171.49 0 0 0-8.24-2.597c.465-1.9.893-3.777 1.273-5.621c6.238-30.281 2.16-54.676-11.769-62.708c-13.355-7.7-35.196.329-57.254 19.526a171.23 171.23 0 0 0-6.375 5.848a155.866 155.866 0 0 0-4.241-3.917C100.759 3.829 77.587-4.822 63.673 3.233C50.33 10.957 46.379 33.89 51.995 62.588a170.974 170.974 0 0 0 1.892 8.48c-3.28.932-6.445 1.924-9.474 2.98C17.309 83.498 0 98.307 0 113.668c0 15.865 18.582 31.778 46.812 41.427a145.52 145.52 0 0 0 6.921 2.165a167.467 167.467 0 0 0-2.01 9.138c-5.354 28.2-1.173 50.591 12.134 58.266c13.744 7.926 36.812-.22 59.273-19.855a145.567 145.567 0 0 0 5.342-4.923a168.064 168.064 0 0 0 6.92 6.314c21.758 18.722 43.246 26.282 56.54 18.586c13.731-7.949 18.194-32.003 12.4-61.268a145.016 145.016 0 0 0-1.535-6.842c1.62-.48 3.21-.974 4.76-1.488c29.348-9.723 48.443-25.443 48.443-41.52c0-15.417-17.868-30.326-45.517-39.844Zm-6.365 70.984c-1.4.463-2.836.91-4.3 1.345c-3.24-10.257-7.612-21.163-12.963-32.432c5.106-11 9.31-21.767 12.459-31.957c2.619.758 5.16 1.557 7.61 2.4c23.69 8.156 38.14 20.213 38.14 29.504c0 9.896-15.606 22.743-40.946 31.14Zm-10.514 20.834c2.562 12.94 2.927 24.64 1.23 33.787c-1.524 8.219-4.59 13.698-8.382 15.893c-8.067 4.67-25.32-1.4-43.927-17.412a156.726 156.726 0 0 1-6.437-5.87c7.214-7.889 14.423-17.06 21.459-27.246c12.376-1.098 24.068-2.894 34.671-5.345a134.17 134.17 0 0 1 1.386 6.193ZM87.276 214.515c-7.882 2.783-14.16 2.863-17.955.675c-8.075-4.657-11.432-22.636-6.853-46.752a156.923 156.923 0 0 1 1.869-8.499c10.486 2.32 22.093 3.988 34.498 4.994c7.084 9.967 14.501 19.128 21.976 27.15a134.668 134.668 0 0 1-4.877 4.492c-9.933 8.682-19.886 14.842-28.658 17.94ZM50.35 144.747c-12.483-4.267-22.792-9.812-29.858-15.863c-6.35-5.437-9.555-10.836-9.555-15.216c0-9.322 13.897-21.212 37.076-29.293c2.813-.98 5.757-1.905 8.812-2.773c3.204 10.42 7.406 21.315 12.477 32.332c-5.137 11.18-9.399 22.249-12.634 32.792a134.718 134.718 0 0 1-6.318-1.979Zm12.378-84.26c-4.811-24.587-1.616-43.134 6.425-47.789c8.564-4.958 27.502 2.111 47.463 19.835a144.318 144.318 0 0 1 3.841 3.545c-7.438 7.987-14.787 17.08-21.808 26.988c-12.04 1.116-23.565 2.908-34.161 5.309a160.342 160.342 0 0 1-1.76-7.887Zm110.427 27.268a347.8 347.8 0 0 0-7.785-12.803c8.168 1.033 15.994 2.404 23.343 4.08c-2.206 7.072-4.956 14.465-8.193 22.045a381.151 381.151 0 0 0-7.365-13.322Zm-45.032-43.861c5.044 5.465 10.096 11.566 15.065 18.186a322.04 322.04 0 0 0-30.257-.006c4.974-6.559 10.069-12.652 15.192-18.18ZM82.802 87.83a323.167 323.167 0 0 0-7.227 13.238c-3.184-7.553-5.909-14.98-8.134-22.152c7.304-1.634 15.093-2.97 23.209-3.984a321.524 321.524 0 0 0-7.848 12.897Zm8.081 65.352c-8.385-.936-16.291-2.203-23.593-3.793c2.26-7.3 5.045-14.885 8.298-22.6a321.187 321.187 0 0 0 7.257 13.246c2.594 4.48 5.28 8.868 8.038 13.147Zm37.542 31.03c-5.184-5.592-10.354-11.779-15.403-18.433c4.902.192 9.899.29 14.978.29c5.218 0 10.376-.117 15.453-.343c-4.985 6.774-10.018 12.97-15.028 18.486Zm52.198-57.817c3.422 7.8 6.306 15.345 8.596 22.52c-7.422 1.694-15.436 3.058-23.88 4.071a382.417 382.417 0 0 0 7.859-13.026a347.403 347.403 0 0 0 7.425-13.565Zm-16.898 8.101a358.557 358.557 0 0 1-12.281 19.815a329.4 329.4 0 0 1-23.444.823c-7.967 0-15.716-.248-23.178-.732a310.202 310.202 0 0 1-12.513-19.846h.001a307.41 307.41 0 0 1-10.923-20.627a310.278 310.278 0 0 1 10.89-20.637l-.001.001a307.318 307.318 0 0 1 12.413-19.761c7.613-.576 15.42-.876 23.31-.876H128c7.926 0 15.743.303 23.354.883a329.357 329.357 0 0 1 12.335 19.695a358.489 358.489 0 0 1 11.036 20.54a329.472 329.472 0 0 1-11 20.722Zm22.56-122.124c8.572 4.944 11.906 24.881 6.52 51.026c-.344 1.668-.73 3.367-1.15 5.09c-10.622-2.452-22.155-4.275-34.23-5.408c-7.034-10.017-14.323-19.124-21.64-27.008a160.789 160.789 0 0 1 5.888-5.4c18.9-16.447 36.564-22.941 44.612-18.3ZM128 90.808c12.625 0 22.86 10.235 22.86 22.86s-10.235 22.86-22.86 22.86s-22.86-10.235-22.86-22.86s10.235-22.86 22.86-22.86Z"></path></svg>

After

Width:  |  Height:  |  Size: 4.0 KiB

+203
View File
@@ -0,0 +1,203 @@
import { useRef, createContext, useState, useContext, useEffect, useCallback, useMemo, type ReactNode } from "react";
type ModelStatus = "ready" | "starting" | "stopping" | "stopped" | "shutdown" | "unknown";
const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
export interface Model {
id: string;
state: ModelStatus;
name: string;
description: string;
unlisted: boolean;
}
interface APIProviderType {
models: Model[];
listModels: () => Promise<Model[]>;
unloadAllModels: () => Promise<void>;
loadModel: (model: string) => Promise<void>;
enableAPIEvents: (enabled: boolean) => void;
proxyLogs: string;
upstreamLogs: string;
metrics: Metrics[];
}
interface Metrics {
id: number;
timestamp: string;
model: string;
input_tokens: number;
output_tokens: number;
tokens_per_second: number;
duration_ms: number;
}
interface LogData {
source: "upstream" | "proxy";
data: string;
}
interface APIEventEnvelope {
type: "modelStatus" | "logData" | "metrics";
data: string;
}
const APIContext = createContext<APIProviderType | undefined>(undefined);
type APIProviderProps = {
children: ReactNode;
autoStartAPIEvents?: boolean;
};
export function APIProvider({ children, autoStartAPIEvents = true }: APIProviderProps) {
const [proxyLogs, setProxyLogs] = useState("");
const [upstreamLogs, setUpstreamLogs] = useState("");
const [metrics, setMetrics] = useState<Metrics[]>([]);
const apiEventSource = useRef<EventSource | null>(null);
const [models, setModels] = useState<Model[]>([]);
const appendLog = useCallback((newData: string, setter: React.Dispatch<React.SetStateAction<string>>) => {
setter((prev) => {
const updatedLog = prev + newData;
return updatedLog.length > LOG_LENGTH_LIMIT ? updatedLog.slice(-LOG_LENGTH_LIMIT) : updatedLog;
});
}, []);
const enableAPIEvents = useCallback((enabled: boolean) => {
if (!enabled) {
apiEventSource.current?.close();
apiEventSource.current = null;
setMetrics([]);
return;
}
let retryCount = 0;
const initialDelay = 1000; // 1 second
const connect = () => {
const eventSource = new EventSource("/api/events");
eventSource.onmessage = (e: MessageEvent) => {
try {
const message = JSON.parse(e.data) as APIEventEnvelope;
switch (message.type) {
case "modelStatus":
{
const models = JSON.parse(message.data) as Model[];
setModels(models);
}
break;
case "logData":
const logData = JSON.parse(message.data) as LogData;
switch (logData.source) {
case "proxy":
appendLog(logData.data, setProxyLogs);
break;
case "upstream":
appendLog(logData.data, setUpstreamLogs);
break;
}
break;
case "metrics":
{
const newMetric = JSON.parse(message.data) as Metrics;
setMetrics((prevMetrics) => {
return [newMetric, ...prevMetrics];
});
}
break;
}
} catch (err) {
console.error(e.data, err);
}
};
eventSource.onerror = () => {
eventSource.close();
retryCount++;
const delay = Math.min(initialDelay * Math.pow(2, retryCount - 1), 5000);
setTimeout(connect, delay);
};
apiEventSource.current = eventSource;
};
connect();
}, []);
useEffect(() => {
if (autoStartAPIEvents) {
enableAPIEvents(true);
}
return () => {
enableAPIEvents(false);
};
}, [enableAPIEvents, autoStartAPIEvents]);
const listModels = useCallback(async (): Promise<Model[]> => {
try {
const response = await fetch("/api/models/");
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const data = await response.json();
return data || [];
} catch (error) {
console.error("Failed to fetch models:", error);
return []; // Return empty array as fallback
}
}, []);
const unloadAllModels = useCallback(async () => {
try {
const response = await fetch(`/api/models/unload/`, {
method: "POST",
});
if (!response.ok) {
throw new Error(`Failed to unload models: ${response.status}`);
}
} catch (error) {
console.error("Failed to unload models:", error);
throw error; // Re-throw to let calling code handle it
}
}, []);
const loadModel = useCallback(async (model: string) => {
try {
const response = await fetch(`/upstream/${model}/`, {
method: "GET",
});
if (!response.ok) {
throw new Error(`Failed to load model: ${response.status}`);
}
} catch (error) {
console.error("Failed to load model:", error);
throw error; // Re-throw to let calling code handle it
}
}, []);
const value = useMemo(
() => ({
models,
listModels,
unloadAllModels,
loadModel,
enableAPIEvents,
proxyLogs,
upstreamLogs,
metrics,
}),
[models, listModels, unloadAllModels, loadModel, enableAPIEvents, proxyLogs, upstreamLogs, metrics]
);
return <APIContext.Provider value={value}>{children}</APIContext.Provider>;
}
export function useAPI() {
const context = useContext(APIContext);
if (context === undefined) {
throw new Error("useAPI must be used within an APIProvider");
}
return context;
}
+68
View File
@@ -0,0 +1,68 @@
import { createContext, useContext, useEffect, type ReactNode, useMemo, useState } from "react";
import { usePersistentState } from "../hooks/usePersistentState";
type ScreenWidth = "xs" | "sm" | "md" | "lg" | "xl" | "2xl";
type ThemeContextType = {
isDarkMode: boolean;
screenWidth: ScreenWidth;
isNarrow: boolean;
toggleTheme: () => void;
};
const ThemeContext = createContext<ThemeContextType | undefined>(undefined);
type ThemeProviderProps = {
children: ReactNode;
};
export function ThemeProvider({ children }: ThemeProviderProps) {
const [isDarkMode, setIsDarkMode] = usePersistentState<boolean>("theme", false);
const [screenWidth, setScreenWidth] = useState<ScreenWidth>("md"); // Default to md
// matches tailwind classes
// https://tailwindcss.com/docs/responsive-design
useEffect(() => {
const checkInnerWidth = () => {
const innerWidth = window.innerWidth;
if (innerWidth < 640) {
setScreenWidth("xs");
} else if (innerWidth < 768) {
setScreenWidth("sm");
} else if (innerWidth < 1024) {
setScreenWidth("md");
} else if (innerWidth < 1280) {
setScreenWidth("lg");
} else if (innerWidth < 1536) {
setScreenWidth("xl");
} else {
setScreenWidth("2xl");
}
};
checkInnerWidth();
window.addEventListener("resize", checkInnerWidth);
return () => window.removeEventListener("resize", checkInnerWidth);
}, []);
useEffect(() => {
document.documentElement.setAttribute("data-theme", isDarkMode ? "dark" : "light");
}, [isDarkMode]);
const toggleTheme = () => setIsDarkMode((prev) => !prev);
const isNarrow = useMemo(() => {
return screenWidth === "xs" || screenWidth === "sm" || screenWidth === "md";
}, [screenWidth]);
return (
<ThemeContext.Provider value={{ isDarkMode, toggleTheme, screenWidth, isNarrow }}>{children}</ThemeContext.Provider>
);
}
export function useTheme(): ThemeContextType {
const context = useContext(ThemeContext);
if (context === undefined) {
throw new Error("useTheme must be used within a ThemeProvider");
}
return context;
}
+39
View File
@@ -0,0 +1,39 @@
import { useState, useEffect, useCallback } from "react";
export function usePersistentState<T>(key: string, initialValue: T): [T, (value: T | ((prevState: T) => T)) => void] {
const [state, setState] = useState<T>(() => {
if (typeof window === "undefined") return initialValue;
try {
const saved = localStorage.getItem(key);
return saved !== null ? JSON.parse(saved) : initialValue;
} catch (e) {
console.error(`Error parsing stored value for ${key}`, e);
return initialValue;
}
});
const setPersistentState = useCallback(
(value: T | ((prevState: T) => T)) => {
setState((prev) => {
const nextValue = typeof value === "function" ? (value as (prevState: T) => T)(prev) : value;
try {
localStorage.setItem(key, JSON.stringify(nextValue));
} catch (e) {
console.error(`Error saving value for ${key}`, e);
}
return nextValue;
});
},
[key]
);
useEffect(() => {
try {
localStorage.setItem(key, JSON.stringify(state));
} catch (e) {
console.error(`Error saving value for ${key}`, e);
}
}, [key, state]);
return [state, setPersistentState];
}
+168
View File
@@ -0,0 +1,168 @@
@import "tailwindcss";
@custom-variant dark (&:where([data-theme=dark], [data-theme=dark] *));
@theme {
--color-background: rgba(252, 252, 249, 1);
--color-surface: rgba(255, 255, 253, 1);
/* text colors */
--color-txtmain: rgba(19, 52, 59, 1);
--color-txtsecondary: rgba(98, 108, 113, 1);
--color-navlink-active: rgba(245, 245, 245, 1);
--color-primary: rgba(50, 184, 198, 1);
--color-primary-hover: rgba(29, 116, 128, 1);
--color-primary-active: rgba(26, 104, 115, 1);
--color-secondary: rgba(94, 82, 64, 0.12);
--color-secondary-hover: rgba(94, 82, 64, 0.2);
--color-secondary-active: rgba(94, 82, 64, 0.25);
--color-border: rgba(94, 82, 64, 0.3);
--color-btn-primary-text: rgba(252, 252, 249, 1);
--color-card-border: rgba(94, 82, 64, 0.12);
--color-card-border-inner: rgba(94, 82, 64, 0.12);
--color-error: rgba(192, 21, 47, 1);
--color-success: rgba(33, 128, 141, 1);
--color-warning: rgb(244, 155, 0);
--color-info: rgba(98, 108, 113, 1);
--color-focus-ring: rgba(33, 128, 141, 0.4);
--color-select-caret: rgba(19, 52, 59, 0.8);
--color-btn-border: rgba(94, 82, 64, 0.7);
}
@layer theme {
/* over ride theme for dark mode */
[data-theme="dark"] {
--color-background: rgba(31, 33, 33, 1);
--color-surface: rgba(38, 40, 40, 1);
/* text colors */
--color-txtmain: rgba(245, 245, 245, 1);
--color-txtsecondary: rgba(167, 169, 169, 0.7);
--color-navlink-active: rgba(245, 245, 245, 1);
--color-primary: rgba(33, 128, 141, 1);
--color-primary-hover: rgba(45, 166, 178, 1);
--color-primary-active: rgba(41, 150, 161, 1);
--color-secondary: rgba(119, 124, 124, 0.15);
--color-secondary-hover: rgba(119, 124, 124, 0.25);
--color-secondary-active: rgba(119, 124, 124, 0.3);
--color-border: rgba(119, 124, 124, 0.3);
--color-error: rgba(255, 84, 89, 1);
--color-success: rgba(50, 184, 198, 1);
--color-warning: rgb(244, 155, 0);
--color-info: rgba(167, 169, 169, 1);
--color-focus-ring: rgba(50, 184, 198, 0.4);
--color-btn-primary-text: rgba(19, 52, 59, 1);
--color-card-border: rgba(119, 124, 124, 0.2);
--color-card-border-inner: rgba(119, 124, 124, 0.15);
--shadow-inset-sm: inset 0 1px 0 rgba(255, 255, 255, 0.1), inset 0 -1px 0 rgba(0, 0, 0, 0.15);
--button-border-secondary: rgba(119, 124, 124, 0.2);
}
}
@layer base {
body {
/* example of how colors using theme colors*/
@apply bg-background text-txtmain;
}
h1 {
@apply text-4xl text-txtmain font-bold pb-4;
}
h2 {
@apply text-3xl text-txtmain font-bold pb-4;
}
h3 {
@apply text-2xl text-txtmain font-bold pb-4;
}
h4 {
@apply text-xl text-txtmain font-bold pb-4;
}
h5 {
@apply text-lg text-txtmain font-bold pb-4;
}
h6 {
@apply text-base text-txtmain font-bold pb-4;
}
}
/* define CSS classes here for specific types of components */
@layer components {
.container {
@apply px-4;
}
/* Navigation Header */
.navlink {
@apply text-txtsecondary hover:bg-secondary hover:text-txtmain rounded-lg p-2;
}
.navlink.active {
@apply bg-primary text-navlink-active;
}
/* Card component */
.card {
@apply bg-surface rounded-lg border border-card-border shadow-sm overflow-hidden p-4;
}
.card:hover {
@apply shadow-md;
}
.card__body {
@apply p-4;
}
.card__header,
.card__footer {
@apply p-4 border-b border-card-border-inner;
}
/* Status Badges */
.status {
@apply inline-block px-2 py-1 text-xs font-medium rounded-full;
}
.status--ready {
@apply bg-success/10 text-success;
}
.status--starting,
.status--stopping {
@apply bg-warning/10 text-warning;
}
.status--stopped {
@apply bg-error/10 text-error;
}
/* Buttons */
.btn {
@apply bg-surface p-2 px-4 text-sm rounded-full border border-2 transition-colors duration-200 border-btn-border;
}
.btn:hover {
cursor: pointer;
}
.btn--sm {
@apply px-2 py-0.5 text-xs;
}
.btn:disabled {
@apply opacity-50 cursor-not-allowed;
}
}
@layer utilities {
.ml-2 {
margin-left: 0.5rem;
}
.my-8 {
margin-top: 2rem;
margin-bottom: 2rem;
}
}
+13
View File
@@ -0,0 +1,13 @@
import { StrictMode } from "react";
import { createRoot } from "react-dom/client";
import "./index.css";
import App from "./App.tsx";
import { ThemeProvider } from "./contexts/ThemeProvider";
createRoot(document.getElementById("root")!).render(
<StrictMode>
<ThemeProvider>
<App />
</ThemeProvider>
</StrictMode>
);
+77
View File
@@ -0,0 +1,77 @@
import { useState, useEffect } from "react";
import { useAPI } from "../contexts/APIProvider";
const formatTimestamp = (timestamp: string): string => {
return new Date(timestamp).toLocaleString();
};
const formatSpeed = (speed: number): string => {
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
};
const formatDuration = (ms: number): string => {
return (ms / 1000).toFixed(2) + "s";
};
const ActivityPage = () => {
const { metrics } = useAPI();
const [error, setError] = useState<string | null>(null);
useEffect(() => {
if (metrics.length > 0) {
setError(null);
}
}, [metrics]);
if (error) {
return (
<div className="p-6">
<h1 className="text-2xl font-bold mb-4">Activity</h1>
<div className="bg-red-50 border border-red-200 rounded-md p-4">
<p className="text-red-800">{error}</p>
</div>
</div>
);
}
return (
<div className="p-6">
<h1 className="text-2xl font-bold mb-4">Activity</h1>
{metrics.length === 0 ? (
<div className="text-center py-8">
<p className="text-gray-600">No metrics data available</p>
</div>
) : (
<div className="overflow-x-auto">
<table className="min-w-full divide-y">
<thead>
<tr>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Timestamp</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Model</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Input Tokens</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Output Tokens</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Generation Speed</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Duration</th>
</tr>
</thead>
<tbody className="divide-y">
{metrics.map((metric, index) => (
<tr key={`${metric.id}-${index}`}>
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatTimestamp(metric.timestamp)}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.model}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.input_tokens.toLocaleString()}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.output_tokens.toLocaleString()}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatSpeed(metric.tokens_per_second)}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatDuration(metric.duration_ms)}</td>
</tr>
))}
</tbody>
</table>
</div>
)}
</div>
);
};
export default ActivityPage;
+162
View File
@@ -0,0 +1,162 @@
import { useState, useEffect, useRef, useMemo, useCallback } from "react";
import { useAPI } from "../contexts/APIProvider";
import { usePersistentState } from "../hooks/usePersistentState";
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
import {
RiTextWrap,
RiAlignJustify,
RiFontSize,
RiMenuSearchLine,
RiMenuSearchFill,
RiCloseCircleFill,
} from "react-icons/ri";
import { useTheme } from "../contexts/ThemeProvider";
const LogViewer = () => {
const { proxyLogs, upstreamLogs } = useAPI();
const { isNarrow } = useTheme();
const direction = isNarrow ? "vertical" : "horizontal";
return (
<PanelGroup direction={direction} className="gap-2" autoSaveId="logviewer-panel-group">
<Panel id="proxy" defaultSize={50} minSize={5} maxSize={100} collapsible={true}>
<LogPanel id="proxy" title="Proxy Logs" logData={proxyLogs} />
</Panel>
<PanelResizeHandle
className={
direction === "horizontal"
? "w-2 h-full bg-primary hover:bg-success transition-colors rounded"
: "w-full h-2 bg-primary hover:bg-success transition-colors rounded"
}
/>
<Panel id="upstream" defaultSize={50} minSize={5} maxSize={100} collapsible={true}>
<LogPanel id="upstream" title="Upstream Logs" logData={upstreamLogs} />
</Panel>
</PanelGroup>
);
};
interface LogPanelProps {
id: string;
title: string;
logData: string;
}
export const LogPanel = ({ id, title, logData }: LogPanelProps) => {
const [filterRegex, setFilterRegex] = useState("");
const [fontSize, setFontSize] = usePersistentState<"xxs" | "xs" | "small" | "normal">(
`logPanel-${id}-fontSize`,
"normal"
);
const [wrapText, setTextWrap] = usePersistentState(`logPanel-${id}-wrapText`, false);
const [showFilter, setShowFilter] = usePersistentState(`logPanel-${id}-showFilter`, false);
const textWrapClass = useMemo(() => {
return wrapText ? "whitespace-pre-wrap" : "whitespace-pre";
}, [wrapText]);
const toggleFontSize = useCallback(() => {
setFontSize((prev) => {
switch (prev) {
case "xxs":
return "xs";
case "xs":
return "small";
case "small":
return "normal";
case "normal":
return "xxs";
}
});
}, []);
const toggleWrapText = useCallback(() => {
setTextWrap((prev) => !prev);
}, []);
const toggleFilter = useCallback(() => {
if (showFilter) {
setShowFilter(false);
setFilterRegex(""); // Clear filter when closing
} else {
setShowFilter(true);
}
}, [filterRegex, setFilterRegex, showFilter]);
const fontSizeClass = useMemo(() => {
switch (fontSize) {
case "xxs":
return "text-[0.5rem]"; // 0.5rem (8px)
case "xs":
return "text-[0.75rem]"; // 0.75rem (12px)
case "small":
return "text-[0.875rem]"; // 0.875rem (14px)
case "normal":
return "text-base"; // 1rem (16px)
}
}, [fontSize]);
const filteredLogs = useMemo(() => {
if (!filterRegex) return logData;
try {
const regex = new RegExp(filterRegex, "i");
const lines = logData.split("\n");
const filtered = lines.filter((line) => regex.test(line));
return filtered.join("\n");
} catch (e) {
return logData; // Return unfiltered if regex is invalid
}
}, [logData, filterRegex]);
// auto scroll to bottom
const preTagRef = useRef<HTMLPreElement>(null);
useEffect(() => {
if (!preTagRef.current) return;
preTagRef.current.scrollTop = preTagRef.current.scrollHeight;
}, [filteredLogs]);
return (
<div className="bg-surface border border-border rounded-lg overflow-hidden flex flex-col h-full">
<div className="p-4 border-b border-border bg-secondary">
<div className="flex items-center justify-between">
<h3 className="m-0 text-lg p-0">{title}</h3>
<div className="flex gap-2 items-center">
<button className="btn" onClick={toggleFontSize}>
<RiFontSize />
</button>
<button className="btn" onClick={toggleWrapText}>
{wrapText ? <RiTextWrap /> : <RiAlignJustify />}
</button>
<button className="btn" onClick={toggleFilter}>
{showFilter ? <RiMenuSearchFill /> : <RiMenuSearchLine />}
</button>
</div>
</div>
{/* Filtering Options - Full width on mobile, normal on desktop */}
{showFilter && (
<div className="mt-2 w-full">
<div className="flex gap-2 items-center w-full">
<input
type="text"
className="w-full text-sm border p-2 rounded"
placeholder="Filter logs..."
value={filterRegex}
onChange={(e) => setFilterRegex(e.target.value)}
/>
<button className="pl-2" onClick={() => setFilterRegex("")}>
<RiCloseCircleFill size="24" />
</button>
</div>
</div>
)}
</div>
<div className="bg-background font-mono text-sm flex-1 overflow-hidden">
<pre ref={preTagRef} className={`${textWrapClass} ${fontSizeClass} h-full overflow-auto p-4`}>
{filteredLogs}
</pre>
</div>
</div>
);
};
export default LogViewer;
+156
View File
@@ -0,0 +1,156 @@
import { useState, useCallback, useMemo } from "react";
import { useAPI } from "../contexts/APIProvider";
import { LogPanel } from "./LogViewer";
import { usePersistentState } from "../hooks/usePersistentState";
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
import { useTheme } from "../contexts/ThemeProvider";
import { RiEyeFill, RiEyeOffFill, RiStopCircleLine } from "react-icons/ri";
export default function ModelsPage() {
const { isNarrow } = useTheme();
const direction = isNarrow ? "vertical" : "horizontal";
const { upstreamLogs } = useAPI();
return (
<PanelGroup direction={direction} className="gap-2" autoSaveId={"models-panel-group"}>
<Panel id="models" defaultSize={50} minSize={isNarrow ? 0 : 25} maxSize={100} collapsible={isNarrow}>
<ModelsPanel />
</Panel>
<PanelResizeHandle
className={
direction === "horizontal"
? "w-2 h-full bg-primary hover:bg-success transition-colors rounded"
: "w-full h-2 bg-primary hover:bg-success transition-colors rounded"
}
/>
<Panel collapsible={true} defaultSize={50} minSize={0}>
<div className="flex flex-col h-full space-y-4">
{direction === "horizontal" && <StatsPanel />}
<div className="flex-1 min-h-0">
<LogPanel id="modelsupstream" title="Upstream Logs" logData={upstreamLogs} />
</div>
</div>
</Panel>
</PanelGroup>
);
}
function ModelsPanel() {
const { models, loadModel, unloadAllModels } = useAPI();
const [isUnloading, setIsUnloading] = useState(false);
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
const filteredModels = useMemo(() => {
return models.filter((model) => showUnlisted || !model.unlisted);
}, [models, showUnlisted]);
const handleUnloadAllModels = useCallback(async () => {
setIsUnloading(true);
try {
await unloadAllModels();
} catch (e) {
console.error(e);
} finally {
setTimeout(() => {
setIsUnloading(false);
}, 1000);
}
}, [unloadAllModels]);
return (
<div className="card h-full flex flex-col">
<div className="shrink-0">
<h2>Models</h2>
<div className="flex justify-between">
<button
className="btn flex items-center gap-2"
onClick={() => setShowUnlisted(!showUnlisted)}
style={{ lineHeight: "1.2" }}
>
{showUnlisted ? <RiEyeFill /> : <RiEyeOffFill />} unlisted
</button>
<button className="btn flex items-center gap-2" onClick={handleUnloadAllModels} disabled={isUnloading}>
<RiStopCircleLine size="24" /> {isUnloading ? "Unloading..." : "Unload"}
</button>
</div>
</div>
<div className="flex-1 overflow-y-auto">
<table className="w-full">
<thead className="sticky top-0 bg-card z-10">
<tr className="border-b border-primary bg-surface">
<th className="text-left p-2">Name</th>
<th className="text-left p-2"></th>
<th className="text-left p-2">State</th>
</tr>
</thead>
<tbody>
{filteredModels.map((model) => (
<tr key={model.id} className="border-b hover:bg-secondary-hover border-border">
<td className={`p-2 ${model.unlisted ? "text-txtsecondary" : ""}`}>
<a href={`/upstream/${model.id}/`} className={`underline`} target="_blank">
{model.name !== "" ? model.name : model.id}
</a>
{model.description !== "" && (
<p className={model.unlisted ? "text-opacity-70" : ""}>
<em>{model.description}</em>
</p>
)}
</td>
<td className="p-2 w-[50px]">
<button
className="btn btn--sm"
disabled={model.state !== "stopped"}
onClick={() => loadModel(model.id)}
>
Load
</button>
</td>
<td className="p-2 w-[75px]">
<span className={`status status--${model.state}`}>{model.state}</span>
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
);
}
function StatsPanel() {
const { metrics } = useAPI();
const [totalRequests, totalTokens, avgTokensPerSecond] = useMemo(() => {
const totalRequests = metrics.length;
if (totalRequests === 0) {
return [0, 0, 0];
}
const totalTokens = metrics.reduce((sum, m) => sum + m.output_tokens, 0);
const avgTokensPerSecond = (metrics.reduce((sum, m) => sum + m.tokens_per_second, 0) / totalRequests).toFixed(2);
return [totalRequests, totalTokens, avgTokensPerSecond];
}, [metrics]);
return (
<div className="card">
<h2>Chat Activity</h2>
<table className="w-full border border-gray-200">
<tbody>
<tr className="border-b border-gray-200">
<td className="py-2 px-4 font-medium border-r border-gray-200">Requests</td>
<td className="py-2 px-4 text-right">{totalRequests}</td>
</tr>
<tr className="border-b border-gray-200">
<td className="py-2 px-4 font-medium border-r border-gray-200">Total Tokens Generated</td>
<td className="py-2 px-4 text-right">{totalTokens}</td>
</tr>
<tr>
<td className="py-2 px-4 font-medium border-r border-gray-200">Average Tokens/Second</td>
<td className="py-2 px-4 text-right">{avgTokensPerSecond}</td>
</tr>
</tbody>
</table>
</div>
);
}
+1
View File
@@ -0,0 +1 @@
/// <reference types="vite/client" />
+27
View File
@@ -0,0 +1,27 @@
{
"compilerOptions": {
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
"target": "ES2020",
"useDefineForClassFields": true,
"lib": ["ES2020", "DOM", "DOM.Iterable"],
"module": "ESNext",
"skipLibCheck": true,
/* Bundler mode */
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"verbatimModuleSyntax": true,
"moduleDetection": "force",
"noEmit": true,
"jsx": "react-jsx",
/* Linting */
"strict": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"erasableSyntaxOnly": true,
"noFallthroughCasesInSwitch": true,
"noUncheckedSideEffectImports": true
},
"include": ["src"]
}
+7
View File
@@ -0,0 +1,7 @@
{
"files": [],
"references": [
{ "path": "./tsconfig.app.json" },
{ "path": "./tsconfig.node.json" }
]
}
+25
View File
@@ -0,0 +1,25 @@
{
"compilerOptions": {
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
"target": "ES2022",
"lib": ["ES2023"],
"module": "ESNext",
"skipLibCheck": true,
/* Bundler mode */
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"verbatimModuleSyntax": true,
"moduleDetection": "force",
"noEmit": true,
/* Linting */
"strict": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"erasableSyntaxOnly": true,
"noFallthroughCasesInSwitch": true,
"noUncheckedSideEffectImports": true
},
"include": ["vite.config.ts"]
}
+20
View File
@@ -0,0 +1,20 @@
import { defineConfig } from "vite";
import react from "@vitejs/plugin-react";
import tailwindcss from "@tailwindcss/vite";
// https://vite.dev/config/
export default defineConfig({
plugins: [react(), tailwindcss()],
base: "/ui/",
build: {
outDir: "../proxy/ui_dist",
assetsDir: "assets",
},
server: {
proxy: {
"/api": "http://localhost:8080", // Proxy API calls to Go backend during development
"/logs": "http://localhost:8080",
"/upstream": "http://localhost:8080",
},
},
});