Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 00b738cd0f | |||
| 70930e4e91 | |||
| 1f6179110c | |||
| 216c40b951 |
@@ -0,0 +1,43 @@
|
|||||||
|
# Project: llama-swap
|
||||||
|
|
||||||
|
## Project Description:
|
||||||
|
|
||||||
|
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||||
|
|
||||||
|
## Tech stack
|
||||||
|
|
||||||
|
- golang
|
||||||
|
- typescript, vite and react for UI (ui/)
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
- `make test-dev` - Use this when making iterative changes. Runs `go test` and `staticcheck`. Fix any static checking errors.
|
||||||
|
- `make test-all` - runs at the end before completing work. Includes long running concurrency tests.
|
||||||
|
|
||||||
|
## Workflow Tasks
|
||||||
|
|
||||||
|
### Plan Improvements
|
||||||
|
|
||||||
|
Work plans are located in ai-plans/. Plans written by the user may be incomplete, contain inconsistencies or errors.
|
||||||
|
|
||||||
|
When the user asks to improve a plan follow these guidelines for expanding and improving it.
|
||||||
|
|
||||||
|
- Identify any inconsistencies.
|
||||||
|
- Expand plans out to be detailed specification of requirements and changes to be made.
|
||||||
|
- Plans should have at least these sections:
|
||||||
|
- Title - very short, describes changes
|
||||||
|
- Overview: A more detailed summary of goal and outcomes desired
|
||||||
|
- Design Requirements: Detailed descriptions of what needs to be done
|
||||||
|
- Testing Plan: Tests to be implemented
|
||||||
|
- Checklist: A detailed list of changes to be made
|
||||||
|
|
||||||
|
Look for "plan expansion" as explicit instructions to improve a plan.
|
||||||
|
|
||||||
|
### Implementation of plans
|
||||||
|
|
||||||
|
When the user says "paint it", respond with "commencing automated assembly". Then implement the changes as described by the plan. Update the checklist as you complete items.
|
||||||
|
|
||||||
|
## General Rules
|
||||||
|
|
||||||
|
- when summarizing changes only include details that require further action (action items)
|
||||||
|
- when there are no action items, just say "Done."
|
||||||
@@ -23,11 +23,17 @@ proxy/ui_dist/placeholder.txt:
|
|||||||
mkdir -p proxy/ui_dist
|
mkdir -p proxy/ui_dist
|
||||||
touch $@
|
touch $@
|
||||||
|
|
||||||
test: proxy/ui_dist/placeholder.txt
|
# use cached test results while developing
|
||||||
go test -short -v -count=1 ./proxy
|
test-dev: proxy/ui_dist/placeholder.txt
|
||||||
|
go test -short ./proxy/...
|
||||||
|
staticcheck ./proxy/... || true
|
||||||
|
|
||||||
|
test: proxy/ui_dist/placeholder.txt
|
||||||
|
go test -short -count=1 ./proxy/...
|
||||||
|
|
||||||
|
# for CI - full test (takes longer)
|
||||||
test-all: proxy/ui_dist/placeholder.txt
|
test-all: proxy/ui_dist/placeholder.txt
|
||||||
go test -v -count=1 ./proxy
|
go test -count=1 ./proxy/...
|
||||||
|
|
||||||
ui/node_modules:
|
ui/node_modules:
|
||||||
cd ui && npm install
|
cd ui && npm install
|
||||||
@@ -81,4 +87,4 @@ release:
|
|||||||
git tag "$$new_tag";
|
git tag "$$new_tag";
|
||||||
|
|
||||||
# Phony targets
|
# Phony targets
|
||||||
.PHONY: all clean ui mac linux windows simple-responder
|
.PHONY: all clean ui mac linux windows simple-responder test test-all test-dev
|
||||||
|
|||||||
@@ -0,0 +1,292 @@
|
|||||||
|
# Add Model Metadata Support with Typed Macros
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Implement support for arbitrary metadata on model configurations that can be exposed through the `/v1/models` API endpoint. This feature extends the existing macro system to support scalar types (string, int, float, bool) instead of only strings, enabling type-safe metadata values.
|
||||||
|
|
||||||
|
The metadata will be schemaless, allowing users to define any key-value pairs they need. Macro substitution will work within metadata values, preserving types when macros are used directly and converting to strings when macros are interpolated within strings.
|
||||||
|
|
||||||
|
## Design Requirements
|
||||||
|
|
||||||
|
### 1. Enhanced Macro System
|
||||||
|
|
||||||
|
**Current State:**
|
||||||
|
|
||||||
|
- Macros are defined as `map[string]string` at both global and model levels
|
||||||
|
- Only string substitution is supported
|
||||||
|
- Macros are replaced in: `cmd`, `cmdStop`, `proxy`, `checkEndpoint`, `filters.stripParams`
|
||||||
|
|
||||||
|
**Required Changes:**
|
||||||
|
|
||||||
|
- Change `MacroList` type from `map[string]string` to `map[string]any`
|
||||||
|
- Support scalar types: `string`, `int`, `float64`, `bool`
|
||||||
|
- Implement type-preserving macro substitution:
|
||||||
|
- Direct macro usage (`key: ${macro}`) preserves the macro's type
|
||||||
|
- Interpolated usage (`key: "text ${macro}"`) converts to string
|
||||||
|
- Add validation to ensure macro values are scalar types only
|
||||||
|
- Update existing macro substitution logic in [proxy/config/config.go](proxy/config/config.go) to handle `any` types
|
||||||
|
|
||||||
|
**Implementation Details:**
|
||||||
|
|
||||||
|
- Create a generic helper function to perform macro substitution that:
|
||||||
|
- Takes a value of type `any`
|
||||||
|
- Recursively processes maps, slices, and scalar values
|
||||||
|
- Replaces `${macro_name}` patterns with macro values
|
||||||
|
- Preserves types for direct substitution
|
||||||
|
- Converts to strings for interpolated substitution
|
||||||
|
- Update `validateMacro()` function to accept `any` type and validate scalar types
|
||||||
|
- Maintain backward compatibility with existing string-only macros
|
||||||
|
|
||||||
|
### 2. Metadata Field in ModelConfig
|
||||||
|
|
||||||
|
**Location:** [proxy/config/model_config.go](proxy/config/model_config.go)
|
||||||
|
|
||||||
|
**Required Changes:**
|
||||||
|
|
||||||
|
- Add `Metadata map[string]any` field to `ModelConfig` struct
|
||||||
|
- Support YAML unmarshaling of arbitrary structures (maps, arrays, scalars)
|
||||||
|
- Apply macro substitution to metadata values during config loading
|
||||||
|
|
||||||
|
**Schema Requirements:**
|
||||||
|
|
||||||
|
- Metadata is optional (default: empty/nil map)
|
||||||
|
- Supports nested structures (objects within objects, arrays, etc.)
|
||||||
|
- All string values within metadata undergo macro substitution
|
||||||
|
- Type preservation rules apply as described above
|
||||||
|
|
||||||
|
### 3. Macro Substitution in Metadata
|
||||||
|
|
||||||
|
**Location:** [proxy/config/config.go](proxy/config/config.go) in `LoadConfigFromReader()`
|
||||||
|
|
||||||
|
**Process Flow:**
|
||||||
|
|
||||||
|
1. After loading YAML configuration
|
||||||
|
2. After model-level and global macro merging
|
||||||
|
3. Apply macro substitution to `ModelConfig.Metadata` field
|
||||||
|
4. Use the same merged macros available to `cmd`, `proxy`, etc.
|
||||||
|
5. Process recursively through all nested structures
|
||||||
|
|
||||||
|
**Substitution Rules:**
|
||||||
|
|
||||||
|
- `port: ${PORT}` → keeps integer type from PORT macro
|
||||||
|
- `temperature: ${temp}` → keeps float type from temp macro
|
||||||
|
- `note: "Running on ${PORT}"` → converts to string `"Running on 10001"`
|
||||||
|
- Arrays and nested objects are processed recursively
|
||||||
|
- Unknown macros should cause configuration load error (consistent with existing behavior)
|
||||||
|
|
||||||
|
### 4. API Response Updates
|
||||||
|
|
||||||
|
**Location:** [proxy/proxymanager.go:350](proxy/proxymanager.go#L350) `listModelsHandler()`
|
||||||
|
|
||||||
|
**Current Behavior:**
|
||||||
|
|
||||||
|
- Returns model records with: `id`, `object`, `created`, `owned_by`
|
||||||
|
- Optionally includes: `name`, `description`
|
||||||
|
|
||||||
|
**Required Changes:**
|
||||||
|
|
||||||
|
- Add metadata to each model record under the key `llamaswap_meta`
|
||||||
|
- Only include `llamaswap_meta` if metadata is non-empty
|
||||||
|
- Preserve all types when marshaling to JSON
|
||||||
|
- Maintain existing sorting by model ID
|
||||||
|
|
||||||
|
**Example Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": "llama",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1234567890,
|
||||||
|
"owned_by": "llama-swap",
|
||||||
|
"name": "llama 3.1 8B",
|
||||||
|
"description": "A small but capable model",
|
||||||
|
"llamaswap_meta": {
|
||||||
|
"port": 10001,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"note": "The llama is running on port 10001 temp=0.7, context=16384",
|
||||||
|
"a_list": [1, 1.23, "macros are OK in list and dictionary types: llama"],
|
||||||
|
"an_obj": {
|
||||||
|
"a": "1",
|
||||||
|
"b": 2,
|
||||||
|
"c": [0.7, false, "model: llama"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Validation and Error Handling
|
||||||
|
|
||||||
|
**Macro Validation:**
|
||||||
|
|
||||||
|
- Extend `validateMacro()` to accept values of type `any`
|
||||||
|
- Verify macro values are scalar types: `string`, `int`, `float64`, `bool`
|
||||||
|
- Reject complex types (maps, slices, structs) as macro values
|
||||||
|
- Maintain existing validation for macro names and lengths
|
||||||
|
|
||||||
|
**Configuration Loading:**
|
||||||
|
|
||||||
|
- Fail fast if unknown macros are found in metadata
|
||||||
|
- Provide clear error messages indicating which model and field contains errors
|
||||||
|
- Ensure macros in metadata follow same rules as macros in cmd/proxy fields
|
||||||
|
|
||||||
|
## Testing Plan
|
||||||
|
|
||||||
|
### Test 1: Model-Level Macros with Different Types
|
||||||
|
|
||||||
|
**File:** [proxy/config/model_config_test.go](proxy/config/model_config_test.go)
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Define model with macros of each scalar type
|
||||||
|
- Verify metadata correctly substitutes and preserves types
|
||||||
|
- Test direct substitution (`port: ${PORT}`)
|
||||||
|
- Test string interpolation (`note: "Port is ${PORT}"`)
|
||||||
|
- Verify nested objects and arrays work correctly
|
||||||
|
|
||||||
|
### Test 2: Global and Model Macro Precedence
|
||||||
|
|
||||||
|
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Define same macro at global and model level with different types
|
||||||
|
- Verify model-level macro takes precedence
|
||||||
|
- Test metadata uses correct macro value
|
||||||
|
- Verify type is preserved from the winning macro
|
||||||
|
|
||||||
|
### Test 3: Macro Validation
|
||||||
|
|
||||||
|
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Test that complex types (maps, arrays) are rejected as macro values
|
||||||
|
- Verify error message includes: macro name and type that was rejected
|
||||||
|
- Test that scalar types (string, int, float, bool) are accepted
|
||||||
|
- Each type should load without error
|
||||||
|
- Test macro name validation still works with `any` types
|
||||||
|
- Invalid characters, reserved names, length limits should still be enforced
|
||||||
|
|
||||||
|
### Test 4: Metadata in API Response
|
||||||
|
|
||||||
|
**File:** [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
|
||||||
|
|
||||||
|
**Existing Test:** `TestProxyManager_ListModelsHandler`
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Model with metadata → verify `llamaswap_meta` key appears
|
||||||
|
- Model without metadata → verify `llamaswap_meta` key is absent
|
||||||
|
- Verify all types are correctly marshaled to JSON
|
||||||
|
- Verify nested structures are preserved
|
||||||
|
- Verify macro substitution has occurred before serialization
|
||||||
|
|
||||||
|
### Test 5: Unknown Macros in Metadata
|
||||||
|
|
||||||
|
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Use undefined macro in metadata
|
||||||
|
- Verify configuration loading fails with clear error
|
||||||
|
- Error should indicate model name and that macro is undefined
|
||||||
|
|
||||||
|
### Test 6: Recursive Substitution
|
||||||
|
|
||||||
|
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Metadata with deeply nested structures
|
||||||
|
- Arrays containing objects with macros
|
||||||
|
- Objects containing arrays with macros
|
||||||
|
- Mixed string interpolation and direct substitution at various nesting levels
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
### Configuration Schema Changes
|
||||||
|
|
||||||
|
- [x] Change `MacroList` type from `map[string]string` to `map[string]any` in [proxy/config/config.go:19](proxy/config/config.go#L19)
|
||||||
|
- [x] Add `Metadata map[string]any` field to `ModelConfig` struct in [proxy/config/model_config.go:37](proxy/config/model_config.go#L37)
|
||||||
|
- [x] Update `validateMacro()` function signature to accept `any` type for values
|
||||||
|
- [x] Add validation logic to ensure macro values are scalar types only
|
||||||
|
|
||||||
|
### Macro Substitution Logic
|
||||||
|
|
||||||
|
- [x] Create generic recursive function `substituteMetadataMacros()` to handle `any` types
|
||||||
|
- [x] Implement type-preserving direct substitution logic
|
||||||
|
- [x] Implement string interpolation with type conversion
|
||||||
|
- [x] Handle maps: recursively process all values
|
||||||
|
- [x] Handle slices: recursively process all elements
|
||||||
|
- [x] Handle scalar types: perform string-based macro substitution if value is string
|
||||||
|
- [x] Integrate macro substitution into `LoadConfigFromReader()` after existing macro expansion
|
||||||
|
- [x] Update existing macro substitution calls to use merged macros with correct types
|
||||||
|
|
||||||
|
### API Response Changes
|
||||||
|
|
||||||
|
- [x] Modify `listModelsHandler()` in [proxy/proxymanager.go:350](proxy/proxymanager.go#L350)
|
||||||
|
- [x] Add `llamaswap_meta` field to model records when metadata exists
|
||||||
|
- [x] Ensure empty metadata results in omitted `llamaswap_meta` key
|
||||||
|
- [x] Verify JSON marshaling preserves all types correctly
|
||||||
|
|
||||||
|
### Testing - Config Package
|
||||||
|
|
||||||
|
- [x] Add test for string macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for int macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for float macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for bool macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for string interpolation in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for model-level macro precedence: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for nested structures in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for unknown macro in metadata (should error): [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for invalid macro type validation: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
|
||||||
|
### Testing - Model Config Package
|
||||||
|
|
||||||
|
- [x] Add test cases to [proxy/config/model_config_test.go](proxy/config/model_config_test.go) for metadata unmarshaling
|
||||||
|
- [x] Test metadata with various scalar types
|
||||||
|
- [x] Test metadata with nested objects and arrays
|
||||||
|
|
||||||
|
### Testing - Proxy Manager
|
||||||
|
|
||||||
|
- [x] Update `TestProxyManager_ListModelsHandler` in [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
|
||||||
|
- [x] Add test case for model with metadata
|
||||||
|
- [x] Add test case for model without metadata
|
||||||
|
- [x] Verify `llamaswap_meta` key presence/absence
|
||||||
|
- [x] Verify type preservation in JSON output
|
||||||
|
- [x] Verify macro substitution has occurred
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
|
||||||
|
- [x] Verify [config.example.yaml](config.example.yaml) already has complete metadata examples (lines 149-171)
|
||||||
|
- [x] No additional documentation needed per project instructions
|
||||||
|
|
||||||
|
## Known Issues and Considerations
|
||||||
|
|
||||||
|
### Inconsistencies
|
||||||
|
|
||||||
|
None identified. The plan references the correct existing example in [config.example.yaml:149-171](config.example.yaml#L149-L171).
|
||||||
|
|
||||||
|
### Design Decisions
|
||||||
|
|
||||||
|
1. **Why `llamaswap_meta` instead of merging into record?**
|
||||||
|
|
||||||
|
- Avoids potential collisions with OpenAI API standard fields
|
||||||
|
- Makes it clear this is llama-swap specific metadata
|
||||||
|
- Easier for clients to distinguish standard vs. custom fields
|
||||||
|
|
||||||
|
2. **Why support nested structures?**
|
||||||
|
|
||||||
|
- Provides maximum flexibility for users
|
||||||
|
- Aligns with the schemaless design principle
|
||||||
|
- Example config already demonstrates this capability
|
||||||
|
|
||||||
|
3. **Why validate macro types?**
|
||||||
|
- Prevents confusing behavior (e.g., substituting a map)
|
||||||
|
- Makes configuration errors explicit at load time
|
||||||
|
- Simpler implementation and testing
|
||||||
@@ -0,0 +1,397 @@
|
|||||||
|
# Improve macro-in-macro support
|
||||||
|
|
||||||
|
**Status: COMPLETED ✅**
|
||||||
|
|
||||||
|
## Title
|
||||||
|
|
||||||
|
Fix macro substitution ordering by preserving definition order using ordered YAML parsing
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The current macro implementation uses `map[string]any` which does not preserve insertion order. This causes issues when macros reference other macros - if macro `B` contains `${A}` but `B` is processed before `A`, the reference won't be substituted, leading to "unknown macro" errors.
|
||||||
|
|
||||||
|
**Goal:** Ensure macros are substituted in definition order (LIFO - last in, first out) to allow macros to reliably reference previously-defined macros.
|
||||||
|
|
||||||
|
**Outcomes:**
|
||||||
|
- Macros can reference other macros defined earlier in the config
|
||||||
|
- Macro substitution is deterministic and order-dependent
|
||||||
|
- Single-pass substitution prevents circular dependencies
|
||||||
|
- Use `yaml.Node` from `gopkg.in/yaml.v3` to preserve macro definition order
|
||||||
|
- All existing tests pass
|
||||||
|
- New tests validate substitution order and self-reference detection
|
||||||
|
|
||||||
|
## Design Requirements
|
||||||
|
|
||||||
|
### 1. YAML Parsing Strategy
|
||||||
|
- **Continue using:** `gopkg.in/yaml.v3` (current library)
|
||||||
|
- **Use:** `yaml.Node` for ordered parsing of macros
|
||||||
|
- **Reason:** `yaml.Node` preserves document structure and order, avoiding need for migration
|
||||||
|
|
||||||
|
### 2. Data Structure Changes
|
||||||
|
|
||||||
|
#### Current Implementation (config.go:19)
|
||||||
|
```go
|
||||||
|
type MacroList map[string]any
|
||||||
|
```
|
||||||
|
|
||||||
|
#### New Implementation
|
||||||
|
```go
|
||||||
|
type MacroList []MacroEntry
|
||||||
|
|
||||||
|
type MacroEntry struct {
|
||||||
|
Name string
|
||||||
|
Value any
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Implementation Note:** Parse macros using `yaml.Node` to extract key-value pairs in document order, then construct the ordered `MacroList`.
|
||||||
|
|
||||||
|
### 3. Macro Substitution Order Rules
|
||||||
|
|
||||||
|
The substitution must follow this hierarchy (from most specific to least):
|
||||||
|
|
||||||
|
1. **Reserved macros** (last): `PORT`, `MODEL_ID` - substituted last, highest priority
|
||||||
|
2. **Model-level macros** (middle): Defined in specific model config, overrides global
|
||||||
|
3. **Global macros** (first): Defined at config root level
|
||||||
|
|
||||||
|
Within each level, macros are substituted in **reverse definition order** (LIFO):
|
||||||
|
- The last macro defined is substituted first
|
||||||
|
- This allows later macros to reference earlier ones
|
||||||
|
- Single-pass substitution prevents circular dependencies
|
||||||
|
|
||||||
|
### 4. Macro Reference Rules
|
||||||
|
|
||||||
|
**Allowed:**
|
||||||
|
- Macro can reference any macro defined **before** it (earlier in the file)
|
||||||
|
- Model macros can reference global macros
|
||||||
|
- Macros can reference reserved macros (`${PORT}`, `${MODEL_ID}`)
|
||||||
|
|
||||||
|
**Prohibited:**
|
||||||
|
- Macro cannot reference itself (e.g., `foo: "value ${foo}"`)
|
||||||
|
- Macro cannot reference macros defined **after** it
|
||||||
|
- No circular references (prevented by single-pass, ordered substitution)
|
||||||
|
|
||||||
|
### 5. Validation Requirements
|
||||||
|
|
||||||
|
Add validation to detect:
|
||||||
|
- **Self-references:** Macro value contains reference to its own name
|
||||||
|
- **Unknown macros:** After substitution, any remaining `${...}` references
|
||||||
|
|
||||||
|
Error messages should be clear:
|
||||||
|
```
|
||||||
|
macro 'foo' contains self-reference
|
||||||
|
unknown macro '${bar}' in model.cmd
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. Implementation Changes
|
||||||
|
|
||||||
|
#### Files to Modify
|
||||||
|
|
||||||
|
1. **[proxy/config/config.go](proxy/config/config.go)**
|
||||||
|
- Line 19: Change `MacroList` type definition
|
||||||
|
- Line 69: Update `Macros MacroList` field
|
||||||
|
- Line 153-157: Update macro validation loop to work with ordered structure
|
||||||
|
- Line 175-188: Update model-level macro validation
|
||||||
|
- Line 181-188: **NEW** Implement proper macro merging respecting order
|
||||||
|
- Line 193-202: **NEW** Implement ordered macro substitution in LIFO order
|
||||||
|
- Line 389-415: Update `validateMacro` to detect self-references
|
||||||
|
- Line 420-475: Update `substituteMetadataMacros` to accept ordered MacroList
|
||||||
|
|
||||||
|
2. **[proxy/config/model_config.go](proxy/config/model_config.go)**
|
||||||
|
- Line 33: Update `Macros MacroList` field type
|
||||||
|
|
||||||
|
3. **All test files**
|
||||||
|
- Update test fixtures to use ordered macro definitions
|
||||||
|
- Ensure tests specify macro order explicitly
|
||||||
|
|
||||||
|
#### Core Algorithm
|
||||||
|
|
||||||
|
Replace the macro substitution logic in [config.go:181-252](proxy/config/config.go#L181-L252) with:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Merge global config and model macros. Model macros take precedence
|
||||||
|
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+2)
|
||||||
|
|
||||||
|
// Add global macros first
|
||||||
|
for _, entry := range config.Macros {
|
||||||
|
mergedMacros = append(mergedMacros, entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add model macros (can override global)
|
||||||
|
for _, entry := range modelConfig.Macros {
|
||||||
|
// Remove any existing global macro with same name
|
||||||
|
found := false
|
||||||
|
for i, existing := range mergedMacros {
|
||||||
|
if existing.Name == entry.Name {
|
||||||
|
mergedMacros[i] = entry // Override
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
mergedMacros = append(mergedMacros, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add reserved MODEL_ID macro at the end
|
||||||
|
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||||
|
|
||||||
|
// Check if PORT macro is needed
|
||||||
|
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
|
||||||
|
// enforce ${PORT} used in both cmd and proxy
|
||||||
|
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
|
||||||
|
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add PORT macro to the end (highest priority)
|
||||||
|
mergedMacros = append(mergedMacros, MacroEntry{Name: "PORT", Value: nextPort})
|
||||||
|
nextPort++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single-pass substitution: Substitute all macros in LIFO order (last defined first)
|
||||||
|
// This allows later macros to reference earlier ones
|
||||||
|
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||||
|
entry := mergedMacros[i]
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
|
// Substitute in command fields
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||||
|
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute in metadata (recursive)
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
var err error
|
||||||
|
modelConfig.Metadata, err = substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Add this new helper function to replace `substituteMetadataMacros`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||||
|
// This is called once per macro, allowing LIFO substitution order
|
||||||
|
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||||
|
macroStr := fmt.Sprintf("%v", macroValue)
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
// Check if this is a direct macro substitution
|
||||||
|
if v == macroSlug {
|
||||||
|
return macroValue, nil
|
||||||
|
}
|
||||||
|
// Handle string interpolation
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
// Recursively process map values
|
||||||
|
newMap := make(map[string]any)
|
||||||
|
for key, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newMap[key] = newVal
|
||||||
|
}
|
||||||
|
return newMap, nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
// Recursively process slice elements
|
||||||
|
newSlice := make([]any, len(v))
|
||||||
|
for i, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newSlice[i] = newVal
|
||||||
|
}
|
||||||
|
return newSlice, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Return scalar types as-is
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. Self-Reference Detection
|
||||||
|
|
||||||
|
Add to `validateMacro` function:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func validateMacro(name string, value any) error {
|
||||||
|
// ... existing validation ...
|
||||||
|
|
||||||
|
// Check for self-reference
|
||||||
|
if str, ok := value.(string); ok {
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", name)
|
||||||
|
if strings.Contains(str, macroSlug) {
|
||||||
|
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing Plan
|
||||||
|
|
||||||
|
### 1. Migration Tests
|
||||||
|
- **Test:** All existing macro tests still pass after YAML library migration
|
||||||
|
- **Files:** All `*_test.go` files with macro tests
|
||||||
|
|
||||||
|
### 2. Macro Order Tests
|
||||||
|
|
||||||
|
#### Test: Macro-in-macro substitution order
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"A": "value-A"
|
||||||
|
"B": "prefix-${A}-suffix"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "echo ${B}"
|
||||||
|
```
|
||||||
|
**Expected:** `cmd` becomes `"echo prefix-value-A-suffix"`
|
||||||
|
|
||||||
|
#### Test: LIFO substitution order
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"base": "/models"
|
||||||
|
"path": "${base}/llama"
|
||||||
|
"full": "${path}/model.gguf"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "load ${full}"
|
||||||
|
```
|
||||||
|
**Expected:** `cmd` becomes `"load /models/llama/model.gguf"`
|
||||||
|
|
||||||
|
#### Test: Model macro overrides global
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"tag": "global"
|
||||||
|
"msg": "value-${tag}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
macros:
|
||||||
|
"tag": "model-level"
|
||||||
|
cmd: "echo ${msg}"
|
||||||
|
```
|
||||||
|
**Expected:** `cmd` becomes `"echo value-model-level"` (model macro overrides global)
|
||||||
|
|
||||||
|
### 3. Reserved Macro Tests
|
||||||
|
|
||||||
|
#### Test: MODEL_ID substituted in macro
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||||
|
|
||||||
|
models:
|
||||||
|
my-model:
|
||||||
|
cmd: "${podman-llama} -m model.gguf"
|
||||||
|
```
|
||||||
|
**Expected:** `cmd` becomes `"podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf"`
|
||||||
|
|
||||||
|
### 4. Error Detection Tests
|
||||||
|
|
||||||
|
#### Test: Self-reference detection
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"recursive": "value-${recursive}"
|
||||||
|
```
|
||||||
|
**Expected:** Error: `macro 'recursive' contains self-reference`
|
||||||
|
|
||||||
|
#### Test: Undefined macro reference
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"A": "value-${UNDEFINED}"
|
||||||
|
```
|
||||||
|
**Expected:** Error: `unknown macro '${UNDEFINED}' found in macros.A` (or similar)
|
||||||
|
|
||||||
|
### 5. Regression Tests
|
||||||
|
- Run all existing macro tests: `TestConfig_MacroReplacement`, `TestConfig_MacroReservedNames`, etc.
|
||||||
|
- Ensure all pass without modification (except test fixtures if needed)
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
### Phase 1: Data Structure Changes
|
||||||
|
- [ ] Implement custom `UnmarshalYAML` method for `MacroList` that uses `yaml.Node`
|
||||||
|
- [ ] Define new ordered `MacroList` type as `[]MacroEntry`
|
||||||
|
- [ ] Update `MacroList` type definition in [config.go](proxy/config/config.go#L19)
|
||||||
|
- [ ] Update `Config.Macros` field type in [config.go](proxy/config/config.go#L69)
|
||||||
|
- [ ] Update `ModelConfig.Macros` field type in [model_config.go](proxy/config/model_config.go#L33)
|
||||||
|
- [ ] Implement helper functions:
|
||||||
|
- [ ] `func (ml MacroList) Get(name string) (any, bool)` - lookup by name
|
||||||
|
- [ ] `func (ml MacroList) Set(name string, value any) MacroList` - add/override entry
|
||||||
|
- [ ] `func (ml MacroList) ToMap() map[string]any` - convert to map if needed
|
||||||
|
|
||||||
|
### Phase 2: Macro Validation Updates
|
||||||
|
- [ ] Update macro validation loop at [config.go:153-157](proxy/config/config.go#L153-L157)
|
||||||
|
- [ ] Update model macro validation at [config.go:175-179](proxy/config/config.go#L175-L179)
|
||||||
|
- [ ] Add self-reference detection to `validateMacro` function [config.go:389](proxy/config/config.go#L389)
|
||||||
|
- [ ] Test self-reference detection with new test case
|
||||||
|
|
||||||
|
### Phase 3: Macro Substitution Algorithm
|
||||||
|
- [ ] Implement ordered macro merging (global → model → reserved) at [config.go:181-188](proxy/config/config.go#L181-L188)
|
||||||
|
- [ ] Implement single-pass LIFO substitution loop (reverse iteration) at [config.go:193-202](proxy/config/config.go#L193-L202)
|
||||||
|
- [ ] Substitute in all string fields (cmd, cmdStop, proxy, checkEndpoint, stripParams)
|
||||||
|
- [ ] Substitute in metadata within same loop
|
||||||
|
- [ ] Ensure `MODEL_ID` is added to merged macros before substitution
|
||||||
|
- [ ] Ensure `PORT` is added after port assignment (if needed)
|
||||||
|
- [ ] Replace `substituteMetadataMacros` with new `substituteMacroInValue` function that processes one macro at a time [config.go:420](proxy/config/config.go#L420)
|
||||||
|
- [ ] Remove old metadata substitution code that was separate from main loop [config.go:245-251](proxy/config/config.go#L245-L251)
|
||||||
|
|
||||||
|
### Phase 4: Testing
|
||||||
|
- [ ] Run `make test-dev` - fix any static checking errors
|
||||||
|
- [ ] Add test: macro-in-macro basic substitution
|
||||||
|
- [ ] Add test: LIFO substitution order with 3+ macro levels
|
||||||
|
- [ ] Add test: MODEL_ID in global macro used by model
|
||||||
|
- [ ] Add test: PORT in global macro used by model
|
||||||
|
- [ ] Add test: model macro overrides global macro in substitution
|
||||||
|
- [ ] Add test: self-reference detection error
|
||||||
|
- [ ] Add test: undefined macro reference error
|
||||||
|
- [ ] Verify all existing macro tests pass: `TestConfig_Macro*`
|
||||||
|
- [ ] Run `make test-all` - ensure all tests including concurrency tests pass
|
||||||
|
|
||||||
|
### Phase 5: Documentation
|
||||||
|
- [ ] Update plan status in this file (mark completed)
|
||||||
|
- [ ] Update CLAUDE.md if macro behavior needs documentation
|
||||||
|
- [ ] Verify no new error messages need user documentation
|
||||||
|
|
||||||
|
## Bug Example (Original Issue)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"podman-llama": >
|
||||||
|
podman run --name ${MODEL_ID}
|
||||||
|
--init --rm -p ${PORT}:8080 -v /home/alex/ai/models:/models:z --gpus=all
|
||||||
|
ghcr.io/ggml-org/llama.cpp:server-cuda
|
||||||
|
|
||||||
|
"standard-options": >
|
||||||
|
--no-mmap --jinja
|
||||||
|
|
||||||
|
"kv8": >
|
||||||
|
-fa on -ctk q8_0 -ctv q8_0
|
||||||
|
```
|
||||||
|
|
||||||
|
**Current Bug:**
|
||||||
|
- During macro substitution, if `${MODEL_ID}` is processed before `${podman-llama}`, the `${MODEL_ID}` reference inside `podman-llama` remains unsubstituted
|
||||||
|
- Results in error: `unknown macro '${MODEL_ID}' found in model.cmd`
|
||||||
|
|
||||||
|
**After Fix:**
|
||||||
|
- Macros substituted in LIFO order: `kv8` → `standard-options` → `podman-llama`
|
||||||
|
- `MODEL_ID` is a reserved macro, substituted last (after all user macros)
|
||||||
|
- `${MODEL_ID}` inside `podman-llama` is correctly replaced with the model name
|
||||||
+50
-4
@@ -38,13 +38,25 @@ startPort: 10001
|
|||||||
# macros: a dictionary of string substitutions
|
# macros: a dictionary of string substitutions
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
# - macros are reusable snippets
|
# - macros are reusable snippets
|
||||||
# - used in a model's cmd, cmdStop, proxy and checkEndpoint
|
# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
|
||||||
# - useful for reducing common configuration settings
|
# - useful for reducing common configuration settings
|
||||||
|
# - macro names are strings and must be less than 64 characters
|
||||||
|
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
|
||||||
|
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||||
|
# - macro values can be numbers, bools, or strings
|
||||||
|
# - macros can contain other macros, but they must be defined before they are used
|
||||||
macros:
|
macros:
|
||||||
|
# Example of a multi-line macro
|
||||||
"latest-llama": >
|
"latest-llama": >
|
||||||
/path/to/llama-server/llama-server-ec9e0301
|
/path/to/llama-server/llama-server-ec9e0301
|
||||||
--port ${PORT}
|
--port ${PORT}
|
||||||
|
|
||||||
|
"default_ctx": 4096
|
||||||
|
|
||||||
|
# Example of macro-in-macro usage. macros can contain other macros
|
||||||
|
# but they must be previously declared.
|
||||||
|
"default_args": "--ctx-size ${default_ctx}"
|
||||||
|
|
||||||
# models: a dictionary of model configurations
|
# models: a dictionary of model configurations
|
||||||
# - required
|
# - required
|
||||||
# - each key is the model's ID, used in API requests
|
# - each key is the model's ID, used in API requests
|
||||||
@@ -55,6 +67,14 @@ models:
|
|||||||
|
|
||||||
# keys are the model names used in API requests
|
# keys are the model names used in API requests
|
||||||
"llama":
|
"llama":
|
||||||
|
# macros: a dictionary of string substitutions specific to this model
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - macros defined here override macros defined in the global macros section
|
||||||
|
# - model level macros follow the same rules as global macros
|
||||||
|
macros:
|
||||||
|
"default_ctx": 16384
|
||||||
|
"temp": 0.7
|
||||||
|
|
||||||
# cmd: the command to run to start the inference server.
|
# cmd: the command to run to start the inference server.
|
||||||
# - required
|
# - required
|
||||||
# - it is just a string, similar to what you would run on the CLI
|
# - it is just a string, similar to what you would run on the CLI
|
||||||
@@ -64,6 +84,8 @@ models:
|
|||||||
# ${latest-llama} is a macro that is defined above
|
# ${latest-llama} is a macro that is defined above
|
||||||
${latest-llama}
|
${latest-llama}
|
||||||
--model path/to/llama-8B-Q4_K_M.gguf
|
--model path/to/llama-8B-Q4_K_M.gguf
|
||||||
|
--ctx-size ${default_ctx}
|
||||||
|
--temperature ${temp}
|
||||||
|
|
||||||
# name: a display name for the model
|
# name: a display name for the model
|
||||||
# - optional, default: empty string
|
# - optional, default: empty string
|
||||||
@@ -119,15 +141,39 @@ models:
|
|||||||
|
|
||||||
# filters: a dictionary of filter settings
|
# filters: a dictionary of filter settings
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
# - only strip_params is currently supported
|
# - only stripParams is currently supported
|
||||||
filters:
|
filters:
|
||||||
# strip_params: a comma separated list of parameters to remove from the request
|
# stripParams: a comma separated list of parameters to remove from the request
|
||||||
# - optional, default: ""
|
# - optional, default: ""
|
||||||
# - useful for server side enforcement of sampling parameters
|
# - useful for server side enforcement of sampling parameters
|
||||||
# - the `model` parameter can never be removed
|
# - the `model` parameter can never be removed
|
||||||
# - can be any JSON key in the request body
|
# - can be any JSON key in the request body
|
||||||
# - recommended to stick to sampling parameters
|
# - recommended to stick to sampling parameters
|
||||||
strip_params: "temperature, top_p, top_k"
|
stripParams: "temperature, top_p, top_k"
|
||||||
|
|
||||||
|
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - while metadata can contains complex types it is recommended to keep it simple
|
||||||
|
# - metadata is only passed through in /v1/models responses
|
||||||
|
metadata:
|
||||||
|
# port will remain an integer
|
||||||
|
port: ${PORT}
|
||||||
|
|
||||||
|
# the ${temp} macro will remain a float
|
||||||
|
temperature: ${temp}
|
||||||
|
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
|
||||||
|
|
||||||
|
a_list:
|
||||||
|
- 1
|
||||||
|
- 1.23
|
||||||
|
- "macros are OK in list and dictionary types: ${MODEL_ID}"
|
||||||
|
|
||||||
|
an_obj:
|
||||||
|
a: "1"
|
||||||
|
b: 2
|
||||||
|
# objects can contain complex types with macro substitution
|
||||||
|
# becomes: c: [0.7, false, "model: llama"]
|
||||||
|
c: ["${temp}", false, "model: ${MODEL_ID}"]
|
||||||
|
|
||||||
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
|
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
|
||||||
# - optional, default: 0
|
# - optional, default: 0
|
||||||
|
|||||||
+7
-6
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
"github.com/mostlygeek/llama-swap/proxy"
|
"github.com/mostlygeek/llama-swap/proxy"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -38,13 +39,13 @@ func main() {
|
|||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := proxy.LoadConfig(*configPath)
|
conf, err := config.LoadConfig(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Error loading config: %v\n", err)
|
fmt.Printf("Error loading config: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(config.Profiles) > 0 {
|
if len(conf.Profiles) > 0 {
|
||||||
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
|
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,7 +68,7 @@ func main() {
|
|||||||
// Support for watching config and reloading when it changes
|
// Support for watching config and reloading when it changes
|
||||||
reloadProxyManager := func() {
|
reloadProxyManager := func() {
|
||||||
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||||
config, err = proxy.LoadConfig(*configPath)
|
conf, err = config.LoadConfig(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
|
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
|
||||||
return
|
return
|
||||||
@@ -75,7 +76,7 @@ func main() {
|
|||||||
|
|
||||||
fmt.Println("Configuration Changed")
|
fmt.Println("Configuration Changed")
|
||||||
currentPM.Shutdown()
|
currentPM.Shutdown()
|
||||||
srv.Handler = proxy.New(config)
|
srv.Handler = proxy.New(conf)
|
||||||
fmt.Println("Configuration Reloaded")
|
fmt.Println("Configuration Reloaded")
|
||||||
|
|
||||||
// wait a few seconds and tell any UI to reload
|
// wait a few seconds and tell any UI to reload
|
||||||
@@ -85,12 +86,12 @@ func main() {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
config, err = proxy.LoadConfig(*configPath)
|
conf, err = config.LoadConfig(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Error, unable to load configuration: %v\n", err)
|
fmt.Printf("Error, unable to load configuration: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
srv.Handler = proxy.New(config)
|
srv.Handler = proxy.New(conf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
-460
@@ -1,460 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"regexp"
|
|
||||||
"runtime"
|
|
||||||
"slices"
|
|
||||||
"sort"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/billziss-gh/golib/shlex"
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
const DEFAULT_GROUP_ID = "(default)"
|
|
||||||
|
|
||||||
type ModelConfig struct {
|
|
||||||
Cmd string `yaml:"cmd"`
|
|
||||||
CmdStop string `yaml:"cmdStop"`
|
|
||||||
Proxy string `yaml:"proxy"`
|
|
||||||
Aliases []string `yaml:"aliases"`
|
|
||||||
Env []string `yaml:"env"`
|
|
||||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
|
||||||
UnloadAfter int `yaml:"ttl"`
|
|
||||||
Unlisted bool `yaml:"unlisted"`
|
|
||||||
UseModelName string `yaml:"useModelName"`
|
|
||||||
|
|
||||||
// #179 for /v1/models
|
|
||||||
Name string `yaml:"name"`
|
|
||||||
Description string `yaml:"description"`
|
|
||||||
|
|
||||||
// Limit concurrency of HTTP requests to process
|
|
||||||
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
|
||||||
|
|
||||||
// Model filters see issue #174
|
|
||||||
Filters ModelFilters `yaml:"filters"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|
||||||
type rawModelConfig ModelConfig
|
|
||||||
defaults := rawModelConfig{
|
|
||||||
Cmd: "",
|
|
||||||
CmdStop: "",
|
|
||||||
Proxy: "http://localhost:${PORT}",
|
|
||||||
Aliases: []string{},
|
|
||||||
Env: []string{},
|
|
||||||
CheckEndpoint: "/health",
|
|
||||||
UnloadAfter: 0,
|
|
||||||
Unlisted: false,
|
|
||||||
UseModelName: "",
|
|
||||||
ConcurrencyLimit: 0,
|
|
||||||
Name: "",
|
|
||||||
Description: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := unmarshal(&defaults); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
*m = ModelConfig(defaults)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
|
||||||
return SanitizeCommand(m.Cmd)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModelFilters see issue #174
|
|
||||||
type ModelFilters struct {
|
|
||||||
StripParams string `yaml:"strip_params"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|
||||||
type rawModelFilters ModelFilters
|
|
||||||
defaults := rawModelFilters{
|
|
||||||
StripParams: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := unmarshal(&defaults); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
*m = ModelFilters(defaults)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
|
||||||
if f.StripParams == "" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
params := strings.Split(f.StripParams, ",")
|
|
||||||
cleaned := make([]string, 0, len(params))
|
|
||||||
|
|
||||||
for _, param := range params {
|
|
||||||
trimmed := strings.TrimSpace(param)
|
|
||||||
if trimmed == "model" || trimmed == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
cleaned = append(cleaned, trimmed)
|
|
||||||
}
|
|
||||||
|
|
||||||
// sort cleaned
|
|
||||||
slices.Sort(cleaned)
|
|
||||||
return cleaned, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type GroupConfig struct {
|
|
||||||
Swap bool `yaml:"swap"`
|
|
||||||
Exclusive bool `yaml:"exclusive"`
|
|
||||||
Persistent bool `yaml:"persistent"`
|
|
||||||
Members []string `yaml:"members"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// set default values for GroupConfig
|
|
||||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|
||||||
type rawGroupConfig GroupConfig
|
|
||||||
defaults := rawGroupConfig{
|
|
||||||
Swap: true,
|
|
||||||
Exclusive: true,
|
|
||||||
Persistent: false,
|
|
||||||
Members: []string{},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := unmarshal(&defaults); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
*c = GroupConfig(defaults)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type HooksConfig struct {
|
|
||||||
OnStartup HookOnStartup `yaml:"on_startup"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type HookOnStartup struct {
|
|
||||||
Preload []string `yaml:"preload"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
|
||||||
LogRequests bool `yaml:"logRequests"`
|
|
||||||
LogLevel string `yaml:"logLevel"`
|
|
||||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
|
||||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
|
||||||
Profiles map[string][]string `yaml:"profiles"`
|
|
||||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
|
||||||
|
|
||||||
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
|
||||||
Macros map[string]string `yaml:"macros"`
|
|
||||||
|
|
||||||
// map aliases to actual model IDs
|
|
||||||
aliases map[string]string
|
|
||||||
|
|
||||||
// automatic port assignments
|
|
||||||
StartPort int `yaml:"startPort"`
|
|
||||||
|
|
||||||
// hooks, see: #209
|
|
||||||
Hooks HooksConfig `yaml:"hooks"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) RealModelName(search string) (string, bool) {
|
|
||||||
if _, found := c.Models[search]; found {
|
|
||||||
return search, true
|
|
||||||
} else if name, found := c.aliases[search]; found {
|
|
||||||
return name, found
|
|
||||||
} else {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
|
||||||
if realName, found := c.RealModelName(modelName); !found {
|
|
||||||
return ModelConfig{}, "", false
|
|
||||||
} else {
|
|
||||||
return c.Models[realName], realName, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadConfig(path string) (Config, error) {
|
|
||||||
file, err := os.Open(path)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
return LoadConfigFromReader(file)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|
||||||
data, err := io.ReadAll(r)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// default configuration values
|
|
||||||
config := Config{
|
|
||||||
HealthCheckTimeout: 120,
|
|
||||||
StartPort: 5800,
|
|
||||||
LogLevel: "info",
|
|
||||||
MetricsMaxInMemory: 1000,
|
|
||||||
}
|
|
||||||
err = yaml.Unmarshal(data, &config)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.HealthCheckTimeout < 15 {
|
|
||||||
// set a minimum of 15 seconds
|
|
||||||
config.HealthCheckTimeout = 15
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.StartPort < 1 {
|
|
||||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Populate the aliases map
|
|
||||||
config.aliases = make(map[string]string)
|
|
||||||
for modelName, modelConfig := range config.Models {
|
|
||||||
for _, alias := range modelConfig.Aliases {
|
|
||||||
if _, found := config.aliases[alias]; found {
|
|
||||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
|
||||||
}
|
|
||||||
config.aliases[alias] = modelName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* check macro constraint rules:
|
|
||||||
|
|
||||||
- name must fit the regex ^[a-zA-Z0-9_-]+$
|
|
||||||
- names must be less than 64 characters (no reason, just cause)
|
|
||||||
- name can not be any reserved macros: PORT, MODEL_ID
|
|
||||||
- macro values must be less than 1024 characters
|
|
||||||
*/
|
|
||||||
macroNameRegex := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
|
||||||
for macroName, macroValue := range config.Macros {
|
|
||||||
if len(macroName) >= 64 {
|
|
||||||
return Config{}, fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", macroName)
|
|
||||||
}
|
|
||||||
if !macroNameRegex.MatchString(macroName) {
|
|
||||||
return Config{}, fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", macroName)
|
|
||||||
}
|
|
||||||
if len(macroValue) >= 1024 {
|
|
||||||
return Config{}, fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", macroName)
|
|
||||||
}
|
|
||||||
switch macroName {
|
|
||||||
case "PORT":
|
|
||||||
case "MODEL_ID":
|
|
||||||
return Config{}, fmt.Errorf("macro name '%s' is reserved and cannot be used", macroName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get and sort all model IDs first, makes testing more consistent
|
|
||||||
modelIds := make([]string, 0, len(config.Models))
|
|
||||||
for modelId := range config.Models {
|
|
||||||
modelIds = append(modelIds, modelId)
|
|
||||||
}
|
|
||||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
|
||||||
|
|
||||||
nextPort := config.StartPort
|
|
||||||
for _, modelId := range modelIds {
|
|
||||||
modelConfig := config.Models[modelId]
|
|
||||||
|
|
||||||
// Strip comments from command fields before macro expansion
|
|
||||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
|
||||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
|
||||||
|
|
||||||
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
|
|
||||||
for macroName, macroValue := range config.Macros {
|
|
||||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue)
|
|
||||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
|
|
||||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroValue)
|
|
||||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroValue)
|
|
||||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroValue)
|
|
||||||
}
|
|
||||||
|
|
||||||
// enforce ${PORT} used in both cmd and proxy
|
|
||||||
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
|
|
||||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
|
||||||
}
|
|
||||||
|
|
||||||
// only iterate over models that use ${PORT} to keep port numbers from increasing unnecessarily
|
|
||||||
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
|
|
||||||
nextPortStr := strconv.Itoa(nextPort)
|
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", nextPortStr)
|
|
||||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${PORT}", nextPortStr)
|
|
||||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", nextPortStr)
|
|
||||||
nextPort++
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.Contains(modelConfig.Cmd, "${MODEL_ID}") || strings.Contains(modelConfig.CmdStop, "${MODEL_ID}") {
|
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${MODEL_ID}", modelId)
|
|
||||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${MODEL_ID}", modelId)
|
|
||||||
}
|
|
||||||
|
|
||||||
// make sure there are no unknown macros that have not been replaced
|
|
||||||
macroPattern := regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
|
||||||
fieldMap := map[string]string{
|
|
||||||
"cmd": modelConfig.Cmd,
|
|
||||||
"cmdStop": modelConfig.CmdStop,
|
|
||||||
"proxy": modelConfig.Proxy,
|
|
||||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
|
||||||
}
|
|
||||||
|
|
||||||
for fieldName, fieldValue := range fieldMap {
|
|
||||||
matches := macroPattern.FindAllStringSubmatch(fieldValue, -1)
|
|
||||||
for _, match := range matches {
|
|
||||||
macroName := match[1]
|
|
||||||
if macroName == "PID" && fieldName == "cmdStop" {
|
|
||||||
continue // this is ok, has to be replaced by process later
|
|
||||||
}
|
|
||||||
if _, exists := config.Macros[macroName]; !exists {
|
|
||||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
config.Models[modelId] = modelConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
config = AddDefaultGroupToConfig(config)
|
|
||||||
// check that members are all unique in the groups
|
|
||||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
|
||||||
for groupID, groupConfig := range config.Groups {
|
|
||||||
prevSet := make(map[string]bool)
|
|
||||||
for _, member := range groupConfig.Members {
|
|
||||||
// Check for duplicates within this group
|
|
||||||
if _, found := prevSet[member]; found {
|
|
||||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
|
||||||
}
|
|
||||||
prevSet[member] = true
|
|
||||||
|
|
||||||
// Check if member is used in another group
|
|
||||||
if existingGroup, exists := memberUsage[member]; exists {
|
|
||||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
|
||||||
}
|
|
||||||
memberUsage[member] = groupID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// clean up hooks preload
|
|
||||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
|
||||||
var toPreload []string
|
|
||||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
|
||||||
modelID = strings.TrimSpace(modelID)
|
|
||||||
if modelID == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if real, found := config.RealModelName(modelID); found {
|
|
||||||
toPreload = append(toPreload, real)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
config.Hooks.OnStartup.Preload = toPreload
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// rewrites the yaml to include a default group with any orphaned models
|
|
||||||
func AddDefaultGroupToConfig(config Config) Config {
|
|
||||||
|
|
||||||
if config.Groups == nil {
|
|
||||||
config.Groups = make(map[string]GroupConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultGroup := GroupConfig{
|
|
||||||
Swap: true,
|
|
||||||
Exclusive: true,
|
|
||||||
Members: []string{},
|
|
||||||
}
|
|
||||||
// if groups is empty, create a default group and put
|
|
||||||
// all models into it
|
|
||||||
if len(config.Groups) == 0 {
|
|
||||||
for modelName := range config.Models {
|
|
||||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// iterate over existing group members and add non-grouped models into the default group
|
|
||||||
for modelName, _ := range config.Models {
|
|
||||||
foundModel := false
|
|
||||||
found:
|
|
||||||
// search for the model in existing groups
|
|
||||||
for _, groupConfig := range config.Groups {
|
|
||||||
for _, member := range groupConfig.Members {
|
|
||||||
if member == modelName {
|
|
||||||
foundModel = true
|
|
||||||
break found
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !foundModel {
|
|
||||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
|
||||||
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
|
||||||
|
|
||||||
return config
|
|
||||||
}
|
|
||||||
|
|
||||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
|
||||||
var cleanedLines []string
|
|
||||||
for _, line := range strings.Split(cmdStr, "\n") {
|
|
||||||
trimmed := strings.TrimSpace(line)
|
|
||||||
// Skip comment lines
|
|
||||||
if strings.HasPrefix(trimmed, "#") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Handle trailing backslashes by replacing with space
|
|
||||||
if strings.HasSuffix(trimmed, "\\") {
|
|
||||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
|
||||||
} else {
|
|
||||||
cleanedLines = append(cleanedLines, line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// put it back together
|
|
||||||
cmdStr = strings.Join(cleanedLines, "\n")
|
|
||||||
|
|
||||||
// Split the command into arguments
|
|
||||||
var args []string
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
args = shlex.Windows.Split(cmdStr)
|
|
||||||
} else {
|
|
||||||
args = shlex.Posix.Split(cmdStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the command is not empty
|
|
||||||
if len(args) == 0 {
|
|
||||||
return nil, fmt.Errorf("empty command")
|
|
||||||
}
|
|
||||||
|
|
||||||
return args, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func StripComments(cmdStr string) string {
|
|
||||||
var cleanedLines []string
|
|
||||||
for _, line := range strings.Split(cmdStr, "\n") {
|
|
||||||
trimmed := strings.TrimSpace(line)
|
|
||||||
// Skip comment lines
|
|
||||||
if strings.HasPrefix(trimmed, "#") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
cleanedLines = append(cleanedLines, line)
|
|
||||||
}
|
|
||||||
return strings.Join(cleanedLines, "\n")
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,593 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/billziss-gh/golib/shlex"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
const DEFAULT_GROUP_ID = "(default)"
|
||||||
|
|
||||||
|
type MacroEntry struct {
|
||||||
|
Name string
|
||||||
|
Value any
|
||||||
|
}
|
||||||
|
|
||||||
|
type MacroList []MacroEntry
|
||||||
|
|
||||||
|
// UnmarshalYAML implements custom YAML unmarshaling that preserves macro definition order
|
||||||
|
func (ml *MacroList) UnmarshalYAML(value *yaml.Node) error {
|
||||||
|
if value.Kind != yaml.MappingNode {
|
||||||
|
return fmt.Errorf("macros must be a mapping")
|
||||||
|
}
|
||||||
|
|
||||||
|
// yaml.Node.Content for a mapping contains alternating key/value nodes
|
||||||
|
entries := make([]MacroEntry, 0, len(value.Content)/2)
|
||||||
|
for i := 0; i < len(value.Content); i += 2 {
|
||||||
|
keyNode := value.Content[i]
|
||||||
|
valueNode := value.Content[i+1]
|
||||||
|
|
||||||
|
var name string
|
||||||
|
if err := keyNode.Decode(&name); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode macro name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var val any
|
||||||
|
if err := valueNode.Decode(&val); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode macro value for '%s': %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries = append(entries, MacroEntry{Name: name, Value: val})
|
||||||
|
}
|
||||||
|
|
||||||
|
*ml = entries
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a macro value by name
|
||||||
|
func (ml MacroList) Get(name string) (any, bool) {
|
||||||
|
for _, entry := range ml {
|
||||||
|
if entry.Name == name {
|
||||||
|
return entry.Value, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToMap converts MacroList to a map (for backward compatibility if needed)
|
||||||
|
func (ml MacroList) ToMap() map[string]any {
|
||||||
|
result := make(map[string]any, len(ml))
|
||||||
|
for _, entry := range ml {
|
||||||
|
result[entry.Name] = entry.Value
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
type GroupConfig struct {
|
||||||
|
Swap bool `yaml:"swap"`
|
||||||
|
Exclusive bool `yaml:"exclusive"`
|
||||||
|
Persistent bool `yaml:"persistent"`
|
||||||
|
Members []string `yaml:"members"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||||
|
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// set default values for GroupConfig
|
||||||
|
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
type rawGroupConfig GroupConfig
|
||||||
|
defaults := rawGroupConfig{
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Persistent: false,
|
||||||
|
Members: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unmarshal(&defaults); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*c = GroupConfig(defaults)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type HooksConfig struct {
|
||||||
|
OnStartup HookOnStartup `yaml:"on_startup"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HookOnStartup struct {
|
||||||
|
Preload []string `yaml:"preload"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||||
|
LogRequests bool `yaml:"logRequests"`
|
||||||
|
LogLevel string `yaml:"logLevel"`
|
||||||
|
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||||
|
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||||
|
Profiles map[string][]string `yaml:"profiles"`
|
||||||
|
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||||
|
|
||||||
|
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||||
|
Macros MacroList `yaml:"macros"`
|
||||||
|
|
||||||
|
// map aliases to actual model IDs
|
||||||
|
aliases map[string]string
|
||||||
|
|
||||||
|
// automatic port assignments
|
||||||
|
StartPort int `yaml:"startPort"`
|
||||||
|
|
||||||
|
// hooks, see: #209
|
||||||
|
Hooks HooksConfig `yaml:"hooks"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) RealModelName(search string) (string, bool) {
|
||||||
|
if _, found := c.Models[search]; found {
|
||||||
|
return search, true
|
||||||
|
} else if name, found := c.aliases[search]; found {
|
||||||
|
return name, found
|
||||||
|
} else {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||||
|
if realName, found := c.RealModelName(modelName); !found {
|
||||||
|
return ModelConfig{}, "", false
|
||||||
|
} else {
|
||||||
|
return c.Models[realName], realName, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadConfig(path string) (Config, error) {
|
||||||
|
file, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
return LoadConfigFromReader(file)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||||
|
data, err := io.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// default configuration values
|
||||||
|
config := Config{
|
||||||
|
HealthCheckTimeout: 120,
|
||||||
|
StartPort: 5800,
|
||||||
|
LogLevel: "info",
|
||||||
|
MetricsMaxInMemory: 1000,
|
||||||
|
}
|
||||||
|
err = yaml.Unmarshal(data, &config)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.HealthCheckTimeout < 15 {
|
||||||
|
// set a minimum of 15 seconds
|
||||||
|
config.HealthCheckTimeout = 15
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.StartPort < 1 {
|
||||||
|
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate the aliases map
|
||||||
|
config.aliases = make(map[string]string)
|
||||||
|
for modelName, modelConfig := range config.Models {
|
||||||
|
for _, alias := range modelConfig.Aliases {
|
||||||
|
if _, found := config.aliases[alias]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||||
|
}
|
||||||
|
config.aliases[alias] = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* check macro constraint rules:
|
||||||
|
|
||||||
|
- name must fit the regex ^[a-zA-Z0-9_-]+$
|
||||||
|
- names must be less than 64 characters (no reason, just cause)
|
||||||
|
- name can not be any reserved macros: PORT, MODEL_ID
|
||||||
|
- macro values must be less than 1024 characters
|
||||||
|
*/
|
||||||
|
for _, macro := range config.Macros {
|
||||||
|
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get and sort all model IDs first, makes testing more consistent
|
||||||
|
modelIds := make([]string, 0, len(config.Models))
|
||||||
|
for modelId := range config.Models {
|
||||||
|
modelIds = append(modelIds, modelId)
|
||||||
|
}
|
||||||
|
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||||
|
|
||||||
|
nextPort := config.StartPort
|
||||||
|
for _, modelId := range modelIds {
|
||||||
|
modelConfig := config.Models[modelId]
|
||||||
|
|
||||||
|
// Strip comments from command fields before macro expansion
|
||||||
|
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||||
|
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||||
|
|
||||||
|
// validate model macros
|
||||||
|
for _, macro := range modelConfig.Macros {
|
||||||
|
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge global config and model macros. Model macros take precedence
|
||||||
|
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros))
|
||||||
|
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||||
|
|
||||||
|
// Add global macros first
|
||||||
|
mergedMacros = append(mergedMacros, config.Macros...)
|
||||||
|
|
||||||
|
// Add model macros (can override global)
|
||||||
|
for _, entry := range modelConfig.Macros {
|
||||||
|
// Remove any existing global macro with same name
|
||||||
|
found := false
|
||||||
|
for i, existing := range mergedMacros {
|
||||||
|
if existing.Name == entry.Name {
|
||||||
|
mergedMacros[i] = entry // Override
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
mergedMacros = append(mergedMacros, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// First pass: Substitute user-defined macros in reverse order (LIFO - last defined first)
|
||||||
|
// This allows later macros to reference earlier ones
|
||||||
|
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||||
|
entry := mergedMacros[i]
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
|
// Substitute in command fields
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||||
|
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute in metadata (recursive)
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
var err error
|
||||||
|
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
modelConfig.Metadata = result.(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final pass: check if PORT macro is needed after macro expansion
|
||||||
|
// ${PORT} is a resource on the local machine so a new port is only allocated
|
||||||
|
// if it is required in either cmd or proxy keys
|
||||||
|
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||||
|
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||||
|
if cmdHasPort || proxyHasPort { // either has it
|
||||||
|
if !cmdHasPort && proxyHasPort { // but both don't have it
|
||||||
|
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add PORT macro and substitute it
|
||||||
|
portEntry := MacroEntry{Name: "PORT", Value: nextPort}
|
||||||
|
macroSlug := "${PORT}"
|
||||||
|
macroStr := fmt.Sprintf("%v", nextPort)
|
||||||
|
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute PORT in metadata
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
var err error
|
||||||
|
result, err := substituteMacroInValue(modelConfig.Metadata, portEntry.Name, portEntry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
modelConfig.Metadata = result.(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
nextPort++
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure there are no unknown macros that have not been replaced
|
||||||
|
fieldMap := map[string]string{
|
||||||
|
"cmd": modelConfig.Cmd,
|
||||||
|
"cmdStop": modelConfig.CmdStop,
|
||||||
|
"proxy": modelConfig.Proxy,
|
||||||
|
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||||
|
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||||
|
}
|
||||||
|
|
||||||
|
for fieldName, fieldValue := range fieldMap {
|
||||||
|
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
macroName := match[1]
|
||||||
|
if macroName == "PID" && fieldName == "cmdStop" {
|
||||||
|
continue // this is ok, has to be replaced by process later
|
||||||
|
}
|
||||||
|
// Reserved macros are always valid (they should have been substituted already)
|
||||||
|
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||||
|
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||||
|
}
|
||||||
|
// Any other macro is unknown
|
||||||
|
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for unknown macros in metadata
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
if err := validateMetadataForUnknownMacros(modelConfig.Metadata, modelId); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Models[modelId] = modelConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
config = AddDefaultGroupToConfig(config)
|
||||||
|
// check that members are all unique in the groups
|
||||||
|
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||||
|
for groupID, groupConfig := range config.Groups {
|
||||||
|
prevSet := make(map[string]bool)
|
||||||
|
for _, member := range groupConfig.Members {
|
||||||
|
// Check for duplicates within this group
|
||||||
|
if _, found := prevSet[member]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||||
|
}
|
||||||
|
prevSet[member] = true
|
||||||
|
|
||||||
|
// Check if member is used in another group
|
||||||
|
if existingGroup, exists := memberUsage[member]; exists {
|
||||||
|
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||||
|
}
|
||||||
|
memberUsage[member] = groupID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clean up hooks preload
|
||||||
|
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||||
|
var toPreload []string
|
||||||
|
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||||
|
modelID = strings.TrimSpace(modelID)
|
||||||
|
if modelID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if real, found := config.RealModelName(modelID); found {
|
||||||
|
toPreload = append(toPreload, real)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Hooks.OnStartup.Preload = toPreload
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewrites the yaml to include a default group with any orphaned models
|
||||||
|
func AddDefaultGroupToConfig(config Config) Config {
|
||||||
|
|
||||||
|
if config.Groups == nil {
|
||||||
|
config.Groups = make(map[string]GroupConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultGroup := GroupConfig{
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{},
|
||||||
|
}
|
||||||
|
// if groups is empty, create a default group and put
|
||||||
|
// all models into it
|
||||||
|
if len(config.Groups) == 0 {
|
||||||
|
for modelName := range config.Models {
|
||||||
|
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// iterate over existing group members and add non-grouped models into the default group
|
||||||
|
for modelName := range config.Models {
|
||||||
|
foundModel := false
|
||||||
|
found:
|
||||||
|
// search for the model in existing groups
|
||||||
|
for _, groupConfig := range config.Groups {
|
||||||
|
for _, member := range groupConfig.Members {
|
||||||
|
if member == modelName {
|
||||||
|
foundModel = true
|
||||||
|
break found
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundModel {
|
||||||
|
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
||||||
|
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
||||||
|
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||||
|
var cleanedLines []string
|
||||||
|
for _, line := range strings.Split(cmdStr, "\n") {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
// Skip comment lines
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Handle trailing backslashes by replacing with space
|
||||||
|
if strings.HasSuffix(trimmed, "\\") {
|
||||||
|
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||||
|
} else {
|
||||||
|
cleanedLines = append(cleanedLines, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// put it back together
|
||||||
|
cmdStr = strings.Join(cleanedLines, "\n")
|
||||||
|
|
||||||
|
// Split the command into arguments
|
||||||
|
var args []string
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
args = shlex.Windows.Split(cmdStr)
|
||||||
|
} else {
|
||||||
|
args = shlex.Posix.Split(cmdStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the command is not empty
|
||||||
|
if len(args) == 0 {
|
||||||
|
return nil, fmt.Errorf("empty command")
|
||||||
|
}
|
||||||
|
|
||||||
|
return args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func StripComments(cmdStr string) string {
|
||||||
|
var cleanedLines []string
|
||||||
|
for _, line := range strings.Split(cmdStr, "\n") {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
// Skip comment lines
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cleanedLines = append(cleanedLines, line)
|
||||||
|
}
|
||||||
|
return strings.Join(cleanedLines, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateMacro validates macro name and value constraints
|
||||||
|
func validateMacro(name string, value any) error {
|
||||||
|
if len(name) >= 64 {
|
||||||
|
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||||
|
}
|
||||||
|
if !macroNameRegex.MatchString(name) {
|
||||||
|
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that value is a scalar type
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
if len(v) >= 1024 {
|
||||||
|
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
|
||||||
|
}
|
||||||
|
// Check for self-reference
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", name)
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||||
|
}
|
||||||
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||||
|
// These types are allowed
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "PORT", "MODEL_ID":
|
||||||
|
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateMetadataForUnknownMacros recursively checks for any remaining macro references in metadata
|
||||||
|
func validateMetadataForUnknownMacros(value any, modelId string) error {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
macroName := match[1]
|
||||||
|
return fmt.Errorf("model %s metadata: unknown macro '${%s}'", modelId, macroName)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
for _, val := range v {
|
||||||
|
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
for _, val := range v {
|
||||||
|
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Scalar types don't contain macros
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||||
|
// This is called once per macro, allowing LIFO substitution order
|
||||||
|
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||||
|
macroStr := fmt.Sprintf("%v", macroValue)
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
// Check if this is a direct macro substitution
|
||||||
|
if v == macroSlug {
|
||||||
|
return macroValue, nil
|
||||||
|
}
|
||||||
|
// Handle string interpolation
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
// Recursively process map values
|
||||||
|
newMap := make(map[string]any)
|
||||||
|
for key, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newMap[key] = newVal
|
||||||
|
}
|
||||||
|
return newMap, nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
// Recursively process slice elements
|
||||||
|
newSlice := make([]any, len(v))
|
||||||
|
for i, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newSlice[i] = newVal
|
||||||
|
}
|
||||||
|
return newSlice, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Return scalar types as-is
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
//go:build !windows
|
//go:build !windows
|
||||||
|
|
||||||
package proxy
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
@@ -163,8 +163,8 @@ groups:
|
|||||||
expected := Config{
|
expected := Config{
|
||||||
LogLevel: "info",
|
LogLevel: "info",
|
||||||
StartPort: 5800,
|
StartPort: 5800,
|
||||||
Macros: map[string]string{
|
Macros: MacroList{
|
||||||
"svr-path": "path/to/server",
|
{"svr-path", "path/to/server"},
|
||||||
},
|
},
|
||||||
Hooks: HooksConfig{
|
Hooks: HooksConfig{
|
||||||
OnStartup: HookOnStartup{
|
OnStartup: HookOnStartup{
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package proxy
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"slices"
|
"slices"
|
||||||
@@ -65,18 +65,6 @@ models:
|
|||||||
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
|
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
|
||||||
config := &ModelConfig{
|
|
||||||
Cmd: `python model1.py \
|
|
||||||
--arg1 value1 \
|
|
||||||
--arg2 value2`,
|
|
||||||
}
|
|
||||||
|
|
||||||
args, err := config.SanitizedCommand()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConfig_FindConfig(t *testing.T) {
|
func TestConfig_FindConfig(t *testing.T) {
|
||||||
|
|
||||||
// TODO?
|
// TODO?
|
||||||
@@ -207,30 +195,91 @@ macros:
|
|||||||
argOne: "--arg1"
|
argOne: "--arg1"
|
||||||
argTwo: "--arg2"
|
argTwo: "--arg2"
|
||||||
autoPort: "--port ${PORT}"
|
autoPort: "--port ${PORT}"
|
||||||
|
overriddenByModelMacro: failed
|
||||||
|
|
||||||
models:
|
models:
|
||||||
model1:
|
model1:
|
||||||
|
macros:
|
||||||
|
overriddenByModelMacro: success
|
||||||
cmd: |
|
cmd: |
|
||||||
${svr-path} ${argTwo}
|
${svr-path} ${argTwo}
|
||||||
# the automatic ${PORT} is replaced
|
# the automatic ${PORT} is replaced
|
||||||
${autoPort}
|
${autoPort}
|
||||||
${argOne}
|
${argOne}
|
||||||
--arg3 three
|
--arg3 three
|
||||||
|
--overridden ${overriddenByModelMacro}
|
||||||
cmdStop: |
|
cmdStop: |
|
||||||
/path/to/stop.sh --port ${PORT} ${argTwo}
|
/path/to/stop.sh --port ${PORT} ${argTwo}
|
||||||
`
|
`
|
||||||
|
|
||||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
assert.NoError(t, err)
|
if !assert.NoError(t, err) {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
|
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three", strings.Join(sanitizedCmd, " "))
|
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three --overridden success", strings.Join(sanitizedCmd, " "))
|
||||||
|
|
||||||
sanitizedCmdStop, err := SanitizeCommand(config.Models["model1"].CmdStop)
|
sanitizedCmdStop, err := SanitizeCommand(config.Models["model1"].CmdStop)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "/path/to/stop.sh --port 9990 --arg2", strings.Join(sanitizedCmdStop, " "))
|
assert.Equal(t, "/path/to/stop.sh --port 9990 --arg2", strings.Join(sanitizedCmdStop, " "))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfig_MacroReservedNames(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config string
|
||||||
|
expectedError string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "global macro named PORT",
|
||||||
|
config: `
|
||||||
|
macros:
|
||||||
|
PORT: "1111"
|
||||||
|
`,
|
||||||
|
expectedError: "macro name 'PORT' is reserved",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "global macro named MODEL_ID",
|
||||||
|
config: `
|
||||||
|
macros:
|
||||||
|
MODEL_ID: model1
|
||||||
|
`,
|
||||||
|
expectedError: "macro name 'MODEL_ID' is reserved",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model macro named PORT",
|
||||||
|
config: `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
macros:
|
||||||
|
PORT: 1111
|
||||||
|
`,
|
||||||
|
expectedError: "model model1: macro name 'PORT' is reserved",
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "model macro named MODEL_ID",
|
||||||
|
config: `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
macros:
|
||||||
|
MODEL_ID: model1
|
||||||
|
`,
|
||||||
|
expectedError: "model model1: macro name 'MODEL_ID' is reserved",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(tt.config))
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Equal(t, tt.expectedError, err.Error())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConfig_MacroErrorOnUnknownMacros(t *testing.T) {
|
func TestConfig_MacroErrorOnUnknownMacros(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -274,7 +323,7 @@ macros:
|
|||||||
models:
|
models:
|
||||||
model1:
|
model1:
|
||||||
cmd: "${svr-path} --port ${PORT}"
|
cmd: "${svr-path} --port ${PORT}"
|
||||||
proxy: "http://localhost:${unknownMacro}"
|
proxy: "http://${unknownMacro}:${PORT}"
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -288,6 +337,20 @@ models:
|
|||||||
model1:
|
model1:
|
||||||
cmd: "${svr-path} --port ${PORT}"
|
cmd: "${svr-path} --port ${PORT}"
|
||||||
checkEndpoint: "http://localhost:${unknownMacro}/health"
|
checkEndpoint: "http://localhost:${unknownMacro}/health"
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown macro in filters.stripParams",
|
||||||
|
field: "filters.stripParams",
|
||||||
|
content: `
|
||||||
|
startPort: 9990
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: "${svr-path} --port ${PORT}"
|
||||||
|
filters:
|
||||||
|
stripParams: "model,${unknownMacro}"
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -295,38 +358,13 @@ models:
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
|
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
|
||||||
assert.Error(t, err)
|
if assert.Error(t, err) {
|
||||||
assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field)
|
assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field)
|
||||||
|
}
|
||||||
//t.Log(err)
|
//t.Log(err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_ModelFilters(t *testing.T) {
|
|
||||||
content := `
|
|
||||||
macros:
|
|
||||||
default_strip: "temperature, top_p"
|
|
||||||
models:
|
|
||||||
model1:
|
|
||||||
cmd: path/to/cmd --port ${PORT}
|
|
||||||
filters:
|
|
||||||
strip_params: "model, top_k, ${default_strip}, , ,"
|
|
||||||
`
|
|
||||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
modelConfig, ok := config.Models["model1"]
|
|
||||||
if !assert.True(t, ok) {
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
// make sure `model` and enmpty strings are not in the list
|
|
||||||
assert.Equal(t, "model, top_k, temperature, top_p, , ,", modelConfig.Filters.StripParams)
|
|
||||||
sanitized, err := modelConfig.Filters.SanitizedStripParams()
|
|
||||||
if assert.NoError(t, err) {
|
|
||||||
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStripComments(t *testing.T) {
|
func TestStripComments(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -467,7 +505,9 @@ models:
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "/path/to/server -p 9001 -hf model1", strings.Join(sanitizedCmd, " "))
|
assert.Equal(t, "/path/to/server -p 9001 -hf model1", strings.Join(sanitizedCmd, " "))
|
||||||
|
|
||||||
assert.Equal(t, "docker stop ${MODEL_ID}", config.Macros["docker-stop"])
|
dockerStopMacro, found := config.Macros.Get("docker-stop")
|
||||||
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, "docker stop ${MODEL_ID}", dockerStopMacro)
|
||||||
|
|
||||||
sanitizedCmd2, err := SanitizeCommand(config.Models["model2"].Cmd)
|
sanitizedCmd2, err := SanitizeCommand(config.Models["model2"].Cmd)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@@ -481,3 +521,243 @@ models:
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "/path/to/server -p 9000 -hf author/model:F16", strings.Join(sanitizedCmd3, " "))
|
assert.Equal(t, "/path/to/server -p 9000 -hf author/model:F16", strings.Join(sanitizedCmd3, " "))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfig_TypedMacrosInMetadata(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
PORT_NUM: 10001
|
||||||
|
TEMP: 0.7
|
||||||
|
ENABLED: true
|
||||||
|
NAME: "llama model"
|
||||||
|
CTX: 16384
|
||||||
|
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
metadata:
|
||||||
|
port: ${PORT_NUM}
|
||||||
|
temperature: ${TEMP}
|
||||||
|
enabled: ${ENABLED}
|
||||||
|
model_name: ${NAME}
|
||||||
|
context: ${CTX}
|
||||||
|
note: "Running on port ${PORT_NUM} with temp ${TEMP} and context ${CTX}"
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
meta := config.Models["test-model"].Metadata
|
||||||
|
assert.NotNil(t, meta)
|
||||||
|
|
||||||
|
// Verify direct substitution preserves types
|
||||||
|
assert.Equal(t, 10001, meta["port"])
|
||||||
|
assert.Equal(t, 0.7, meta["temperature"])
|
||||||
|
assert.Equal(t, true, meta["enabled"])
|
||||||
|
assert.Equal(t, "llama model", meta["model_name"])
|
||||||
|
assert.Equal(t, 16384, meta["context"])
|
||||||
|
|
||||||
|
// Verify string interpolation converts to string
|
||||||
|
assert.Equal(t, "Running on port 10001 with temp 0.7 and context 16384", meta["note"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_NestedStructuresInMetadata(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
TEMP: 0.7
|
||||||
|
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
metadata:
|
||||||
|
config:
|
||||||
|
port: ${PORT}
|
||||||
|
temperature: ${TEMP}
|
||||||
|
tags: ["model:${MODEL_ID}", "port:${PORT}"]
|
||||||
|
nested:
|
||||||
|
deep:
|
||||||
|
value: ${TEMP}
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
meta := config.Models["test-model"].Metadata
|
||||||
|
assert.NotNil(t, meta)
|
||||||
|
|
||||||
|
// Verify nested objects
|
||||||
|
configMap := meta["config"].(map[string]any)
|
||||||
|
assert.Equal(t, 10000, configMap["port"])
|
||||||
|
assert.Equal(t, 0.7, configMap["temperature"])
|
||||||
|
|
||||||
|
// Verify arrays
|
||||||
|
tags := meta["tags"].([]any)
|
||||||
|
assert.Equal(t, "model:test-model", tags[0])
|
||||||
|
assert.Equal(t, "port:10000", tags[1])
|
||||||
|
|
||||||
|
// Verify deeply nested structures
|
||||||
|
nested := meta["nested"].(map[string]any)
|
||||||
|
deep := nested["deep"].(map[string]any)
|
||||||
|
assert.Equal(t, 0.7, deep["value"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_ModelLevelMacroPrecedenceInMetadata(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
TEMP: 0.5
|
||||||
|
GLOBAL_VAL: "global"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
macros:
|
||||||
|
TEMP: 0.9
|
||||||
|
LOCAL_VAL: "local"
|
||||||
|
metadata:
|
||||||
|
temperature: ${TEMP}
|
||||||
|
global: ${GLOBAL_VAL}
|
||||||
|
local: ${LOCAL_VAL}
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
meta := config.Models["test-model"].Metadata
|
||||||
|
assert.NotNil(t, meta)
|
||||||
|
|
||||||
|
// Model-level macro should override global
|
||||||
|
assert.Equal(t, 0.9, meta["temperature"])
|
||||||
|
// Global macro should be accessible
|
||||||
|
assert.Equal(t, "global", meta["global"])
|
||||||
|
// Model-level macro should be accessible
|
||||||
|
assert.Equal(t, "local", meta["local"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_UnknownMacroInMetadata(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
metadata:
|
||||||
|
value: ${UNKNOWN_MACRO}
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "test-model")
|
||||||
|
assert.Contains(t, err.Error(), "UNKNOWN_MACRO")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_InvalidMacroType(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
INVALID:
|
||||||
|
nested: value
|
||||||
|
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "INVALID")
|
||||||
|
assert.Contains(t, err.Error(), "must be a scalar type")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_MacroTypeValidation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
yaml string
|
||||||
|
shouldErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "string macro",
|
||||||
|
yaml: `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
STR: "test"
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
`,
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "int macro",
|
||||||
|
yaml: `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
NUM: 42
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
`,
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "float macro",
|
||||||
|
yaml: `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
FLOAT: 3.14
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
`,
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bool macro",
|
||||||
|
yaml: `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
BOOL: true
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
`,
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array macro (invalid)",
|
||||||
|
yaml: `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
ARR: [1, 2, 3]
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
`,
|
||||||
|
shouldErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "map macro (invalid)",
|
||||||
|
yaml: `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
MAP:
|
||||||
|
key: value
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
`,
|
||||||
|
shouldErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(tt.yaml))
|
||||||
|
if tt.shouldErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
//go:build windows
|
//go:build windows
|
||||||
|
|
||||||
package proxy
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
@@ -155,8 +155,8 @@ groups:
|
|||||||
expected := Config{
|
expected := Config{
|
||||||
LogLevel: "info",
|
LogLevel: "info",
|
||||||
StartPort: 5800,
|
StartPort: 5800,
|
||||||
Macros: map[string]string{
|
Macros: MacroList{
|
||||||
"svr-path": "path/to/server",
|
{"svr-path", "path/to/server"},
|
||||||
},
|
},
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": {
|
"model1": {
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test macro-in-macro basic substitution
|
||||||
|
func TestConfig_MacroInMacroBasic(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"A": "value-A"
|
||||||
|
"B": "prefix-${A}-suffix"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: echo ${B}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "echo prefix-value-A-suffix", config.Models["test"].Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test LIFO substitution order with 3+ macro levels
|
||||||
|
func TestConfig_MacroInMacroLIFOOrder(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"base": "/models"
|
||||||
|
"path": "${base}/llama"
|
||||||
|
"full": "${path}/model.gguf"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: load ${full}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "load /models/llama/model.gguf", config.Models["test"].Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MODEL_ID in global macro used by model
|
||||||
|
func TestConfig_ModelIdInGlobalMacro(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||||
|
|
||||||
|
models:
|
||||||
|
my-model:
|
||||||
|
cmd: ${podman-llama} -m model.gguf
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf", config.Models["my-model"].Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test model macro overrides global macro in substitution
|
||||||
|
func TestConfig_ModelMacroOverridesGlobal(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"tag": "global"
|
||||||
|
"msg": "value-${tag}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
macros:
|
||||||
|
"tag": "model-level"
|
||||||
|
cmd: echo ${msg}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "echo value-model-level", config.Models["test"].Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test self-reference detection error
|
||||||
|
func TestConfig_SelfReferenceDetection(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"recursive": "value-${recursive}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: echo ${recursive}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "recursive")
|
||||||
|
assert.Contains(t, err.Error(), "self-reference")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test undefined macro reference error
|
||||||
|
func TestConfig_UndefinedMacroReference(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"A": "value-${UNDEFINED}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: echo ${A}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "UNDEFINED")
|
||||||
|
}
|
||||||
@@ -0,0 +1,125 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ModelConfig struct {
|
||||||
|
Cmd string `yaml:"cmd"`
|
||||||
|
CmdStop string `yaml:"cmdStop"`
|
||||||
|
Proxy string `yaml:"proxy"`
|
||||||
|
Aliases []string `yaml:"aliases"`
|
||||||
|
Env []string `yaml:"env"`
|
||||||
|
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||||
|
UnloadAfter int `yaml:"ttl"`
|
||||||
|
Unlisted bool `yaml:"unlisted"`
|
||||||
|
UseModelName string `yaml:"useModelName"`
|
||||||
|
|
||||||
|
// #179 for /v1/models
|
||||||
|
Name string `yaml:"name"`
|
||||||
|
Description string `yaml:"description"`
|
||||||
|
|
||||||
|
// Limit concurrency of HTTP requests to process
|
||||||
|
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||||
|
|
||||||
|
// Model filters see issue #174
|
||||||
|
Filters ModelFilters `yaml:"filters"`
|
||||||
|
|
||||||
|
// Macros: see #264
|
||||||
|
// Model level macros take precedence over the global macros
|
||||||
|
Macros MacroList `yaml:"macros"`
|
||||||
|
|
||||||
|
// Metadata: see #264
|
||||||
|
// Arbitrary metadata that can be exposed through the API
|
||||||
|
Metadata map[string]any `yaml:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
type rawModelConfig ModelConfig
|
||||||
|
defaults := rawModelConfig{
|
||||||
|
Cmd: "",
|
||||||
|
CmdStop: "",
|
||||||
|
Proxy: "http://localhost:${PORT}",
|
||||||
|
Aliases: []string{},
|
||||||
|
Env: []string{},
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
UnloadAfter: 0,
|
||||||
|
Unlisted: false,
|
||||||
|
UseModelName: "",
|
||||||
|
ConcurrencyLimit: 0,
|
||||||
|
Name: "",
|
||||||
|
Description: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unmarshal(&defaults); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*m = ModelConfig(defaults)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||||
|
return SanitizeCommand(m.Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelFilters see issue #174
|
||||||
|
type ModelFilters struct {
|
||||||
|
StripParams string `yaml:"stripParams"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
type rawModelFilters ModelFilters
|
||||||
|
defaults := rawModelFilters{
|
||||||
|
StripParams: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unmarshal(&defaults); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to unmarshal with the old field name for backwards compatibility
|
||||||
|
if defaults.StripParams == "" {
|
||||||
|
var legacy struct {
|
||||||
|
StripParams string `yaml:"strip_params"`
|
||||||
|
}
|
||||||
|
if legacyErr := unmarshal(&legacy); legacyErr != nil {
|
||||||
|
return errors.New("failed to unmarshal legacy filters.strip_params: " + legacyErr.Error())
|
||||||
|
}
|
||||||
|
defaults.StripParams = legacy.StripParams
|
||||||
|
}
|
||||||
|
|
||||||
|
*m = ModelFilters(defaults)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||||
|
if f.StripParams == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
params := strings.Split(f.StripParams, ",")
|
||||||
|
cleaned := make([]string, 0, len(params))
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, param := range params {
|
||||||
|
trimmed := strings.TrimSpace(param)
|
||||||
|
if trimmed == "model" || trimmed == "" || seen[trimmed] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[trimmed] = true
|
||||||
|
cleaned = append(cleaned, trimmed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sort cleaned
|
||||||
|
slices.Sort(cleaned)
|
||||||
|
return cleaned, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||||
|
config := &ModelConfig{
|
||||||
|
Cmd: `python model1.py \
|
||||||
|
--arg1 value1 \
|
||||||
|
--arg2 value2`,
|
||||||
|
}
|
||||||
|
|
||||||
|
args, err := config.SanitizedCommand()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_ModelFilters(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
macros:
|
||||||
|
default_strip: "temperature, top_p"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
filters:
|
||||||
|
# macros inserted and list is cleaned of duplicates and empty strings
|
||||||
|
stripParams: "model, top_k, top_k, temperature, ${default_strip}, , ,"
|
||||||
|
# check for strip_params (legacy field name) compatibility
|
||||||
|
legacy:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
filters:
|
||||||
|
strip_params: "model, top_k, top_k, temperature, ${default_strip}, , ,"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
for modelId, modelConfig := range config.Models {
|
||||||
|
t.Run(fmt.Sprintf("Testing macros in filters for model %s", modelId), func(t *testing.T) {
|
||||||
|
assert.Equal(t, "model, top_k, top_k, temperature, temperature, top_p, , ,", modelConfig.Filters.StripParams)
|
||||||
|
sanitized, err := modelConfig.Filters.SanitizedStripParams()
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
// model has been removed
|
||||||
|
// empty strings have been removed
|
||||||
|
// duplicates have been removed
|
||||||
|
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -65,18 +66,18 @@ func getTestPort() int {
|
|||||||
return port
|
return port
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
||||||
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
|
||||||
// Create a YAML string with just the values we want to set
|
// Create a YAML string with just the values we want to set
|
||||||
yamlStr := fmt.Sprintf(`
|
yamlStr := fmt.Sprintf(`
|
||||||
cmd: '%s --port %d --silent --respond %s'
|
cmd: '%s --port %d --silent --respond %s'
|
||||||
proxy: "http://127.0.0.1:%d"
|
proxy: "http://127.0.0.1:%d"
|
||||||
`, simpleResponderPath, port, expectedMessage, port)
|
`, simpleResponderPath, port, expectedMessage, port)
|
||||||
|
|
||||||
var cfg ModelConfig
|
var cfg config.ModelConfig
|
||||||
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||||
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
|
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TokenMetrics represents parsed token statistics from llama-server logs
|
// TokenMetrics represents parsed token statistics from llama-server logs
|
||||||
@@ -38,7 +39,7 @@ type MetricsMonitor struct {
|
|||||||
nextID int
|
nextID int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMetricsMonitor(config *Config) *MetricsMonitor {
|
func NewMetricsMonitor(config *config.Config) *MetricsMonitor {
|
||||||
maxMetrics := config.MetricsMaxInMemory
|
maxMetrics := config.MetricsMaxInMemory
|
||||||
if maxMetrics <= 0 {
|
if maxMetrics <= 0 {
|
||||||
maxMetrics = 1000 // Default fallback
|
maxMetrics = 1000 // Default fallback
|
||||||
|
|||||||
+4
-3
@@ -16,6 +16,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProcessState string
|
type ProcessState string
|
||||||
@@ -39,7 +40,7 @@ const (
|
|||||||
|
|
||||||
type Process struct {
|
type Process struct {
|
||||||
ID string
|
ID string
|
||||||
config ModelConfig
|
config config.ModelConfig
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
|
|
||||||
// PR #155 called to cancel the upstream process
|
// PR #155 called to cancel the upstream process
|
||||||
@@ -74,7 +75,7 @@ type Process struct {
|
|||||||
failedStartCount int
|
failedStartCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
||||||
concurrentLimit := 10
|
concurrentLimit := 10
|
||||||
if config.ConcurrencyLimit > 0 {
|
if config.ConcurrencyLimit > 0 {
|
||||||
concurrentLimit = config.ConcurrencyLimit
|
concurrentLimit = config.ConcurrencyLimit
|
||||||
@@ -539,7 +540,7 @@ func (p *Process) cmdStopUpstreamProcess() error {
|
|||||||
|
|
||||||
if p.config.CmdStop != "" {
|
if p.config.CmdStop != "" {
|
||||||
// replace ${PID} with the pid of the process
|
// replace ${PID} with the pid of the process
|
||||||
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
|
stopArgs, err := config.SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
|
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
|
||||||
return err
|
return err
|
||||||
|
|||||||
+12
-11
@@ -10,6 +10,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -90,7 +91,7 @@ func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
|||||||
// test that the automatic start returns the expected error type
|
// test that the automatic start returns the expected error type
|
||||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||||
// Create a process configuration
|
// Create a process configuration
|
||||||
config := ModelConfig{
|
config := config.ModelConfig{
|
||||||
Cmd: "nonexistent-command",
|
Cmd: "nonexistent-command",
|
||||||
Proxy: "http://127.0.0.1:9913",
|
Proxy: "http://127.0.0.1:9913",
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
@@ -325,7 +326,7 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
|||||||
|
|
||||||
// should run and exit but interrupt the long checkHealthTimeout
|
// should run and exit but interrupt the long checkHealthTimeout
|
||||||
checkHealthTimeout := 5
|
checkHealthTimeout := 5
|
||||||
config := ModelConfig{
|
config := config.ModelConfig{
|
||||||
Cmd: "sleep 1",
|
Cmd: "sleep 1",
|
||||||
Proxy: "http://127.0.0.1:9913",
|
Proxy: "http://127.0.0.1:9913",
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
@@ -402,7 +403,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
|||||||
binaryPath := getSimpleResponderPath()
|
binaryPath := getSimpleResponderPath()
|
||||||
port := getTestPort()
|
port := getTestPort()
|
||||||
|
|
||||||
config := ModelConfig{
|
conf := config.ModelConfig{
|
||||||
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
||||||
// to force the process to exit
|
// to force the process to exit
|
||||||
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
||||||
@@ -410,7 +411,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
|||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger)
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
// reduce to make testing go faster
|
// reduce to make testing go faster
|
||||||
@@ -450,15 +451,15 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProcess_StopCmd(t *testing.T) {
|
func TestProcess_StopCmd(t *testing.T) {
|
||||||
config := getTestSimpleResponderConfig("test_stop_cmd")
|
conf := getTestSimpleResponderConfig("test_stop_cmd")
|
||||||
|
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
config.CmdStop = "taskkill /f /t /pid ${PID}"
|
conf.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||||
} else {
|
} else {
|
||||||
config.CmdStop = "kill -TERM ${PID}"
|
conf.CmdStop = "kill -TERM ${PID}"
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger)
|
process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger)
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
err := process.start()
|
err := process.start()
|
||||||
@@ -470,15 +471,15 @@ func TestProcess_StopCmd(t *testing.T) {
|
|||||||
|
|
||||||
func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
|
func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
|
||||||
expectedMessage := "test_env_not_emptied"
|
expectedMessage := "test_env_not_emptied"
|
||||||
config := getTestSimpleResponderConfig(expectedMessage)
|
conf := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
// ensure that the the default config does not blank out the inherited environment
|
// ensure that the the default config does not blank out the inherited environment
|
||||||
configWEnv := config
|
configWEnv := conf
|
||||||
|
|
||||||
// ensure the additiona variables are appended to the process' environment
|
// ensure the additiona variables are appended to the process' environment
|
||||||
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
|
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
|
||||||
|
|
||||||
process1 := NewProcess("env_test", 2, config, debugLogger, debugLogger)
|
process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger)
|
||||||
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
|
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
|
||||||
|
|
||||||
process1.start()
|
process1.start()
|
||||||
|
|||||||
@@ -5,12 +5,14 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProcessGroup struct {
|
type ProcessGroup struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
config Config
|
config config.Config
|
||||||
id string
|
id string
|
||||||
swap bool
|
swap bool
|
||||||
exclusive bool
|
exclusive bool
|
||||||
@@ -24,7 +26,7 @@ type ProcessGroup struct {
|
|||||||
lastUsedProcess string
|
lastUsedProcess string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
||||||
groupConfig, ok := config.Groups[id]
|
groupConfig, ok := config.Groups[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("Unable to find configuration for group id: " + id)
|
panic("Unable to find configuration for group id: " + id)
|
||||||
|
|||||||
@@ -7,19 +7,20 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
"model3": getTestSimpleResponderConfig("model3"),
|
"model3": getTestSimpleResponderConfig("model3"),
|
||||||
"model4": getTestSimpleResponderConfig("model4"),
|
"model4": getTestSimpleResponderConfig("model4"),
|
||||||
"model5": getTestSimpleResponderConfig("model5"),
|
"model5": getTestSimpleResponderConfig("model5"),
|
||||||
},
|
},
|
||||||
Groups: map[string]GroupConfig{
|
Groups: map[string]config.GroupConfig{
|
||||||
"G1": {
|
"G1": {
|
||||||
Swap: true,
|
Swap: true,
|
||||||
Exclusive: true,
|
Exclusive: true,
|
||||||
@@ -34,7 +35,7 @@ var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
|||||||
})
|
})
|
||||||
|
|
||||||
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
||||||
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
||||||
assert.True(t, pg.HasMember("model5"))
|
assert.True(t, pg.HasMember("model5"))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,9 +49,9 @@ func TestProcessGroup_HasMember(t *testing.T) {
|
|||||||
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||||
// and multiple requests are made in parallel, only one process is running at a time.
|
// and multiple requests are made in parallel, only one process is running at a time.
|
||||||
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||||
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
// use the same listening so if a model is already running, it will fail
|
// use the same listening so if a model is already running, it will fail
|
||||||
// this is a way to test that swap isolation is working
|
// this is a way to test that swap isolation is working
|
||||||
// properly when there are parallel requests made at the
|
// properly when there are parallel requests made at the
|
||||||
@@ -61,7 +62,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
|||||||
"model4": getTestSimpleResponderConfigPort("model4", 9832),
|
"model4": getTestSimpleResponderConfigPort("model4", 9832),
|
||||||
"model5": getTestSimpleResponderConfigPort("model5", 9832),
|
"model5": getTestSimpleResponderConfigPort("model5", 9832),
|
||||||
},
|
},
|
||||||
Groups: map[string]GroupConfig{
|
Groups: map[string]config.GroupConfig{
|
||||||
"G1": {
|
"G1": {
|
||||||
Swap: true,
|
Swap: true,
|
||||||
Members: []string{"model1", "model2", "model3", "model4", "model5"},
|
Members: []string{"model1", "model2", "model3", "model4", "model5"},
|
||||||
|
|||||||
+10
-2
@@ -16,6 +16,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
@@ -27,7 +28,7 @@ const (
|
|||||||
type ProxyManager struct {
|
type ProxyManager struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
config Config
|
config config.Config
|
||||||
ginEngine *gin.Engine
|
ginEngine *gin.Engine
|
||||||
|
|
||||||
// logging
|
// logging
|
||||||
@@ -44,7 +45,7 @@ type ProxyManager struct {
|
|||||||
shutdownCancel context.CancelFunc
|
shutdownCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(config Config) *ProxyManager {
|
func New(config config.Config) *ProxyManager {
|
||||||
// set up loggers
|
// set up loggers
|
||||||
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
||||||
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
||||||
@@ -369,6 +370,13 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
|||||||
record["description"] = desc
|
record["description"] = desc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add metadata if present
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
record["meta"] = gin.H{
|
||||||
|
"llamaswap": modelConfig.Metadata,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
data = append(data, record)
|
data = append(data, record)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+140
-55
@@ -16,14 +16,15 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
@@ -44,14 +45,14 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
Groups: map[string]GroupConfig{
|
Groups: map[string]config.GroupConfig{
|
||||||
"G1": {
|
"G1": {
|
||||||
Swap: true,
|
Swap: true,
|
||||||
Exclusive: false,
|
Exclusive: false,
|
||||||
@@ -89,14 +90,14 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
|||||||
// Test that a persistent group is not affected by the swapping behaviour of
|
// Test that a persistent group is not affected by the swapping behaviour of
|
||||||
// other groups.
|
// other groups.
|
||||||
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
|
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
Groups: map[string]GroupConfig{
|
Groups: map[string]config.GroupConfig{
|
||||||
// the forever group is persistent and should not be affected by model1
|
// the forever group is persistent and should not be affected by model1
|
||||||
"forever": {
|
"forever": {
|
||||||
Swap: true,
|
Swap: true,
|
||||||
@@ -133,9 +134,9 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
t.Skip("skipping slow test")
|
t.Skip("skipping slow test")
|
||||||
}
|
}
|
||||||
|
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
"model3": getTestSimpleResponderConfig("model3"),
|
"model3": getTestSimpleResponderConfig("model3"),
|
||||||
@@ -196,9 +197,9 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|||||||
model2Config.Name = " " // empty whitespace only strings will get ignored
|
model2Config.Name = " " // empty whitespace only strings will get ignored
|
||||||
model2Config.Description = " "
|
model2Config.Description = " "
|
||||||
|
|
||||||
config := Config{
|
config := config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": model1Config,
|
"model1": model1Config,
|
||||||
"model2": model2Config,
|
"model2": model2Config,
|
||||||
"model3": getTestSimpleResponderConfig("model3"),
|
"model3": getTestSimpleResponderConfig("model3"),
|
||||||
@@ -281,15 +282,99 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|||||||
assert.Empty(t, expectedModels, "not all expected models were returned")
|
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_ListModelsHandler_WithMetadata(t *testing.T) {
|
||||||
|
// Process config through LoadConfigFromReader to apply macro substitution
|
||||||
|
configYaml := `
|
||||||
|
healthCheckTimeout: 15
|
||||||
|
logLevel: error
|
||||||
|
startPort: 10000
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
macros:
|
||||||
|
PORT_NUM: 10001
|
||||||
|
TEMP: 0.7
|
||||||
|
NAME: "llama"
|
||||||
|
metadata:
|
||||||
|
port: ${PORT_NUM}
|
||||||
|
temperature: ${TEMP}
|
||||||
|
enabled: true
|
||||||
|
note: "Running on port ${PORT_NUM}"
|
||||||
|
nested:
|
||||||
|
value: ${TEMP}
|
||||||
|
model2:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
`
|
||||||
|
processedConfig, err := config.LoadConfigFromReader(strings.NewReader(configYaml))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
proxy := New(processedConfig)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var response struct {
|
||||||
|
Data []map[string]any `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, response.Data, 2)
|
||||||
|
|
||||||
|
// Find model1 and model2 in response
|
||||||
|
var model1Data, model2Data map[string]any
|
||||||
|
for _, model := range response.Data {
|
||||||
|
if model["id"] == "model1" {
|
||||||
|
model1Data = model
|
||||||
|
} else if model["id"] == "model2" {
|
||||||
|
model2Data = model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify model1 has llamaswap_meta
|
||||||
|
assert.NotNil(t, model1Data)
|
||||||
|
meta, exists := model1Data["meta"]
|
||||||
|
if !assert.True(t, exists, "model1 should have meta key") {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
metaMap := meta.(map[string]any)
|
||||||
|
|
||||||
|
lsmeta, exists := metaMap["llamaswap"]
|
||||||
|
if !assert.True(t, exists, "model1 should have meta.llamaswap key") {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
lsmetamap := lsmeta.(map[string]any)
|
||||||
|
|
||||||
|
// Verify type preservation
|
||||||
|
assert.Equal(t, float64(10001), lsmetamap["port"]) // JSON numbers are float64
|
||||||
|
assert.Equal(t, 0.7, lsmetamap["temperature"])
|
||||||
|
assert.Equal(t, true, lsmetamap["enabled"])
|
||||||
|
// Verify string interpolation
|
||||||
|
assert.Equal(t, "Running on port 10001", lsmetamap["note"])
|
||||||
|
// Verify nested structure
|
||||||
|
nested := lsmetamap["nested"].(map[string]any)
|
||||||
|
assert.Equal(t, 0.7, nested["value"])
|
||||||
|
|
||||||
|
// Verify model2 does NOT have llamaswap_meta
|
||||||
|
assert.NotNil(t, model2Data)
|
||||||
|
_, exists = model2Data["llamaswap_meta"]
|
||||||
|
assert.False(t, exists, "model2 should not have llamaswap_meta")
|
||||||
|
}
|
||||||
|
|
||||||
func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
|
func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
|
||||||
// Intentionally add models in non-sorted order and with an unlisted model
|
// Intentionally add models in non-sorted order and with an unlisted model
|
||||||
config := Config{
|
config := config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"zeta": getTestSimpleResponderConfig("zeta"),
|
"zeta": getTestSimpleResponderConfig("zeta"),
|
||||||
"alpha": getTestSimpleResponderConfig("alpha"),
|
"alpha": getTestSimpleResponderConfig("alpha"),
|
||||||
"beta": getTestSimpleResponderConfig("beta"),
|
"beta": getTestSimpleResponderConfig("beta"),
|
||||||
"hidden": func() ModelConfig {
|
"hidden": func() config.ModelConfig {
|
||||||
mc := getTestSimpleResponderConfig("hidden")
|
mc := getTestSimpleResponderConfig("hidden")
|
||||||
mc.Unlisted = true
|
mc.Unlisted = true
|
||||||
return mc
|
return mc
|
||||||
@@ -337,15 +422,15 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
|||||||
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
||||||
model3Config.Proxy = "http://localhost:10003/"
|
model3Config.Proxy = "http://localhost:10003/"
|
||||||
|
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": model1Config,
|
"model1": model1Config,
|
||||||
"model2": model2Config,
|
"model2": model2Config,
|
||||||
"model3": model3Config,
|
"model3": model3Config,
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
Groups: map[string]GroupConfig{
|
Groups: map[string]config.GroupConfig{
|
||||||
"test": {
|
"test": {
|
||||||
Swap: false,
|
Swap: false,
|
||||||
Members: []string{"model1", "model2", "model3"},
|
Members: []string{"model1", "model2", "model3"},
|
||||||
@@ -380,21 +465,21 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_Unload(t *testing.T) {
|
func TestProxyManager_Unload(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
conf := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(conf)
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
||||||
req = httptest.NewRequest("GET", "/unload", nil)
|
req = httptest.NewRequest("GET", "/unload", nil)
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
@@ -402,23 +487,23 @@ func TestProxyManager_Unload(t *testing.T) {
|
|||||||
assert.Equal(t, w.Body.String(), "OK")
|
assert.Equal(t, w.Body.String(), "OK")
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan:
|
case <-proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan:
|
||||||
// good
|
// good
|
||||||
case <-time.After(2 * time.Second):
|
case <-time.After(2 * time.Second):
|
||||||
t.Fatal("timeout waiting for model1 to stop")
|
t.Fatal("timeout waiting for model1 to stop")
|
||||||
}
|
}
|
||||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
|
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_UnloadSingleModel(t *testing.T) {
|
func TestProxyManager_UnloadSingleModel(t *testing.T) {
|
||||||
const testGroupId = "testGroup"
|
const testGroupId = "testGroup"
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
Groups: map[string]GroupConfig{
|
Groups: map[string]config.GroupConfig{
|
||||||
testGroupId: {
|
testGroupId: {
|
||||||
Swap: false,
|
Swap: false,
|
||||||
Members: []string{"model1", "model2"},
|
Members: []string{"model1", "model2"},
|
||||||
@@ -463,9 +548,9 @@ func TestProxyManager_UnloadSingleModel(t *testing.T) {
|
|||||||
// Test issue #61 `Listing the current list of models and the loaded model.`
|
// Test issue #61 `Listing the current list of models and the loaded model.`
|
||||||
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||||
// Shared configuration
|
// Shared configuration
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
@@ -528,9 +613,9 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
@@ -581,15 +666,15 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
|||||||
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
||||||
modelConfig.UseModelName = upstreamModelName
|
modelConfig.UseModelName = upstreamModelName
|
||||||
|
|
||||||
config := AddDefaultGroupToConfig(Config{
|
conf := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": modelConfig,
|
"model1": modelConfig,
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(conf)
|
||||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
requestedModel := "model1"
|
requestedModel := "model1"
|
||||||
@@ -644,9 +729,9 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
@@ -720,7 +805,7 @@ models:
|
|||||||
aliases: [model-alias]
|
aliases: [model-alias]
|
||||||
`, getSimpleResponderPath())
|
`, getSimpleResponderPath())
|
||||||
|
|
||||||
config, err := LoadConfigFromReader(strings.NewReader(configStr))
|
config, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
@@ -743,9 +828,9 @@ models:
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_ChatContentLength(t *testing.T) {
|
func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
@@ -768,14 +853,14 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
|||||||
|
|
||||||
func TestProxyManager_FiltersStripParams(t *testing.T) {
|
func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||||
modelConfig := getTestSimpleResponderConfig("model1")
|
modelConfig := getTestSimpleResponderConfig("model1")
|
||||||
modelConfig.Filters = ModelFilters{
|
modelConfig.Filters = config.ModelFilters{
|
||||||
StripParams: "temperature, model, stream",
|
StripParams: "temperature, model, stream",
|
||||||
}
|
}
|
||||||
|
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": modelConfig,
|
"model1": modelConfig,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -801,9 +886,9 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
|
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
@@ -836,9 +921,9 @@ func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
|
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
@@ -871,9 +956,9 @@ func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_HealthEndpoint(t *testing.T) {
|
func TestProxyManager_HealthEndpoint(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
@@ -890,9 +975,9 @@ func TestProxyManager_HealthEndpoint(t *testing.T) {
|
|||||||
|
|
||||||
// Ensure the custom llama-server /completion endpoint proxies correctly
|
// Ensure the custom llama-server /completion endpoint proxies correctly
|
||||||
func TestProxyManager_CompletionEndpoint(t *testing.T) {
|
func TestProxyManager_CompletionEndpoint(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
@@ -936,7 +1021,7 @@ models:
|
|||||||
`, "${simpleresponderpath}", simpleResponderPath, -1)
|
`, "${simpleresponderpath}", simpleResponderPath, -1)
|
||||||
|
|
||||||
// Create a test model configuration
|
// Create a test model configuration
|
||||||
config, err := LoadConfigFromReader(strings.NewReader(configStr))
|
config, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||||
if !assert.NoError(t, err, "Invalid configuration") {
|
if !assert.NoError(t, err, "Invalid configuration") {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -970,9 +1055,9 @@ models:
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
@@ -1009,9 +1094,9 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testing.T) {
|
func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"streaming-model": getTestSimpleResponderConfig("streaming-model"),
|
"streaming-model": getTestSimpleResponderConfig("streaming-model"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
|
|||||||
Reference in New Issue
Block a user