Compare commits
59 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 04fc67354a | |||
| 4662cf7699 | |||
| 5dc6b3e6d9 | |||
| 74c69f39ef | |||
| a186318892 | |||
| c4e4d5e1e9 | |||
| 7985e94ba4 | |||
| 74556c3a36 | |||
| 5c381e4b30 | |||
| 10569ed546 | |||
| 5b10b3c23f | |||
| 45ea792a3a | |||
| 1bc2802353 | |||
| 701476c0c4 | |||
| 5c63e0066c | |||
| 8be5073c51 | |||
| 6307bd3205 | |||
| 558a72de17 | |||
| dc42cf366d | |||
| ba0a81937a | |||
| 574fdfabb4 | |||
| 5172cb2e12 | |||
| 5672cb03fd | |||
| 0f583163f7 | |||
| 7905fa9ea3 | |||
| bbaf172956 | |||
| fd50932dbc | |||
| 8c693e7fcf | |||
| 8f2af26a41 | |||
| 01d4838fb3 | |||
| accd65294b | |||
| 7472a25864 | |||
| cce0bc6aa1 | |||
| 36e25125e8 | |||
| 9a54273d15 | |||
| 87dce5f8f6 | |||
| 307e619521 | |||
| 6299c1b874 | |||
| a906cd459b | |||
| 78b2bc3dbc | |||
| 6a058e4191 | |||
| 1921e570d7 | |||
| c867a6c9a2 | |||
| 3bd1b23ce0 | |||
| 10606abf89 | |||
| fefd14903d | |||
| 717d64e336 | |||
| 285191e655 | |||
| 4236cec03a | |||
| 756193d0dd | |||
| a6b2e930d8 | |||
| 9e02c22ff8 | |||
| 0bdbf2fdc1 | |||
| 49035e2e8e | |||
| 9963ae18bf | |||
| 2ae48c713b | |||
| 54c519e365 | |||
| 3fce9ee0e9 | |||
| 5899ae7966 |
@@ -1,11 +1,13 @@
|
||||
---
|
||||
name: Bug Report
|
||||
about: Something is not working as expected...
|
||||
about: I found a defect
|
||||
title: ''
|
||||
labels: bug
|
||||
labels: 'unconfirmed bug'
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
> [!IMPORTANT]
|
||||
> If you have questions about llama-swap please post in the Q&A in Discussions. Use bug reports when you've found a defect and wish to discuss a fix.
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
@@ -22,6 +22,13 @@ jobs:
|
||||
with:
|
||||
go-version: '1.23'
|
||||
|
||||
# Only run in this linux based runner
|
||||
- name: Check Formatting
|
||||
run: |
|
||||
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
|
||||
gofmt -l . | grep -v 'event/.*_test.go'
|
||||
exit 1
|
||||
fi
|
||||
# cache simple-responder to save the build time
|
||||
- name: Restore Simple Responder
|
||||
id: restore-simple-responder
|
||||
|
||||
@@ -7,6 +7,10 @@ on:
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: 'Tag version to release (e.g. v144)'
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -20,15 +24,15 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
|
||||
-
|
||||
name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '23' # or your preferred version
|
||||
node-version: '23'
|
||||
-
|
||||
name: Install dependencies and build UI
|
||||
run: |
|
||||
@@ -46,4 +50,30 @@ jobs:
|
||||
version: '~> v2'
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
trigger-tap-update:
|
||||
runs-on: ubuntu-latest
|
||||
needs: goreleaser
|
||||
steps:
|
||||
- name: "Resolve tag to dispatch"
|
||||
id: tag
|
||||
run: |
|
||||
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
||||
echo "tag=${{ github.event.inputs.tag }}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "tag=${{ github.ref_name }}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: "Trigger tap repository update"
|
||||
uses: peter-evans/repository-dispatch@v2
|
||||
with:
|
||||
token: ${{ secrets.TAP_REPO_PAT }}
|
||||
repository: mostlygeek/homebrew-llama-swap
|
||||
event-type: new-release
|
||||
client-payload: |
|
||||
{
|
||||
"release": {
|
||||
"tag_name": "${{ steps.tag.outputs.tag }}"
|
||||
}
|
||||
}
|
||||
@@ -4,3 +4,4 @@ build/
|
||||
dist/
|
||||
.vscode
|
||||
.DS_Store
|
||||
.dev/
|
||||
|
||||
@@ -17,14 +17,16 @@ builds:
|
||||
- goos: windows
|
||||
goarch: arm64
|
||||
|
||||
# use zip format for windows
|
||||
archives:
|
||||
- id: default
|
||||
format: tar.gz
|
||||
formats:
|
||||
- tar.gz
|
||||
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||
builds_info:
|
||||
group: root
|
||||
owner: root
|
||||
format_overrides:
|
||||
# use zip format for windows
|
||||
- goos: windows
|
||||
format: zip
|
||||
formats:
|
||||
- zip
|
||||
@@ -29,9 +29,13 @@ test: proxy/ui_dist/placeholder.txt
|
||||
test-all: proxy/ui_dist/placeholder.txt
|
||||
go test -v -count=1 ./proxy
|
||||
|
||||
ui/node_modules:
|
||||
cd ui && npm install
|
||||
|
||||
# build react UI
|
||||
ui:
|
||||
ui: ui/node_modules
|
||||
cd ui && npm run build
|
||||
|
||||
# Build OSX binary
|
||||
mac: ui
|
||||
@echo "Building Mac binary..."
|
||||
@@ -41,6 +45,7 @@ mac: ui
|
||||
linux: ui
|
||||
@echo "Building Linux binary..."
|
||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
|
||||
|
||||
# Build Windows binary
|
||||
windows: ui
|
||||
|
||||
@@ -18,62 +18,80 @@ Written in golang, it is very easy to install (single binary with no dependencie
|
||||
- `v1/completions`
|
||||
- `v1/chat/completions`
|
||||
- `v1/embeddings`
|
||||
- `v1/rerank`
|
||||
- `v1/rerank`, `v1/reranking`, `rerank`
|
||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||
- ✅ llama-swap custom API endpoints
|
||||
- `/ui` - web UI
|
||||
- `/log` - remote log monitoring
|
||||
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||
- `/health` - just returns "OK"
|
||||
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||
- ✅ Docker and Podman support
|
||||
- ✅ Reliable Docker and Podman support with `cmdStart` and `cmdStop`
|
||||
- ✅ Full control over server settings per model
|
||||
- ✅ Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
|
||||
|
||||
## How does llama-swap work?
|
||||
|
||||
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
|
||||
|
||||
## config.yaml
|
||||
|
||||
llama-swap's configuration is purposefully simple:
|
||||
llama-swap is managed entirely through a yaml configuration file.
|
||||
|
||||
It can be very minimal to start:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
"qwen2.5":
|
||||
cmd: |
|
||||
/app/llama-server
|
||||
/path/to/llama-server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
--port ${PORT}
|
||||
|
||||
"smollm2":
|
||||
cmd: |
|
||||
/app/llama-server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
--port ${PORT}
|
||||
```
|
||||
|
||||
.. but also supports many advanced features:
|
||||
However, there are many more capabilities that llama-swap supports:
|
||||
|
||||
- `groups` to run multiple models at once
|
||||
- `macros` for reusable snippets
|
||||
- `ttl` to automatically unload models
|
||||
- `macros` for reusable snippets
|
||||
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
|
||||
- `env` variables to pass custom environment to inference servers
|
||||
- `env` to pass custom environment variables to inference servers
|
||||
- `cmdStop` for to gracefully stop Docker/Podman containers
|
||||
- `useModelName` to override model names sent to upstream servers
|
||||
- `healthCheckTimeout` to control model startup wait times
|
||||
- `${PORT}` automatic port variables for dynamic port assignment
|
||||
- `cmdStop` for to gracefully stop Docker/Podman containers
|
||||
|
||||
Check the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration) in the wiki for all options.
|
||||
See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration) in the wiki all options and examples.
|
||||
|
||||
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
## Web UI
|
||||
|
||||
Docker is the quickest way to try out llama-swap:
|
||||
llama-swap includes a real time web interface for monitoring logs and models:
|
||||
|
||||
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/adef4a8e-de0b-49db-885a-8f6dedae6799" />
|
||||
|
||||
The Activity Page shows recent requests:
|
||||
|
||||
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" />
|
||||
|
||||
## Installation
|
||||
|
||||
llama-swap can be installed in multiple ways
|
||||
|
||||
1. Docker
|
||||
2. Homebrew (OSX and Linux)
|
||||
3. From release binaries
|
||||
4. From source
|
||||
|
||||
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
|
||||
Docker images with llama-swap and llama-server are built nightly.
|
||||
|
||||
```shell
|
||||
# use CPU inference comes with the example config above
|
||||
@@ -95,7 +113,7 @@ $ curl -s http://localhost:9292/v1/chat/completions \
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Docker images are built nightly for cuda, intel, vulcan, etc ...</summary>
|
||||
<summary>Docker images are built nightly with llama-server for cuda, intel, vulcan and musa.</summary>
|
||||
|
||||
They include:
|
||||
|
||||
@@ -118,13 +136,27 @@ $ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
|
||||
</details>
|
||||
|
||||
## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases))
|
||||
### Homebrew Install (macOS/Linux)
|
||||
|
||||
Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server.
|
||||
The latest release of `llama-swap` can be installed via [Homebrew](https://brew.sh).
|
||||
|
||||
```shell
|
||||
# Set up tap and install formula
|
||||
brew tap mostlygeek/llama-swap
|
||||
brew install llama-swap
|
||||
# Run llama-swap
|
||||
llama-swap --config path/to/config.yaml --listen localhost:8080
|
||||
```
|
||||
|
||||
This will install the `llama-swap` binary and make it available in your path. See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration)
|
||||
|
||||
### Pre-built Binaries ([download](https://github.com/mostlygeek/llama-swap/releases))
|
||||
|
||||
Binaries are available for Linux, Mac, Windows and FreeBSD. These are automatically published and are likely a few hours ahead of the docker releases. The binary install works with any OpenAI compatible server, not just llama-server.
|
||||
|
||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||
1. Run the binary with `llama-swap --config path/to/config.yaml`.
|
||||
1. Create a configuration file, see the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration).
|
||||
1. Run the binary with `llama-swap --config path/to/config.yaml --listen localhost:8080`.
|
||||
Available flags:
|
||||
- `--config`: Path to the configuration file (default: `config.yaml`).
|
||||
- `--listen`: Address and port to listen on (default: `:8080`).
|
||||
@@ -133,16 +165,16 @@ Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are
|
||||
|
||||
### Building from source
|
||||
|
||||
1. Install golang for your system
|
||||
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
||||
1. Build requires golang and nodejs for the user interface.
|
||||
1. `git clone https://github.com/mostlygeek/llama-swap.git`
|
||||
1. `make clean all`
|
||||
1. Binaries will be in `build/` subdirectory
|
||||
|
||||
## Monitoring Logs
|
||||
|
||||
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
|
||||
Open the `http://<host>:<port>/` with your browser to get a web interface with streaming logs.
|
||||
|
||||
Of course, CLI access is also supported:
|
||||
CLI access is also supported:
|
||||
|
||||
```shell
|
||||
# sends up to the last 10KB of logs
|
||||
|
||||
@@ -1,93 +1,232 @@
|
||||
# ======
|
||||
# For a more detailed configuration example:
|
||||
# https://github.com/mostlygeek/llama-swap/wiki/Configuration
|
||||
# ======
|
||||
# llama-swap YAML configuration example
|
||||
# -------------------------------------
|
||||
#
|
||||
# 💡 Tip - Use an LLM with this file!
|
||||
# ====================================
|
||||
# This example configuration is written to be LLM friendly! Try
|
||||
# copying this file into an LLM and asking it to explain or generate
|
||||
# sections for you.
|
||||
# ====================================
|
||||
#
|
||||
# - Below are all the available configuration options for llama-swap.
|
||||
# - Settings with a default value, or noted as optional can be omitted.
|
||||
# - Settings that are marked required must be in your configuration file
|
||||
|
||||
# Seconds to wait for llama.cpp to be available to serve requests
|
||||
# Default (and minimum): 15 seconds
|
||||
healthCheckTimeout: 90
|
||||
# healthCheckTimeout: number of seconds to wait for a model to be ready to serve requests
|
||||
# - optional, default: 120
|
||||
# - minimum value is 15 seconds, anything less will be set to this value
|
||||
healthCheckTimeout: 500
|
||||
|
||||
# valid log levels: debug, info (default), warn, error
|
||||
logLevel: debug
|
||||
# logLevel: sets the logging value
|
||||
# - optional, default: info
|
||||
# - Valid log levels: debug, info, warn, error
|
||||
logLevel: info
|
||||
|
||||
# creating a coding profile with models for code generation and general questions
|
||||
groups:
|
||||
coding:
|
||||
swap: false
|
||||
members:
|
||||
- "qwen"
|
||||
- "llama"
|
||||
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||
# - optional, default: 1000
|
||||
# - controls how many metrics are stored in memory before older ones are discarded
|
||||
# - useful for limiting memory usage when processing large volumes of metrics
|
||||
metricsMaxInMemory: 1000
|
||||
|
||||
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
||||
# - optional, default: 5800
|
||||
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
||||
# - it is automatically incremented for every model that uses it
|
||||
startPort: 10001
|
||||
|
||||
# macros: sets a dictionary of string:string pairs
|
||||
# - optional, default: empty dictionary
|
||||
# - these are reusable snippets
|
||||
# - used in a model's cmd, cmdStop, proxy and checkEndpoint
|
||||
# - useful for reducing common configuration settings
|
||||
macros:
|
||||
"latest-llama": >
|
||||
/path/to/llama-server/llama-server-ec9e0301
|
||||
--port ${PORT}
|
||||
|
||||
# models: a dictionary of model configurations
|
||||
# - required
|
||||
# - each key is the model's ID, used in API requests
|
||||
# - model settings have default values that are used if they are not defined here
|
||||
# - below are examples of the various settings a model can have:
|
||||
# - available model settings: env, cmd, cmdStop, proxy, aliases, checkEndpoint, ttl, unlisted
|
||||
models:
|
||||
|
||||
# keys are the model names used in API requests
|
||||
"llama":
|
||||
# cmd: the command to run to start the inference server.
|
||||
# - required
|
||||
# - it is just a string, similar to what you would run on the CLI
|
||||
# - using `|` allows for comments in the command, these will be parsed out
|
||||
# - macros can be used within cmd
|
||||
cmd: |
|
||||
models/llama-server-osx
|
||||
--port ${PORT}
|
||||
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
|
||||
# ${latest-llama} is a macro that is defined above
|
||||
${latest-llama}
|
||||
--model path/to/llama-8B-Q4_K_M.gguf
|
||||
|
||||
# list of model name aliases this llama.cpp instance can serve
|
||||
# name: a display name for the model
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
name: "llama 3.1 8B"
|
||||
|
||||
# description: a description for the model
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
description: "A small but capable model used for quick testing"
|
||||
|
||||
# env: define an array of environment variables to inject into cmd's environment
|
||||
# - optional, default: empty array
|
||||
# - each value is a single string
|
||||
# - in the format: ENV_NAME=value
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0,1,2"
|
||||
|
||||
# proxy: the URL where llama-swap routes API requests
|
||||
# - optional, default: http://localhost:${PORT}
|
||||
# - if you used ${PORT} in cmd this can be omitted
|
||||
# - if you use a custom port in cmd this *must* be set
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# aliases: alternative model names that this model configuration is used for
|
||||
# - optional, default: empty array
|
||||
# - aliases must be unique globally
|
||||
# - useful for impersonating a specific model
|
||||
aliases:
|
||||
- gpt-4o-mini
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
|
||||
# check this path for a HTTP 200 response for the server to be ready
|
||||
checkEndpoint: /health
|
||||
# checkEndpoint: URL path to check if the server is ready
|
||||
# - optional, default: /health
|
||||
# - use "none" to skip endpoint ready checking
|
||||
# - endpoint is expected to return an HTTP 200 response
|
||||
# - all requests wait until the endpoint is ready (or fails)
|
||||
checkEndpoint: /custom-endpoint
|
||||
|
||||
# unload model after 5 seconds
|
||||
ttl: 5
|
||||
# ttl: automatically unload the model after this many seconds
|
||||
# - optional, default: 0
|
||||
# - ttl values must be a value greater than 0
|
||||
# - a value of 0 disables automatic unloading of the model
|
||||
ttl: 60
|
||||
|
||||
"qwen":
|
||||
cmd: models/llama-server-osx --port ${PORT} -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
aliases:
|
||||
- gpt-3.5-turbo
|
||||
# useModelName: overrides the model name that is sent to upstream server
|
||||
# - optional, default: ""
|
||||
# - useful when the upstream server expects a specific model name or format
|
||||
useModelName: "qwen:qwq"
|
||||
|
||||
# Embedding example with Nomic
|
||||
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
|
||||
"nomic":
|
||||
cmd: |
|
||||
models/llama-server-osx --port ${PORT}
|
||||
-m models/nomic-embed-text-v1.5.Q8_0.gguf
|
||||
--ctx-size 8192
|
||||
--batch-size 8192
|
||||
--rope-scaling yarn
|
||||
--rope-freq-scale 0.75
|
||||
-ngl 99
|
||||
--embeddings
|
||||
# filters: a dictionary of filter settings
|
||||
# - optional, default: empty dictionary
|
||||
filters:
|
||||
# strip_params: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
# - useful for preventing overriding of default server params by requests
|
||||
# - `model` parameter is never removed
|
||||
# - can be any JSON key in the request body
|
||||
# - recommended to stick to sampling parameters
|
||||
strip_params: "temperature, top_p, top_k"
|
||||
|
||||
# Reranking example with bge-reranker
|
||||
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
|
||||
"bge-reranker":
|
||||
cmd: |
|
||||
models/llama-server-osx --port ${PORT}
|
||||
-m models/bge-reranker-v2-m3-Q4_K_M.gguf
|
||||
--ctx-size 8192
|
||||
--reranking
|
||||
# Unlisted model example:
|
||||
"qwen-unlisted":
|
||||
# unlisted: true or false
|
||||
# - optional, default: false
|
||||
# - unlisted models do not show up in /v1/models or /upstream lists
|
||||
# - can be requested as normal through all apis
|
||||
unlisted: true
|
||||
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
|
||||
# Docker Support (v26.1.4+ required!)
|
||||
"dockertest":
|
||||
# Docker example:
|
||||
# container run times like Docker and Podman can also be used with a
|
||||
# a combination of cmd and cmdStop.
|
||||
"docker-llama":
|
||||
proxy: "http://127.0.0.1:${PORT}"
|
||||
cmd: |
|
||||
docker run --name dockertest
|
||||
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
||||
ghcr.io/ggerganov/llama.cpp:server
|
||||
ghcr.io/ggml-org/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
"simple":
|
||||
# example of setting environment variables
|
||||
env:
|
||||
- CUDA_VISIBLE_DEVICES=0,1
|
||||
- env1=hello
|
||||
cmd: build/simple-responder --port ${PORT}
|
||||
unlisted: true
|
||||
# cmdStop: command to run to stop the model gracefully
|
||||
# - optional, default: ""
|
||||
# - useful for stopping commands managed by another system
|
||||
# - on POSIX systems: a SIGTERM is sent for graceful shutdown
|
||||
# - on Windows, taskkill is used
|
||||
# - processes are given 5 seconds to shutdown until they are forcefully killed
|
||||
# - the upstream's process id is available in the ${PID} macro
|
||||
cmdStop: docker stop dockertest
|
||||
|
||||
# use "none" to skip check. Caution this may cause some requests to fail
|
||||
# until the upstream server is ready for traffic
|
||||
checkEndpoint: none
|
||||
# groups: a dictionary of group settings
|
||||
# - optional, default: empty dictionary
|
||||
# - provide advanced controls over model swapping behaviour.
|
||||
# - Using groups some models can be kept loaded indefinitely, while others are swapped out.
|
||||
# - model ids must be defined in the Models section
|
||||
# - a model can only be a member of one group
|
||||
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
|
||||
# - see issue #109 for details
|
||||
#
|
||||
# NOTE: the example below uses model names that are not defined above for demonstration purposes
|
||||
groups:
|
||||
# group1 is same as the default behaviour of llama-swap where only one model is allowed
|
||||
# to run a time across the whole llama-swap instance
|
||||
"group1":
|
||||
# swap: controls the model swapping behaviour in within the group
|
||||
# - optional, default: true
|
||||
# - true : only one model is allowed to run at a time
|
||||
# - false: all models can run together, no swapping
|
||||
swap: true
|
||||
|
||||
# don't use these, just for testing if things are broken
|
||||
"broken":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
unlisted: true
|
||||
"broken_timeout":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
proxy: http://127.0.0.1:9000
|
||||
unlisted: true
|
||||
# exclusive: controls how the group affects other groups
|
||||
# - optional, default: true
|
||||
# - true: causes all other groups to unload when this group runs a model
|
||||
# - false: does not affect other groups
|
||||
exclusive: true
|
||||
|
||||
# members references the models defined above
|
||||
# required
|
||||
members:
|
||||
- "llama"
|
||||
- "qwen-unlisted"
|
||||
|
||||
# Example:
|
||||
# - in this group all the models can run at the same time
|
||||
# - when a different group loads all running models in this group are unloaded
|
||||
"group2":
|
||||
swap: false
|
||||
exclusive: false
|
||||
members:
|
||||
- "docker-llama"
|
||||
- "modelA"
|
||||
- "modelB"
|
||||
|
||||
# Example:
|
||||
# - a persistent group, prevents other groups from unloading it
|
||||
"forever":
|
||||
# persistent: prevents over groups from unloading the models in this group
|
||||
# - optional, default: false
|
||||
# - does not affect individual model behaviour
|
||||
persistent: true
|
||||
|
||||
# set swap/exclusive to false to prevent swapping inside the group
|
||||
# and the unloading of other groups
|
||||
swap: false
|
||||
exclusive: false
|
||||
members:
|
||||
- "forever-modelA"
|
||||
- "forever-modelB"
|
||||
- "forever-modelc"
|
||||
|
||||
# hooks: a dictionary of event triggers and actions
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported hook is on_startup
|
||||
hooks:
|
||||
# on_startup: a dictionary of actions to perform on startup
|
||||
# - optional, default: empty dictionar
|
||||
# - the only supported action is preload
|
||||
on_startup:
|
||||
# preload: a list of model ids to load on startup
|
||||
# - optional, default: empty list
|
||||
# - model names must match keys in the models sections
|
||||
# - when preloading multiple models at once, define a group
|
||||
# otherwise models will be loaded and swapped out
|
||||
preload:
|
||||
- "llama"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
healthCheckTimeout: 300
|
||||
logRequests: true
|
||||
metricsMaxInMemory: 1000
|
||||
|
||||
models:
|
||||
"qwen2.5":
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
The code in `event` was originally a part of https://github.com/kelindar/event (v1.5.2)
|
||||
|
||||
The original code uses a `time.Ticker` to process the event queue which caused a large increase in CPU usage ([#189](https://github.com/mostlygeek/llama-swap/issues/189)). This code was ported to remove the ticker and instead be more event driven.
|
||||
@@ -0,0 +1,30 @@
|
||||
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Default initializes a default in-process dispatcher
|
||||
var Default = NewDispatcherConfig(25000)
|
||||
|
||||
// On subscribes to an event, the type of the event will be automatically
|
||||
// inferred from the provided type. Must be constant for this to work. This
|
||||
// functions same way as Subscribe() but uses the default dispatcher instead.
|
||||
func On[T Event](handler func(T)) context.CancelFunc {
|
||||
return Subscribe(Default, handler)
|
||||
}
|
||||
|
||||
// OnType subscribes to an event with the specified event type. This functions
|
||||
// same way as SubscribeTo() but uses the default dispatcher instead.
|
||||
func OnType[T Event](eventType uint32, handler func(T)) context.CancelFunc {
|
||||
return SubscribeTo(Default, eventType, handler)
|
||||
}
|
||||
|
||||
// Emit writes an event into the dispatcher. This functions same way as
|
||||
// Publish() but uses the default dispatcher instead.
|
||||
func Emit[T Event](ev T) {
|
||||
Publish(Default, ev)
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
/*
|
||||
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
|
||||
BenchmarkSubcribeConcurrent-24 1826686 606.3 ns/op 1648 B/op 5 allocs/op
|
||||
*/
|
||||
func BenchmarkSubscribeConcurrent(b *testing.B) {
|
||||
d := NewDispatcher()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
unsub := Subscribe(d, func(ev MyEvent1) {})
|
||||
unsub()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultPublish(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Subscribe
|
||||
var count int64
|
||||
defer On(func(ev MyEvent1) {
|
||||
atomic.AddInt64(&count, 1)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
defer OnType(TypeEvent1, func(ev MyEvent1) {
|
||||
atomic.AddInt64(&count, 1)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
// Publish
|
||||
wg.Add(4)
|
||||
Emit(MyEvent1{})
|
||||
Emit(MyEvent1{})
|
||||
|
||||
// Wait and check
|
||||
wg.Wait()
|
||||
assert.Equal(t, int64(4), count)
|
||||
}
|
||||
@@ -0,0 +1,324 @@
|
||||
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for details.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Event represents an event contract
|
||||
type Event interface {
|
||||
Type() uint32
|
||||
}
|
||||
|
||||
// registry holds an immutable sorted array of event mappings
|
||||
type registry struct {
|
||||
keys []uint32 // Event types (sorted)
|
||||
grps []any // Corresponding subscribers
|
||||
}
|
||||
|
||||
// ------------------------------------- Dispatcher -------------------------------------
|
||||
|
||||
// Dispatcher represents an event dispatcher.
|
||||
type Dispatcher struct {
|
||||
subs atomic.Pointer[registry] // Atomic pointer to immutable array
|
||||
done chan struct{} // Cancellation
|
||||
maxQueue int // Maximum queue size per consumer
|
||||
mu sync.Mutex // Only for writes (subscribe/unsubscribe)
|
||||
}
|
||||
|
||||
// NewDispatcher creates a new dispatcher of events.
|
||||
func NewDispatcher() *Dispatcher {
|
||||
return NewDispatcherConfig(50000)
|
||||
}
|
||||
|
||||
// NewDispatcherConfig creates a new dispatcher with configurable max queue size
|
||||
func NewDispatcherConfig(maxQueue int) *Dispatcher {
|
||||
d := &Dispatcher{
|
||||
done: make(chan struct{}),
|
||||
maxQueue: maxQueue,
|
||||
}
|
||||
|
||||
d.subs.Store(®istry{
|
||||
keys: make([]uint32, 0, 16),
|
||||
grps: make([]any, 0, 16),
|
||||
})
|
||||
return d
|
||||
}
|
||||
|
||||
// Close closes the dispatcher
|
||||
func (d *Dispatcher) Close() error {
|
||||
close(d.done)
|
||||
return nil
|
||||
}
|
||||
|
||||
// isClosed returns whether the dispatcher is closed or not
|
||||
func (d *Dispatcher) isClosed() bool {
|
||||
select {
|
||||
case <-d.done:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// findGroup performs a lock-free binary search for the event type
|
||||
func (d *Dispatcher) findGroup(eventType uint32) any {
|
||||
reg := d.subs.Load()
|
||||
keys := reg.keys
|
||||
|
||||
// Inlined binary search for better cache locality
|
||||
left, right := 0, len(keys)
|
||||
for left < right {
|
||||
mid := left + (right-left)/2
|
||||
if keys[mid] < eventType {
|
||||
left = mid + 1
|
||||
} else {
|
||||
right = mid
|
||||
}
|
||||
}
|
||||
|
||||
if left < len(keys) && keys[left] == eventType {
|
||||
return reg.grps[left]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Subscribe subscribes to an event, the type of the event will be automatically
|
||||
// inferred from the provided type. Must be constant for this to work.
|
||||
func Subscribe[T Event](broker *Dispatcher, handler func(T)) context.CancelFunc {
|
||||
var event T
|
||||
return SubscribeTo(broker, event.Type(), handler)
|
||||
}
|
||||
|
||||
// SubscribeTo subscribes to an event with the specified event type.
|
||||
func SubscribeTo[T Event](broker *Dispatcher, eventType uint32, handler func(T)) context.CancelFunc {
|
||||
if broker.isClosed() {
|
||||
panic(errClosed)
|
||||
}
|
||||
|
||||
broker.mu.Lock()
|
||||
defer broker.mu.Unlock()
|
||||
|
||||
// Check if group already exists
|
||||
if existing := broker.findGroup(eventType); existing != nil {
|
||||
grp := groupOf[T](eventType, existing)
|
||||
sub := grp.Add(handler)
|
||||
return func() {
|
||||
grp.Del(sub)
|
||||
}
|
||||
}
|
||||
|
||||
// Create new group
|
||||
grp := &group[T]{cond: sync.NewCond(new(sync.Mutex)), maxQueue: broker.maxQueue}
|
||||
sub := grp.Add(handler)
|
||||
|
||||
// Copy-on-write: insert new entry in sorted position
|
||||
old := broker.subs.Load()
|
||||
idx := sort.Search(len(old.keys), func(i int) bool {
|
||||
return old.keys[i] >= eventType
|
||||
})
|
||||
|
||||
// Create new arrays with space for one more element
|
||||
newKeys := make([]uint32, len(old.keys)+1)
|
||||
newGrps := make([]any, len(old.grps)+1)
|
||||
|
||||
// Copy elements before insertion point
|
||||
copy(newKeys[:idx], old.keys[:idx])
|
||||
copy(newGrps[:idx], old.grps[:idx])
|
||||
|
||||
// Insert new element
|
||||
newKeys[idx] = eventType
|
||||
newGrps[idx] = grp
|
||||
|
||||
// Copy elements after insertion point
|
||||
copy(newKeys[idx+1:], old.keys[idx:])
|
||||
copy(newGrps[idx+1:], old.grps[idx:])
|
||||
|
||||
// Atomically store the new registry (mutex ensures no concurrent writers)
|
||||
newReg := ®istry{keys: newKeys, grps: newGrps}
|
||||
broker.subs.Store(newReg)
|
||||
|
||||
return func() {
|
||||
grp.Del(sub)
|
||||
}
|
||||
}
|
||||
|
||||
// Publish writes an event into the dispatcher
|
||||
func Publish[T Event](broker *Dispatcher, ev T) {
|
||||
eventType := ev.Type()
|
||||
if sub := broker.findGroup(eventType); sub != nil {
|
||||
group := groupOf[T](eventType, sub)
|
||||
group.Broadcast(ev)
|
||||
}
|
||||
}
|
||||
|
||||
// Count counts the number of subscribers, this is for testing only.
|
||||
func (d *Dispatcher) count(eventType uint32) int {
|
||||
if group := d.findGroup(eventType); group != nil {
|
||||
return group.(interface{ Count() int }).Count()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// groupOf casts the subscriber group to the specified generic type
|
||||
func groupOf[T Event](eventType uint32, subs any) *group[T] {
|
||||
if group, ok := subs.(*group[T]); ok {
|
||||
return group
|
||||
}
|
||||
|
||||
panic(errConflict[T](eventType, subs))
|
||||
}
|
||||
|
||||
// ------------------------------------- Subscriber -------------------------------------
|
||||
|
||||
// consumer represents a consumer with a message queue
|
||||
type consumer[T Event] struct {
|
||||
queue []T // Current work queue
|
||||
stop bool // Stop signal
|
||||
}
|
||||
|
||||
// Listen listens to the event queue and processes events
|
||||
func (s *consumer[T]) Listen(c *sync.Cond, fn func(T)) {
|
||||
pending := make([]T, 0, 128)
|
||||
|
||||
for {
|
||||
c.L.Lock()
|
||||
for len(s.queue) == 0 {
|
||||
switch {
|
||||
case s.stop:
|
||||
c.L.Unlock()
|
||||
return
|
||||
default:
|
||||
c.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
// Swap buffers and reset the current queue
|
||||
temp := s.queue
|
||||
s.queue = pending[:0]
|
||||
pending = temp
|
||||
c.L.Unlock()
|
||||
|
||||
// Outside of the critical section, process the work
|
||||
for _, event := range pending {
|
||||
fn(event)
|
||||
}
|
||||
|
||||
// Notify potential publishers waiting due to backpressure
|
||||
c.Broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------- Subscriber Group -------------------------------------
|
||||
|
||||
// group represents a consumer group
|
||||
type group[T Event] struct {
|
||||
cond *sync.Cond
|
||||
subs []*consumer[T]
|
||||
maxQueue int // Maximum queue size per consumer
|
||||
maxLen int // Current maximum queue length across all consumers
|
||||
}
|
||||
|
||||
// Broadcast sends an event to all consumers
|
||||
func (s *group[T]) Broadcast(ev T) {
|
||||
s.cond.L.Lock()
|
||||
defer s.cond.L.Unlock()
|
||||
|
||||
// Calculate current maximum queue length
|
||||
s.maxLen = 0
|
||||
for _, sub := range s.subs {
|
||||
if len(sub.queue) > s.maxLen {
|
||||
s.maxLen = len(sub.queue)
|
||||
}
|
||||
}
|
||||
|
||||
// Backpressure: wait if queues are full
|
||||
for s.maxLen >= s.maxQueue {
|
||||
s.cond.Wait()
|
||||
|
||||
// Recalculate after wakeup
|
||||
s.maxLen = 0
|
||||
for _, sub := range s.subs {
|
||||
if len(sub.queue) > s.maxLen {
|
||||
s.maxLen = len(sub.queue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add event to all queues and track new maximum
|
||||
newMax := 0
|
||||
for _, sub := range s.subs {
|
||||
sub.queue = append(sub.queue, ev)
|
||||
if len(sub.queue) > newMax {
|
||||
newMax = len(sub.queue)
|
||||
}
|
||||
}
|
||||
s.maxLen = newMax
|
||||
s.cond.Broadcast() // Wake consumers
|
||||
}
|
||||
|
||||
// Add adds a subscriber to the list
|
||||
func (s *group[T]) Add(handler func(T)) *consumer[T] {
|
||||
sub := &consumer[T]{
|
||||
queue: make([]T, 0, 64),
|
||||
}
|
||||
|
||||
// Add the consumer to the list of active consumers
|
||||
s.cond.L.Lock()
|
||||
s.subs = append(s.subs, sub)
|
||||
s.cond.L.Unlock()
|
||||
|
||||
// Start listening
|
||||
go sub.Listen(s.cond, handler)
|
||||
return sub
|
||||
}
|
||||
|
||||
// Del removes a subscriber from the list
|
||||
func (s *group[T]) Del(sub *consumer[T]) {
|
||||
s.cond.L.Lock()
|
||||
defer s.cond.L.Unlock()
|
||||
|
||||
// Search and remove the subscriber
|
||||
sub.stop = true
|
||||
for i, v := range s.subs {
|
||||
if v == sub {
|
||||
copy(s.subs[i:], s.subs[i+1:])
|
||||
s.subs = s.subs[:len(s.subs)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------- Debugging -------------------------------------
|
||||
|
||||
var errClosed = fmt.Errorf("event dispatcher is closed")
|
||||
|
||||
// Count returns the number of subscribers in this group
|
||||
func (s *group[T]) Count() int {
|
||||
return len(s.subs)
|
||||
}
|
||||
|
||||
// String returns string representation of the type
|
||||
func (s *group[T]) String() string {
|
||||
typ := reflect.TypeOf(s).String()
|
||||
idx := strings.LastIndex(typ, "/")
|
||||
typ = typ[idx+1 : len(typ)-1]
|
||||
return typ
|
||||
}
|
||||
|
||||
// errConflict returns a conflict message
|
||||
func errConflict[T any](eventType uint32, existing any) string {
|
||||
var want T
|
||||
return fmt.Sprintf(
|
||||
"conflicting event type, want=<%T>, registered=<%s>, event=0x%v",
|
||||
want, existing, eventType,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,324 @@
|
||||
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPublish(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Subscribe, must be received in order
|
||||
var count int64
|
||||
defer Subscribe(d, func(ev MyEvent1) {
|
||||
assert.Equal(t, int(atomic.AddInt64(&count, 1)), ev.Number)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
// Publish
|
||||
wg.Add(3)
|
||||
Publish(d, MyEvent1{Number: 1})
|
||||
Publish(d, MyEvent1{Number: 2})
|
||||
Publish(d, MyEvent1{Number: 3})
|
||||
|
||||
// Wait and check
|
||||
wg.Wait()
|
||||
assert.Equal(t, int64(3), count)
|
||||
}
|
||||
|
||||
func TestUnsubscribe(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
assert.Equal(t, 0, d.count(TypeEvent1))
|
||||
unsubscribe := Subscribe(d, func(ev MyEvent1) {
|
||||
// Nothing
|
||||
})
|
||||
|
||||
assert.Equal(t, 1, d.count(TypeEvent1))
|
||||
unsubscribe()
|
||||
assert.Equal(t, 0, d.count(TypeEvent1))
|
||||
}
|
||||
|
||||
func TestConcurrent(t *testing.T) {
|
||||
const max = 1000000
|
||||
var count int64
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
d := NewDispatcher()
|
||||
defer Subscribe(d, func(ev MyEvent1) {
|
||||
if current := atomic.AddInt64(&count, 1); current == max {
|
||||
wg.Done()
|
||||
}
|
||||
})()
|
||||
|
||||
// Asynchronously publish
|
||||
go func() {
|
||||
for i := 0; i < max; i++ {
|
||||
Publish(d, MyEvent1{})
|
||||
}
|
||||
}()
|
||||
|
||||
defer Subscribe(d, func(ev MyEvent1) {
|
||||
// Subscriber that does nothing
|
||||
})()
|
||||
|
||||
wg.Wait()
|
||||
assert.Equal(t, max, int(count))
|
||||
}
|
||||
|
||||
func TestSubscribeDifferentType(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
assert.Panics(t, func() {
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent1) {})
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||
})
|
||||
}
|
||||
|
||||
func TestPublishDifferentType(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
assert.Panics(t, func() {
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||
Publish(d, MyEvent1{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestCloseDispatcher(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
defer SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})()
|
||||
|
||||
assert.NoError(t, d.Close())
|
||||
assert.Panics(t, func() {
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatrix(t *testing.T) {
|
||||
const amount = 1000
|
||||
for _, subs := range []int{1, 10, 100} {
|
||||
for _, topics := range []int{1, 10} {
|
||||
expected := subs * topics * amount
|
||||
t.Run(fmt.Sprintf("%dx%d", topics, subs), func(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(expected)
|
||||
|
||||
d := NewDispatcher()
|
||||
for i := 0; i < subs; i++ {
|
||||
for id := 0; id < topics; id++ {
|
||||
defer SubscribeTo(d, uint32(id), func(ev MyEvent3) {
|
||||
count.Add(1)
|
||||
wg.Done()
|
||||
})()
|
||||
}
|
||||
}
|
||||
|
||||
for n := 0; n < amount; n++ {
|
||||
for id := 0; id < topics; id++ {
|
||||
go Publish(d, MyEvent3{ID: id})
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
assert.Equal(t, expected, int(count.Load()))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentSubscriptionRace(t *testing.T) {
|
||||
// This test specifically targets the race condition that occurs when multiple
|
||||
// goroutines try to subscribe to different event types simultaneously.
|
||||
// Without the CAS loop, subscriptions could be lost due to registry corruption.
|
||||
|
||||
const numGoroutines = 100
|
||||
const numEventTypes = 50
|
||||
|
||||
d := NewDispatcher()
|
||||
defer d.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var receivedCount int64
|
||||
var subscribedTypes sync.Map // Thread-safe map
|
||||
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
// Start multiple goroutines that subscribe to different event types concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Each goroutine subscribes to a unique event type
|
||||
eventType := uint32(goroutineID%numEventTypes + 1000) // Offset to avoid collision with other tests
|
||||
|
||||
// Subscribe to the event type
|
||||
SubscribeTo(d, eventType, func(ev MyEvent3) {
|
||||
atomic.AddInt64(&receivedCount, 1)
|
||||
})
|
||||
|
||||
// Record that this type was subscribed
|
||||
subscribedTypes.Store(eventType, true)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all subscriptions to complete
|
||||
wg.Wait()
|
||||
|
||||
// Count the number of unique event types subscribed
|
||||
expectedTypes := 0
|
||||
subscribedTypes.Range(func(key, value interface{}) bool {
|
||||
expectedTypes++
|
||||
return true
|
||||
})
|
||||
|
||||
// Small delay to ensure all subscriptions are fully processed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Publish events to each subscribed type
|
||||
subscribedTypes.Range(func(key, value interface{}) bool {
|
||||
eventType := key.(uint32)
|
||||
Publish(d, MyEvent3{ID: int(eventType)})
|
||||
return true
|
||||
})
|
||||
|
||||
// Wait for all events to be processed
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Verify that we received at least the expected number of events
|
||||
// (there might be more if multiple goroutines subscribed to the same event type)
|
||||
received := atomic.LoadInt64(&receivedCount)
|
||||
assert.GreaterOrEqual(t, int(received), expectedTypes,
|
||||
"Should have received at least %d events, got %d", expectedTypes, received)
|
||||
|
||||
// Verify that we have the expected number of unique event types
|
||||
assert.Equal(t, numEventTypes, expectedTypes,
|
||||
"Should have exactly %d unique event types", numEventTypes)
|
||||
}
|
||||
|
||||
func TestConcurrentHandlerRegistration(t *testing.T) {
|
||||
const numGoroutines = 100
|
||||
|
||||
// Test concurrent subscriptions to the same event type
|
||||
t.Run("SameEventType", func(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
var handlerCount int64
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start multiple goroutines subscribing to the same event type (0x1)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
SubscribeTo(d, uint32(0x1), func(ev MyEvent1) {
|
||||
atomic.AddInt64(&handlerCount, 1)
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all handlers were registered by publishing an event
|
||||
atomic.StoreInt64(&handlerCount, 0)
|
||||
Publish(d, MyEvent1{})
|
||||
|
||||
// Small delay to ensure all handlers have executed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, int64(numGoroutines), atomic.LoadInt64(&handlerCount),
|
||||
"Not all handlers were registered due to race condition")
|
||||
})
|
||||
|
||||
// Test concurrent subscriptions to different event types
|
||||
t.Run("DifferentEventTypes", func(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
var wg sync.WaitGroup
|
||||
receivedEvents := make(map[uint32]*int64)
|
||||
|
||||
// Create multiple event types and subscribe concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
eventType := uint32(100 + i)
|
||||
counter := new(int64)
|
||||
receivedEvents[eventType] = counter
|
||||
|
||||
wg.Add(1)
|
||||
go func(et uint32, cnt *int64) {
|
||||
defer wg.Done()
|
||||
SubscribeTo(d, et, func(ev MyEvent3) {
|
||||
atomic.AddInt64(cnt, 1)
|
||||
})
|
||||
}(eventType, counter)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Publish events to all types
|
||||
for eventType := uint32(100); eventType < uint32(100+numGoroutines); eventType++ {
|
||||
Publish(d, MyEvent3{ID: int(eventType)})
|
||||
}
|
||||
|
||||
// Small delay to ensure all handlers have executed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Verify all event types received their events
|
||||
for eventType, counter := range receivedEvents {
|
||||
assert.Equal(t, int64(1), atomic.LoadInt64(counter),
|
||||
"Event type %d did not receive its event", eventType)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackpressure(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
d.maxQueue = 10
|
||||
|
||||
var processedCount int64
|
||||
unsub := SubscribeTo(d, uint32(0x200), func(ev MyEvent3) {
|
||||
atomic.AddInt64(&processedCount, 1)
|
||||
})
|
||||
defer unsub()
|
||||
|
||||
const eventsToPublish = 1000
|
||||
for i := 0; i < eventsToPublish; i++ {
|
||||
Publish(d, MyEvent3{ID: 0x200})
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify all events were eventually processed
|
||||
finalProcessed := atomic.LoadInt64(&processedCount)
|
||||
assert.Equal(t, int64(eventsToPublish), finalProcessed)
|
||||
t.Logf("Events processed: %d/%d", finalProcessed, eventsToPublish)
|
||||
}
|
||||
|
||||
// ------------------------------------- Test Events -------------------------------------
|
||||
|
||||
const (
|
||||
TypeEvent1 = 0x1
|
||||
TypeEvent2 = 0x2
|
||||
)
|
||||
|
||||
type MyEvent1 struct {
|
||||
Number int
|
||||
}
|
||||
|
||||
func (t MyEvent1) Type() uint32 { return TypeEvent1 }
|
||||
|
||||
type MyEvent2 struct {
|
||||
Text string
|
||||
}
|
||||
|
||||
func (t MyEvent2) Type() uint32 { return TypeEvent2 }
|
||||
|
||||
type MyEvent3 struct {
|
||||
ID int
|
||||
}
|
||||
|
||||
func (t MyEvent3) Type() uint32 { return uint32(t.ID) }
|
||||
@@ -3,6 +3,7 @@ module github.com/mostlygeek/llama-swap
|
||||
go 1.23.0
|
||||
|
||||
require (
|
||||
github.com/billziss-gh/golib v0.2.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
@@ -12,7 +13,6 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/billziss-gh/golib v0.2.0 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
|
||||
@@ -32,8 +32,6 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/proxy"
|
||||
)
|
||||
|
||||
@@ -53,137 +54,135 @@ func main() {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
proxyManager := proxy.New(config)
|
||||
|
||||
// Setup channels for server management
|
||||
reloadChan := make(chan *proxy.ProxyManager)
|
||||
exitChan := make(chan struct{})
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Create server with initial handler
|
||||
srv := &http.Server{
|
||||
Addr: *listenStr,
|
||||
Handler: proxyManager,
|
||||
Addr: *listenStr,
|
||||
}
|
||||
|
||||
// Support for watching config and reloading when it changes
|
||||
reloadProxyManager := func() {
|
||||
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||
config, err = proxy.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Configuration Changed")
|
||||
currentPM.Shutdown()
|
||||
srv.Handler = proxy.New(config)
|
||||
fmt.Println("Configuration Reloaded")
|
||||
|
||||
// wait a few seconds and tell any UI to reload
|
||||
time.AfterFunc(3*time.Second, func() {
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateEnd,
|
||||
})
|
||||
})
|
||||
} else {
|
||||
config, err = proxy.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error, unable to load configuration: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
srv.Handler = proxy.New(config)
|
||||
}
|
||||
}
|
||||
|
||||
// load the initial proxy manager
|
||||
reloadProxyManager()
|
||||
debouncedReload := debounce(time.Second, reloadProxyManager)
|
||||
if *watchConfig {
|
||||
defer event.On(func(e proxy.ConfigFileChangedEvent) {
|
||||
if e.ReloadingState == proxy.ReloadingStateStart {
|
||||
debouncedReload()
|
||||
}
|
||||
})()
|
||||
|
||||
fmt.Println("Watching Configuration for changes")
|
||||
go func() {
|
||||
absConfigPath, err := filepath.Abs(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error getting absolute path for watching config file: %v\n", err)
|
||||
return
|
||||
}
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
fmt.Printf("Error creating file watcher: %v. File watching disabled.\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
configDir := filepath.Dir(absConfigPath)
|
||||
err = watcher.Add(configDir)
|
||||
if err != nil {
|
||||
fmt.Printf("Error adding config path directory (%s) to watcher: %v. File watching disabled.", configDir, err)
|
||||
return
|
||||
}
|
||||
|
||||
defer watcher.Close()
|
||||
for {
|
||||
select {
|
||||
case changeEvent := <-watcher.Events:
|
||||
if changeEvent.Name == absConfigPath && (changeEvent.Has(fsnotify.Write) || changeEvent.Has(fsnotify.Create) || changeEvent.Has(fsnotify.Remove)) {
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateStart,
|
||||
})
|
||||
} else if changeEvent.Name == filepath.Join(configDir, "..data") && changeEvent.Has(fsnotify.Create) {
|
||||
// the change for k8s configmap
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateStart,
|
||||
})
|
||||
}
|
||||
|
||||
case err := <-watcher.Errors:
|
||||
log.Printf("File watcher error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// shutdown on signal
|
||||
go func() {
|
||||
sig := <-sigChan
|
||||
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
if pm, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||
pm.Shutdown()
|
||||
} else {
|
||||
fmt.Println("srv.Handler is not of type *proxy.ProxyManager")
|
||||
}
|
||||
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
fmt.Printf("Server shutdown error: %v\n", err)
|
||||
}
|
||||
close(exitChan)
|
||||
}()
|
||||
|
||||
// Start server
|
||||
fmt.Printf("llama-swap listening on %s\n", *listenStr)
|
||||
go func() {
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
fmt.Printf("Fatal server error: %v\n", err)
|
||||
close(exitChan)
|
||||
log.Fatalf("Fatal server error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Handle config reloads and signals
|
||||
go func() {
|
||||
currentManager := proxyManager
|
||||
for {
|
||||
select {
|
||||
case newManager := <-reloadChan:
|
||||
log.Println("Config change detected, waiting for in-flight requests to complete...")
|
||||
// Stop old manager processes gracefully (this waits for in-flight requests)
|
||||
currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
|
||||
// Now do a full shutdown to clear the process map
|
||||
currentManager.Shutdown()
|
||||
currentManager = newManager
|
||||
srv.Handler = newManager
|
||||
log.Println("Server handler updated with new config")
|
||||
case sig := <-sigChan:
|
||||
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
currentManager.Shutdown()
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
fmt.Printf("Server shutdown error: %v\n", err)
|
||||
}
|
||||
close(exitChan)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Start file watcher if requested
|
||||
if *watchConfig {
|
||||
absConfigPath, err := filepath.Abs(*configPath)
|
||||
if err != nil {
|
||||
log.Printf("Error getting absolute path for config: %v. File watching disabled.", err)
|
||||
} else {
|
||||
go watchConfigFileWithReload(absConfigPath, reloadChan)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for exit signal
|
||||
<-exitChan
|
||||
}
|
||||
|
||||
// watchConfigFileWithReload monitors the configuration file and sends new ProxyManager instances through reloadChan.
|
||||
func watchConfigFileWithReload(configPath string, reloadChan chan<- *proxy.ProxyManager) {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Printf("Error creating file watcher: %v. File watching disabled.", err)
|
||||
return
|
||||
}
|
||||
defer watcher.Close()
|
||||
|
||||
err = watcher.Add(configPath)
|
||||
if err != nil {
|
||||
log.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", configPath, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Watching config file for changes: %s", configPath)
|
||||
|
||||
var debounceTimer *time.Timer
|
||||
debounceDuration := 2 * time.Second
|
||||
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// We only care about writes to the specific config file
|
||||
if event.Name == configPath && event.Has(fsnotify.Write) {
|
||||
// Reset or start the debounce timer
|
||||
if debounceTimer != nil {
|
||||
debounceTimer.Stop()
|
||||
}
|
||||
debounceTimer = time.AfterFunc(debounceDuration, func() {
|
||||
log.Printf("Config file modified: %s, reloading...", event.Name)
|
||||
|
||||
// Try up to 3 times with exponential backoff
|
||||
var newConfig proxy.Config
|
||||
var err error
|
||||
for retries := 0; retries < 3; retries++ {
|
||||
// Load new configuration
|
||||
newConfig, err = proxy.LoadConfig(configPath)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
log.Printf("Error loading new config (attempt %d/3): %v", retries+1, err)
|
||||
if retries < 2 {
|
||||
time.Sleep(time.Duration(1<<retries) * time.Second)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Failed to load new config after retries: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create new ProxyManager with new config
|
||||
newPM := proxy.New(newConfig)
|
||||
reloadChan <- newPM
|
||||
log.Println("Config reloaded successfully")
|
||||
})
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
log.Println("File watcher error channel closed.")
|
||||
return
|
||||
}
|
||||
log.Printf("File watcher error: %v", err)
|
||||
func debounce(interval time.Duration, f func()) func() {
|
||||
var timer *time.Timer
|
||||
return func() {
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
timer = time.AfterFunc(interval, f)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
package main
|
||||
|
||||
// created for issue: #252 https://github.com/mostlygeek/llama-swap/issues/252
|
||||
// this simple benchmark tool sends a lot of small chat completion requests to llama-swap
|
||||
// to make sure all the requests are accounted for.
|
||||
//
|
||||
// requests can be sent in parallel, and the tool will report the results.
|
||||
// usage: go run main.go -baseurl http://localhost:8080/v1 -model llama3 -requests 1000 -par 5
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// ----- CLI arguments ----------------------------------------------------
|
||||
var (
|
||||
baseurl string
|
||||
modelName string
|
||||
totalRequests int
|
||||
parallelization int
|
||||
)
|
||||
|
||||
flag.StringVar(&baseurl, "baseurl", "http://localhost:8080/v1", "Base URL of the API (e.g., https://api.example.com)")
|
||||
flag.StringVar(&modelName, "model", "", "Model name to use")
|
||||
flag.IntVar(&totalRequests, "requests", 1, "Total number of requests to send")
|
||||
flag.IntVar(¶llelization, "par", 1, "Maximum number of concurrent requests")
|
||||
flag.Parse()
|
||||
|
||||
if baseurl == "" || modelName == "" {
|
||||
fmt.Println("Error: both -baseurl and -model are required.")
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
if totalRequests <= 0 {
|
||||
fmt.Println("Error: -requests must be greater than 0.")
|
||||
os.Exit(1)
|
||||
}
|
||||
if parallelization <= 0 {
|
||||
fmt.Println("Error: -parallelization must be greater than 0.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// ----- HTTP client -------------------------------------------------------
|
||||
client := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// ----- Tracking response codes -------------------------------------------
|
||||
statusCounts := make(map[int]int) // map[statusCode]count
|
||||
var mu sync.Mutex // protects statusCounts
|
||||
|
||||
// ----- Request queue (buffered channel) ----------------------------------
|
||||
requests := make(chan int, 10) // Buffered channel with capacity 10
|
||||
|
||||
// Goroutine to fill the request queue
|
||||
go func() {
|
||||
for i := 0; i < totalRequests; i++ {
|
||||
requests <- i + 1
|
||||
}
|
||||
close(requests)
|
||||
}()
|
||||
|
||||
// ----- Worker pool -------------------------------------------------------
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < parallelization; i++ {
|
||||
wg.Add(1)
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for reqID := range requests {
|
||||
// Build request payload as a single line JSON string
|
||||
payload := `{"model":"` + modelName + `","max_tokens":100,"stream":false,"messages":[{"role":"user","content":"write a snake game in python"}]}`
|
||||
|
||||
// Send POST request
|
||||
req, err := http.NewRequest(http.MethodPost,
|
||||
fmt.Sprintf("%s/chat/completions", baseurl),
|
||||
bytes.NewReader([]byte(payload)))
|
||||
if err != nil {
|
||||
log.Printf("[worker %d][req %d] request creation error: %v", workerID, reqID, err)
|
||||
mu.Lock()
|
||||
statusCounts[-1]++
|
||||
mu.Unlock()
|
||||
continue
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("[worker %d][req %d] HTTP request error: %v", workerID, reqID, err)
|
||||
mu.Lock()
|
||||
statusCounts[-1]++
|
||||
mu.Unlock()
|
||||
continue
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
// Record status code
|
||||
mu.Lock()
|
||||
statusCounts[resp.StatusCode]++
|
||||
mu.Unlock()
|
||||
}
|
||||
}(i + 1)
|
||||
}
|
||||
|
||||
// ----- Status ticker (prints every second) -------------------------------
|
||||
done := make(chan struct{})
|
||||
tickerDone := make(chan struct{})
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
startTime := time.Now()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
mu.Lock()
|
||||
// Compute how many requests have completed so far
|
||||
completed := 0
|
||||
for _, cnt := range statusCounts {
|
||||
completed += cnt
|
||||
}
|
||||
// Calculate duration and progress
|
||||
duration := time.Since(startTime)
|
||||
progress := completed * 100 / totalRequests
|
||||
fmt.Printf("Duration: %v, Completed: %d%% requests\n", duration, progress)
|
||||
mu.Unlock()
|
||||
case <-done:
|
||||
duration := time.Since(startTime)
|
||||
fmt.Printf("Duration: %v, Completed: %d%% requests\n", duration, 100)
|
||||
close(tickerDone)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for all workers to finish
|
||||
wg.Wait()
|
||||
close(done) // stops the status-update goroutine
|
||||
<-tickerDone // give ticker time to finish / print
|
||||
|
||||
// ----- Summary ------------------------------------------------------------
|
||||
fmt.Println("\n\n=== HTTP response code summary ===")
|
||||
mu.Lock()
|
||||
for code, cnt := range statusCounts {
|
||||
if code == -1 {
|
||||
fmt.Printf("Client-side errors (no HTTP response): %d\n", cnt)
|
||||
} else {
|
||||
fmt.Printf("%d : %d\n", code, cnt)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
@@ -35,17 +35,90 @@ func main() {
|
||||
|
||||
// Set up the handler function using the provided response message
|
||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
// Check if streaming is requested
|
||||
// Query is checked instead of JSON body since that event stream conflicts with other tests
|
||||
isStreaming := c.Query("stream") == "true"
|
||||
|
||||
if isStreaming {
|
||||
// Set headers for streaming
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
// Send 10 "asdf" tokens
|
||||
for i := 0; i < 10; i++ {
|
||||
data := gin.H{
|
||||
"created": time.Now().Unix(),
|
||||
"choices": []gin.H{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": gin.H{
|
||||
"content": "asdf",
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
c.SSEvent("message", data)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
// Send final data with usage info
|
||||
finalData := gin.H{
|
||||
"usage": gin.H{
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 35,
|
||||
},
|
||||
// add timings to simulate llama.cpp
|
||||
"timings": gin.H{
|
||||
"prompt_n": 25,
|
||||
"prompt_ms": 13,
|
||||
"predicted_n": 10,
|
||||
"predicted_ms": 17,
|
||||
"predicted_per_second": 10,
|
||||
},
|
||||
}
|
||||
c.SSEvent("message", finalData)
|
||||
c.Writer.Flush()
|
||||
|
||||
// Send [DONE]
|
||||
c.SSEvent("message", "[DONE]")
|
||||
c.Writer.Flush()
|
||||
} else {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
"h_content_length": c.Request.Header.Get("Content-Length"),
|
||||
"request_body": string(bodyBytes),
|
||||
"usage": gin.H{
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 35,
|
||||
},
|
||||
"timings": gin.H{
|
||||
"prompt_n": 25,
|
||||
"prompt_ms": 13,
|
||||
"predicted_n": 10,
|
||||
"predicted_ms": 17,
|
||||
"predicted_per_second": 10,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
"h_content_length": c.Request.Header.Get("Content-Length"),
|
||||
})
|
||||
})
|
||||
|
||||
// for issue #62 to check model name strips profile slug
|
||||
@@ -71,6 +144,11 @@ func main() {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
"usage": gin.H{
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 35,
|
||||
},
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -27,8 +28,15 @@ type ModelConfig struct {
|
||||
Unlisted bool `yaml:"unlisted"`
|
||||
UseModelName string `yaml:"useModelName"`
|
||||
|
||||
// #179 for /v1/models
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
// Limit concurrency of HTTP requests to process
|
||||
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||
|
||||
// Model filters see issue #174
|
||||
Filters ModelFilters `yaml:"filters"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
@@ -44,6 +52,8 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
Unlisted: false,
|
||||
UseModelName: "",
|
||||
ConcurrencyLimit: 0,
|
||||
Name: "",
|
||||
Description: "",
|
||||
}
|
||||
|
||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||
@@ -63,6 +73,46 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
// ModelFilters see issue #174
|
||||
type ModelFilters struct {
|
||||
StripParams string `yaml:"strip_params"`
|
||||
}
|
||||
|
||||
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelFilters ModelFilters
|
||||
defaults := rawModelFilters{
|
||||
StripParams: "",
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*m = ModelFilters(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||
if f.StripParams == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
params := strings.Split(f.StripParams, ",")
|
||||
cleaned := make([]string, 0, len(params))
|
||||
|
||||
for _, param := range params {
|
||||
trimmed := strings.TrimSpace(param)
|
||||
if trimmed == "model" || trimmed == "" {
|
||||
continue
|
||||
}
|
||||
cleaned = append(cleaned, trimmed)
|
||||
}
|
||||
|
||||
// sort cleaned
|
||||
slices.Sort(cleaned)
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
type GroupConfig struct {
|
||||
Swap bool `yaml:"swap"`
|
||||
Exclusive bool `yaml:"exclusive"`
|
||||
@@ -88,10 +138,19 @@ func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type HooksConfig struct {
|
||||
OnStartup HookOnStartup `yaml:"on_startup"`
|
||||
}
|
||||
|
||||
type HookOnStartup struct {
|
||||
Preload []string `yaml:"preload"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
@@ -104,6 +163,9 @@ type Config struct {
|
||||
|
||||
// automatic port assignments
|
||||
StartPort int `yaml:"startPort"`
|
||||
|
||||
// hooks, see: #209
|
||||
Hooks HooksConfig `yaml:"hooks"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
@@ -144,6 +206,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
MetricsMaxInMemory: 1000,
|
||||
}
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
@@ -205,6 +268,10 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
|
||||
// Strip comments from command fields before macro expansion
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
|
||||
for macroName, macroValue := range config.Macros {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
@@ -212,6 +279,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroValue)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroValue)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroValue)
|
||||
}
|
||||
|
||||
// enforce ${PORT} used in both cmd and proxy
|
||||
@@ -273,6 +341,22 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
if real, found := config.RealModelName(modelID); found {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -355,3 +439,16 @@ func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func StripComments(cmdStr string) string {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
|
||||
@@ -83,6 +83,9 @@ models:
|
||||
assert.Equal(t, "", model1.UseModelName)
|
||||
assert.Equal(t, 0, model1.ConcurrencyLimit)
|
||||
}
|
||||
|
||||
// default empty filter exists
|
||||
assert.Equal(t, "", model1.Filters.StripParams)
|
||||
}
|
||||
|
||||
func TestConfig_LoadPosix(t *testing.T) {
|
||||
@@ -97,10 +100,15 @@ func TestConfig_LoadPosix(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
hooks:
|
||||
on_startup:
|
||||
preload: ["model1", "model2"]
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
name: "Model 1"
|
||||
description: "This is model 1"
|
||||
aliases:
|
||||
- "m1"
|
||||
- "model-one"
|
||||
@@ -158,6 +166,11 @@ groups:
|
||||
Macros: map[string]string{
|
||||
"svr-path": "path/to/server",
|
||||
},
|
||||
Hooks: HooksConfig{
|
||||
OnStartup: HookOnStartup{
|
||||
Preload: []string{"model1", "model2"},
|
||||
},
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
@@ -165,6 +178,8 @@ groups:
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
Name: "Model 1",
|
||||
Description: "This is model 1",
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
@@ -189,6 +204,7 @@ groups:
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -300,3 +301,142 @@ models:
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_ModelFilters(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
default_strip: "temperature, top_p"
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
strip_params: "model, top_k, ${default_strip}, , ,"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
modelConfig, ok := config.Models["model1"]
|
||||
if !assert.True(t, ok) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// make sure `model` and enmpty strings are not in the list
|
||||
assert.Equal(t, "model, top_k, temperature, top_p, , ,", modelConfig.Filters.StripParams)
|
||||
sanitized, err := modelConfig.Filters.SanitizedStripParams()
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripComments(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no comments",
|
||||
input: "echo hello\necho world",
|
||||
expected: "echo hello\necho world",
|
||||
},
|
||||
{
|
||||
name: "single comment line",
|
||||
input: "# this is a comment\necho hello",
|
||||
expected: "echo hello",
|
||||
},
|
||||
{
|
||||
name: "multiple comment lines",
|
||||
input: "# comment 1\necho hello\n# comment 2\necho world",
|
||||
expected: "echo hello\necho world",
|
||||
},
|
||||
{
|
||||
name: "comment with spaces",
|
||||
input: " # indented comment\necho hello",
|
||||
expected: "echo hello",
|
||||
},
|
||||
{
|
||||
name: "empty lines preserved",
|
||||
input: "echo hello\n\necho world",
|
||||
expected: "echo hello\n\necho world",
|
||||
},
|
||||
{
|
||||
name: "only comments",
|
||||
input: "# comment 1\n# comment 2",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := StripComments(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("StripComments() = %q, expected %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_MacroInCommentStrippedBeforeExpansion(t *testing.T) {
|
||||
// Test case that reproduces the original bug where a macro in a comment
|
||||
// would get expanded and cause the comment text to be included in the command
|
||||
content := `
|
||||
startPort: 9990
|
||||
macros:
|
||||
"latest-llama": >
|
||||
/user/llama.cpp/build/bin/llama-server
|
||||
--port ${PORT}
|
||||
|
||||
models:
|
||||
"test-model":
|
||||
cmd: |
|
||||
# ${latest-llama} is a macro that is defined above
|
||||
${latest-llama}
|
||||
--model /path/to/model.gguf
|
||||
-ngl 99
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Get the sanitized command
|
||||
sanitizedCmd, err := SanitizeCommand(config.Models["test-model"].Cmd)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Join the command for easier inspection
|
||||
cmdStr := strings.Join(sanitizedCmd, " ")
|
||||
|
||||
// Verify that comment text is NOT present in the final command as separate arguments
|
||||
commentWords := []string{"is", "macro", "that", "defined", "above"}
|
||||
for _, word := range commentWords {
|
||||
found := slices.Contains(sanitizedCmd, word)
|
||||
assert.False(t, found, "Comment text '%s' should not be present as a separate argument in final command", word)
|
||||
}
|
||||
|
||||
// Verify that the actual command components ARE present
|
||||
expectedParts := []string{
|
||||
"/user/llama.cpp/build/bin/llama-server",
|
||||
"--port",
|
||||
"9990",
|
||||
"--model",
|
||||
"/path/to/model.gguf",
|
||||
"-ngl",
|
||||
"99",
|
||||
}
|
||||
|
||||
for _, part := range expectedParts {
|
||||
assert.Contains(t, cmdStr, part, "Expected command part '%s' not found in final command", part)
|
||||
}
|
||||
|
||||
// Verify the server path appears exactly once (not duplicated due to macro expansion)
|
||||
serverPath := "/user/llama.cpp/build/bin/llama-server"
|
||||
count := strings.Count(cmdStr, serverPath)
|
||||
assert.Equal(t, 1, count, "Expected exactly 1 occurrence of server path, found %d", count)
|
||||
|
||||
// Verify the expected final command structure
|
||||
expectedCmd := "/user/llama.cpp/build/bin/llama-server --port 9990 --model /path/to/model.gguf -ngl 99"
|
||||
assert.Equal(t, expectedCmd, cmdStr, "Final command does not match expected structure")
|
||||
}
|
||||
|
||||
@@ -80,6 +80,9 @@ models:
|
||||
assert.Equal(t, "", model1.UseModelName)
|
||||
assert.Equal(t, 0, model1.ConcurrencyLimit)
|
||||
}
|
||||
|
||||
// default empty filter exists
|
||||
assert.Equal(t, "", model1.Filters.StripParams)
|
||||
}
|
||||
|
||||
func TestConfig_LoadWindows(t *testing.T) {
|
||||
@@ -190,6 +193,7 @@ groups:
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
package proxy
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Custom discard writer that implements http.ResponseWriter but just discards everything
|
||||
type DiscardWriter struct {
|
||||
header http.Header
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) Header() http.Header {
|
||||
if w.header == nil {
|
||||
w.header = make(http.Header)
|
||||
}
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) Write(data []byte) (int, error) {
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) WriteHeader(code int) {
|
||||
w.status = code
|
||||
}
|
||||
|
||||
// Satisfy the http.Flusher interface for streaming responses
|
||||
func (w *DiscardWriter) Flush() {}
|
||||
@@ -0,0 +1,60 @@
|
||||
package proxy
|
||||
|
||||
// package level registry of the different event types
|
||||
|
||||
const ProcessStateChangeEventID = 0x01
|
||||
const ChatCompletionStatsEventID = 0x02
|
||||
const ConfigFileChangedEventID = 0x03
|
||||
const LogDataEventID = 0x04
|
||||
const TokenMetricsEventID = 0x05
|
||||
const ModelPreloadedEventID = 0x06
|
||||
|
||||
type ProcessStateChangeEvent struct {
|
||||
ProcessName string
|
||||
NewState ProcessState
|
||||
OldState ProcessState
|
||||
}
|
||||
|
||||
func (e ProcessStateChangeEvent) Type() uint32 {
|
||||
return ProcessStateChangeEventID
|
||||
}
|
||||
|
||||
type ChatCompletionStats struct {
|
||||
TokensGenerated int
|
||||
}
|
||||
|
||||
func (e ChatCompletionStats) Type() uint32 {
|
||||
return ChatCompletionStatsEventID
|
||||
}
|
||||
|
||||
type ReloadingState int
|
||||
|
||||
const (
|
||||
ReloadingStateStart ReloadingState = iota
|
||||
ReloadingStateEnd
|
||||
)
|
||||
|
||||
type ConfigFileChangedEvent struct {
|
||||
ReloadingState ReloadingState
|
||||
}
|
||||
|
||||
func (e ConfigFileChangedEvent) Type() uint32 {
|
||||
return ConfigFileChangedEventID
|
||||
}
|
||||
|
||||
type LogDataEvent struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (e LogDataEvent) Type() uint32 {
|
||||
return LogDataEventID
|
||||
}
|
||||
|
||||
type ModelPreloadedEvent struct {
|
||||
ModelName string
|
||||
Success bool
|
||||
}
|
||||
|
||||
func (e ModelPreloadedEvent) Type() uint32 {
|
||||
return ModelPreloadedEventID
|
||||
}
|
||||
@@ -13,9 +13,10 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
nextTestPort int = 12000
|
||||
portMutex sync.Mutex
|
||||
testLogger = NewLogMonitorWriter(os.Stdout)
|
||||
nextTestPort int = 12000
|
||||
portMutex sync.Mutex
|
||||
testLogger = NewLogMonitorWriter(os.Stdout)
|
||||
simpleResponderPath = getSimpleResponderPath()
|
||||
)
|
||||
|
||||
// Check if the binary exists
|
||||
@@ -69,13 +70,11 @@ func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
||||
binaryPath := getSimpleResponderPath()
|
||||
|
||||
// Create a YAML string with just the values we want to set
|
||||
yamlStr := fmt.Sprintf(`
|
||||
cmd: '%s --port %d --silent --respond %s'
|
||||
proxy: "http://127.0.0.1:%d"
|
||||
`, binaryPath, port, expectedMessage, port)
|
||||
`, simpleResponderPath, port, expectedMessage, port)
|
||||
|
||||
var cfg ModelConfig
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||
|
||||
@@ -2,10 +2,13 @@ package proxy
|
||||
|
||||
import (
|
||||
"container/ring"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
)
|
||||
|
||||
type LogLevel int
|
||||
@@ -18,7 +21,7 @@ const (
|
||||
)
|
||||
|
||||
type LogMonitor struct {
|
||||
clients map[chan []byte]bool
|
||||
eventbus *event.Dispatcher
|
||||
mu sync.RWMutex
|
||||
buffer *ring.Ring
|
||||
bufferMu sync.RWMutex
|
||||
@@ -37,11 +40,11 @@ func NewLogMonitor() *LogMonitor {
|
||||
|
||||
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||
return &LogMonitor{
|
||||
clients: make(map[chan []byte]bool),
|
||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||
stdout: stdout,
|
||||
level: LevelInfo,
|
||||
prefix: "",
|
||||
eventbus: event.NewDispatcherConfig(1000),
|
||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||
stdout: stdout,
|
||||
level: LevelInfo,
|
||||
prefix: "",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,34 +84,14 @@ func (w *LogMonitor) GetHistory() []byte {
|
||||
return history
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Subscribe() chan []byte {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
ch := make(chan []byte, 100)
|
||||
w.clients[ch] = true
|
||||
return ch
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Unsubscribe(ch chan []byte) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
delete(w.clients, ch)
|
||||
close(ch)
|
||||
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
|
||||
return event.Subscribe(w.eventbus, func(e LogDataEvent) {
|
||||
callback(e.Data)
|
||||
})
|
||||
}
|
||||
|
||||
func (w *LogMonitor) broadcast(msg []byte) {
|
||||
w.mu.RLock()
|
||||
defer w.mu.RUnlock()
|
||||
|
||||
for client := range w.clients {
|
||||
select {
|
||||
case client <- msg:
|
||||
default:
|
||||
// If client buffer is full, skip
|
||||
}
|
||||
}
|
||||
event.Publish(w.eventbus, LogDataEvent{Data: msg})
|
||||
}
|
||||
|
||||
func (w *LogMonitor) SetPrefix(prefix string) {
|
||||
|
||||
@@ -10,38 +10,29 @@ import (
|
||||
func TestLogMonitor(t *testing.T) {
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Test subscription
|
||||
client1 := logMonitor.Subscribe()
|
||||
client2 := logMonitor.Subscribe()
|
||||
|
||||
defer logMonitor.Unsubscribe(client1)
|
||||
defer logMonitor.Unsubscribe(client2)
|
||||
// A WaitGroup is used to wait for all the expected writes to complete
|
||||
var wg sync.WaitGroup
|
||||
|
||||
client1Messages := make([]byte, 0)
|
||||
client2Messages := make([]byte, 0)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
defer logMonitor.OnLogData(func(data []byte) {
|
||||
client1Messages = append(client1Messages, data...)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case data := <-client1:
|
||||
client1Messages = append(client1Messages, data...)
|
||||
case data := <-client2:
|
||||
client2Messages = append(client2Messages, data...)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer logMonitor.OnLogData(func(data []byte) {
|
||||
client2Messages = append(client2Messages, data...)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
wg.Add(6) // 2 x 3 writes
|
||||
|
||||
logMonitor.Write([]byte("1"))
|
||||
logMonitor.Write([]byte("2"))
|
||||
logMonitor.Write([]byte("3"))
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
// wait for all writes to complete
|
||||
wg.Wait()
|
||||
|
||||
// Check the buffer
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
|
||||
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
writer := &MetricsResponseWriter{
|
||||
ResponseWriter: c.Writer,
|
||||
metricsRecorder: &MetricsRecorder{
|
||||
metricsMonitor: pm.metricsMonitor,
|
||||
realModelName: realModelName,
|
||||
isStreaming: gjson.GetBytes(bodyBytes, "stream").Bool(),
|
||||
startTime: time.Now(),
|
||||
},
|
||||
}
|
||||
c.Writer = writer
|
||||
c.Next()
|
||||
|
||||
rec := writer.metricsRecorder
|
||||
rec.processBody(writer.body)
|
||||
}
|
||||
}
|
||||
|
||||
type MetricsRecorder struct {
|
||||
metricsMonitor *MetricsMonitor
|
||||
realModelName string
|
||||
isStreaming bool
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
// processBody handles response processing after request completes
|
||||
func (rec *MetricsRecorder) processBody(body []byte) {
|
||||
if rec.isStreaming {
|
||||
rec.processStreamingResponse(body)
|
||||
} else {
|
||||
rec.processNonStreamingResponse(body)
|
||||
}
|
||||
}
|
||||
|
||||
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool {
|
||||
usage := jsonData.Get("usage")
|
||||
if !usage.Exists() {
|
||||
return false
|
||||
}
|
||||
|
||||
// default values
|
||||
outputTokens := int(jsonData.Get("usage.completion_tokens").Int())
|
||||
inputTokens := int(jsonData.Get("usage.prompt_tokens").Int())
|
||||
tokensPerSecond := -1.0
|
||||
promptPerSecond := -1.0
|
||||
durationMs := int(time.Since(rec.startTime).Milliseconds())
|
||||
|
||||
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
||||
if timings := jsonData.Get("timings"); timings.Exists() {
|
||||
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
|
||||
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
|
||||
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
|
||||
}
|
||||
|
||||
rec.metricsMonitor.addMetrics(TokenMetrics{
|
||||
Timestamp: time.Now(),
|
||||
Model: rec.realModelName,
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
PromptPerSecond: promptPerSecond,
|
||||
TokensPerSecond: tokensPerSecond,
|
||||
DurationMs: durationMs,
|
||||
})
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (rec *MetricsRecorder) processStreamingResponse(body []byte) {
|
||||
// Iterate **backwards** through the lines looking for the data payload with
|
||||
// usage data
|
||||
lines := bytes.Split(body, []byte("\n"))
|
||||
|
||||
for i := len(lines) - 1; i >= 0; i-- {
|
||||
line := bytes.TrimSpace(lines[i])
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// SSE payload always follows "data:"
|
||||
prefix := []byte("data:")
|
||||
if !bytes.HasPrefix(line, prefix) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[len(prefix):])
|
||||
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(data, []byte("[DONE]")) {
|
||||
// [DONE] line itself contains nothing of interest.
|
||||
continue
|
||||
}
|
||||
|
||||
if gjson.ValidBytes(data) {
|
||||
if rec.parseAndRecordMetrics(gjson.ParseBytes(data)) {
|
||||
return // short circuit if a metric was recorded
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) {
|
||||
if len(body) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Parse JSON to extract usage information
|
||||
if gjson.ValidBytes(body) {
|
||||
rec.parseAndRecordMetrics(gjson.ParseBytes(body))
|
||||
}
|
||||
}
|
||||
|
||||
// MetricsResponseWriter captures the entire response for non-streaming
|
||||
type MetricsResponseWriter struct {
|
||||
gin.ResponseWriter
|
||||
body []byte
|
||||
metricsRecorder *MetricsRecorder
|
||||
}
|
||||
|
||||
func (w *MetricsResponseWriter) Write(b []byte) (int, error) {
|
||||
n, err := w.ResponseWriter.Write(b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
w.body = append(w.body, b...)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (w *MetricsResponseWriter) WriteHeader(statusCode int) {
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *MetricsResponseWriter) Header() http.Header {
|
||||
return w.ResponseWriter.Header()
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
)
|
||||
|
||||
// TokenMetrics represents parsed token statistics from llama-server logs
|
||||
type TokenMetrics struct {
|
||||
ID int `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Model string `json:"model"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
}
|
||||
|
||||
// TokenMetricsEvent represents a token metrics event
|
||||
type TokenMetricsEvent struct {
|
||||
Metrics TokenMetrics
|
||||
}
|
||||
|
||||
func (e TokenMetricsEvent) Type() uint32 {
|
||||
return TokenMetricsEventID // defined in events.go
|
||||
}
|
||||
|
||||
// MetricsMonitor parses llama-server output for token statistics
|
||||
type MetricsMonitor struct {
|
||||
mu sync.RWMutex
|
||||
metrics []TokenMetrics
|
||||
maxMetrics int
|
||||
nextID int
|
||||
}
|
||||
|
||||
func NewMetricsMonitor(config *Config) *MetricsMonitor {
|
||||
maxMetrics := config.MetricsMaxInMemory
|
||||
if maxMetrics <= 0 {
|
||||
maxMetrics = 1000 // Default fallback
|
||||
}
|
||||
|
||||
mp := &MetricsMonitor{
|
||||
maxMetrics: maxMetrics,
|
||||
}
|
||||
|
||||
return mp
|
||||
}
|
||||
|
||||
// addMetrics adds a new metric to the collection and publishes an event
|
||||
func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
metric.ID = mp.nextID
|
||||
mp.nextID++
|
||||
mp.metrics = append(mp.metrics, metric)
|
||||
if len(mp.metrics) > mp.maxMetrics {
|
||||
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
||||
}
|
||||
|
||||
event.Emit(TokenMetricsEvent{Metrics: metric})
|
||||
}
|
||||
|
||||
// GetMetrics returns a copy of the current metrics
|
||||
func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
result := make([]TokenMetrics, len(mp.metrics))
|
||||
copy(result, mp.metrics)
|
||||
return result
|
||||
}
|
||||
|
||||
// GetMetricsJSON returns metrics as JSON
|
||||
func (mp *MetricsMonitor) GetMetricsJSON() ([]byte, error) {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
return json.Marshal(mp.metrics)
|
||||
}
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
)
|
||||
|
||||
type ProcessState string
|
||||
@@ -127,6 +129,7 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
|
||||
|
||||
p.state = newState
|
||||
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
|
||||
return p.state, nil
|
||||
}
|
||||
|
||||
@@ -189,17 +192,19 @@ func (p *Process) start() error {
|
||||
p.waitStarting.Add(1)
|
||||
defer p.waitStarting.Done()
|
||||
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
|
||||
|
||||
p.cmd = exec.CommandContext(cmdContext, args[0], args[1:]...)
|
||||
p.cmd.Stdout = p.processLogger
|
||||
p.cmd.Stderr = p.processLogger
|
||||
p.cmd.Env = p.config.Env
|
||||
|
||||
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
|
||||
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
||||
p.cmd.WaitDelay = p.gracefulStopTimeout
|
||||
p.cancelUpstream = ctxCancelUpstream
|
||||
p.cmdWaitChan = make(chan struct{})
|
||||
|
||||
p.failedStartCount++ // this will be reset to zero when the process has successfully started
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.ID, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
|
||||
err = p.cmd.Start()
|
||||
|
||||
// Set process state to failed
|
||||
@@ -207,11 +212,11 @@ func (p *Process) start() error {
|
||||
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
|
||||
p.state = StateStopped // force it into a stopped state
|
||||
return fmt.Errorf(
|
||||
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||
err, curState, swapErr,
|
||||
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||
strings.Join(args, " "), err, curState, swapErr,
|
||||
)
|
||||
}
|
||||
return fmt.Errorf("start() failed: %v", err)
|
||||
return fmt.Errorf("start() failed for command '%s': %v", strings.Join(args, " "), err)
|
||||
}
|
||||
|
||||
// Capture the exit error for later signalling
|
||||
@@ -530,7 +535,7 @@ func (p *Process) cmdStopUpstreamProcess() error {
|
||||
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||
stopCmd.Stdout = p.processLogger
|
||||
stopCmd.Stderr = p.processLogger
|
||||
stopCmd.Env = p.config.Env
|
||||
stopCmd.Env = p.cmd.Env
|
||||
|
||||
if err := stopCmd.Run(); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err)
|
||||
|
||||
@@ -107,7 +107,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "start() failed: ")
|
||||
assert.Contains(t, w.Body.String(), "start() failed for command 'nonexistent-command':")
|
||||
}
|
||||
|
||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
@@ -394,6 +394,9 @@ func TestProcess_StopImmediately(t *testing.T) {
|
||||
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||
// the upstream command
|
||||
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping SIGTERM test on Windows ")
|
||||
}
|
||||
|
||||
expectedMessage := "test_sigkill"
|
||||
binaryPath := getSimpleResponderPath()
|
||||
@@ -405,7 +408,6 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
}
|
||||
|
||||
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||
@@ -465,3 +467,27 @@ func TestProcess_StopCmd(t *testing.T) {
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
|
||||
expectedMessage := "test_env_not_emptied"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// ensure that the the default config does not blank out the inherited environment
|
||||
configWEnv := config
|
||||
|
||||
// ensure the additiona variables are appended to the process' environment
|
||||
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
|
||||
|
||||
process1 := NewProcess("env_test", 2, config, debugLogger, debugLogger)
|
||||
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
|
||||
|
||||
process1.start()
|
||||
defer process1.Stop()
|
||||
process2.start()
|
||||
defer process2.Stop()
|
||||
|
||||
assert.NotZero(t, len(process1.cmd.Environ()))
|
||||
assert.NotZero(t, len(process2.cmd.Environ()))
|
||||
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
|
||||
|
||||
}
|
||||
|
||||
@@ -2,18 +2,20 @@ package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -33,7 +35,13 @@ type ProxyManager struct {
|
||||
upstreamLogger *LogMonitor
|
||||
muxLogger *LogMonitor
|
||||
|
||||
metricsMonitor *MetricsMonitor
|
||||
|
||||
processGroups map[string]*ProcessGroup
|
||||
|
||||
// shutdown signaling
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func New(config Config) *ProxyManager {
|
||||
@@ -64,6 +72,8 @@ func New(config Config) *ProxyManager {
|
||||
upstreamLogger.SetLogLevel(LevelInfo)
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
|
||||
pm := &ProxyManager{
|
||||
config: config,
|
||||
ginEngine: gin.New(),
|
||||
@@ -72,7 +82,12 @@ func New(config Config) *ProxyManager {
|
||||
muxLogger: stdoutLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
|
||||
metricsMonitor: NewMetricsMonitor(&config),
|
||||
|
||||
processGroups: make(map[string]*ProcessGroup),
|
||||
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownCancel: shutdownCancel,
|
||||
}
|
||||
|
||||
// create the process groups
|
||||
@@ -82,6 +97,35 @@ func New(config Config) *ProxyManager {
|
||||
}
|
||||
|
||||
pm.setupGinEngine()
|
||||
|
||||
// run any startup hooks
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
// do it in the background, don't block startup -- not sure if good idea yet
|
||||
go func() {
|
||||
discardWriter := &DiscardWriter{}
|
||||
for _, realModelName := range config.Hooks.OnStartup.Preload {
|
||||
proxyLogger.Infof("Preloading model: %s", realModelName)
|
||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||
|
||||
if err != nil {
|
||||
event.Emit(ModelPreloadedEvent{
|
||||
ModelName: realModelName,
|
||||
Success: false,
|
||||
})
|
||||
proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err)
|
||||
continue
|
||||
} else {
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
processGroup.ProxyRequest(realModelName, discardWriter, req)
|
||||
event.Emit(ModelPreloadedEvent{
|
||||
ModelName: realModelName,
|
||||
Success: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return pm
|
||||
}
|
||||
|
||||
@@ -140,14 +184,18 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
c.Next()
|
||||
})
|
||||
|
||||
mm := MetricsMiddleware(pm)
|
||||
|
||||
// Set up routes using the Gin engine
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler)
|
||||
// Support legacy /v1/completions api, see issue #12
|
||||
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler)
|
||||
|
||||
// Support embeddings
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler)
|
||||
|
||||
// Support audio/speech endpoint
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||
@@ -158,9 +206,7 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
// in proxymanager_loghandlers.go
|
||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
||||
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
|
||||
|
||||
/**
|
||||
* User Interface Endpoints
|
||||
@@ -176,6 +222,9 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
|
||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
||||
if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil {
|
||||
@@ -262,6 +311,7 @@ func (pm *ProxyManager) Shutdown() {
|
||||
}(processGroup)
|
||||
}
|
||||
wg.Wait()
|
||||
pm.shutdownCancel()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
||||
@@ -289,32 +339,48 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
data := []interface{}{}
|
||||
data := make([]gin.H, 0, len(pm.config.Models))
|
||||
createdTime := time.Now().Unix()
|
||||
|
||||
for id, modelConfig := range pm.config.Models {
|
||||
if modelConfig.Unlisted {
|
||||
continue
|
||||
}
|
||||
|
||||
data = append(data, map[string]interface{}{
|
||||
record := gin.H{
|
||||
"id": id,
|
||||
"object": "model",
|
||||
"created": time.Now().Unix(),
|
||||
"created": createdTime,
|
||||
"owned_by": "llama-swap",
|
||||
})
|
||||
}
|
||||
|
||||
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||
record["name"] = name
|
||||
}
|
||||
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||
record["description"] = desc
|
||||
}
|
||||
|
||||
data = append(data, record)
|
||||
}
|
||||
|
||||
// Set the Content-Type header to application/json
|
||||
c.Header("Content-Type", "application/json")
|
||||
// Sort by the "id" key
|
||||
sort.Slice(data, func(i, j int) bool {
|
||||
si, _ := data[i]["id"].(string)
|
||||
sj, _ := data[j]["id"].(string)
|
||||
return si < sj
|
||||
})
|
||||
|
||||
if origin := c.Request.Header.Get("Origin"); origin != "" {
|
||||
// Set CORS headers if origin exists
|
||||
if origin := c.GetHeader("Origin"); origin != "" {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
// Encode the data as JSON and write it to the response writer
|
||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"object": "list", "data": data}); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
|
||||
return
|
||||
}
|
||||
// Use gin's JSON method which handles content-type and encoding
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
@@ -325,7 +391,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, _, err := pm.swapProcessGroup(requestedModel)
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
@@ -333,7 +399,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
|
||||
// rewrite the path
|
||||
c.Request.URL.Path = c.Param("upstreamPath")
|
||||
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
|
||||
processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
@@ -349,7 +415,13 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
@@ -365,6 +437,21 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// issue #174 strip parameters from the JSON body
|
||||
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams()
|
||||
if err != nil { // just log it and continue
|
||||
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error())
|
||||
} else {
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// dechunk it as we already have all the body bytes see issue #11
|
||||
|
||||
@@ -1,25 +1,30 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
Id string `json:"id"`
|
||||
State string `json:"state"`
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
}
|
||||
|
||||
func addApiHandlers(pm *ProxyManager) {
|
||||
// Add API endpoints for React to consume
|
||||
apiGroup := pm.ginEngine.Group("/api")
|
||||
{
|
||||
apiGroup.GET("/models", pm.apiListModels)
|
||||
apiGroup.GET("/modelsSSE", pm.apiListModelsSSE)
|
||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||
apiGroup.GET("/events", pm.apiSendEvents)
|
||||
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,37 +70,133 @@ func (pm *ProxyManager) getModelStatus() []Model {
|
||||
}
|
||||
}
|
||||
models = append(models, Model{
|
||||
Id: modelID,
|
||||
State: state,
|
||||
Id: modelID,
|
||||
Name: pm.config.Models[modelID].Name,
|
||||
Description: pm.config.Models[modelID].Description,
|
||||
State: state,
|
||||
Unlisted: pm.config.Models[modelID].Unlisted,
|
||||
})
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiListModels(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, pm.getModelStatus())
|
||||
type messageType string
|
||||
|
||||
const (
|
||||
msgTypeModelStatus messageType = "modelStatus"
|
||||
msgTypeLogData messageType = "logData"
|
||||
msgTypeMetrics messageType = "metrics"
|
||||
)
|
||||
|
||||
type messageEnvelope struct {
|
||||
Type messageType `json:"type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// stream the models as a SSE
|
||||
func (pm *ProxyManager) apiListModelsSSE(c *gin.Context) {
|
||||
// sends a stream of different message types that happen on the server
|
||||
func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
sendBuffer := make(chan messageEnvelope, 25)
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
sendModels := func() {
|
||||
data, err := json.Marshal(pm.getModelStatus())
|
||||
if err == nil {
|
||||
msg := messageEnvelope{Type: msgTypeModelStatus, Data: string(data)}
|
||||
select {
|
||||
case sendBuffer <- msg:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
sendLogData := func(source string, data []byte) {
|
||||
data, err := json.Marshal(gin.H{
|
||||
"source": source,
|
||||
"data": string(data),
|
||||
})
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeLogData, Data: string(data)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sendMetrics := func(metrics []TokenMetrics) {
|
||||
jsonData, err := json.Marshal(metrics)
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeMetrics, Data: string(jsonData)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send updated models list
|
||||
*/
|
||||
defer event.On(func(e ProcessStateChangeEvent) {
|
||||
sendModels()
|
||||
})()
|
||||
defer event.On(func(e ConfigFileChangedEvent) {
|
||||
sendModels()
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send Log data
|
||||
*/
|
||||
defer pm.proxyLogger.OnLogData(func(data []byte) {
|
||||
sendLogData("proxy", data)
|
||||
})()
|
||||
defer pm.upstreamLogger.OnLogData(func(data []byte) {
|
||||
sendLogData("upstream", data)
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send Metrics data
|
||||
*/
|
||||
defer event.On(func(e TokenMetricsEvent) {
|
||||
sendMetrics([]TokenMetrics{e.Metrics})
|
||||
})()
|
||||
|
||||
// send initial batch of data
|
||||
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||
sendModels()
|
||||
sendMetrics(pm.metricsMonitor.GetMetrics())
|
||||
|
||||
// Stream new events
|
||||
for {
|
||||
select {
|
||||
case <-notify:
|
||||
case <-c.Request.Context().Done():
|
||||
cancel()
|
||||
return
|
||||
default:
|
||||
models := pm.getModelStatus()
|
||||
c.SSEvent("message", models)
|
||||
case <-pm.shutdownCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case msg := <-sendBuffer:
|
||||
c.SSEvent("message", msg)
|
||||
c.Writer.Flush()
|
||||
<-time.After(1000 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
||||
jsonData, err := pm.metricsMonitor.GetMetricsJSON()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
|
||||
return
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", jsonData)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -34,10 +35,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
c.String(http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
ch := logger.Subscribe()
|
||||
defer logger.Unsubscribe(ch)
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
||||
@@ -55,57 +53,28 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Stream new logs
|
||||
sendChan := make(chan []byte, 10)
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer logger.OnLogData(func(data []byte) {
|
||||
select {
|
||||
case sendChan <- data:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
})()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
_, err := c.Writer.Write(msg)
|
||||
if err != nil {
|
||||
// just break the loop if we can't write for some reason
|
||||
return
|
||||
}
|
||||
case <-c.Request.Context().Done():
|
||||
cancel()
|
||||
return
|
||||
case <-pm.shutdownCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case data := <-sendChan:
|
||||
c.Writer.Write(data)
|
||||
flusher.Flush()
|
||||
case <-notify:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
logMonitorId := c.Param("logMonitorID")
|
||||
logger, err := pm.getLogger(logMonitorId)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
ch := logger.Subscribe()
|
||||
defer logger.Unsubscribe(ch)
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
|
||||
// Send history first if not skipped
|
||||
_, skipHistory := c.GetQuery("no-history")
|
||||
if !skipHistory {
|
||||
history := logger.GetHistory()
|
||||
if len(history) != 0 {
|
||||
c.SSEvent("message", string(history))
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Stream new logs
|
||||
for {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
c.SSEvent("message", string(msg))
|
||||
c.Writer.Flush()
|
||||
case <-notify:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,10 +9,12 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -165,9 +167,11 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
var response map[string]string
|
||||
var response map[string]interface{}
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
results[key] = response["responseMessage"]
|
||||
result, ok := response["responseMessage"].(string)
|
||||
assert.Equal(t, ok, true)
|
||||
results[key] = result
|
||||
mu.Unlock()
|
||||
}(key)
|
||||
|
||||
@@ -183,11 +187,20 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
|
||||
model1Config := getTestSimpleResponderConfig("model1")
|
||||
model1Config.Name = "Model 1"
|
||||
model1Config.Description = "Model 1 description is used for testing"
|
||||
|
||||
model2Config := getTestSimpleResponderConfig("model2")
|
||||
model2Config.Name = " " // empty whitespace only strings will get ignored
|
||||
model2Config.Description = " "
|
||||
|
||||
config := Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model1": model1Config,
|
||||
"model2": model2Config,
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
@@ -213,6 +226,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
var response struct {
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||
}
|
||||
@@ -227,6 +241,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
"model3": {},
|
||||
}
|
||||
|
||||
// make all models
|
||||
for _, model := range response.Data {
|
||||
modelID, ok := model["id"].(string)
|
||||
assert.True(t, ok, "model ID should be a string")
|
||||
@@ -245,12 +260,72 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
ownedBy, ok := model["owned_by"].(string)
|
||||
assert.True(t, ok, "owned_by should be a string")
|
||||
assert.Equal(t, "llama-swap", ownedBy)
|
||||
|
||||
// check for optional name and description
|
||||
if modelID == "model1" {
|
||||
name, ok := model["name"].(string)
|
||||
assert.True(t, ok, "name should be a string")
|
||||
assert.Equal(t, "Model 1", name)
|
||||
description, ok := model["description"].(string)
|
||||
assert.True(t, ok, "description should be a string")
|
||||
assert.Equal(t, "Model 1 description is used for testing", description)
|
||||
} else {
|
||||
_, exists := model["name"]
|
||||
assert.False(t, exists, "unexpected name field for model: %s", modelID)
|
||||
_, exists = model["description"]
|
||||
assert.False(t, exists, "unexpected description field for model: %s", modelID)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure all expected models were returned
|
||||
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||
}
|
||||
|
||||
func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
|
||||
// Intentionally add models in non-sorted order and with an unlisted model
|
||||
config := Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"zeta": getTestSimpleResponderConfig("zeta"),
|
||||
"alpha": getTestSimpleResponderConfig("alpha"),
|
||||
"beta": getTestSimpleResponderConfig("beta"),
|
||||
"hidden": func() ModelConfig {
|
||||
mc := getTestSimpleResponderConfig("hidden")
|
||||
mc.Unlisted = true
|
||||
return mc
|
||||
}(),
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
|
||||
// Request models list
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
w := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response struct {
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||
}
|
||||
|
||||
// We expect only the listed models in sorted order by id
|
||||
expectedOrder := []string{"alpha", "beta", "zeta"}
|
||||
if assert.Len(t, response.Data, len(expectedOrder), "unexpected number of listed models") {
|
||||
got := make([]string, 0, len(response.Data))
|
||||
for _, m := range response.Data {
|
||||
id, _ := m["id"].(string)
|
||||
got = append(got, id)
|
||||
}
|
||||
assert.Equal(t, expectedOrder, got, "models should be sorted by id ascending")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_Shutdown(t *testing.T) {
|
||||
// make broken model configurations
|
||||
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||
@@ -583,21 +658,34 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_Upstream(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
models:
|
||||
model1:
|
||||
cmd: %s -port ${PORT} -silent -respond model1
|
||||
aliases: [model-alias]
|
||||
`, getSimpleResponderPath())
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "model1", rec.Body.String())
|
||||
t.Run("main model name", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "model1", rec.Body.String())
|
||||
})
|
||||
|
||||
t.Run("model alias", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "model1", rec.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
@@ -618,8 +706,189 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var response map[string]string
|
||||
var response map[string]interface{}
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
assert.Equal(t, "81", response["h_content_length"])
|
||||
assert.Equal(t, "model1", response["responseMessage"])
|
||||
}
|
||||
|
||||
func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
modelConfig := getTestSimpleResponderConfig("model1")
|
||||
modelConfig.Filters = ModelFilters{
|
||||
StripParams: "temperature, model, stream",
|
||||
}
|
||||
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
LogLevel: "error",
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": modelConfig,
|
||||
},
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var response map[string]interface{}
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
|
||||
// `temperature` and `stream` are gone but model remains
|
||||
assert.Equal(t, `{"model":"model1", "x_param":"123", "y_param":"abc"}`, response["request_body"])
|
||||
|
||||
// assert.Nil(t, response["temperature"])
|
||||
// assert.Equal(t, "123", response["x_param"])
|
||||
// assert.Equal(t, "abc", response["y_param"])
|
||||
// t.Logf("%v", response)
|
||||
}
|
||||
|
||||
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
// Make a non-streaming request
|
||||
reqBody := `{"model":"model1", "stream": false}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Check that metrics were recorded
|
||||
metrics := proxy.metricsMonitor.GetMetrics()
|
||||
if !assert.NotEmpty(t, metrics, "metrics should be recorded for non-streaming request") {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the last metric has the correct model
|
||||
lastMetric := metrics[len(metrics)-1]
|
||||
assert.Equal(t, "model1", lastMetric.Model)
|
||||
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
|
||||
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
|
||||
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
|
||||
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
|
||||
}
|
||||
|
||||
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
// Make a streaming request
|
||||
reqBody := `{"model":"model1", "stream": true}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Check that metrics were recorded
|
||||
metrics := proxy.metricsMonitor.GetMetrics()
|
||||
if !assert.NotEmpty(t, metrics, "metrics should be recorded for streaming request") {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the last metric has the correct model
|
||||
lastMetric := metrics[len(metrics)-1]
|
||||
assert.Equal(t, "model1", lastMetric.Model)
|
||||
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
|
||||
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
|
||||
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
|
||||
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
|
||||
}
|
||||
|
||||
func TestProxyManager_HealthEndpoint(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "OK", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyManager_StartupHooks(t *testing.T) {
|
||||
|
||||
// using real YAML as the configuration has gotten more complex
|
||||
// is the right approach as LoadConfigFromReader() does a lot more
|
||||
// than parse YAML now. Eventually migrate all tests to use this approach
|
||||
configStr := strings.Replace(`
|
||||
logLevel: error
|
||||
hooks:
|
||||
on_startup:
|
||||
preload:
|
||||
- model1
|
||||
- model2
|
||||
groups:
|
||||
preloadTestGroup:
|
||||
swap: false
|
||||
members:
|
||||
- model1
|
||||
- model2
|
||||
models:
|
||||
model1:
|
||||
cmd: ${simpleresponderpath} --port ${PORT} --silent --respond model1
|
||||
model2:
|
||||
cmd: ${simpleresponderpath} --port ${PORT} --silent --respond model2
|
||||
`, "${simpleresponderpath}", simpleResponderPath, -1)
|
||||
|
||||
// Create a test model configuration
|
||||
config, err := LoadConfigFromReader(strings.NewReader(configStr))
|
||||
if !assert.NoError(t, err, "Invalid configuration") {
|
||||
return
|
||||
}
|
||||
|
||||
preloadChan := make(chan ModelPreloadedEvent, 2) // buffer for 2 expected events
|
||||
|
||||
unsub := event.On(func(e ModelPreloadedEvent) {
|
||||
preloadChan <- e
|
||||
})
|
||||
|
||||
defer unsub()
|
||||
|
||||
// Create the proxy which should trigger preloading
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case <-preloadChan:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for models to preload")
|
||||
}
|
||||
}
|
||||
// make sure they are both loaded
|
||||
_, foundGroup := proxy.processGroups["preloadTestGroup"]
|
||||
if !assert.True(t, foundGroup, "preloadTestGroup should exist") {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, StateReady, proxy.processGroups["preloadTestGroup"].processes["model1"].CurrentState())
|
||||
assert.Equal(t, StateReady, proxy.processGroups["preloadTestGroup"].processes["model2"].CurrentState())
|
||||
}
|
||||
|
||||
@@ -3,7 +3,11 @@
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<link rel="icon" type="image/png" href="/favicon.ico" />
|
||||
<link rel="icon" type="image/png" href="/favicon-96x96.png" sizes="96x96" />
|
||||
<link rel="icon" type="image/svg+xml" href="/favicon.svg" />
|
||||
<link rel="shortcut icon" href="/favicon.ico" />
|
||||
<link rel="apple-touch-icon" sizes="180x180" href="/apple-touch-icon.png" />
|
||||
<link rel="manifest" href="/site.webmanifest" />
|
||||
<title>llama-swap</title>
|
||||
</head>
|
||||
<body >
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
"@tanstack/react-query": "^5.80.6",
|
||||
"react": "^19.1.0",
|
||||
"react-dom": "^19.1.0",
|
||||
"react-icons": "^5.5.0",
|
||||
"react-resizable-panels": "^3.0.4",
|
||||
"react-router-dom": "^7.6.2",
|
||||
"tailwindcss": "^4.1.8"
|
||||
},
|
||||
@@ -3460,6 +3462,15 @@
|
||||
"react": "^19.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/react-icons": {
|
||||
"version": "5.5.0",
|
||||
"resolved": "https://registry.npmjs.org/react-icons/-/react-icons-5.5.0.tgz",
|
||||
"integrity": "sha512-MEFcXdkP3dLo8uumGI5xN3lDFNsRtrjbOEKDLD7yv76v4wpnEq2Lt2qeHaQOr34I/wPN3s3+N08WkQ+CW37Xiw==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"react": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/react-refresh": {
|
||||
"version": "0.17.0",
|
||||
"resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.17.0.tgz",
|
||||
@@ -3470,6 +3481,16 @@
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/react-resizable-panels": {
|
||||
"version": "3.0.4",
|
||||
"resolved": "https://registry.npmjs.org/react-resizable-panels/-/react-resizable-panels-3.0.4.tgz",
|
||||
"integrity": "sha512-8Y4KNgV94XhUvI2LeByyPIjoUJb71M/0hyhtzkHaqpVHs+ZQs8b627HmzyhmVYi3C9YP6R+XD1KmG7hHjEZXFQ==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"react": "^16.14.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.14.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc"
|
||||
}
|
||||
},
|
||||
"node_modules/react-router": {
|
||||
"version": "7.6.2",
|
||||
"resolved": "https://registry.npmjs.org/react-router/-/react-router-7.6.2.tgz",
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
"@tanstack/react-query": "^5.80.6",
|
||||
"react": "^19.1.0",
|
||||
"react-dom": "^19.1.0",
|
||||
"react-icons": "^5.5.0",
|
||||
"react-resizable-panels": "^3.0.4",
|
||||
"react-router-dom": "^7.6.2",
|
||||
"tailwindcss": "^4.1.8"
|
||||
},
|
||||
@@ -30,4 +32,4 @@
|
||||
"typescript-eslint": "^8.30.1",
|
||||
"vite": "^6.3.5"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
After Width: | Height: | Size: 5.9 KiB |
|
After Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
|
After Width: | Height: | Size: 38 KiB |
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"name": "llama-swap",
|
||||
"short_name": "llama-swap",
|
||||
"icons": [
|
||||
{
|
||||
"src": "/web-app-manifest-192x192.png",
|
||||
"sizes": "192x192",
|
||||
"type": "image/png",
|
||||
"purpose": "maskable"
|
||||
},
|
||||
{
|
||||
"src": "/web-app-manifest-512x512.png",
|
||||
"sizes": "512x512",
|
||||
"type": "image/png",
|
||||
"purpose": "maskable"
|
||||
}
|
||||
],
|
||||
"theme_color": "#ffffff",
|
||||
"background_color": "#ffffff",
|
||||
"display": "standalone"
|
||||
}
|
||||
|
After Width: | Height: | Size: 6.5 KiB |
|
After Width: | Height: | Size: 28 KiB |
@@ -3,17 +3,21 @@ import { useTheme } from "./contexts/ThemeProvider";
|
||||
import { APIProvider } from "./contexts/APIProvider";
|
||||
import LogViewerPage from "./pages/LogViewer";
|
||||
import ModelPage from "./pages/Models";
|
||||
import ActivityPage from "./pages/Activity";
|
||||
import ConnectionStatus from "./components/ConnectionStatus";
|
||||
import { RiSunFill, RiMoonFill } from "react-icons/ri";
|
||||
|
||||
function App() {
|
||||
const theme = useTheme();
|
||||
const { isNarrow, toggleTheme, isDarkMode } = useTheme();
|
||||
|
||||
return (
|
||||
<Router basename="/ui/">
|
||||
<APIProvider>
|
||||
<div>
|
||||
<nav className="bg-surface border-b border-border p-4">
|
||||
<div className="flex items-center justify-between mx-auto px-4">
|
||||
<h1>llama-swap</h1>
|
||||
<div className="flex space-x-4">
|
||||
<div className="flex flex-col h-screen">
|
||||
<nav className="bg-surface border-b border-border p-2 h-[75px]">
|
||||
<div className="flex items-center justify-between mx-auto px-4 h-full">
|
||||
{!isNarrow && <h1 className="flex items-center p-0">llama-swap</h1>}
|
||||
<div className="flex items-center space-x-4">
|
||||
<NavLink to="/" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
|
||||
Logs
|
||||
</NavLink>
|
||||
@@ -21,17 +25,23 @@ function App() {
|
||||
<NavLink to="/models" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
|
||||
Models
|
||||
</NavLink>
|
||||
<button className="btn btn--sm" onClick={theme.toggleTheme}>
|
||||
{theme.isDarkMode ? "🌙" : "☀️"}
|
||||
|
||||
<NavLink to="/activity" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
|
||||
Activity
|
||||
</NavLink>
|
||||
<button className="" onClick={toggleTheme}>
|
||||
{isDarkMode ? <RiMoonFill /> : <RiSunFill />}
|
||||
</button>
|
||||
<ConnectionStatus />
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<main className="mx-auto py-4 px-4">
|
||||
<main className="flex-1 overflow-auto p-4">
|
||||
<Routes>
|
||||
<Route path="/" element={<LogViewerPage />} />
|
||||
<Route path="/models" element={<ModelPage />} />
|
||||
<Route path="/activity" element={<ActivityPage />} />
|
||||
<Route path="*" element={<Navigate to="/" replace />} />
|
||||
</Routes>
|
||||
</main>
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
import { useAPI } from "../contexts/APIProvider";
|
||||
import { useEffect, useState, useMemo } from "react";
|
||||
|
||||
type ConnectionStatus = "disconnected" | "connecting" | "connected";
|
||||
|
||||
const ConnectionStatus = () => {
|
||||
const { getConnectionStatus } = useAPI();
|
||||
const [eventStreamStatus, setEventStreamStatus] = useState<ConnectionStatus>("disconnected");
|
||||
|
||||
useEffect(() => {
|
||||
const interval = setInterval(() => {
|
||||
setEventStreamStatus(getConnectionStatus());
|
||||
}, 1000);
|
||||
return () => clearInterval(interval);
|
||||
});
|
||||
|
||||
const eventStatusColor = useMemo(() => {
|
||||
switch (eventStreamStatus) {
|
||||
case "connected":
|
||||
return "bg-green-500";
|
||||
case "connecting":
|
||||
return "bg-yellow-500";
|
||||
case "disconnected":
|
||||
default:
|
||||
return "bg-red-500";
|
||||
}
|
||||
}, [eventStreamStatus]);
|
||||
|
||||
return (
|
||||
<div className="flex items-center" title={`event stream: ${eventStreamStatus}`}>
|
||||
<span className={`inline-block w-3 h-3 rounded-full ${eventStatusColor} mr-2`}></span>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ConnectionStatus;
|
||||
@@ -6,32 +6,56 @@ const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
|
||||
export interface Model {
|
||||
id: string;
|
||||
state: ModelStatus;
|
||||
name: string;
|
||||
description: string;
|
||||
unlisted: boolean;
|
||||
}
|
||||
|
||||
interface APIProviderType {
|
||||
models: Model[];
|
||||
listModels: () => Promise<Model[]>;
|
||||
unloadAllModels: () => Promise<void>;
|
||||
enableProxyLogs: (enabled: boolean) => void;
|
||||
enableUpstreamLogs: (enabled: boolean) => void;
|
||||
enableModelUpdates: (enabled: boolean) => void;
|
||||
loadModel: (model: string) => Promise<void>;
|
||||
enableAPIEvents: (enabled: boolean) => void;
|
||||
proxyLogs: string;
|
||||
upstreamLogs: string;
|
||||
metrics: Metrics[];
|
||||
getConnectionStatus: () => "connected" | "connecting" | "disconnected";
|
||||
}
|
||||
|
||||
interface Metrics {
|
||||
id: number;
|
||||
timestamp: string;
|
||||
model: string;
|
||||
input_tokens: number;
|
||||
output_tokens: number;
|
||||
prompt_per_second: number;
|
||||
tokens_per_second: number;
|
||||
duration_ms: number;
|
||||
}
|
||||
|
||||
interface LogData {
|
||||
source: "upstream" | "proxy";
|
||||
data: string;
|
||||
}
|
||||
interface APIEventEnvelope {
|
||||
type: "modelStatus" | "logData" | "metrics";
|
||||
data: string;
|
||||
}
|
||||
|
||||
const APIContext = createContext<APIProviderType | undefined>(undefined);
|
||||
type APIProviderProps = {
|
||||
children: ReactNode;
|
||||
autoStartAPIEvents?: boolean;
|
||||
};
|
||||
|
||||
export function APIProvider({ children }: APIProviderProps) {
|
||||
export function APIProvider({ children, autoStartAPIEvents = true }: APIProviderProps) {
|
||||
const [proxyLogs, setProxyLogs] = useState("");
|
||||
const [upstreamLogs, setUpstreamLogs] = useState("");
|
||||
const proxyEventSource = useRef<EventSource | null>(null);
|
||||
const upstreamEventSource = useRef<EventSource | null>(null);
|
||||
const [metrics, setMetrics] = useState<Metrics[]>([]);
|
||||
const apiEventSource = useRef<EventSource | null>(null);
|
||||
|
||||
const [models, setModels] = useState<Model[]>([]);
|
||||
const modelStatusEventSource = useRef<EventSource | null>(null);
|
||||
|
||||
const appendLog = useCallback((newData: string, setter: React.Dispatch<React.SetStateAction<string>>) => {
|
||||
setter((prev) => {
|
||||
@@ -40,76 +64,102 @@ export function APIProvider({ children }: APIProviderProps) {
|
||||
});
|
||||
}, []);
|
||||
|
||||
const handleProxyMessage = useCallback(
|
||||
(e: MessageEvent) => {
|
||||
appendLog(e.data, setProxyLogs);
|
||||
},
|
||||
[proxyLogs, appendLog]
|
||||
);
|
||||
const getConnectionStatus = useCallback(() => {
|
||||
if (apiEventSource.current?.readyState === EventSource.OPEN) {
|
||||
return "connected";
|
||||
} else if (apiEventSource.current?.readyState === EventSource.CONNECTING) {
|
||||
return "connecting";
|
||||
} else {
|
||||
return "disconnected";
|
||||
}
|
||||
}, []);
|
||||
|
||||
const handleUpstreamMessage = useCallback(
|
||||
(e: MessageEvent) => {
|
||||
appendLog(e.data, setUpstreamLogs);
|
||||
},
|
||||
[appendLog]
|
||||
);
|
||||
const enableAPIEvents = useCallback((enabled: boolean) => {
|
||||
if (!enabled) {
|
||||
apiEventSource.current?.close();
|
||||
apiEventSource.current = null;
|
||||
setMetrics([]);
|
||||
return;
|
||||
}
|
||||
|
||||
const enableProxyLogs = useCallback(
|
||||
(enabled: boolean) => {
|
||||
if (enabled) {
|
||||
const eventSource = new EventSource("/logs/streamSSE/proxy");
|
||||
eventSource.onmessage = handleProxyMessage;
|
||||
proxyEventSource.current = eventSource;
|
||||
} else {
|
||||
proxyEventSource.current?.close();
|
||||
proxyEventSource.current = null;
|
||||
}
|
||||
},
|
||||
[handleProxyMessage]
|
||||
);
|
||||
let retryCount = 0;
|
||||
const initialDelay = 1000; // 1 second
|
||||
|
||||
const enableUpstreamLogs = useCallback(
|
||||
(enabled: boolean) => {
|
||||
if (enabled) {
|
||||
const eventSource = new EventSource("/logs/streamSSE/upstream");
|
||||
eventSource.onmessage = handleUpstreamMessage;
|
||||
upstreamEventSource.current = eventSource;
|
||||
} else {
|
||||
upstreamEventSource.current?.close();
|
||||
upstreamEventSource.current = null;
|
||||
}
|
||||
},
|
||||
[upstreamEventSource, handleUpstreamMessage]
|
||||
);
|
||||
const connect = () => {
|
||||
const eventSource = new EventSource("/api/events");
|
||||
|
||||
const enableModelUpdates = useCallback(
|
||||
(enabled: boolean) => {
|
||||
if (enabled) {
|
||||
const eventSource = new EventSource("/api/modelsSSE");
|
||||
eventSource.onmessage = (e: MessageEvent) => {
|
||||
try {
|
||||
const models = JSON.parse(e.data) as Model[];
|
||||
setModels(models);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
eventSource.onopen = () => {
|
||||
// clear everything out on connect to keep things in sync
|
||||
setProxyLogs("");
|
||||
setUpstreamLogs("");
|
||||
setMetrics([]); // clear metrics on reconnect
|
||||
setModels([]); // clear models on reconnect
|
||||
};
|
||||
|
||||
eventSource.onmessage = (e: MessageEvent) => {
|
||||
try {
|
||||
const message = JSON.parse(e.data) as APIEventEnvelope;
|
||||
switch (message.type) {
|
||||
case "modelStatus":
|
||||
{
|
||||
const models = JSON.parse(message.data) as Model[];
|
||||
|
||||
// sort models by name and id
|
||||
models.sort((a, b) => {
|
||||
return (a.name + a.id).localeCompare(b.name + b.id);
|
||||
});
|
||||
|
||||
setModels(models);
|
||||
}
|
||||
break;
|
||||
|
||||
case "logData":
|
||||
const logData = JSON.parse(message.data) as LogData;
|
||||
switch (logData.source) {
|
||||
case "proxy":
|
||||
appendLog(logData.data, setProxyLogs);
|
||||
break;
|
||||
case "upstream":
|
||||
appendLog(logData.data, setUpstreamLogs);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
case "metrics":
|
||||
{
|
||||
const newMetrics = JSON.parse(message.data) as Metrics[];
|
||||
setMetrics((prevMetrics) => {
|
||||
return [...newMetrics, ...prevMetrics];
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
};
|
||||
modelStatusEventSource.current = eventSource;
|
||||
} else {
|
||||
modelStatusEventSource.current?.close();
|
||||
modelStatusEventSource.current = null;
|
||||
}
|
||||
},
|
||||
[setModels]
|
||||
);
|
||||
} catch (err) {
|
||||
console.error(e.data, err);
|
||||
}
|
||||
};
|
||||
eventSource.onerror = () => {
|
||||
eventSource.close();
|
||||
retryCount++;
|
||||
const delay = Math.min(initialDelay * Math.pow(2, retryCount - 1), 5000);
|
||||
setTimeout(connect, delay);
|
||||
};
|
||||
|
||||
apiEventSource.current = eventSource;
|
||||
};
|
||||
|
||||
connect();
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (autoStartAPIEvents) {
|
||||
enableAPIEvents(true);
|
||||
}
|
||||
|
||||
return () => {
|
||||
proxyEventSource.current?.close();
|
||||
upstreamEventSource.current?.close();
|
||||
modelStatusEventSource.current?.close();
|
||||
enableAPIEvents(false);
|
||||
};
|
||||
}, []);
|
||||
}, [enableAPIEvents, autoStartAPIEvents]);
|
||||
|
||||
const listModels = useCallback(async (): Promise<Model[]> => {
|
||||
try {
|
||||
@@ -139,27 +189,33 @@ export function APIProvider({ children }: APIProviderProps) {
|
||||
}
|
||||
}, []);
|
||||
|
||||
const loadModel = useCallback(async (model: string) => {
|
||||
try {
|
||||
const response = await fetch(`/upstream/${model}/`, {
|
||||
method: "GET",
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to load model: ${response.status}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to load model:", error);
|
||||
throw error; // Re-throw to let calling code handle it
|
||||
}
|
||||
}, []);
|
||||
|
||||
const value = useMemo(
|
||||
() => ({
|
||||
models,
|
||||
listModels,
|
||||
unloadAllModels,
|
||||
enableProxyLogs,
|
||||
enableUpstreamLogs,
|
||||
enableModelUpdates,
|
||||
loadModel,
|
||||
enableAPIEvents,
|
||||
proxyLogs,
|
||||
upstreamLogs,
|
||||
metrics,
|
||||
getConnectionStatus,
|
||||
}),
|
||||
[
|
||||
models,
|
||||
listModels,
|
||||
unloadAllModels,
|
||||
enableProxyLogs,
|
||||
enableUpstreamLogs,
|
||||
enableModelUpdates,
|
||||
proxyLogs,
|
||||
upstreamLogs,
|
||||
]
|
||||
[models, listModels, unloadAllModels, loadModel, enableAPIEvents, proxyLogs, upstreamLogs, metrics]
|
||||
);
|
||||
|
||||
return <APIContext.Provider value={value}>{children}</APIContext.Provider>;
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import { createContext, useContext, useEffect, type ReactNode } from "react";
|
||||
import { createContext, useContext, useEffect, type ReactNode, useMemo, useState } from "react";
|
||||
import { usePersistentState } from "../hooks/usePersistentState";
|
||||
|
||||
type ScreenWidth = "xs" | "sm" | "md" | "lg" | "xl" | "2xl";
|
||||
type ThemeContextType = {
|
||||
isDarkMode: boolean;
|
||||
screenWidth: ScreenWidth;
|
||||
isNarrow: boolean;
|
||||
toggleTheme: () => void;
|
||||
};
|
||||
|
||||
@@ -14,14 +17,46 @@ type ThemeProviderProps = {
|
||||
|
||||
export function ThemeProvider({ children }: ThemeProviderProps) {
|
||||
const [isDarkMode, setIsDarkMode] = usePersistentState<boolean>("theme", false);
|
||||
const [screenWidth, setScreenWidth] = useState<ScreenWidth>("md"); // Default to md
|
||||
|
||||
// matches tailwind classes
|
||||
// https://tailwindcss.com/docs/responsive-design
|
||||
useEffect(() => {
|
||||
const checkInnerWidth = () => {
|
||||
const innerWidth = window.innerWidth;
|
||||
if (innerWidth < 640) {
|
||||
setScreenWidth("xs");
|
||||
} else if (innerWidth < 768) {
|
||||
setScreenWidth("sm");
|
||||
} else if (innerWidth < 1024) {
|
||||
setScreenWidth("md");
|
||||
} else if (innerWidth < 1280) {
|
||||
setScreenWidth("lg");
|
||||
} else if (innerWidth < 1536) {
|
||||
setScreenWidth("xl");
|
||||
} else {
|
||||
setScreenWidth("2xl");
|
||||
}
|
||||
};
|
||||
|
||||
checkInnerWidth();
|
||||
window.addEventListener("resize", checkInnerWidth);
|
||||
|
||||
return () => window.removeEventListener("resize", checkInnerWidth);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
document.documentElement.setAttribute("data-theme", isDarkMode ? "dark" : "light");
|
||||
}, [isDarkMode]);
|
||||
|
||||
const toggleTheme = () => setIsDarkMode((prev) => !prev);
|
||||
const isNarrow = useMemo(() => {
|
||||
return screenWidth === "xs" || screenWidth === "sm" || screenWidth === "md";
|
||||
}, [screenWidth]);
|
||||
|
||||
return <ThemeContext.Provider value={{ isDarkMode, toggleTheme }}>{children}</ThemeContext.Provider>;
|
||||
return (
|
||||
<ThemeContext.Provider value={{ isDarkMode, toggleTheme, screenWidth, isNarrow }}>{children}</ThemeContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useTheme(): ThemeContextType {
|
||||
|
||||
@@ -143,6 +143,10 @@
|
||||
@apply bg-surface p-2 px-4 text-sm rounded-full border border-2 transition-colors duration-200 border-btn-border;
|
||||
}
|
||||
|
||||
.btn:hover {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.btn--sm {
|
||||
@apply px-2 py-0.5 text-xs;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
import { useMemo } from "react";
|
||||
import { useAPI } from "../contexts/APIProvider";
|
||||
|
||||
const formatTimestamp = (timestamp: string): string => {
|
||||
return new Date(timestamp).toLocaleString();
|
||||
};
|
||||
|
||||
const formatSpeed = (speed: number): string => {
|
||||
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
|
||||
};
|
||||
|
||||
const formatDuration = (ms: number): string => {
|
||||
return (ms / 1000).toFixed(2) + "s";
|
||||
};
|
||||
|
||||
const ActivityPage = () => {
|
||||
const { metrics } = useAPI();
|
||||
const sortedMetrics = useMemo(() => {
|
||||
return [...metrics].sort((a, b) => b.id - a.id);
|
||||
}, [metrics]);
|
||||
|
||||
return (
|
||||
<div className="p-6">
|
||||
<h1 className="text-2xl font-bold mb-4">Activity</h1>
|
||||
|
||||
{metrics.length === 0 ? (
|
||||
<div className="text-center py-8">
|
||||
<p className="text-gray-600">No metrics data available</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="overflow-x-auto">
|
||||
<table className="min-w-full divide-y">
|
||||
<thead>
|
||||
<tr>
|
||||
<th className="px-4 py-3 text-left text-xs font-medium uppercase tracking-wider">Id</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Timestamp</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Model</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Input Tokens</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Output Tokens</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Prompt Processing</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Generation Speed</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Duration</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody className="divide-y">
|
||||
{sortedMetrics.map((metric) => (
|
||||
<tr key={`metric_${metric.id}`}>
|
||||
<td className="px-4 py-4 whitespace-nowrap text-sm">{metric.id + 1 /* un-zero index */}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatTimestamp(metric.timestamp)}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.model}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.input_tokens.toLocaleString()}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.output_tokens.toLocaleString()}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatSpeed(metric.prompt_per_second)}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatSpeed(metric.tokens_per_second)}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatDuration(metric.duration_ms)}</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ActivityPage;
|
||||
@@ -1,24 +1,38 @@
|
||||
import { useState, useEffect, useRef, useMemo, useCallback } from "react";
|
||||
import { useAPI } from "../contexts/APIProvider";
|
||||
import { usePersistentState } from "../hooks/usePersistentState";
|
||||
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
|
||||
import {
|
||||
RiTextWrap,
|
||||
RiAlignJustify,
|
||||
RiFontSize,
|
||||
RiMenuSearchLine,
|
||||
RiMenuSearchFill,
|
||||
RiCloseCircleFill,
|
||||
} from "react-icons/ri";
|
||||
import { useTheme } from "../contexts/ThemeProvider";
|
||||
|
||||
const LogViewer = () => {
|
||||
const { proxyLogs, upstreamLogs, enableProxyLogs, enableUpstreamLogs } = useAPI();
|
||||
|
||||
useEffect(() => {
|
||||
enableProxyLogs(true);
|
||||
enableUpstreamLogs(true);
|
||||
return () => {
|
||||
enableProxyLogs(false);
|
||||
enableUpstreamLogs(false);
|
||||
};
|
||||
}, []);
|
||||
const { proxyLogs, upstreamLogs } = useAPI();
|
||||
const { isNarrow } = useTheme();
|
||||
const direction = isNarrow ? "vertical" : "horizontal";
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-5">
|
||||
<LogPanel id="proxy" title="Proxy Logs" logData={proxyLogs} />
|
||||
<LogPanel id="upstream" title="Upstream Logs" logData={upstreamLogs} />
|
||||
</div>
|
||||
<PanelGroup direction={direction} className="gap-2" autoSaveId="logviewer-panel-group">
|
||||
<Panel id="proxy" defaultSize={50} minSize={5} maxSize={100} collapsible={true}>
|
||||
<LogPanel id="proxy" title="Proxy Logs" logData={proxyLogs} />
|
||||
</Panel>
|
||||
<PanelResizeHandle
|
||||
className={
|
||||
direction === "horizontal"
|
||||
? "w-2 h-full bg-primary hover:bg-success transition-colors rounded"
|
||||
: "w-full h-2 bg-primary hover:bg-success transition-colors rounded"
|
||||
}
|
||||
/>
|
||||
<Panel id="upstream" defaultSize={50} minSize={5} maxSize={100} collapsible={true}>
|
||||
<LogPanel id="upstream" title="Upstream Logs" logData={upstreamLogs} />
|
||||
</Panel>
|
||||
</PanelGroup>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -26,20 +40,15 @@ interface LogPanelProps {
|
||||
id: string;
|
||||
title: string;
|
||||
logData: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export const LogPanel = ({ id, title, logData, className }: LogPanelProps) => {
|
||||
export const LogPanel = ({ id, title, logData }: LogPanelProps) => {
|
||||
const [filterRegex, setFilterRegex] = useState("");
|
||||
const [panelState, setPanelState] = usePersistentState<"hide" | "small" | "max">(
|
||||
`logPanel-${id}-panelState`,
|
||||
"small"
|
||||
);
|
||||
const [fontSize, setFontSize] = usePersistentState<"xxs" | "xs" | "small" | "normal">(
|
||||
`logPanel-${id}-fontSize`,
|
||||
"normal"
|
||||
);
|
||||
const [wrapText, setTextWrap] = usePersistentState(`logPanel-${id}-wrapText`, false);
|
||||
const [showFilter, setShowFilter] = usePersistentState(`logPanel-${id}-showFilter`, false);
|
||||
|
||||
const textWrapClass = useMemo(() => {
|
||||
return wrapText ? "whitespace-pre-wrap" : "whitespace-pre";
|
||||
@@ -60,14 +69,19 @@ export const LogPanel = ({ id, title, logData, className }: LogPanelProps) => {
|
||||
});
|
||||
}, []);
|
||||
|
||||
const togglePanelState = useCallback(() => {
|
||||
setPanelState((prev) => {
|
||||
if (prev === "small") return "max";
|
||||
if (prev === "hide") return "small";
|
||||
return "hide";
|
||||
});
|
||||
const toggleWrapText = useCallback(() => {
|
||||
setTextWrap((prev) => !prev);
|
||||
}, []);
|
||||
|
||||
const toggleFilter = useCallback(() => {
|
||||
if (showFilter) {
|
||||
setShowFilter(false);
|
||||
setFilterRegex(""); // Clear filter when closing
|
||||
} else {
|
||||
setShowFilter(true);
|
||||
}
|
||||
}, [filterRegex, setFilterRegex, showFilter]);
|
||||
|
||||
const fontSizeClass = useMemo(() => {
|
||||
switch (fontSize) {
|
||||
case "xxs":
|
||||
@@ -101,60 +115,48 @@ export const LogPanel = ({ id, title, logData, className }: LogPanelProps) => {
|
||||
}, [filteredLogs]);
|
||||
|
||||
return (
|
||||
<div className={`bg-surface border border-border rounded-lg overflow-hidden flex flex-col ${className || ""}`}>
|
||||
<div className="bg-surface border border-border rounded-lg overflow-hidden flex flex-col h-full">
|
||||
<div className="p-4 border-b border-border bg-secondary">
|
||||
<div className="flex flex-col md:flex-row md:items-center md:justify-between gap-4">
|
||||
{/* Title - Always full width on mobile, normal on desktop */}
|
||||
<div className="w-full md:w-auto" onClick={togglePanelState}>
|
||||
<h3 className="m-0 text-lg">{title}</h3>
|
||||
<div className="flex items-center justify-between">
|
||||
<h3 className="m-0 text-lg p-0">{title}</h3>
|
||||
|
||||
<div className="flex gap-2 items-center">
|
||||
<button className="btn" onClick={toggleFontSize}>
|
||||
<RiFontSize />
|
||||
</button>
|
||||
<button className="btn" onClick={toggleWrapText}>
|
||||
{wrapText ? <RiTextWrap /> : <RiAlignJustify />}
|
||||
</button>
|
||||
<button className="btn" onClick={toggleFilter}>
|
||||
{showFilter ? <RiMenuSearchFill /> : <RiMenuSearchLine />}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col sm:flex-row gap-4 w-full md:w-auto">
|
||||
{/* Sizing Buttons - Stacks vertically on mobile */}
|
||||
<div className="flex flex-wrap gap-2">
|
||||
<button className="btn" onClick={togglePanelState}>
|
||||
size: {panelState}
|
||||
</button>
|
||||
<button className="btn" onClick={toggleFontSize}>
|
||||
font: {fontSize}
|
||||
</button>
|
||||
<button className="btn" onClick={() => setTextWrap((prev) => !prev)}>
|
||||
{wrapText ? "wrap" : "wrap off"}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Filtering Options - Full width on mobile, normal on desktop */}
|
||||
<div className="flex flex-1 min-w-0 gap-2">
|
||||
{/* Filtering Options - Full width on mobile, normal on desktop */}
|
||||
{showFilter && (
|
||||
<div className="mt-2 w-full">
|
||||
<div className="flex gap-2 items-center w-full">
|
||||
<input
|
||||
type="text"
|
||||
className="flex-1 min-w-[120px] text-sm border p-2 rounded"
|
||||
className="w-full text-sm border p-2 rounded"
|
||||
placeholder="Filter logs..."
|
||||
value={filterRegex}
|
||||
onChange={(e) => setFilterRegex(e.target.value)}
|
||||
/>
|
||||
<button className="btn" onClick={() => setFilterRegex("")}>
|
||||
Clear
|
||||
<button className="pl-2" onClick={() => setFilterRegex("")}>
|
||||
<RiCloseCircleFill size="24" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="bg-background font-mono text-sm flex-1 overflow-hidden">
|
||||
<pre ref={preTagRef} className={`${textWrapClass} ${fontSizeClass} h-full overflow-auto p-4`}>
|
||||
{filteredLogs}
|
||||
</pre>
|
||||
</div>
|
||||
|
||||
{panelState !== "hide" && (
|
||||
<div className="flex-1 bg-background font-mono text-sm leading-[1.4] p-3">
|
||||
<pre
|
||||
ref={preTagRef}
|
||||
className={`flex-1 p-4 overflow-y-auto whitespace-pre min-h-0 ${textWrapClass} ${fontSizeClass}`}
|
||||
style={{
|
||||
maxHeight: panelState === "max" ? "1500px" : "500px",
|
||||
}}
|
||||
>
|
||||
{filteredLogs}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default LogViewer;
|
||||
|
||||
@@ -1,19 +1,50 @@
|
||||
import { useState, useEffect, useCallback } from "react";
|
||||
import { useState, useCallback, useMemo } from "react";
|
||||
import { useAPI } from "../contexts/APIProvider";
|
||||
import { LogPanel } from "./LogViewer";
|
||||
import { usePersistentState } from "../hooks/usePersistentState";
|
||||
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
|
||||
import { useTheme } from "../contexts/ThemeProvider";
|
||||
import { RiEyeFill, RiEyeOffFill, RiStopCircleLine, RiSwapBoxFill } from "react-icons/ri";
|
||||
|
||||
export default function ModelsPage() {
|
||||
const { models, enableModelUpdates, unloadAllModels, upstreamLogs, enableUpstreamLogs } = useAPI();
|
||||
const [isUnloading, setIsUnloading] = useState(false);
|
||||
const { isNarrow } = useTheme();
|
||||
const direction = isNarrow ? "vertical" : "horizontal";
|
||||
const { upstreamLogs } = useAPI();
|
||||
|
||||
useEffect(() => {
|
||||
enableModelUpdates(true);
|
||||
enableUpstreamLogs(true);
|
||||
return () => {
|
||||
enableModelUpdates(false);
|
||||
enableUpstreamLogs(false);
|
||||
};
|
||||
}, []);
|
||||
return (
|
||||
<PanelGroup direction={direction} className="gap-2" autoSaveId={"models-panel-group"}>
|
||||
<Panel id="models" defaultSize={50} minSize={isNarrow ? 0 : 25} maxSize={100} collapsible={isNarrow}>
|
||||
<ModelsPanel />
|
||||
</Panel>
|
||||
|
||||
<PanelResizeHandle
|
||||
className={
|
||||
direction === "horizontal"
|
||||
? "w-2 h-full bg-primary hover:bg-success transition-colors rounded"
|
||||
: "w-full h-2 bg-primary hover:bg-success transition-colors rounded"
|
||||
}
|
||||
/>
|
||||
<Panel collapsible={true} defaultSize={50} minSize={0}>
|
||||
<div className="flex flex-col h-full space-y-4">
|
||||
{direction === "horizontal" && <StatsPanel />}
|
||||
<div className="flex-1 min-h-0">
|
||||
<LogPanel id="modelsupstream" title="Upstream Logs" logData={upstreamLogs} />
|
||||
</div>
|
||||
</div>
|
||||
</Panel>
|
||||
</PanelGroup>
|
||||
);
|
||||
}
|
||||
|
||||
function ModelsPanel() {
|
||||
const { models, loadModel, unloadAllModels } = useAPI();
|
||||
const [isUnloading, setIsUnloading] = useState(false);
|
||||
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
|
||||
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
|
||||
|
||||
const filteredModels = useMemo(() => {
|
||||
return models.filter((model) => showUnlisted || !model.unlisted);
|
||||
}, [models, showUnlisted]);
|
||||
|
||||
const handleUnloadAllModels = useCallback(async () => {
|
||||
setIsUnloading(true);
|
||||
@@ -22,53 +53,121 @@ export default function ModelsPage() {
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
} finally {
|
||||
// at least give it a second to show the unloading message
|
||||
setTimeout(() => {
|
||||
setIsUnloading(false);
|
||||
}, 1000);
|
||||
}
|
||||
}, []);
|
||||
}, [unloadAllModels]);
|
||||
|
||||
const toggleIdorName = useCallback(() => {
|
||||
setShowIdorName((prev) => (prev === "name" ? "id" : "name"));
|
||||
}, [showIdorName]);
|
||||
|
||||
return (
|
||||
<div className="h-screen">
|
||||
<div className="flex flex-col md:flex-row gap-4">
|
||||
{/* Left Column */}
|
||||
<div className="w-full md:w-1/2 flex items-top">
|
||||
<div className="card w-full">
|
||||
<h2 className="">Models</h2>
|
||||
<button className="btn" onClick={handleUnloadAllModels} disabled={isUnloading}>
|
||||
{isUnloading ? "Unloading..." : "Unload All Models"}
|
||||
<div className="card h-full flex flex-col">
|
||||
<div className="shrink-0">
|
||||
<h2>Models</h2>
|
||||
<div className="flex justify-between">
|
||||
<div className="flex gap-2">
|
||||
<button className="btn flex items-center gap-2" onClick={toggleIdorName} style={{ lineHeight: "1.2" }}>
|
||||
<RiSwapBoxFill /> {showIdorName === "id" ? "ID" : "Name"}
|
||||
</button>
|
||||
<table className="w-full mt-4">
|
||||
<thead>
|
||||
<tr className="border-b border-primary">
|
||||
<th className="text-left p-2">Name</th>
|
||||
<th className="text-left p-2">State</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{models.map((model) => (
|
||||
<tr key={model.id} className="border-b hover:bg-secondary-hover border-border">
|
||||
<td className="p-2">
|
||||
<a href={`/upstream/${model.id}/`} className="underline" target="top">
|
||||
{model.id}
|
||||
</a>
|
||||
</td>
|
||||
<td className="p-2">
|
||||
<span className={`status status--${model.state}`}>{model.state}</span>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Right Column */}
|
||||
<div className="w-full md:w-1/2 flex items-top">
|
||||
<LogPanel id="modelsupstream" title="Upstream Logs" logData={upstreamLogs} className="h-full" />
|
||||
<button
|
||||
className="btn flex items-center gap-2"
|
||||
onClick={() => setShowUnlisted(!showUnlisted)}
|
||||
style={{ lineHeight: "1.2" }}
|
||||
>
|
||||
{showUnlisted ? <RiEyeFill /> : <RiEyeOffFill />} unlisted
|
||||
</button>
|
||||
</div>
|
||||
<button className="btn flex items-center gap-2" onClick={handleUnloadAllModels} disabled={isUnloading}>
|
||||
<RiStopCircleLine size="24" /> {isUnloading ? "Unloading..." : "Unload"}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex-1 overflow-y-auto">
|
||||
<table className="w-full">
|
||||
<thead className="sticky top-0 bg-card z-10">
|
||||
<tr className="border-b border-primary bg-surface">
|
||||
<th className="text-left p-2">{showIdorName === "id" ? "Model ID" : "Name"}</th>
|
||||
<th className="text-left p-2"></th>
|
||||
<th className="text-left p-2">State</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{filteredModels.map((model) => (
|
||||
<tr key={model.id} className="border-b hover:bg-secondary-hover border-border">
|
||||
<td className={`p-2 ${model.unlisted ? "text-txtsecondary" : ""}`}>
|
||||
<a href={`/upstream/${model.id}/`} className={`underline`} target="_blank">
|
||||
{showIdorName === "id" ? model.id : model.name !== "" ? model.name : model.id}
|
||||
</a>
|
||||
{model.description !== "" && (
|
||||
<p className={model.unlisted ? "text-opacity-70" : ""}>
|
||||
<em>{model.description}</em>
|
||||
</p>
|
||||
)}
|
||||
</td>
|
||||
<td className="p-2 w-[50px]">
|
||||
<button
|
||||
className="btn btn--sm"
|
||||
disabled={model.state !== "stopped"}
|
||||
onClick={() => loadModel(model.id)}
|
||||
>
|
||||
Load
|
||||
</button>
|
||||
</td>
|
||||
<td className="p-2 w-[75px]">
|
||||
<span className={`status status--${model.state}`}>{model.state}</span>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function StatsPanel() {
|
||||
const { metrics } = useAPI();
|
||||
|
||||
const [totalRequests, totalInputTokens, totalOutputTokens, avgTokensPerSecond] = useMemo(() => {
|
||||
const totalRequests = metrics.length;
|
||||
if (totalRequests === 0) {
|
||||
return [0, 0, 0];
|
||||
}
|
||||
const totalInputTokens = metrics.reduce((sum, m) => sum + m.input_tokens, 0);
|
||||
const totalOutputTokens = metrics.reduce((sum, m) => sum + m.output_tokens, 0);
|
||||
const avgTokensPerSecond = (metrics.reduce((sum, m) => sum + m.tokens_per_second, 0) / totalRequests).toFixed(2);
|
||||
return [totalRequests, totalInputTokens, totalOutputTokens, avgTokensPerSecond];
|
||||
}, [metrics]);
|
||||
|
||||
return (
|
||||
<div className="card">
|
||||
<div className="rounded-lg overflow-hidden border border-gray-200">
|
||||
<table className="w-full">
|
||||
<tbody>
|
||||
<tr>
|
||||
<th className="p-2 font-medium border-b border-gray-200 text-right">Requests</th>
|
||||
<th className="p-2 font-medium border-l border-b border-gray-200 text-right">Processed</th>
|
||||
<th className="p-2 font-medium border-l border-b border-gray-200 text-right">Generated</th>
|
||||
<th className="p-2 font-medium border-l border-b border-gray-200 text-right">Tokens/Sec</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td className="p-2 text-right border-r border-gray-200">{totalRequests}</td>
|
||||
<td className="p-2 text-right border-r border-gray-200">
|
||||
{new Intl.NumberFormat().format(totalInputTokens)}
|
||||
</td>
|
||||
<td className="p-2 text-right border-r border-gray-200">
|
||||
{new Intl.NumberFormat().format(totalOutputTokens)}
|
||||
</td>
|
||||
<td className="p-2 text-right">{avgTokensPerSecond}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||