Compare commits

..

2 Commits

Author SHA1 Message Date
Benson Wong 014a2fa9a3 fix bug checking incorrect error 2025-03-20 15:26:39 -07:00
Benson Wong 5ceaef6144 add override for windows 2025-03-20 13:21:03 -07:00
21 changed files with 179 additions and 1024 deletions
+9 -14
View File
@@ -1,7 +1,4 @@
![llama-swap header image](header2.png)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml)
![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap)
![llama-swap header image](header.jpeg)
# llama-swap
@@ -67,8 +64,8 @@ models:
# Default (and minimum) is 15 seconds
healthCheckTimeout: 60
# Valid log levels: debug, info (default), warn, error
logLevel: info
# Write HTTP logs (useful for troubleshooting), defaults to false
logRequests: true
# define valid model values and the upstream server start
models:
@@ -219,15 +216,9 @@ Of course, CLI access is also supported:
# sends up to the last 10KB of logs
curl http://host/logs'
# streams combined logs
# streams logs
curl -Ns 'http://host/logs/stream'
# just llama-swap's logs
curl -Ns 'http://host/logs/stream/proxy'
# just upstream's logs
curl -Ns 'http://host/logs/stream/upstream'
# stream and filter logs with linux pipes
curl -Ns http://host/logs/stream | grep 'eval time'
@@ -269,4 +260,8 @@ WantedBy=multi-user.target
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date)](https://www.star-history.com/#mostlygeek/llama-swap&Date)
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
</picture>
+3 -3
View File
@@ -1,9 +1,9 @@
# Seconds to wait for llama.cpp to be available to serve requests
# Default (and minimum): 15 seconds
healthCheckTimeout: 90
healthCheckTimeout: 15
# valid log levels: debug, info (default), warn, error
logLevel: debug
# Log HTTP requests helpful for troubleshoot, defaults to False
logRequests: true
models:
"llama":
-153
View File
@@ -1,153 +0,0 @@
# aider, QwQ, Qwen-Coder 2.5 and llama-swap
This guide show how to use aider and llama-swap to get a 100% local coding co-pilot setup. The focus is on the trickest part which is configuring aider, llama-swap and llama-server to work together.
## Here's what you you need:
- aider - [installation docs](https://aider.chat/docs/install.html)
- llama-server - [download latest release](https://github.com/ggml-org/llama.cpp/releases)
- llama-swap - [download latest release](https://github.com/mostlygeek/llama-swap/releases)
- [QwQ 32B](https://huggingface.co/bartowski/Qwen_QwQ-32B-GGUF) and [Qwen Coder 2.5 32B](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF) models
- 24GB VRAM video card
## Running aider
The goal is getting this command line to work:
```sh
aider --architect \
--no-show-model-warnings \
--model openai/QwQ \
--editor-model openai/qwen-coder-32B \
--model-settings-file aider.model.settings.yml \
--openai-api-key "sk-na" \
--openai-api-base "http://10.0.1.24:8080/v1" \
```
Set `--openai-api-base` to the IP and port where your llama-swap is running.
## Create an aider model settings file
```yaml
# aider.model.settings.yml
#
# !!! important: model names must match llama-swap configuration names !!!
#
- name: "openai/QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/qwen-coder-32B"
editor_model_name: "openai/qwen-coder-32B"
- name: "openai/qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/qwen-coder-32B"
```
## llama-swap configuration
```yaml
# config.yaml
# The parameters are tweaked to fit model+context into 24GB VRAM GPUs
models:
"qwen-coder-32B":
proxy: "http://127.0.0.1:8999"
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 8999 --flash-attn --slots
--ctx-size 16000
--cache-type-k q8_0 --cache-type-v q8_0
-ngl 99
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
"QwQ":
proxy: "http://127.0.0.1:9503"
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 9503 --flash-attn --metrics--slots
--cache-type-k q8_0 --cache-type-v q8_0
--ctx-size 32000
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
--temp 0.6 --repeat-penalty 1.1 --dry-multiplier 0.5
--min-p 0.01 --top-k 40 --top-p 0.95
-ngl 99
--model /mnt/nvme/models/bartowski/Qwen_QwQ-32B-Q4_K_M.gguf
```
## Advanced, Dual GPU Configuration
If you have _dual 24GB GPUs_ you can use llama-swap profiles to avoid swapping between QwQ and Qwen Coder.
In llama-swap's configuration file:
1. add a `profiles` section with `aider` as the profile name
2. using the `env` field to specify the GPU IDs for each model
```yaml
# config.yaml
# Add a profile for aider
profiles:
aider:
- qwen-coder-32B
- QwQ
models:
"qwen-coder-32B":
# manually set the GPU to run on
env:
- "CUDA_VISIBLE_DEVICES=0"
proxy: "http://127.0.0.1:8999"
cmd: /path/to/llama-server ...
"QwQ":
# manually set the GPU to run on
env:
- "CUDA_VISIBLE_DEVICES=1"
proxy: "http://127.0.0.1:9503"
cmd: /path/to/llama-server ...
```
Append the profile tag, `aider:`, to the model names in the model settings file
```yaml
# aider.model.settings.yml
- name: "openai/aider:QwQ"
weak_model_name: "openai/aider:qwen-coder-32B-aider"
editor_model_name: "openai/aider:qwen-coder-32B-aider"
- name: "openai/aider:qwen-coder-32B"
editor_model_name: "openai/aider:qwen-coder-32B-aider"
```
Run aider with:
```sh
$ aider --architect \
--no-show-model-warnings \
--model openai/aider:QwQ \
--editor-model openai/aider:qwen-coder-32B \
--config aider.conf.yml \
--model-settings-file aider.model.settings.yml
--openai-api-key "sk-na" \
--openai-api-base "http://10.0.1.24:8080/v1"
```
@@ -1,28 +0,0 @@
# this makes use of llama-swap's profile feature to
# keep the architect and editor models in VRAM on different GPUs
- name: "openai/aider:QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/aider:qwen-coder-32B"
editor_model_name: "openai/aider:qwen-coder-32B"
- name: "openai/aider:qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/aider:qwen-coder-32B"
@@ -1,26 +0,0 @@
- name: "openai/QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/qwen-coder-32B"
editor_model_name: "openai/qwen-coder-32B"
- name: "openai/qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/qwen-coder-32B"
-49
View File
@@ -1,49 +0,0 @@
healthCheckTimeout: 300
logLevel: debug
profiles:
aider:
- qwen-coder-32B
- QwQ
models:
"qwen-coder-32B":
env:
- "CUDA_VISIBLE_DEVICES=0"
aliases:
- coder
proxy: "http://127.0.0.1:8999"
# set appropriate paths for your environment
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 8999 --flash-attn --slots
--ctx-size 16000
--ctx-size-draft 16000
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
--model-draft /path/to/Qwen2.5-Coder-1.5B-Instruct-Q8_0.gguf
-ngl 99 -ngld 99
--draft-max 16 --draft-min 4 --draft-p-min 0.4
--cache-type-k q8_0 --cache-type-v q8_0
"QwQ":
env:
- "CUDA_VISIBLE_DEVICES=1"
proxy: "http://127.0.0.1:9503"
# set appropriate paths for your environment
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 9503
--flash-attn --metrics
--slots
--model /path/to/Qwen_QwQ-32B-Q4_K_M.gguf
--cache-type-k q8_0 --cache-type-v q8_0
--ctx-size 32000
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
--temp 0.6
--repeat-penalty 1.1
--dry-multiplier 0.5
--min-p 0.01
--top-k 40
--top-p 0.95
-ngl 99 -ngld 99
+1 -1
View File
@@ -37,7 +37,7 @@ require (
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/net v0.37.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
-2
View File
@@ -86,8 +86,6 @@ golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 351 KiB

-1
View File
@@ -27,7 +27,6 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"`
LogLevel string `yaml:"logLevel"`
Models map[string]ModelConfig `yaml:"models"`
Profiles map[string][]string `yaml:"profiles"`
+78 -191
View File
@@ -12,65 +12,32 @@
flex-direction: column;
font-family: "Courier New", Courier, monospace;
}
.log-container {
display: flex;
flex: 1;
gap: 0.5em;
#log-controls {
margin: 0.5em;
min-height: 0;
}
.log-column {
display: flex;
flex-direction: column;
align-items: center;
justify-content: space-between; /* Spaces out elements evenly */
}
#log-controls input {
flex: 1;
min-width: 0;
transition: flex 0.3s ease;
}
.log-column.minimized {
flex: 0.1;
max-width: 50px;
border: 1px solid #777;
color: green;
#log-controls input:focus {
outline: none; /* Ensures no outline is shown when the input is focused */
}
.log-controls {
display: grid;
grid-template-columns: 1fr auto;
gap: 0.5em;
margin-bottom: 0.5em;
}
.log-controls input {
width: 100%;
padding: 4px;
}
.log-controls input:focus {
outline: none;
}
.log-stream {
#log-stream {
flex: 1;
margin: 0.5em;
padding: 1em;
background: #f4f4f4;
overflow-y: auto;
white-space: pre-wrap;
word-wrap: break-word;
min-height: 0;
white-space: pre-wrap; /* Ensures line wrapping */
word-wrap: break-word; /* Ensures long words wrap */
}
.regex-error {
background-color: #ff0000 !important;
}
/* Make headers clickable and show pointer cursor */
h2 {
cursor: pointer;
user-select: none;
margin: 0 0 0.5em 0;
padding: 0.5em;
}
h2:hover {
background-color: rgba(0, 0, 0, 0.05);
}
/* Dark mode styles */
@media (prefers-color-scheme: dark) {
body {
@@ -78,181 +45,101 @@
color: #fff;
}
.log-stream {
#log-stream {
background: #444;
color: #fff;
}
.log-controls input {
#log-controls input {
background: #555;
color: #fff;
border: 1px solid #777;
}
.log-controls button {
#log-controls button {
background: #555;
color: #fff;
border: 1px solid #777;
}
h2:hover {
background-color: rgba(255, 255, 255, 0.1);
}
}
/* Hide content when minimized */
.log-column.minimized .log-controls,
.log-column.minimized .log-stream {
display: none;
}
.log-column.minimized h2 {
writing-mode: vertical-rl;
text-orientation: mixed;
transform: rotate(180deg);
white-space: nowrap;
margin: auto;
}
</style>
</head>
<body>
<div class="log-container">
<div class="log-column">
<h2>Proxy Logs</h2>
<div class="log-controls">
<input type="text" id="proxy-filter-input" placeholder="proxy regex filter">
<button id="proxy-clear-button">clear</button>
</div>
<pre class="log-stream" id="proxy-log-stream">Waiting for proxy logs...</pre>
</div>
<div class="log-column minimized">
<h2>Upstream Logs</h2>
<div class="log-controls">
<input type="text" id="upstream-filter-input" placeholder="upstream regex filter">
<button id="upstream-clear-button">clear</button>
</div>
<pre class="log-stream" id="upstream-log-stream">Waiting for upstream logs...</pre>
</div>
<pre id="log-stream">Waiting for logs...</pre>
<div id="log-controls">
<input type="text" id="filter-input" placeholder="regex filter">
<button id="clear-button">clear</button>
</div>
<script>
class LogStream {
constructor(streamElement, filterInput, clearButton, endpoint) {
this.streamElement = streamElement;
this.filterInput = filterInput;
this.clearButton = clearButton;
this.endpoint = endpoint;
this.logData = "";
this.regexFilter = null;
this.eventSource = null;
const logStream = document.getElementById('log-stream');
const filterInput = document.getElementById('filter-input');
var logData = "";
let regexFilter = null;
this.initialize();
}
function setupEventSource() {
if (typeof(EventSource) !== "undefined") {
const eventSource = new EventSource("/logs/streamSSE");
initialize() {
this.filterInput.addEventListener('input', () => this.updateFilter());
this.clearButton.addEventListener('click', () => {
this.filterInput.value = "";
this.regexFilter = null;
this.render();
});
this.setupEventSource();
}
setupEventSource() {
if (typeof(EventSource) === "undefined") {
this.logData = "SSE Not supported by this browser.";
this.render();
return;
}
const connect = () => {
this.eventSource = new EventSource(this.endpoint);
this.eventSource.onmessage = (event) => {
this.logData += event.data;
this.render();
};
this.eventSource.onerror = (err) => {
// Close the current connection
this.eventSource.close();
this.logData += "\nConnection lost. Retrying in 5 seconds...\n";
this.render();
// Attempt to reconnect after 5 seconds
setTimeout(() => {
this.logData += "Attempting to reconnect...\n";
this.render();
connect();
}, 5000);
};
eventSource.onmessage = function(event) {
logData += event.data;
render()
};
// Initial connection
connect();
}
render() {
let content = this.logData;
if (this.regexFilter) {
const lines = content.split('\n');
const filteredLines = lines.filter(line => this.regexFilter.test(line));
content = filteredLines.length > 0 ? filteredLines.join('\n') + '\n' : "";
}
this.streamElement.textContent = content;
this.streamElement.scrollTop = this.streamElement.scrollHeight;
}
updateFilter() {
const pattern = this.filterInput.value.trim();
this.filterInput.classList.remove('regex-error');
if (!pattern) {
this.regexFilter = null;
this.render();
return;
}
try {
this.regexFilter = new RegExp(pattern);
} catch (e) {
console.error("Invalid regex pattern:", e);
this.regexFilter = null;
this.filterInput.classList.add('regex-error');
return;
}
this.render();
eventSource.onerror = function(err) {
logData = "EventSource failed: " + err.message;
};
} else {
logData = "SSE Not supported by this browser."
}
}
// Initialize both log streams
document.addEventListener('DOMContentLoaded', () => {
new LogStream(
document.getElementById('proxy-log-stream'),
document.getElementById('proxy-filter-input'),
document.getElementById('proxy-clear-button'),
"/logs/streamSSE/proxy"
);
new LogStream(
document.getElementById('upstream-log-stream'),
document.getElementById('upstream-filter-input'),
document.getElementById('upstream-clear-button'),
"/logs/streamSSE/upstream"
);
// Initialize clickable headers
document.querySelectorAll('h2').forEach(header => {
header.addEventListener('click', () => {
const column = header.closest('.log-column');
column.classList.toggle('minimized');
// poor-ai's react ¯\_(ツ)_/¯
function render() {
if (regexFilter) {
const lines = logData.split('\n');
const filteredLines = lines.filter(line => {
return regexFilter === null || regexFilter.test(line);
});
});
if (filteredLines.length > 0) {
logStream.textContent = filteredLines.join('\n') + '\n';
} else {
logStream.textContent = "";
}
} else {
logStream.textContent = logData;
}
logStream.scrollTop = logStream.scrollHeight;
}
function updateFilter() {
const pattern = filterInput.value.trim();
filterInput.classList.remove('regex-error');
if (pattern) {
try {
regexFilter = new RegExp(pattern);
} catch (e) {
console.error("Invalid regex pattern:", e);
regexFilter = null;
filterInput.classList.add('regex-error');
return
}
} else {
regexFilter = null;
}
render();
}
filterInput.addEventListener('input', updateFilter);
document.getElementById('clear-button').addEventListener('click', () => {
filterInput.value = "";
regexFilter = null;
render();
});
setupEventSource();
updateFilter();
</script>
</body>
</html>
-90
View File
@@ -2,21 +2,11 @@ package proxy
import (
"container/ring"
"fmt"
"io"
"os"
"sync"
)
type LogLevel int
const (
LevelDebug LogLevel = iota
LevelInfo
LevelWarn
LevelError
)
type LogMonitor struct {
clients map[chan []byte]bool
mu sync.RWMutex
@@ -25,10 +15,6 @@ type LogMonitor struct {
// typically this can be os.Stdout
stdout io.Writer
// logging levels
level LogLevel
prefix string
}
func NewLogMonitor() *LogMonitor {
@@ -40,8 +26,6 @@ func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
clients: make(map[chan []byte]bool),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout,
level: LevelInfo,
prefix: "",
}
}
@@ -110,77 +94,3 @@ func (w *LogMonitor) broadcast(msg []byte) {
}
}
}
func (w *LogMonitor) SetPrefix(prefix string) {
w.mu.Lock()
defer w.mu.Unlock()
w.prefix = prefix
}
func (w *LogMonitor) SetLogLevel(level LogLevel) {
w.mu.Lock()
defer w.mu.Unlock()
w.level = level
}
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
prefix := ""
if w.prefix != "" {
prefix = fmt.Sprintf("[%s] ", w.prefix)
}
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
}
func (w *LogMonitor) log(level LogLevel, msg string) {
if level < w.level {
return
}
w.Write(w.formatMessage(level.String(), msg))
}
func (w *LogMonitor) Debug(msg string) {
w.log(LevelDebug, msg)
}
func (w *LogMonitor) Info(msg string) {
w.log(LevelInfo, msg)
}
func (w *LogMonitor) Warn(msg string) {
w.log(LevelWarn, msg)
}
func (w *LogMonitor) Error(msg string) {
w.log(LevelError, msg)
}
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
w.log(LevelDebug, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Infof(format string, args ...interface{}) {
w.log(LevelInfo, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
w.log(LevelWarn, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
w.log(LevelError, fmt.Sprintf(format, args...))
}
func (l LogLevel) String() string {
switch l {
case LevelDebug:
return "DEBUG"
case LevelInfo:
return "INFO"
case LevelWarn:
return "WARN"
case LevelError:
return "ERROR"
default:
return "UNKNOWN"
}
}
+28 -81
View File
@@ -30,15 +30,10 @@ const (
)
type Process struct {
ID string
config ModelConfig
cmd *exec.Cmd
// for p.cmd.Wait() select { ... }
cmdWaitChan chan error
processLogger *LogMonitor
proxyLogger *LogMonitor
ID string
config ModelConfig
cmd *exec.Cmd
logMonitor *LogMonitor
healthCheckTimeout int
healthCheckLoopInterval time.Duration
@@ -58,15 +53,13 @@ type Process struct {
shutdownCancel context.CancelFunc
}
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
ctx, cancel := context.WithCancel(context.Background())
return &Process{
ID: ID,
config: config,
cmd: nil,
cmdWaitChan: make(chan error, 1),
processLogger: processLogger,
proxyLogger: proxyLogger,
logMonitor: logMonitor,
healthCheckTimeout: healthCheckTimeout,
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
state: StateStopped,
@@ -75,11 +68,6 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLo
}
}
// LogMonitor returns the log monitor associated with the process.
func (p *Process) LogMonitor() *LogMonitor {
return p.processLogger
}
// custom error types for swapping state
var (
ErrExpectedStateMismatch = errors.New("expected state mismatch")
@@ -93,17 +81,14 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
defer p.stateMutex.Unlock()
if p.state != expectedState {
p.proxyLogger.Warnf("swapState() Unexpected current state %s, expected %s", p.state, expectedState)
return p.state, ErrExpectedStateMismatch
}
if !isValidTransition(p.state, newState) {
p.proxyLogger.Warnf("swapState() Invalid state transition from %s to %s", p.state, newState)
return p.state, ErrInvalidStateTransition
}
p.state = newState
p.proxyLogger.Debugf("swapState() State transitioned from %s to %s", expectedState, newState)
return p.state, nil
}
@@ -167,8 +152,8 @@ func (p *Process) start() error {
defer p.waitStarting.Done()
p.cmd = exec.Command(args[0], args[1:]...)
p.cmd.Stdout = p.processLogger
p.cmd.Stderr = p.processLogger
p.cmd.Stdout = p.logMonitor
p.cmd.Stderr = p.logMonitor
p.cmd.Env = p.config.Env
err = p.cmd.Start()
@@ -184,13 +169,6 @@ func (p *Process) start() error {
return fmt.Errorf("start() failed: %v", err)
}
// Capture the exit error for later signaling
go func() {
exitErr := p.cmd.Wait()
p.proxyLogger.Debugf("cmd.Wait() returned for [%s] error: %v", p.ID, exitErr)
p.cmdWaitChan <- exitErr
}()
// One of three things can happen at this stage:
// 1. The command exits unexpectedly
// 2. The health check fails
@@ -234,34 +212,17 @@ func (p *Process) start() error {
}
case <-p.shutdownCtx.Done():
return errors.New("health check interrupted due to shutdown")
case exitErr := <-p.cmdWaitChan:
if exitErr != nil {
p.proxyLogger.Warnf("upstream command exited prematurely with error: %v", exitErr)
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("upstream command exited unexpectedly: %s AND state swap failed: %v, current state: %v", exitErr.Error(), err, curState)
} else {
return fmt.Errorf("upstream command exited unexpectedly: %s", exitErr.Error())
}
} else {
p.proxyLogger.Warnf("upstream command exited prematurely with no error")
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("upstream command exited prematurely with no error AND state swap failed: %v, current state: %v", err, curState)
} else {
return fmt.Errorf("upstream command exited prematurely with no error")
}
}
default:
if err := p.checkHealthEndpoint(healthURL); err == nil {
p.proxyLogger.Infof("Health check passed on %s", healthURL)
cancelHealthCheck()
break loop
} else {
if strings.Contains(err.Error(), "connection refused") {
endTime, _ := checkDeadline.Deadline()
ttl := time.Until(endTime)
p.proxyLogger.Infof("Connection refused on %s, giving up in %.0fs", healthURL, ttl.Seconds())
fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
} else {
p.proxyLogger.Infof("Health check error on %s, %v", healthURL, err)
fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
}
}
}
@@ -285,7 +246,7 @@ func (p *Process) start() error {
p.inFlightRequests.Wait()
if time.Since(p.lastRequestHandled) > maxDuration {
p.proxyLogger.Infof("Unloading model %s, TTL of %ds reached.", p.ID, p.config.UnloadAfter)
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
p.Stop()
return
}
@@ -303,11 +264,10 @@ func (p *Process) start() error {
func (p *Process) Stop() {
// wait for any inflight requests before proceeding
p.inFlightRequests.Wait()
p.proxyLogger.Debugf("Stopping process [%s]", p.ID)
// calling Stop() when state is invalid is a no-op
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
p.proxyLogger.Infof("Stop() Ready -> StateStopping err: %v, current state: %v", err, curState)
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() Ready -> StateStopping err: %v, current state: %v\n", err, curState)
return
}
@@ -315,7 +275,7 @@ func (p *Process) Stop() {
p.stopCommand(5 * time.Second)
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
p.proxyLogger.Infof("Stop() StateStopping -> StateStopped err: %v, current state: %v", err, curState)
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() StateStopping -> StateStopped err: %v, current state: %v\n", err, curState)
}
}
@@ -331,45 +291,40 @@ func (p *Process) Shutdown() {
// stopCommand will send a SIGTERM to the process and wait for it to exit.
// If it does not exit within 5 seconds, it will send a SIGKILL.
func (p *Process) stopCommand(sigtermTTL time.Duration) {
stopStartTime := time.Now()
defer func() {
p.proxyLogger.Debugf("Process [%s] stopCommand took %v", p.ID, time.Since(stopStartTime))
}()
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
defer cancelTimeout()
sigtermNormal := make(chan error, 1)
go func() {
sigtermNormal <- p.cmd.Wait()
}()
if p.cmd == nil || p.cmd.Process == nil {
p.proxyLogger.Warnf("Process [%s] cmd or cmd.Process is nil", p.ID)
fmt.Fprintf(p.logMonitor, "!!! process [%s] cmd or cmd.Process is nil", p.ID)
return
}
if err := p.terminateProcess(); err != nil {
p.proxyLogger.Infof("Failed to gracefully terminate process [%s]: %v", p.ID, err)
}
p.cmd.Process.Signal(syscall.SIGTERM)
select {
case <-sigtermTimeout.Done():
p.proxyLogger.Infof("Process [%s] timed out waiting to stop, sending KILL signal", p.ID)
fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
p.cmd.Process.Kill()
case err := <-p.cmdWaitChan:
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
// because if we make it here then the cmd has been successfully running and made it
// through the health check. There is a possibility that ithe cmd crashed after the health check
// succeeded but that's not a case llama-swap is handling for now.
case err := <-sigtermNormal:
if err != nil {
if errno, ok := err.(syscall.Errno); ok {
p.proxyLogger.Errorf("Process [%s] errno >> %v", p.ID, errno)
fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
} else if exitError, ok := err.(*exec.ExitError); ok {
if strings.Contains(exitError.String(), "signal: terminated") {
p.proxyLogger.Infof("Process [%s] stopped OK", p.ID)
fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
} else if strings.Contains(exitError.String(), "signal: interrupt") {
p.proxyLogger.Infof("Process [%s] interrupted OK", p.ID)
fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
} else {
p.proxyLogger.Warnf("Process [%s] ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
}
} else {
p.proxyLogger.Errorf("Process [%s] exited >> %v", p.ID, err)
fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
}
}
}
@@ -401,8 +356,6 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
}
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
requestBeginTime := time.Now()
var startDuration time.Duration
// prevent new requests from being made while stopping or irrecoverable
currentState := p.CurrentState()
@@ -419,13 +372,11 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
// start the process on demand
if p.CurrentState() != StateReady {
beginStartTime := time.Now()
if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusBadGateway)
return
}
startDuration = time.Since(beginStartTime)
}
proxyTo := p.config.Proxy
@@ -469,8 +420,4 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return
}
}
totalTime := time.Since(requestBeginTime)
p.proxyLogger.Debugf("Process [%s] request %s - start: %v, total: %v",
p.ID, r.RequestURI, startDuration, totalTime)
}
-9
View File
@@ -1,9 +0,0 @@
//go:build !windows
package proxy
import "syscall"
func (p *Process) terminateProcess() error {
return p.cmd.Process.Signal(syscall.SIGTERM)
}
-14
View File
@@ -1,14 +0,0 @@
//go:build windows
package proxy
import (
"fmt"
"os/exec"
)
func (p *Process) terminateProcess() error {
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
return cmd.Run()
}
+14 -43
View File
@@ -2,6 +2,7 @@ package proxy
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
@@ -12,26 +13,13 @@ import (
"github.com/stretchr/testify/assert"
)
var (
debugLogger = NewLogMonitorWriter(os.Stdout)
)
func init() {
// flip to help with debugging tests
if false {
debugLogger.SetLogLevel(LevelDebug)
} else {
debugLogger.SetLogLevel(LevelError)
}
}
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage)
// Create a process
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
process := NewProcess("test-process", 5, config, logMonitor)
defer process.Stop()
req := httptest.NewRequest("GET", "/test", nil)
@@ -64,10 +52,11 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
// are all handled successfully, even though they all may ask for the process to .start()
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
process := NewProcess("test-process", 5, config, logMonitor)
defer process.Stop()
var wg sync.WaitGroup
@@ -95,7 +84,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
CheckEndpoint: "/health",
}
process := NewProcess("broken", 1, config, debugLogger, debugLogger)
process := NewProcess("broken", 1, config, NewLogMonitor())
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
@@ -120,7 +109,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
defer process.Stop()
// this should take 4 seconds
@@ -162,7 +151,7 @@ func TestProcess_LowTTLValue(t *testing.T) {
config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter)
process := NewProcess("ttl", 2, config, debugLogger, debugLogger)
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
defer process.Stop()
for i := 0; i < 100; i++ {
@@ -189,7 +178,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
expectedMessage := "12345"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("t", 10, config, debugLogger, debugLogger)
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
defer process.Stop()
results := map[string]string{
@@ -266,8 +255,9 @@ func TestProcess_SwapState(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger)
p.state = test.currentState
p := &Process{
state: test.currentState,
}
resultState, err := p.swapState(test.expectedState, test.newState)
if err != nil && test.expectedError == nil {
@@ -292,6 +282,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
t.Skip("skipping long shutdown test")
}
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931"
// make a config where the healthcheck will always fail because port is wrong
@@ -299,7 +290,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
config.Proxy = "http://localhost:9998/test"
healthCheckTTLSeconds := 30
process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger)
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
// make it a lot faster
process.healthCheckLoopInterval = time.Second
@@ -320,23 +311,3 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
assert.Equal(t, StateShutdown, process.CurrentState())
}
func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
if testing.Short() {
t.Skip("skipping Exit Interrupts Health Check test")
}
// should run and exit but interrupt the long checkHealthTimeout
checkHealthTimeout := 5
config := ModelConfig{
Cmd: "sleep 1",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
}
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
process.healthCheckLoopInterval = time.Second // make it faster
err := process.start()
assert.Equal(t, "upstream command exited prematurely with no error", err.Error())
assert.Equal(t, process.CurrentState(), StateFailed)
}
+37 -82
View File
@@ -7,7 +7,6 @@ import (
"io"
"mime/multipart"
"net/http"
"os"
"sort"
"strconv"
"strings"
@@ -28,101 +27,59 @@ type ProxyManager struct {
config *Config
currentProcesses map[string]*Process
logMonitor *LogMonitor
ginEngine *gin.Engine
// logging
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
muxLogger *LogMonitor
}
func New(config *Config) *ProxyManager {
// set up loggers
stdoutLogger := NewLogMonitorWriter(os.Stdout)
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
proxyLogger := NewLogMonitorWriter(stdoutLogger)
if config.LogRequests {
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
}
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
case "debug":
proxyLogger.SetLogLevel(LevelDebug)
upstreamLogger.SetLogLevel(LevelDebug)
case "info":
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
case "warn":
proxyLogger.SetLogLevel(LevelWarn)
upstreamLogger.SetLogLevel(LevelWarn)
case "error":
proxyLogger.SetLogLevel(LevelError)
upstreamLogger.SetLogLevel(LevelError)
default:
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
}
pm := &ProxyManager{
config: config,
currentProcesses: make(map[string]*Process),
logMonitor: NewLogMonitor(),
ginEngine: gin.New(),
proxyLogger: proxyLogger,
muxLogger: stdoutLogger,
upstreamLogger: upstreamLogger,
}
pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
if config.LogRequests {
pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
clientIP,
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
clientIP,
time.Now().Format("2006-01-02 15:04:05"),
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
}
// see: issue: #81, #77 and #42 for CORS issues
// see: https://github.com/mostlygeek/llama-swap/issues/42
// respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) {
if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
// allow whatever the client requested by default
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
c.Header("Access-Control-Allow-Headers", sanitized)
} else {
c.Header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, Accept, X-Requested-With",
)
}
c.Header("Access-Control-Max-Age", "86400")
c.AbortWithStatus(http.StatusNoContent)
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.AbortWithStatus(204)
return
}
c.Next()
@@ -147,8 +104,6 @@ func New(config *Config) *ProxyManager {
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
@@ -308,20 +263,19 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
requestedProcessKey := ProcessKeyName(profileName, realModelName)
if process, found := pm.currentProcesses[requestedProcessKey]; found {
pm.proxyLogger.Debugf("No-swap, using existing process for model [%s]", requestedModel)
return process, nil
}
// stop all running models
pm.proxyLogger.Infof("Swapping model to [%s]", requestedModel)
pm.stopProcesses()
if profileName == "" {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
} else {
@@ -332,7 +286,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
}
@@ -420,6 +374,7 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
return
}
}
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
+8 -37
View File
@@ -9,6 +9,7 @@ import (
)
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
accept := c.GetHeader("Accept")
if strings.Contains(accept, "text/html") {
// Set the Content-Type header to text/html
@@ -27,7 +28,7 @@ func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
}
} else {
c.Header("Content-Type", "text/plain")
history := pm.muxLogger.GetHistory()
history := pm.logMonitor.GetHistory()
_, err := c.Writer.Write(history)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
@@ -41,14 +42,8 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
c.Header("Transfer-Encoding", "chunked")
c.Header("X-Content-Type-Options", "nosniff")
logMonitorId := c.Param("logMonitorID")
logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
ch := pm.logMonitor.Subscribe()
defer pm.logMonitor.Unsubscribe(ch)
notify := c.Request.Context().Done()
flusher, ok := c.Writer.(http.Flusher)
@@ -61,7 +56,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// Send history first if not skipped
if !skipHistory {
history := logger.GetHistory()
history := pm.logMonitor.GetHistory()
if len(history) != 0 {
c.Writer.Write(history)
flusher.Flush()
@@ -90,21 +85,15 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
c.Header("Connection", "keep-alive")
c.Header("X-Content-Type-Options", "nosniff")
logMonitorId := c.Param("logMonitorID")
logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
ch := pm.logMonitor.Subscribe()
defer pm.logMonitor.Unsubscribe(ch)
notify := c.Request.Context().Done()
// Send history first if not skipped
_, skipHistory := c.GetQuery("no-history")
if !skipHistory {
history := logger.GetHistory()
history := pm.logMonitor.GetHistory()
if len(history) != 0 {
c.SSEvent("message", string(history))
c.Writer.Flush()
@@ -122,21 +111,3 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
}
}
}
// getLogger searches for the appropriate logger based on the logMonitorId
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
var logger *LogMonitor
if logMonitorId == "" {
// maintain the default
logger = pm.muxLogger
} else if logMonitorId == "proxy" {
logger = pm.proxyLogger
} else if logMonitorId == "upstream" {
logger = pm.upstreamLogger
} else {
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
}
return logger, nil
}
+1 -80
View File
@@ -22,7 +22,6 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
LogLevel: "error",
}
proxy := New(config)
@@ -63,7 +62,6 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
Profiles: map[string][]string{
"test": {model1, model2},
},
LogLevel: "error",
}
proxy := New(config)
@@ -105,7 +103,6 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
},
LogLevel: "error",
}
proxy := New(config)
@@ -156,7 +153,6 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
},
LogLevel: "error",
}
proxy := New(config)
@@ -234,7 +230,6 @@ func TestProxyManager_ProfileNonMember(t *testing.T) {
Profiles: map[string][]string{
"test": {model1},
},
LogLevel: "error",
}
proxy := New(config)
@@ -283,7 +278,6 @@ func TestProxyManager_Shutdown(t *testing.T) {
"model2": model2Config,
"model3": model3Config,
},
LogLevel: "error",
}
proxy := New(config)
@@ -319,7 +313,6 @@ func TestProxyManager_Unload(t *testing.T) {
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
}
proxy := New(config)
@@ -346,7 +339,6 @@ func TestProxyManager_StripProfileSlug(t *testing.T) {
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
LogLevel: "error",
}
proxy := New(config)
@@ -373,7 +365,6 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
LogLevel: "error",
}
// Define a helper struct to parse the JSON response.
@@ -481,7 +472,6 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
LogLevel: "error",
}
proxy := New(config)
@@ -590,8 +580,6 @@ func TestProxyManager_UseModelName(t *testing.T) {
Models: map[string]ModelConfig{
"model1": modelConfig,
},
LogLevel: "error",
}
proxy := New(config)
@@ -651,72 +639,5 @@ func TestProxyManager_UseModelName(t *testing.T) {
assert.Equal(t, upstreamModelName, response["model"])
})
}
}
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
}
tests := []struct {
name string
method string
requestHeaders map[string]string
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "OPTIONS with no headers",
method: "OPTIONS",
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
},
},
{
name: "OPTIONS with specific headers",
method: "OPTIONS",
requestHeaders: map[string]string{
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
},
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
},
},
{
name: "Non-OPTIONS request",
method: "GET",
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proxy := New(config)
defer proxy.StopProcesses()
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
for k, v := range tt.requestHeaders {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()
proxy.ginEngine.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
for header, expectedValue := range tt.expectedHeaders {
assert.Equal(t, expectedValue, w.Header().Get(header))
}
})
}
}
-43
View File
@@ -1,43 +0,0 @@
package proxy
import (
"strings"
)
func isTokenChar(r rune) bool {
switch {
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r >= '0' && r <= '9':
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
default:
return false
}
return true
}
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
parts := strings.Split(headerValues, ",")
valid := make([]string, 0, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v == "" {
continue
}
validPart := true
for _, c := range v {
if !isTokenChar(c) {
validPart = false
break
}
}
if validPart {
valid = append(valid, v)
}
}
return strings.Join(valid, ", ")
}
-77
View File
@@ -1,77 +0,0 @@
package proxy
import "testing"
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "empty string",
input: "",
expected: "",
},
{
name: "whitespace only",
input: " ",
expected: "",
},
{
name: "single valid value",
input: "content-type",
expected: "content-type",
},
{
name: "multiple valid values",
input: "content-type, authorization, x-requested-with",
expected: "content-type, authorization, x-requested-with",
},
{
name: "values with extra spaces",
input: " content-type , authorization ",
expected: "content-type, authorization",
},
{
name: "values with tabs",
input: "content-type,\tauthorization",
expected: "content-type, authorization",
},
{
name: "values with invalid characters",
input: "content-type, auth\n, x-requested-with\r",
expected: "content-type, auth, x-requested-with",
},
{
name: "empty values in list",
input: "content-type,,authorization",
expected: "content-type, authorization",
},
{
name: "leading and trailing commas",
input: ",content-type,authorization,",
expected: "content-type, authorization",
},
{
name: "mixed valid and invalid values",
input: "content-type, \x00invalid, x-requested-with",
expected: "content-type, x-requested-with",
},
{
name: "mixed case values",
input: "Content-Type, my-Valid-Header, Another-hEader",
expected: "Content-Type, my-Valid-Header, Another-hEader",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SanitizeAccessControlRequestHeaderValues(tt.input)
if got != tt.expected {
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
tt.input, got, tt.expected)
}
})
}
}