Compare commits
139 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9e02c22ff8 | |||
| 0bdbf2fdc1 | |||
| 49035e2e8e | |||
| 9963ae18bf | |||
| 2ae48c713b | |||
| 54c519e365 | |||
| 3fce9ee0e9 | |||
| 5899ae7966 | |||
| 591a9cdf4d | |||
| 9a3c656738 | |||
| 75015f82ea | |||
| cc33b6c270 | |||
| 4fa12a429c | |||
| 2dc0ca0663 | |||
| a84098d3b4 | |||
| 4d02ccd26a | |||
| dfd47eeac4 | |||
| 1ac6499c08 | |||
| 25f3dc25e7 | |||
| 8422e4e6a1 | |||
| 02ee29d881 | |||
| b2a891f8f4 | |||
| 8d2b568897 | |||
| fb44cf4e08 | |||
| 02aee4e86d | |||
| f45896d395 | |||
| f7e46a359f | |||
| c260907415 | |||
| b83a5fa291 | |||
| 6e2ff28d59 | |||
| a8b81f2799 | |||
| f9ee7156dc | |||
| 2d00120781 | |||
| afc9aef058 | |||
| d7b390df74 | |||
| 5025c2f1f3 | |||
| e3a0b013c1 | |||
| f5763a94a0 | |||
| 8ada72eb57 | |||
| 2441b383d3 | |||
| 25f251699c | |||
| 7f37bcc6eb | |||
| 519c3a4d22 | |||
| 9dc4bcb46c | |||
| cb876c143b | |||
| bc652709a5 | |||
| 9548931258 | |||
| 5c5a5da664 | |||
| aa9ef59aa5 | |||
| 09e52c0500 | |||
| ca9063ffbe | |||
| 21d7973d11 | |||
| cc450e9c5f | |||
| 27465fe053 | |||
| 9667989727 | |||
| d9a1ddea0d | |||
| e7ab024ca0 | |||
| 448ccae959 | |||
| ec0348e431 | |||
| 06eda7f591 | |||
| 5fad24c16f | |||
| 8404244fab | |||
| 712cd01081 | |||
| 1f7aa359b1 | |||
| b138d6cf25 | |||
| fb7c808082 | |||
| a7e640b0f7 | |||
| 593604dfdc | |||
| b8f888f864 | |||
| 192b2ae621 | |||
| b7f8cb5094 | |||
| a23da6eb57 | |||
| 4c3aa40564 | |||
| 84e2c07a7e | |||
| 680af28bcc | |||
| d94db42ffe | |||
| 93cd83c55c | |||
| 5565fca3ac | |||
| d625ab8d92 | |||
| a3f82c140b | |||
| 5c97299e7b | |||
| 671c1a5a7b | |||
| 52c0196e0f | |||
| 3201a68a04 | |||
| 3ac94ad20e | |||
| 60355bf74a | |||
| 9b2ed244e2 | |||
| eeb72297f7 | |||
| eabfe70cc6 | |||
| 29cd98878d | |||
| b3d331da0d | |||
| 62275e078d | |||
| 88916059e1 | |||
| 082d5d0fc5 | |||
| 53338938bd | |||
| af653347ae | |||
| 1e25b44a06 | |||
| 0815bb4cc3 | |||
| 7187cfe52e | |||
| 24089d2d9c | |||
| ebabe55ff3 | |||
| 41a338297c | |||
| 7e3353efeb | |||
| 4ed58fb173 | |||
| f5a2be698d | |||
| f5e6ec3b7a | |||
| 3f462da146 | |||
| 48bd766536 | |||
| 8d319da4dd | |||
| be7c502448 | |||
| 92336f00bf | |||
| ed2a50d9a6 | |||
| 0acfdb9f78 | |||
| 96a8ea0241 | |||
| f20f2c9b7a | |||
| 7a97c38828 | |||
| 4885132565 | |||
| 8b46a0b7f1 | |||
| 1b6736ec6f | |||
| ddc1ce031e | |||
| 11d024bbaa | |||
| 43e23c16dc | |||
| f9c8e763ba | |||
| d7e1bb9f7c | |||
| ab93460a8b | |||
| 13d4552edc | |||
| 6667e307a2 | |||
| 7ac446e6a9 | |||
| eab9795bcc | |||
| 09bdd86b54 | |||
| 85cd74a51c | |||
| 314d2f2212 | |||
| fad25f3e11 | |||
| 2c3e3e27f7 | |||
| baeb0c4e7f | |||
| 2833517eef | |||
| abdc2bfdb3 | |||
| c3b834737f | |||
| 3c8e727b73 |
@@ -0,0 +1,15 @@
|
||||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||
language: "en-US"
|
||||
early_access: false
|
||||
reviews:
|
||||
profile: "chill"
|
||||
request_changes_workflow: false
|
||||
high_level_summary: true
|
||||
poem: false
|
||||
review_status: true
|
||||
collapse_walkthrough: false
|
||||
auto_review:
|
||||
enabled: true
|
||||
drafts: false
|
||||
chat:
|
||||
auto_reply: true
|
||||
@@ -0,0 +1,37 @@
|
||||
---
|
||||
name: Bug Report
|
||||
about: Something is not working as expected...
|
||||
title: ''
|
||||
labels: bug
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Expected behaviour**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Operating system and version**
|
||||
|
||||
- OS: (linux, osx, windows, freebsd, etc)
|
||||
- GPUs: (list architecture)
|
||||
|
||||
**My Configuration**
|
||||
|
||||
```yaml
|
||||
# copy / paste your configuration here
|
||||
```
|
||||
|
||||
**Proxy Logs**
|
||||
|
||||
```
|
||||
# copy / paste from /logs
|
||||
```
|
||||
|
||||
**Upstream Logs**
|
||||
|
||||
```
|
||||
# copy/paste from /logs
|
||||
```
|
||||
@@ -0,0 +1,23 @@
|
||||
# https://docs.github.com/en/actions/use-cases-and-examples/project-management/closing-inactive-issues
|
||||
name: Close inactive issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "32 1 * * *"
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
days-before-issue-stale: 14
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -0,0 +1,46 @@
|
||||
name: Build Containers
|
||||
|
||||
on:
|
||||
# time has no specific meaning, trying to time it after
|
||||
# the llama.cpp daily packages are published
|
||||
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
|
||||
schedule:
|
||||
- cron: "37 5 * * *"
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [intel, cuda, vulkan, cpu, musa]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Run build-container
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: ./docker/build-container.sh ${{ matrix.platform }} true
|
||||
|
||||
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
|
||||
# see: https://github.com/actions/delete-package-versions/issues/74
|
||||
delete-untagged-containers:
|
||||
needs: build-and-push
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/delete-package-versions@v5
|
||||
with:
|
||||
package-name: 'llama-swap'
|
||||
package-type: 'container'
|
||||
delete-only-untagged-versions: 'true'
|
||||
@@ -0,0 +1,50 @@
|
||||
name: Windows CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
run-tests:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.23'
|
||||
|
||||
# cache simple-responder to save the build time
|
||||
- name: Restore Simple Responder
|
||||
id: restore-simple-responder
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
|
||||
# necessary for testing proxy/Process swapping
|
||||
- name: Create simple-responder
|
||||
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: make simple-responder-windows
|
||||
|
||||
- name: Save Simple Responder
|
||||
# nothing new to save ... skip this step
|
||||
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||
id: save-simple-responder
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
|
||||
- name: Test all
|
||||
shell: bash
|
||||
run: make test-all
|
||||
@@ -0,0 +1,47 @@
|
||||
name: Linux CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
run-tests:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.23'
|
||||
|
||||
# cache simple-responder to save the build time
|
||||
- name: Restore Simple Responder
|
||||
id: restore-simple-responder
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
|
||||
# necessary for testing proxy/Process swapping
|
||||
- name: Create simple-responder
|
||||
run: make simple-responder
|
||||
|
||||
- name: Save Simple Responder
|
||||
# nothing new to save ... skip this step
|
||||
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||
id: save-simple-responder
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
|
||||
- name: Test all
|
||||
run: make test-all
|
||||
@@ -5,6 +5,9 @@ on:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
@@ -20,6 +23,19 @@ jobs:
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
|
||||
-
|
||||
name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '23' # or your preferred version
|
||||
-
|
||||
name: Install dependencies and build UI
|
||||
run: |
|
||||
cd ui
|
||||
npm ci
|
||||
npm run build
|
||||
|
||||
-
|
||||
name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v6
|
||||
|
||||
+22
-1
@@ -6,6 +6,27 @@ builds:
|
||||
goos:
|
||||
- linux
|
||||
- darwin
|
||||
- freebsd
|
||||
- windows
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm64
|
||||
ignore:
|
||||
- goos: freebsd
|
||||
goarch: arm64
|
||||
- goos: windows
|
||||
goarch: arm64
|
||||
|
||||
archives:
|
||||
- id: default
|
||||
formats:
|
||||
- tar.gz
|
||||
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||
builds_info:
|
||||
group: root
|
||||
owner: root
|
||||
format_overrides:
|
||||
# use zip format for windows
|
||||
- goos: windows
|
||||
formats:
|
||||
- zip
|
||||
@@ -19,28 +19,48 @@ all: mac linux simple-responder
|
||||
clean:
|
||||
rm -rf $(BUILD_DIR)
|
||||
|
||||
test:
|
||||
go test -short -v ./proxy
|
||||
proxy/ui_dist/placeholder.txt:
|
||||
mkdir -p proxy/ui_dist
|
||||
touch $@
|
||||
|
||||
test-all:
|
||||
go test -v ./proxy
|
||||
test: proxy/ui_dist/placeholder.txt
|
||||
go test -short -v -count=1 ./proxy
|
||||
|
||||
test-all: proxy/ui_dist/placeholder.txt
|
||||
go test -v -count=1 ./proxy
|
||||
|
||||
ui/node_modules:
|
||||
cd ui && npm install
|
||||
|
||||
# build react UI
|
||||
ui: ui/node_modules
|
||||
cd ui && npm run build
|
||||
|
||||
# Build OSX binary
|
||||
mac:
|
||||
mac: ui
|
||||
@echo "Building Mac binary..."
|
||||
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
||||
|
||||
# Build Linux binary
|
||||
linux:
|
||||
linux: ui
|
||||
@echo "Building Linux binary..."
|
||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||
|
||||
# Build Windows binary
|
||||
windows: ui
|
||||
@echo "Building Windows binary..."
|
||||
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
||||
|
||||
# for testing proxy.Process
|
||||
simple-responder:
|
||||
@echo "Building simple responder"
|
||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
|
||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
|
||||
|
||||
simple-responder-windows:
|
||||
@echo "Building simple responder for windows"
|
||||
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe misc/simple-responder/simple-responder.go
|
||||
|
||||
# Ensure build directory exists
|
||||
$(BUILD_DIR):
|
||||
mkdir -p $(BUILD_DIR)
|
||||
@@ -60,4 +80,4 @@ release:
|
||||
git tag "$$new_tag";
|
||||
|
||||
# Phony targets
|
||||
.PHONY: all clean osx linux
|
||||
.PHONY: all clean ui mac linux windows simple-responder
|
||||
|
||||
@@ -1,137 +1,158 @@
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
# llama-swap
|
||||
|
||||

