Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ab5a048584 |
@@ -15,8 +15,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
#platform: [intel, cuda, vulkan, cpu, musa]
|
platform: [intel, cuda, vulkan, cpu, musa]
|
||||||
platform: [cuda, vulkan, cpu, musa]
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
|
|||||||
@@ -70,14 +70,6 @@ healthCheckTimeout: 60
|
|||||||
# Valid log levels: debug, info (default), warn, error
|
# Valid log levels: debug, info (default), warn, error
|
||||||
logLevel: info
|
logLevel: info
|
||||||
|
|
||||||
# Automatic Port Values
|
|
||||||
# use ${PORT} in model.cmd and model.proxy to use an automatic port number
|
|
||||||
# when you use ${PORT} you can omit a custom model.proxy value, as it will
|
|
||||||
# default to http://localhost:${PORT}
|
|
||||||
|
|
||||||
# override the default port (5800) for automatic port values
|
|
||||||
startPort: 10001
|
|
||||||
|
|
||||||
# define valid model values and the upstream server start
|
# define valid model values and the upstream server start
|
||||||
models:
|
models:
|
||||||
"llama":
|
"llama":
|
||||||
@@ -91,7 +83,6 @@ models:
|
|||||||
- "CUDA_VISIBLE_DEVICES=0"
|
- "CUDA_VISIBLE_DEVICES=0"
|
||||||
|
|
||||||
# where to reach the server started by cmd, make sure the ports match
|
# where to reach the server started by cmd, make sure the ports match
|
||||||
# can be omitted if you use an automatic ${PORT} in cmd
|
|
||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:8999
|
||||||
|
|
||||||
# aliases names to use this model for
|
# aliases names to use this model for
|
||||||
@@ -118,14 +109,14 @@ models:
|
|||||||
# but they can still be requested as normal
|
# but they can still be requested as normal
|
||||||
"qwen-unlisted":
|
"qwen-unlisted":
|
||||||
unlisted: true
|
unlisted: true
|
||||||
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||||
|
|
||||||
# Docker Support (v26.1.4+ required!)
|
# Docker Support (v26.1.4+ required!)
|
||||||
"docker-llama":
|
"docker-llama":
|
||||||
proxy: "http://127.0.0.1:${PORT}"
|
proxy: "http://127.0.0.1:9790"
|
||||||
cmd: >
|
cmd: >
|
||||||
docker run --name dockertest
|
docker run --name dockertest
|
||||||
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
||||||
ghcr.io/ggerganov/llama.cpp:server
|
ghcr.io/ggerganov/llama.cpp:server
|
||||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||||
|
|
||||||
|
|||||||
@@ -33,17 +33,14 @@ func main() {
|
|||||||
|
|
||||||
// Set up the handler function using the provided response message
|
// Set up the handler function using the provided response message
|
||||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "text/plain")
|
||||||
|
|
||||||
// add a wait to simulate a slow query
|
// add a wait to simulate a slow query
|
||||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||||
time.Sleep(wait)
|
time.Sleep(wait)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.String(200, *responseMessage)
|
||||||
"responseMessage": *responseMessage,
|
|
||||||
"h_content_length": c.Request.Header.Get("Content-Length"),
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// for issue #62 to check model name strips profile slug
|
// for issue #62 to check model name strips profile slug
|
||||||
@@ -66,11 +63,8 @@ func main() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
r.POST("/v1/completions", func(c *gin.Context) {
|
r.POST("/v1/completions", func(c *gin.Context) {
|
||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "text/plain")
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.String(200, *responseMessage)
|
||||||
"responseMessage": *responseMessage,
|
|
||||||
})
|
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// issue #41
|
// issue #41
|
||||||
@@ -110,10 +104,6 @@ func main() {
|
|||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|
||||||
// expose some header values for testing
|
|
||||||
"h_content_type": c.GetHeader("Content-Type"),
|
|
||||||
"h_content_length": c.GetHeader("Content-Length"),
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
+1
-51
@@ -2,10 +2,8 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/google/shlex"
|
"github.com/google/shlex"
|
||||||
@@ -64,9 +62,6 @@ type Config struct {
|
|||||||
|
|
||||||
// map aliases to actual model IDs
|
// map aliases to actual model IDs
|
||||||
aliases map[string]string
|
aliases map[string]string
|
||||||
|
|
||||||
// automatic port assignments
|
|
||||||
StartPort int `yaml:"startPort"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) RealModelName(search string) (string, bool) {
|
func (c *Config) RealModelName(search string) (string, bool) {
|
||||||
@@ -88,16 +83,7 @@ func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func LoadConfig(path string) (Config, error) {
|
func LoadConfig(path string) (Config, error) {
|
||||||
file, err := os.Open(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
return LoadConfigFromReader(file)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|
||||||
data, err := io.ReadAll(r)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Config{}, err
|
return Config{}, err
|
||||||
}
|
}
|
||||||
@@ -112,50 +98,14 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
config.HealthCheckTimeout = 15
|
config.HealthCheckTimeout = 15
|
||||||
}
|
}
|
||||||
|
|
||||||
// set default port ranges
|
|
||||||
if config.StartPort == 0 {
|
|
||||||
// default to 5800
|
|
||||||
config.StartPort = 5800
|
|
||||||
} else if config.StartPort < 1 {
|
|
||||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Populate the aliases map
|
// Populate the aliases map
|
||||||
config.aliases = make(map[string]string)
|
config.aliases = make(map[string]string)
|
||||||
for modelName, modelConfig := range config.Models {
|
for modelName, modelConfig := range config.Models {
|
||||||
for _, alias := range modelConfig.Aliases {
|
for _, alias := range modelConfig.Aliases {
|
||||||
if _, found := config.aliases[alias]; found {
|
|
||||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
|
||||||
}
|
|
||||||
config.aliases[alias] = modelName
|
config.aliases[alias] = modelName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// iterate over the models and replace any ${PORT} with the next available port
|
|
||||||
// Get and sort all model IDs first, makes testing more consistent
|
|
||||||
modelIds := make([]string, 0, len(config.Models))
|
|
||||||
for modelId := range config.Models {
|
|
||||||
modelIds = append(modelIds, modelId)
|
|
||||||
}
|
|
||||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
|
||||||
|
|
||||||
// iterate over the sorted models
|
|
||||||
nextPort := config.StartPort
|
|
||||||
for _, modelId := range modelIds {
|
|
||||||
modelConfig := config.Models[modelId]
|
|
||||||
if strings.Contains(modelConfig.Cmd, "${PORT}") {
|
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", strconv.Itoa(nextPort))
|
|
||||||
if modelConfig.Proxy == "" {
|
|
||||||
modelConfig.Proxy = fmt.Sprintf("http://localhost:%d", nextPort)
|
|
||||||
} else {
|
|
||||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", strconv.Itoa(nextPort))
|
|
||||||
}
|
|
||||||
nextPort++
|
|
||||||
config.Models[modelId] = modelConfig
|
|
||||||
} else if modelConfig.Proxy == "" {
|
|
||||||
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
config = AddDefaultGroupToConfig(config)
|
config = AddDefaultGroupToConfig(config)
|
||||||
// check that members are all unique in the groups
|
// check that members are all unique in the groups
|
||||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||||
|
|||||||
+12
-102
@@ -3,7 +3,6 @@ package proxy
|
|||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -44,7 +43,6 @@ models:
|
|||||||
checkEndpoint: "/"
|
checkEndpoint: "/"
|
||||||
model4:
|
model4:
|
||||||
cmd: path/to/cmd --arg1 one
|
cmd: path/to/cmd --arg1 one
|
||||||
proxy: "http://localhost:8082"
|
|
||||||
checkEndpoint: "/"
|
checkEndpoint: "/"
|
||||||
|
|
||||||
healthCheckTimeout: 15
|
healthCheckTimeout: 15
|
||||||
@@ -75,7 +73,6 @@ groups:
|
|||||||
}
|
}
|
||||||
|
|
||||||
expected := Config{
|
expected := Config{
|
||||||
StartPort: 5800,
|
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": {
|
"model1": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
@@ -100,7 +97,6 @@ groups:
|
|||||||
},
|
},
|
||||||
"model4": {
|
"model4": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
Proxy: "http://localhost:8082",
|
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -142,6 +138,14 @@ groups:
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_GroupMemberIsUnique(t *testing.T) {
|
func TestConfig_GroupMemberIsUnique(t *testing.T) {
|
||||||
|
// Create a temporary YAML file for testing
|
||||||
|
tempDir, err := os.MkdirTemp("", "test-config")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||||
content := `
|
content := `
|
||||||
models:
|
models:
|
||||||
model1:
|
model1:
|
||||||
@@ -167,35 +171,15 @@ groups:
|
|||||||
exclusive: false
|
exclusive: false
|
||||||
members: ["model2"]
|
members: ["model2"]
|
||||||
`
|
`
|
||||||
// Load the config and verify
|
|
||||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
|
|
||||||
// a Contains as order of the map is not guaranteed
|
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||||
assert.Contains(t, err.Error(), "model member model2 is used in multiple groups:")
|
t.Fatalf("Failed to write temporary file: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_ModelAliasesAreUnique(t *testing.T) {
|
|
||||||
content := `
|
|
||||||
models:
|
|
||||||
model1:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8080"
|
|
||||||
aliases:
|
|
||||||
- m1
|
|
||||||
model2:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8081"
|
|
||||||
checkEndpoint: "/"
|
|
||||||
aliases:
|
|
||||||
- m1
|
|
||||||
- m2
|
|
||||||
`
|
|
||||||
// Load the config and verify
|
// Load the config and verify
|
||||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
_, err = LoadConfig(tempFile)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
// this is a contains because it could be `model1` or `model2` depending on the order
|
|
||||||
// go decided on the order of the map
|
|
||||||
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||||
@@ -285,77 +269,3 @@ func TestConfig_SanitizeCommand(t *testing.T) {
|
|||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Nil(t, args)
|
assert.Nil(t, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_AutomaticPortAssignments(t *testing.T) {
|
|
||||||
|
|
||||||
t.Run("Default Port Ranges", func(t *testing.T) {
|
|
||||||
content := ``
|
|
||||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, 5800, config.StartPort)
|
|
||||||
})
|
|
||||||
t.Run("User specific port ranges", func(t *testing.T) {
|
|
||||||
content := `startPort: 1000`
|
|
||||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, 1000, config.StartPort)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Invalid start port", func(t *testing.T) {
|
|
||||||
content := `startPort: abcd`
|
|
||||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
assert.NotNil(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("start port must be greater than 1", func(t *testing.T) {
|
|
||||||
content := `startPort: -99`
|
|
||||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
assert.NotNil(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Automatic port assignments", func(t *testing.T) {
|
|
||||||
content := `
|
|
||||||
startPort: 5800
|
|
||||||
models:
|
|
||||||
model1:
|
|
||||||
cmd: svr --port ${PORT}
|
|
||||||
model2:
|
|
||||||
cmd: svr --port ${PORT}
|
|
||||||
proxy: "http://172.11.22.33:${PORT}"
|
|
||||||
model3:
|
|
||||||
cmd: svr --port 1999
|
|
||||||
proxy: "http://1.2.3.4:1999"
|
|
||||||
`
|
|
||||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, 5800, config.StartPort)
|
|
||||||
assert.Equal(t, "svr --port 5800", config.Models["model1"].Cmd)
|
|
||||||
assert.Equal(t, "http://localhost:5800", config.Models["model1"].Proxy)
|
|
||||||
|
|
||||||
assert.Equal(t, "svr --port 5801", config.Models["model2"].Cmd)
|
|
||||||
assert.Equal(t, "http://172.11.22.33:5801", config.Models["model2"].Proxy)
|
|
||||||
|
|
||||||
assert.Equal(t, "svr --port 1999", config.Models["model3"].Cmd)
|
|
||||||
assert.Equal(t, "http://1.2.3.4:1999", config.Models["model3"].Proxy)
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Proxy value required if no ${PORT} in cmd", func(t *testing.T) {
|
|
||||||
content := `
|
|
||||||
models:
|
|
||||||
model1:
|
|
||||||
cmd: svr --port 111
|
|
||||||
`
|
|
||||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
assert.Equal(t, "model model1 requires a proxy value when not using automatic ${PORT}", err.Error())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -440,12 +439,6 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.Header = r.Header.Clone()
|
req.Header = r.Header.Clone()
|
||||||
|
|
||||||
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
|
|
||||||
if err == nil {
|
|
||||||
req.ContentLength = contentLength
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||||
|
|||||||
@@ -381,6 +381,11 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||||
|
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||||
|
// Create a new buffer for the reconstructed request
|
||||||
|
var requestBuffer bytes.Buffer
|
||||||
|
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||||
|
|
||||||
// Parse multipart form
|
// Parse multipart form
|
||||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||||
@@ -400,11 +405,6 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
|
||||||
// Create a new buffer for the reconstructed request
|
|
||||||
var requestBuffer bytes.Buffer
|
|
||||||
multipartWriter := multipart.NewWriter(&requestBuffer)
|
|
||||||
|
|
||||||
// Copy all form values
|
// Copy all form values
|
||||||
for key, values := range c.Request.MultipartForm.Value {
|
for key, values := range c.Request.MultipartForm.Value {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
@@ -478,10 +478,6 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
modifiedReq.Header = c.Request.Header.Clone()
|
modifiedReq.Header = c.Request.Header.Clone()
|
||||||
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
||||||
|
|
||||||
// set the content length of the body
|
|
||||||
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
|
|
||||||
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
|
||||||
|
|
||||||
// Use the modified request for proxying
|
// Use the modified request for proxying
|
||||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strconv"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -166,9 +165,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
|
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
|
|
||||||
var response map[string]string
|
results[key] = w.Body.String()
|
||||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
|
||||||
results[key] = response["responseMessage"]
|
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
}(key)
|
}(key)
|
||||||
|
|
||||||
@@ -445,7 +442,6 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "TheExpectedModel", response["model"])
|
assert.Equal(t, "TheExpectedModel", response["model"])
|
||||||
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
||||||
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test useModelName in configuration sends overrides what is sent to upstream
|
// Test useModelName in configuration sends overrides what is sent to upstream
|
||||||
@@ -596,27 +592,3 @@ func TestProxyManager_Upstream(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
assert.Equal(t, "model1", rec.Body.String())
|
assert.Equal(t, "model1", rec.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_ChatContentLength(t *testing.T) {
|
|
||||||
config := AddDefaultGroupToConfig(Config{
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Models: map[string]ModelConfig{
|
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
|
||||||
},
|
|
||||||
LogLevel: "error",
|
|
||||||
})
|
|
||||||
|
|
||||||
proxy := New(config)
|
|
||||||
defer proxy.StopProcesses()
|
|
||||||
|
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.HandlerFunc(w, req)
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
var response map[string]string
|
|
||||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
|
||||||
assert.Equal(t, "81", response["h_content_length"])
|
|
||||||
assert.Equal(t, "model1", response["responseMessage"])
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user