Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6516532568 | |||
| d58a8b85bf | |||
| caf9e98b1e | |||
| 539278343b | |||
| 00b738cd0f |
@@ -33,7 +33,7 @@ test: proxy/ui_dist/placeholder.txt
|
|||||||
|
|
||||||
# for CI - full test (takes longer)
|
# for CI - full test (takes longer)
|
||||||
test-all: proxy/ui_dist/placeholder.txt
|
test-all: proxy/ui_dist/placeholder.txt
|
||||||
go test -count=1 ./proxy/...
|
go test -race -count=1 ./proxy/...
|
||||||
|
|
||||||
ui/node_modules:
|
ui/node_modules:
|
||||||
cd ui && npm install
|
cd ui && npm install
|
||||||
|
|||||||
@@ -0,0 +1,397 @@
|
|||||||
|
# Improve macro-in-macro support
|
||||||
|
|
||||||
|
**Status: COMPLETED ✅**
|
||||||
|
|
||||||
|
## Title
|
||||||
|
|
||||||
|
Fix macro substitution ordering by preserving definition order using ordered YAML parsing
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The current macro implementation uses `map[string]any` which does not preserve insertion order. This causes issues when macros reference other macros - if macro `B` contains `${A}` but `B` is processed before `A`, the reference won't be substituted, leading to "unknown macro" errors.
|
||||||
|
|
||||||
|
**Goal:** Ensure macros are substituted in definition order (LIFO - last in, first out) to allow macros to reliably reference previously-defined macros.
|
||||||
|
|
||||||
|
**Outcomes:**
|
||||||
|
- Macros can reference other macros defined earlier in the config
|
||||||
|
- Macro substitution is deterministic and order-dependent
|
||||||
|
- Single-pass substitution prevents circular dependencies
|
||||||
|
- Use `yaml.Node` from `gopkg.in/yaml.v3` to preserve macro definition order
|
||||||
|
- All existing tests pass
|
||||||
|
- New tests validate substitution order and self-reference detection
|
||||||
|
|
||||||
|
## Design Requirements
|
||||||
|
|
||||||
|
### 1. YAML Parsing Strategy
|
||||||
|
- **Continue using:** `gopkg.in/yaml.v3` (current library)
|
||||||
|
- **Use:** `yaml.Node` for ordered parsing of macros
|
||||||
|
- **Reason:** `yaml.Node` preserves document structure and order, avoiding need for migration
|
||||||
|
|
||||||
|
### 2. Data Structure Changes
|
||||||
|
|
||||||
|
#### Current Implementation (config.go:19)
|
||||||
|
```go
|
||||||
|
type MacroList map[string]any
|
||||||
|
```
|
||||||
|
|
||||||
|
#### New Implementation
|
||||||
|
```go
|
||||||
|
type MacroList []MacroEntry
|
||||||
|
|
||||||
|
type MacroEntry struct {
|
||||||
|
Name string
|
||||||
|
Value any
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Implementation Note:** Parse macros using `yaml.Node` to extract key-value pairs in document order, then construct the ordered `MacroList`.
|
||||||
|
|
||||||
|
### 3. Macro Substitution Order Rules
|
||||||
|
|
||||||
|
The substitution must follow this hierarchy (from most specific to least):
|
||||||
|
|
||||||
|
1. **Reserved macros** (last): `PORT`, `MODEL_ID` - substituted last, highest priority
|
||||||
|
2. **Model-level macros** (middle): Defined in specific model config, overrides global
|
||||||
|
3. **Global macros** (first): Defined at config root level
|
||||||
|
|
||||||
|
Within each level, macros are substituted in **reverse definition order** (LIFO):
|
||||||
|
- The last macro defined is substituted first
|
||||||
|
- This allows later macros to reference earlier ones
|
||||||
|
- Single-pass substitution prevents circular dependencies
|
||||||
|
|
||||||
|
### 4. Macro Reference Rules
|
||||||
|
|
||||||
|
**Allowed:**
|
||||||
|
- Macro can reference any macro defined **before** it (earlier in the file)
|
||||||
|
- Model macros can reference global macros
|
||||||
|
- Macros can reference reserved macros (`${PORT}`, `${MODEL_ID}`)
|
||||||
|
|
||||||
|
**Prohibited:**
|
||||||
|
- Macro cannot reference itself (e.g., `foo: "value ${foo}"`)
|
||||||
|
- Macro cannot reference macros defined **after** it
|
||||||
|
- No circular references (prevented by single-pass, ordered substitution)
|
||||||
|
|
||||||
|
### 5. Validation Requirements
|
||||||
|
|
||||||
|
Add validation to detect:
|
||||||
|
- **Self-references:** Macro value contains reference to its own name
|
||||||
|
- **Unknown macros:** After substitution, any remaining `${...}` references
|
||||||
|
|
||||||
|
Error messages should be clear:
|
||||||
|
```
|
||||||
|
macro 'foo' contains self-reference
|
||||||
|
unknown macro '${bar}' in model.cmd
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. Implementation Changes
|
||||||
|
|
||||||
|
#### Files to Modify
|
||||||
|
|
||||||
|
1. **[proxy/config/config.go](proxy/config/config.go)**
|
||||||
|
- Line 19: Change `MacroList` type definition
|
||||||
|
- Line 69: Update `Macros MacroList` field
|
||||||
|
- Line 153-157: Update macro validation loop to work with ordered structure
|
||||||
|
- Line 175-188: Update model-level macro validation
|
||||||
|
- Line 181-188: **NEW** Implement proper macro merging respecting order
|
||||||
|
- Line 193-202: **NEW** Implement ordered macro substitution in LIFO order
|
||||||
|
- Line 389-415: Update `validateMacro` to detect self-references
|
||||||
|
- Line 420-475: Update `substituteMetadataMacros` to accept ordered MacroList
|
||||||
|
|
||||||
|
2. **[proxy/config/model_config.go](proxy/config/model_config.go)**
|
||||||
|
- Line 33: Update `Macros MacroList` field type
|
||||||
|
|
||||||
|
3. **All test files**
|
||||||
|
- Update test fixtures to use ordered macro definitions
|
||||||
|
- Ensure tests specify macro order explicitly
|
||||||
|
|
||||||
|
#### Core Algorithm
|
||||||
|
|
||||||
|
Replace the macro substitution logic in [config.go:181-252](proxy/config/config.go#L181-L252) with:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Merge global config and model macros. Model macros take precedence
|
||||||
|
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+2)
|
||||||
|
|
||||||
|
// Add global macros first
|
||||||
|
for _, entry := range config.Macros {
|
||||||
|
mergedMacros = append(mergedMacros, entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add model macros (can override global)
|
||||||
|
for _, entry := range modelConfig.Macros {
|
||||||
|
// Remove any existing global macro with same name
|
||||||
|
found := false
|
||||||
|
for i, existing := range mergedMacros {
|
||||||
|
if existing.Name == entry.Name {
|
||||||
|
mergedMacros[i] = entry // Override
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
mergedMacros = append(mergedMacros, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add reserved MODEL_ID macro at the end
|
||||||
|
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||||
|
|
||||||
|
// Check if PORT macro is needed
|
||||||
|
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
|
||||||
|
// enforce ${PORT} used in both cmd and proxy
|
||||||
|
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
|
||||||
|
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add PORT macro to the end (highest priority)
|
||||||
|
mergedMacros = append(mergedMacros, MacroEntry{Name: "PORT", Value: nextPort})
|
||||||
|
nextPort++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single-pass substitution: Substitute all macros in LIFO order (last defined first)
|
||||||
|
// This allows later macros to reference earlier ones
|
||||||
|
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||||
|
entry := mergedMacros[i]
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
|
// Substitute in command fields
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||||
|
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute in metadata (recursive)
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
var err error
|
||||||
|
modelConfig.Metadata, err = substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Add this new helper function to replace `substituteMetadataMacros`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||||
|
// This is called once per macro, allowing LIFO substitution order
|
||||||
|
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||||
|
macroStr := fmt.Sprintf("%v", macroValue)
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
// Check if this is a direct macro substitution
|
||||||
|
if v == macroSlug {
|
||||||
|
return macroValue, nil
|
||||||
|
}
|
||||||
|
// Handle string interpolation
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
// Recursively process map values
|
||||||
|
newMap := make(map[string]any)
|
||||||
|
for key, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newMap[key] = newVal
|
||||||
|
}
|
||||||
|
return newMap, nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
// Recursively process slice elements
|
||||||
|
newSlice := make([]any, len(v))
|
||||||
|
for i, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newSlice[i] = newVal
|
||||||
|
}
|
||||||
|
return newSlice, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Return scalar types as-is
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. Self-Reference Detection
|
||||||
|
|
||||||
|
Add to `validateMacro` function:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func validateMacro(name string, value any) error {
|
||||||
|
// ... existing validation ...
|
||||||
|
|
||||||
|
// Check for self-reference
|
||||||
|
if str, ok := value.(string); ok {
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", name)
|
||||||
|
if strings.Contains(str, macroSlug) {
|
||||||
|
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing Plan
|
||||||
|
|
||||||
|
### 1. Migration Tests
|
||||||
|
- **Test:** All existing macro tests still pass after YAML library migration
|
||||||
|
- **Files:** All `*_test.go` files with macro tests
|
||||||
|
|
||||||
|
### 2. Macro Order Tests
|
||||||
|
|
||||||
|
#### Test: Macro-in-macro substitution order
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"A": "value-A"
|
||||||
|
"B": "prefix-${A}-suffix"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "echo ${B}"
|
||||||
|
```
|
||||||
|
**Expected:** `cmd` becomes `"echo prefix-value-A-suffix"`
|
||||||
|
|
||||||
|
#### Test: LIFO substitution order
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"base": "/models"
|
||||||
|
"path": "${base}/llama"
|
||||||
|
"full": "${path}/model.gguf"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "load ${full}"
|
||||||
|
```
|
||||||
|
**Expected:** `cmd` becomes `"load /models/llama/model.gguf"`
|
||||||
|
|
||||||
|
#### Test: Model macro overrides global
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"tag": "global"
|
||||||
|
"msg": "value-${tag}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
macros:
|
||||||
|
"tag": "model-level"
|
||||||
|
cmd: "echo ${msg}"
|
||||||
|
```
|
||||||
|
**Expected:** `cmd` becomes `"echo value-model-level"` (model macro overrides global)
|
||||||
|
|
||||||
|
### 3. Reserved Macro Tests
|
||||||
|
|
||||||
|
#### Test: MODEL_ID substituted in macro
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||||
|
|
||||||
|
models:
|
||||||
|
my-model:
|
||||||
|
cmd: "${podman-llama} -m model.gguf"
|
||||||
|
```
|
||||||
|
**Expected:** `cmd` becomes `"podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf"`
|
||||||
|
|
||||||
|
### 4. Error Detection Tests
|
||||||
|
|
||||||
|
#### Test: Self-reference detection
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"recursive": "value-${recursive}"
|
||||||
|
```
|
||||||
|
**Expected:** Error: `macro 'recursive' contains self-reference`
|
||||||
|
|
||||||
|
#### Test: Undefined macro reference
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"A": "value-${UNDEFINED}"
|
||||||
|
```
|
||||||
|
**Expected:** Error: `unknown macro '${UNDEFINED}' found in macros.A` (or similar)
|
||||||
|
|
||||||
|
### 5. Regression Tests
|
||||||
|
- Run all existing macro tests: `TestConfig_MacroReplacement`, `TestConfig_MacroReservedNames`, etc.
|
||||||
|
- Ensure all pass without modification (except test fixtures if needed)
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
### Phase 1: Data Structure Changes
|
||||||
|
- [ ] Implement custom `UnmarshalYAML` method for `MacroList` that uses `yaml.Node`
|
||||||
|
- [ ] Define new ordered `MacroList` type as `[]MacroEntry`
|
||||||
|
- [ ] Update `MacroList` type definition in [config.go](proxy/config/config.go#L19)
|
||||||
|
- [ ] Update `Config.Macros` field type in [config.go](proxy/config/config.go#L69)
|
||||||
|
- [ ] Update `ModelConfig.Macros` field type in [model_config.go](proxy/config/model_config.go#L33)
|
||||||
|
- [ ] Implement helper functions:
|
||||||
|
- [ ] `func (ml MacroList) Get(name string) (any, bool)` - lookup by name
|
||||||
|
- [ ] `func (ml MacroList) Set(name string, value any) MacroList` - add/override entry
|
||||||
|
- [ ] `func (ml MacroList) ToMap() map[string]any` - convert to map if needed
|
||||||
|
|
||||||
|
### Phase 2: Macro Validation Updates
|
||||||
|
- [ ] Update macro validation loop at [config.go:153-157](proxy/config/config.go#L153-L157)
|
||||||
|
- [ ] Update model macro validation at [config.go:175-179](proxy/config/config.go#L175-L179)
|
||||||
|
- [ ] Add self-reference detection to `validateMacro` function [config.go:389](proxy/config/config.go#L389)
|
||||||
|
- [ ] Test self-reference detection with new test case
|
||||||
|
|
||||||
|
### Phase 3: Macro Substitution Algorithm
|
||||||
|
- [ ] Implement ordered macro merging (global → model → reserved) at [config.go:181-188](proxy/config/config.go#L181-L188)
|
||||||
|
- [ ] Implement single-pass LIFO substitution loop (reverse iteration) at [config.go:193-202](proxy/config/config.go#L193-L202)
|
||||||
|
- [ ] Substitute in all string fields (cmd, cmdStop, proxy, checkEndpoint, stripParams)
|
||||||
|
- [ ] Substitute in metadata within same loop
|
||||||
|
- [ ] Ensure `MODEL_ID` is added to merged macros before substitution
|
||||||
|
- [ ] Ensure `PORT` is added after port assignment (if needed)
|
||||||
|
- [ ] Replace `substituteMetadataMacros` with new `substituteMacroInValue` function that processes one macro at a time [config.go:420](proxy/config/config.go#L420)
|
||||||
|
- [ ] Remove old metadata substitution code that was separate from main loop [config.go:245-251](proxy/config/config.go#L245-L251)
|
||||||
|
|
||||||
|
### Phase 4: Testing
|
||||||
|
- [ ] Run `make test-dev` - fix any static checking errors
|
||||||
|
- [ ] Add test: macro-in-macro basic substitution
|
||||||
|
- [ ] Add test: LIFO substitution order with 3+ macro levels
|
||||||
|
- [ ] Add test: MODEL_ID in global macro used by model
|
||||||
|
- [ ] Add test: PORT in global macro used by model
|
||||||
|
- [ ] Add test: model macro overrides global macro in substitution
|
||||||
|
- [ ] Add test: self-reference detection error
|
||||||
|
- [ ] Add test: undefined macro reference error
|
||||||
|
- [ ] Verify all existing macro tests pass: `TestConfig_Macro*`
|
||||||
|
- [ ] Run `make test-all` - ensure all tests including concurrency tests pass
|
||||||
|
|
||||||
|
### Phase 5: Documentation
|
||||||
|
- [ ] Update plan status in this file (mark completed)
|
||||||
|
- [ ] Update CLAUDE.md if macro behavior needs documentation
|
||||||
|
- [ ] Verify no new error messages need user documentation
|
||||||
|
|
||||||
|
## Bug Example (Original Issue)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"podman-llama": >
|
||||||
|
podman run --name ${MODEL_ID}
|
||||||
|
--init --rm -p ${PORT}:8080 -v /home/alex/ai/models:/models:z --gpus=all
|
||||||
|
ghcr.io/ggml-org/llama.cpp:server-cuda
|
||||||
|
|
||||||
|
"standard-options": >
|
||||||
|
--no-mmap --jinja
|
||||||
|
|
||||||
|
"kv8": >
|
||||||
|
-fa on -ctk q8_0 -ctv q8_0
|
||||||
|
```
|
||||||
|
|
||||||
|
**Current Bug:**
|
||||||
|
- During macro substitution, if `${MODEL_ID}` is processed before `${podman-llama}`, the `${MODEL_ID}` reference inside `podman-llama` remains unsubstituted
|
||||||
|
- Results in error: `unknown macro '${MODEL_ID}' found in model.cmd`
|
||||||
|
|
||||||
|
**After Fix:**
|
||||||
|
- Macros substituted in LIFO order: `kv8` → `standard-options` → `podman-llama`
|
||||||
|
- `MODEL_ID` is a reserved macro, substituted last (after all user macros)
|
||||||
|
- `${MODEL_ID}` inside `podman-llama` is correctly replaced with the model name
|
||||||
+9
-4
@@ -43,14 +43,19 @@ startPort: 10001
|
|||||||
# - macro names are strings and must be less than 64 characters
|
# - macro names are strings and must be less than 64 characters
|
||||||
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
|
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
|
||||||
# - macro names must not be a reserved name: PORT or MODEL_ID
|
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||||
# - macro values must be less than 1024 characters
|
# - macro values can be numbers, bools, or strings
|
||||||
#
|
# - macros can contain other macros, but they must be defined before they are used
|
||||||
# Important: do not nest macros inside other macros; expansion is single-pass
|
|
||||||
macros:
|
macros:
|
||||||
|
# Example of a multi-line macro
|
||||||
"latest-llama": >
|
"latest-llama": >
|
||||||
/path/to/llama-server/llama-server-ec9e0301
|
/path/to/llama-server/llama-server-ec9e0301
|
||||||
--port ${PORT}
|
--port ${PORT}
|
||||||
"default_ctx": "4096"
|
|
||||||
|
"default_ctx": 4096
|
||||||
|
|
||||||
|
# Example of macro-in-macro usage. macros can contain other macros
|
||||||
|
# but they must be previously declared.
|
||||||
|
"default_args": "--ctx-size ${default_ctx}"
|
||||||
|
|
||||||
# models: a dictionary of model configurations
|
# models: a dictionary of model configurations
|
||||||
# - required
|
# - required
|
||||||
|
|||||||
+29
-3
@@ -28,7 +28,9 @@ var (
|
|||||||
func main() {
|
func main() {
|
||||||
// Define a command-line flag for the port
|
// Define a command-line flag for the port
|
||||||
configPath := flag.String("config", "config.yaml", "config file name")
|
configPath := flag.String("config", "config.yaml", "config file name")
|
||||||
listenStr := flag.String("listen", ":8080", "listen ip/port")
|
listenStr := flag.String("listen", "", "listen ip/port")
|
||||||
|
certFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
||||||
|
keyFile := flag.String("tls-key-file", "", "TLS key file")
|
||||||
showVersion := flag.Bool("version", false, "show version of build")
|
showVersion := flag.Bool("version", false, "show version of build")
|
||||||
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
|
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
|
||||||
|
|
||||||
@@ -55,6 +57,23 @@ func main() {
|
|||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate TLS flags.
|
||||||
|
var useTLS = (*certFile != "" && *keyFile != "")
|
||||||
|
if (*certFile != "" && *keyFile == "") ||
|
||||||
|
(*certFile == "" && *keyFile != "") {
|
||||||
|
fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set default ports.
|
||||||
|
if *listenStr == "" {
|
||||||
|
defaultPort := ":8080"
|
||||||
|
if useTLS {
|
||||||
|
defaultPort = ":8443"
|
||||||
|
}
|
||||||
|
listenStr = &defaultPort
|
||||||
|
}
|
||||||
|
|
||||||
// Setup channels for server management
|
// Setup channels for server management
|
||||||
exitChan := make(chan struct{})
|
exitChan := make(chan struct{})
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
@@ -167,9 +186,16 @@ func main() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
fmt.Printf("llama-swap listening on %s\n", *listenStr)
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
var err error
|
||||||
|
if useTLS {
|
||||||
|
fmt.Printf("llama-swap listening with TLS on https://%s\n", *listenStr)
|
||||||
|
err = srv.ListenAndServeTLS(*certFile, *keyFile)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("llama-swap listening on http://%s\n", *listenStr)
|
||||||
|
err = srv.ListenAndServe()
|
||||||
|
}
|
||||||
|
if err != nil && err != http.ErrServerClosed {
|
||||||
log.Fatalf("Fatal server error: %v\n", err)
|
log.Fatalf("Fatal server error: %v\n", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
+185
-59
@@ -3,11 +3,11 @@ package config
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/billziss-gh/golib/shlex"
|
"github.com/billziss-gh/golib/shlex"
|
||||||
@@ -16,7 +16,60 @@ import (
|
|||||||
|
|
||||||
const DEFAULT_GROUP_ID = "(default)"
|
const DEFAULT_GROUP_ID = "(default)"
|
||||||
|
|
||||||
type MacroList map[string]any
|
type MacroEntry struct {
|
||||||
|
Name string
|
||||||
|
Value any
|
||||||
|
}
|
||||||
|
|
||||||
|
type MacroList []MacroEntry
|
||||||
|
|
||||||
|
// UnmarshalYAML implements custom YAML unmarshaling that preserves macro definition order
|
||||||
|
func (ml *MacroList) UnmarshalYAML(value *yaml.Node) error {
|
||||||
|
if value.Kind != yaml.MappingNode {
|
||||||
|
return fmt.Errorf("macros must be a mapping")
|
||||||
|
}
|
||||||
|
|
||||||
|
// yaml.Node.Content for a mapping contains alternating key/value nodes
|
||||||
|
entries := make([]MacroEntry, 0, len(value.Content)/2)
|
||||||
|
for i := 0; i < len(value.Content); i += 2 {
|
||||||
|
keyNode := value.Content[i]
|
||||||
|
valueNode := value.Content[i+1]
|
||||||
|
|
||||||
|
var name string
|
||||||
|
if err := keyNode.Decode(&name); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode macro name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var val any
|
||||||
|
if err := valueNode.Decode(&val); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode macro value for '%s': %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries = append(entries, MacroEntry{Name: name, Value: val})
|
||||||
|
}
|
||||||
|
|
||||||
|
*ml = entries
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a macro value by name
|
||||||
|
func (ml MacroList) Get(name string) (any, bool) {
|
||||||
|
for _, entry := range ml {
|
||||||
|
if entry.Name == name {
|
||||||
|
return entry.Value, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToMap converts MacroList to a map (for backward compatibility if needed)
|
||||||
|
func (ml MacroList) ToMap() map[string]any {
|
||||||
|
result := make(map[string]any, len(ml))
|
||||||
|
for _, entry := range ml {
|
||||||
|
result[entry.Name] = entry.Value
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
type GroupConfig struct {
|
type GroupConfig struct {
|
||||||
Swap bool `yaml:"swap"`
|
Swap bool `yaml:"swap"`
|
||||||
@@ -150,8 +203,8 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
- name can not be any reserved macros: PORT, MODEL_ID
|
- name can not be any reserved macros: PORT, MODEL_ID
|
||||||
- macro values must be less than 1024 characters
|
- macro values must be less than 1024 characters
|
||||||
*/
|
*/
|
||||||
for macroName, macroValue := range config.Macros {
|
for _, macro := range config.Macros {
|
||||||
if err = validateMacro(macroName, macroValue); err != nil {
|
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||||
return Config{}, err
|
return Config{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -172,49 +225,88 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||||
|
|
||||||
// validate model macros
|
// validate model macros
|
||||||
for macroName, macroValue := range modelConfig.Macros {
|
for _, macro := range modelConfig.Macros {
|
||||||
if err = validateMacro(macroName, macroValue); err != nil {
|
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge global config and model macros. Model macros take precedence
|
// Merge global config and model macros. Model macros take precedence
|
||||||
mergedMacros := make(MacroList)
|
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros))
|
||||||
for k, v := range config.Macros {
|
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||||
mergedMacros[k] = v
|
|
||||||
|
// Add global macros first
|
||||||
|
mergedMacros = append(mergedMacros, config.Macros...)
|
||||||
|
|
||||||
|
// Add model macros (can override global)
|
||||||
|
for _, entry := range modelConfig.Macros {
|
||||||
|
// Remove any existing global macro with same name
|
||||||
|
found := false
|
||||||
|
for i, existing := range mergedMacros {
|
||||||
|
if existing.Name == entry.Name {
|
||||||
|
mergedMacros[i] = entry // Override
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
mergedMacros = append(mergedMacros, entry)
|
||||||
}
|
}
|
||||||
for k, v := range modelConfig.Macros {
|
|
||||||
mergedMacros[k] = v
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mergedMacros["MODEL_ID"] = modelId
|
// First pass: Substitute user-defined macros in reverse order (LIFO - last defined first)
|
||||||
|
// This allows later macros to reference earlier ones
|
||||||
|
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||||
|
entry := mergedMacros[i]
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
|
// Substitute in command fields
|
||||||
for macroName, macroValue := range mergedMacros {
|
|
||||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
|
||||||
// Convert macro value to string for command/string field substitution
|
|
||||||
macroStr := fmt.Sprintf("%v", macroValue)
|
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute in metadata (recursive)
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
var err error
|
||||||
|
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
modelConfig.Metadata = result.(map[string]any)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// enforce ${PORT} used in both cmd and proxy
|
// Final pass: check if PORT macro is needed after macro expansion
|
||||||
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
|
// ${PORT} is a resource on the local machine so a new port is only allocated
|
||||||
|
// if it is required in either cmd or proxy keys
|
||||||
|
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||||
|
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||||
|
if cmdHasPort || proxyHasPort { // either has it
|
||||||
|
if !cmdHasPort && proxyHasPort { // but both don't have it
|
||||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// only iterate over models that use ${PORT} to keep port numbers from increasing unnecessarily
|
// Add PORT macro and substitute it
|
||||||
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
|
portEntry := MacroEntry{Name: "PORT", Value: nextPort}
|
||||||
nextPortStr := strconv.Itoa(nextPort)
|
macroSlug := "${PORT}"
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", nextPortStr)
|
macroStr := fmt.Sprintf("%v", nextPort)
|
||||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${PORT}", nextPortStr)
|
|
||||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", nextPortStr)
|
|
||||||
|
|
||||||
// add port to merged macros so it can be used in metadata
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
mergedMacros["PORT"] = nextPort
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute PORT in metadata
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
var err error
|
||||||
|
result, err := substituteMacroInValue(modelConfig.Metadata, portEntry.Name, portEntry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
modelConfig.Metadata = result.(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
nextPort++
|
nextPort++
|
||||||
}
|
}
|
||||||
@@ -235,19 +327,27 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
if macroName == "PID" && fieldName == "cmdStop" {
|
if macroName == "PID" && fieldName == "cmdStop" {
|
||||||
continue // this is ok, has to be replaced by process later
|
continue // this is ok, has to be replaced by process later
|
||||||
}
|
}
|
||||||
if _, exists := config.Macros[macroName]; !exists {
|
// Reserved macros are always valid (they should have been substituted already)
|
||||||
|
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||||
|
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||||
|
}
|
||||||
|
// Any other macro is unknown
|
||||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for unknown macros in metadata
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
if err := validateMetadataForUnknownMacros(modelConfig.Metadata, modelId); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply macro substitution to metadata
|
// Validate the proxy URL.
|
||||||
if len(modelConfig.Metadata) > 0 {
|
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||||
substitutedMetadata, err := substituteMetadataMacros(modelConfig.Metadata, mergedMacros)
|
return Config{}, fmt.Errorf(
|
||||||
if err != nil {
|
"model %s: invalid proxy URL: %w", modelId, err,
|
||||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
)
|
||||||
}
|
|
||||||
modelConfig.Metadata = substitutedMetadata.(map[string]any)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
config.Models[modelId] = modelConfig
|
config.Models[modelId] = modelConfig
|
||||||
@@ -400,6 +500,11 @@ func validateMacro(name string, value any) error {
|
|||||||
if len(v) >= 1024 {
|
if len(v) >= 1024 {
|
||||||
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
|
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
|
||||||
}
|
}
|
||||||
|
// Check for self-reference
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", name)
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||||
|
}
|
||||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||||
// These types are allowed
|
// These types are allowed
|
||||||
default:
|
default:
|
||||||
@@ -414,41 +519,62 @@ func validateMacro(name string, value any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// substituteMetadataMacros recursively substitutes macros in metadata structures
|
// validateMetadataForUnknownMacros recursively checks for any remaining macro references in metadata
|
||||||
// Direct substitution (key: ${macro}) preserves the macro's type
|
func validateMetadataForUnknownMacros(value any, modelId string) error {
|
||||||
// Interpolated substitution (key: "text ${macro}") converts to string
|
switch v := value.(type) {
|
||||||
func substituteMetadataMacros(value any, macros MacroList) (any, error) {
|
case string:
|
||||||
|
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
macroName := match[1]
|
||||||
|
return fmt.Errorf("model %s metadata: unknown macro '${%s}'", modelId, macroName)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
for _, val := range v {
|
||||||
|
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
for _, val := range v {
|
||||||
|
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Scalar types don't contain macros
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||||
|
// This is called once per macro, allowing LIFO substitution order
|
||||||
|
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||||
|
macroStr := fmt.Sprintf("%v", macroValue)
|
||||||
|
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
case string:
|
case string:
|
||||||
// Check if this is a direct macro substitution
|
// Check if this is a direct macro substitution
|
||||||
if strings.HasPrefix(v, "${") && strings.HasSuffix(v, "}") && strings.Count(v, "${") == 1 {
|
if v == macroSlug {
|
||||||
macroName := v[2 : len(v)-1]
|
|
||||||
if macroValue, exists := macros[macroName]; exists {
|
|
||||||
return macroValue, nil
|
return macroValue, nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown macro '${%s}' in metadata", macroName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle string interpolation
|
// Handle string interpolation
|
||||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
if strings.Contains(v, macroSlug) {
|
||||||
result := v
|
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||||
for _, match := range matches {
|
|
||||||
macroName := match[1]
|
|
||||||
macroValue, exists := macros[macroName]
|
|
||||||
if !exists {
|
|
||||||
return nil, fmt.Errorf("unknown macro '${%s}' in metadata", macroName)
|
|
||||||
}
|
}
|
||||||
// Convert macro value to string for interpolation
|
return v, nil
|
||||||
macroStr := fmt.Sprintf("%v", macroValue)
|
|
||||||
result = strings.ReplaceAll(result, match[0], macroStr)
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
|
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
// Recursively process map values
|
// Recursively process map values
|
||||||
newMap := make(map[string]any)
|
newMap := make(map[string]any)
|
||||||
for key, val := range v {
|
for key, val := range v {
|
||||||
newVal, err := substituteMetadataMacros(val, macros)
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -460,7 +586,7 @@ func substituteMetadataMacros(value any, macros MacroList) (any, error) {
|
|||||||
// Recursively process slice elements
|
// Recursively process slice elements
|
||||||
newSlice := make([]any, len(v))
|
newSlice := make([]any, len(v))
|
||||||
for i, val := range v {
|
for i, val := range v {
|
||||||
newVal, err := substituteMetadataMacros(val, macros)
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ groups:
|
|||||||
LogLevel: "info",
|
LogLevel: "info",
|
||||||
StartPort: 5800,
|
StartPort: 5800,
|
||||||
Macros: MacroList{
|
Macros: MacroList{
|
||||||
"svr-path": "path/to/server",
|
{"svr-path", "path/to/server"},
|
||||||
},
|
},
|
||||||
Hooks: HooksConfig{
|
Hooks: HooksConfig{
|
||||||
OnStartup: HookOnStartup{
|
OnStartup: HookOnStartup{
|
||||||
|
|||||||
@@ -213,7 +213,9 @@ models:
|
|||||||
`
|
`
|
||||||
|
|
||||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
assert.NoError(t, err)
|
if !assert.NoError(t, err) {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
|
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three --overridden success", strings.Join(sanitizedCmd, " "))
|
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three --overridden success", strings.Join(sanitizedCmd, " "))
|
||||||
@@ -321,7 +323,7 @@ macros:
|
|||||||
models:
|
models:
|
||||||
model1:
|
model1:
|
||||||
cmd: "${svr-path} --port ${PORT}"
|
cmd: "${svr-path} --port ${PORT}"
|
||||||
proxy: "http://localhost:${unknownMacro}"
|
proxy: "http://${unknownMacro}:${PORT}"
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -503,7 +505,9 @@ models:
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "/path/to/server -p 9001 -hf model1", strings.Join(sanitizedCmd, " "))
|
assert.Equal(t, "/path/to/server -p 9001 -hf model1", strings.Join(sanitizedCmd, " "))
|
||||||
|
|
||||||
assert.Equal(t, "docker stop ${MODEL_ID}", config.Macros["docker-stop"])
|
dockerStopMacro, found := config.Macros.Get("docker-stop")
|
||||||
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, "docker stop ${MODEL_ID}", dockerStopMacro)
|
||||||
|
|
||||||
sanitizedCmd2, err := SanitizeCommand(config.Models["model2"].Cmd)
|
sanitizedCmd2, err := SanitizeCommand(config.Models["model2"].Cmd)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ groups:
|
|||||||
LogLevel: "info",
|
LogLevel: "info",
|
||||||
StartPort: 5800,
|
StartPort: 5800,
|
||||||
Macros: MacroList{
|
Macros: MacroList{
|
||||||
"svr-path": "path/to/server",
|
{"svr-path", "path/to/server"},
|
||||||
},
|
},
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": {
|
"model1": {
|
||||||
|
|||||||
@@ -0,0 +1,123 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test macro-in-macro basic substitution
|
||||||
|
func TestConfig_MacroInMacroBasic(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"A": "value-A"
|
||||||
|
"B": "prefix-${A}-suffix"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: echo ${B}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "echo prefix-value-A-suffix", config.Models["test"].Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test LIFO substitution order with 3+ macro levels
|
||||||
|
func TestConfig_MacroInMacroLIFOOrder(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"base": "/models"
|
||||||
|
"path": "${base}/llama"
|
||||||
|
"full": "${path}/model.gguf"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: load ${full}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "load /models/llama/model.gguf", config.Models["test"].Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MODEL_ID in global macro used by model
|
||||||
|
func TestConfig_ModelIdInGlobalMacro(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||||
|
|
||||||
|
models:
|
||||||
|
my-model:
|
||||||
|
cmd: ${podman-llama} -m model.gguf
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf", config.Models["my-model"].Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test model macro overrides global macro in substitution
|
||||||
|
func TestConfig_ModelMacroOverridesGlobal(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"tag": "global"
|
||||||
|
"msg": "value-${tag}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
macros:
|
||||||
|
"tag": "model-level"
|
||||||
|
cmd: echo ${msg}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "echo value-model-level", config.Models["test"].Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test self-reference detection error
|
||||||
|
func TestConfig_SelfReferenceDetection(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"recursive": "value-${recursive}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: echo ${recursive}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "recursive")
|
||||||
|
assert.Contains(t, err.Error(), "self-reference")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test undefined macro reference error
|
||||||
|
func TestConfig_UndefinedMacroReference(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"A": "value-${UNDEFINED}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: echo ${A}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "UNDEFINED")
|
||||||
|
}
|
||||||
+86
-64
@@ -4,14 +4,14 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -42,8 +42,10 @@ type Process struct {
|
|||||||
ID string
|
ID string
|
||||||
config config.ModelConfig
|
config config.ModelConfig
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
|
reverseProxy *httputil.ReverseProxy
|
||||||
|
|
||||||
// PR #155 called to cancel the upstream process
|
// PR #155 called to cancel the upstream process
|
||||||
|
cmdMutex sync.RWMutex
|
||||||
cancelUpstream context.CancelFunc
|
cancelUpstream context.CancelFunc
|
||||||
|
|
||||||
// closed when command exits
|
// closed when command exits
|
||||||
@@ -55,12 +57,14 @@ type Process struct {
|
|||||||
healthCheckTimeout int
|
healthCheckTimeout int
|
||||||
healthCheckLoopInterval time.Duration
|
healthCheckLoopInterval time.Duration
|
||||||
|
|
||||||
|
lastRequestHandledMutex sync.RWMutex
|
||||||
lastRequestHandled time.Time
|
lastRequestHandled time.Time
|
||||||
|
|
||||||
stateMutex sync.RWMutex
|
stateMutex sync.RWMutex
|
||||||
state ProcessState
|
state ProcessState
|
||||||
|
|
||||||
inFlightRequests sync.WaitGroup
|
inFlightRequests sync.WaitGroup
|
||||||
|
inFlightRequestsCount atomic.Int32
|
||||||
|
|
||||||
// used to block on multiple start() calls
|
// used to block on multiple start() calls
|
||||||
waitStarting sync.WaitGroup
|
waitStarting sync.WaitGroup
|
||||||
@@ -81,10 +85,29 @@ func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, pr
|
|||||||
concurrentLimit = config.ConcurrencyLimit
|
concurrentLimit = config.ConcurrencyLimit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Setup the reverse proxy.
|
||||||
|
proxyURL, err := url.Parse(config.Proxy)
|
||||||
|
if err != nil {
|
||||||
|
proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, config.Proxy, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var reverseProxy *httputil.ReverseProxy
|
||||||
|
if proxyURL != nil {
|
||||||
|
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
|
||||||
|
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||||
|
// prevent nginx from buffering streaming responses (e.g., SSE)
|
||||||
|
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||||
|
resp.Header.Set("X-Accel-Buffering", "no")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &Process{
|
return &Process{
|
||||||
ID: ID,
|
ID: ID,
|
||||||
config: config,
|
config: config,
|
||||||
cmd: nil,
|
cmd: nil,
|
||||||
|
reverseProxy: reverseProxy,
|
||||||
cancelUpstream: nil,
|
cancelUpstream: nil,
|
||||||
processLogger: processLogger,
|
processLogger: processLogger,
|
||||||
proxyLogger: proxyLogger,
|
proxyLogger: proxyLogger,
|
||||||
@@ -107,6 +130,20 @@ func (p *Process) LogMonitor() *LogMonitor {
|
|||||||
return p.processLogger
|
return p.processLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setLastRequestHandled sets the last request handled time in a thread-safe manner.
|
||||||
|
func (p *Process) setLastRequestHandled(t time.Time) {
|
||||||
|
p.lastRequestHandledMutex.Lock()
|
||||||
|
defer p.lastRequestHandledMutex.Unlock()
|
||||||
|
p.lastRequestHandled = t
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLastRequestHandled gets the last request handled time in a thread-safe manner.
|
||||||
|
func (p *Process) getLastRequestHandled() time.Time {
|
||||||
|
p.lastRequestHandledMutex.RLock()
|
||||||
|
defer p.lastRequestHandledMutex.RUnlock()
|
||||||
|
return p.lastRequestHandled
|
||||||
|
}
|
||||||
|
|
||||||
// custom error types for swapping state
|
// custom error types for swapping state
|
||||||
var (
|
var (
|
||||||
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
||||||
@@ -130,6 +167,13 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
|
|||||||
}
|
}
|
||||||
|
|
||||||
p.state = newState
|
p.state = newState
|
||||||
|
|
||||||
|
// Atomically increment waitStarting when entering StateStarting
|
||||||
|
// This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented
|
||||||
|
if newState == StateStarting {
|
||||||
|
p.waitStarting.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||||
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
|
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
|
||||||
return p.state, nil
|
return p.state, nil
|
||||||
@@ -158,6 +202,15 @@ func (p *Process) CurrentState() ProcessState {
|
|||||||
return p.state
|
return p.state
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// forceState forces the process state to the new state with mutex protection.
|
||||||
|
// This should only be used in exceptional cases where the normal state transition
|
||||||
|
// validation via swapState() cannot be used.
|
||||||
|
func (p *Process) forceState(newState ProcessState) {
|
||||||
|
p.stateMutex.Lock()
|
||||||
|
defer p.stateMutex.Unlock()
|
||||||
|
p.state = newState
|
||||||
|
}
|
||||||
|
|
||||||
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
||||||
// it is a private method because starting is automatic but stopping can be called
|
// it is a private method because starting is automatic but stopping can be called
|
||||||
// at any time.
|
// at any time.
|
||||||
@@ -191,7 +244,7 @@ func (p *Process) start() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
p.waitStarting.Add(1)
|
// waitStarting.Add(1) is now called atomically in swapState() when transitioning to StateStarting
|
||||||
defer p.waitStarting.Done()
|
defer p.waitStarting.Done()
|
||||||
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
|
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
|
||||||
|
|
||||||
@@ -201,8 +254,11 @@ func (p *Process) start() error {
|
|||||||
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
|
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
|
||||||
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
||||||
p.cmd.WaitDelay = p.gracefulStopTimeout
|
p.cmd.WaitDelay = p.gracefulStopTimeout
|
||||||
|
|
||||||
|
p.cmdMutex.Lock()
|
||||||
p.cancelUpstream = ctxCancelUpstream
|
p.cancelUpstream = ctxCancelUpstream
|
||||||
p.cmdWaitChan = make(chan struct{})
|
p.cmdWaitChan = make(chan struct{})
|
||||||
|
p.cmdMutex.Unlock()
|
||||||
|
|
||||||
p.failedStartCount++ // this will be reset to zero when the process has successfully started
|
p.failedStartCount++ // this will be reset to zero when the process has successfully started
|
||||||
|
|
||||||
@@ -212,7 +268,7 @@ func (p *Process) start() error {
|
|||||||
// Set process state to failed
|
// Set process state to failed
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
|
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
|
||||||
p.state = StateStopped // force it into a stopped state
|
p.forceState(StateStopped) // force it into a stopped state
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||||
strings.Join(args, " "), err, curState, swapErr,
|
strings.Join(args, " "), err, curState, swapErr,
|
||||||
@@ -285,10 +341,12 @@ func (p *Process) start() error {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for all inflight requests to complete and ticker
|
// skip the TTL check if there are inflight requests
|
||||||
p.inFlightRequests.Wait()
|
if p.inFlightRequestsCount.Load() != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
if time.Since(p.getLastRequestHandled()) > maxDuration {
|
||||||
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
|
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
|
||||||
p.Stop()
|
p.Stop()
|
||||||
return
|
return
|
||||||
@@ -344,7 +402,7 @@ func (p *Process) Shutdown() {
|
|||||||
|
|
||||||
p.stopCommand()
|
p.stopCommand()
|
||||||
// just force it to this state since there is no recovery from shutdown
|
// just force it to this state since there is no recovery from shutdown
|
||||||
p.state = StateShutdown
|
p.forceState(StateShutdown)
|
||||||
}
|
}
|
||||||
|
|
||||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||||
@@ -355,13 +413,18 @@ func (p *Process) stopCommand() {
|
|||||||
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if p.cancelUpstream == nil {
|
p.cmdMutex.RLock()
|
||||||
|
cancelUpstream := p.cancelUpstream
|
||||||
|
cmdWaitChan := p.cmdWaitChan
|
||||||
|
p.cmdMutex.RUnlock()
|
||||||
|
|
||||||
|
if cancelUpstream == nil {
|
||||||
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
|
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.cancelUpstream()
|
cancelUpstream()
|
||||||
<-p.cmdWaitChan
|
<-cmdWaitChan
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||||
@@ -418,8 +481,10 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
p.inFlightRequests.Add(1)
|
p.inFlightRequests.Add(1)
|
||||||
|
p.inFlightRequestsCount.Add(1)
|
||||||
defer func() {
|
defer func() {
|
||||||
p.lastRequestHandled = time.Now()
|
p.setLastRequestHandled(time.Now())
|
||||||
|
p.inFlightRequestsCount.Add(-1)
|
||||||
p.inFlightRequests.Done()
|
p.inFlightRequests.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -434,56 +499,10 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
startDuration = time.Since(beginStartTime)
|
startDuration = time.Since(beginStartTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyTo := p.config.Proxy
|
if p.reverseProxy != nil {
|
||||||
client := &http.Client{}
|
p.reverseProxy.ServeHTTP(w, r)
|
||||||
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
|
} else {
|
||||||
if err != nil {
|
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.Header = r.Header.Clone()
|
|
||||||
|
|
||||||
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
|
|
||||||
if err == nil {
|
|
||||||
req.ContentLength = contentLength
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
for k, vv := range resp.Header {
|
|
||||||
for _, v := range vv {
|
|
||||||
w.Header().Add(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// prevent nginx from buffering streaming responses (e.g., SSE)
|
|
||||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
}
|
|
||||||
w.WriteHeader(resp.StatusCode)
|
|
||||||
|
|
||||||
// faster than io.Copy when streaming
|
|
||||||
buf := make([]byte, 32*1024)
|
|
||||||
for {
|
|
||||||
n, err := resp.Body.Read(buf)
|
|
||||||
if n > 0 {
|
|
||||||
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if flusher, ok := w.(http.Flusher); ok {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
totalTime := time.Since(requestBeginTime)
|
totalTime := time.Since(requestBeginTime)
|
||||||
@@ -519,13 +538,16 @@ func (p *Process) waitForCmd() {
|
|||||||
case StateStopping:
|
case StateStopping:
|
||||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||||
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
|
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
|
||||||
p.state = StateStopped
|
p.forceState(StateStopped)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
|
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
|
||||||
p.state = StateStopped // force it to be in this state
|
p.forceState(StateStopped) // force it to be in this state
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.cmdMutex.Lock()
|
||||||
close(p.cmdWaitChan)
|
close(p.cmdWaitChan)
|
||||||
|
p.cmdMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
|
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
|
||||||
|
|||||||
@@ -436,7 +436,9 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
|||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
|
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
|
||||||
} else {
|
} else {
|
||||||
assert.Contains(t, w.Body.String(), "unexpected EOF")
|
// Upstream may be killed mid-response.
|
||||||
|
// Assert an incomplete or partial response.
|
||||||
|
assert.NotEqual(t, "12345", w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
close(waitChan)
|
close(waitChan)
|
||||||
|
|||||||
+70
-34
@@ -21,6 +21,32 @@ import (
|
|||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TestResponseRecorder adds CloseNotify to httptest.ResponseRecorder.
|
||||||
|
// "If you want to write your own tests around streams you will need a Recorder that can handle CloseNotifier."
|
||||||
|
// The tests can panic otherwise:
|
||||||
|
// panic: interface conversion: *httptest.ResponseRecorder is not http.CloseNotifier: missing method CloseNotify
|
||||||
|
// See: https://github.com/gin-gonic/gin/issues/1815
|
||||||
|
// TestResponseRecorder is taken from gin's own tests: https://github.com/gin-gonic/gin/blob/ce20f107f5dc498ec7489d7739541a25dcd48463/context_test.go#L1747-L1765
|
||||||
|
type TestResponseRecorder struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
closeChannel chan bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TestResponseRecorder) CloseNotify() <-chan bool {
|
||||||
|
return r.closeChannel
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TestResponseRecorder) closeClient() {
|
||||||
|
r.closeChannel <- true
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateTestResponseRecorder() *TestResponseRecorder {
|
||||||
|
return &TestResponseRecorder{
|
||||||
|
httptest.NewRecorder(),
|
||||||
|
make(chan bool, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||||
config := config.AddDefaultGroupToConfig(config.Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
@@ -37,7 +63,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|||||||
for _, modelName := range []string{"model1", "model2"} {
|
for _, modelName := range []string{"model1", "model2"} {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -74,7 +100,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
|||||||
t.Run(requestedModel, func(t *testing.T) {
|
t.Run(requestedModel, func(t *testing.T) {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -116,7 +142,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
|||||||
for _, requestedModel := range tests {
|
for _, requestedModel := range tests {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -159,7 +185,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
|
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
@@ -212,7 +238,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|||||||
// Create a test request
|
// Create a test request
|
||||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||||
req.Header.Add("Origin", "i-am-the-origin")
|
req.Header.Add("Origin", "i-am-the-origin")
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
// Call the listModelsHandler
|
// Call the listModelsHandler
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
@@ -311,7 +337,7 @@ models:
|
|||||||
proxy := New(processedConfig)
|
proxy := New(processedConfig)
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -387,7 +413,7 @@ func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
|
|||||||
|
|
||||||
// Request models list
|
// Request models list
|
||||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -448,7 +474,7 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
|||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
// send a request to trigger the proxy to load ... this should hang waiting for start up
|
// send a request to trigger the proxy to load ... this should hang waiting for start up
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
@@ -476,12 +502,12 @@ func TestProxyManager_Unload(t *testing.T) {
|
|||||||
proxy := New(conf)
|
proxy := New(conf)
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
||||||
req = httptest.NewRequest("GET", "/unload", nil)
|
req = httptest.NewRequest("GET", "/unload", nil)
|
||||||
w = httptest.NewRecorder()
|
w = CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Equal(t, w.Body.String(), "OK")
|
assert.Equal(t, w.Body.String(), "OK")
|
||||||
@@ -519,7 +545,7 @@ func TestProxyManager_UnloadSingleModel(t *testing.T) {
|
|||||||
for _, modelName := range []string{"model1", "model2"} {
|
for _, modelName := range []string{"model1", "model2"} {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -527,7 +553,7 @@ func TestProxyManager_UnloadSingleModel(t *testing.T) {
|
|||||||
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState())
|
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState())
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/models/unload/model1", nil)
|
req := httptest.NewRequest("POST", "/api/models/unload/model1", nil)
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
if !assert.Equal(t, w.Body.String(), "OK") {
|
if !assert.Equal(t, w.Body.String(), "OK") {
|
||||||
@@ -571,7 +597,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("no models loaded", func(t *testing.T) {
|
t.Run("no models loaded", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "/running", nil)
|
req := httptest.NewRequest("GET", "/running", nil)
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -589,13 +615,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
|||||||
// Load just a model.
|
// Load just a model.
|
||||||
reqBody := `{"model":"model1"}`
|
reqBody := `{"model":"model1"}`
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
// Simulate browser call for the `/running` endpoint.
|
// Simulate browser call for the `/running` endpoint.
|
||||||
req = httptest.NewRequest("GET", "/running", nil)
|
req = httptest.NewRequest("GET", "/running", nil)
|
||||||
w = httptest.NewRecorder()
|
w = CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
var response RunningResponse
|
var response RunningResponse
|
||||||
@@ -647,7 +673,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
|||||||
// Create the request with the multipart form data
|
// Create the request with the multipart form data
|
||||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
rec := httptest.NewRecorder()
|
rec := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
|
|
||||||
// Verify the response
|
// Verify the response
|
||||||
@@ -682,7 +708,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
|||||||
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
|
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -716,7 +742,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
|||||||
// Create the request with the multipart form data
|
// Create the request with the multipart form data
|
||||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
rec := httptest.NewRecorder()
|
rec := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
|
|
||||||
// Verify the response
|
// Verify the response
|
||||||
@@ -784,7 +810,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
|||||||
req.Header.Set(k, v)
|
req.Header.Set(k, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||||
@@ -812,7 +838,7 @@ models:
|
|||||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
t.Run("main model name", func(t *testing.T) {
|
t.Run("main model name", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
assert.Equal(t, "model1", rec.Body.String())
|
assert.Equal(t, "model1", rec.Body.String())
|
||||||
@@ -820,7 +846,7 @@ models:
|
|||||||
|
|
||||||
t.Run("model alias", func(t *testing.T) {
|
t.Run("model alias", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil)
|
req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
assert.Equal(t, "model1", rec.Body.String())
|
assert.Equal(t, "model1", rec.Body.String())
|
||||||
@@ -841,7 +867,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
|||||||
|
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -869,7 +895,7 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
|
|||||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}`
|
reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}`
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -900,7 +926,7 @@ func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
|
|||||||
// Make a non-streaming request
|
// Make a non-streaming request
|
||||||
reqBody := `{"model":"model1", "stream": false}`
|
reqBody := `{"model":"model1", "stream": false}`
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -935,7 +961,7 @@ func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
|
|||||||
// Make a streaming request
|
// Make a streaming request
|
||||||
reqBody := `{"model":"model1", "stream": true}`
|
reqBody := `{"model":"model1", "stream": true}`
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -967,7 +993,7 @@ func TestProxyManager_HealthEndpoint(t *testing.T) {
|
|||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
req := httptest.NewRequest("GET", "/health", nil)
|
req := httptest.NewRequest("GET", "/health", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
assert.Equal(t, "OK", rec.Body.String())
|
assert.Equal(t, "OK", rec.Body.String())
|
||||||
@@ -988,7 +1014,7 @@ func TestProxyManager_CompletionEndpoint(t *testing.T) {
|
|||||||
|
|
||||||
reqBody := `{"model":"model1"}`
|
reqBody := `{"model":"model1"}`
|
||||||
req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -1075,18 +1101,28 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
|||||||
|
|
||||||
for _, endpoint := range endpoints {
|
for _, endpoint := range endpoints {
|
||||||
t.Run(endpoint, func(t *testing.T) {
|
t.Run(endpoint, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", endpoint, nil)
|
req := httptest.NewRequest("GET", endpoint, nil)
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
rec := httptest.NewRecorder()
|
rec := CreateTestResponseRecorder()
|
||||||
|
|
||||||
// We don't need the handler to fully complete, just to set the headers
|
// Run handler in goroutine and wait for context timeout
|
||||||
// so run it in a goroutine and check the headers after a short delay
|
done := make(chan struct{})
|
||||||
go proxy.ServeHTTP(rec, req)
|
go func() {
|
||||||
time.Sleep(10 * time.Millisecond) // give it time to start and write headers
|
defer close(done)
|
||||||
|
proxy.ServeHTTP(rec, req)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for either the handler to complete or context to timeout
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
// At this point, the handler has either finished or been cancelled
|
||||||
|
// Wait for the goroutine to fully exit before reading
|
||||||
|
<-done
|
||||||
|
|
||||||
|
// Now it's safe to read from rec - no more concurrent writes
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
|
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
|
||||||
})
|
})
|
||||||
@@ -1109,7 +1145,7 @@ func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testin
|
|||||||
reqBody := `{"model":"streaming-model"}`
|
reqBody := `{"model":"streaming-model"}`
|
||||||
// simple-responder will return text/event-stream when stream=true is in the query
|
// simple-responder will return text/event-stream when stream=true is in the query
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
||||||
rec := httptest.NewRecorder()
|
rec := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import { useTheme } from "../contexts/ThemeProvider";
|
|||||||
import ConnectionStatusIcon from "./ConnectionStatus";
|
import ConnectionStatusIcon from "./ConnectionStatus";
|
||||||
|
|
||||||
export function Header() {
|
export function Header() {
|
||||||
const { screenWidth, toggleTheme, isDarkMode, appTitle, setAppTitle } = useTheme();
|
const { screenWidth, toggleTheme, isDarkMode, appTitle, setAppTitle, isNarrow } = useTheme();
|
||||||
const handleTitleChange = useCallback(
|
const handleTitleChange = useCallback(
|
||||||
(newTitle: string) => {
|
(newTitle: string) => {
|
||||||
setAppTitle(newTitle.replace(/\n/g, "").trim().substring(0, 64) || "llama-swap");
|
setAppTitle(newTitle.replace(/\n/g, "").trim().substring(0, 64) || "llama-swap");
|
||||||
@@ -17,7 +17,7 @@ export function Header() {
|
|||||||
`text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 ${isActive ? "font-semibold" : ""}`;
|
`text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 ${isActive ? "font-semibold" : ""}`;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<header className="flex items-center justify-between bg-surface border-b border-border p-2 px-4 h-[75px]">
|
<header className={`flex items-center justify-between bg-surface border-b border-border px-4 ${isNarrow ? "py-1 h-[60px]" : "p-2 h-[75px]"}`}>
|
||||||
{screenWidth !== "xs" && screenWidth !== "sm" && (
|
{screenWidth !== "xs" && screenWidth !== "sm" && (
|
||||||
<h1
|
<h1
|
||||||
contentEditable
|
contentEditable
|
||||||
|
|||||||
+48
-2
@@ -4,7 +4,7 @@ import { LogPanel } from "./LogViewer";
|
|||||||
import { usePersistentState } from "../hooks/usePersistentState";
|
import { usePersistentState } from "../hooks/usePersistentState";
|
||||||
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
|
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
|
||||||
import { useTheme } from "../contexts/ThemeProvider";
|
import { useTheme } from "../contexts/ThemeProvider";
|
||||||
import { RiEyeFill, RiEyeOffFill, RiSwapBoxFill, RiEjectLine } from "react-icons/ri";
|
import { RiEyeFill, RiEyeOffFill, RiSwapBoxFill, RiEjectLine, RiMenuFill } from "react-icons/ri";
|
||||||
|
|
||||||
export default function ModelsPage() {
|
export default function ModelsPage() {
|
||||||
const { isNarrow } = useTheme();
|
const { isNarrow } = useTheme();
|
||||||
@@ -38,9 +38,11 @@ export default function ModelsPage() {
|
|||||||
|
|
||||||
function ModelsPanel() {
|
function ModelsPanel() {
|
||||||
const { models, loadModel, unloadAllModels, unloadSingleModel } = useAPI();
|
const { models, loadModel, unloadAllModels, unloadSingleModel } = useAPI();
|
||||||
|
const { isNarrow } = useTheme();
|
||||||
const [isUnloading, setIsUnloading] = useState(false);
|
const [isUnloading, setIsUnloading] = useState(false);
|
||||||
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
|
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
|
||||||
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
|
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
|
||||||
|
const [menuOpen, setMenuOpen] = useState(false);
|
||||||
|
|
||||||
const filteredModels = useMemo(() => {
|
const filteredModels = useMemo(() => {
|
||||||
return models.filter((model) => showUnlisted || !model.unlisted);
|
return models.filter((model) => showUnlisted || !model.unlisted);
|
||||||
@@ -66,7 +68,50 @@ function ModelsPanel() {
|
|||||||
return (
|
return (
|
||||||
<div className="card h-full flex flex-col">
|
<div className="card h-full flex flex-col">
|
||||||
<div className="shrink-0">
|
<div className="shrink-0">
|
||||||
<h2>Models</h2>
|
<div className="flex justify-between items-baseline">
|
||||||
|
<h2 className={isNarrow ? "text-xl" : ""}>Models</h2>
|
||||||
|
{isNarrow && (
|
||||||
|
<div className="relative">
|
||||||
|
<button className="btn text-base flex items-center gap-2 py-1" onClick={() => setMenuOpen(!menuOpen)}>
|
||||||
|
<RiMenuFill size="20" />
|
||||||
|
</button>
|
||||||
|
{menuOpen && (
|
||||||
|
<div className="absolute right-0 mt-2 w-48 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-20">
|
||||||
|
<button
|
||||||
|
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||||
|
onClick={() => {
|
||||||
|
toggleIdorName();
|
||||||
|
setMenuOpen(false);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<RiSwapBoxFill size="20" /> {showIdorName === "id" ? "Show Name" : "Show ID"}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||||
|
onClick={() => {
|
||||||
|
setShowUnlisted(!showUnlisted);
|
||||||
|
setMenuOpen(false);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{showUnlisted ? <RiEyeOffFill size="20" /> : <RiEyeFill size="20" />}{" "}
|
||||||
|
{showUnlisted ? "Hide Unlisted" : "Show Unlisted"}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||||
|
onClick={() => {
|
||||||
|
handleUnloadAllModels();
|
||||||
|
setMenuOpen(false);
|
||||||
|
}}
|
||||||
|
disabled={isUnloading}
|
||||||
|
>
|
||||||
|
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
{!isNarrow && (
|
||||||
<div className="flex justify-between">
|
<div className="flex justify-between">
|
||||||
<div className="flex gap-2">
|
<div className="flex gap-2">
|
||||||
<button
|
<button
|
||||||
@@ -93,6 +138,7 @@ function ModelsPanel() {
|
|||||||
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
|
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="flex-1 overflow-y-auto">
|
<div className="flex-1 overflow-y-auto">
|
||||||
|
|||||||
Reference in New Issue
Block a user