|
||||
|
||||
# Introduction
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file). Download a pre-built [release](https://github.com/mostlygeek/llama-swap/releases) or built it yourself from source with `make clean all`.
|
||||
|
||||
## How does it work?
|
||||
When a request is made to an OpenAI compatible endpoints, lama-swap will extract the `model` value load the appropriate server configuration to serve it. If a server is already running it will stop it and start a new one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
|
||||
|
||||
## Do I need to use llama.cpp's server (llama-server)?
|
||||
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported. For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
||||
Written in golang, it is very easy to install (single binary with no dependencies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
|
||||
|
||||
## Features:
|
||||
|
||||
- ✅ Easy to deploy: single binary with no dependencies
|
||||
- ✅ Easy to config: single yaml file
|
||||
- ✅ On-demand model switching
|
||||
- ✅ Full control over server settings per model
|
||||
- ✅ OpenAI API supported endpoints:
|
||||
- `v1/completions`
|
||||
- `v1/chat/completions`
|
||||
- `v1/embeddings`
|
||||
- `v1/rerank`
|
||||
- `v1/audio/speech`
|
||||
- ✅ Multiple GPU support
|
||||
- ✅ Run multiple models at once with `profiles`
|
||||
- ✅ Remote log monitoring at `/log`
|
||||
- ✅ Automatic unloading of models from GPUs after timeout
|
||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||
- ✅ llama-swap custom API endpoints
|
||||
- `/log` - remote log monitoring
|
||||
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
- ✅ Docker and Podman support
|
||||
- ✅ Full control over server settings per model
|
||||
|
||||
## How does llama-swap work?
|
||||
|
||||
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
|
||||
|
||||
## config.yaml
|
||||
|
||||
llama-swap's configuration is purposefully simple.
|
||||
llama-swap is managed entirely through a yaml configuration file.
|
||||
|
||||
It can be very minimal to start:
|
||||
|
||||
```yaml
|
||||
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
||||
# Default (and minimum) is 15 seconds
|
||||
healthCheckTimeout: 60
|
||||
|
||||
# Write HTTP logs (useful for troubleshooting), defaults to false
|
||||
logRequests: true
|
||||
|
||||
# define valid model values and the upstream server start
|
||||
models:
|
||||
"llama":
|
||||
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
||||
|
||||
# where to reach the server started by cmd, make sure the ports match
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# aliases names to use this model for
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
|
||||
# check this path for an HTTP 200 OK before serving requests
|
||||
# default: /health to match llama.cpp
|
||||
# use "none" to skip endpoint checking, but may cause HTTP errors
|
||||
# until the model is ready
|
||||
checkEndpoint: /custom-endpoint
|
||||
|
||||
# automatically unload the model after this many seconds
|
||||
# ttl values must be a value greater than 0
|
||||
# default: 0 = never unload model
|
||||
ttl: 60
|
||||
|
||||
"qwen":
|
||||
# environment variables to pass to the command
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
|
||||
# multiline for readability
|
||||
cmd: >
|
||||
llama-server --port 8999
|
||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# unlisted models do not show up in /v1/models or /upstream lists
|
||||
# but they can still be requested as normal
|
||||
"qwen-unlisted":
|
||||
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
unlisted: true
|
||||
|
||||
# profiles make it easy to managing multi model (and gpu) configurations.
|
||||
#
|
||||
# Tips:
|
||||
# - each model must be listening on a unique address and port
|
||||
# - the model name is in this format: "profile_name:model", like "coding:qwen"
|
||||
# - the profile will load and unload all models in the profile at the same time
|
||||
profiles:
|
||||
coding:
|
||||
- "qwen"
|
||||
- "llama"
|
||||
"qwen2.5":
|
||||
cmd: |
|
||||
/path/to/llama-server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
--port ${PORT}
|
||||
```
|
||||
|
||||
### Advanced Examples
|
||||
However, there are many more capabilities that llama-swap supports:
|
||||
|
||||
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
||||
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||
- `groups` to run multiple models at once
|
||||
- `ttl` to automatically unload models
|
||||
- `macros` for reusable snippets
|
||||
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
|
||||
- `env` to pass custom environment variables to inference servers
|
||||
- `cmdStop` for to gracefully stop Docker/Podman containers
|
||||
- `useModelName` to override model names sent to upstream servers
|
||||
- `healthCheckTimeout` to control model startup wait times
|
||||
- `${PORT}` automatic port variables for dynamic port assignment
|
||||
|
||||
### Installation
|
||||
See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration) in the wiki all options and examples.
|
||||
|
||||
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
|
||||
Docker is the quickest way to try out llama-swap:
|
||||
|
||||
```shell
|
||||
# use CPU inference comes with the example config above
|
||||
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
|
||||
|
||||
# qwen2.5 0.5B
|
||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer no-key" \
|
||||
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||
jq -r '.choices[0].message.content'
|
||||
|
||||
# SmolLM2 135M
|
||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer no-key" \
|
||||
-d '{"model":"smollm2","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||
jq -r '.choices[0].message.content'
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Docker images are built nightly for cuda, intel, vulcan, etc ...</summary>
|
||||
|
||||
They include:
|
||||
|
||||
- `ghcr.io/mostlygeek/llama-swap:cpu`
|
||||
- `ghcr.io/mostlygeek/llama-swap:cuda`
|
||||
- `ghcr.io/mostlygeek/llama-swap:intel`
|
||||
- `ghcr.io/mostlygeek/llama-swap:vulkan`
|
||||
- ROCm disabled until fixed in llama.cpp container
|
||||
|
||||
Specific versions are also available and are tagged with the llama-swap, architecture and llama.cpp versions. For example: `ghcr.io/mostlygeek/llama-swap:v89-cuda-b4716`
|
||||
|
||||
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
|
||||
|
||||
```shell
|
||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
ghcr.io/mostlygeek/llama-swap:cuda
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases))
|
||||
|
||||
Pre-built binaries are available for Linux, Mac, Windows and FreeBSD. These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server.
|
||||
|
||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||
* _Note: Windows currently untested._
|
||||
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
||||
1. Create a configuration file, see the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration).
|
||||
1. Run the binary with `llama-swap --config path/to/config.yaml --listen localhost:8080`.
|
||||
Available flags:
|
||||
- `--config`: Path to the configuration file (default: `config.yaml`).
|
||||
- `--listen`: Address and port to listen on (default: `:8080`).
|
||||
- `--version`: Show version information and exit.
|
||||
- `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`).
|
||||
|
||||
### Building from source
|
||||
|
||||
1. Install golang for your system
|
||||
1. Build requires golang and nodejs for the user interface.
|
||||
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
||||
1. `make clean all`
|
||||
1. Binaries will be in `build/` subdirectory
|
||||
|
||||
## Monitoring Logs
|
||||
|
||||
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
|
||||
Open the `http://<host>:<port>/` with your browser to get a web interface with streaming logs.
|
||||
|
||||
Of course, CLI access is also supported:
|
||||
CLI access is also supported:
|
||||
|
||||
```
|
||||
```shell
|
||||
# sends up to the last 10KB of logs
|
||||
curl http://host/logs'
|
||||
|
||||
# streams logs
|
||||
# streams combined logs
|
||||
curl -Ns 'http://host/logs/stream'
|
||||
|
||||
# just llama-swap's logs
|
||||
curl -Ns 'http://host/logs/stream/proxy'
|
||||
|
||||
# just upstream's logs
|
||||
curl -Ns 'http://host/logs/stream/upstream'
|
||||
|
||||
# stream and filter logs with linux pipes
|
||||
curl -Ns http://host/logs/stream | grep 'eval time'
|
||||
|
||||
@@ -139,27 +160,12 @@ curl -Ns http://host/logs/stream | grep 'eval time'
|
||||
curl -Ns 'http://host/logs/stream?no-history'
|
||||
```
|
||||
|
||||
## Systemd Unit Files
|
||||
## Do I need to use llama.cpp's server (llama-server)?
|
||||
|
||||
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
||||
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
|
||||
|
||||
`/etc/systemd/system/llama-swap.service`
|
||||
```
|
||||
[Unit]
|
||||
Description=llama-swap
|
||||
After=network.target
|
||||
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
||||
|
||||
[Service]
|
||||
User=nobody
|
||||
## Star History
|
||||
|
||||
# set this to match your environment
|
||||
ExecStart=/path/to/llama-swap --config /path/to/llama-swap.config.yml
|
||||
|
||||
Restart=on-failure
|
||||
RestartSec=3
|
||||
StartLimitBurst=3
|
||||
StartLimitInterval=30
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
[](https://www.star-history.com/#mostlygeek/llama-swap&Date)
|
||||
|
||||
+33
-24
@@ -1,17 +1,29 @@
|
||||
# ======
|
||||
# For a more detailed configuration example:
|
||||
# https://github.com/mostlygeek/llama-swap/wiki/Configuration
|
||||
# ======
|
||||
|
||||
# Seconds to wait for llama.cpp to be available to serve requests
|
||||
# Default (and minimum): 15 seconds
|
||||
healthCheckTimeout: 15
|
||||
healthCheckTimeout: 90
|
||||
|
||||
# Log HTTP requests helpful for troubleshoot, defaults to False
|
||||
logRequests: true
|
||||
# valid log levels: debug, info (default), warn, error
|
||||
logLevel: debug
|
||||
|
||||
# creating a coding profile with models for code generation and general questions
|
||||
groups:
|
||||
coding:
|
||||
swap: false
|
||||
members:
|
||||
- "qwen"
|
||||
- "llama"
|
||||
|
||||
models:
|
||||
"llama":
|
||||
cmd: >
|
||||
cmd: |
|
||||
models/llama-server-osx
|
||||
--port 9001
|
||||
--port ${PORT}
|
||||
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
|
||||
proxy: http://127.0.0.1:9001
|
||||
|
||||
# list of model name aliases this llama.cpp instance can serve
|
||||
aliases:
|
||||
@@ -24,17 +36,15 @@ models:
|
||||
ttl: 5
|
||||
|
||||
"qwen":
|
||||
cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
proxy: http://127.0.0.1:9002
|
||||
cmd: models/llama-server-osx --port ${PORT} -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
aliases:
|
||||
- gpt-3.5-turbo
|
||||
- gpt-3.5-turbo
|
||||
|
||||
# Embedding example with Nomic
|
||||
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
|
||||
"nomic":
|
||||
proxy: http://127.0.0.1:9005
|
||||
cmd: >
|
||||
models/llama-server-osx --port 9005
|
||||
cmd: |
|
||||
models/llama-server-osx --port ${PORT}
|
||||
-m models/nomic-embed-text-v1.5.Q8_0.gguf
|
||||
--ctx-size 8192
|
||||
--batch-size 8192
|
||||
@@ -46,21 +56,26 @@ models:
|
||||
# Reranking example with bge-reranker
|
||||
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
|
||||
"bge-reranker":
|
||||
proxy: http://127.0.0.1:9006
|
||||
cmd: >
|
||||
models/llama-server-osx --port 9006
|
||||
cmd: |
|
||||
models/llama-server-osx --port ${PORT}
|
||||
-m models/bge-reranker-v2-m3-Q4_K_M.gguf
|
||||
--ctx-size 8192
|
||||
--reranking
|
||||
|
||||
# Docker Support (v26.1.4+ required!)
|
||||
"dockertest":
|
||||
cmd: |
|
||||
docker run --name dockertest
|
||||
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
||||
ghcr.io/ggerganov/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
"simple":
|
||||
# example of setting environment variables
|
||||
env:
|
||||
- CUDA_VISIBLE_DEVICES=0,1
|
||||
- env1=hello
|
||||
cmd: build/simple-responder --port 8999
|
||||
proxy: http://127.0.0.1:8999
|
||||
cmd: build/simple-responder --port ${PORT}
|
||||
unlisted: true
|
||||
|
||||
# use "none" to skip check. Caution this may cause some requests to fail
|
||||
@@ -75,10 +90,4 @@ models:
|
||||
"broken_timeout":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
proxy: http://127.0.0.1:9000
|
||||
unlisted: true
|
||||
|
||||
# creating a coding profile with models for code generation and general questions
|
||||
profiles:
|
||||
coding:
|
||||
- "qwen"
|
||||
- "llama"
|
||||
unlisted: true
|
||||
Executable
+55
@@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd $(dirname "$0")
|
||||
|
||||
ARCH=$1
|
||||
PUSH_IMAGES=${2:-false}
|
||||
|
||||
# List of allowed architectures
|
||||
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
|
||||
|
||||
# Check if ARCH is in the allowed list
|
||||
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
|
||||
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if GITHUB_TOKEN is set and not empty
|
||||
if [[ -z "$GITHUB_TOKEN" ]]; then
|
||||
echo "Error: GITHUB_TOKEN is not set or is empty."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# the most recent llama-swap tag
|
||||
# have to strip out the 'v' due to .tar.gz file naming
|
||||
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||
|
||||
if [ "$ARCH" == "cpu" ]; then
|
||||
# cpu only containers just use the latest available
|
||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu"
|
||||
echo "Building ${CONTAINER_LATEST} $LS_VER"
|
||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} .
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
else
|
||||
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
||||
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||
|
||||
# Abort if LCPP_TAG is empty.
|
||||
if [[ -z "$LCPP_TAG" ]]; then
|
||||
echo "Abort: Could not find llama-server container for arch: $ARCH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
|
||||
echo "Building ${CONTAINER_TAG} $LS_VER"
|
||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_TAG}
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
fi
|
||||
@@ -0,0 +1,17 @@
|
||||
healthCheckTimeout: 300
|
||||
logRequests: true
|
||||
|
||||
models:
|
||||
"qwen2.5":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
|
||||
"smollm2":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
@@ -0,0 +1,16 @@
|
||||
ARG BASE_TAG=server-cuda
|
||||
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
|
||||
|
||||
# has to be after the FROM
|
||||
ARG LS_VER=89
|
||||
|
||||
WORKDIR /app
|
||||
RUN \
|
||||
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
|
||||
|
||||
COPY config.example.yaml /app/config.yaml
|
||||
|
||||
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
||||
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||
@@ -0,0 +1,153 @@
|
||||
# aider, QwQ, Qwen-Coder 2.5 and llama-swap
|
||||
|
||||
This guide show how to use aider and llama-swap to get a 100% local coding co-pilot setup. The focus is on the trickest part which is configuring aider, llama-swap and llama-server to work together.
|
||||
|
||||
## Here's what you you need:
|
||||
|
||||
- aider - [installation docs](https://aider.chat/docs/install.html)
|
||||
- llama-server - [download latest release](https://github.com/ggml-org/llama.cpp/releases)
|
||||
- llama-swap - [download latest release](https://github.com/mostlygeek/llama-swap/releases)
|
||||
- [QwQ 32B](https://huggingface.co/bartowski/Qwen_QwQ-32B-GGUF) and [Qwen Coder 2.5 32B](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF) models
|
||||
- 24GB VRAM video card
|
||||
|
||||
## Running aider
|
||||
|
||||
The goal is getting this command line to work:
|
||||
|
||||
```sh
|
||||
aider --architect \
|
||||
--no-show-model-warnings \
|
||||
--model openai/QwQ \
|
||||
--editor-model openai/qwen-coder-32B \
|
||||
--model-settings-file aider.model.settings.yml \
|
||||
--openai-api-key "sk-na" \
|
||||
--openai-api-base "http://10.0.1.24:8080/v1" \
|
||||
```
|
||||
|
||||
Set `--openai-api-base` to the IP and port where your llama-swap is running.
|
||||
|
||||
## Create an aider model settings file
|
||||
|
||||
```yaml
|
||||
# aider.model.settings.yml
|
||||
|
||||
#
|
||||
# !!! important: model names must match llama-swap configuration names !!!
|
||||
#
|
||||
|
||||
- name: "openai/QwQ"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.95
|
||||
top_k: 40
|
||||
presence_penalty: 0.1
|
||||
repetition_penalty: 1
|
||||
num_ctx: 16384
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
weak_model_name: "openai/qwen-coder-32B"
|
||||
editor_model_name: "openai/qwen-coder-32B"
|
||||
|
||||
- name: "openai/qwen-coder-32B"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.8
|
||||
top_k: 20
|
||||
repetition_penalty: 1.05
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
editor_edit_format: editor-diff
|
||||
editor_model_name: "openai/qwen-coder-32B"
|
||||
```
|
||||
|
||||
## llama-swap configuration
|
||||
|
||||
```yaml
|
||||
# config.yaml
|
||||
|
||||
# The parameters are tweaked to fit model+context into 24GB VRAM GPUs
|
||||
models:
|
||||
"qwen-coder-32B":
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
cmd: >
|
||||
/path/to/llama-server
|
||||
--host 127.0.0.1 --port 8999 --flash-attn --slots
|
||||
--ctx-size 16000
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
-ngl 99
|
||||
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
|
||||
"QwQ":
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
cmd: >
|
||||
/path/to/llama-server
|
||||
--host 127.0.0.1 --port 9503 --flash-attn --metrics--slots
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
--ctx-size 32000
|
||||
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
|
||||
--temp 0.6 --repeat-penalty 1.1 --dry-multiplier 0.5
|
||||
--min-p 0.01 --top-k 40 --top-p 0.95
|
||||
-ngl 99
|
||||
--model /mnt/nvme/models/bartowski/Qwen_QwQ-32B-Q4_K_M.gguf
|
||||
```
|
||||
|
||||
## Advanced, Dual GPU Configuration
|
||||
|
||||
If you have _dual 24GB GPUs_ you can use llama-swap profiles to avoid swapping between QwQ and Qwen Coder.
|
||||
|
||||
In llama-swap's configuration file:
|
||||
|
||||
1. add a `profiles` section with `aider` as the profile name
|
||||
2. using the `env` field to specify the GPU IDs for each model
|
||||
|
||||
```yaml
|
||||
# config.yaml
|
||||
|
||||
# Add a profile for aider
|
||||
profiles:
|
||||
aider:
|
||||
- qwen-coder-32B
|
||||
- QwQ
|
||||
|
||||
models:
|
||||
"qwen-coder-32B":
|
||||
# manually set the GPU to run on
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
cmd: /path/to/llama-server ...
|
||||
|
||||
"QwQ":
|
||||
# manually set the GPU to run on
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=1"
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
cmd: /path/to/llama-server ...
|
||||
```
|
||||
|
||||
Append the profile tag, `aider:`, to the model names in the model settings file
|
||||
|
||||
```yaml
|
||||
# aider.model.settings.yml
|
||||
- name: "openai/aider:QwQ"
|
||||
weak_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||
editor_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||
|
||||
- name: "openai/aider:qwen-coder-32B"
|
||||
editor_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||
```
|
||||
|
||||
Run aider with:
|
||||
|
||||
```sh
|
||||
$ aider --architect \
|
||||
--no-show-model-warnings \
|
||||
--model openai/aider:QwQ \
|
||||
--editor-model openai/aider:qwen-coder-32B \
|
||||
--config aider.conf.yml \
|
||||
--model-settings-file aider.model.settings.yml
|
||||
--openai-api-key "sk-na" \
|
||||
--openai-api-base "http://10.0.1.24:8080/v1"
|
||||
```
|
||||
@@ -0,0 +1,28 @@
|
||||
# this makes use of llama-swap's profile feature to
|
||||
# keep the architect and editor models in VRAM on different GPUs
|
||||
|
||||
- name: "openai/aider:QwQ"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.95
|
||||
top_k: 40
|
||||
presence_penalty: 0.1
|
||||
repetition_penalty: 1
|
||||
num_ctx: 16384
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
weak_model_name: "openai/aider:qwen-coder-32B"
|
||||
editor_model_name: "openai/aider:qwen-coder-32B"
|
||||
|
||||
- name: "openai/aider:qwen-coder-32B"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.8
|
||||
top_k: 20
|
||||
repetition_penalty: 1.05
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
editor_edit_format: editor-diff
|
||||
editor_model_name: "openai/aider:qwen-coder-32B"
|
||||
@@ -0,0 +1,26 @@
|
||||
- name: "openai/QwQ"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.95
|
||||
top_k: 40
|
||||
presence_penalty: 0.1
|
||||
repetition_penalty: 1
|
||||
num_ctx: 16384
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
weak_model_name: "openai/qwen-coder-32B"
|
||||
editor_model_name: "openai/qwen-coder-32B"
|
||||
|
||||
- name: "openai/qwen-coder-32B"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.8
|
||||
top_k: 20
|
||||
repetition_penalty: 1.05
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
editor_edit_format: editor-diff
|
||||
editor_model_name: "openai/qwen-coder-32B"
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
healthCheckTimeout: 300
|
||||
logLevel: debug
|
||||
|
||||
profiles:
|
||||
aider:
|
||||
- qwen-coder-32B
|
||||
- QwQ
|
||||
|
||||
models:
|
||||
"qwen-coder-32B":
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
aliases:
|
||||
- coder
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
|
||||
# set appropriate paths for your environment
|
||||
cmd: >
|
||||
/path/to/llama-server
|
||||
--host 127.0.0.1 --port 8999 --flash-attn --slots
|
||||
--ctx-size 16000
|
||||
--ctx-size-draft 16000
|
||||
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
--model-draft /path/to/Qwen2.5-Coder-1.5B-Instruct-Q8_0.gguf
|
||||
-ngl 99 -ngld 99
|
||||
--draft-max 16 --draft-min 4 --draft-p-min 0.4
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
"QwQ":
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=1"
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
|
||||
# set appropriate paths for your environment
|
||||
cmd: >
|
||||
/path/to/llama-server
|
||||
--host 127.0.0.1 --port 9503
|
||||
--flash-attn --metrics
|
||||
--slots
|
||||
--model /path/to/Qwen_QwQ-32B-Q4_K_M.gguf
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
--ctx-size 32000
|
||||
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
|
||||
--temp 0.6
|
||||
--repeat-penalty 1.1
|
||||
--dry-multiplier 0.5
|
||||
--min-p 0.01
|
||||
--top-k 40
|
||||
--top-p 0.95
|
||||
-ngl 99 -ngld 99
|
||||
@@ -0,0 +1,51 @@
|
||||
# Restart llama-swap on config change
|
||||
|
||||
Sometimes editing the configuration file can take a bit of trail and error to get a model configuration tuned just right. The `watch-and-restart.sh` script can be used to watch `config.yaml` for changes and restart `llama-swap` when it detects a change.
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#
|
||||
# A simple watch and restart llama-swap when its configuration
|
||||
# file changes. Useful for trying out configuration changes
|
||||
# without manually restarting the server each time.
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: $0 <path to config.yaml>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
while true; do
|
||||
# Start the process again
|
||||
./llama-swap-linux-amd64 -config $1 -listen :1867 &
|
||||
PID=$!
|
||||
echo "Started llama-swap with PID $PID"
|
||||
|
||||
# Wait for modifications in the specified directory or file
|
||||
inotifywait -e modify "$1"
|
||||
|
||||
# Check if process exists before sending signal
|
||||
if kill -0 $PID 2>/dev/null; then
|
||||
echo "Sending SIGTERM to $PID"
|
||||
kill -SIGTERM $PID
|
||||
wait $PID
|
||||
else
|
||||
echo "Process $PID no longer exists"
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
```
|
||||
|
||||
## Usage and output example
|
||||
|
||||
```bash
|
||||
$ ./watch-and-restart.sh config.yaml
|
||||
Started llama-swap with PID 495455
|
||||
Setting up watches.
|
||||
Watches established.
|
||||
llama-swap listening on :1867
|
||||
Sending SIGTERM to 495455
|
||||
Shutting down llama-swap
|
||||
Started llama-swap with PID 495486
|
||||
Setting up watches.
|
||||
Watches established.
|
||||
llama-swap listening on :1867
|
||||
```
|
||||
@@ -3,11 +3,16 @@ module github.com/mostlygeek/llama-swap
|
||||
go 1.23.0
|
||||
|
||||
require (
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/billziss-gh/golib v0.2.0 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
@@ -15,12 +20,10 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/gin-gonic/gin v1.10.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
@@ -29,12 +32,14 @@ require (
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.31.0 // indirect
|
||||
golang.org/x/net v0.33.0 // indirect
|
||||
golang.org/x/sys v0.28.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8=
|
||||
github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
@@ -9,12 +11,16 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
||||
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
@@ -23,6 +29,8 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
|
||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
||||
@@ -57,6 +65,16 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
@@ -64,24 +82,18 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
|
||||
BIN
Binary file not shown.
|
After Width: | Height: | Size: 351 KiB |
+149
-7
@@ -1,23 +1,34 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/proxy"
|
||||
)
|
||||
|
||||
var version string = "0"
|
||||
var commit string = "abcd1234"
|
||||
var date = "unknown"
|
||||
var (
|
||||
version string = "0"
|
||||
commit string = "abcd1234"
|
||||
date string = "unknown"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Define a command-line flag for the port
|
||||
configPath := flag.String("config", "config.yaml", "config file name")
|
||||
listenStr := flag.String("listen", ":8080", "listen ip/port")
|
||||
showVersion := flag.Bool("version", false, "show version of build")
|
||||
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
@@ -32,6 +43,10 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(config.Profiles) > 0 {
|
||||
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
|
||||
}
|
||||
|
||||
if mode := os.Getenv("GIN_MODE"); mode != "" {
|
||||
gin.SetMode(mode)
|
||||
} else {
|
||||
@@ -39,9 +54,136 @@ func main() {
|
||||
}
|
||||
|
||||
proxyManager := proxy.New(config)
|
||||
fmt.Println("llama-swap listening on " + *listenStr)
|
||||
if err := proxyManager.Run(*listenStr); err != nil {
|
||||
fmt.Printf("Server error: %v\n", err)
|
||||
os.Exit(1)
|
||||
|
||||
// Setup channels for server management
|
||||
reloadChan := make(chan *proxy.ProxyManager)
|
||||
exitChan := make(chan struct{})
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Create server with initial handler
|
||||
srv := &http.Server{
|
||||
Addr: *listenStr,
|
||||
Handler: proxyManager,
|
||||
}
|
||||
|
||||
// Start server
|
||||
fmt.Printf("llama-swap listening on %s\n", *listenStr)
|
||||
go func() {
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
fmt.Printf("Fatal server error: %v\n", err)
|
||||
close(exitChan)
|
||||
}
|
||||
}()
|
||||
|
||||
// Handle config reloads and signals
|
||||
go func() {
|
||||
currentManager := proxyManager
|
||||
for {
|
||||
select {
|
||||
case newManager := <-reloadChan:
|
||||
log.Println("Config change detected, waiting for in-flight requests to complete...")
|
||||
// Stop old manager processes gracefully (this waits for in-flight requests)
|
||||
currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
|
||||
// Now do a full shutdown to clear the process map
|
||||
currentManager.Shutdown()
|
||||
currentManager = newManager
|
||||
srv.Handler = newManager
|
||||
log.Println("Server handler updated with new config")
|
||||
case sig := <-sigChan:
|
||||
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
currentManager.Shutdown()
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
fmt.Printf("Server shutdown error: %v\n", err)
|
||||
}
|
||||
close(exitChan)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Start file watcher if requested
|
||||
if *watchConfig {
|
||||
absConfigPath, err := filepath.Abs(*configPath)
|
||||
if err != nil {
|
||||
log.Printf("Error getting absolute path for config: %v. File watching disabled.", err)
|
||||
} else {
|
||||
go watchConfigFileWithReload(absConfigPath, reloadChan)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for exit signal
|
||||
<-exitChan
|
||||
}
|
||||
|
||||
// watchConfigFileWithReload monitors the configuration file and sends new ProxyManager instances through reloadChan.
|
||||
func watchConfigFileWithReload(configPath string, reloadChan chan<- *proxy.ProxyManager) {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Printf("Error creating file watcher: %v. File watching disabled.", err)
|
||||
return
|
||||
}
|
||||
defer watcher.Close()
|
||||
|
||||
err = watcher.Add(configPath)
|
||||
if err != nil {
|
||||
log.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", configPath, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Watching config file for changes: %s", configPath)
|
||||
|
||||
var debounceTimer *time.Timer
|
||||
debounceDuration := 2 * time.Second
|
||||
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// We only care about writes to the specific config file
|
||||
if event.Name == configPath && event.Has(fsnotify.Write) {
|
||||
// Reset or start the debounce timer
|
||||
if debounceTimer != nil {
|
||||
debounceTimer.Stop()
|
||||
}
|
||||
debounceTimer = time.AfterFunc(debounceDuration, func() {
|
||||
log.Printf("Config file modified: %s, reloading...", event.Name)
|
||||
|
||||
// Try up to 3 times with exponential backoff
|
||||
var newConfig proxy.Config
|
||||
var err error
|
||||
for retries := 0; retries < 3; retries++ {
|
||||
// Load new configuration
|
||||
newConfig, err = proxy.LoadConfig(configPath)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
log.Printf("Error loading new config (attempt %d/3): %v", retries+1, err)
|
||||
if retries < 2 {
|
||||
time.Sleep(time.Duration(1<<retries) * time.Second)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Failed to load new config after retries: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create new ProxyManager with new config
|
||||
newPM := proxy.New(newConfig)
|
||||
reloadChan <- newPM
|
||||
log.Println("Config reloaded successfully")
|
||||
})
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
log.Println("File watcher error channel closed.")
|
||||
return
|
||||
}
|
||||
log.Printf("File watcher error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
/*
|
||||
**
|
||||
Test how exec.Cmd.CommandContext behaves under certain conditions:*
|
||||
|
||||
- process is killed externally, what happens with cmd.Wait() *
|
||||
✔︎ it returns. catches crashes.*
|
||||
|
||||
- process ignores SIGTERM*
|
||||
✔︎ `kill()` is called after cmd.WaitDelay*
|
||||
|
||||
- this process exits, what happens with children (kill -9 <this process' pid>)*
|
||||
x they stick around. have to be manually killed.*
|
||||
|
||||
- .WithTimeout()'s cancel is called *
|
||||
✔︎ process is killed after it ignores sigterm, cmd.Wait() catches it.*
|
||||
|
||||
- parent receives SIGINT/SIGTERM, what happens
|
||||
✔︎ waits for child process to exit, then exits gracefully.
|
||||
*/
|
||||
func main() {
|
||||
|
||||
// swap between these to use kill -9 <pid> on the cli to sim external crash
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
//ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
//cmd := exec.CommandContext(ctx, "sleep", "1")
|
||||
cmd := exec.CommandContext(ctx,
|
||||
"../../build/simple-responder_darwin_arm64",
|
||||
//"-ignore-sig-term", /* so it doesn't exit on receiving SIGTERM, test cmd.WaitTimeout */
|
||||
)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
// set a wait delay before signing sig kill
|
||||
cmd.WaitDelay = 500 * time.Millisecond
|
||||
cmd.Cancel = func() error {
|
||||
fmt.Println("✔︎ Cancel() called, sending SIGTERM")
|
||||
cmd.Process.Signal(syscall.SIGTERM)
|
||||
|
||||
//return nil
|
||||
|
||||
// this error is returned by cmd.Wait(), and can be used to
|
||||
// single an error when the process couldn't be normally terminated
|
||||
// but since a SIGTERM is sent, it's probably ok to return a nil
|
||||
// as WaitDelay timing out will override the any error set here.
|
||||
//
|
||||
// test by enabling/disabling -ignore-sig-term on the process
|
||||
// with -ignore-sig-term enabled, cmd.Wait() will have "signal: killed"
|
||||
// without it, it will show the "new error from cancel"
|
||||
return errors.New("error from cmd.Cancel()") // sets error returned by cmd.Wait()
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
fmt.Println("Error starting process:", err)
|
||||
return
|
||||
}
|
||||
|
||||
// catch signals. Calls cancel() which will cause cmd.Wait() to return and
|
||||
// this program to eventually exit gracefully.
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
signal := <-sigChan
|
||||
fmt.Printf("✔︎ Received signal: %d, Killing process... with cancel before exiting\n", signal)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
fmt.Printf("✔︎ Parent Pid: %d, Process Pid: %d\n", os.Getpid(), cmd.Process.Pid)
|
||||
fmt.Println("✔︎ Process started, cmd.Wait() ... ")
|
||||
if err := cmd.Wait(); err != nil {
|
||||
fmt.Println("✔︎ cmd.Wait returned, Error:", err)
|
||||
} else {
|
||||
fmt.Println("✔︎ cmd.Wait returned, Process exited on its own")
|
||||
}
|
||||
fmt.Println("✔︎ Child process exited, Done.")
|
||||
}
|
||||
@@ -12,18 +12,22 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func main() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
// Define a command-line flag for the port
|
||||
port := flag.String("port", "8080", "port to listen on")
|
||||
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
|
||||
|
||||
// Define a command-line flag for the response message
|
||||
responseMessage := flag.String("respond", "hi", "message to respond with")
|
||||
|
||||
silent := flag.Bool("silent", false, "disable all logging")
|
||||
|
||||
ignoreSigTerm := flag.Bool("ignore-sig-term", false, "ignore SIGTERM signal")
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
// Create a new Gin router
|
||||
@@ -31,19 +35,88 @@ func main() {
|
||||
|
||||
// Set up the handler function using the provided response message
|
||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
c.String(200, *responseMessage)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
"h_content_length": c.Request.Header.Get("Content-Length"),
|
||||
})
|
||||
})
|
||||
|
||||
// for issue #62 to check model name strips profile slug
|
||||
// has to be one of the openAI API endpoints that llama-swap proxies
|
||||
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
|
||||
r.POST("/v1/audio/speech", func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
defer c.Request.Body.Close()
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
if modelName != *expectedModel {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
})
|
||||
|
||||
r.POST("/v1/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
// issue #41
|
||||
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
|
||||
// Parse the multipart form
|
||||
if err := c.Request.ParseMultipartForm(10 << 20); err != nil { // 10 MB max memory
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
|
||||
return
|
||||
}
|
||||
|
||||
// Get the model from the form values
|
||||
model := c.Request.FormValue("model")
|
||||
|
||||
if model == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing model parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get the file from the form
|
||||
file, _, err := c.Request.FormFile("file")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error getting file: %s", err)})
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Read the file content to get its size
|
||||
fileBytes, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error reading file: %s", err)})
|
||||
return
|
||||
}
|
||||
|
||||
fileSize := len(fileBytes)
|
||||
|
||||
// Return a JSON response with the model and transcription text including file size
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
||||
"model": model,
|
||||
|
||||
// expose some header values for testing
|
||||
"h_content_type": c.GetHeader("Content-Type"),
|
||||
"h_content_length": c.GetHeader("Content-Length"),
|
||||
})
|
||||
})
|
||||
|
||||
r.GET("/slow-respond", func(c *gin.Context) {
|
||||
@@ -119,6 +192,10 @@ func main() {
|
||||
log.SetOutput(io.Discard)
|
||||
}
|
||||
|
||||
if !*silent {
|
||||
fmt.Printf("My PID: %d\n", os.Getpid())
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("simple-responder listening on %s\n", address)
|
||||
// service connections
|
||||
@@ -129,11 +206,36 @@ func main() {
|
||||
|
||||
// Wait for interrupt signal to gracefully shutdown the server with
|
||||
// a timeout of 5 seconds.
|
||||
quit := make(chan os.Signal, 1)
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
// kill (no param) default send syscall.SIGTERM
|
||||
// kill -2 is syscall.SIGINT
|
||||
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
countSigInt := 0
|
||||
|
||||
runloop:
|
||||
for {
|
||||
signal := <-sigChan
|
||||
switch signal {
|
||||
case syscall.SIGINT:
|
||||
countSigInt++
|
||||
if countSigInt > 1 {
|
||||
break runloop
|
||||
} else {
|
||||
log.Println("Received SIGINT, send another SIGINT to shutdown")
|
||||
}
|
||||
case syscall.SIGTERM:
|
||||
if *ignoreSigTerm {
|
||||
log.Println("Ignoring SIGTERM")
|
||||
} else {
|
||||
log.Println("Received SIGTERM, shutting down")
|
||||
break runloop
|
||||
}
|
||||
default:
|
||||
break runloop
|
||||
}
|
||||
}
|
||||
|
||||
log.Println("simple-responder shutting down")
|
||||
}
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
ui_dist/*
|
||||
+273
-14
@@ -2,35 +2,108 @@ package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/shlex"
|
||||
"github.com/billziss-gh/golib/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DEFAULT_GROUP_ID = "(default)"
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
CmdStop string `yaml:"cmdStop"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||
UnloadAfter int `yaml:"ttl"`
|
||||
Unlisted bool `yaml:"unlisted"`
|
||||
UseModelName string `yaml:"useModelName"`
|
||||
|
||||
// Limit concurrency of HTTP requests to process
|
||||
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelConfig ModelConfig
|
||||
defaults := rawModelConfig{
|
||||
Cmd: "",
|
||||
CmdStop: "",
|
||||
Proxy: "http://localhost:${PORT}",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/health",
|
||||
UnloadAfter: 0,
|
||||
Unlisted: false,
|
||||
UseModelName: "",
|
||||
ConcurrencyLimit: 0,
|
||||
}
|
||||
|
||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||
if runtime.GOOS == "windows" {
|
||||
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*m = ModelConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
type GroupConfig struct {
|
||||
Swap bool `yaml:"swap"`
|
||||
Exclusive bool `yaml:"exclusive"`
|
||||
Persistent bool `yaml:"persistent"`
|
||||
Members []string `yaml:"members"`
|
||||
}
|
||||
|
||||
// set default values for GroupConfig
|
||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawGroupConfig GroupConfig
|
||||
defaults := rawGroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Persistent: false,
|
||||
Members: []string{},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = GroupConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
Models map[string]ModelConfig `yaml:"models"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
|
||||
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||
Macros map[string]string `yaml:"macros"`
|
||||
|
||||
// map aliases to actual model IDs
|
||||
aliases map[string]string
|
||||
|
||||
// automatic port assignments
|
||||
StartPort int `yaml:"startPort"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
@@ -51,42 +124,228 @@ func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
func LoadConfig(path string) (Config, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return Config{}, err
|
||||
}
|
||||
defer file.Close()
|
||||
return LoadConfigFromReader(file)
|
||||
}
|
||||
|
||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
var config Config
|
||||
// default configuration values
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
}
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
// set a minimum of 15 seconds
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
if config.StartPort < 1 {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if _, found := config.aliases[alias]; found {
|
||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||
}
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
/* check macro constraint rules:
|
||||
|
||||
- name must fit the regex ^[a-zA-Z0-9_-]+$
|
||||
- names must be less than 64 characters (no reason, just cause)
|
||||
- name can not be any reserved macros: PORT
|
||||
- macro values must be less than 1024 characters
|
||||
*/
|
||||
macroNameRegex := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
for macroName, macroValue := range config.Macros {
|
||||
if len(macroName) >= 64 {
|
||||
return Config{}, fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", macroName)
|
||||
}
|
||||
if !macroNameRegex.MatchString(macroName) {
|
||||
return Config{}, fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", macroName)
|
||||
}
|
||||
if len(macroValue) >= 1024 {
|
||||
return Config{}, fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", macroName)
|
||||
}
|
||||
switch macroName {
|
||||
case "PORT":
|
||||
return Config{}, fmt.Errorf("macro name '%s' is reserved and cannot be used", macroName)
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs first, makes testing more consistent
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
|
||||
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
|
||||
for macroName, macroValue := range config.Macros {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroValue)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroValue)
|
||||
}
|
||||
|
||||
// enforce ${PORT} used in both cmd and proxy
|
||||
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
// only iterate over models that use ${PORT} to keep port numbers from increasing unnecessarily
|
||||
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
|
||||
nextPortStr := strconv.Itoa(nextPort)
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", nextPortStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${PORT}", nextPortStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", nextPortStr)
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// make sure there are no unknown macros that have not been replaced
|
||||
macroPattern := regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPattern.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // this is ok, has to be replaced by process later
|
||||
}
|
||||
if _, exists := config.Macros[macroName]; !exists {
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
// check that members are all unique in the groups
|
||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
// Check for duplicates within this group
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
// Check if member is used in another group
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// rewrites the yaml to include a default group with any orphaned models
|
||||
func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
if config.Groups == nil {
|
||||
config.Groups = make(map[string]GroupConfig)
|
||||
}
|
||||
|
||||
defaultGroup := GroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{},
|
||||
}
|
||||
// if groups is empty, create a default group and put
|
||||
// all models into it
|
||||
if len(config.Groups) == 0 {
|
||||
for modelName := range config.Models {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
} else {
|
||||
// iterate over existing group members and add non-grouped models into the default group
|
||||
for modelName, _ := range config.Models {
|
||||
foundModel := false
|
||||
found:
|
||||
// search for the model in existing groups
|
||||
for _, groupConfig := range config.Groups {
|
||||
for _, member := range groupConfig.Members {
|
||||
if member == modelName {
|
||||
foundModel = true
|
||||
break found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundModel {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
||||
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
// Remove trailing backslashes
|
||||
cmdStr = strings.ReplaceAll(cmdStr, "\\ \n", " ")
|
||||
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
// Handle trailing backslashes by replacing with space
|
||||
if strings.HasSuffix(trimmed, "\\") {
|
||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// put it back together
|
||||
cmdStr = strings.Join(cleanedLines, "\n")
|
||||
|
||||
// Split the command into arguments
|
||||
args, err := shlex.Split(cmdStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var args []string
|
||||
if runtime.GOOS == "windows" {
|
||||
args = shlex.Windows.Split(cmdStr)
|
||||
} else {
|
||||
args = shlex.Posix.Split(cmdStr)
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
//go:build !windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
// Test a command with spaces and newlines
|
||||
args, err := SanitizeCommand(`python model1.py \
|
||||
-a "double quotes" \
|
||||
--arg2 'single quotes'
|
||||
-s
|
||||
# comment 1
|
||||
--arg3 123 \
|
||||
|
||||
# comment 2
|
||||
--arg4 '"string in string"'
|
||||
|
||||
|
||||
# this will get stripped out as well as the white space above
|
||||
-c "'single quoted'"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{
|
||||
"python", "model1.py",
|
||||
"-a", "double quotes",
|
||||
"--arg2", "single quotes",
|
||||
"-s",
|
||||
"--arg3", "123",
|
||||
"--arg4", `"string in string"`,
|
||||
"-c", `'single quoted'`,
|
||||
}, args)
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, args)
|
||||
}
|
||||
|
||||
// Test the default values are automatically set for global, model and group configurations
|
||||
// after loading the configuration
|
||||
func TestConfig_DefaultValuesPosix(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||
assert.Equal(t, 5800, config.StartPort)
|
||||
assert.Equal(t, "info", config.LogLevel)
|
||||
|
||||
// Test default group exists
|
||||
defaultGroup, exists := config.Groups["(default)"]
|
||||
assert.True(t, exists, "default group should exist")
|
||||
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
|
||||
assert.Equal(t, true, defaultGroup.Swap)
|
||||
assert.Equal(t, true, defaultGroup.Exclusive)
|
||||
assert.Equal(t, false, defaultGroup.Persistent)
|
||||
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
|
||||
}
|
||||
|
||||
model1, exists := config.Models["model1"]
|
||||
assert.True(t, exists, "model1 should exist")
|
||||
if assert.NotNil(t, model1, "model1 should not be nil") {
|
||||
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
|
||||
assert.Equal(t, "", model1.CmdStop)
|
||||
assert.Equal(t, "http://localhost:5800", model1.Proxy)
|
||||
assert.Equal(t, "/health", model1.CheckEndpoint)
|
||||
assert.Equal(t, []string{}, model1.Aliases)
|
||||
assert.Equal(t, []string{}, model1.Env)
|
||||
assert.Equal(t, 0, model1.UnloadAfter)
|
||||
assert.Equal(t, false, model1.Unlisted)
|
||||
assert.Equal(t, "", model1.UseModelName)
|
||||
assert.Equal(t, 0, model1.ConcurrencyLimit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_LoadPosix(t *testing.T) {
|
||||
// Create a temporary YAML file for testing
|
||||
tempDir, err := os.MkdirTemp("", "test-config")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||
content := `
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
aliases:
|
||||
- "m1"
|
||||
- "model-one"
|
||||
env:
|
||||
- "VAR1=value1"
|
||||
- "VAR2=value2"
|
||||
checkEndpoint: "/health"
|
||||
model2:
|
||||
cmd: ${svr-path} --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "m2"
|
||||
checkEndpoint: "/"
|
||||
model3:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "mthree"
|
||||
checkEndpoint: "/"
|
||||
model4:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8082"
|
||||
checkEndpoint: "/"
|
||||
|
||||
healthCheckTimeout: 15
|
||||
profiles:
|
||||
test:
|
||||
- model1
|
||||
- model2
|
||||
groups:
|
||||
group1:
|
||||
swap: true
|
||||
exclusive: false
|
||||
members: ["model2"]
|
||||
forever:
|
||||
exclusive: false
|
||||
persistent: true
|
||||
members:
|
||||
- "model4"
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("Failed to write temporary file: %v", err)
|
||||
}
|
||||
|
||||
// Load the config and verify
|
||||
config, err := LoadConfig(tempFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
StartPort: 5800,
|
||||
Macros: map[string]string{
|
||||
"svr-path": "path/to/server",
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, config)
|
||||
|
||||
realname, found := config.RealModelName("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", realname)
|
||||
}
|
||||
+215
-89
@@ -1,90 +1,67 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_Load(t *testing.T) {
|
||||
// Create a temporary YAML file for testing
|
||||
tempDir, err := os.MkdirTemp("", "test-config")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
func TestConfig_GroupMemberIsUnique(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
model2:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
checkEndpoint: "/"
|
||||
model3:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
checkEndpoint: "/"
|
||||
|
||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||
healthCheckTimeout: 15
|
||||
groups:
|
||||
group1:
|
||||
swap: true
|
||||
exclusive: false
|
||||
members: ["model2"]
|
||||
group2:
|
||||
swap: true
|
||||
exclusive: false
|
||||
members: ["model2"]
|
||||
`
|
||||
// Load the config and verify
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
|
||||
// a Contains as order of the map is not guaranteed
|
||||
assert.Contains(t, err.Error(), "model member model2 is used in multiple groups:")
|
||||
}
|
||||
|
||||
func TestConfig_ModelAliasesAreUnique(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
aliases:
|
||||
- "m1"
|
||||
- "model-one"
|
||||
env:
|
||||
- "VAR1=value1"
|
||||
- "VAR2=value2"
|
||||
checkEndpoint: "/health"
|
||||
- m1
|
||||
model2:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "m2"
|
||||
checkEndpoint: "/"
|
||||
healthCheckTimeout: 15
|
||||
profiles:
|
||||
test:
|
||||
- model1
|
||||
- model2
|
||||
aliases:
|
||||
- m1
|
||||
- m2
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("Failed to write temporary file: %v", err)
|
||||
}
|
||||
|
||||
// Load the config and verify
|
||||
config, err := LoadConfig(tempFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
|
||||
expected := &Config{
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: nil,
|
||||
CheckEndpoint: "/",
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, config)
|
||||
|
||||
realname, found := config.RealModelName("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", realname)
|
||||
// this is a contains because it could be `model1` or `model2` depending on the order
|
||||
// go decided on the order of the map
|
||||
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
|
||||
}
|
||||
|
||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||
@@ -147,30 +124,179 @@ func TestConfig_FindConfig(t *testing.T) {
|
||||
assert.Equal(t, ModelConfig{}, modelConfig)
|
||||
}
|
||||
|
||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
func TestConfig_AutomaticPortAssignments(t *testing.T) {
|
||||
|
||||
// Test a command with spaces and newlines
|
||||
args, err := SanitizeCommand(`python model1.py \
|
||||
-a "double quotes" \
|
||||
--arg2 'single quotes'
|
||||
-s
|
||||
--arg3 123 \
|
||||
--arg4 '"string in string"'
|
||||
-c "'single quoted'"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{
|
||||
"python", "model1.py",
|
||||
"-a", "double quotes",
|
||||
"--arg2", "single quotes",
|
||||
"-s",
|
||||
"--arg3", "123",
|
||||
"--arg4", `"string in string"`,
|
||||
"-c", `'single quoted'`,
|
||||
}, args)
|
||||
t.Run("Default Port Ranges", func(t *testing.T) {
|
||||
content := ``
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if !assert.NoError(t, err) {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, args)
|
||||
assert.Equal(t, 5800, config.StartPort)
|
||||
})
|
||||
t.Run("User specific port ranges", func(t *testing.T) {
|
||||
content := `startPort: 1000`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if !assert.NoError(t, err) {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 1000, config.StartPort)
|
||||
})
|
||||
|
||||
t.Run("Invalid start port", func(t *testing.T) {
|
||||
content := `startPort: abcd`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("start port must be greater than 1", func(t *testing.T) {
|
||||
content := `startPort: -99`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Automatic port assignments", func(t *testing.T) {
|
||||
content := `
|
||||
startPort: 5800
|
||||
models:
|
||||
model1:
|
||||
cmd: svr --port ${PORT}
|
||||
model2:
|
||||
cmd: svr --port ${PORT}
|
||||
proxy: "http://172.11.22.33:${PORT}"
|
||||
model3:
|
||||
cmd: svr --port 1999
|
||||
proxy: "http://1.2.3.4:1999"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if !assert.NoError(t, err) {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 5800, config.StartPort)
|
||||
assert.Equal(t, "svr --port 5800", config.Models["model1"].Cmd)
|
||||
assert.Equal(t, "http://localhost:5800", config.Models["model1"].Proxy)
|
||||
|
||||
assert.Equal(t, "svr --port 5801", config.Models["model2"].Cmd)
|
||||
assert.Equal(t, "http://172.11.22.33:5801", config.Models["model2"].Proxy)
|
||||
|
||||
assert.Equal(t, "svr --port 1999", config.Models["model3"].Cmd)
|
||||
assert.Equal(t, "http://1.2.3.4:1999", config.Models["model3"].Proxy)
|
||||
|
||||
})
|
||||
|
||||
t.Run("Proxy value required if no ${PORT} in cmd", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: svr --port 111
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Equal(t, "model model1: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_MacroReplacement(t *testing.T) {
|
||||
content := `
|
||||
startPort: 9990
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
argOne: "--arg1"
|
||||
argTwo: "--arg2"
|
||||
autoPort: "--port ${PORT}"
|
||||
|
||||
models:
|
||||
model1:
|
||||
cmd: |
|
||||
${svr-path} ${argTwo}
|
||||
# the automatic ${PORT} is replaced
|
||||
${autoPort}
|
||||
${argOne}
|
||||
--arg3 three
|
||||
cmdStop: |
|
||||
/path/to/stop.sh --port ${PORT} ${argTwo}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three", strings.Join(sanitizedCmd, " "))
|
||||
|
||||
sanitizedCmdStop, err := SanitizeCommand(config.Models["model1"].CmdStop)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/path/to/stop.sh --port 9990 --arg2", strings.Join(sanitizedCmdStop, " "))
|
||||
}
|
||||
|
||||
func TestConfig_MacroErrorOnUnknownMacros(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field string
|
||||
content string
|
||||
}{
|
||||
{
|
||||
name: "unknown macro in cmd",
|
||||
field: "cmd",
|
||||
content: `
|
||||
startPort: 9990
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
models:
|
||||
model1:
|
||||
cmd: |
|
||||
${svr-path} --port ${PORT}
|
||||
${unknownMacro}
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "unknown macro in cmdStop",
|
||||
field: "cmdStop",
|
||||
content: `
|
||||
startPort: 9990
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
models:
|
||||
model1:
|
||||
cmd: "${svr-path} --port ${PORT}"
|
||||
cmdStop: "kill ${unknownMacro}"
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "unknown macro in proxy",
|
||||
field: "proxy",
|
||||
content: `
|
||||
startPort: 9990
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
models:
|
||||
model1:
|
||||
cmd: "${svr-path} --port ${PORT}"
|
||||
proxy: "http://localhost:${unknownMacro}"
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "unknown macro in checkEndpoint",
|
||||
field: "checkEndpoint",
|
||||
content: `
|
||||
startPort: 9990
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
models:
|
||||
model1:
|
||||
cmd: "${svr-path} --port ${PORT}"
|
||||
checkEndpoint: "http://localhost:${unknownMacro}/health"
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field)
|
||||
//t.Log(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,227 @@
|
||||
//go:build windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
// does not support single quoted strings like in config_posix_test.go
|
||||
args, err := SanitizeCommand(`python model1.py \
|
||||
|
||||
-a "double quotes" \
|
||||
-s
|
||||
--arg3 123 \
|
||||
|
||||
# comment 2
|
||||
--arg4 '"string in string"'
|
||||
|
||||
|
||||
|
||||
# this will get stripped out as well as the white space above
|
||||
-c "'single quoted'"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{
|
||||
"python", "model1.py",
|
||||
"-a", "double quotes",
|
||||
"-s",
|
||||
"--arg3", "123",
|
||||
"--arg4", "'string in string'", // this is a little weird but the lexer says so...?
|
||||
"-c", `'single quoted'`,
|
||||
}, args)
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, args)
|
||||
}
|
||||
|
||||
func TestConfig_DefaultValuesWindows(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||
assert.Equal(t, 5800, config.StartPort)
|
||||
assert.Equal(t, "info", config.LogLevel)
|
||||
|
||||
// Test default group exists
|
||||
defaultGroup, exists := config.Groups["(default)"]
|
||||
assert.True(t, exists, "default group should exist")
|
||||
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
|
||||
assert.Equal(t, true, defaultGroup.Swap)
|
||||
assert.Equal(t, true, defaultGroup.Exclusive)
|
||||
assert.Equal(t, false, defaultGroup.Persistent)
|
||||
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
|
||||
}
|
||||
|
||||
model1, exists := config.Models["model1"]
|
||||
assert.True(t, exists, "model1 should exist")
|
||||
if assert.NotNil(t, model1, "model1 should not be nil") {
|
||||
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
|
||||
assert.Equal(t, "taskkill /f /t /pid ${PID}", model1.CmdStop)
|
||||
assert.Equal(t, "http://localhost:5800", model1.Proxy)
|
||||
assert.Equal(t, "/health", model1.CheckEndpoint)
|
||||
assert.Equal(t, []string{}, model1.Aliases)
|
||||
assert.Equal(t, []string{}, model1.Env)
|
||||
assert.Equal(t, 0, model1.UnloadAfter)
|
||||
assert.Equal(t, false, model1.Unlisted)
|
||||
assert.Equal(t, "", model1.UseModelName)
|
||||
assert.Equal(t, 0, model1.ConcurrencyLimit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_LoadWindows(t *testing.T) {
|
||||
// Create a temporary YAML file for testing
|
||||
tempDir, err := os.MkdirTemp("", "test-config")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||
content := `
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
aliases:
|
||||
- "m1"
|
||||
- "model-one"
|
||||
env:
|
||||
- "VAR1=value1"
|
||||
- "VAR2=value2"
|
||||
checkEndpoint: "/health"
|
||||
model2:
|
||||
cmd: ${svr-path} --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "m2"
|
||||
checkEndpoint: "/"
|
||||
model3:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "mthree"
|
||||
checkEndpoint: "/"
|
||||
model4:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8082"
|
||||
checkEndpoint: "/"
|
||||
|
||||
healthCheckTimeout: 15
|
||||
profiles:
|
||||
test:
|
||||
- model1
|
||||
- model2
|
||||
groups:
|
||||
group1:
|
||||
swap: true
|
||||
exclusive: false
|
||||
members: ["model2"]
|
||||
forever:
|
||||
exclusive: false
|
||||
persistent: true
|
||||
members:
|
||||
- "model4"
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("Failed to write temporary file: %v", err)
|
||||
}
|
||||
|
||||
// Load the config and verify
|
||||
config, err := LoadConfig(tempFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
StartPort: 5800,
|
||||
Macros: map[string]string{
|
||||
"svr-path": "path/to/server",
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, config)
|
||||
|
||||
realname, found := config.RealModelName("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", realname)
|
||||
}
|
||||
+36
-8
@@ -9,11 +9,13 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
nextTestPort int = 12000
|
||||
portMutex sync.Mutex
|
||||
testLogger = NewLogMonitorWriter(os.Stdout)
|
||||
)
|
||||
|
||||
// Check if the binary exists
|
||||
@@ -26,6 +28,17 @@ func TestMain(m *testing.M) {
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
switch os.Getenv("LOG_LEVEL") {
|
||||
case "debug":
|
||||
testLogger.SetLogLevel(LevelDebug)
|
||||
case "warn":
|
||||
testLogger.SetLogLevel(LevelWarn)
|
||||
case "info":
|
||||
testLogger.SetLogLevel(LevelInfo)
|
||||
default:
|
||||
testLogger.SetLogLevel(LevelWarn)
|
||||
}
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
@@ -33,26 +46,41 @@ func TestMain(m *testing.M) {
|
||||
func getSimpleResponderPath() string {
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||
|
||||
if goos == "windows" {
|
||||
return filepath.Join("..", "build", "simple-responder.exe")
|
||||
} else {
|
||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||
}
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||
func getTestPort() int {
|
||||
portMutex.Lock()
|
||||
defer portMutex.Unlock()
|
||||
|
||||
port := nextTestPort
|
||||
nextTestPort++
|
||||
|
||||
return getTestSimpleResponderConfigPort(expectedMessage, port)
|
||||
return port
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
||||
binaryPath := getSimpleResponderPath()
|
||||
|
||||
// Create a process configuration
|
||||
return ModelConfig{
|
||||
Cmd: fmt.Sprintf("%s --port %d --silent --respond %s", binaryPath, port, expectedMessage),
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
// Create a YAML string with just the values we want to set
|
||||
yamlStr := fmt.Sprintf(`
|
||||
cmd: '%s --port %d --silent --respond %s'
|
||||
proxy: "http://127.0.0.1:%d"
|
||||
`, binaryPath, port, expectedMessage, port)
|
||||
|
||||
var cfg ModelConfig
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 15 KiB |
@@ -1,14 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>llama-swap</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>llama-swap</h1>
|
||||
<p>
|
||||
<a href="/logs">view logs</a> | <a href="/upstream">configured models</a> | <a href="https://github.com/mostlygeek/llama-swap">github</a>
|
||||
</p>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,145 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Logs</title>
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
font-family: "Courier New", Courier, monospace;
|
||||
}
|
||||
#log-controls {
|
||||
margin: 0.5em;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between; /* Spaces out elements evenly */
|
||||
}
|
||||
#log-controls input {
|
||||
flex: 1;
|
||||
}
|
||||
#log-controls input:focus {
|
||||
outline: none; /* Ensures no outline is shown when the input is focused */
|
||||
}
|
||||
#log-stream {
|
||||
flex: 1;
|
||||
margin: 0.5em;
|
||||
padding: 1em;
|
||||
background: #f4f4f4;
|
||||
overflow-y: auto;
|
||||
white-space: pre-wrap; /* Ensures line wrapping */
|
||||
word-wrap: break-word; /* Ensures long words wrap */
|
||||
}
|
||||
|
||||
.regex-error {
|
||||
background-color: #ff0000 !important;
|
||||
}
|
||||
|
||||
/* Dark mode styles */
|
||||
@media (prefers-color-scheme: dark) {
|
||||
body {
|
||||
background-color: #333;
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
#log-stream {
|
||||
background: #444;
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
#log-controls input {
|
||||
background: #555;
|
||||
color: #fff;
|
||||
border: 1px solid #777;
|
||||
}
|
||||
|
||||
#log-controls button {
|
||||
background: #555;
|
||||
color: #fff;
|
||||
border: 1px solid #777;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<pre id="log-stream">Waiting for logs...</pre>
|
||||
<div id="log-controls">
|
||||
<input type="text" id="filter-input" placeholder="regex filter">
|
||||
<button id="clear-button">clear</button>
|
||||
</div>
|
||||
<script>
|
||||
const logStream = document.getElementById('log-stream');
|
||||
const filterInput = document.getElementById('filter-input');
|
||||
var logData = "";
|
||||
let regexFilter = null;
|
||||
|
||||
function setupEventSource() {
|
||||
if (typeof(EventSource) !== "undefined") {
|
||||
const eventSource = new EventSource("/logs/streamSSE");
|
||||
|
||||
eventSource.onmessage = function(event) {
|
||||
logData += event.data;
|
||||
render()
|
||||
};
|
||||
|
||||
eventSource.onerror = function(err) {
|
||||
logData = "EventSource failed: " + err.message;
|
||||
};
|
||||
} else {
|
||||
logData = "SSE Not supported by this browser."
|
||||
}
|
||||
}
|
||||
|
||||
// poor-ai's react ¯\_(ツ)_/¯
|
||||
function render() {
|
||||
if (regexFilter) {
|
||||
const lines = logData.split('\n');
|
||||
const filteredLines = lines.filter(line => {
|
||||
return regexFilter === null || regexFilter.test(line);
|
||||
});
|
||||
|
||||
if (filteredLines.length > 0) {
|
||||
logStream.textContent = filteredLines.join('\n') + '\n';
|
||||
} else {
|
||||
logStream.textContent = "";
|
||||
}
|
||||
} else {
|
||||
logStream.textContent = logData;
|
||||
}
|
||||
|
||||
logStream.scrollTop = logStream.scrollHeight;
|
||||
}
|
||||
|
||||
function updateFilter() {
|
||||
const pattern = filterInput.value.trim();
|
||||
filterInput.classList.remove('regex-error');
|
||||
if (pattern) {
|
||||
try {
|
||||
regexFilter = new RegExp(pattern);
|
||||
} catch (e) {
|
||||
console.error("Invalid regex pattern:", e);
|
||||
regexFilter = null;
|
||||
filterInput.classList.add('regex-error');
|
||||
return
|
||||
}
|
||||
} else {
|
||||
regexFilter = null;
|
||||
}
|
||||
|
||||
render();
|
||||
}
|
||||
|
||||
filterInput.addEventListener('input', updateFilter);
|
||||
document.getElementById('clear-button').addEventListener('click', () => {
|
||||
filterInput.value = "";
|
||||
regexFilter = null;
|
||||
render();
|
||||
});
|
||||
setupEventSource();
|
||||
updateFilter();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,10 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import "embed"
|
||||
|
||||
//go:embed html
|
||||
var htmlFiles embed.FS
|
||||
|
||||
func getHTMLFile(path string) ([]byte, error) {
|
||||
return htmlFiles.ReadFile("html/" + path)
|
||||
}
|
||||
@@ -2,11 +2,21 @@ package proxy
|
||||
|
||||
import (
|
||||
"container/ring"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type LogLevel int
|
||||
|
||||
const (
|
||||
LevelDebug LogLevel = iota
|
||||
LevelInfo
|
||||
LevelWarn
|
||||
LevelError
|
||||
)
|
||||
|
||||
type LogMonitor struct {
|
||||
clients map[chan []byte]bool
|
||||
mu sync.RWMutex
|
||||
@@ -15,6 +25,10 @@ type LogMonitor struct {
|
||||
|
||||
// typically this can be os.Stdout
|
||||
stdout io.Writer
|
||||
|
||||
// logging levels
|
||||
level LogLevel
|
||||
prefix string
|
||||
}
|
||||
|
||||
func NewLogMonitor() *LogMonitor {
|
||||
@@ -26,6 +40,8 @@ func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||
clients: make(map[chan []byte]bool),
|
||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||
stdout: stdout,
|
||||
level: LevelInfo,
|
||||
prefix: "",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,3 +110,77 @@ func (w *LogMonitor) broadcast(msg []byte) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *LogMonitor) SetPrefix(prefix string) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.prefix = prefix
|
||||
}
|
||||
|
||||
func (w *LogMonitor) SetLogLevel(level LogLevel) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.level = level
|
||||
}
|
||||
|
||||
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
|
||||
prefix := ""
|
||||
if w.prefix != "" {
|
||||
prefix = fmt.Sprintf("[%s] ", w.prefix)
|
||||
}
|
||||
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) log(level LogLevel, msg string) {
|
||||
if level < w.level {
|
||||
return
|
||||
}
|
||||
w.Write(w.formatMessage(level.String(), msg))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Debug(msg string) {
|
||||
w.log(LevelDebug, msg)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Info(msg string) {
|
||||
w.log(LevelInfo, msg)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Warn(msg string) {
|
||||
w.log(LevelWarn, msg)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Error(msg string) {
|
||||
w.log(LevelError, msg)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
|
||||
w.log(LevelDebug, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Infof(format string, args ...interface{}) {
|
||||
w.log(LevelInfo, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
|
||||
w.log(LevelWarn, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
|
||||
w.log(LevelError, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (l LogLevel) String() string {
|
||||
switch l {
|
||||
case LevelDebug:
|
||||
return "DEBUG"
|
||||
case LevelInfo:
|
||||
return "INFO"
|
||||
case LevelWarn:
|
||||
return "WARN"
|
||||
case LevelError:
|
||||
return "ERROR"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
+397
-184
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
@@ -17,19 +18,38 @@ import (
|
||||
type ProcessState string
|
||||
|
||||
const (
|
||||
StateStopped ProcessState = ProcessState("stopped")
|
||||
StateReady ProcessState = ProcessState("ready")
|
||||
StateFailed ProcessState = ProcessState("failed")
|
||||
StateStopped ProcessState = ProcessState("stopped")
|
||||
StateStarting ProcessState = ProcessState("starting")
|
||||
StateReady ProcessState = ProcessState("ready")
|
||||
StateStopping ProcessState = ProcessState("stopping")
|
||||
|
||||
// process is shutdown and will not be restarted
|
||||
StateShutdown ProcessState = ProcessState("shutdown")
|
||||
)
|
||||
|
||||
type StopStrategy int
|
||||
|
||||
const (
|
||||
StopImmediately StopStrategy = iota
|
||||
StopWaitForInflightRequest
|
||||
)
|
||||
|
||||
type Process struct {
|
||||
sync.Mutex
|
||||
ID string
|
||||
config ModelConfig
|
||||
cmd *exec.Cmd
|
||||
|
||||
ID string
|
||||
config ModelConfig
|
||||
cmd *exec.Cmd
|
||||
logMonitor *LogMonitor
|
||||
healthCheckTimeout int
|
||||
// PR #155 called to cancel the upstream process
|
||||
cancelUpstream context.CancelFunc
|
||||
|
||||
// closed when command exits
|
||||
cmdWaitChan chan struct{}
|
||||
|
||||
processLogger *LogMonitor
|
||||
proxyLogger *LogMonitor
|
||||
|
||||
healthCheckTimeout int
|
||||
healthCheckLoopInterval time.Duration
|
||||
|
||||
lastRequestHandled time.Time
|
||||
|
||||
@@ -37,31 +57,109 @@ type Process struct {
|
||||
state ProcessState
|
||||
|
||||
inFlightRequests sync.WaitGroup
|
||||
|
||||
// used to block on multiple start() calls
|
||||
waitStarting sync.WaitGroup
|
||||
|
||||
// for managing concurrency limits
|
||||
concurrencyLimitSemaphore chan struct{}
|
||||
|
||||
// used for testing to override the default value
|
||||
gracefulStopTimeout time.Duration
|
||||
|
||||
// track the number of failed starts
|
||||
failedStartCount int
|
||||
}
|
||||
|
||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
||||
concurrentLimit := 10
|
||||
if config.ConcurrencyLimit > 0 {
|
||||
concurrentLimit = config.ConcurrencyLimit
|
||||
}
|
||||
|
||||
return &Process{
|
||||
ID: ID,
|
||||
config: config,
|
||||
cmd: nil,
|
||||
logMonitor: logMonitor,
|
||||
healthCheckTimeout: healthCheckTimeout,
|
||||
state: StateStopped,
|
||||
ID: ID,
|
||||
config: config,
|
||||
cmd: nil,
|
||||
cancelUpstream: nil,
|
||||
processLogger: processLogger,
|
||||
proxyLogger: proxyLogger,
|
||||
healthCheckTimeout: healthCheckTimeout,
|
||||
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
||||
state: StateStopped,
|
||||
|
||||
// concurrency limit
|
||||
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
|
||||
|
||||
// To be removed when migration over exec.CommandContext is complete
|
||||
// stop timeout
|
||||
gracefulStopTimeout: 10 * time.Second,
|
||||
cmdWaitChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// start the process and returns when it is ready
|
||||
func (p *Process) start() error {
|
||||
// LogMonitor returns the log monitor associated with the process.
|
||||
func (p *Process) LogMonitor() *LogMonitor {
|
||||
return p.processLogger
|
||||
}
|
||||
|
||||
// custom error types for swapping state
|
||||
var (
|
||||
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
||||
ErrInvalidStateTransition = errors.New("invalid state transition")
|
||||
)
|
||||
|
||||
// swapState performs a compare and swap of the state atomically. It returns the current state
|
||||
// and an error if the swap failed.
|
||||
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) {
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
if p.state == StateReady {
|
||||
return nil
|
||||
if p.state != expectedState {
|
||||
p.proxyLogger.Warnf("<%s> swapState() Unexpected current state %s, expected %s", p.ID, p.state, expectedState)
|
||||
return p.state, ErrExpectedStateMismatch
|
||||
}
|
||||
|
||||
if p.state == StateFailed {
|
||||
return fmt.Errorf("process is in a failed state and can not be restarted")
|
||||
if !isValidTransition(p.state, newState) {
|
||||
p.proxyLogger.Warnf("<%s> swapState() Invalid state transition from %s to %s", p.ID, p.state, newState)
|
||||
return p.state, ErrInvalidStateTransition
|
||||
}
|
||||
|
||||
p.state = newState
|
||||
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||
return p.state, nil
|
||||
}
|
||||
|
||||
// Helper function to encapsulate transition rules
|
||||
func isValidTransition(from, to ProcessState) bool {
|
||||
switch from {
|
||||
case StateStopped:
|
||||
return to == StateStarting
|
||||
case StateStarting:
|
||||
return to == StateReady || to == StateStopping || to == StateStopped
|
||||
case StateReady:
|
||||
return to == StateStopping
|
||||
case StateStopping:
|
||||
return to == StateStopped || to == StateShutdown
|
||||
case StateShutdown:
|
||||
return false // No transitions allowed from these states
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Process) CurrentState() ProcessState {
|
||||
p.stateMutex.RLock()
|
||||
defer p.stateMutex.RUnlock()
|
||||
return p.state
|
||||
}
|
||||
|
||||
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
||||
// it is a private method because starting is automatic but stopping can be called
|
||||
// at any time.
|
||||
func (p *Process) start() error {
|
||||
|
||||
if p.config.Proxy == "" {
|
||||
return fmt.Errorf("can not start(), upstream proxy missing")
|
||||
}
|
||||
|
||||
args, err := p.config.SanitizedCommand()
|
||||
@@ -69,52 +167,105 @@ func (p *Process) start() error {
|
||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||
}
|
||||
|
||||
p.cmd = exec.Command(args[0], args[1:]...)
|
||||
p.cmd.Stdout = p.logMonitor
|
||||
p.cmd.Stderr = p.logMonitor
|
||||
p.cmd.Env = p.config.Env
|
||||
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
|
||||
if err == ErrExpectedStateMismatch {
|
||||
// already starting, just wait for it to complete and expect
|
||||
// it to be be in the Ready start after. If not, return an error
|
||||
if curState == StateStarting {
|
||||
p.waitStarting.Wait()
|
||||
if state := p.CurrentState(); state == StateReady {
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("process was already starting but wound up in state %v", state)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("processes was in state %v when start() was called", curState)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
|
||||
}
|
||||
}
|
||||
|
||||
p.waitStarting.Add(1)
|
||||
defer p.waitStarting.Done()
|
||||
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
|
||||
|
||||
p.cmd = exec.CommandContext(cmdContext, args[0], args[1:]...)
|
||||
p.cmd.Stdout = p.processLogger
|
||||
p.cmd.Stderr = p.processLogger
|
||||
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
|
||||
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
||||
p.cmd.WaitDelay = p.gracefulStopTimeout
|
||||
p.cancelUpstream = ctxCancelUpstream
|
||||
p.cmdWaitChan = make(chan struct{})
|
||||
|
||||
p.failedStartCount++ // this will be reset to zero when the process has successfully started
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.ID, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
|
||||
err = p.cmd.Start()
|
||||
|
||||
// Set process state to failed
|
||||
if err != nil {
|
||||
return err
|
||||
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
|
||||
p.state = StateStopped // force it into a stopped state
|
||||
return fmt.Errorf(
|
||||
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||
err, curState, swapErr,
|
||||
)
|
||||
}
|
||||
return fmt.Errorf("start() failed: %v", err)
|
||||
}
|
||||
|
||||
// Capture the exit error for later signalling
|
||||
go p.waitForCmd()
|
||||
|
||||
// One of three things can happen at this stage:
|
||||
// 1. The command exits unexpectedly
|
||||
// 2. The health check fails
|
||||
// 3. The health check passes
|
||||
//
|
||||
// only in the third case will the process be considered Ready to accept
|
||||
healthCheckContext, cancelHealthCheck := context.WithCancelCause(context.Background())
|
||||
defer cancelHealthCheck(nil) // clean up
|
||||
cmdWaitChan := make(chan error, 1)
|
||||
healthCheckChan := make(chan error, 1)
|
||||
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
||||
|
||||
go func() {
|
||||
// possible cmd exits early
|
||||
cmdWaitChan <- p.cmd.Wait()
|
||||
}()
|
||||
checkStartTime := time.Now()
|
||||
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
|
||||
go func() {
|
||||
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
||||
healthCheckChan <- p.checkHealthEndpoint(healthCheckContext)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-cmdWaitChan:
|
||||
p.state = StateFailed
|
||||
// a "none" means don't check for health ... I could have picked a better word :facepalm:
|
||||
if checkEndpoint != "none" {
|
||||
proxyTo := p.config.Proxy
|
||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error())
|
||||
} else {
|
||||
err = fmt.Errorf("command [%s] exited unexpected", strings.Join(p.cmd.Args, " "))
|
||||
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
|
||||
}
|
||||
cancelHealthCheck(err)
|
||||
return err
|
||||
case err := <-healthCheckChan:
|
||||
if err != nil {
|
||||
p.state = StateFailed
|
||||
return err
|
||||
|
||||
// Ready Check loop
|
||||
for {
|
||||
currentState := p.CurrentState()
|
||||
if currentState != StateStarting {
|
||||
if currentState == StateStopped {
|
||||
return fmt.Errorf("upstream command exited prematurely but successfully")
|
||||
}
|
||||
return errors.New("health check interrupted due to shutdown")
|
||||
}
|
||||
|
||||
if time.Since(checkStartTime) > maxDuration {
|
||||
p.stopCommand()
|
||||
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
||||
}
|
||||
|
||||
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
|
||||
break
|
||||
} else {
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
ttl := time.Until(checkStartTime.Add(maxDuration))
|
||||
p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds())
|
||||
} else {
|
||||
p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err)
|
||||
}
|
||||
}
|
||||
<-time.After(p.healthCheckLoopInterval)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,7 +276,7 @@ func (p *Process) start() error {
|
||||
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||
|
||||
for range time.Tick(time.Second) {
|
||||
if p.state != StateReady {
|
||||
if p.CurrentState() != StateReady {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -133,173 +284,155 @@ func (p *Process) start() error {
|
||||
p.inFlightRequests.Wait()
|
||||
|
||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
|
||||
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
|
||||
p.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
p.state = StateReady
|
||||
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||
} else {
|
||||
p.failedStartCount = 0
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Stop will wait for inflight requests to complete before stopping the process.
|
||||
func (p *Process) Stop() {
|
||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||
return
|
||||
}
|
||||
|
||||
// wait for any inflight requests before proceeding
|
||||
p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID)
|
||||
p.inFlightRequests.Wait()
|
||||
p.StopImmediately()
|
||||
}
|
||||
|
||||
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
|
||||
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
|
||||
func (p *Process) StopImmediately() {
|
||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||
return
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState())
|
||||
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
||||
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
|
||||
return
|
||||
}
|
||||
|
||||
p.stopCommand()
|
||||
}
|
||||
|
||||
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
||||
// of time for any inflight requests to complete before shutting down. If the Process
|
||||
// is in the state of starting, it will cancel it and shut it down. Once a process is in
|
||||
// the StateShutdown state, it can not be started again.
|
||||
func (p *Process) Shutdown() {
|
||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||
return
|
||||
}
|
||||
|
||||
p.stopCommand()
|
||||
// just force it to this state since there is no recovery from shutdown
|
||||
p.state = StateShutdown
|
||||
}
|
||||
|
||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||
// If it does not exit within 5 seconds, it will send a SIGKILL.
|
||||
func (p *Process) stopCommand() {
|
||||
stopStartTime := time.Now()
|
||||
defer func() {
|
||||
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||
}()
|
||||
|
||||
if p.cancelUpstream == nil {
|
||||
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
|
||||
return
|
||||
}
|
||||
|
||||
p.cancelUpstream()
|
||||
<-p.cmdWaitChan
|
||||
}
|
||||
|
||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||
client := &http.Client{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", healthURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// got a response but it was not an OK
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Process) Stop() {
|
||||
// wait for any inflight requests before proceeding
|
||||
p.inFlightRequests.Wait()
|
||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
requestBeginTime := time.Now()
|
||||
var startDuration time.Duration
|
||||
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
if p.state != StateReady {
|
||||
// prevent new requests from being made while stopping or irrecoverable
|
||||
currentState := p.CurrentState()
|
||||
if currentState == StateShutdown || currentState == StateStopping {
|
||||
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
if p.cmd == nil || p.cmd.Process == nil {
|
||||
// this situation should never happen... but if it does just update the state
|
||||
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.")
|
||||
p.state = StateStopped
|
||||
return
|
||||
}
|
||||
|
||||
// Pretty sure this stopping code needs some work for windows and
|
||||
// will be a source of pain in the future.
|
||||
|
||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||
sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
sigtermNormal := make(chan error, 1)
|
||||
go func() {
|
||||
sigtermNormal <- p.cmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sigtermTimeout.Done():
|
||||
fmt.Fprintf(p.logMonitor, "!!! process for %s timed out waiting to stop\n", p.ID)
|
||||
p.cmd.Process.Kill()
|
||||
p.cmd.Wait()
|
||||
case err := <-sigtermNormal:
|
||||
if err != nil {
|
||||
if err.Error() != "wait: no child processes" {
|
||||
// possible that simple-responder for testing is just not
|
||||
// existing right, so suppress those errors.
|
||||
fmt.Fprintf(p.logMonitor, "!!! process for %s stopped with error > %v\n", p.ID, err)
|
||||
}
|
||||
}
|
||||
case p.concurrencyLimitSemaphore <- struct{}{}:
|
||||
defer func() { <-p.concurrencyLimitSemaphore }()
|
||||
default:
|
||||
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
p.state = StateStopped
|
||||
}
|
||||
|
||||
func (p *Process) CurrentState() ProcessState {
|
||||
p.stateMutex.RLock()
|
||||
defer p.stateMutex.RUnlock()
|
||||
return p.state
|
||||
}
|
||||
|
||||
func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error {
|
||||
if p.config.Proxy == "" {
|
||||
return fmt.Errorf("no upstream available to check /health")
|
||||
}
|
||||
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
|
||||
if checkEndpoint == "none" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// keep default behaviour
|
||||
if checkEndpoint == "" {
|
||||
checkEndpoint = "/health"
|
||||
}
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
startTime := time.Now()
|
||||
|
||||
for {
|
||||
req, err := http.NewRequest("GET", healthURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctxFromStart, time.Second)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := client.Do(req)
|
||||
|
||||
ttl := (maxDuration - time.Since(startTime)).Seconds()
|
||||
|
||||
if err != nil {
|
||||
// check if the context was cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err := context.Cause(ctx)
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
// wait a bit longer for TCP connection issues
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
|
||||
time.Sleep(5 * time.Second)
|
||||
} else {
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
p.inFlightRequests.Add(1)
|
||||
|
||||
defer func() {
|
||||
p.lastRequestHandled = time.Now()
|
||||
p.inFlightRequests.Done()
|
||||
}()
|
||||
|
||||
// start the process on demand
|
||||
if p.CurrentState() != StateReady {
|
||||
beginStartTime := time.Now()
|
||||
if err := p.start(); err != nil {
|
||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||
http.Error(w, errstr, http.StatusInternalServerError)
|
||||
http.Error(w, errstr, http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
startDuration = time.Since(beginStartTime)
|
||||
}
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
|
||||
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
req.Header = r.Header.Clone()
|
||||
|
||||
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
|
||||
if err == nil {
|
||||
req.ContentLength = contentLength
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
@@ -333,4 +466,84 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
totalTime := time.Since(requestBeginTime)
|
||||
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
|
||||
p.ID, r.RequestURI, startDuration, totalTime)
|
||||
}
|
||||
|
||||
// waitForCmd waits for the command to exit and handles exit conditions depending on current state
|
||||
func (p *Process) waitForCmd() {
|
||||
exitErr := p.cmd.Wait()
|
||||
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
|
||||
|
||||
if exitErr != nil {
|
||||
if errno, ok := exitErr.(syscall.Errno); ok {
|
||||
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
|
||||
} else if exitError, ok := exitErr.(*exec.ExitError); ok {
|
||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||
p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID)
|
||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||
p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID)
|
||||
} else {
|
||||
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
||||
}
|
||||
} else {
|
||||
if exitErr.Error() != "context canceled" /* this is normal */ {
|
||||
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, exitErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentState := p.CurrentState()
|
||||
switch currentState {
|
||||
case StateStopping:
|
||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
|
||||
p.state = StateStopped
|
||||
}
|
||||
default:
|
||||
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
|
||||
p.state = StateStopped // force it to be in this state
|
||||
}
|
||||
close(p.cmdWaitChan)
|
||||
}
|
||||
|
||||
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
|
||||
func (p *Process) cmdStopUpstreamProcess() error {
|
||||
p.processLogger.Debugf("<%s> cmdStopUpstreamProcess() initiating graceful stop of upstream process", p.ID)
|
||||
|
||||
// this should never happen ...
|
||||
if p.cmd == nil || p.cmd.Process == nil {
|
||||
p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID)
|
||||
return fmt.Errorf("<%s> process is nil or cmd is nil, skipping graceful stop", p.ID)
|
||||
}
|
||||
|
||||
if p.config.CmdStop != "" {
|
||||
// replace ${PID} with the pid of the process
|
||||
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
|
||||
if err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " "))
|
||||
|
||||
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||
stopCmd.Stdout = p.processLogger
|
||||
stopCmd.Stderr = p.processLogger
|
||||
stopCmd.Env = p.cmd.Env
|
||||
|
||||
if err := stopCmd.Run(); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+340
-13
@@ -2,10 +2,10 @@ package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -13,13 +13,26 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
debugLogger = NewLogMonitorWriter(os.Stdout)
|
||||
)
|
||||
|
||||
func init() {
|
||||
// flip to help with debugging tests
|
||||
if false {
|
||||
debugLogger.SetLogLevel(LevelDebug)
|
||||
} else {
|
||||
debugLogger.SetLogLevel(LevelError)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// Create a process
|
||||
process := NewProcess("test-process", 5, config, logMonitor)
|
||||
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -48,6 +61,32 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests
|
||||
// are all handled successfully, even though they all may ask for the process to .start()
|
||||
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(reqID int) {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
|
||||
// test that the automatic start returns the expected error type
|
||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
// Create a process configuration
|
||||
@@ -57,17 +96,20 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("broken", 1, config, NewLogMonitor())
|
||||
defer process.Stop()
|
||||
process := NewProcess("broken", 1, config, debugLogger, debugLogger)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unable to start process")
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "start() failed: ")
|
||||
}
|
||||
|
||||
// test that the process unloads after the TTL
|
||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long auto unload TTL test")
|
||||
@@ -79,7 +121,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
config.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, config.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard))
|
||||
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// this should take 4 seconds
|
||||
@@ -111,7 +153,36 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
assert.Equal(t, StateStopped, process.CurrentState())
|
||||
}
|
||||
|
||||
func TestProcess_LowTTLValue(t *testing.T) {
|
||||
if true { // change this code to run this ...
|
||||
t.Skip("skipping test, edit process_test.go to run it ")
|
||||
}
|
||||
|
||||
config := getTestSimpleResponderConfig("fast_ttl")
|
||||
assert.Equal(t, 0, config.UnloadAfter)
|
||||
config.UnloadAfter = 1 // second
|
||||
assert.Equal(t, 1, config.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
t.Logf("Waiting before sending request %d", i)
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
expected := fmt.Sprintf("echo=test_%d", i)
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expected)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// issue #19
|
||||
// This test makes sure using Process.Stop() does not affect pending HTTP
|
||||
// requests. All HTTP requests in this test should complete successfully.
|
||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
@@ -119,7 +190,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
|
||||
expectedMessage := "12345"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
|
||||
process := NewProcess("t", 10, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
results := map[string]string{
|
||||
@@ -135,8 +206,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func(key string) {
|
||||
defer wg.Done()
|
||||
// send a request that should take 5 * 200ms (1 second) to complete
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=200ms", key), nil)
|
||||
// send a request where simple-responder is will wait 300ms before responding
|
||||
// this will simulate an in-progress request.
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
process.ProxyRequest(w, req)
|
||||
@@ -152,9 +224,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
}(key)
|
||||
}
|
||||
|
||||
// stop the requests in the middle
|
||||
// Stop the process while requests are still being processed
|
||||
go func() {
|
||||
<-time.After(500 * time.Millisecond)
|
||||
<-time.After(150 * time.Millisecond)
|
||||
process.Stop()
|
||||
}()
|
||||
|
||||
@@ -164,3 +236,258 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
assert.Equal(t, key, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_SwapState(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentState ProcessState
|
||||
expectedState ProcessState
|
||||
newState ProcessState
|
||||
expectedError error
|
||||
expectedResult ProcessState
|
||||
}{
|
||||
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
||||
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
||||
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
||||
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, nil, StateStopped},
|
||||
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
||||
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
||||
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
||||
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
||||
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
||||
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
||||
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
||||
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
||||
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger)
|
||||
p.state = test.currentState
|
||||
|
||||
resultState, err := p.swapState(test.expectedState, test.newState)
|
||||
if err != nil && test.expectedError == nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
} else if err == nil && test.expectedError != nil {
|
||||
t.Errorf("Expected error: %v, but got none", test.expectedError)
|
||||
} else if err != nil && test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("Expected error: %v, got: %v", test.expectedError, err)
|
||||
}
|
||||
}
|
||||
|
||||
if resultState != test.expectedResult {
|
||||
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long shutdown test")
|
||||
}
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
|
||||
// make a config where the healthcheck will always fail because port is wrong
|
||||
config := getTestSimpleResponderConfigPort(expectedMessage, 9999)
|
||||
config.Proxy = "http://localhost:9998/test"
|
||||
|
||||
healthCheckTTLSeconds := 30
|
||||
process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger)
|
||||
|
||||
// make it a lot faster
|
||||
process.healthCheckLoopInterval = time.Second
|
||||
|
||||
// start a goroutine to simulate a shutdown
|
||||
var wg sync.WaitGroup
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-time.After(time.Millisecond * 500)
|
||||
process.Shutdown()
|
||||
}()
|
||||
wg.Add(1)
|
||||
|
||||
// start the process, this is a blocking call
|
||||
err := process.start()
|
||||
|
||||
wg.Wait()
|
||||
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
|
||||
assert.Equal(t, StateShutdown, process.CurrentState())
|
||||
}
|
||||
|
||||
func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping Exit Interrupts Health Check test")
|
||||
}
|
||||
|
||||
// should run and exit but interrupt the long checkHealthTimeout
|
||||
checkHealthTimeout := 5
|
||||
config := ModelConfig{
|
||||
Cmd: "sleep 1",
|
||||
Proxy: "http://127.0.0.1:9913",
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
|
||||
process.healthCheckLoopInterval = time.Second // make it faster
|
||||
err := process.start()
|
||||
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
func TestProcess_ConcurrencyLimit(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long concurrency limit test")
|
||||
}
|
||||
|
||||
expectedMessage := "concurrency_limit_test"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// only allow 1 concurrent request at a time
|
||||
config.ConcurrencyLimit = 1
|
||||
|
||||
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||
assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore))
|
||||
defer process.Stop()
|
||||
|
||||
// launch a goroutine first to take up the semaphore
|
||||
go func() {
|
||||
req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req1)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}()
|
||||
|
||||
// let the goroutine start
|
||||
<-time.After(time.Millisecond * 25)
|
||||
|
||||
denied := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, denied)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
}
|
||||
|
||||
func TestProcess_StopImmediately(t *testing.T) {
|
||||
expectedMessage := "test_stop_immediate"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
go func() {
|
||||
// slow, but will get killed by StopImmediate
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
}()
|
||||
<-time.After(time.Millisecond)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||
// the upstream command
|
||||
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping SIGTERM test on Windows ")
|
||||
}
|
||||
|
||||
expectedMessage := "test_sigkill"
|
||||
binaryPath := getSimpleResponderPath()
|
||||
port := getTestPort()
|
||||
|
||||
config := ModelConfig{
|
||||
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
||||
// to force the process to exit
|
||||
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// reduce to make testing go faster
|
||||
process.gracefulStopTimeout = time.Second
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
|
||||
waitChan := make(chan struct{})
|
||||
go func() {
|
||||
// slow, but will get killed by StopImmediate
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=2s", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
|
||||
// StatusOK because that was already sent before the kill
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// unexpected EOF because the kill happened, the "1" is sent before the kill
|
||||
// then the unexpected EOF is sent after the kill
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
|
||||
} else {
|
||||
assert.Contains(t, w.Body.String(), "unexpected EOF")
|
||||
}
|
||||
|
||||
close(waitChan)
|
||||
}()
|
||||
|
||||
<-time.After(time.Millisecond)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
|
||||
// the request should have been interrupted by SIGKILL
|
||||
<-waitChan
|
||||
}
|
||||
|
||||
func TestProcess_StopCmd(t *testing.T) {
|
||||
config := getTestSimpleResponderConfig("test_stop_cmd")
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
} else {
|
||||
config.CmdStop = "kill -TERM ${PID}"
|
||||
}
|
||||
|
||||
process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
|
||||
expectedMessage := "test_env_not_emptied"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// ensure that the the default config does not blank out the inherited environment
|
||||
configWEnv := config
|
||||
|
||||
// ensure the additiona variables are appended to the process' environment
|
||||
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
|
||||
|
||||
process1 := NewProcess("env_test", 2, config, debugLogger, debugLogger)
|
||||
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
|
||||
|
||||
process1.start()
|
||||
defer process1.Stop()
|
||||
process2.start()
|
||||
defer process2.Stop()
|
||||
|
||||
assert.NotZero(t, len(process1.cmd.Environ()))
|
||||
assert.NotZero(t, len(process2.cmd.Environ()))
|
||||
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type ProcessGroup struct {
|
||||
sync.Mutex
|
||||
|
||||
config Config
|
||||
id string
|
||||
swap bool
|
||||
exclusive bool
|
||||
persistent bool
|
||||
|
||||
proxyLogger *LogMonitor
|
||||
upstreamLogger *LogMonitor
|
||||
|
||||
// map of current processes
|
||||
processes map[string]*Process
|
||||
lastUsedProcess string
|
||||
}
|
||||
|
||||
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
||||
groupConfig, ok := config.Groups[id]
|
||||
if !ok {
|
||||
panic("Unable to find configuration for group id: " + id)
|
||||
}
|
||||
|
||||
pg := &ProcessGroup{
|
||||
id: id,
|
||||
config: config,
|
||||
swap: groupConfig.Swap,
|
||||
exclusive: groupConfig.Exclusive,
|
||||
persistent: groupConfig.Persistent,
|
||||
proxyLogger: proxyLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
processes: make(map[string]*Process),
|
||||
}
|
||||
|
||||
// Create a Process for each member in the group
|
||||
for _, modelID := range groupConfig.Members {
|
||||
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
|
||||
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
|
||||
pg.processes[modelID] = process
|
||||
}
|
||||
|
||||
return pg
|
||||
}
|
||||
|
||||
// ProxyRequest proxies a request to the specified model
|
||||
func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error {
|
||||
if !pg.HasMember(modelID) {
|
||||
return fmt.Errorf("model %s not part of group %s", modelID, pg.id)
|
||||
}
|
||||
|
||||
if pg.swap {
|
||||
pg.Lock()
|
||||
if pg.lastUsedProcess != modelID {
|
||||
if pg.lastUsedProcess != "" {
|
||||
pg.processes[pg.lastUsedProcess].Stop()
|
||||
}
|
||||
pg.lastUsedProcess = modelID
|
||||
}
|
||||
pg.Unlock()
|
||||
}
|
||||
|
||||
pg.processes[modelID].ProxyRequest(writer, request)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) HasMember(modelName string) bool {
|
||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
||||
pg.Lock()
|
||||
defer pg.Unlock()
|
||||
|
||||
if len(pg.processes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// stop Processes in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pg.processes {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) Shutdown() {
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pg.processes {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
process.Shutdown()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
"model4": getTestSimpleResponderConfig("model4"),
|
||||
"model5": getTestSimpleResponderConfig("model5"),
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model2"},
|
||||
},
|
||||
"G2": {
|
||||
Swap: false,
|
||||
Exclusive: true,
|
||||
Members: []string{"model3", "model4"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
||||
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
||||
assert.True(t, pg.HasMember("model5"))
|
||||
}
|
||||
|
||||
func TestProcessGroup_HasMember(t *testing.T) {
|
||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||
assert.True(t, pg.HasMember("model1"))
|
||||
assert.True(t, pg.HasMember("model2"))
|
||||
assert.False(t, pg.HasMember("model3"))
|
||||
}
|
||||
|
||||
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
tests := []string{"model1", "model2"}
|
||||
|
||||
for _, modelName := range tests {
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
reqBody := `{"x", "y"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
|
||||
// make sure only one process is in the running state
|
||||
count := 0
|
||||
for _, process := range pg.processes {
|
||||
if process.CurrentState() == StateReady {
|
||||
count++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
tests := []string{"model3", "model4"}
|
||||
|
||||
for _, modelName := range tests {
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
reqBody := `{"x", "y"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
})
|
||||
}
|
||||
|
||||
// make sure all the processes are running
|
||||
for _, process := range pg.processes {
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
}
|
||||
+393
-185
@@ -5,14 +5,17 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"sort"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -22,53 +25,121 @@ const (
|
||||
type ProxyManager struct {
|
||||
sync.Mutex
|
||||
|
||||
config *Config
|
||||
currentProcesses map[string]*Process
|
||||
logMonitor *LogMonitor
|
||||
ginEngine *gin.Engine
|
||||
config Config
|
||||
ginEngine *gin.Engine
|
||||
|
||||
// logging
|
||||
proxyLogger *LogMonitor
|
||||
upstreamLogger *LogMonitor
|
||||
muxLogger *LogMonitor
|
||||
|
||||
processGroups map[string]*ProcessGroup
|
||||
}
|
||||
|
||||
func New(config *Config) *ProxyManager {
|
||||
pm := &ProxyManager{
|
||||
config: config,
|
||||
currentProcesses: make(map[string]*Process),
|
||||
logMonitor: NewLogMonitor(),
|
||||
ginEngine: gin.New(),
|
||||
}
|
||||
func New(config Config) *ProxyManager {
|
||||
// set up loggers
|
||||
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
||||
proxyLogger := NewLogMonitorWriter(stdoutLogger)
|
||||
|
||||
if config.LogRequests {
|
||||
pm.ginEngine.Use(func(c *gin.Context) {
|
||||
// Start timer
|
||||
start := time.Now()
|
||||
|
||||
// capture these because /upstream/:model rewrites them in c.Next()
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// Process request
|
||||
c.Next()
|
||||
|
||||
// Stop timer
|
||||
duration := time.Since(start)
|
||||
|
||||
statusCode := c.Writer.Status()
|
||||
bodySize := c.Writer.Size()
|
||||
|
||||
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
|
||||
clientIP,
|
||||
time.Now().Format("2006-01-02 15:04:05"),
|
||||
method,
|
||||
path,
|
||||
c.Request.Proto,
|
||||
statusCode,
|
||||
bodySize,
|
||||
c.Request.UserAgent(),
|
||||
duration,
|
||||
)
|
||||
})
|
||||
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
|
||||
case "debug":
|
||||
proxyLogger.SetLogLevel(LevelDebug)
|
||||
upstreamLogger.SetLogLevel(LevelDebug)
|
||||
case "info":
|
||||
proxyLogger.SetLogLevel(LevelInfo)
|
||||
upstreamLogger.SetLogLevel(LevelInfo)
|
||||
case "warn":
|
||||
proxyLogger.SetLogLevel(LevelWarn)
|
||||
upstreamLogger.SetLogLevel(LevelWarn)
|
||||
case "error":
|
||||
proxyLogger.SetLogLevel(LevelError)
|
||||
upstreamLogger.SetLogLevel(LevelError)
|
||||
default:
|
||||
proxyLogger.SetLogLevel(LevelInfo)
|
||||
upstreamLogger.SetLogLevel(LevelInfo)
|
||||
}
|
||||
|
||||
pm := &ProxyManager{
|
||||
config: config,
|
||||
ginEngine: gin.New(),
|
||||
|
||||
proxyLogger: proxyLogger,
|
||||
muxLogger: stdoutLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
|
||||
processGroups: make(map[string]*ProcessGroup),
|
||||
}
|
||||
|
||||
// create the process groups
|
||||
for groupID := range config.Groups {
|
||||
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
|
||||
pm.processGroups[groupID] = processGroup
|
||||
}
|
||||
|
||||
pm.setupGinEngine()
|
||||
return pm
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) setupGinEngine() {
|
||||
pm.ginEngine.Use(func(c *gin.Context) {
|
||||
// Start timer
|
||||
start := time.Now()
|
||||
|
||||
// capture these because /upstream/:model rewrites them in c.Next()
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// Process request
|
||||
c.Next()
|
||||
|
||||
// Stop timer
|
||||
duration := time.Since(start)
|
||||
|
||||
statusCode := c.Writer.Status()
|
||||
bodySize := c.Writer.Size()
|
||||
|
||||
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
|
||||
clientIP,
|
||||
method,
|
||||
path,
|
||||
c.Request.Proto,
|
||||
statusCode,
|
||||
bodySize,
|
||||
c.Request.UserAgent(),
|
||||
duration,
|
||||
)
|
||||
})
|
||||
|
||||
// see: issue: #81, #77 and #42 for CORS issues
|
||||
// respond with permissive OPTIONS for any endpoint
|
||||
pm.ginEngine.Use(func(c *gin.Context) {
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
|
||||
// allow whatever the client requested by default
|
||||
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
||||
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
|
||||
c.Header("Access-Control-Allow-Headers", sanitized)
|
||||
} else {
|
||||
c.Header(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Content-Type, Authorization, Accept, X-Requested-With",
|
||||
)
|
||||
}
|
||||
c.Header("Access-Control-Max-Age", "86400")
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
})
|
||||
|
||||
// Set up routes using the Gin engine
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
||||
// Support legacy /v1/completions api, see issue #12
|
||||
@@ -80,6 +151,7 @@ func New(config *Config) *ProxyManager {
|
||||
|
||||
// Support audio/speech endpoint
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
||||
|
||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||
|
||||
@@ -87,67 +159,133 @@ func New(config *Config) *ProxyManager {
|
||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
||||
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
|
||||
|
||||
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
||||
|
||||
/**
|
||||
* User Interface Endpoints
|
||||
*/
|
||||
pm.ginEngine.GET("/", func(c *gin.Context) {
|
||||
// Set the Content-Type header to text/html
|
||||
c.Header("Content-Type", "text/html")
|
||||
|
||||
// Write the embedded HTML content to the response
|
||||
htmlData, err := getHTMLFile("index.html")
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
_, err = c.Writer.Write(htmlData)
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
||||
return
|
||||
}
|
||||
c.Redirect(http.StatusFound, "/ui")
|
||||
})
|
||||
|
||||
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
||||
c.Redirect(http.StatusFound, "/ui/models")
|
||||
})
|
||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
||||
|
||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||
|
||||
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
||||
if data, err := getHTMLFile("favicon.ico"); err == nil {
|
||||
if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil {
|
||||
c.Data(http.StatusOK, "image/x-icon", data)
|
||||
} else {
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
})
|
||||
|
||||
reactFS, err := GetReactFS()
|
||||
if err != nil {
|
||||
pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err)
|
||||
} else {
|
||||
|
||||
// serve files that exist under /ui/*
|
||||
pm.ginEngine.StaticFS("/ui", reactFS)
|
||||
|
||||
// server SPA for UI under /ui/*
|
||||
pm.ginEngine.NoRoute(func(c *gin.Context) {
|
||||
if !strings.HasPrefix(c.Request.URL.Path, "/ui") {
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
file, err := reactFS.Open("index.html")
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
http.ServeContent(c.Writer, c.Request, "index.html", time.Now(), file)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
// see: proxymanager_api.go
|
||||
// add API handler functions
|
||||
addApiHandlers(pm)
|
||||
|
||||
// Disable console color for testing
|
||||
gin.DisableConsoleColor()
|
||||
|
||||
return pm
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) Run(addr ...string) error {
|
||||
return pm.ginEngine.Run(addr...)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) {
|
||||
// ServeHTTP implements http.Handler interface
|
||||
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
pm.ginEngine.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) StopProcesses() {
|
||||
// StopProcesses acquires a lock and stops all running upstream processes.
|
||||
// This is the public method safe for concurrent calls.
|
||||
// Unlike Shutdown, this method only stops the processes but doesn't perform
|
||||
// a complete shutdown, allowing for process replacement without full termination.
|
||||
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
pm.stopProcesses()
|
||||
// stop Processes in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, processGroup := range pm.processGroups {
|
||||
wg.Add(1)
|
||||
go func(processGroup *ProcessGroup) {
|
||||
defer wg.Done()
|
||||
processGroup.StopProcesses(strategy)
|
||||
}(processGroup)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// for internal usage
|
||||
func (pm *ProxyManager) stopProcesses() {
|
||||
if len(pm.currentProcesses) == 0 {
|
||||
return
|
||||
// Shutdown stops all processes managed by this ProxyManager
|
||||
func (pm *ProxyManager) Shutdown() {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
// Send shutdown signal to all process in groups
|
||||
for _, processGroup := range pm.processGroups {
|
||||
wg.Add(1)
|
||||
go func(processGroup *ProcessGroup) {
|
||||
defer wg.Done()
|
||||
processGroup.Shutdown()
|
||||
}(processGroup)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
||||
// de-alias the real model name and get a real one
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
|
||||
}
|
||||
|
||||
for _, process := range pm.currentProcesses {
|
||||
process.Stop()
|
||||
processGroup := pm.findGroupByModelName(realModelName)
|
||||
if processGroup == nil {
|
||||
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
|
||||
}
|
||||
|
||||
pm.currentProcesses = make(map[string]*Process)
|
||||
if processGroup.exclusive {
|
||||
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
|
||||
for groupId, otherGroup := range pm.processGroups {
|
||||
if groupId != processGroup.id && !otherGroup.persistent {
|
||||
otherGroup.StopProcesses(StopWaitForInflightRequest)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return processGroup, realModelName, nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
@@ -173,73 +311,12 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Encode the data as JSON and write it to the response writer
|
||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"object": "list", "data": data}); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
||||
profileName, modelName := "", requestedModel
|
||||
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||
profileName = requestedModel[:idx]
|
||||
modelName = requestedModel[idx+1:]
|
||||
}
|
||||
|
||||
if profileName != "" {
|
||||
if _, found := pm.config.Profiles[profileName]; !found {
|
||||
return nil, fmt.Errorf("model group not found %s", profileName)
|
||||
}
|
||||
}
|
||||
|
||||
// de-alias the real model name and get a real one
|
||||
realModelName, found := pm.config.RealModelName(modelName)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
|
||||
}
|
||||
|
||||
// exit early when already running, otherwise stop everything and swap
|
||||
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
||||
|
||||
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
||||
return process, nil
|
||||
}
|
||||
|
||||
// stop all running models
|
||||
pm.stopProcesses()
|
||||
|
||||
if profileName == "" {
|
||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
||||
}
|
||||
|
||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||
processKey := ProcessKeyName(profileName, modelID)
|
||||
pm.currentProcesses[processKey] = process
|
||||
} else {
|
||||
for _, modelName := range pm.config.Profiles[profileName] {
|
||||
if realModelName, found := pm.config.RealModelName(modelName); found {
|
||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
|
||||
}
|
||||
|
||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||
processKey := ProcessKeyName(profileName, modelID)
|
||||
pm.currentProcesses[processKey] = process
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// requestedProcessKey should exist due to swap
|
||||
return pm.currentProcesses[requestedProcessKey], nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
requestedModel := c.Param("model_id")
|
||||
|
||||
@@ -248,38 +325,15 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if process, err := pm.swapModel(requestedModel); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
} else {
|
||||
// rewrite the path
|
||||
c.Request.URL.Path = c.Param("upstreamPath")
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
processGroup, _, err := pm.swapProcessGroup(requestedModel)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
|
||||
var html strings.Builder
|
||||
|
||||
html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><ul>")
|
||||
|
||||
// Extract keys and sort them
|
||||
var modelIDs []string
|
||||
for modelID, modelConfig := range pm.config.Models {
|
||||
if modelConfig.Unlisted {
|
||||
continue
|
||||
}
|
||||
|
||||
modelIDs = append(modelIDs, modelID)
|
||||
}
|
||||
sort.Strings(modelIDs)
|
||||
|
||||
// Iterate over sorted keys
|
||||
for _, modelID := range modelIDs {
|
||||
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a></li>", modelID, modelID))
|
||||
}
|
||||
html.WriteString("</ul></body></html>")
|
||||
c.Header("Content-Type", "text/html")
|
||||
c.String(http.StatusOK, html.String())
|
||||
// rewrite the path
|
||||
c.Request.URL.Path = c.Param("upstreamPath")
|
||||
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
@@ -289,28 +343,149 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var requestBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
model, ok := requestBody["model"].(string)
|
||||
if !ok {
|
||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
return
|
||||
}
|
||||
|
||||
if process, err := pm.swapModel(model); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
} else {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
}
|
||||
|
||||
// dechunk it as we already have all the body bytes see issue #11
|
||||
c.Request.Header.Del("transfer-encoding")
|
||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[realModelName].UseModelName
|
||||
if useModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// dechunk it as we already have all the body bytes see issue #11
|
||||
c.Request.Header.Del("transfer-encoding")
|
||||
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
c.Request.ContentLength = int64(len(bodyBytes))
|
||||
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
// Parse multipart form
|
||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Get model parameter from the form
|
||||
requestedModel := c.Request.FormValue("model")
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||
// Create a new buffer for the reconstructed request
|
||||
var requestBuffer bytes.Buffer
|
||||
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||
|
||||
// Copy all form values
|
||||
for key, values := range c.Request.MultipartForm.Value {
|
||||
for _, value := range values {
|
||||
fieldValue := value
|
||||
// If this is the model field and we have a profile, use just the model name
|
||||
if key == "model" {
|
||||
// # issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[realModelName].UseModelName
|
||||
|
||||
if useModelName != "" {
|
||||
fieldValue = useModelName
|
||||
} else {
|
||||
fieldValue = requestedModel
|
||||
}
|
||||
}
|
||||
field, err := multipartWriter.CreateFormField(key)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
|
||||
return
|
||||
}
|
||||
if _, err = field.Write([]byte(fieldValue)); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copy all files from the original request
|
||||
for key, fileHeaders := range c.Request.MultipartForm.File {
|
||||
for _, fileHeader := range fileHeaders {
|
||||
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
|
||||
return
|
||||
}
|
||||
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = io.Copy(formFile, file); err != nil {
|
||||
file.Close()
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
|
||||
return
|
||||
}
|
||||
file.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Close the multipart writer to finalize the form
|
||||
if err := multipartWriter.Close(); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
|
||||
return
|
||||
}
|
||||
|
||||
// Create a new request with the reconstructed form data
|
||||
modifiedReq, err := http.NewRequestWithContext(
|
||||
c.Request.Context(),
|
||||
c.Request.Method,
|
||||
c.Request.URL.String(),
|
||||
&requestBuffer,
|
||||
)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
|
||||
return
|
||||
}
|
||||
|
||||
// Copy the headers from the original request
|
||||
modifiedReq.Header = c.Request.Header.Clone()
|
||||
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
||||
|
||||
// set the content length of the body
|
||||
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
|
||||
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||
|
||||
// Use the modified request for proxying
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -324,6 +499,39 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
||||
}
|
||||
}
|
||||
|
||||
func ProcessKeyName(groupName, modelName string) string {
|
||||
return groupName + PROFILE_SPLIT_CHAR + modelName
|
||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||
pm.StopProcesses(StopImmediately)
|
||||
c.String(http.StatusOK, "OK")
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
||||
context.Header("Content-Type", "application/json")
|
||||
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
||||
|
||||
for _, processGroup := range pm.processGroups {
|
||||
for _, process := range processGroup.processes {
|
||||
if process.CurrentState() == StateReady {
|
||||
runningProcesses = append(runningProcesses, gin.H{
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Put the results under the `running` key.
|
||||
response := gin.H{
|
||||
"running": runningProcesses,
|
||||
}
|
||||
|
||||
context.JSON(http.StatusOK, response) // Always return 200 OK
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
|
||||
for _, group := range pm.processGroups {
|
||||
if group.HasMember(modelName) {
|
||||
return group
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
Id string `json:"id"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
func addApiHandlers(pm *ProxyManager) {
|
||||
// Add API endpoints for React to consume
|
||||
apiGroup := pm.ginEngine.Group("/api")
|
||||
{
|
||||
apiGroup.GET("/models", pm.apiListModels)
|
||||
apiGroup.GET("/modelsSSE", pm.apiListModelsSSE)
|
||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiUnloadAllModels(c *gin.Context) {
|
||||
pm.StopProcesses(StopImmediately)
|
||||
c.JSON(http.StatusOK, gin.H{"msg": "ok"})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) getModelStatus() []Model {
|
||||
// Extract keys and sort them
|
||||
models := []Model{}
|
||||
|
||||
modelIDs := make([]string, 0, len(pm.config.Models))
|
||||
for modelID := range pm.config.Models {
|
||||
modelIDs = append(modelIDs, modelID)
|
||||
}
|
||||
sort.Strings(modelIDs)
|
||||
|
||||
// Iterate over sorted keys
|
||||
for _, modelID := range modelIDs {
|
||||
// Get process state
|
||||
processGroup := pm.findGroupByModelName(modelID)
|
||||
state := "unknown"
|
||||
if processGroup != nil {
|
||||
process := processGroup.processes[modelID]
|
||||
if process != nil {
|
||||
var stateStr string
|
||||
switch process.CurrentState() {
|
||||
case StateReady:
|
||||
stateStr = "ready"
|
||||
case StateStarting:
|
||||
stateStr = "starting"
|
||||
case StateStopping:
|
||||
stateStr = "stopping"
|
||||
case StateShutdown:
|
||||
stateStr = "shutdown"
|
||||
case StateStopped:
|
||||
stateStr = "stopped"
|
||||
default:
|
||||
stateStr = "unknown"
|
||||
}
|
||||
state = stateStr
|
||||
}
|
||||
}
|
||||
models = append(models, Model{
|
||||
Id: modelID,
|
||||
State: state,
|
||||
})
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiListModels(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, pm.getModelStatus())
|
||||
}
|
||||
|
||||
// stream the models as a SSE
|
||||
func (pm *ProxyManager) apiListModelsSSE(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
|
||||
// Stream new events
|
||||
for {
|
||||
select {
|
||||
case <-notify:
|
||||
return
|
||||
default:
|
||||
models := pm.getModelStatus()
|
||||
c.SSEvent("message", models)
|
||||
c.Writer.Flush()
|
||||
<-time.After(1000 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -9,26 +9,12 @@ import (
|
||||
)
|
||||
|
||||
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||
|
||||
accept := c.GetHeader("Accept")
|
||||
if strings.Contains(accept, "text/html") {
|
||||
// Set the Content-Type header to text/html
|
||||
c.Header("Content-Type", "text/html")
|
||||
|
||||
// Write the embedded HTML content to the response
|
||||
logsHTML, err := getHTMLFile("logs.html")
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
_, err = c.Writer.Write(logsHTML)
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
||||
return
|
||||
}
|
||||
c.Redirect(http.StatusFound, "/ui/")
|
||||
} else {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
history := pm.logMonitor.GetHistory()
|
||||
history := pm.muxLogger.GetHistory()
|
||||
_, err := c.Writer.Write(history)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, err)
|
||||
@@ -42,8 +28,14 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
ch := pm.logMonitor.Subscribe()
|
||||
defer pm.logMonitor.Unsubscribe(ch)
|
||||
logMonitorId := c.Param("logMonitorID")
|
||||
logger, err := pm.getLogger(logMonitorId)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
ch := logger.Subscribe()
|
||||
defer logger.Unsubscribe(ch)
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
@@ -56,7 +48,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
// Send history first if not skipped
|
||||
|
||||
if !skipHistory {
|
||||
history := pm.logMonitor.GetHistory()
|
||||
history := logger.GetHistory()
|
||||
if len(history) != 0 {
|
||||
c.Writer.Write(history)
|
||||
flusher.Flush()
|
||||
@@ -85,15 +77,21 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
ch := pm.logMonitor.Subscribe()
|
||||
defer pm.logMonitor.Unsubscribe(ch)
|
||||
logMonitorId := c.Param("logMonitorID")
|
||||
logger, err := pm.getLogger(logMonitorId)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
ch := logger.Subscribe()
|
||||
defer logger.Unsubscribe(ch)
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
|
||||
// Send history first if not skipped
|
||||
_, skipHistory := c.GetQuery("no-history")
|
||||
if !skipHistory {
|
||||
history := pm.logMonitor.GetHistory()
|
||||
history := logger.GetHistory()
|
||||
if len(history) != 0 {
|
||||
c.SSEvent("message", string(history))
|
||||
c.Writer.Flush()
|
||||
@@ -111,3 +109,21 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getLogger searches for the appropriate logger based on the logMonitorId
|
||||
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
|
||||
var logger *LogMonitor
|
||||
|
||||
if logMonitorId == "" {
|
||||
// maintain the default
|
||||
logger = pm.muxLogger
|
||||
} else if logMonitorId == "proxy" {
|
||||
logger = pm.proxyLogger
|
||||
} else if logMonitorId == "upstream" {
|
||||
logger = pm.upstreamLogger
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
|
||||
}
|
||||
|
||||
return logger, nil
|
||||
}
|
||||
|
||||
+459
-46
@@ -4,87 +4,124 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
}
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
for _, modelName := range []string{"model1", "model2"} {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
|
||||
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
|
||||
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
|
||||
|
||||
}
|
||||
|
||||
// make sure there's only one loaded model
|
||||
assert.Len(t, proxy.currentProcesses, 1)
|
||||
}
|
||||
|
||||
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||
|
||||
model1 := "path1/model1"
|
||||
model2 := "path2/model2"
|
||||
|
||||
profileModel1 := ProcessKeyName("test", model1)
|
||||
profileModel2 := ProcessKeyName("test", model2)
|
||||
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
model1: getTestSimpleResponderConfig("model1"),
|
||||
model2: getTestSimpleResponderConfig("model2"),
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Profiles: map[string][]string{
|
||||
"test": {model1, model2},
|
||||
LogLevel: "error",
|
||||
Groups: map[string]GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model1"},
|
||||
},
|
||||
"G2": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
for modelID, requestedModel := range map[string]string{
|
||||
"model1": profileModel1,
|
||||
"model2": profileModel2,
|
||||
} {
|
||||
tests := []string{"model1", "model2"}
|
||||
for _, requestedModel := range tests {
|
||||
t.Run(requestedModel, func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), requestedModel)
|
||||
})
|
||||
}
|
||||
|
||||
// make sure there's two loaded models
|
||||
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
|
||||
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
|
||||
}
|
||||
|
||||
// Test that a persistent group is not affected by the swapping behaviour of
|
||||
// other groups.
|
||||
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
Groups: map[string]GroupConfig{
|
||||
// the forever group is persistent and should not be affected by model1
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
// make requests to load all models, loading model1 should not affect model2
|
||||
tests := []string{"model2", "model1"}
|
||||
for _, requestedModel := range tests {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelID)
|
||||
assert.Contains(t, w.Body.String(), requestedModel)
|
||||
}
|
||||
|
||||
// make sure there's two loaded models
|
||||
assert.Len(t, proxy.currentProcesses, 2)
|
||||
_, exists := proxy.currentProcesses[profileModel1]
|
||||
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
|
||||
|
||||
_, exists = proxy.currentProcesses[profileModel2]
|
||||
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
|
||||
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
|
||||
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
|
||||
}
|
||||
|
||||
// When a request for a different model comes in ProxyManager should wait until
|
||||
@@ -94,17 +131,18 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
}
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
results := map[string]string{}
|
||||
|
||||
@@ -120,15 +158,16 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
|
||||
results[key] = w.Body.String()
|
||||
var response map[string]string
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
results[key] = response["responseMessage"]
|
||||
mu.Unlock()
|
||||
}(key)
|
||||
|
||||
@@ -144,13 +183,14 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
config := &Config{
|
||||
config := Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
@@ -161,7 +201,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the listModelsHandler
|
||||
proxy.HandlerFunc(w, req)
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
// Check the response status code
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -210,3 +250,376 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
// Ensure all expected models were returned
|
||||
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||
}
|
||||
|
||||
func TestProxyManager_Shutdown(t *testing.T) {
|
||||
// make broken model configurations
|
||||
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||
model1Config.Proxy = "http://localhost:10001/"
|
||||
|
||||
model2Config := getTestSimpleResponderConfigPort("model2", 9992)
|
||||
model2Config.Proxy = "http://localhost:10002/"
|
||||
|
||||
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
||||
model3Config.Proxy = "http://localhost:10003/"
|
||||
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": model1Config,
|
||||
"model2": model2Config,
|
||||
"model3": model3Config,
|
||||
},
|
||||
LogLevel: "error",
|
||||
Groups: map[string]GroupConfig{
|
||||
"test": {
|
||||
Swap: false,
|
||||
Members: []string{"model1", "model2", "model3"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
|
||||
// Start all the processes
|
||||
var wg sync.WaitGroup
|
||||
for _, modelName := range []string{"model1", "model2", "model3"} {
|
||||
wg.Add(1)
|
||||
go func(modelName string) {
|
||||
defer wg.Done()
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// send a request to trigger the proxy to load ... this should hang waiting for start up
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
|
||||
}(modelName)
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-time.After(time.Second)
|
||||
proxy.Shutdown()
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestProxyManager_Unload(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
||||
req = httptest.NewRequest("GET", "/unload", nil)
|
||||
w = httptest.NewRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, w.Body.String(), "OK")
|
||||
|
||||
// give it a bit of time to stop
|
||||
<-time.After(time.Millisecond * 250)
|
||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
// Test issue #61 `Listing the current list of models and the loaded model.`
|
||||
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
// Shared configuration
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
LogLevel: "warn",
|
||||
})
|
||||
|
||||
// Define a helper struct to parse the JSON response.
|
||||
type RunningResponse struct {
|
||||
Running []struct {
|
||||
Model string `json:"model"`
|
||||
State string `json:"state"`
|
||||
} `json:"running"`
|
||||
}
|
||||
|
||||
// Create proxy once for all tests
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
t.Run("no models loaded", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/running", nil)
|
||||
w := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response RunningResponse
|
||||
|
||||
// Check if this is a valid JSON object.
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
|
||||
// We should have an empty running array here.
|
||||
assert.Empty(t, response.Running, "expected no running models")
|
||||
})
|
||||
|
||||
t.Run("single model loaded", func(t *testing.T) {
|
||||
// Load just a model.
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Simulate browser call for the `/running` endpoint.
|
||||
req = httptest.NewRequest("GET", "/running", nil)
|
||||
w = httptest.NewRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
var response RunningResponse
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
|
||||
// Check if we have a single array element.
|
||||
assert.Len(t, response.Running, 1)
|
||||
|
||||
// Is this the right model?
|
||||
assert.Equal(t, "model1", response.Running[0].Model)
|
||||
|
||||
// Is the model loaded?
|
||||
assert.Equal(t, "ready", response.Running[0].State)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
// Create a buffer with multipart form data
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
|
||||
// Add the model field
|
||||
fw, err := w.CreateFormField("model")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte("TheExpectedModel"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Add a file field
|
||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||
assert.NoError(t, err)
|
||||
// Generate random content length between 10 and 20
|
||||
contentLength := rand.Intn(11) + 10 // 10 to 20
|
||||
content := make([]byte, contentLength)
|
||||
_, err = fw.Write(content)
|
||||
assert.NoError(t, err)
|
||||
w.Close()
|
||||
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
|
||||
// Verify the response
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response map[string]string
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "TheExpectedModel", response["model"])
|
||||
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
||||
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"])
|
||||
}
|
||||
|
||||
// Test useModelName in configuration sends overrides what is sent to upstream
|
||||
func TestProxyManager_UseModelName(t *testing.T) {
|
||||
upstreamModelName := "upstreamModel"
|
||||
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
||||
modelConfig.UseModelName = upstreamModelName
|
||||
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": modelConfig,
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
requestedModel := "model1"
|
||||
|
||||
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), upstreamModelName)
|
||||
|
||||
// make sure the content length was set correctly
|
||||
// simple-responder will return the content length it got in the response
|
||||
body := w.Body.Bytes()
|
||||
contentLength := int(gjson.GetBytes(body, "h_content_length").Int())
|
||||
assert.Equal(t, len(fmt.Sprintf(`{"model":"%s"}`, upstreamModelName)), contentLength)
|
||||
})
|
||||
|
||||
t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) {
|
||||
// Create a buffer with multipart form data
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
|
||||
// Add the model field
|
||||
fw, err := w.CreateFormField("model")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte(requestedModel))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Add a file field
|
||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte("test"))
|
||||
assert.NoError(t, err)
|
||||
w.Close()
|
||||
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
|
||||
// Verify the response
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response map[string]string
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, upstreamModelName, response["model"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
requestHeaders map[string]string
|
||||
expectedStatus int
|
||||
expectedHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "OPTIONS with no headers",
|
||||
method: "OPTIONS",
|
||||
expectedStatus: http.StatusNoContent,
|
||||
expectedHeaders: map[string]string{
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OPTIONS with specific headers",
|
||||
method: "OPTIONS",
|
||||
requestHeaders: map[string]string{
|
||||
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
|
||||
},
|
||||
expectedStatus: http.StatusNoContent,
|
||||
expectedHeaders: map[string]string{
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Non-OPTIONS request",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
|
||||
for k, v := range tt.requestHeaders {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
|
||||
for header, expectedValue := range tt.expectedHeaders {
|
||||
assert.Equal(t, expectedValue, w.Header().Get(header))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_Upstream(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "model1", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var response map[string]string
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
assert.Equal(t, "81", response["h_content_length"])
|
||||
assert.Equal(t, "model1", response["responseMessage"])
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func isTokenChar(r rune) bool {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r >= '0' && r <= '9':
|
||||
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
|
||||
parts := strings.Split(headerValues, ",")
|
||||
valid := make([]string, 0, len(parts))
|
||||
|
||||
for _, p := range parts {
|
||||
v := strings.TrimSpace(p)
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
validPart := true
|
||||
for _, c := range v {
|
||||
if !isTokenChar(c) {
|
||||
validPart = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if validPart {
|
||||
valid = append(valid, v)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(valid, ", ")
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package proxy
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
input: " ",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single valid value",
|
||||
input: "content-type",
|
||||
expected: "content-type",
|
||||
},
|
||||
{
|
||||
name: "multiple valid values",
|
||||
input: "content-type, authorization, x-requested-with",
|
||||
expected: "content-type, authorization, x-requested-with",
|
||||
},
|
||||
{
|
||||
name: "values with extra spaces",
|
||||
input: " content-type , authorization ",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "values with tabs",
|
||||
input: "content-type,\tauthorization",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "values with invalid characters",
|
||||
input: "content-type, auth\n, x-requested-with\r",
|
||||
expected: "content-type, auth, x-requested-with",
|
||||
},
|
||||
{
|
||||
name: "empty values in list",
|
||||
input: "content-type,,authorization",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "leading and trailing commas",
|
||||
input: ",content-type,authorization,",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "mixed valid and invalid values",
|
||||
input: "content-type, \x00invalid, x-requested-with",
|
||||
expected: "content-type, x-requested-with",
|
||||
},
|
||||
{
|
||||
name: "mixed case values",
|
||||
input: "Content-Type, my-Valid-Header, Another-hEader",
|
||||
expected: "Content-Type, my-Valid-Header, Another-hEader",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SanitizeAccessControlRequestHeaderValues(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
|
||||
tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
//go:embed ui_dist
|
||||
var reactStaticFS embed.FS
|
||||
|
||||
// GetReactFS returns the embedded React filesystem
|
||||
func GetReactFS() (http.FileSystem, error) {
|
||||
subFS, err := fs.Sub(reactStaticFS, "ui_dist")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return http.FS(subFS), nil
|
||||
}
|
||||
|
||||
// GetReactIndexHTML returns the main index.html for the React app
|
||||
func GetReactIndexHTML() ([]byte, error) {
|
||||
return reactStaticFS.ReadFile("ui_dist/index.html")
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
#!/bin/sh
|
||||
# This script installs llama-swap on Linux.
|
||||
# It detects the current operating system architecture and installs the appropriate version of llama-swap.
|
||||
|
||||
set -eu
|
||||
|
||||
LLAMA_SWAP_DEFAULT_ADDRESS=${LLAMA_SWAP_DEFAULT_ADDRESS:-"127.0.0.1:8080"}
|
||||
|
||||
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
|
||||
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
|
||||
|
||||
status() { echo ">>> $*" >&2; }
|
||||
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
|
||||
warning() { echo "${red}WARNING:${plain} $*"; }
|
||||
|
||||
available() { command -v "$1" >/dev/null; }
|
||||
require() {
|
||||
_MISSING=''
|
||||
for TOOL in "$@"; do
|
||||
if ! available "$TOOL"; then
|
||||
_MISSING="$_MISSING $TOOL"
|
||||
fi
|
||||
done
|
||||
|
||||
echo "$_MISSING"
|
||||
}
|
||||
|
||||
SUDO=
|
||||
if [ "$(id -u)" -ne 0 ]; then
|
||||
if ! available sudo; then
|
||||
error "This script requires superuser permissions. Please re-run as root."
|
||||
fi
|
||||
|
||||
SUDO="sudo"
|
||||
fi
|
||||
|
||||
NEEDS=$(require tee tar python3 mktemp)
|
||||
if [ -n "$NEEDS" ]; then
|
||||
status "ERROR: The following tools are required but missing:"
|
||||
for NEED in $NEEDS; do
|
||||
echo " - $NEED"
|
||||
done
|
||||
exit 1
|
||||
fi
|
||||
|
||||
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
|
||||
|
||||
ARCH=$(uname -m)
|
||||
case "$ARCH" in
|
||||
x86_64) ARCH="amd64" ;;
|
||||
aarch64|arm64) ARCH="arm64" ;;
|
||||
*) error "Unsupported architecture: $ARCH" ;;
|
||||
esac
|
||||
|
||||
IS_WSL2=false
|
||||
|
||||
KERN=$(uname -r)
|
||||
case "$KERN" in
|
||||
*icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;;
|
||||
*icrosoft) error "Microsoft WSL1 is not currently supported. Please use WSL2 with 'wsl --set-version <distro> 2'" ;;
|
||||
*) ;;
|
||||
esac
|
||||
|
||||
download_binary() {
|
||||
ASSET_NAME="linux_$ARCH"
|
||||
|
||||
TMPDIR=$(mktemp -d)
|
||||
trap 'rm -rf "${TMPDIR}"' EXIT INT TERM HUP
|
||||
PYTHON_SCRIPT=$(cat <<EOF
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
import urllib.request
|
||||
|
||||
ASSET_NAME = "${ASSET_NAME}"
|
||||
|
||||
with urllib.request.urlopen("https://api.github.com/repos/mostlygeek/llama-swap/releases/latest") as resp:
|
||||
data = json.load(resp)
|
||||
for asset in data.get("assets", []):
|
||||
if ASSET_NAME in asset.get("name", ""):
|
||||
url = asset["browser_download_url"]
|
||||
break
|
||||
else:
|
||||
print("ERROR: Matching asset not found.", file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
print("Downloading:", url, file=sys.stderr)
|
||||
output_path = os.path.join("${TMPDIR}", "llama-swap.tar.gz")
|
||||
urllib.request.urlretrieve(url, output_path)
|
||||
print(output_path)
|
||||
EOF
|
||||
)
|
||||
|
||||
TARFILE=$(python3 -c "$PYTHON_SCRIPT")
|
||||
if [ ! -f "$TARFILE" ]; then
|
||||
error "Failed to download binary."
|
||||
fi
|
||||
|
||||
status "Extracting to /usr/local/bin"
|
||||
$SUDO tar -xzf "$TARFILE" -C /usr/local/bin llama-swap
|
||||
}
|
||||
download_binary
|
||||
|
||||
configure_systemd() {
|
||||
if ! id llama-swap >/dev/null 2>&1; then
|
||||
status "Creating llama-swap user..."
|
||||
$SUDO useradd -r -s /bin/false -U -m -d /usr/share/llama-swap llama-swap
|
||||
fi
|
||||
if getent group render >/dev/null 2>&1; then
|
||||
status "Adding llama-swap user to render group..."
|
||||
$SUDO usermod -a -G render llama-swap
|
||||
fi
|
||||
if getent group video >/dev/null 2>&1; then
|
||||
status "Adding llama-swap user to video group..."
|
||||
$SUDO usermod -a -G video llama-swap
|
||||
fi
|
||||
if getent group docker >/dev/null 2>&1; then
|
||||
status "Adding llama-swap user to docker group..."
|
||||
$SUDO usermod -a -G docker llama-swap
|
||||
fi
|
||||
|
||||
status "Adding current user to llama-swap group..."
|
||||
$SUDO usermod -a -G llama-swap "$(whoami)"
|
||||
|
||||
if [ ! -f "/usr/share/llama-swap/config.yaml" ]; then
|
||||
status "Creating default config.yaml..."
|
||||
cat <<EOF | $SUDO -u llama-swap tee /usr/share/llama-swap/config.yaml >/dev/null
|
||||
# default 15s likely to fail for default models due to downloading models
|
||||
healthCheckTimeout: 60
|
||||
|
||||
models:
|
||||
"qwen2.5":
|
||||
cmd: |
|
||||
docker run
|
||||
--rm
|
||||
-p \${PORT}:8080
|
||||
--name qwen2.5
|
||||
ghcr.io/ggml-org/llama.cpp:server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
cmdStop: docker stop qwen2.5
|
||||
|
||||
"smollm2":
|
||||
cmd: |
|
||||
docker run
|
||||
--rm
|
||||
-p \${PORT}:8080
|
||||
--name smollm2
|
||||
ghcr.io/ggml-org/llama.cpp:server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
cmdStop: docker stop smollm2
|
||||
EOF
|
||||
fi
|
||||
|
||||
status "Creating llama-swap systemd service..."
|
||||
cat <<EOF | $SUDO tee /etc/systemd/system/llama-swap.service >/dev/null
|
||||
[Unit]
|
||||
Description=llama-swap
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
User=llama-swap
|
||||
Group=llama-swap
|
||||
|
||||
# set this to match your environment
|
||||
ExecStart=/usr/local/bin/llama-swap --config /usr/share/llama-swap/config.yaml --watch-config -listen ${LLAMA_SWAP_DEFAULT_ADDRESS}
|
||||
|
||||
Restart=on-failure
|
||||
RestartSec=3
|
||||
StartLimitBurst=3
|
||||
StartLimitInterval=30
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
EOF
|
||||
SYSTEMCTL_RUNNING="$(systemctl is-system-running || true)"
|
||||
case $SYSTEMCTL_RUNNING in
|
||||
running|degraded)
|
||||
status "Enabling and starting llama-swap service..."
|
||||
$SUDO systemctl daemon-reload
|
||||
$SUDO systemctl enable llama-swap
|
||||
|
||||
start_service() { $SUDO systemctl restart llama-swap; }
|
||||
trap start_service EXIT
|
||||
;;
|
||||
*)
|
||||
warning "systemd is not running"
|
||||
if [ "$IS_WSL2" = true ]; then
|
||||
warning "see https://learn.microsoft.com/en-us/windows/wsl/systemd#how-to-enable-systemd to enable it"
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
if available systemctl; then
|
||||
configure_systemd
|
||||
fi
|
||||
|
||||
install_success() {
|
||||
status "The llama-swap API is now available at http://${LLAMA_SWAP_DEFAULT_ADDRESS}"
|
||||
status 'Customize the config file at /usr/share/llama-swap/config.yaml.'
|
||||
status 'Install complete.'
|
||||
}
|
||||
|
||||
# WSL2 only supports GPUs via nvidia passthrough
|
||||
# so check for nvidia-smi to determine if GPU is available
|
||||
if [ "$IS_WSL2" = true ]; then
|
||||
if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
|
||||
status "Nvidia GPU detected."
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
|
||||
install_success
|
||||
@@ -0,0 +1,68 @@
|
||||
#!/bin/sh
|
||||
# This script uninstalls llama-swap on Linux.
|
||||
# It removes the binary, systemd service, config.yaml (optional), and llama-swap user and group.
|
||||
|
||||
set -eu
|
||||
|
||||
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
|
||||
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
|
||||
|
||||
status() { echo ">>> $*" >&2; }
|
||||
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
|
||||
warning() { echo "${red}WARNING:${plain} $*"; }
|
||||
|
||||
available() { command -v $1 >/dev/null; }
|
||||
|
||||
SUDO=
|
||||
if [ "$(id -u)" -ne 0 ]; then
|
||||
if ! available sudo; then
|
||||
error "This script requires superuser permissions. Please re-run as root."
|
||||
fi
|
||||
|
||||
SUDO="sudo"
|
||||
fi
|
||||
|
||||
configure_systemd() {
|
||||
status "Stopping llama-swap service..."
|
||||
$SUDO systemctl stop llama-swap
|
||||
|
||||
status "Disabling llama-swap service..."
|
||||
$SUDO systemctl disable llama-swap
|
||||
}
|
||||
if available systemctl; then
|
||||
configure_systemd
|
||||
fi
|
||||
|
||||
if available llama-swap; then
|
||||
status "Removing llama-swap binary..."
|
||||
$SUDO rm $(which llama-swap)
|
||||
fi
|
||||
|
||||
if [ -f "/usr/share/llama-swap/config.yaml" ]; then
|
||||
while true; do
|
||||
printf "Delete config.yaml (/usr/share/llama-swap/config.yaml)? [y/N] " >&2
|
||||
read answer
|
||||
case "$answer" in
|
||||
[Yy]* )
|
||||
$SUDO rm -r /usr/share/llama-swap
|
||||
break
|
||||
;;
|
||||
[Nn]* | "" )
|
||||
break
|
||||
;;
|
||||
* )
|
||||
echo "Invalid input. Please enter y or n."
|
||||
;;
|
||||
esac
|
||||
done
|
||||
fi
|
||||
|
||||
if id llama-swap >/dev/null 2>&1; then
|
||||
status "Removing llama-swap user..."
|
||||
$SUDO userdel llama-swap
|
||||
fi
|
||||
|
||||
if getent group llama-swap >/dev/null 2>&1; then
|
||||
status "Removing llama-swap group..."
|
||||
$SUDO groupdel llama-swap
|
||||
fi
|
||||
@@ -0,0 +1,25 @@
|
||||
.vite
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
@@ -0,0 +1,54 @@
|
||||
# React + TypeScript + Vite
|
||||
|
||||
This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
|
||||
|
||||
Currently, two official plugins are available:
|
||||
|
||||
- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Babel](https://babeljs.io/) for Fast Refresh
|
||||
- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
|
||||
|
||||
## Expanding the ESLint configuration
|
||||
|
||||
If you are developing a production application, we recommend updating the configuration to enable type-aware lint rules:
|
||||
|
||||
```js
|
||||
export default tseslint.config({
|
||||
extends: [
|
||||
// Remove ...tseslint.configs.recommended and replace with this
|
||||
...tseslint.configs.recommendedTypeChecked,
|
||||
// Alternatively, use this for stricter rules
|
||||
...tseslint.configs.strictTypeChecked,
|
||||
// Optionally, add this for stylistic rules
|
||||
...tseslint.configs.stylisticTypeChecked,
|
||||
],
|
||||
languageOptions: {
|
||||
// other options...
|
||||
parserOptions: {
|
||||
project: ['./tsconfig.node.json', './tsconfig.app.json'],
|
||||
tsconfigRootDir: import.meta.dirname,
|
||||
},
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
You can also install [eslint-plugin-react-x](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-x) and [eslint-plugin-react-dom](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-dom) for React-specific lint rules:
|
||||
|
||||
```js
|
||||
// eslint.config.js
|
||||
import reactX from 'eslint-plugin-react-x'
|
||||
import reactDom from 'eslint-plugin-react-dom'
|
||||
|
||||
export default tseslint.config({
|
||||
plugins: {
|
||||
// Add the react-x and react-dom plugins
|
||||
'react-x': reactX,
|
||||
'react-dom': reactDom,
|
||||
},
|
||||
rules: {
|
||||
// other rules...
|
||||
// Enable its recommended typescript rules
|
||||
...reactX.configs['recommended-typescript'].rules,
|
||||
...reactDom.configs.recommended.rules,
|
||||
},
|
||||
})
|
||||
```
|
||||
@@ -0,0 +1,28 @@
|
||||
import js from '@eslint/js'
|
||||
import globals from 'globals'
|
||||
import reactHooks from 'eslint-plugin-react-hooks'
|
||||
import reactRefresh from 'eslint-plugin-react-refresh'
|
||||
import tseslint from 'typescript-eslint'
|
||||
|
||||
export default tseslint.config(
|
||||
{ ignores: ['dist'] },
|
||||
{
|
||||
extends: [js.configs.recommended, ...tseslint.configs.recommended],
|
||||
files: ['**/*.{ts,tsx}'],
|
||||
languageOptions: {
|
||||
ecmaVersion: 2020,
|
||||
globals: globals.browser,
|
||||
},
|
||||
plugins: {
|
||||
'react-hooks': reactHooks,
|
||||
'react-refresh': reactRefresh,
|
||||
},
|
||||
rules: {
|
||||
...reactHooks.configs.recommended.rules,
|
||||
'react-refresh/only-export-components': [
|
||||
'warn',
|
||||
{ allowConstantExport: true },
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,13 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<link rel="icon" type="image/png" href="/favicon.ico" />
|
||||
<title>llama-swap</title>
|
||||
</head>
|
||||
<body >
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.tsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
Generated
+4028
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,33 @@
|
||||
{
|
||||
"name": "ui",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "tsc -b && vite build --emptyOutDir",
|
||||
"lint": "eslint .",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@tailwindcss/vite": "^4.1.8",
|
||||
"@tanstack/react-query": "^5.80.6",
|
||||
"react": "^19.1.0",
|
||||
"react-dom": "^19.1.0",
|
||||
"react-router-dom": "^7.6.2",
|
||||
"tailwindcss": "^4.1.8"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.25.0",
|
||||
"@types/react": "^19.1.2",
|
||||
"@types/react-dom": "^19.1.2",
|
||||
"@vitejs/plugin-react": "^4.4.1",
|
||||
"eslint": "^9.25.0",
|
||||
"eslint-plugin-react-hooks": "^5.2.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.19",
|
||||
"globals": "^16.0.0",
|
||||
"typescript": "~5.8.3",
|
||||
"typescript-eslint": "^8.30.1",
|
||||
"vite": "^6.3.5"
|
||||
}
|
||||
}
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 15 KiB |
@@ -0,0 +1,6 @@
|
||||
#root {
|
||||
max-width: 1280px;
|
||||
margin: 0 auto;
|
||||
padding: 2rem;
|
||||
text-align: center;
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
import { BrowserRouter as Router, Routes, Route, Navigate, NavLink } from "react-router-dom";
|
||||
import { useTheme } from "./contexts/ThemeProvider";
|
||||
import { APIProvider } from "./contexts/APIProvider";
|
||||
import LogViewerPage from "./pages/LogViewer";
|
||||
import ModelPage from "./pages/Models";
|
||||
|
||||
function App() {
|
||||
const theme = useTheme();
|
||||
return (
|
||||
<Router basename="/ui/">
|
||||
<APIProvider>
|
||||
<div>
|
||||
<nav className="bg-surface border-b border-border p-4">
|
||||
<div className="flex items-center justify-between mx-auto px-4">
|
||||
<h1>llama-swap</h1>
|
||||
<div className="flex space-x-4">
|
||||
<NavLink to="/" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
|
||||
Logs
|
||||
</NavLink>
|
||||
|
||||
<NavLink to="/models" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
|
||||
Models
|
||||
</NavLink>
|
||||
<button className="btn btn--sm" onClick={theme.toggleTheme}>
|
||||
{theme.isDarkMode ? "🌙" : "☀️"}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<main className="mx-auto py-4 px-4">
|
||||
<Routes>
|
||||
<Route path="/" element={<LogViewerPage />} />
|
||||
<Route path="/models" element={<ModelPage />} />
|
||||
<Route path="*" element={<Navigate to="/" replace />} />
|
||||
</Routes>
|
||||
</main>
|
||||
</div>
|
||||
</APIProvider>
|
||||
</Router>
|
||||
);
|
||||
}
|
||||
|
||||
export default App;
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 12 KiB |
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="35.93" height="32" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 228"><path fill="#00D8FF" d="M210.483 73.824a171.49 171.49 0 0 0-8.24-2.597c.465-1.9.893-3.777 1.273-5.621c6.238-30.281 2.16-54.676-11.769-62.708c-13.355-7.7-35.196.329-57.254 19.526a171.23 171.23 0 0 0-6.375 5.848a155.866 155.866 0 0 0-4.241-3.917C100.759 3.829 77.587-4.822 63.673 3.233C50.33 10.957 46.379 33.89 51.995 62.588a170.974 170.974 0 0 0 1.892 8.48c-3.28.932-6.445 1.924-9.474 2.98C17.309 83.498 0 98.307 0 113.668c0 15.865 18.582 31.778 46.812 41.427a145.52 145.52 0 0 0 6.921 2.165a167.467 167.467 0 0 0-2.01 9.138c-5.354 28.2-1.173 50.591 12.134 58.266c13.744 7.926 36.812-.22 59.273-19.855a145.567 145.567 0 0 0 5.342-4.923a168.064 168.064 0 0 0 6.92 6.314c21.758 18.722 43.246 26.282 56.54 18.586c13.731-7.949 18.194-32.003 12.4-61.268a145.016 145.016 0 0 0-1.535-6.842c1.62-.48 3.21-.974 4.76-1.488c29.348-9.723 48.443-25.443 48.443-41.52c0-15.417-17.868-30.326-45.517-39.844Zm-6.365 70.984c-1.4.463-2.836.91-4.3 1.345c-3.24-10.257-7.612-21.163-12.963-32.432c5.106-11 9.31-21.767 12.459-31.957c2.619.758 5.16 1.557 7.61 2.4c23.69 8.156 38.14 20.213 38.14 29.504c0 9.896-15.606 22.743-40.946 31.14Zm-10.514 20.834c2.562 12.94 2.927 24.64 1.23 33.787c-1.524 8.219-4.59 13.698-8.382 15.893c-8.067 4.67-25.32-1.4-43.927-17.412a156.726 156.726 0 0 1-6.437-5.87c7.214-7.889 14.423-17.06 21.459-27.246c12.376-1.098 24.068-2.894 34.671-5.345a134.17 134.17 0 0 1 1.386 6.193ZM87.276 214.515c-7.882 2.783-14.16 2.863-17.955.675c-8.075-4.657-11.432-22.636-6.853-46.752a156.923 156.923 0 0 1 1.869-8.499c10.486 2.32 22.093 3.988 34.498 4.994c7.084 9.967 14.501 19.128 21.976 27.15a134.668 134.668 0 0 1-4.877 4.492c-9.933 8.682-19.886 14.842-28.658 17.94ZM50.35 144.747c-12.483-4.267-22.792-9.812-29.858-15.863c-6.35-5.437-9.555-10.836-9.555-15.216c0-9.322 13.897-21.212 37.076-29.293c2.813-.98 5.757-1.905 8.812-2.773c3.204 10.42 7.406 21.315 12.477 32.332c-5.137 11.18-9.399 22.249-12.634 32.792a134.718 134.718 0 0 1-6.318-1.979Zm12.378-84.26c-4.811-24.587-1.616-43.134 6.425-47.789c8.564-4.958 27.502 2.111 47.463 19.835a144.318 144.318 0 0 1 3.841 3.545c-7.438 7.987-14.787 17.08-21.808 26.988c-12.04 1.116-23.565 2.908-34.161 5.309a160.342 160.342 0 0 1-1.76-7.887Zm110.427 27.268a347.8 347.8 0 0 0-7.785-12.803c8.168 1.033 15.994 2.404 23.343 4.08c-2.206 7.072-4.956 14.465-8.193 22.045a381.151 381.151 0 0 0-7.365-13.322Zm-45.032-43.861c5.044 5.465 10.096 11.566 15.065 18.186a322.04 322.04 0 0 0-30.257-.006c4.974-6.559 10.069-12.652 15.192-18.18ZM82.802 87.83a323.167 323.167 0 0 0-7.227 13.238c-3.184-7.553-5.909-14.98-8.134-22.152c7.304-1.634 15.093-2.97 23.209-3.984a321.524 321.524 0 0 0-7.848 12.897Zm8.081 65.352c-8.385-.936-16.291-2.203-23.593-3.793c2.26-7.3 5.045-14.885 8.298-22.6a321.187 321.187 0 0 0 7.257 13.246c2.594 4.48 5.28 8.868 8.038 13.147Zm37.542 31.03c-5.184-5.592-10.354-11.779-15.403-18.433c4.902.192 9.899.29 14.978.29c5.218 0 10.376-.117 15.453-.343c-4.985 6.774-10.018 12.97-15.028 18.486Zm52.198-57.817c3.422 7.8 6.306 15.345 8.596 22.52c-7.422 1.694-15.436 3.058-23.88 4.071a382.417 382.417 0 0 0 7.859-13.026a347.403 347.403 0 0 0 7.425-13.565Zm-16.898 8.101a358.557 358.557 0 0 1-12.281 19.815a329.4 329.4 0 0 1-23.444.823c-7.967 0-15.716-.248-23.178-.732a310.202 310.202 0 0 1-12.513-19.846h.001a307.41 307.41 0 0 1-10.923-20.627a310.278 310.278 0 0 1 10.89-20.637l-.001.001a307.318 307.318 0 0 1 12.413-19.761c7.613-.576 15.42-.876 23.31-.876H128c7.926 0 15.743.303 23.354.883a329.357 329.357 0 0 1 12.335 19.695a358.489 358.489 0 0 1 11.036 20.54a329.472 329.472 0 0 1-11 20.722Zm22.56-122.124c8.572 4.944 11.906 24.881 6.52 51.026c-.344 1.668-.73 3.367-1.15 5.09c-10.622-2.452-22.155-4.275-34.23-5.408c-7.034-10.017-14.323-19.124-21.64-27.008a160.789 160.789 0 0 1 5.888-5.4c18.9-16.447 36.564-22.941 44.612-18.3ZM128 90.808c12.625 0 22.86 10.235 22.86 22.86s-10.235 22.86-22.86 22.86s-22.86-10.235-22.86-22.86s10.235-22.86 22.86-22.86Z"></path></svg>
|
||||
|
After Width: | Height: | Size: 4.0 KiB |
@@ -0,0 +1,174 @@
|
||||
import { useRef, createContext, useState, useContext, useEffect, useCallback, useMemo, type ReactNode } from "react";
|
||||
|
||||
type ModelStatus = "ready" | "starting" | "stopping" | "stopped" | "shutdown" | "unknown";
|
||||
const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
|
||||
|
||||
export interface Model {
|
||||
id: string;
|
||||
state: ModelStatus;
|
||||
}
|
||||
|
||||
interface APIProviderType {
|
||||
models: Model[];
|
||||
listModels: () => Promise<Model[]>;
|
||||
unloadAllModels: () => Promise<void>;
|
||||
enableProxyLogs: (enabled: boolean) => void;
|
||||
enableUpstreamLogs: (enabled: boolean) => void;
|
||||
enableModelUpdates: (enabled: boolean) => void;
|
||||
proxyLogs: string;
|
||||
upstreamLogs: string;
|
||||
}
|
||||
|
||||
const APIContext = createContext<APIProviderType | undefined>(undefined);
|
||||
type APIProviderProps = {
|
||||
children: ReactNode;
|
||||
};
|
||||
|
||||
export function APIProvider({ children }: APIProviderProps) {
|
||||
const [proxyLogs, setProxyLogs] = useState("");
|
||||
const [upstreamLogs, setUpstreamLogs] = useState("");
|
||||
const proxyEventSource = useRef<EventSource | null>(null);
|
||||
const upstreamEventSource = useRef<EventSource | null>(null);
|
||||
|
||||
const [models, setModels] = useState<Model[]>([]);
|
||||
const modelStatusEventSource = useRef<EventSource | null>(null);
|
||||
|
||||
const appendLog = useCallback((newData: string, setter: React.Dispatch<React.SetStateAction<string>>) => {
|
||||
setter((prev) => {
|
||||
const updatedLog = prev + newData;
|
||||
return updatedLog.length > LOG_LENGTH_LIMIT ? updatedLog.slice(-LOG_LENGTH_LIMIT) : updatedLog;
|
||||
});
|
||||
}, []);
|
||||
|
||||
const handleProxyMessage = useCallback(
|
||||
(e: MessageEvent) => {
|
||||
appendLog(e.data, setProxyLogs);
|
||||
},
|
||||
[proxyLogs, appendLog]
|
||||
);
|
||||
|
||||
const handleUpstreamMessage = useCallback(
|
||||
(e: MessageEvent) => {
|
||||
appendLog(e.data, setUpstreamLogs);
|
||||
},
|
||||
[appendLog]
|
||||
);
|
||||
|
||||
const enableProxyLogs = useCallback(
|
||||
(enabled: boolean) => {
|
||||
if (enabled) {
|
||||
const eventSource = new EventSource("/logs/streamSSE/proxy");
|
||||
eventSource.onmessage = handleProxyMessage;
|
||||
proxyEventSource.current = eventSource;
|
||||
} else {
|
||||
proxyEventSource.current?.close();
|
||||
proxyEventSource.current = null;
|
||||
}
|
||||
},
|
||||
[handleProxyMessage]
|
||||
);
|
||||
|
||||
const enableUpstreamLogs = useCallback(
|
||||
(enabled: boolean) => {
|
||||
if (enabled) {
|
||||
const eventSource = new EventSource("/logs/streamSSE/upstream");
|
||||
eventSource.onmessage = handleUpstreamMessage;
|
||||
upstreamEventSource.current = eventSource;
|
||||
} else {
|
||||
upstreamEventSource.current?.close();
|
||||
upstreamEventSource.current = null;
|
||||
}
|
||||
},
|
||||
[upstreamEventSource, handleUpstreamMessage]
|
||||
);
|
||||
|
||||
const enableModelUpdates = useCallback(
|
||||
(enabled: boolean) => {
|
||||
if (enabled) {
|
||||
const eventSource = new EventSource("/api/modelsSSE");
|
||||
eventSource.onmessage = (e: MessageEvent) => {
|
||||
try {
|
||||
const models = JSON.parse(e.data) as Model[];
|
||||
setModels(models);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
};
|
||||
modelStatusEventSource.current = eventSource;
|
||||
} else {
|
||||
modelStatusEventSource.current?.close();
|
||||
modelStatusEventSource.current = null;
|
||||
}
|
||||
},
|
||||
[setModels]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
proxyEventSource.current?.close();
|
||||
upstreamEventSource.current?.close();
|
||||
modelStatusEventSource.current?.close();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const listModels = useCallback(async (): Promise<Model[]> => {
|
||||
try {
|
||||
const response = await fetch("/api/models/");
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
const data = await response.json();
|
||||
return data || [];
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch models:", error);
|
||||
return []; // Return empty array as fallback
|
||||
}
|
||||
}, []);
|
||||
|
||||
const unloadAllModels = useCallback(async () => {
|
||||
try {
|
||||
const response = await fetch(`/api/models/unload/`, {
|
||||
method: "POST",
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to unload models: ${response.status}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to unload models:", error);
|
||||
throw error; // Re-throw to let calling code handle it
|
||||
}
|
||||
}, []);
|
||||
|
||||
const value = useMemo(
|
||||
() => ({
|
||||
models,
|
||||
listModels,
|
||||
unloadAllModels,
|
||||
enableProxyLogs,
|
||||
enableUpstreamLogs,
|
||||
enableModelUpdates,
|
||||
proxyLogs,
|
||||
upstreamLogs,
|
||||
}),
|
||||
[
|
||||
models,
|
||||
listModels,
|
||||
unloadAllModels,
|
||||
enableProxyLogs,
|
||||
enableUpstreamLogs,
|
||||
enableModelUpdates,
|
||||
proxyLogs,
|
||||
upstreamLogs,
|
||||
]
|
||||
);
|
||||
|
||||
return <APIContext.Provider value={value}>{children}</APIContext.Provider>;
|
||||
}
|
||||
|
||||
export function useAPI() {
|
||||
const context = useContext(APIContext);
|
||||
if (context === undefined) {
|
||||
throw new Error("useAPI must be used within an APIProvider");
|
||||
}
|
||||
return context;
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
import { createContext, useContext, useEffect, type ReactNode } from "react";
|
||||
import { usePersistentState } from "../hooks/usePersistentState";
|
||||
|
||||
type ThemeContextType = {
|
||||
isDarkMode: boolean;
|
||||
toggleTheme: () => void;
|
||||
};
|
||||
|
||||
const ThemeContext = createContext<ThemeContextType | undefined>(undefined);
|
||||
|
||||
type ThemeProviderProps = {
|
||||
children: ReactNode;
|
||||
};
|
||||
|
||||
export function ThemeProvider({ children }: ThemeProviderProps) {
|
||||
const [isDarkMode, setIsDarkMode] = usePersistentState<boolean>("theme", false);
|
||||
|
||||
useEffect(() => {
|
||||
document.documentElement.setAttribute("data-theme", isDarkMode ? "dark" : "light");
|
||||
}, [isDarkMode]);
|
||||
|
||||
const toggleTheme = () => setIsDarkMode((prev) => !prev);
|
||||
|
||||
return <ThemeContext.Provider value={{ isDarkMode, toggleTheme }}>{children}</ThemeContext.Provider>;
|
||||
}
|
||||
|
||||
export function useTheme(): ThemeContextType {
|
||||
const context = useContext(ThemeContext);
|
||||
if (context === undefined) {
|
||||
throw new Error("useTheme must be used within a ThemeProvider");
|
||||
}
|
||||
return context;
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
import { useState, useEffect, useCallback } from "react";
|
||||
|
||||
export function usePersistentState<T>(key: string, initialValue: T): [T, (value: T | ((prevState: T) => T)) => void] {
|
||||
const [state, setState] = useState<T>(() => {
|
||||
if (typeof window === "undefined") return initialValue;
|
||||
try {
|
||||
const saved = localStorage.getItem(key);
|
||||
return saved !== null ? JSON.parse(saved) : initialValue;
|
||||
} catch (e) {
|
||||
console.error(`Error parsing stored value for ${key}`, e);
|
||||
return initialValue;
|
||||
}
|
||||
});
|
||||
|
||||
const setPersistentState = useCallback(
|
||||
(value: T | ((prevState: T) => T)) => {
|
||||
setState((prev) => {
|
||||
const nextValue = typeof value === "function" ? (value as (prevState: T) => T)(prev) : value;
|
||||
try {
|
||||
localStorage.setItem(key, JSON.stringify(nextValue));
|
||||
} catch (e) {
|
||||
console.error(`Error saving value for ${key}`, e);
|
||||
}
|
||||
return nextValue;
|
||||
});
|
||||
},
|
||||
[key]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
try {
|
||||
localStorage.setItem(key, JSON.stringify(state));
|
||||
} catch (e) {
|
||||
console.error(`Error saving value for ${key}`, e);
|
||||
}
|
||||
}, [key, state]);
|
||||
|
||||
return [state, setPersistentState];
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
@import "tailwindcss";
|
||||
@custom-variant dark (&:where([data-theme=dark], [data-theme=dark] *));
|
||||
|
||||
@theme {
|
||||
--color-background: rgba(252, 252, 249, 1);
|
||||
--color-surface: rgba(255, 255, 253, 1);
|
||||
|
||||
/* text colors */
|
||||
--color-txtmain: rgba(19, 52, 59, 1);
|
||||
--color-txtsecondary: rgba(98, 108, 113, 1);
|
||||
--color-navlink-active: rgba(245, 245, 245, 1);
|
||||
|
||||
--color-primary: rgba(50, 184, 198, 1);
|
||||
|
||||
--color-primary-hover: rgba(29, 116, 128, 1);
|
||||
--color-primary-active: rgba(26, 104, 115, 1);
|
||||
--color-secondary: rgba(94, 82, 64, 0.12);
|
||||
--color-secondary-hover: rgba(94, 82, 64, 0.2);
|
||||
--color-secondary-active: rgba(94, 82, 64, 0.25);
|
||||
--color-border: rgba(94, 82, 64, 0.3);
|
||||
--color-btn-primary-text: rgba(252, 252, 249, 1);
|
||||
--color-card-border: rgba(94, 82, 64, 0.12);
|
||||
--color-card-border-inner: rgba(94, 82, 64, 0.12);
|
||||
--color-error: rgba(192, 21, 47, 1);
|
||||
--color-success: rgba(33, 128, 141, 1);
|
||||
--color-warning: rgb(244, 155, 0);
|
||||
--color-info: rgba(98, 108, 113, 1);
|
||||
--color-focus-ring: rgba(33, 128, 141, 0.4);
|
||||
--color-select-caret: rgba(19, 52, 59, 0.8);
|
||||
--color-btn-border: rgba(94, 82, 64, 0.7);
|
||||
}
|
||||
|
||||
@layer theme {
|
||||
/* over ride theme for dark mode */
|
||||
[data-theme="dark"] {
|
||||
--color-background: rgba(31, 33, 33, 1);
|
||||
--color-surface: rgba(38, 40, 40, 1);
|
||||
/* text colors */
|
||||
--color-txtmain: rgba(245, 245, 245, 1);
|
||||
--color-txtsecondary: rgba(167, 169, 169, 0.7);
|
||||
|
||||
--color-navlink-active: rgba(245, 245, 245, 1);
|
||||
|
||||
--color-primary: rgba(33, 128, 141, 1);
|
||||
--color-primary-hover: rgba(45, 166, 178, 1);
|
||||
--color-primary-active: rgba(41, 150, 161, 1);
|
||||
--color-secondary: rgba(119, 124, 124, 0.15);
|
||||
--color-secondary-hover: rgba(119, 124, 124, 0.25);
|
||||
--color-secondary-active: rgba(119, 124, 124, 0.3);
|
||||
--color-border: rgba(119, 124, 124, 0.3);
|
||||
--color-error: rgba(255, 84, 89, 1);
|
||||
--color-success: rgba(50, 184, 198, 1);
|
||||
--color-warning: rgb(244, 155, 0);
|
||||
--color-info: rgba(167, 169, 169, 1);
|
||||
--color-focus-ring: rgba(50, 184, 198, 0.4);
|
||||
--color-btn-primary-text: rgba(19, 52, 59, 1);
|
||||
--color-card-border: rgba(119, 124, 124, 0.2);
|
||||
--color-card-border-inner: rgba(119, 124, 124, 0.15);
|
||||
--shadow-inset-sm: inset 0 1px 0 rgba(255, 255, 255, 0.1), inset 0 -1px 0 rgba(0, 0, 0, 0.15);
|
||||
--button-border-secondary: rgba(119, 124, 124, 0.2);
|
||||
}
|
||||
}
|
||||
|
||||
@layer base {
|
||||
body {
|
||||
/* example of how colors using theme colors*/
|
||||
@apply bg-background text-txtmain;
|
||||
}
|
||||
|
||||
h1 {
|
||||
@apply text-4xl text-txtmain font-bold pb-4;
|
||||
}
|
||||
h2 {
|
||||
@apply text-3xl text-txtmain font-bold pb-4;
|
||||
}
|
||||
h3 {
|
||||
@apply text-2xl text-txtmain font-bold pb-4;
|
||||
}
|
||||
h4 {
|
||||
@apply text-xl text-txtmain font-bold pb-4;
|
||||
}
|
||||
h5 {
|
||||
@apply text-lg text-txtmain font-bold pb-4;
|
||||
}
|
||||
h6 {
|
||||
@apply text-base text-txtmain font-bold pb-4;
|
||||
}
|
||||
}
|
||||
|
||||
/* define CSS classes here for specific types of components */
|
||||
@layer components {
|
||||
.container {
|
||||
@apply px-4;
|
||||
}
|
||||
|
||||
/* Navigation Header */
|
||||
|
||||
.navlink {
|
||||
@apply text-txtsecondary hover:bg-secondary hover:text-txtmain rounded-lg p-2;
|
||||
}
|
||||
.navlink.active {
|
||||
@apply bg-primary text-navlink-active;
|
||||
}
|
||||
|
||||
/* Card component */
|
||||
.card {
|
||||
@apply bg-surface rounded-lg border border-card-border shadow-sm overflow-hidden p-4;
|
||||
}
|
||||
|
||||
.card:hover {
|
||||
@apply shadow-md;
|
||||
}
|
||||
|
||||
.card__body {
|
||||
@apply p-4;
|
||||
}
|
||||
|
||||
.card__header,
|
||||
.card__footer {
|
||||
@apply p-4 border-b border-card-border-inner;
|
||||
}
|
||||
|
||||
/* Status Badges */
|
||||
.status {
|
||||
@apply inline-block px-2 py-1 text-xs font-medium rounded-full;
|
||||
}
|
||||
|
||||
.status--ready {
|
||||
@apply bg-success/10 text-success;
|
||||
}
|
||||
|
||||
.status--starting,
|
||||
.status--stopping {
|
||||
@apply bg-warning/10 text-warning;
|
||||
}
|
||||
|
||||
.status--stopped {
|
||||
@apply bg-error/10 text-error;
|
||||
}
|
||||
|
||||
/* Buttons */
|
||||
.btn {
|
||||
@apply bg-surface p-2 px-4 text-sm rounded-full border border-2 transition-colors duration-200 border-btn-border;
|
||||
}
|
||||
|
||||
.btn--sm {
|
||||
@apply px-2 py-0.5 text-xs;
|
||||
}
|
||||
|
||||
.btn:disabled {
|
||||
@apply opacity-50 cursor-not-allowed;
|
||||
}
|
||||
}
|
||||
|
||||
@layer utilities {
|
||||
.ml-2 {
|
||||
margin-left: 0.5rem;
|
||||
}
|
||||
|
||||
.my-8 {
|
||||
margin-top: 2rem;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
import { StrictMode } from "react";
|
||||
import { createRoot } from "react-dom/client";
|
||||
import "./index.css";
|
||||
import App from "./App.tsx";
|
||||
import { ThemeProvider } from "./contexts/ThemeProvider";
|
||||
|
||||
createRoot(document.getElementById("root")!).render(
|
||||
<StrictMode>
|
||||
<ThemeProvider>
|
||||
<App />
|
||||
</ThemeProvider>
|
||||
</StrictMode>
|
||||
);
|
||||
@@ -0,0 +1,160 @@
|
||||
import { useState, useEffect, useRef, useMemo, useCallback } from "react";
|
||||
import { useAPI } from "../contexts/APIProvider";
|
||||
import { usePersistentState } from "../hooks/usePersistentState";
|
||||
|
||||
const LogViewer = () => {
|
||||
const { proxyLogs, upstreamLogs, enableProxyLogs, enableUpstreamLogs } = useAPI();
|
||||
|
||||
useEffect(() => {
|
||||
enableProxyLogs(true);
|
||||
enableUpstreamLogs(true);
|
||||
return () => {
|
||||
enableProxyLogs(false);
|
||||
enableUpstreamLogs(false);
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-5">
|
||||
<LogPanel id="proxy" title="Proxy Logs" logData={proxyLogs} />
|
||||
<LogPanel id="upstream" title="Upstream Logs" logData={upstreamLogs} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
interface LogPanelProps {
|
||||
id: string;
|
||||
title: string;
|
||||
logData: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export const LogPanel = ({ id, title, logData, className }: LogPanelProps) => {
|
||||
const [filterRegex, setFilterRegex] = useState("");
|
||||
const [panelState, setPanelState] = usePersistentState<"hide" | "small" | "max">(
|
||||
`logPanel-${id}-panelState`,
|
||||
"small"
|
||||
);
|
||||
const [fontSize, setFontSize] = usePersistentState<"xxs" | "xs" | "small" | "normal">(
|
||||
`logPanel-${id}-fontSize`,
|
||||
"normal"
|
||||
);
|
||||
const [wrapText, setTextWrap] = usePersistentState(`logPanel-${id}-wrapText`, false);
|
||||
|
||||
const textWrapClass = useMemo(() => {
|
||||
return wrapText ? "whitespace-pre-wrap" : "whitespace-pre";
|
||||
}, [wrapText]);
|
||||
|
||||
const toggleFontSize = useCallback(() => {
|
||||
setFontSize((prev) => {
|
||||
switch (prev) {
|
||||
case "xxs":
|
||||
return "xs";
|
||||
case "xs":
|
||||
return "small";
|
||||
case "small":
|
||||
return "normal";
|
||||
case "normal":
|
||||
return "xxs";
|
||||
}
|
||||
});
|
||||
}, []);
|
||||
|
||||
const togglePanelState = useCallback(() => {
|
||||
setPanelState((prev) => {
|
||||
if (prev === "small") return "max";
|
||||
if (prev === "hide") return "small";
|
||||
return "hide";
|
||||
});
|
||||
}, []);
|
||||
|
||||
const fontSizeClass = useMemo(() => {
|
||||
switch (fontSize) {
|
||||
case "xxs":
|
||||
return "text-[0.5rem]"; // 0.5rem (8px)
|
||||
case "xs":
|
||||
return "text-[0.75rem]"; // 0.75rem (12px)
|
||||
case "small":
|
||||
return "text-[0.875rem]"; // 0.875rem (14px)
|
||||
case "normal":
|
||||
return "text-base"; // 1rem (16px)
|
||||
}
|
||||
}, [fontSize]);
|
||||
|
||||
const filteredLogs = useMemo(() => {
|
||||
if (!filterRegex) return logData;
|
||||
try {
|
||||
const regex = new RegExp(filterRegex, "i");
|
||||
const lines = logData.split("\n");
|
||||
const filtered = lines.filter((line) => regex.test(line));
|
||||
return filtered.join("\n");
|
||||
} catch (e) {
|
||||
return logData; // Return unfiltered if regex is invalid
|
||||
}
|
||||
}, [logData, filterRegex]);
|
||||
|
||||
// auto scroll to bottom
|
||||
const preTagRef = useRef<HTMLPreElement>(null);
|
||||
useEffect(() => {
|
||||
if (!preTagRef.current) return;
|
||||
preTagRef.current.scrollTop = preTagRef.current.scrollHeight;
|
||||
}, [filteredLogs]);
|
||||
|
||||
return (
|
||||
<div className={`bg-surface border border-border rounded-lg overflow-hidden flex flex-col ${className || ""}`}>
|
||||
<div className="p-4 border-b border-border bg-secondary">
|
||||
<div className="flex flex-col md:flex-row md:items-center md:justify-between gap-4">
|
||||
{/* Title - Always full width on mobile, normal on desktop */}
|
||||
<div className="w-full md:w-auto" onClick={togglePanelState}>
|
||||
<h3 className="m-0 text-lg">{title}</h3>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col sm:flex-row gap-4 w-full md:w-auto">
|
||||
{/* Sizing Buttons - Stacks vertically on mobile */}
|
||||
<div className="flex flex-wrap gap-2">
|
||||
<button className="btn" onClick={togglePanelState}>
|
||||
size: {panelState}
|
||||
</button>
|
||||
<button className="btn" onClick={toggleFontSize}>
|
||||
font: {fontSize}
|
||||
</button>
|
||||
<button className="btn" onClick={() => setTextWrap((prev) => !prev)}>
|
||||
{wrapText ? "wrap" : "wrap off"}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Filtering Options - Full width on mobile, normal on desktop */}
|
||||
<div className="flex flex-1 min-w-0 gap-2">
|
||||
<input
|
||||
type="text"
|
||||
className="flex-1 min-w-[120px] text-sm border p-2 rounded"
|
||||
placeholder="Filter logs..."
|
||||
value={filterRegex}
|
||||
onChange={(e) => setFilterRegex(e.target.value)}
|
||||
/>
|
||||
<button className="btn" onClick={() => setFilterRegex("")}>
|
||||
Clear
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{panelState !== "hide" && (
|
||||
<div className="flex-1 bg-background font-mono text-sm leading-[1.4] p-3">
|
||||
<pre
|
||||
ref={preTagRef}
|
||||
className={`flex-1 p-4 overflow-y-auto whitespace-pre min-h-0 ${textWrapClass} ${fontSizeClass}`}
|
||||
style={{
|
||||
maxHeight: panelState === "max" ? "1500px" : "500px",
|
||||
}}
|
||||
>
|
||||
{filteredLogs}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default LogViewer;
|
||||
@@ -0,0 +1,74 @@
|
||||
import { useState, useEffect, useCallback } from "react";
|
||||
import { useAPI } from "../contexts/APIProvider";
|
||||
import { LogPanel } from "./LogViewer";
|
||||
|
||||
export default function ModelsPage() {
|
||||
const { models, enableModelUpdates, unloadAllModels, upstreamLogs, enableUpstreamLogs } = useAPI();
|
||||
const [isUnloading, setIsUnloading] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
enableModelUpdates(true);
|
||||
enableUpstreamLogs(true);
|
||||
return () => {
|
||||
enableModelUpdates(false);
|
||||
enableUpstreamLogs(false);
|
||||
};
|
||||
}, []);
|
||||
|
||||
const handleUnloadAllModels = useCallback(async () => {
|
||||
setIsUnloading(true);
|
||||
try {
|
||||
await unloadAllModels();
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
} finally {
|
||||
// at least give it a second to show the unloading message
|
||||
setTimeout(() => {
|
||||
setIsUnloading(false);
|
||||
}, 1000);
|
||||
}
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div className="h-screen">
|
||||
<div className="flex flex-col md:flex-row gap-4">
|
||||
{/* Left Column */}
|
||||
<div className="w-full md:w-1/2 flex items-top">
|
||||
<div className="card w-full">
|
||||
<h2 className="">Models</h2>
|
||||
<button className="btn" onClick={handleUnloadAllModels} disabled={isUnloading}>
|
||||
{isUnloading ? "Unloading..." : "Unload All Models"}
|
||||
</button>
|
||||
<table className="w-full mt-4">
|
||||
<thead>
|
||||
<tr className="border-b border-primary">
|
||||
<th className="text-left p-2">Name</th>
|
||||
<th className="text-left p-2">State</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{models.map((model) => (
|
||||
<tr key={model.id} className="border-b hover:bg-secondary-hover border-border">
|
||||
<td className="p-2">
|
||||
<a href={`/upstream/${model.id}/`} className="underline" target="top">
|
||||
{model.id}
|
||||
</a>
|
||||
</td>
|
||||
<td className="p-2">
|
||||
<span className={`status status--${model.state}`}>{model.state}</span>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Right Column */}
|
||||
<div className="w-full md:w-1/2 flex items-top">
|
||||
<LogPanel id="modelsupstream" title="Upstream Logs" logData={upstreamLogs} className="h-full" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Vendored
+1
@@ -0,0 +1 @@
|
||||
/// <reference types="vite/client" />
|
||||
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
|
||||
"target": "ES2020",
|
||||
"useDefineForClassFields": true,
|
||||
"lib": ["ES2020", "DOM", "DOM.Iterable"],
|
||||
"module": "ESNext",
|
||||
"skipLibCheck": true,
|
||||
|
||||
/* Bundler mode */
|
||||
"moduleResolution": "bundler",
|
||||
"allowImportingTsExtensions": true,
|
||||
"verbatimModuleSyntax": true,
|
||||
"moduleDetection": "force",
|
||||
"noEmit": true,
|
||||
"jsx": "react-jsx",
|
||||
|
||||
/* Linting */
|
||||
"strict": true,
|
||||
"noUnusedLocals": true,
|
||||
"noUnusedParameters": true,
|
||||
"erasableSyntaxOnly": true,
|
||||
"noFallthroughCasesInSwitch": true,
|
||||
"noUncheckedSideEffectImports": true
|
||||
},
|
||||
"include": ["src"]
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"files": [],
|
||||
"references": [
|
||||
{ "path": "./tsconfig.app.json" },
|
||||
{ "path": "./tsconfig.node.json" }
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
|
||||
"target": "ES2022",
|
||||
"lib": ["ES2023"],
|
||||
"module": "ESNext",
|
||||
"skipLibCheck": true,
|
||||
|
||||
/* Bundler mode */
|
||||
"moduleResolution": "bundler",
|
||||
"allowImportingTsExtensions": true,
|
||||
"verbatimModuleSyntax": true,
|
||||
"moduleDetection": "force",
|
||||
"noEmit": true,
|
||||
|
||||
/* Linting */
|
||||
"strict": true,
|
||||
"noUnusedLocals": true,
|
||||
"noUnusedParameters": true,
|
||||
"erasableSyntaxOnly": true,
|
||||
"noFallthroughCasesInSwitch": true,
|
||||
"noUncheckedSideEffectImports": true
|
||||
},
|
||||
"include": ["vite.config.ts"]
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
import { defineConfig } from "vite";
|
||||
import react from "@vitejs/plugin-react";
|
||||
import tailwindcss from "@tailwindcss/vite";
|
||||
|
||||
// https://vite.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [react(), tailwindcss()],
|
||||
base: "/ui/",
|
||||
build: {
|
||||
outDir: "../proxy/ui_dist",
|
||||
assetsDir: "assets",
|
||||
},
|
||||
server: {
|
||||
proxy: {
|
||||
"/api": "http://localhost:8080", // Proxy API calls to Go backend during development
|
||||
"/logs": "http://localhost:8080",
|
||||
"/upstream": "http://localhost:8080",
|
||||
},
|
||||
},
|
||||
});
|
||||
Reference in New Issue
Block a user