Compare commits

...

5 Commits

Author SHA1 Message Date
Benson Wong 18c134624d Add Access-Control-Allow-Origin CORS header to /v1/models endpoint
- match behavior of llama.cpp where the Origin in request is used
- add test for listModelsHandler
2024-12-03 15:53:59 -08:00
Benson Wong da2326bdc7 add example: optimizing code generation 2024-12-03 10:25:43 -08:00
Benson Wong da46545630 fix profile example in README 2024-12-01 10:13:31 -08:00
Benson Wong 04b4760e7e change profile split character to : (colon) (#21)
- change from `/` to `:` for multiple models loaded as part of a profile
- breaking change now, but allows for more compatibility with other inference engines that may have model references like `coding:Qwen/Qwen-2.5-Coder-32B`
2024-12-01 09:10:50 -08:00
Benson Wong 9fc5d5b5eb improve cmd parsing (#22)
Switch from using a naive strings.Fields() to shlex.Split() for parsing the model startup command into a string[]. This makes parsing much more reliable around newlines, quotes, etc.
2024-12-01 09:02:58 -08:00
11 changed files with 309 additions and 38 deletions
+1 -1
View File
@@ -64,7 +64,7 @@ models:
#
# Tips:
# - each model must be listening on a unique address and port
# - the model name is in this format: "profile_name/model", like "coding/qwen"
# - the model name is in this format: "profile_name:model", like "coding:qwen"
# - the profile will load and unload all models in the profile at the same time
profiles:
coding:
+3 -6
View File
@@ -1,9 +1,6 @@
# Example Configurations
# Example Configs and Use Cases
Learning by example is best.
Here in the `examples/` folder are llama-swap configurations that can be used on your local LLM server.
## List
A collections of usecases and examples for getting the most out of llama-swap.
* [Speculative Decoding](speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
* [Optimizing Code Generation](benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
+123
View File
@@ -0,0 +1,123 @@
# Optimizing Code Generation with llama-swap
Finding the best mix of settings for your hardware can be time consuming. This example demonstrates using a custom configuration file to automate testing different scenarios to find the an optimal configuration.
The benchmark writes a snake game in Python, TypeScript, and Swift using the Qwen 2.5 Coder models. The experiments were done using a 3090 and a P40.
**Benchmark Scenarios**
Three scenarios are tested:
- 3090-only: Just the main model on the 3090
- 3090-with-draft: the main and draft models on the 3090
- 3090-P40-draft: the main model on the 3090 with the draft model offloaded to the P40
**Available Devices**
Use the following command to list available devices IDs for the configuration:
```
$ /mnt/nvme/llama-server/llama-server-f3252055 --list-devices
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 4 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
Device 1: Tesla P40, compute capability 6.1, VMM: yes
Device 2: Tesla P40, compute capability 6.1, VMM: yes
Device 3: Tesla P40, compute capability 6.1, VMM: yes
Available devices:
CUDA0: NVIDIA GeForce RTX 3090 (24154 MiB, 406 MiB free)
CUDA1: Tesla P40 (24438 MiB, 22942 MiB free)
CUDA2: Tesla P40 (24438 MiB, 24144 MiB free)
CUDA3: Tesla P40 (24438 MiB, 24144 MiB free)
```
**Configuration**
The configuration file, `benchmark-config.yaml`, defines the three scenarios:
```yaml
models:
"3090-only":
proxy: "http://127.0.0.1:9503"
cmd: >
/mnt/nvme/llama-server/llama-server-f3252055
--host 127.0.0.1 --port 9503
--flash-attn
--slots
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
-ngl 99
--device CUDA0
--ctx-size 32768
--cache-type-k q8_0 --cache-type-v q8_0
"3090-with-draft":
proxy: "http://127.0.0.1:9503"
# --ctx-size 28500 max that can fit on 3090 after draft model
cmd: >
/mnt/nvme/llama-server/llama-server-f3252055
--host 127.0.0.1 --port 9503
--flash-attn
--slots
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
-ngl 99
--device CUDA0
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
-ngld 99
--draft-max 16
--draft-min 4
--draft-p-min 0.4
--device-draft CUDA0
--ctx-size 28500
--cache-type-k q8_0 --cache-type-v q8_0
"3090-P40-draft":
proxy: "http://127.0.0.1:9503"
cmd: >
/mnt/nvme/llama-server/llama-server-f3252055
--host 127.0.0.1 --port 9503
--flash-attn --metrics
--slots
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
-ngl 99
--device CUDA0
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
-ngld 99
--draft-max 16
--draft-min 4
--draft-p-min 0.4
--device-draft CUDA1
--ctx-size 32768
--cache-type-k q8_0 --cache-type-v q8_0
```
> Note in the `3090-with-draft` scenario the `--ctx-size` had to be reduced from 32768 to to accommodate the draft model.
**Running the Benchmark**
To run the benchmark, execute the following commands:
1. `llama-swap -config benchmark-config.yaml`
1. `./run-benchmark.sh http://localhost:8080 "3090-only" "3090-with-draft" "3090-P40-draft"`
The [benchmark script](run-benchmark.sh) generates a CSV output of the results, which can be converted to a Markdown table for readability.
**Results (tokens/second)**
| model | python | typescript | swift |
|-----------------|--------|------------|-------|
| 3090-only | 34.03 | 34.01 | 34.01 |
| 3090-with-draft | 106.65 | 70.48 | 57.89 |
| 3090-P40-draft | 81.54 | 60.35 | 46.50 |
Many different factors, like the programming language, can have big impacts on the performance gains. However, with a custom configuration file for benchmarking it is easy to test the different variations to discover what's best for your hardware.
Happy coding!
+43
View File
@@ -0,0 +1,43 @@
#!/usr/bin/env bash
# This script generates a CSV file showing the token/second for generating a Snake Game in python, typescript and swift
# It was created to test the effects of speculative decoding and the various draft settings on performance.
#
# Writing code with a low temperature seems to provide fairly consistent logic.
#
# Usage: ./benchmark.sh <url> <model1> [model2 ...]
# Example: ./benchmark.sh http://localhost:8080 model1 model2
if [ "$#" -lt 2 ]; then
echo "Usage: $0 <url> <model1> [model2 ...]"
exit 1
fi
url=$1; shift
echo "model,python,typescript,swift"
for model in "$@"; do
echo -n "$model,"
for lang in "python" "typescript" "swift"; do
response=$(curl -s --url "$url/v1/chat/completions" -d "{\"messages\": [{\"role\": \"system\", \"content\": \"you only write code.\"}, {\"role\": \"user\", \"content\": \"write snake game in $lang\"}], \"temperature\": 0.1, \"model\":\"$model\"}")
if [ $? -ne 0 ]; then
time="error"
else
time=$(curl -s --url "$url/logs" | grep -oE '\d+(?:\.\d+)? tokens per second' | awk '{print $1}' | tail -n 1)
if [ $? -ne 0 ]; then
time="error"
fi
fi
if [ "$lang" != "swift" ]; then
echo -n "$time,"
else
echo -n "$time"
fi
done
echo ""
done
+1
View File
@@ -20,6 +20,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
+2
View File
@@ -24,6 +24,8 @@ github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaC
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
+5 -1
View File
@@ -5,6 +5,7 @@ import (
"os"
"strings"
"github.com/google/shlex"
"gopkg.in/yaml.v3"
)
@@ -81,7 +82,10 @@ func SanitizeCommand(cmdStr string) ([]string, error) {
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
// Split the command into arguments
args := strings.Fields(cmdStr)
args, err := shlex.Split(cmdStr)
if err != nil {
return nil, err
}
// Ensure the command is not empty
if len(args) == 0 {
+17 -8
View File
@@ -148,17 +148,26 @@ func TestConfig_FindConfig(t *testing.T) {
}
func TestConfig_SanitizeCommand(t *testing.T) {
// Test a simple command
args, err := SanitizeCommand("python model1.py")
assert.NoError(t, err)
assert.Equal(t, []string{"python", "model1.py"}, args)
// Test a command with spaces and newlines
args, err = SanitizeCommand(`python model1.py \
--arg1 value1 \
--arg2 value2`)
args, err := SanitizeCommand(`python model1.py \
-a "double quotes" \
--arg2 'single quotes'
-s
--arg3 123 \
--arg4 '"string in string"'
-c "'single quoted'"
`)
assert.NoError(t, err)
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
assert.Equal(t, []string{
"python", "model1.py",
"-a", "double quotes",
"--arg2", "single quotes",
"-s",
"--arg3", "123",
"--arg4", `"string in string"`,
"-c", `'single quoted'`,
}, args)
// Test an empty command
args, err = SanitizeCommand("")
+1 -1
View File
@@ -101,7 +101,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
// issue #19
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
if testing.Short() {
t.Skip("skipping long test")
t.Skip("skipping slow test")
}
expectedMessage := "12345"
+25 -12
View File
@@ -14,6 +14,10 @@ import (
"github.com/gin-gonic/gin"
)
const (
PROFILE_SPLIT_CHAR = ":"
)
type ProxyManager struct {
sync.Mutex
@@ -94,6 +98,10 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
// Set the Content-Type header to application/json
c.Header("Content-Type", "application/json")
if origin := c.Request.Header.Get("Origin"); origin != "" {
c.Header("Access-Control-Allow-Origin", origin)
}
// Encode the data as JSON and write it to the response writer
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error encoding JSON"))
@@ -106,15 +114,15 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
defer pm.Unlock()
// Check if requestedModel contains a /
groupName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, "/"); idx != -1 {
groupName = requestedModel[:idx]
profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
}
if groupName != "" {
if _, found := pm.config.Profiles[groupName]; !found {
return nil, fmt.Errorf("model group not found %s", groupName)
if profileName != "" {
if _, found := pm.config.Profiles[profileName]; !found {
return nil, fmt.Errorf("model group not found %s", profileName)
}
}
@@ -125,7 +133,8 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
}
// exit early when already running, otherwise stop everything and swap
requestedProcessKey := groupName + "/" + realModelName
requestedProcessKey := ProcessKeyName(profileName, realModelName)
if process, found := pm.currentProcesses[requestedProcessKey]; found {
return process, nil
}
@@ -133,25 +142,25 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
// stop all running models
pm.stopProcesses()
if groupName == "" {
if profileName == "" {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := groupName + "/" + modelID
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
} else {
for _, modelName := range pm.config.Profiles[groupName] {
for _, modelName := range pm.config.Profiles[profileName] {
if realModelName, found := pm.config.RealModelName(modelName); found {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, groupName)
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := groupName + "/" + modelID
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
}
}
@@ -201,3 +210,7 @@ func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) {
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request"))
}
func ProcessKeyName(groupName, modelName string) string {
return groupName + PROFILE_SPLIT_CHAR + modelName
}
+88 -9
View File
@@ -2,6 +2,7 @@ package proxy
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
@@ -33,7 +34,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
_, exists := proxy.currentProcesses["/"+modelName]
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
}
@@ -43,21 +44,31 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
}
func TestProxyManager_SwapMultiProcess(t *testing.T) {
model1 := "path1/model1"
model2 := "path2/model2"
profileModel1 := ProcessKeyName("test", model1)
profileModel2 := ProcessKeyName("test", model2)
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
model1: getTestSimpleResponderConfig("model1"),
model2: getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {"model1", "model2"},
"test": {model1, model2},
},
}
proxy := New(config)
defer proxy.StopProcesses()
for modelID, requestedModel := range map[string]string{"model1": "test/model1", "model2": "test/model2"} {
for modelID, requestedModel := range map[string]string{
"model1": profileModel1,
"model2": profileModel2,
} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
@@ -69,11 +80,11 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
// make sure there's two loaded models
assert.Len(t, proxy.currentProcesses, 2)
_, exists := proxy.currentProcesses["test/model1"]
assert.True(t, exists, "expected test/model1 key in currentProcesses")
_, exists := proxy.currentProcesses[profileModel1]
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
_, exists = proxy.currentProcesses["test/model2"]
assert.True(t, exists, "expected test/model2 key in currentProcesses")
_, exists = proxy.currentProcesses[profileModel2]
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
}
// When a request for a different model comes in ProxyManager should wait until
@@ -131,3 +142,71 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
assert.Equal(t, key, result)
}
}
func TestProxyManager_ListModelsHandler(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
},
}
proxy := New(config)
// Create a test request
req := httptest.NewRequest("GET", "/v1/models", nil)
req.Header.Add("Origin", "i-am-the-origin")
w := httptest.NewRecorder()
// Call the listModelsHandler
proxy.HandlerFunc(w, req)
// Check the response status code
assert.Equal(t, http.StatusOK, w.Code)
// Check for Access-Control-Allow-Origin
assert.Equal(t, req.Header.Get("Origin"), w.Result().Header.Get("Access-Control-Allow-Origin"))
// Parse the JSON response
var response struct {
Data []map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse JSON response: %v", err)
}
// Check the number of models returned
assert.Len(t, response.Data, 3)
// Check the details of each model
expectedModels := map[string]struct{}{
"model1": {},
"model2": {},
"model3": {},
}
for _, model := range response.Data {
modelID, ok := model["id"].(string)
assert.True(t, ok, "model ID should be a string")
_, exists := expectedModels[modelID]
assert.True(t, exists, "unexpected model ID: %s", modelID)
delete(expectedModels, modelID)
object, ok := model["object"].(string)
assert.True(t, ok, "object should be a string")
assert.Equal(t, "model", object)
created, ok := model["created"].(float64)
assert.True(t, ok, "created should be a number")
assert.Greater(t, created, float64(0)) // Assuming the timestamp is positive
ownedBy, ok := model["owned_by"].(string)
assert.True(t, ok, "owned_by should be a string")
assert.Equal(t, "llama-swap", ownedBy)
}
// Ensure all expected models were returned
assert.Empty(t, expectedModels, "not all expected models were returned")
}