Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 04b4760e7e | |||
| 9fc5d5b5eb |
@@ -20,6 +20,7 @@ require (
|
|||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||||
github.com/leodido/go-urn v1.4.0 // indirect
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaC
|
|||||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
||||||
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
|
|||||||
+5
-1
@@ -5,6 +5,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/shlex"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -81,7 +82,10 @@ func SanitizeCommand(cmdStr string) ([]string, error) {
|
|||||||
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
|
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
|
||||||
|
|
||||||
// Split the command into arguments
|
// Split the command into arguments
|
||||||
args := strings.Fields(cmdStr)
|
args, err := shlex.Split(cmdStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure the command is not empty
|
// Ensure the command is not empty
|
||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
|
|||||||
+17
-8
@@ -148,17 +148,26 @@ func TestConfig_FindConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||||
// Test a simple command
|
|
||||||
args, err := SanitizeCommand("python model1.py")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{"python", "model1.py"}, args)
|
|
||||||
|
|
||||||
// Test a command with spaces and newlines
|
// Test a command with spaces and newlines
|
||||||
args, err = SanitizeCommand(`python model1.py \
|
args, err := SanitizeCommand(`python model1.py \
|
||||||
--arg1 value1 \
|
-a "double quotes" \
|
||||||
--arg2 value2`)
|
--arg2 'single quotes'
|
||||||
|
-s
|
||||||
|
--arg3 123 \
|
||||||
|
--arg4 '"string in string"'
|
||||||
|
-c "'single quoted'"
|
||||||
|
`)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
assert.Equal(t, []string{
|
||||||
|
"python", "model1.py",
|
||||||
|
"-a", "double quotes",
|
||||||
|
"--arg2", "single quotes",
|
||||||
|
"-s",
|
||||||
|
"--arg3", "123",
|
||||||
|
"--arg4", `"string in string"`,
|
||||||
|
"-c", `'single quoted'`,
|
||||||
|
}, args)
|
||||||
|
|
||||||
// Test an empty command
|
// Test an empty command
|
||||||
args, err = SanitizeCommand("")
|
args, err = SanitizeCommand("")
|
||||||
|
|||||||
+21
-12
@@ -14,6 +14,10 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
PROFILE_SPLIT_CHAR = ":"
|
||||||
|
)
|
||||||
|
|
||||||
type ProxyManager struct {
|
type ProxyManager struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
@@ -106,15 +110,15 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
// Check if requestedModel contains a /
|
// Check if requestedModel contains a /
|
||||||
groupName, modelName := "", requestedModel
|
profileName, modelName := "", requestedModel
|
||||||
if idx := strings.Index(requestedModel, "/"); idx != -1 {
|
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||||
groupName = requestedModel[:idx]
|
profileName = requestedModel[:idx]
|
||||||
modelName = requestedModel[idx+1:]
|
modelName = requestedModel[idx+1:]
|
||||||
}
|
}
|
||||||
|
|
||||||
if groupName != "" {
|
if profileName != "" {
|
||||||
if _, found := pm.config.Profiles[groupName]; !found {
|
if _, found := pm.config.Profiles[profileName]; !found {
|
||||||
return nil, fmt.Errorf("model group not found %s", groupName)
|
return nil, fmt.Errorf("model group not found %s", profileName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,7 +129,8 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// exit early when already running, otherwise stop everything and swap
|
// exit early when already running, otherwise stop everything and swap
|
||||||
requestedProcessKey := groupName + "/" + realModelName
|
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
||||||
|
|
||||||
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
||||||
return process, nil
|
return process, nil
|
||||||
}
|
}
|
||||||
@@ -133,25 +138,25 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
// stop all running models
|
// stop all running models
|
||||||
pm.stopProcesses()
|
pm.stopProcesses()
|
||||||
|
|
||||||
if groupName == "" {
|
if profileName == "" {
|
||||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||||
if !found {
|
if !found {
|
||||||
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||||
processKey := groupName + "/" + modelID
|
processKey := ProcessKeyName(profileName, modelID)
|
||||||
pm.currentProcesses[processKey] = process
|
pm.currentProcesses[processKey] = process
|
||||||
} else {
|
} else {
|
||||||
for _, modelName := range pm.config.Profiles[groupName] {
|
for _, modelName := range pm.config.Profiles[profileName] {
|
||||||
if realModelName, found := pm.config.RealModelName(modelName); found {
|
if realModelName, found := pm.config.RealModelName(modelName); found {
|
||||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||||
if !found {
|
if !found {
|
||||||
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, groupName)
|
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||||
processKey := groupName + "/" + modelID
|
processKey := ProcessKeyName(profileName, modelID)
|
||||||
pm.currentProcesses[processKey] = process
|
pm.currentProcesses[processKey] = process
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -201,3 +206,7 @@ func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) {
|
|||||||
|
|
||||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request"))
|
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ProcessKeyName(groupName, modelName string) string {
|
||||||
|
return groupName + PROFILE_SPLIT_CHAR + modelName
|
||||||
|
}
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), modelName)
|
assert.Contains(t, w.Body.String(), modelName)
|
||||||
|
|
||||||
_, exists := proxy.currentProcesses["/"+modelName]
|
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
|
||||||
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
|
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -43,21 +43,31 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||||
|
|
||||||
|
model1 := "path1/model1"
|
||||||
|
model2 := "path2/model2"
|
||||||
|
|
||||||
|
profileModel1 := ProcessKeyName("test", model1)
|
||||||
|
profileModel2 := ProcessKeyName("test", model2)
|
||||||
|
|
||||||
config := &Config{
|
config := &Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
model1: getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
model2: getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
Profiles: map[string][]string{
|
Profiles: map[string][]string{
|
||||||
"test": {"model1", "model2"},
|
"test": {model1, model2},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
for modelID, requestedModel := range map[string]string{"model1": "test/model1", "model2": "test/model2"} {
|
for modelID, requestedModel := range map[string]string{
|
||||||
|
"model1": profileModel1,
|
||||||
|
"model2": profileModel2,
|
||||||
|
} {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -69,11 +79,11 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
|||||||
|
|
||||||
// make sure there's two loaded models
|
// make sure there's two loaded models
|
||||||
assert.Len(t, proxy.currentProcesses, 2)
|
assert.Len(t, proxy.currentProcesses, 2)
|
||||||
_, exists := proxy.currentProcesses["test/model1"]
|
_, exists := proxy.currentProcesses[profileModel1]
|
||||||
assert.True(t, exists, "expected test/model1 key in currentProcesses")
|
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
|
||||||
|
|
||||||
_, exists = proxy.currentProcesses["test/model2"]
|
_, exists = proxy.currentProcesses[profileModel2]
|
||||||
assert.True(t, exists, "expected test/model2 key in currentProcesses")
|
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
|
||||||
}
|
}
|
||||||
|
|
||||||
// When a request for a different model comes in ProxyManager should wait until
|
// When a request for a different model comes in ProxyManager should wait until
|
||||||
|
|||||||
Reference in New Issue
Block a user