Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6516532568 | |||
| d58a8b85bf | |||
| caf9e98b1e | |||
| 539278343b | |||
| 00b738cd0f | |||
| 70930e4e91 | |||
| 1f6179110c | |||
| 216c40b951 | |||
| 9e3d491c85 | |||
| 1a84926505 |
@@ -0,0 +1,43 @@
|
||||
# Project: llama-swap
|
||||
|
||||
## Project Description:
|
||||
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
## Tech stack
|
||||
|
||||
- golang
|
||||
- typescript, vite and react for UI (ui/)
|
||||
|
||||
## Testing
|
||||
|
||||
- `make test-dev` - Use this when making iterative changes. Runs `go test` and `staticcheck`. Fix any static checking errors.
|
||||
- `make test-all` - runs at the end before completing work. Includes long running concurrency tests.
|
||||
|
||||
## Workflow Tasks
|
||||
|
||||
### Plan Improvements
|
||||
|
||||
Work plans are located in ai-plans/. Plans written by the user may be incomplete, contain inconsistencies or errors.
|
||||
|
||||
When the user asks to improve a plan follow these guidelines for expanding and improving it.
|
||||
|
||||
- Identify any inconsistencies.
|
||||
- Expand plans out to be detailed specification of requirements and changes to be made.
|
||||
- Plans should have at least these sections:
|
||||
- Title - very short, describes changes
|
||||
- Overview: A more detailed summary of goal and outcomes desired
|
||||
- Design Requirements: Detailed descriptions of what needs to be done
|
||||
- Testing Plan: Tests to be implemented
|
||||
- Checklist: A detailed list of changes to be made
|
||||
|
||||
Look for "plan expansion" as explicit instructions to improve a plan.
|
||||
|
||||
### Implementation of plans
|
||||
|
||||
When the user says "paint it", respond with "commencing automated assembly". Then implement the changes as described by the plan. Update the checklist as you complete items.
|
||||
|
||||
## General Rules
|
||||
|
||||
- when summarizing changes only include details that require further action (action items)
|
||||
- when there are no action items, just say "Done."
|
||||
@@ -23,11 +23,17 @@ proxy/ui_dist/placeholder.txt:
|
||||
mkdir -p proxy/ui_dist
|
||||
touch $@
|
||||
|
||||
test: proxy/ui_dist/placeholder.txt
|
||||
go test -short -v -count=1 ./proxy
|
||||
# use cached test results while developing
|
||||
test-dev: proxy/ui_dist/placeholder.txt
|
||||
go test -short ./proxy/...
|
||||
staticcheck ./proxy/... || true
|
||||
|
||||
test: proxy/ui_dist/placeholder.txt
|
||||
go test -short -count=1 ./proxy/...
|
||||
|
||||
# for CI - full test (takes longer)
|
||||
test-all: proxy/ui_dist/placeholder.txt
|
||||
go test -v -count=1 ./proxy
|
||||
go test -race -count=1 ./proxy/...
|
||||
|
||||
ui/node_modules:
|
||||
cd ui && npm install
|
||||
@@ -81,4 +87,4 @@ release:
|
||||
git tag "$$new_tag";
|
||||
|
||||
# Phony targets
|
||||
.PHONY: all clean ui mac linux windows simple-responder
|
||||
.PHONY: all clean ui mac linux windows simple-responder test test-all test-dev
|
||||
|
||||
@@ -0,0 +1,292 @@
|
||||
# Add Model Metadata Support with Typed Macros
|
||||
|
||||
## Overview
|
||||
|
||||
Implement support for arbitrary metadata on model configurations that can be exposed through the `/v1/models` API endpoint. This feature extends the existing macro system to support scalar types (string, int, float, bool) instead of only strings, enabling type-safe metadata values.
|
||||
|
||||
The metadata will be schemaless, allowing users to define any key-value pairs they need. Macro substitution will work within metadata values, preserving types when macros are used directly and converting to strings when macros are interpolated within strings.
|
||||
|
||||
## Design Requirements
|
||||
|
||||
### 1. Enhanced Macro System
|
||||
|
||||
**Current State:**
|
||||
|
||||
- Macros are defined as `map[string]string` at both global and model levels
|
||||
- Only string substitution is supported
|
||||
- Macros are replaced in: `cmd`, `cmdStop`, `proxy`, `checkEndpoint`, `filters.stripParams`
|
||||
|
||||
**Required Changes:**
|
||||
|
||||
- Change `MacroList` type from `map[string]string` to `map[string]any`
|
||||
- Support scalar types: `string`, `int`, `float64`, `bool`
|
||||
- Implement type-preserving macro substitution:
|
||||
- Direct macro usage (`key: ${macro}`) preserves the macro's type
|
||||
- Interpolated usage (`key: "text ${macro}"`) converts to string
|
||||
- Add validation to ensure macro values are scalar types only
|
||||
- Update existing macro substitution logic in [proxy/config/config.go](proxy/config/config.go) to handle `any` types
|
||||
|
||||
**Implementation Details:**
|
||||
|
||||
- Create a generic helper function to perform macro substitution that:
|
||||
- Takes a value of type `any`
|
||||
- Recursively processes maps, slices, and scalar values
|
||||
- Replaces `${macro_name}` patterns with macro values
|
||||
- Preserves types for direct substitution
|
||||
- Converts to strings for interpolated substitution
|
||||
- Update `validateMacro()` function to accept `any` type and validate scalar types
|
||||
- Maintain backward compatibility with existing string-only macros
|
||||
|
||||
### 2. Metadata Field in ModelConfig
|
||||
|
||||
**Location:** [proxy/config/model_config.go](proxy/config/model_config.go)
|
||||
|
||||
**Required Changes:**
|
||||
|
||||
- Add `Metadata map[string]any` field to `ModelConfig` struct
|
||||
- Support YAML unmarshaling of arbitrary structures (maps, arrays, scalars)
|
||||
- Apply macro substitution to metadata values during config loading
|
||||
|
||||
**Schema Requirements:**
|
||||
|
||||
- Metadata is optional (default: empty/nil map)
|
||||
- Supports nested structures (objects within objects, arrays, etc.)
|
||||
- All string values within metadata undergo macro substitution
|
||||
- Type preservation rules apply as described above
|
||||
|
||||
### 3. Macro Substitution in Metadata
|
||||
|
||||
**Location:** [proxy/config/config.go](proxy/config/config.go) in `LoadConfigFromReader()`
|
||||
|
||||
**Process Flow:**
|
||||
|
||||
1. After loading YAML configuration
|
||||
2. After model-level and global macro merging
|
||||
3. Apply macro substitution to `ModelConfig.Metadata` field
|
||||
4. Use the same merged macros available to `cmd`, `proxy`, etc.
|
||||
5. Process recursively through all nested structures
|
||||
|
||||
**Substitution Rules:**
|
||||
|
||||
- `port: ${PORT}` → keeps integer type from PORT macro
|
||||
- `temperature: ${temp}` → keeps float type from temp macro
|
||||
- `note: "Running on ${PORT}"` → converts to string `"Running on 10001"`
|
||||
- Arrays and nested objects are processed recursively
|
||||
- Unknown macros should cause configuration load error (consistent with existing behavior)
|
||||
|
||||
### 4. API Response Updates
|
||||
|
||||
**Location:** [proxy/proxymanager.go:350](proxy/proxymanager.go#L350) `listModelsHandler()`
|
||||
|
||||
**Current Behavior:**
|
||||
|
||||
- Returns model records with: `id`, `object`, `created`, `owned_by`
|
||||
- Optionally includes: `name`, `description`
|
||||
|
||||
**Required Changes:**
|
||||
|
||||
- Add metadata to each model record under the key `llamaswap_meta`
|
||||
- Only include `llamaswap_meta` if metadata is non-empty
|
||||
- Preserve all types when marshaling to JSON
|
||||
- Maintain existing sorting by model ID
|
||||
|
||||
**Example Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "llama",
|
||||
"object": "model",
|
||||
"created": 1234567890,
|
||||
"owned_by": "llama-swap",
|
||||
"name": "llama 3.1 8B",
|
||||
"description": "A small but capable model",
|
||||
"llamaswap_meta": {
|
||||
"port": 10001,
|
||||
"temperature": 0.7,
|
||||
"note": "The llama is running on port 10001 temp=0.7, context=16384",
|
||||
"a_list": [1, 1.23, "macros are OK in list and dictionary types: llama"],
|
||||
"an_obj": {
|
||||
"a": "1",
|
||||
"b": 2,
|
||||
"c": [0.7, false, "model: llama"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Validation and Error Handling
|
||||
|
||||
**Macro Validation:**
|
||||
|
||||
- Extend `validateMacro()` to accept values of type `any`
|
||||
- Verify macro values are scalar types: `string`, `int`, `float64`, `bool`
|
||||
- Reject complex types (maps, slices, structs) as macro values
|
||||
- Maintain existing validation for macro names and lengths
|
||||
|
||||
**Configuration Loading:**
|
||||
|
||||
- Fail fast if unknown macros are found in metadata
|
||||
- Provide clear error messages indicating which model and field contains errors
|
||||
- Ensure macros in metadata follow same rules as macros in cmd/proxy fields
|
||||
|
||||
## Testing Plan
|
||||
|
||||
### Test 1: Model-Level Macros with Different Types
|
||||
|
||||
**File:** [proxy/config/model_config_test.go](proxy/config/model_config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Define model with macros of each scalar type
|
||||
- Verify metadata correctly substitutes and preserves types
|
||||
- Test direct substitution (`port: ${PORT}`)
|
||||
- Test string interpolation (`note: "Port is ${PORT}"`)
|
||||
- Verify nested objects and arrays work correctly
|
||||
|
||||
### Test 2: Global and Model Macro Precedence
|
||||
|
||||
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Define same macro at global and model level with different types
|
||||
- Verify model-level macro takes precedence
|
||||
- Test metadata uses correct macro value
|
||||
- Verify type is preserved from the winning macro
|
||||
|
||||
### Test 3: Macro Validation
|
||||
|
||||
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Test that complex types (maps, arrays) are rejected as macro values
|
||||
- Verify error message includes: macro name and type that was rejected
|
||||
- Test that scalar types (string, int, float, bool) are accepted
|
||||
- Each type should load without error
|
||||
- Test macro name validation still works with `any` types
|
||||
- Invalid characters, reserved names, length limits should still be enforced
|
||||
|
||||
### Test 4: Metadata in API Response
|
||||
|
||||
**File:** [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
|
||||
|
||||
**Existing Test:** `TestProxyManager_ListModelsHandler`
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Model with metadata → verify `llamaswap_meta` key appears
|
||||
- Model without metadata → verify `llamaswap_meta` key is absent
|
||||
- Verify all types are correctly marshaled to JSON
|
||||
- Verify nested structures are preserved
|
||||
- Verify macro substitution has occurred before serialization
|
||||
|
||||
### Test 5: Unknown Macros in Metadata
|
||||
|
||||
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Use undefined macro in metadata
|
||||
- Verify configuration loading fails with clear error
|
||||
- Error should indicate model name and that macro is undefined
|
||||
|
||||
### Test 6: Recursive Substitution
|
||||
|
||||
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Metadata with deeply nested structures
|
||||
- Arrays containing objects with macros
|
||||
- Objects containing arrays with macros
|
||||
- Mixed string interpolation and direct substitution at various nesting levels
|
||||
|
||||
## Checklist
|
||||
|
||||
### Configuration Schema Changes
|
||||
|
||||
- [x] Change `MacroList` type from `map[string]string` to `map[string]any` in [proxy/config/config.go:19](proxy/config/config.go#L19)
|
||||
- [x] Add `Metadata map[string]any` field to `ModelConfig` struct in [proxy/config/model_config.go:37](proxy/config/model_config.go#L37)
|
||||
- [x] Update `validateMacro()` function signature to accept `any` type for values
|
||||
- [x] Add validation logic to ensure macro values are scalar types only
|
||||
|
||||
### Macro Substitution Logic
|
||||
|
||||
- [x] Create generic recursive function `substituteMetadataMacros()` to handle `any` types
|
||||
- [x] Implement type-preserving direct substitution logic
|
||||
- [x] Implement string interpolation with type conversion
|
||||
- [x] Handle maps: recursively process all values
|
||||
- [x] Handle slices: recursively process all elements
|
||||
- [x] Handle scalar types: perform string-based macro substitution if value is string
|
||||
- [x] Integrate macro substitution into `LoadConfigFromReader()` after existing macro expansion
|
||||
- [x] Update existing macro substitution calls to use merged macros with correct types
|
||||
|
||||
### API Response Changes
|
||||
|
||||
- [x] Modify `listModelsHandler()` in [proxy/proxymanager.go:350](proxy/proxymanager.go#L350)
|
||||
- [x] Add `llamaswap_meta` field to model records when metadata exists
|
||||
- [x] Ensure empty metadata results in omitted `llamaswap_meta` key
|
||||
- [x] Verify JSON marshaling preserves all types correctly
|
||||
|
||||
### Testing - Config Package
|
||||
|
||||
- [x] Add test for string macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for int macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for float macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for bool macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for string interpolation in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for model-level macro precedence: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for nested structures in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for unknown macro in metadata (should error): [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for invalid macro type validation: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
### Testing - Model Config Package
|
||||
|
||||
- [x] Add test cases to [proxy/config/model_config_test.go](proxy/config/model_config_test.go) for metadata unmarshaling
|
||||
- [x] Test metadata with various scalar types
|
||||
- [x] Test metadata with nested objects and arrays
|
||||
|
||||
### Testing - Proxy Manager
|
||||
|
||||
- [x] Update `TestProxyManager_ListModelsHandler` in [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
|
||||
- [x] Add test case for model with metadata
|
||||
- [x] Add test case for model without metadata
|
||||
- [x] Verify `llamaswap_meta` key presence/absence
|
||||
- [x] Verify type preservation in JSON output
|
||||
- [x] Verify macro substitution has occurred
|
||||
|
||||
### Documentation
|
||||
|
||||
- [x] Verify [config.example.yaml](config.example.yaml) already has complete metadata examples (lines 149-171)
|
||||
- [x] No additional documentation needed per project instructions
|
||||
|
||||
## Known Issues and Considerations
|
||||
|
||||
### Inconsistencies
|
||||
|
||||
None identified. The plan references the correct existing example in [config.example.yaml:149-171](config.example.yaml#L149-L171).
|
||||
|
||||
### Design Decisions
|
||||
|
||||
1. **Why `llamaswap_meta` instead of merging into record?**
|
||||
|
||||
- Avoids potential collisions with OpenAI API standard fields
|
||||
- Makes it clear this is llama-swap specific metadata
|
||||
- Easier for clients to distinguish standard vs. custom fields
|
||||
|
||||
2. **Why support nested structures?**
|
||||
|
||||
- Provides maximum flexibility for users
|
||||
- Aligns with the schemaless design principle
|
||||
- Example config already demonstrates this capability
|
||||
|
||||
3. **Why validate macro types?**
|
||||
- Prevents confusing behavior (e.g., substituting a map)
|
||||
- Makes configuration errors explicit at load time
|
||||
- Simpler implementation and testing
|
||||
@@ -0,0 +1,397 @@
|
||||
# Improve macro-in-macro support
|
||||
|
||||
**Status: COMPLETED ✅**
|
||||
|
||||
## Title
|
||||
|
||||
Fix macro substitution ordering by preserving definition order using ordered YAML parsing
|
||||
|
||||
## Overview
|
||||
|
||||
The current macro implementation uses `map[string]any` which does not preserve insertion order. This causes issues when macros reference other macros - if macro `B` contains `${A}` but `B` is processed before `A`, the reference won't be substituted, leading to "unknown macro" errors.
|
||||
|
||||
**Goal:** Ensure macros are substituted in definition order (LIFO - last in, first out) to allow macros to reliably reference previously-defined macros.
|
||||
|
||||
**Outcomes:**
|
||||
- Macros can reference other macros defined earlier in the config
|
||||
- Macro substitution is deterministic and order-dependent
|
||||
- Single-pass substitution prevents circular dependencies
|
||||
- Use `yaml.Node` from `gopkg.in/yaml.v3` to preserve macro definition order
|
||||
- All existing tests pass
|
||||
- New tests validate substitution order and self-reference detection
|
||||
|
||||
## Design Requirements
|
||||
|
||||
### 1. YAML Parsing Strategy
|
||||
- **Continue using:** `gopkg.in/yaml.v3` (current library)
|
||||
- **Use:** `yaml.Node` for ordered parsing of macros
|
||||
- **Reason:** `yaml.Node` preserves document structure and order, avoiding need for migration
|
||||
|
||||
### 2. Data Structure Changes
|
||||
|
||||
#### Current Implementation (config.go:19)
|
||||
```go
|
||||
type MacroList map[string]any
|
||||
```
|
||||
|
||||
#### New Implementation
|
||||
```go
|
||||
type MacroList []MacroEntry
|
||||
|
||||
type MacroEntry struct {
|
||||
Name string
|
||||
Value any
|
||||
}
|
||||
```
|
||||
|
||||
**Implementation Note:** Parse macros using `yaml.Node` to extract key-value pairs in document order, then construct the ordered `MacroList`.
|
||||
|
||||
### 3. Macro Substitution Order Rules
|
||||
|
||||
The substitution must follow this hierarchy (from most specific to least):
|
||||
|
||||
1. **Reserved macros** (last): `PORT`, `MODEL_ID` - substituted last, highest priority
|
||||
2. **Model-level macros** (middle): Defined in specific model config, overrides global
|
||||
3. **Global macros** (first): Defined at config root level
|
||||
|
||||
Within each level, macros are substituted in **reverse definition order** (LIFO):
|
||||
- The last macro defined is substituted first
|
||||
- This allows later macros to reference earlier ones
|
||||
- Single-pass substitution prevents circular dependencies
|
||||
|
||||
### 4. Macro Reference Rules
|
||||
|
||||
**Allowed:**
|
||||
- Macro can reference any macro defined **before** it (earlier in the file)
|
||||
- Model macros can reference global macros
|
||||
- Macros can reference reserved macros (`${PORT}`, `${MODEL_ID}`)
|
||||
|
||||
**Prohibited:**
|
||||
- Macro cannot reference itself (e.g., `foo: "value ${foo}"`)
|
||||
- Macro cannot reference macros defined **after** it
|
||||
- No circular references (prevented by single-pass, ordered substitution)
|
||||
|
||||
### 5. Validation Requirements
|
||||
|
||||
Add validation to detect:
|
||||
- **Self-references:** Macro value contains reference to its own name
|
||||
- **Unknown macros:** After substitution, any remaining `${...}` references
|
||||
|
||||
Error messages should be clear:
|
||||
```
|
||||
macro 'foo' contains self-reference
|
||||
unknown macro '${bar}' in model.cmd
|
||||
```
|
||||
|
||||
### 6. Implementation Changes
|
||||
|
||||
#### Files to Modify
|
||||
|
||||
1. **[proxy/config/config.go](proxy/config/config.go)**
|
||||
- Line 19: Change `MacroList` type definition
|
||||
- Line 69: Update `Macros MacroList` field
|
||||
- Line 153-157: Update macro validation loop to work with ordered structure
|
||||
- Line 175-188: Update model-level macro validation
|
||||
- Line 181-188: **NEW** Implement proper macro merging respecting order
|
||||
- Line 193-202: **NEW** Implement ordered macro substitution in LIFO order
|
||||
- Line 389-415: Update `validateMacro` to detect self-references
|
||||
- Line 420-475: Update `substituteMetadataMacros` to accept ordered MacroList
|
||||
|
||||
2. **[proxy/config/model_config.go](proxy/config/model_config.go)**
|
||||
- Line 33: Update `Macros MacroList` field type
|
||||
|
||||
3. **All test files**
|
||||
- Update test fixtures to use ordered macro definitions
|
||||
- Ensure tests specify macro order explicitly
|
||||
|
||||
#### Core Algorithm
|
||||
|
||||
Replace the macro substitution logic in [config.go:181-252](proxy/config/config.go#L181-L252) with:
|
||||
|
||||
```go
|
||||
// Merge global config and model macros. Model macros take precedence
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+2)
|
||||
|
||||
// Add global macros first
|
||||
for _, entry := range config.Macros {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
|
||||
// Add model macros (can override global)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
// Remove any existing global macro with same name
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry // Override
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Add reserved MODEL_ID macro at the end
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
|
||||
// Check if PORT macro is needed
|
||||
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
|
||||
// enforce ${PORT} used in both cmd and proxy
|
||||
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
// Add PORT macro to the end (highest priority)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "PORT", Value: nextPort})
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// Single-pass substitution: Substitute all macros in LIFO order (last defined first)
|
||||
// This allows later macros to reference earlier ones
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
// Substitute in command fields
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in metadata (recursive)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
var err error
|
||||
modelConfig.Metadata, err = substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Add this new helper function to replace `substituteMetadataMacros`:
|
||||
|
||||
```go
|
||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||
// This is called once per macro, allowing LIFO substitution order
|
||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if v == macroSlug {
|
||||
return macroValue, nil
|
||||
}
|
||||
// Handle string interpolation
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 7. Self-Reference Detection
|
||||
|
||||
Add to `validateMacro` function:
|
||||
|
||||
```go
|
||||
func validateMacro(name string, value any) error {
|
||||
// ... existing validation ...
|
||||
|
||||
// Check for self-reference
|
||||
if str, ok := value.(string); ok {
|
||||
macroSlug := fmt.Sprintf("${%s}", name)
|
||||
if strings.Contains(str, macroSlug) {
|
||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Plan
|
||||
|
||||
### 1. Migration Tests
|
||||
- **Test:** All existing macro tests still pass after YAML library migration
|
||||
- **Files:** All `*_test.go` files with macro tests
|
||||
|
||||
### 2. Macro Order Tests
|
||||
|
||||
#### Test: Macro-in-macro substitution order
|
||||
```yaml
|
||||
macros:
|
||||
"A": "value-A"
|
||||
"B": "prefix-${A}-suffix"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: "echo ${B}"
|
||||
```
|
||||
**Expected:** `cmd` becomes `"echo prefix-value-A-suffix"`
|
||||
|
||||
#### Test: LIFO substitution order
|
||||
```yaml
|
||||
macros:
|
||||
"base": "/models"
|
||||
"path": "${base}/llama"
|
||||
"full": "${path}/model.gguf"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: "load ${full}"
|
||||
```
|
||||
**Expected:** `cmd` becomes `"load /models/llama/model.gguf"`
|
||||
|
||||
#### Test: Model macro overrides global
|
||||
```yaml
|
||||
macros:
|
||||
"tag": "global"
|
||||
"msg": "value-${tag}"
|
||||
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
"tag": "model-level"
|
||||
cmd: "echo ${msg}"
|
||||
```
|
||||
**Expected:** `cmd` becomes `"echo value-model-level"` (model macro overrides global)
|
||||
|
||||
### 3. Reserved Macro Tests
|
||||
|
||||
#### Test: MODEL_ID substituted in macro
|
||||
```yaml
|
||||
macros:
|
||||
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||
|
||||
models:
|
||||
my-model:
|
||||
cmd: "${podman-llama} -m model.gguf"
|
||||
```
|
||||
**Expected:** `cmd` becomes `"podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf"`
|
||||
|
||||
### 4. Error Detection Tests
|
||||
|
||||
#### Test: Self-reference detection
|
||||
```yaml
|
||||
macros:
|
||||
"recursive": "value-${recursive}"
|
||||
```
|
||||
**Expected:** Error: `macro 'recursive' contains self-reference`
|
||||
|
||||
#### Test: Undefined macro reference
|
||||
```yaml
|
||||
macros:
|
||||
"A": "value-${UNDEFINED}"
|
||||
```
|
||||
**Expected:** Error: `unknown macro '${UNDEFINED}' found in macros.A` (or similar)
|
||||
|
||||
### 5. Regression Tests
|
||||
- Run all existing macro tests: `TestConfig_MacroReplacement`, `TestConfig_MacroReservedNames`, etc.
|
||||
- Ensure all pass without modification (except test fixtures if needed)
|
||||
|
||||
## Checklist
|
||||
|
||||
### Phase 1: Data Structure Changes
|
||||
- [ ] Implement custom `UnmarshalYAML` method for `MacroList` that uses `yaml.Node`
|
||||
- [ ] Define new ordered `MacroList` type as `[]MacroEntry`
|
||||
- [ ] Update `MacroList` type definition in [config.go](proxy/config/config.go#L19)
|
||||
- [ ] Update `Config.Macros` field type in [config.go](proxy/config/config.go#L69)
|
||||
- [ ] Update `ModelConfig.Macros` field type in [model_config.go](proxy/config/model_config.go#L33)
|
||||
- [ ] Implement helper functions:
|
||||
- [ ] `func (ml MacroList) Get(name string) (any, bool)` - lookup by name
|
||||
- [ ] `func (ml MacroList) Set(name string, value any) MacroList` - add/override entry
|
||||
- [ ] `func (ml MacroList) ToMap() map[string]any` - convert to map if needed
|
||||
|
||||
### Phase 2: Macro Validation Updates
|
||||
- [ ] Update macro validation loop at [config.go:153-157](proxy/config/config.go#L153-L157)
|
||||
- [ ] Update model macro validation at [config.go:175-179](proxy/config/config.go#L175-L179)
|
||||
- [ ] Add self-reference detection to `validateMacro` function [config.go:389](proxy/config/config.go#L389)
|
||||
- [ ] Test self-reference detection with new test case
|
||||
|
||||
### Phase 3: Macro Substitution Algorithm
|
||||
- [ ] Implement ordered macro merging (global → model → reserved) at [config.go:181-188](proxy/config/config.go#L181-L188)
|
||||
- [ ] Implement single-pass LIFO substitution loop (reverse iteration) at [config.go:193-202](proxy/config/config.go#L193-L202)
|
||||
- [ ] Substitute in all string fields (cmd, cmdStop, proxy, checkEndpoint, stripParams)
|
||||
- [ ] Substitute in metadata within same loop
|
||||
- [ ] Ensure `MODEL_ID` is added to merged macros before substitution
|
||||
- [ ] Ensure `PORT` is added after port assignment (if needed)
|
||||
- [ ] Replace `substituteMetadataMacros` with new `substituteMacroInValue` function that processes one macro at a time [config.go:420](proxy/config/config.go#L420)
|
||||
- [ ] Remove old metadata substitution code that was separate from main loop [config.go:245-251](proxy/config/config.go#L245-L251)
|
||||
|
||||
### Phase 4: Testing
|
||||
- [ ] Run `make test-dev` - fix any static checking errors
|
||||
- [ ] Add test: macro-in-macro basic substitution
|
||||
- [ ] Add test: LIFO substitution order with 3+ macro levels
|
||||
- [ ] Add test: MODEL_ID in global macro used by model
|
||||
- [ ] Add test: PORT in global macro used by model
|
||||
- [ ] Add test: model macro overrides global macro in substitution
|
||||
- [ ] Add test: self-reference detection error
|
||||
- [ ] Add test: undefined macro reference error
|
||||
- [ ] Verify all existing macro tests pass: `TestConfig_Macro*`
|
||||
- [ ] Run `make test-all` - ensure all tests including concurrency tests pass
|
||||
|
||||
### Phase 5: Documentation
|
||||
- [ ] Update plan status in this file (mark completed)
|
||||
- [ ] Update CLAUDE.md if macro behavior needs documentation
|
||||
- [ ] Verify no new error messages need user documentation
|
||||
|
||||
## Bug Example (Original Issue)
|
||||
|
||||
```yaml
|
||||
macros:
|
||||
"podman-llama": >
|
||||
podman run --name ${MODEL_ID}
|
||||
--init --rm -p ${PORT}:8080 -v /home/alex/ai/models:/models:z --gpus=all
|
||||
ghcr.io/ggml-org/llama.cpp:server-cuda
|
||||
|
||||
"standard-options": >
|
||||
--no-mmap --jinja
|
||||
|
||||
"kv8": >
|
||||
-fa on -ctk q8_0 -ctv q8_0
|
||||
```
|
||||
|
||||
**Current Bug:**
|
||||
- During macro substitution, if `${MODEL_ID}` is processed before `${podman-llama}`, the `${MODEL_ID}` reference inside `podman-llama` remains unsubstituted
|
||||
- Results in error: `unknown macro '${MODEL_ID}' found in model.cmd`
|
||||
|
||||
**After Fix:**
|
||||
- Macros substituted in LIFO order: `kv8` → `standard-options` → `podman-llama`
|
||||
- `MODEL_ID` is a reserved macro, substituted last (after all user macros)
|
||||
- `${MODEL_ID}` inside `podman-llama` is correctly replaced with the model name
|
||||
+50
-4
@@ -38,13 +38,25 @@ startPort: 10001
|
||||
# macros: a dictionary of string substitutions
|
||||
# - optional, default: empty dictionary
|
||||
# - macros are reusable snippets
|
||||
# - used in a model's cmd, cmdStop, proxy and checkEndpoint
|
||||
# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
|
||||
# - useful for reducing common configuration settings
|
||||
# - macro names are strings and must be less than 64 characters
|
||||
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
|
||||
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||
# - macro values can be numbers, bools, or strings
|
||||
# - macros can contain other macros, but they must be defined before they are used
|
||||
macros:
|
||||
# Example of a multi-line macro
|
||||
"latest-llama": >
|
||||
/path/to/llama-server/llama-server-ec9e0301
|
||||
--port ${PORT}
|
||||
|
||||
"default_ctx": 4096
|
||||
|
||||
# Example of macro-in-macro usage. macros can contain other macros
|
||||
# but they must be previously declared.
|
||||
"default_args": "--ctx-size ${default_ctx}"
|
||||
|
||||
# models: a dictionary of model configurations
|
||||
# - required
|
||||
# - each key is the model's ID, used in API requests
|
||||
@@ -55,6 +67,14 @@ models:
|
||||
|
||||
# keys are the model names used in API requests
|
||||
"llama":
|
||||
# macros: a dictionary of string substitutions specific to this model
|
||||
# - optional, default: empty dictionary
|
||||
# - macros defined here override macros defined in the global macros section
|
||||
# - model level macros follow the same rules as global macros
|
||||
macros:
|
||||
"default_ctx": 16384
|
||||
"temp": 0.7
|
||||
|
||||
# cmd: the command to run to start the inference server.
|
||||
# - required
|
||||
# - it is just a string, similar to what you would run on the CLI
|
||||
@@ -64,6 +84,8 @@ models:
|
||||
# ${latest-llama} is a macro that is defined above
|
||||
${latest-llama}
|
||||
--model path/to/llama-8B-Q4_K_M.gguf
|
||||
--ctx-size ${default_ctx}
|
||||
--temperature ${temp}
|
||||
|
||||
# name: a display name for the model
|
||||
# - optional, default: empty string
|
||||
@@ -119,15 +141,39 @@ models:
|
||||
|
||||
# filters: a dictionary of filter settings
|
||||
# - optional, default: empty dictionary
|
||||
# - only strip_params is currently supported
|
||||
# - only stripParams is currently supported
|
||||
filters:
|
||||
# strip_params: a comma separated list of parameters to remove from the request
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
# - useful for server side enforcement of sampling parameters
|
||||
# - the `model` parameter can never be removed
|
||||
# - can be any JSON key in the request body
|
||||
# - recommended to stick to sampling parameters
|
||||
strip_params: "temperature, top_p, top_k"
|
||||
stripParams: "temperature, top_p, top_k"
|
||||
|
||||
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||
# - optional, default: empty dictionary
|
||||
# - while metadata can contains complex types it is recommended to keep it simple
|
||||
# - metadata is only passed through in /v1/models responses
|
||||
metadata:
|
||||
# port will remain an integer
|
||||
port: ${PORT}
|
||||
|
||||
# the ${temp} macro will remain a float
|
||||
temperature: ${temp}
|
||||
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
|
||||
|
||||
a_list:
|
||||
- 1
|
||||
- 1.23
|
||||
- "macros are OK in list and dictionary types: ${MODEL_ID}"
|
||||
|
||||
an_obj:
|
||||
a: "1"
|
||||
b: 2
|
||||
# objects can contain complex types with macro substitution
|
||||
# becomes: c: [0.7, false, "model: llama"]
|
||||
c: ["${temp}", false, "model: ${MODEL_ID}"]
|
||||
|
||||
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
|
||||
# - optional, default: 0
|
||||
|
||||
+36
-9
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/proxy"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -27,7 +28,9 @@ var (
|
||||
func main() {
|
||||
// Define a command-line flag for the port
|
||||
configPath := flag.String("config", "config.yaml", "config file name")
|
||||
listenStr := flag.String("listen", ":8080", "listen ip/port")
|
||||
listenStr := flag.String("listen", "", "listen ip/port")
|
||||
certFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
||||
keyFile := flag.String("tls-key-file", "", "TLS key file")
|
||||
showVersion := flag.Bool("version", false, "show version of build")
|
||||
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
|
||||
|
||||
@@ -38,13 +41,13 @@ func main() {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
config, err := proxy.LoadConfig(*configPath)
|
||||
conf, err := config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(config.Profiles) > 0 {
|
||||
if len(conf.Profiles) > 0 {
|
||||
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
|
||||
}
|
||||
|
||||
@@ -54,6 +57,23 @@ func main() {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
// Validate TLS flags.
|
||||
var useTLS = (*certFile != "" && *keyFile != "")
|
||||
if (*certFile != "" && *keyFile == "") ||
|
||||
(*certFile == "" && *keyFile != "") {
|
||||
fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Set default ports.
|
||||
if *listenStr == "" {
|
||||
defaultPort := ":8080"
|
||||
if useTLS {
|
||||
defaultPort = ":8443"
|
||||
}
|
||||
listenStr = &defaultPort
|
||||
}
|
||||
|
||||
// Setup channels for server management
|
||||
exitChan := make(chan struct{})
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
@@ -67,7 +87,7 @@ func main() {
|
||||
// Support for watching config and reloading when it changes
|
||||
reloadProxyManager := func() {
|
||||
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||
config, err = proxy.LoadConfig(*configPath)
|
||||
conf, err = config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
|
||||
return
|
||||
@@ -75,7 +95,7 @@ func main() {
|
||||
|
||||
fmt.Println("Configuration Changed")
|
||||
currentPM.Shutdown()
|
||||
srv.Handler = proxy.New(config)
|
||||
srv.Handler = proxy.New(conf)
|
||||
fmt.Println("Configuration Reloaded")
|
||||
|
||||
// wait a few seconds and tell any UI to reload
|
||||
@@ -85,12 +105,12 @@ func main() {
|
||||
})
|
||||
})
|
||||
} else {
|
||||
config, err = proxy.LoadConfig(*configPath)
|
||||
conf, err = config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error, unable to load configuration: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
srv.Handler = proxy.New(config)
|
||||
srv.Handler = proxy.New(conf)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,9 +186,16 @@ func main() {
|
||||
}()
|
||||
|
||||
// Start server
|
||||
fmt.Printf("llama-swap listening on %s\n", *listenStr)
|
||||
go func() {
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
var err error
|
||||
if useTLS {
|
||||
fmt.Printf("llama-swap listening with TLS on https://%s\n", *listenStr)
|
||||
err = srv.ListenAndServeTLS(*certFile, *keyFile)
|
||||
} else {
|
||||
fmt.Printf("llama-swap listening on http://%s\n", *listenStr)
|
||||
err = srv.ListenAndServe()
|
||||
}
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Fatal server error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
-460
@@ -1,460 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/billziss-gh/golib/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DEFAULT_GROUP_ID = "(default)"
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
CmdStop string `yaml:"cmdStop"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||
UnloadAfter int `yaml:"ttl"`
|
||||
Unlisted bool `yaml:"unlisted"`
|
||||
UseModelName string `yaml:"useModelName"`
|
||||
|
||||
// #179 for /v1/models
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
// Limit concurrency of HTTP requests to process
|
||||
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||
|
||||
// Model filters see issue #174
|
||||
Filters ModelFilters `yaml:"filters"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelConfig ModelConfig
|
||||
defaults := rawModelConfig{
|
||||
Cmd: "",
|
||||
CmdStop: "",
|
||||
Proxy: "http://localhost:${PORT}",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/health",
|
||||
UnloadAfter: 0,
|
||||
Unlisted: false,
|
||||
UseModelName: "",
|
||||
ConcurrencyLimit: 0,
|
||||
Name: "",
|
||||
Description: "",
|
||||
}
|
||||
|
||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||
if runtime.GOOS == "windows" {
|
||||
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*m = ModelConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
// ModelFilters see issue #174
|
||||
type ModelFilters struct {
|
||||
StripParams string `yaml:"strip_params"`
|
||||
}
|
||||
|
||||
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelFilters ModelFilters
|
||||
defaults := rawModelFilters{
|
||||
StripParams: "",
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*m = ModelFilters(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||
if f.StripParams == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
params := strings.Split(f.StripParams, ",")
|
||||
cleaned := make([]string, 0, len(params))
|
||||
|
||||
for _, param := range params {
|
||||
trimmed := strings.TrimSpace(param)
|
||||
if trimmed == "model" || trimmed == "" {
|
||||
continue
|
||||
}
|
||||
cleaned = append(cleaned, trimmed)
|
||||
}
|
||||
|
||||
// sort cleaned
|
||||
slices.Sort(cleaned)
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
type GroupConfig struct {
|
||||
Swap bool `yaml:"swap"`
|
||||
Exclusive bool `yaml:"exclusive"`
|
||||
Persistent bool `yaml:"persistent"`
|
||||
Members []string `yaml:"members"`
|
||||
}
|
||||
|
||||
// set default values for GroupConfig
|
||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawGroupConfig GroupConfig
|
||||
defaults := rawGroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Persistent: false,
|
||||
Members: []string{},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = GroupConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
type HooksConfig struct {
|
||||
OnStartup HookOnStartup `yaml:"on_startup"`
|
||||
}
|
||||
|
||||
type HookOnStartup struct {
|
||||
Preload []string `yaml:"preload"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
|
||||
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||
Macros map[string]string `yaml:"macros"`
|
||||
|
||||
// map aliases to actual model IDs
|
||||
aliases map[string]string
|
||||
|
||||
// automatic port assignments
|
||||
StartPort int `yaml:"startPort"`
|
||||
|
||||
// hooks, see: #209
|
||||
Hooks HooksConfig `yaml:"hooks"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
if _, found := c.Models[search]; found {
|
||||
return search, true
|
||||
} else if name, found := c.aliases[search]; found {
|
||||
return name, found
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||
if realName, found := c.RealModelName(modelName); !found {
|
||||
return ModelConfig{}, "", false
|
||||
} else {
|
||||
return c.Models[realName], realName, true
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (Config, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
defer file.Close()
|
||||
return LoadConfigFromReader(file)
|
||||
}
|
||||
|
||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// default configuration values
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
MetricsMaxInMemory: 1000,
|
||||
}
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
// set a minimum of 15 seconds
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
if config.StartPort < 1 {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if _, found := config.aliases[alias]; found {
|
||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||
}
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
/* check macro constraint rules:
|
||||
|
||||
- name must fit the regex ^[a-zA-Z0-9_-]+$
|
||||
- names must be less than 64 characters (no reason, just cause)
|
||||
- name can not be any reserved macros: PORT, MODEL_ID
|
||||
- macro values must be less than 1024 characters
|
||||
*/
|
||||
macroNameRegex := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
for macroName, macroValue := range config.Macros {
|
||||
if len(macroName) >= 64 {
|
||||
return Config{}, fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", macroName)
|
||||
}
|
||||
if !macroNameRegex.MatchString(macroName) {
|
||||
return Config{}, fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", macroName)
|
||||
}
|
||||
if len(macroValue) >= 1024 {
|
||||
return Config{}, fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", macroName)
|
||||
}
|
||||
switch macroName {
|
||||
case "PORT":
|
||||
case "MODEL_ID":
|
||||
return Config{}, fmt.Errorf("macro name '%s' is reserved and cannot be used", macroName)
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs first, makes testing more consistent
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
|
||||
// Strip comments from command fields before macro expansion
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
|
||||
for macroName, macroValue := range config.Macros {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroValue)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroValue)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroValue)
|
||||
}
|
||||
|
||||
// enforce ${PORT} used in both cmd and proxy
|
||||
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
// only iterate over models that use ${PORT} to keep port numbers from increasing unnecessarily
|
||||
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
|
||||
nextPortStr := strconv.Itoa(nextPort)
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", nextPortStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${PORT}", nextPortStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", nextPortStr)
|
||||
nextPort++
|
||||
}
|
||||
|
||||
if strings.Contains(modelConfig.Cmd, "${MODEL_ID}") || strings.Contains(modelConfig.CmdStop, "${MODEL_ID}") {
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${MODEL_ID}", modelId)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${MODEL_ID}", modelId)
|
||||
}
|
||||
|
||||
// make sure there are no unknown macros that have not been replaced
|
||||
macroPattern := regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPattern.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // this is ok, has to be replaced by process later
|
||||
}
|
||||
if _, exists := config.Macros[macroName]; !exists {
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
// check that members are all unique in the groups
|
||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
// Check for duplicates within this group
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
// Check if member is used in another group
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
|
||||
// clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
if real, found := config.RealModelName(modelID); found {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// rewrites the yaml to include a default group with any orphaned models
|
||||
func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
if config.Groups == nil {
|
||||
config.Groups = make(map[string]GroupConfig)
|
||||
}
|
||||
|
||||
defaultGroup := GroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{},
|
||||
}
|
||||
// if groups is empty, create a default group and put
|
||||
// all models into it
|
||||
if len(config.Groups) == 0 {
|
||||
for modelName := range config.Models {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
} else {
|
||||
// iterate over existing group members and add non-grouped models into the default group
|
||||
for modelName, _ := range config.Models {
|
||||
foundModel := false
|
||||
found:
|
||||
// search for the model in existing groups
|
||||
for _, groupConfig := range config.Groups {
|
||||
for _, member := range groupConfig.Members {
|
||||
if member == modelName {
|
||||
foundModel = true
|
||||
break found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundModel {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
||||
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
// Handle trailing backslashes by replacing with space
|
||||
if strings.HasSuffix(trimmed, "\\") {
|
||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// put it back together
|
||||
cmdStr = strings.Join(cleanedLines, "\n")
|
||||
|
||||
// Split the command into arguments
|
||||
var args []string
|
||||
if runtime.GOOS == "windows" {
|
||||
args = shlex.Windows.Split(cmdStr)
|
||||
} else {
|
||||
args = shlex.Posix.Split(cmdStr)
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func StripComments(cmdStr string) string {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
@@ -0,0 +1,601 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/billziss-gh/golib/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DEFAULT_GROUP_ID = "(default)"
|
||||
|
||||
type MacroEntry struct {
|
||||
Name string
|
||||
Value any
|
||||
}
|
||||
|
||||
type MacroList []MacroEntry
|
||||
|
||||
// UnmarshalYAML implements custom YAML unmarshaling that preserves macro definition order
|
||||
func (ml *MacroList) UnmarshalYAML(value *yaml.Node) error {
|
||||
if value.Kind != yaml.MappingNode {
|
||||
return fmt.Errorf("macros must be a mapping")
|
||||
}
|
||||
|
||||
// yaml.Node.Content for a mapping contains alternating key/value nodes
|
||||
entries := make([]MacroEntry, 0, len(value.Content)/2)
|
||||
for i := 0; i < len(value.Content); i += 2 {
|
||||
keyNode := value.Content[i]
|
||||
valueNode := value.Content[i+1]
|
||||
|
||||
var name string
|
||||
if err := keyNode.Decode(&name); err != nil {
|
||||
return fmt.Errorf("failed to decode macro name: %w", err)
|
||||
}
|
||||
|
||||
var val any
|
||||
if err := valueNode.Decode(&val); err != nil {
|
||||
return fmt.Errorf("failed to decode macro value for '%s': %w", name, err)
|
||||
}
|
||||
|
||||
entries = append(entries, MacroEntry{Name: name, Value: val})
|
||||
}
|
||||
|
||||
*ml = entries
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a macro value by name
|
||||
func (ml MacroList) Get(name string) (any, bool) {
|
||||
for _, entry := range ml {
|
||||
if entry.Name == name {
|
||||
return entry.Value, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// ToMap converts MacroList to a map (for backward compatibility if needed)
|
||||
func (ml MacroList) ToMap() map[string]any {
|
||||
result := make(map[string]any, len(ml))
|
||||
for _, entry := range ml {
|
||||
result[entry.Name] = entry.Value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type GroupConfig struct {
|
||||
Swap bool `yaml:"swap"`
|
||||
Exclusive bool `yaml:"exclusive"`
|
||||
Persistent bool `yaml:"persistent"`
|
||||
Members []string `yaml:"members"`
|
||||
}
|
||||
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
)
|
||||
|
||||
// set default values for GroupConfig
|
||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawGroupConfig GroupConfig
|
||||
defaults := rawGroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Persistent: false,
|
||||
Members: []string{},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = GroupConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
type HooksConfig struct {
|
||||
OnStartup HookOnStartup `yaml:"on_startup"`
|
||||
}
|
||||
|
||||
type HookOnStartup struct {
|
||||
Preload []string `yaml:"preload"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
|
||||
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||
Macros MacroList `yaml:"macros"`
|
||||
|
||||
// map aliases to actual model IDs
|
||||
aliases map[string]string
|
||||
|
||||
// automatic port assignments
|
||||
StartPort int `yaml:"startPort"`
|
||||
|
||||
// hooks, see: #209
|
||||
Hooks HooksConfig `yaml:"hooks"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
if _, found := c.Models[search]; found {
|
||||
return search, true
|
||||
} else if name, found := c.aliases[search]; found {
|
||||
return name, found
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||
if realName, found := c.RealModelName(modelName); !found {
|
||||
return ModelConfig{}, "", false
|
||||
} else {
|
||||
return c.Models[realName], realName, true
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (Config, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
defer file.Close()
|
||||
return LoadConfigFromReader(file)
|
||||
}
|
||||
|
||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// default configuration values
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
MetricsMaxInMemory: 1000,
|
||||
}
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
// set a minimum of 15 seconds
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
if config.StartPort < 1 {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if _, found := config.aliases[alias]; found {
|
||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||
}
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
/* check macro constraint rules:
|
||||
|
||||
- name must fit the regex ^[a-zA-Z0-9_-]+$
|
||||
- names must be less than 64 characters (no reason, just cause)
|
||||
- name can not be any reserved macros: PORT, MODEL_ID
|
||||
- macro values must be less than 1024 characters
|
||||
*/
|
||||
for _, macro := range config.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs first, makes testing more consistent
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
|
||||
// Strip comments from command fields before macro expansion
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Merge global config and model macros. Model macros take precedence
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros))
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
|
||||
// Add global macros first
|
||||
mergedMacros = append(mergedMacros, config.Macros...)
|
||||
|
||||
// Add model macros (can override global)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
// Remove any existing global macro with same name
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry // Override
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// First pass: Substitute user-defined macros in reverse order (LIFO - last defined first)
|
||||
// This allows later macros to reference earlier ones
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
// Substitute in command fields
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in metadata (recursive)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
var err error
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Final pass: check if PORT macro is needed after macro expansion
|
||||
// ${PORT} is a resource on the local machine so a new port is only allocated
|
||||
// if it is required in either cmd or proxy keys
|
||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||
if cmdHasPort || proxyHasPort { // either has it
|
||||
if !cmdHasPort && proxyHasPort { // but both don't have it
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
// Add PORT macro and substitute it
|
||||
portEntry := MacroEntry{Name: "PORT", Value: nextPort}
|
||||
macroSlug := "${PORT}"
|
||||
macroStr := fmt.Sprintf("%v", nextPort)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
|
||||
// Substitute PORT in metadata
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
var err error
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, portEntry.Name, portEntry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// make sure there are no unknown macros that have not been replaced
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // this is ok, has to be replaced by process later
|
||||
}
|
||||
// Reserved macros are always valid (they should have been substituted already)
|
||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
// Any other macro is unknown
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for unknown macros in metadata
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
if err := validateMetadataForUnknownMacros(modelConfig.Metadata, modelId); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the proxy URL.
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf(
|
||||
"model %s: invalid proxy URL: %w", modelId, err,
|
||||
)
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
// check that members are all unique in the groups
|
||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
// Check for duplicates within this group
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
// Check if member is used in another group
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
|
||||
// clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
if real, found := config.RealModelName(modelID); found {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// rewrites the yaml to include a default group with any orphaned models
|
||||
func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
if config.Groups == nil {
|
||||
config.Groups = make(map[string]GroupConfig)
|
||||
}
|
||||
|
||||
defaultGroup := GroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{},
|
||||
}
|
||||
// if groups is empty, create a default group and put
|
||||
// all models into it
|
||||
if len(config.Groups) == 0 {
|
||||
for modelName := range config.Models {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
} else {
|
||||
// iterate over existing group members and add non-grouped models into the default group
|
||||
for modelName := range config.Models {
|
||||
foundModel := false
|
||||
found:
|
||||
// search for the model in existing groups
|
||||
for _, groupConfig := range config.Groups {
|
||||
for _, member := range groupConfig.Members {
|
||||
if member == modelName {
|
||||
foundModel = true
|
||||
break found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundModel {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
||||
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
// Handle trailing backslashes by replacing with space
|
||||
if strings.HasSuffix(trimmed, "\\") {
|
||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// put it back together
|
||||
cmdStr = strings.Join(cleanedLines, "\n")
|
||||
|
||||
// Split the command into arguments
|
||||
var args []string
|
||||
if runtime.GOOS == "windows" {
|
||||
args = shlex.Windows.Split(cmdStr)
|
||||
} else {
|
||||
args = shlex.Posix.Split(cmdStr)
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func StripComments(cmdStr string) string {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
|
||||
// validateMacro validates macro name and value constraints
|
||||
func validateMacro(name string, value any) error {
|
||||
if len(name) >= 64 {
|
||||
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||
}
|
||||
if !macroNameRegex.MatchString(name) {
|
||||
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||
}
|
||||
|
||||
// Validate that value is a scalar type
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if len(v) >= 1024 {
|
||||
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
|
||||
}
|
||||
// Check for self-reference
|
||||
macroSlug := fmt.Sprintf("${%s}", name)
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||
// These types are allowed
|
||||
default:
|
||||
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "PORT", "MODEL_ID":
|
||||
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateMetadataForUnknownMacros recursively checks for any remaining macro references in metadata
|
||||
func validateMetadataForUnknownMacros(value any, modelId string) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
return fmt.Errorf("model %s metadata: unknown macro '${%s}'", modelId, macroName)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
for _, val := range v {
|
||||
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
// Scalar types don't contain macros
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||
// This is called once per macro, allowing LIFO substitution order
|
||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if v == macroSlug {
|
||||
return macroValue, nil
|
||||
}
|
||||
// Handle string interpolation
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !windows
|
||||
|
||||
package proxy
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -163,8 +163,8 @@ groups:
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
StartPort: 5800,
|
||||
Macros: map[string]string{
|
||||
"svr-path": "path/to/server",
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
},
|
||||
Hooks: HooksConfig{
|
||||
OnStartup: HookOnStartup{
|
||||
@@ -1,4 +1,4 @@
|
||||
package proxy
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
@@ -65,18 +65,6 @@ models:
|
||||
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
|
||||
}
|
||||
|
||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||
config := &ModelConfig{
|
||||
Cmd: `python model1.py \
|
||||
--arg1 value1 \
|
||||
--arg2 value2`,
|
||||
}
|
||||
|
||||
args, err := config.SanitizedCommand()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||
}
|
||||
|
||||
func TestConfig_FindConfig(t *testing.T) {
|
||||
|
||||
// TODO?
|
||||
@@ -207,30 +195,91 @@ macros:
|
||||
argOne: "--arg1"
|
||||
argTwo: "--arg2"
|
||||
autoPort: "--port ${PORT}"
|
||||
overriddenByModelMacro: failed
|
||||
|
||||
models:
|
||||
model1:
|
||||
macros:
|
||||
overriddenByModelMacro: success
|
||||
cmd: |
|
||||
${svr-path} ${argTwo}
|
||||
# the automatic ${PORT} is replaced
|
||||
${autoPort}
|
||||
${argOne}
|
||||
--arg3 three
|
||||
--overridden ${overriddenByModelMacro}
|
||||
cmdStop: |
|
||||
/path/to/stop.sh --port ${PORT} ${argTwo}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three", strings.Join(sanitizedCmd, " "))
|
||||
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three --overridden success", strings.Join(sanitizedCmd, " "))
|
||||
|
||||
sanitizedCmdStop, err := SanitizeCommand(config.Models["model1"].CmdStop)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/path/to/stop.sh --port 9990 --arg2", strings.Join(sanitizedCmdStop, " "))
|
||||
}
|
||||
|
||||
func TestConfig_MacroReservedNames(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config string
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "global macro named PORT",
|
||||
config: `
|
||||
macros:
|
||||
PORT: "1111"
|
||||
`,
|
||||
expectedError: "macro name 'PORT' is reserved",
|
||||
},
|
||||
{
|
||||
name: "global macro named MODEL_ID",
|
||||
config: `
|
||||
macros:
|
||||
MODEL_ID: model1
|
||||
`,
|
||||
expectedError: "macro name 'MODEL_ID' is reserved",
|
||||
},
|
||||
{
|
||||
name: "model macro named PORT",
|
||||
config: `
|
||||
models:
|
||||
model1:
|
||||
macros:
|
||||
PORT: 1111
|
||||
`,
|
||||
expectedError: "model model1: macro name 'PORT' is reserved",
|
||||
},
|
||||
|
||||
{
|
||||
name: "model macro named MODEL_ID",
|
||||
config: `
|
||||
models:
|
||||
model1:
|
||||
macros:
|
||||
MODEL_ID: model1
|
||||
`,
|
||||
expectedError: "model model1: macro name 'MODEL_ID' is reserved",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LoadConfigFromReader(strings.NewReader(tt.config))
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, tt.expectedError, err.Error())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_MacroErrorOnUnknownMacros(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -274,7 +323,7 @@ macros:
|
||||
models:
|
||||
model1:
|
||||
cmd: "${svr-path} --port ${PORT}"
|
||||
proxy: "http://localhost:${unknownMacro}"
|
||||
proxy: "http://${unknownMacro}:${PORT}"
|
||||
`,
|
||||
},
|
||||
{
|
||||
@@ -288,6 +337,20 @@ models:
|
||||
model1:
|
||||
cmd: "${svr-path} --port ${PORT}"
|
||||
checkEndpoint: "http://localhost:${unknownMacro}/health"
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "unknown macro in filters.stripParams",
|
||||
field: "filters.stripParams",
|
||||
content: `
|
||||
startPort: 9990
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
models:
|
||||
model1:
|
||||
cmd: "${svr-path} --port ${PORT}"
|
||||
filters:
|
||||
stripParams: "model,${unknownMacro}"
|
||||
`,
|
||||
},
|
||||
}
|
||||
@@ -295,38 +358,13 @@ models:
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field)
|
||||
}
|
||||
//t.Log(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_ModelFilters(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
default_strip: "temperature, top_p"
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
strip_params: "model, top_k, ${default_strip}, , ,"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
modelConfig, ok := config.Models["model1"]
|
||||
if !assert.True(t, ok) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// make sure `model` and enmpty strings are not in the list
|
||||
assert.Equal(t, "model, top_k, temperature, top_p, , ,", modelConfig.Filters.StripParams)
|
||||
sanitized, err := modelConfig.Filters.SanitizedStripParams()
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripComments(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -467,7 +505,9 @@ models:
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/path/to/server -p 9001 -hf model1", strings.Join(sanitizedCmd, " "))
|
||||
|
||||
assert.Equal(t, "docker stop ${MODEL_ID}", config.Macros["docker-stop"])
|
||||
dockerStopMacro, found := config.Macros.Get("docker-stop")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "docker stop ${MODEL_ID}", dockerStopMacro)
|
||||
|
||||
sanitizedCmd2, err := SanitizeCommand(config.Models["model2"].Cmd)
|
||||
assert.NoError(t, err)
|
||||
@@ -481,3 +521,243 @@ models:
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/path/to/server -p 9000 -hf author/model:F16", strings.Join(sanitizedCmd3, " "))
|
||||
}
|
||||
|
||||
func TestConfig_TypedMacrosInMetadata(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
PORT_NUM: 10001
|
||||
TEMP: 0.7
|
||||
ENABLED: true
|
||||
NAME: "llama model"
|
||||
CTX: 16384
|
||||
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
metadata:
|
||||
port: ${PORT_NUM}
|
||||
temperature: ${TEMP}
|
||||
enabled: ${ENABLED}
|
||||
model_name: ${NAME}
|
||||
context: ${CTX}
|
||||
note: "Running on port ${PORT_NUM} with temp ${TEMP} and context ${CTX}"
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
meta := config.Models["test-model"].Metadata
|
||||
assert.NotNil(t, meta)
|
||||
|
||||
// Verify direct substitution preserves types
|
||||
assert.Equal(t, 10001, meta["port"])
|
||||
assert.Equal(t, 0.7, meta["temperature"])
|
||||
assert.Equal(t, true, meta["enabled"])
|
||||
assert.Equal(t, "llama model", meta["model_name"])
|
||||
assert.Equal(t, 16384, meta["context"])
|
||||
|
||||
// Verify string interpolation converts to string
|
||||
assert.Equal(t, "Running on port 10001 with temp 0.7 and context 16384", meta["note"])
|
||||
}
|
||||
|
||||
func TestConfig_NestedStructuresInMetadata(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
TEMP: 0.7
|
||||
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
metadata:
|
||||
config:
|
||||
port: ${PORT}
|
||||
temperature: ${TEMP}
|
||||
tags: ["model:${MODEL_ID}", "port:${PORT}"]
|
||||
nested:
|
||||
deep:
|
||||
value: ${TEMP}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
meta := config.Models["test-model"].Metadata
|
||||
assert.NotNil(t, meta)
|
||||
|
||||
// Verify nested objects
|
||||
configMap := meta["config"].(map[string]any)
|
||||
assert.Equal(t, 10000, configMap["port"])
|
||||
assert.Equal(t, 0.7, configMap["temperature"])
|
||||
|
||||
// Verify arrays
|
||||
tags := meta["tags"].([]any)
|
||||
assert.Equal(t, "model:test-model", tags[0])
|
||||
assert.Equal(t, "port:10000", tags[1])
|
||||
|
||||
// Verify deeply nested structures
|
||||
nested := meta["nested"].(map[string]any)
|
||||
deep := nested["deep"].(map[string]any)
|
||||
assert.Equal(t, 0.7, deep["value"])
|
||||
}
|
||||
|
||||
func TestConfig_ModelLevelMacroPrecedenceInMetadata(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
TEMP: 0.5
|
||||
GLOBAL_VAL: "global"
|
||||
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
macros:
|
||||
TEMP: 0.9
|
||||
LOCAL_VAL: "local"
|
||||
metadata:
|
||||
temperature: ${TEMP}
|
||||
global: ${GLOBAL_VAL}
|
||||
local: ${LOCAL_VAL}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
meta := config.Models["test-model"].Metadata
|
||||
assert.NotNil(t, meta)
|
||||
|
||||
// Model-level macro should override global
|
||||
assert.Equal(t, 0.9, meta["temperature"])
|
||||
// Global macro should be accessible
|
||||
assert.Equal(t, "global", meta["global"])
|
||||
// Model-level macro should be accessible
|
||||
assert.Equal(t, "local", meta["local"])
|
||||
}
|
||||
|
||||
func TestConfig_UnknownMacroInMetadata(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
metadata:
|
||||
value: ${UNKNOWN_MACRO}
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "test-model")
|
||||
assert.Contains(t, err.Error(), "UNKNOWN_MACRO")
|
||||
}
|
||||
|
||||
func TestConfig_InvalidMacroType(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
INVALID:
|
||||
nested: value
|
||||
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "INVALID")
|
||||
assert.Contains(t, err.Error(), "must be a scalar type")
|
||||
}
|
||||
|
||||
func TestConfig_MacroTypeValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
yaml string
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
name: "string macro",
|
||||
yaml: `
|
||||
startPort: 10000
|
||||
macros:
|
||||
STR: "test"
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "int macro",
|
||||
yaml: `
|
||||
startPort: 10000
|
||||
macros:
|
||||
NUM: 42
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "float macro",
|
||||
yaml: `
|
||||
startPort: 10000
|
||||
macros:
|
||||
FLOAT: 3.14
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "bool macro",
|
||||
yaml: `
|
||||
startPort: 10000
|
||||
macros:
|
||||
BOOL: true
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "array macro (invalid)",
|
||||
yaml: `
|
||||
startPort: 10000
|
||||
macros:
|
||||
ARR: [1, 2, 3]
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
`,
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "map macro (invalid)",
|
||||
yaml: `
|
||||
startPort: 10000
|
||||
macros:
|
||||
MAP:
|
||||
key: value
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
`,
|
||||
shouldErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LoadConfigFromReader(strings.NewReader(tt.yaml))
|
||||
if tt.shouldErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build windows
|
||||
|
||||
package proxy
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -155,8 +155,8 @@ groups:
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
StartPort: 5800,
|
||||
Macros: map[string]string{
|
||||
"svr-path": "path/to/server",
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
@@ -0,0 +1,123 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Test macro-in-macro basic substitution
|
||||
func TestConfig_MacroInMacroBasic(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"A": "value-A"
|
||||
"B": "prefix-${A}-suffix"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: echo ${B}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "echo prefix-value-A-suffix", config.Models["test"].Cmd)
|
||||
}
|
||||
|
||||
// Test LIFO substitution order with 3+ macro levels
|
||||
func TestConfig_MacroInMacroLIFOOrder(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"base": "/models"
|
||||
"path": "${base}/llama"
|
||||
"full": "${path}/model.gguf"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: load ${full}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "load /models/llama/model.gguf", config.Models["test"].Cmd)
|
||||
}
|
||||
|
||||
// Test MODEL_ID in global macro used by model
|
||||
func TestConfig_ModelIdInGlobalMacro(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||
|
||||
models:
|
||||
my-model:
|
||||
cmd: ${podman-llama} -m model.gguf
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf", config.Models["my-model"].Cmd)
|
||||
}
|
||||
|
||||
// Test model macro overrides global macro in substitution
|
||||
func TestConfig_ModelMacroOverridesGlobal(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"tag": "global"
|
||||
"msg": "value-${tag}"
|
||||
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
"tag": "model-level"
|
||||
cmd: echo ${msg}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "echo value-model-level", config.Models["test"].Cmd)
|
||||
}
|
||||
|
||||
// Test self-reference detection error
|
||||
func TestConfig_SelfReferenceDetection(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"recursive": "value-${recursive}"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: echo ${recursive}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "recursive")
|
||||
assert.Contains(t, err.Error(), "self-reference")
|
||||
}
|
||||
|
||||
// Test undefined macro reference error
|
||||
func TestConfig_UndefinedMacroReference(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"A": "value-${UNDEFINED}"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: echo ${A}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "UNDEFINED")
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
CmdStop string `yaml:"cmdStop"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||
UnloadAfter int `yaml:"ttl"`
|
||||
Unlisted bool `yaml:"unlisted"`
|
||||
UseModelName string `yaml:"useModelName"`
|
||||
|
||||
// #179 for /v1/models
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
// Limit concurrency of HTTP requests to process
|
||||
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||
|
||||
// Model filters see issue #174
|
||||
Filters ModelFilters `yaml:"filters"`
|
||||
|
||||
// Macros: see #264
|
||||
// Model level macros take precedence over the global macros
|
||||
Macros MacroList `yaml:"macros"`
|
||||
|
||||
// Metadata: see #264
|
||||
// Arbitrary metadata that can be exposed through the API
|
||||
Metadata map[string]any `yaml:"metadata"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelConfig ModelConfig
|
||||
defaults := rawModelConfig{
|
||||
Cmd: "",
|
||||
CmdStop: "",
|
||||
Proxy: "http://localhost:${PORT}",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/health",
|
||||
UnloadAfter: 0,
|
||||
Unlisted: false,
|
||||
UseModelName: "",
|
||||
ConcurrencyLimit: 0,
|
||||
Name: "",
|
||||
Description: "",
|
||||
}
|
||||
|
||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||
if runtime.GOOS == "windows" {
|
||||
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*m = ModelConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
// ModelFilters see issue #174
|
||||
type ModelFilters struct {
|
||||
StripParams string `yaml:"stripParams"`
|
||||
}
|
||||
|
||||
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelFilters ModelFilters
|
||||
defaults := rawModelFilters{
|
||||
StripParams: "",
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Try to unmarshal with the old field name for backwards compatibility
|
||||
if defaults.StripParams == "" {
|
||||
var legacy struct {
|
||||
StripParams string `yaml:"strip_params"`
|
||||
}
|
||||
if legacyErr := unmarshal(&legacy); legacyErr != nil {
|
||||
return errors.New("failed to unmarshal legacy filters.strip_params: " + legacyErr.Error())
|
||||
}
|
||||
defaults.StripParams = legacy.StripParams
|
||||
}
|
||||
|
||||
*m = ModelFilters(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||
if f.StripParams == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
params := strings.Split(f.StripParams, ",")
|
||||
cleaned := make([]string, 0, len(params))
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, param := range params {
|
||||
trimmed := strings.TrimSpace(param)
|
||||
if trimmed == "model" || trimmed == "" || seen[trimmed] {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = true
|
||||
cleaned = append(cleaned, trimmed)
|
||||
}
|
||||
|
||||
// sort cleaned
|
||||
slices.Sort(cleaned)
|
||||
return cleaned, nil
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||
config := &ModelConfig{
|
||||
Cmd: `python model1.py \
|
||||
--arg1 value1 \
|
||||
--arg2 value2`,
|
||||
}
|
||||
|
||||
args, err := config.SanitizedCommand()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||
}
|
||||
|
||||
func TestConfig_ModelFilters(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
default_strip: "temperature, top_p"
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
# macros inserted and list is cleaned of duplicates and empty strings
|
||||
stripParams: "model, top_k, top_k, temperature, ${default_strip}, , ,"
|
||||
# check for strip_params (legacy field name) compatibility
|
||||
legacy:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
strip_params: "model, top_k, top_k, temperature, ${default_strip}, , ,"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
for modelId, modelConfig := range config.Models {
|
||||
t.Run(fmt.Sprintf("Testing macros in filters for model %s", modelId), func(t *testing.T) {
|
||||
assert.Equal(t, "model, top_k, top_k, temperature, temperature, top_p, , ,", modelConfig.Filters.StripParams)
|
||||
sanitized, err := modelConfig.Filters.SanitizedStripParams()
|
||||
if assert.NoError(t, err) {
|
||||
// model has been removed
|
||||
// empty strings have been removed
|
||||
// duplicates have been removed
|
||||
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -65,18 +66,18 @@ func getTestPort() int {
|
||||
return port
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
||||
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
|
||||
// Create a YAML string with just the values we want to set
|
||||
yamlStr := fmt.Sprintf(`
|
||||
cmd: '%s --port %d --silent --respond %s'
|
||||
proxy: "http://127.0.0.1:%d"
|
||||
`, simpleResponderPath, port, expectedMessage, port)
|
||||
|
||||
var cfg ModelConfig
|
||||
var cfg config.ModelConfig
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
// TokenMetrics represents parsed token statistics from llama-server logs
|
||||
@@ -38,7 +39,7 @@ type MetricsMonitor struct {
|
||||
nextID int
|
||||
}
|
||||
|
||||
func NewMetricsMonitor(config *Config) *MetricsMonitor {
|
||||
func NewMetricsMonitor(config *config.Config) *MetricsMonitor {
|
||||
maxMetrics := config.MetricsMaxInMemory
|
||||
if maxMetrics <= 0 {
|
||||
maxMetrics = 1000 // Default fallback
|
||||
|
||||
+94
-71
@@ -4,18 +4,19 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
type ProcessState string
|
||||
@@ -38,11 +39,13 @@ const (
|
||||
)
|
||||
|
||||
type Process struct {
|
||||
ID string
|
||||
config ModelConfig
|
||||
cmd *exec.Cmd
|
||||
ID string
|
||||
config config.ModelConfig
|
||||
cmd *exec.Cmd
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
|
||||
// PR #155 called to cancel the upstream process
|
||||
cmdMutex sync.RWMutex
|
||||
cancelUpstream context.CancelFunc
|
||||
|
||||
// closed when command exits
|
||||
@@ -54,12 +57,14 @@ type Process struct {
|
||||
healthCheckTimeout int
|
||||
healthCheckLoopInterval time.Duration
|
||||
|
||||
lastRequestHandled time.Time
|
||||
lastRequestHandledMutex sync.RWMutex
|
||||
lastRequestHandled time.Time
|
||||
|
||||
stateMutex sync.RWMutex
|
||||
state ProcessState
|
||||
|
||||
inFlightRequests sync.WaitGroup
|
||||
inFlightRequests sync.WaitGroup
|
||||
inFlightRequestsCount atomic.Int32
|
||||
|
||||
// used to block on multiple start() calls
|
||||
waitStarting sync.WaitGroup
|
||||
@@ -74,16 +79,35 @@ type Process struct {
|
||||
failedStartCount int
|
||||
}
|
||||
|
||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
||||
func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
||||
concurrentLimit := 10
|
||||
if config.ConcurrencyLimit > 0 {
|
||||
concurrentLimit = config.ConcurrencyLimit
|
||||
}
|
||||
|
||||
// Setup the reverse proxy.
|
||||
proxyURL, err := url.Parse(config.Proxy)
|
||||
if err != nil {
|
||||
proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, config.Proxy, err)
|
||||
}
|
||||
|
||||
var reverseProxy *httputil.ReverseProxy
|
||||
if proxyURL != nil {
|
||||
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// prevent nginx from buffering streaming responses (e.g., SSE)
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return &Process{
|
||||
ID: ID,
|
||||
config: config,
|
||||
cmd: nil,
|
||||
reverseProxy: reverseProxy,
|
||||
cancelUpstream: nil,
|
||||
processLogger: processLogger,
|
||||
proxyLogger: proxyLogger,
|
||||
@@ -106,6 +130,20 @@ func (p *Process) LogMonitor() *LogMonitor {
|
||||
return p.processLogger
|
||||
}
|
||||
|
||||
// setLastRequestHandled sets the last request handled time in a thread-safe manner.
|
||||
func (p *Process) setLastRequestHandled(t time.Time) {
|
||||
p.lastRequestHandledMutex.Lock()
|
||||
defer p.lastRequestHandledMutex.Unlock()
|
||||
p.lastRequestHandled = t
|
||||
}
|
||||
|
||||
// getLastRequestHandled gets the last request handled time in a thread-safe manner.
|
||||
func (p *Process) getLastRequestHandled() time.Time {
|
||||
p.lastRequestHandledMutex.RLock()
|
||||
defer p.lastRequestHandledMutex.RUnlock()
|
||||
return p.lastRequestHandled
|
||||
}
|
||||
|
||||
// custom error types for swapping state
|
||||
var (
|
||||
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
||||
@@ -129,6 +167,13 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
|
||||
}
|
||||
|
||||
p.state = newState
|
||||
|
||||
// Atomically increment waitStarting when entering StateStarting
|
||||
// This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented
|
||||
if newState == StateStarting {
|
||||
p.waitStarting.Add(1)
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
|
||||
return p.state, nil
|
||||
@@ -157,6 +202,15 @@ func (p *Process) CurrentState() ProcessState {
|
||||
return p.state
|
||||
}
|
||||
|
||||
// forceState forces the process state to the new state with mutex protection.
|
||||
// This should only be used in exceptional cases where the normal state transition
|
||||
// validation via swapState() cannot be used.
|
||||
func (p *Process) forceState(newState ProcessState) {
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
p.state = newState
|
||||
}
|
||||
|
||||
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
||||
// it is a private method because starting is automatic but stopping can be called
|
||||
// at any time.
|
||||
@@ -190,7 +244,7 @@ func (p *Process) start() error {
|
||||
}
|
||||
}
|
||||
|
||||
p.waitStarting.Add(1)
|
||||
// waitStarting.Add(1) is now called atomically in swapState() when transitioning to StateStarting
|
||||
defer p.waitStarting.Done()
|
||||
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
|
||||
|
||||
@@ -200,8 +254,11 @@ func (p *Process) start() error {
|
||||
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
|
||||
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
||||
p.cmd.WaitDelay = p.gracefulStopTimeout
|
||||
|
||||
p.cmdMutex.Lock()
|
||||
p.cancelUpstream = ctxCancelUpstream
|
||||
p.cmdWaitChan = make(chan struct{})
|
||||
p.cmdMutex.Unlock()
|
||||
|
||||
p.failedStartCount++ // this will be reset to zero when the process has successfully started
|
||||
|
||||
@@ -211,7 +268,7 @@ func (p *Process) start() error {
|
||||
// Set process state to failed
|
||||
if err != nil {
|
||||
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
|
||||
p.state = StateStopped // force it into a stopped state
|
||||
p.forceState(StateStopped) // force it into a stopped state
|
||||
return fmt.Errorf(
|
||||
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||
strings.Join(args, " "), err, curState, swapErr,
|
||||
@@ -284,10 +341,12 @@ func (p *Process) start() error {
|
||||
return
|
||||
}
|
||||
|
||||
// wait for all inflight requests to complete and ticker
|
||||
p.inFlightRequests.Wait()
|
||||
// skip the TTL check if there are inflight requests
|
||||
if p.inFlightRequestsCount.Load() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||
if time.Since(p.getLastRequestHandled()) > maxDuration {
|
||||
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
|
||||
p.Stop()
|
||||
return
|
||||
@@ -343,7 +402,7 @@ func (p *Process) Shutdown() {
|
||||
|
||||
p.stopCommand()
|
||||
// just force it to this state since there is no recovery from shutdown
|
||||
p.state = StateShutdown
|
||||
p.forceState(StateShutdown)
|
||||
}
|
||||
|
||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||
@@ -354,13 +413,18 @@ func (p *Process) stopCommand() {
|
||||
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||
}()
|
||||
|
||||
if p.cancelUpstream == nil {
|
||||
p.cmdMutex.RLock()
|
||||
cancelUpstream := p.cancelUpstream
|
||||
cmdWaitChan := p.cmdWaitChan
|
||||
p.cmdMutex.RUnlock()
|
||||
|
||||
if cancelUpstream == nil {
|
||||
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
|
||||
return
|
||||
}
|
||||
|
||||
p.cancelUpstream()
|
||||
<-p.cmdWaitChan
|
||||
cancelUpstream()
|
||||
<-cmdWaitChan
|
||||
}
|
||||
|
||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||
@@ -417,8 +481,10 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
p.inFlightRequests.Add(1)
|
||||
p.inFlightRequestsCount.Add(1)
|
||||
defer func() {
|
||||
p.lastRequestHandled = time.Now()
|
||||
p.setLastRequestHandled(time.Now())
|
||||
p.inFlightRequestsCount.Add(-1)
|
||||
p.inFlightRequests.Done()
|
||||
}()
|
||||
|
||||
@@ -433,56 +499,10 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
startDuration = time.Since(beginStartTime)
|
||||
}
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
req.Header = r.Header.Clone()
|
||||
|
||||
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
|
||||
if err == nil {
|
||||
req.ContentLength = contentLength
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
for k, vv := range resp.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
// prevent nginx from buffering streaming responses (e.g., SSE)
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
// faster than io.Copy when streaming
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
|
||||
return
|
||||
}
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
if p.reverseProxy != nil {
|
||||
p.reverseProxy.ServeHTTP(w, r)
|
||||
} else {
|
||||
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
totalTime := time.Since(requestBeginTime)
|
||||
@@ -518,13 +538,16 @@ func (p *Process) waitForCmd() {
|
||||
case StateStopping:
|
||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
|
||||
p.state = StateStopped
|
||||
p.forceState(StateStopped)
|
||||
}
|
||||
default:
|
||||
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
|
||||
p.state = StateStopped // force it to be in this state
|
||||
p.forceState(StateStopped) // force it to be in this state
|
||||
}
|
||||
|
||||
p.cmdMutex.Lock()
|
||||
close(p.cmdWaitChan)
|
||||
p.cmdMutex.Unlock()
|
||||
}
|
||||
|
||||
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
|
||||
@@ -539,7 +562,7 @@ func (p *Process) cmdStopUpstreamProcess() error {
|
||||
|
||||
if p.config.CmdStop != "" {
|
||||
// replace ${PID} with the pid of the process
|
||||
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
|
||||
stopArgs, err := config.SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
|
||||
if err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
|
||||
return err
|
||||
|
||||
+15
-12
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -90,7 +91,7 @@ func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
||||
// test that the automatic start returns the expected error type
|
||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
// Create a process configuration
|
||||
config := ModelConfig{
|
||||
config := config.ModelConfig{
|
||||
Cmd: "nonexistent-command",
|
||||
Proxy: "http://127.0.0.1:9913",
|
||||
CheckEndpoint: "/health",
|
||||
@@ -325,7 +326,7 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
||||
|
||||
// should run and exit but interrupt the long checkHealthTimeout
|
||||
checkHealthTimeout := 5
|
||||
config := ModelConfig{
|
||||
config := config.ModelConfig{
|
||||
Cmd: "sleep 1",
|
||||
Proxy: "http://127.0.0.1:9913",
|
||||
CheckEndpoint: "/health",
|
||||
@@ -402,7 +403,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
binaryPath := getSimpleResponderPath()
|
||||
port := getTestPort()
|
||||
|
||||
config := ModelConfig{
|
||||
conf := config.ModelConfig{
|
||||
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
||||
// to force the process to exit
|
||||
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
||||
@@ -410,7 +411,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||
process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// reduce to make testing go faster
|
||||
@@ -435,7 +436,9 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
|
||||
} else {
|
||||
assert.Contains(t, w.Body.String(), "unexpected EOF")
|
||||
// Upstream may be killed mid-response.
|
||||
// Assert an incomplete or partial response.
|
||||
assert.NotEqual(t, "12345", w.Body.String())
|
||||
}
|
||||
|
||||
close(waitChan)
|
||||
@@ -450,15 +453,15 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProcess_StopCmd(t *testing.T) {
|
||||
config := getTestSimpleResponderConfig("test_stop_cmd")
|
||||
conf := getTestSimpleResponderConfig("test_stop_cmd")
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
conf.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
} else {
|
||||
config.CmdStop = "kill -TERM ${PID}"
|
||||
conf.CmdStop = "kill -TERM ${PID}"
|
||||
}
|
||||
|
||||
process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger)
|
||||
process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
err := process.start()
|
||||
@@ -470,15 +473,15 @@ func TestProcess_StopCmd(t *testing.T) {
|
||||
|
||||
func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
|
||||
expectedMessage := "test_env_not_emptied"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
conf := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// ensure that the the default config does not blank out the inherited environment
|
||||
configWEnv := config
|
||||
configWEnv := conf
|
||||
|
||||
// ensure the additiona variables are appended to the process' environment
|
||||
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
|
||||
|
||||
process1 := NewProcess("env_test", 2, config, debugLogger, debugLogger)
|
||||
process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger)
|
||||
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
|
||||
|
||||
process1.start()
|
||||
|
||||
+27
-2
@@ -5,12 +5,14 @@ import (
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
type ProcessGroup struct {
|
||||
sync.Mutex
|
||||
|
||||
config Config
|
||||
config config.Config
|
||||
id string
|
||||
swap bool
|
||||
exclusive bool
|
||||
@@ -24,7 +26,7 @@ type ProcessGroup struct {
|
||||
lastUsedProcess string
|
||||
}
|
||||
|
||||
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
||||
func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
||||
groupConfig, ok := config.Groups[id]
|
||||
if !ok {
|
||||
panic("Unable to find configuration for group id: " + id)
|
||||
@@ -86,6 +88,29 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
|
||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
|
||||
pg.Lock()
|
||||
|
||||
process, exists := pg.processes[modelID]
|
||||
if !exists {
|
||||
pg.Unlock()
|
||||
return fmt.Errorf("process not found for %s", modelID)
|
||||
}
|
||||
|
||||
if pg.lastUsedProcess == modelID {
|
||||
pg.lastUsedProcess = ""
|
||||
}
|
||||
pg.Unlock()
|
||||
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
||||
pg.Lock()
|
||||
defer pg.Unlock()
|
||||
|
||||
@@ -7,19 +7,20 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
"model4": getTestSimpleResponderConfig("model4"),
|
||||
"model5": getTestSimpleResponderConfig("model5"),
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
@@ -34,7 +35,7 @@ var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
||||
})
|
||||
|
||||
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
||||
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
||||
pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
||||
assert.True(t, pg.HasMember("model5"))
|
||||
}
|
||||
|
||||
@@ -48,9 +49,9 @@ func TestProcessGroup_HasMember(t *testing.T) {
|
||||
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||
// and multiple requests are made in parallel, only one process is running at a time.
|
||||
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
// use the same listening so if a model is already running, it will fail
|
||||
// this is a way to test that swap isolation is working
|
||||
// properly when there are parallel requests made at the
|
||||
@@ -61,7 +62,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||
"model4": getTestSimpleResponderConfigPort("model4", 9832),
|
||||
"model5": getTestSimpleResponderConfigPort("model5", 9832),
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Members: []string{"model1", "model2", "model3", "model4", "model5"},
|
||||
|
||||
+29
-3
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -27,7 +28,7 @@ const (
|
||||
type ProxyManager struct {
|
||||
sync.Mutex
|
||||
|
||||
config Config
|
||||
config config.Config
|
||||
ginEngine *gin.Engine
|
||||
|
||||
// logging
|
||||
@@ -44,7 +45,7 @@ type ProxyManager struct {
|
||||
shutdownCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func New(config Config) *ProxyManager {
|
||||
func New(config config.Config) *ProxyManager {
|
||||
// set up loggers
|
||||
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
||||
@@ -228,7 +229,6 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
c.Redirect(http.StatusFound, "/ui/models")
|
||||
})
|
||||
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
|
||||
|
||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
||||
@@ -370,6 +370,13 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
record["description"] = desc
|
||||
}
|
||||
|
||||
// Add metadata if present
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
record["meta"] = gin.H{
|
||||
"llamaswap": modelConfig.Metadata,
|
||||
}
|
||||
}
|
||||
|
||||
data = append(data, record)
|
||||
}
|
||||
|
||||
@@ -420,6 +427,24 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
modelName = real
|
||||
remainingPath = "/" + strings.Join(parts[i+1:], "/")
|
||||
modelFound = true
|
||||
|
||||
// Check if this is exactly a model name with no additional path
|
||||
// and doesn't end with a trailing slash
|
||||
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
|
||||
// Build new URL with query parameters preserved
|
||||
newPath := "/upstream/" + searchModelName + "/"
|
||||
if c.Request.URL.RawQuery != "" {
|
||||
newPath += "?" + c.Request.URL.RawQuery
|
||||
}
|
||||
|
||||
// Use 308 for non-GET/HEAD requests to preserve method
|
||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
||||
c.Redirect(http.StatusMovedPermanently, newPath)
|
||||
} else {
|
||||
c.Redirect(http.StatusPermanentRedirect, newPath)
|
||||
}
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -439,6 +464,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
c.Request.URL.Path = remainingPath
|
||||
processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,8 +3,10 @@ package proxy
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
@@ -23,6 +25,7 @@ func addApiHandlers(pm *ProxyManager) {
|
||||
apiGroup := pm.ginEngine.Group("/api")
|
||||
{
|
||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
||||
apiGroup.GET("/events", pm.apiSendEvents)
|
||||
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||
}
|
||||
@@ -202,3 +205,25 @@ func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", jsonData)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
|
||||
requestedModel := strings.TrimPrefix(c.Param("model"), "/")
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, "Model not found")
|
||||
return
|
||||
}
|
||||
|
||||
processGroup := pm.findGroupByModelName(realModelName)
|
||||
if processGroup == nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error()))
|
||||
return
|
||||
} else {
|
||||
c.String(http.StatusOK, "OK")
|
||||
}
|
||||
}
|
||||
|
||||
+259
-84
@@ -16,14 +16,41 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// TestResponseRecorder adds CloseNotify to httptest.ResponseRecorder.
|
||||
// "If you want to write your own tests around streams you will need a Recorder that can handle CloseNotifier."
|
||||
// The tests can panic otherwise:
|
||||
// panic: interface conversion: *httptest.ResponseRecorder is not http.CloseNotifier: missing method CloseNotify
|
||||
// See: https://github.com/gin-gonic/gin/issues/1815
|
||||
// TestResponseRecorder is taken from gin's own tests: https://github.com/gin-gonic/gin/blob/ce20f107f5dc498ec7489d7739541a25dcd48463/context_test.go#L1747-L1765
|
||||
type TestResponseRecorder struct {
|
||||
*httptest.ResponseRecorder
|
||||
closeChannel chan bool
|
||||
}
|
||||
|
||||
func (r *TestResponseRecorder) CloseNotify() <-chan bool {
|
||||
return r.closeChannel
|
||||
}
|
||||
|
||||
func (r *TestResponseRecorder) closeClient() {
|
||||
r.closeChannel <- true
|
||||
}
|
||||
|
||||
func CreateTestResponseRecorder() *TestResponseRecorder {
|
||||
return &TestResponseRecorder{
|
||||
httptest.NewRecorder(),
|
||||
make(chan bool, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
@@ -36,7 +63,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||
for _, modelName := range []string{"model1", "model2"} {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -44,14 +71,14 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||
}
|
||||
}
|
||||
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
Groups: map[string]GroupConfig{
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
@@ -73,7 +100,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||
t.Run(requestedModel, func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -89,14 +116,14 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||
// Test that a persistent group is not affected by the swapping behaviour of
|
||||
// other groups.
|
||||
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
Groups: map[string]GroupConfig{
|
||||
Groups: map[string]config.GroupConfig{
|
||||
// the forever group is persistent and should not be affected by model1
|
||||
"forever": {
|
||||
Swap: true,
|
||||
@@ -115,7 +142,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
||||
for _, requestedModel := range tests {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -133,9 +160,9 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
@@ -158,7 +185,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
@@ -196,9 +223,9 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
model2Config.Name = " " // empty whitespace only strings will get ignored
|
||||
model2Config.Description = " "
|
||||
|
||||
config := Config{
|
||||
config := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": model1Config,
|
||||
"model2": model2Config,
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
@@ -211,7 +238,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
req.Header.Add("Origin", "i-am-the-origin")
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
// Call the listModelsHandler
|
||||
proxy.ServeHTTP(w, req)
|
||||
@@ -281,15 +308,99 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||
}
|
||||
|
||||
func TestProxyManager_ListModelsHandler_WithMetadata(t *testing.T) {
|
||||
// Process config through LoadConfigFromReader to apply macro substitution
|
||||
configYaml := `
|
||||
healthCheckTimeout: 15
|
||||
logLevel: error
|
||||
startPort: 10000
|
||||
models:
|
||||
model1:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
macros:
|
||||
PORT_NUM: 10001
|
||||
TEMP: 0.7
|
||||
NAME: "llama"
|
||||
metadata:
|
||||
port: ${PORT_NUM}
|
||||
temperature: ${TEMP}
|
||||
enabled: true
|
||||
note: "Running on port ${PORT_NUM}"
|
||||
nested:
|
||||
value: ${TEMP}
|
||||
model2:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
`
|
||||
processedConfig, err := config.LoadConfigFromReader(strings.NewReader(configYaml))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(processedConfig)
|
||||
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response struct {
|
||||
Data []map[string]any `json:"data"`
|
||||
}
|
||||
|
||||
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, response.Data, 2)
|
||||
|
||||
// Find model1 and model2 in response
|
||||
var model1Data, model2Data map[string]any
|
||||
for _, model := range response.Data {
|
||||
if model["id"] == "model1" {
|
||||
model1Data = model
|
||||
} else if model["id"] == "model2" {
|
||||
model2Data = model
|
||||
}
|
||||
}
|
||||
|
||||
// Verify model1 has llamaswap_meta
|
||||
assert.NotNil(t, model1Data)
|
||||
meta, exists := model1Data["meta"]
|
||||
if !assert.True(t, exists, "model1 should have meta key") {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
metaMap := meta.(map[string]any)
|
||||
|
||||
lsmeta, exists := metaMap["llamaswap"]
|
||||
if !assert.True(t, exists, "model1 should have meta.llamaswap key") {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
lsmetamap := lsmeta.(map[string]any)
|
||||
|
||||
// Verify type preservation
|
||||
assert.Equal(t, float64(10001), lsmetamap["port"]) // JSON numbers are float64
|
||||
assert.Equal(t, 0.7, lsmetamap["temperature"])
|
||||
assert.Equal(t, true, lsmetamap["enabled"])
|
||||
// Verify string interpolation
|
||||
assert.Equal(t, "Running on port 10001", lsmetamap["note"])
|
||||
// Verify nested structure
|
||||
nested := lsmetamap["nested"].(map[string]any)
|
||||
assert.Equal(t, 0.7, nested["value"])
|
||||
|
||||
// Verify model2 does NOT have llamaswap_meta
|
||||
assert.NotNil(t, model2Data)
|
||||
_, exists = model2Data["llamaswap_meta"]
|
||||
assert.False(t, exists, "model2 should not have llamaswap_meta")
|
||||
}
|
||||
|
||||
func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
|
||||
// Intentionally add models in non-sorted order and with an unlisted model
|
||||
config := Config{
|
||||
config := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"zeta": getTestSimpleResponderConfig("zeta"),
|
||||
"alpha": getTestSimpleResponderConfig("alpha"),
|
||||
"beta": getTestSimpleResponderConfig("beta"),
|
||||
"hidden": func() ModelConfig {
|
||||
"hidden": func() config.ModelConfig {
|
||||
mc := getTestSimpleResponderConfig("hidden")
|
||||
mc.Unlisted = true
|
||||
return mc
|
||||
@@ -302,7 +413,7 @@ func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
|
||||
|
||||
// Request models list
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -337,15 +448,15 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
||||
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
||||
model3Config.Proxy = "http://localhost:10003/"
|
||||
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": model1Config,
|
||||
"model2": model2Config,
|
||||
"model3": model3Config,
|
||||
},
|
||||
LogLevel: "error",
|
||||
Groups: map[string]GroupConfig{
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"test": {
|
||||
Swap: false,
|
||||
Members: []string{"model1", "model2", "model3"},
|
||||
@@ -363,7 +474,7 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
||||
defer wg.Done()
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
// send a request to trigger the proxy to load ... this should hang waiting for start up
|
||||
proxy.ServeHTTP(w, req)
|
||||
@@ -380,38 +491,92 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_Unload(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
conf := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
proxy := New(conf)
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
||||
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
||||
req = httptest.NewRequest("GET", "/unload", nil)
|
||||
w = httptest.NewRecorder()
|
||||
w = CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, w.Body.String(), "OK")
|
||||
|
||||
// give it a bit of time to stop
|
||||
<-time.After(time.Millisecond * 250)
|
||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
|
||||
select {
|
||||
case <-proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan:
|
||||
// good
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for model1 to stop")
|
||||
}
|
||||
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
func TestProxyManager_UnloadSingleModel(t *testing.T) {
|
||||
const testGroupId = "testGroup"
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
testGroupId: {
|
||||
Swap: false,
|
||||
Members: []string{"model1", "model2"},
|
||||
},
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
// start both model
|
||||
for _, modelName := range []string{"model1", "model2"} {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model1"].CurrentState())
|
||||
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState())
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/models/unload/model1", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
if !assert.Equal(t, w.Body.String(), "OK") {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-proxy.processGroups[testGroupId].processes["model1"].cmdWaitChan:
|
||||
// good
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for model1 to stop")
|
||||
}
|
||||
|
||||
assert.Equal(t, proxy.processGroups[testGroupId].processes["model1"].CurrentState(), StateStopped)
|
||||
assert.Equal(t, proxy.processGroups[testGroupId].processes["model2"].CurrentState(), StateReady)
|
||||
}
|
||||
|
||||
// Test issue #61 `Listing the current list of models and the loaded model.`
|
||||
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
// Shared configuration
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
@@ -432,7 +597,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
|
||||
t.Run("no models loaded", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/running", nil)
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -450,13 +615,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
// Load just a model.
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Simulate browser call for the `/running` endpoint.
|
||||
req = httptest.NewRequest("GET", "/running", nil)
|
||||
w = httptest.NewRecorder()
|
||||
w = CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
var response RunningResponse
|
||||
@@ -474,9 +639,9 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
@@ -508,7 +673,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
rec := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
|
||||
// Verify the response
|
||||
@@ -527,15 +692,15 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
||||
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
||||
modelConfig.UseModelName = upstreamModelName
|
||||
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
conf := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": modelConfig,
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
proxy := New(conf)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
requestedModel := "model1"
|
||||
@@ -543,7 +708,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
||||
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -577,7 +742,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
rec := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
|
||||
// Verify the response
|
||||
@@ -590,9 +755,9 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
@@ -645,7 +810,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
@@ -666,14 +831,14 @@ models:
|
||||
aliases: [model-alias]
|
||||
`, getSimpleResponderPath())
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(configStr))
|
||||
config, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
t.Run("main model name", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
rec := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "model1", rec.Body.String())
|
||||
@@ -681,7 +846,7 @@ models:
|
||||
|
||||
t.Run("model alias", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
rec := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "model1", rec.Body.String())
|
||||
@@ -689,9 +854,9 @@ models:
|
||||
}
|
||||
|
||||
func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
@@ -702,7 +867,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
|
||||
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -714,14 +879,14 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
|
||||
func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
modelConfig := getTestSimpleResponderConfig("model1")
|
||||
modelConfig.Filters = ModelFilters{
|
||||
modelConfig.Filters = config.ModelFilters{
|
||||
StripParams: "temperature, model, stream",
|
||||
}
|
||||
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
LogLevel: "error",
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": modelConfig,
|
||||
},
|
||||
})
|
||||
@@ -730,7 +895,7 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -747,9 +912,9 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
@@ -761,7 +926,7 @@ func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
|
||||
// Make a non-streaming request
|
||||
reqBody := `{"model":"model1", "stream": false}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -782,9 +947,9 @@ func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
@@ -796,7 +961,7 @@ func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
|
||||
// Make a streaming request
|
||||
reqBody := `{"model":"model1", "stream": true}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -817,9 +982,9 @@ func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_HealthEndpoint(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
@@ -828,7 +993,7 @@ func TestProxyManager_HealthEndpoint(t *testing.T) {
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
rec := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "OK", rec.Body.String())
|
||||
@@ -836,9 +1001,9 @@ func TestProxyManager_HealthEndpoint(t *testing.T) {
|
||||
|
||||
// Ensure the custom llama-server /completion endpoint proxies correctly
|
||||
func TestProxyManager_CompletionEndpoint(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
@@ -849,7 +1014,7 @@ func TestProxyManager_CompletionEndpoint(t *testing.T) {
|
||||
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -882,7 +1047,7 @@ models:
|
||||
`, "${simpleresponderpath}", simpleResponderPath, -1)
|
||||
|
||||
// Create a test model configuration
|
||||
config, err := LoadConfigFromReader(strings.NewReader(configStr))
|
||||
config, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
if !assert.NoError(t, err, "Invalid configuration") {
|
||||
return
|
||||
}
|
||||
@@ -916,9 +1081,9 @@ models:
|
||||
}
|
||||
|
||||
func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
@@ -936,18 +1101,28 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
||||
|
||||
for _, endpoint := range endpoints {
|
||||
t.Run(endpoint, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
req := httptest.NewRequest("GET", endpoint, nil)
|
||||
req = req.WithContext(ctx)
|
||||
rec := httptest.NewRecorder()
|
||||
rec := CreateTestResponseRecorder()
|
||||
|
||||
// We don't need the handler to fully complete, just to set the headers
|
||||
// so run it in a goroutine and check the headers after a short delay
|
||||
go proxy.ServeHTTP(rec, req)
|
||||
time.Sleep(10 * time.Millisecond) // give it time to start and write headers
|
||||
// Run handler in goroutine and wait for context timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
proxy.ServeHTTP(rec, req)
|
||||
}()
|
||||
|
||||
// Wait for either the handler to complete or context to timeout
|
||||
<-ctx.Done()
|
||||
|
||||
// At this point, the handler has either finished or been cancelled
|
||||
// Wait for the goroutine to fully exit before reading
|
||||
<-done
|
||||
|
||||
// Now it's safe to read from rec - no more concurrent writes
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
|
||||
})
|
||||
@@ -955,9 +1130,9 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"streaming-model": getTestSimpleResponderConfig("streaming-model"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
@@ -970,7 +1145,7 @@ func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testin
|
||||
reqBody := `{"model":"streaming-model"}`
|
||||
// simple-responder will return text/event-stream when stream=true is in the query
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
||||
rec := httptest.NewRecorder()
|
||||
rec := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(rec, req)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import { useTheme } from "../contexts/ThemeProvider";
|
||||
import ConnectionStatusIcon from "./ConnectionStatus";
|
||||
|
||||
export function Header() {
|
||||
const { screenWidth, toggleTheme, isDarkMode, appTitle, setAppTitle } = useTheme();
|
||||
const { screenWidth, toggleTheme, isDarkMode, appTitle, setAppTitle, isNarrow } = useTheme();
|
||||
const handleTitleChange = useCallback(
|
||||
(newTitle: string) => {
|
||||
setAppTitle(newTitle.replace(/\n/g, "").trim().substring(0, 64) || "llama-swap");
|
||||
@@ -17,7 +17,7 @@ export function Header() {
|
||||
`text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 ${isActive ? "font-semibold" : ""}`;
|
||||
|
||||
return (
|
||||
<header className="flex items-center justify-between bg-surface border-b border-border p-2 px-4 h-[75px]">
|
||||
<header className={`flex items-center justify-between bg-surface border-b border-border px-4 ${isNarrow ? "py-1 h-[60px]" : "p-2 h-[75px]"}`}>
|
||||
{screenWidth !== "xs" && screenWidth !== "sm" && (
|
||||
<h1
|
||||
contentEditable
|
||||
|
||||
@@ -16,6 +16,7 @@ interface APIProviderType {
|
||||
models: Model[];
|
||||
listModels: () => Promise<Model[]>;
|
||||
unloadAllModels: () => Promise<void>;
|
||||
unloadSingleModel: (model: string) => Promise<void>;
|
||||
loadModel: (model: string) => Promise<void>;
|
||||
enableAPIEvents: (enabled: boolean) => void;
|
||||
proxyLogs: string;
|
||||
@@ -177,7 +178,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
||||
|
||||
const unloadAllModels = useCallback(async () => {
|
||||
try {
|
||||
const response = await fetch(`/api/models/unload/`, {
|
||||
const response = await fetch(`/api/models/unload`, {
|
||||
method: "POST",
|
||||
});
|
||||
if (!response.ok) {
|
||||
@@ -189,6 +190,20 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
||||
}
|
||||
}, []);
|
||||
|
||||
const unloadSingleModel = useCallback(async (model: string) => {
|
||||
try {
|
||||
const response = await fetch(`/api/models/unload/${model}`, {
|
||||
method: "POST",
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to unload model: ${response.status}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to unload model", model, error);
|
||||
throw error;
|
||||
}
|
||||
}, []);
|
||||
|
||||
const loadModel = useCallback(async (model: string) => {
|
||||
try {
|
||||
const response = await fetch(`/upstream/${model}/`, {
|
||||
@@ -208,6 +223,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
||||
models,
|
||||
listModels,
|
||||
unloadAllModels,
|
||||
unloadSingleModel,
|
||||
loadModel,
|
||||
enableAPIEvents,
|
||||
proxyLogs,
|
||||
|
||||
+82
-30
@@ -4,7 +4,7 @@ import { LogPanel } from "./LogViewer";
|
||||
import { usePersistentState } from "../hooks/usePersistentState";
|
||||
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
|
||||
import { useTheme } from "../contexts/ThemeProvider";
|
||||
import { RiEyeFill, RiEyeOffFill, RiStopCircleLine, RiSwapBoxFill } from "react-icons/ri";
|
||||
import { RiEyeFill, RiEyeOffFill, RiSwapBoxFill, RiEjectLine, RiMenuFill } from "react-icons/ri";
|
||||
|
||||
export default function ModelsPage() {
|
||||
const { isNarrow } = useTheme();
|
||||
@@ -37,10 +37,12 @@ export default function ModelsPage() {
|
||||
}
|
||||
|
||||
function ModelsPanel() {
|
||||
const { models, loadModel, unloadAllModels } = useAPI();
|
||||
const { models, loadModel, unloadAllModels, unloadSingleModel } = useAPI();
|
||||
const { isNarrow } = useTheme();
|
||||
const [isUnloading, setIsUnloading] = useState(false);
|
||||
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
|
||||
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
|
||||
const [menuOpen, setMenuOpen] = useState(false);
|
||||
|
||||
const filteredModels = useMemo(() => {
|
||||
return models.filter((model) => showUnlisted || !model.unlisted);
|
||||
@@ -66,33 +68,77 @@ function ModelsPanel() {
|
||||
return (
|
||||
<div className="card h-full flex flex-col">
|
||||
<div className="shrink-0">
|
||||
<h2>Models</h2>
|
||||
<div className="flex justify-between">
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
className="btn text-base flex items-center gap-2"
|
||||
onClick={toggleIdorName}
|
||||
style={{ lineHeight: "1.2" }}
|
||||
>
|
||||
<RiSwapBoxFill size="20" /> {showIdorName === "id" ? "ID" : "Name"}
|
||||
</button>
|
||||
<div className="flex justify-between items-baseline">
|
||||
<h2 className={isNarrow ? "text-xl" : ""}>Models</h2>
|
||||
{isNarrow && (
|
||||
<div className="relative">
|
||||
<button className="btn text-base flex items-center gap-2 py-1" onClick={() => setMenuOpen(!menuOpen)}>
|
||||
<RiMenuFill size="20" />
|
||||
</button>
|
||||
{menuOpen && (
|
||||
<div className="absolute right-0 mt-2 w-48 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-20">
|
||||
<button
|
||||
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||
onClick={() => {
|
||||
toggleIdorName();
|
||||
setMenuOpen(false);
|
||||
}}
|
||||
>
|
||||
<RiSwapBoxFill size="20" /> {showIdorName === "id" ? "Show Name" : "Show ID"}
|
||||
</button>
|
||||
<button
|
||||
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||
onClick={() => {
|
||||
setShowUnlisted(!showUnlisted);
|
||||
setMenuOpen(false);
|
||||
}}
|
||||
>
|
||||
{showUnlisted ? <RiEyeOffFill size="20" /> : <RiEyeFill size="20" />}{" "}
|
||||
{showUnlisted ? "Hide Unlisted" : "Show Unlisted"}
|
||||
</button>
|
||||
<button
|
||||
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||
onClick={() => {
|
||||
handleUnloadAllModels();
|
||||
setMenuOpen(false);
|
||||
}}
|
||||
disabled={isUnloading}
|
||||
>
|
||||
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{!isNarrow && (
|
||||
<div className="flex justify-between">
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
className="btn text-base flex items-center gap-2"
|
||||
onClick={toggleIdorName}
|
||||
style={{ lineHeight: "1.2" }}
|
||||
>
|
||||
<RiSwapBoxFill size="20" /> {showIdorName === "id" ? "ID" : "Name"}
|
||||
</button>
|
||||
|
||||
<button
|
||||
className="btn text-base flex items-center gap-2"
|
||||
onClick={() => setShowUnlisted(!showUnlisted)}
|
||||
style={{ lineHeight: "1.2" }}
|
||||
>
|
||||
{showUnlisted ? <RiEyeFill size="20" /> : <RiEyeOffFill size="20" />} unlisted
|
||||
</button>
|
||||
</div>
|
||||
<button
|
||||
className="btn text-base flex items-center gap-2"
|
||||
onClick={() => setShowUnlisted(!showUnlisted)}
|
||||
style={{ lineHeight: "1.2" }}
|
||||
onClick={handleUnloadAllModels}
|
||||
disabled={isUnloading}
|
||||
>
|
||||
{showUnlisted ? <RiEyeFill size="20" /> : <RiEyeOffFill size="20" />} unlisted
|
||||
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
|
||||
</button>
|
||||
</div>
|
||||
<button
|
||||
className="btn text-base flex items-center gap-2"
|
||||
onClick={handleUnloadAllModels}
|
||||
disabled={isUnloading}
|
||||
>
|
||||
<RiStopCircleLine size="24" /> {isUnloading ? "Unloading..." : "Unload"}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex-1 overflow-y-auto">
|
||||
@@ -119,13 +165,19 @@ function ModelsPanel() {
|
||||
)}
|
||||
</td>
|
||||
<td className="w-12">
|
||||
<button
|
||||
className="btn btn--sm"
|
||||
disabled={model.state !== "stopped"}
|
||||
onClick={() => loadModel(model.id)}
|
||||
>
|
||||
Load
|
||||
</button>
|
||||
{model.state === "stopped" ? (
|
||||
<button className="btn btn--sm" onClick={() => loadModel(model.id)}>
|
||||
Load
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
className="btn btn--sm"
|
||||
onClick={() => unloadSingleModel(model.id)}
|
||||
disabled={model.state !== "ready"}
|
||||
>
|
||||
Unload
|
||||
</button>
|
||||
)}
|
||||
</td>
|
||||
<td className="w-20">
|
||||
<span className={`w-16 text-center status status--${model.state}`}>{model.state}</span>
|
||||
|
||||
@@ -15,6 +15,7 @@ export default defineConfig({
|
||||
"/api": "http://localhost:8080", // Proxy API calls to Go backend during development
|
||||
"/logs": "http://localhost:8080",
|
||||
"/upstream": "http://localhost:8080",
|
||||
"/unload": "http://localhost:8080",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user