Compare commits

...

4 Commits

Author SHA1 Message Date
Benson Wong 21d7973d11 Improve content-length handling (#115)
ref: See #114

* Improve content-length handling
- Content length was not always being sent
- Add tests for content-length
2025-05-05 10:46:26 -07:00
Yi Hong Ang cc450e9c5f fix issue where proxy is still proxying with chunked transfer-encoding (#114) 2025-05-05 10:00:03 -07:00
Benson Wong 27465fe053 bug fix with missing early return statements fix #112 2025-05-05 09:32:44 -07:00
Benson Wong 9667989727 Disabling intel container build since it's been broken for weeks. 2025-05-04 21:39:42 -07:00
5 changed files with 67 additions and 11 deletions
+2 -1
View File
@@ -15,7 +15,8 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
platform: [intel, cuda, vulkan, cpu, musa] #platform: [intel, cuda, vulkan, cpu, musa]
platform: [cuda, vulkan, cpu, musa]
fail-fast: false fail-fast: false
steps: steps:
- name: Checkout code - name: Checkout code
+14 -4
View File
@@ -33,14 +33,17 @@ func main() {
// Set up the handler function using the provided response message // Set up the handler function using the provided response message
r.POST("/v1/chat/completions", func(c *gin.Context) { r.POST("/v1/chat/completions", func(c *gin.Context) {
c.Header("Content-Type", "text/plain") c.Header("Content-Type", "application/json")
// add a wait to simulate a slow query // add a wait to simulate a slow query
if wait, err := time.ParseDuration(c.Query("wait")); err == nil { if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
time.Sleep(wait) time.Sleep(wait)
} }
c.String(200, *responseMessage) c.JSON(http.StatusOK, gin.H{
"responseMessage": *responseMessage,
"h_content_length": c.Request.Header.Get("Content-Length"),
})
}) })
// for issue #62 to check model name strips profile slug // for issue #62 to check model name strips profile slug
@@ -63,8 +66,11 @@ func main() {
}) })
r.POST("/v1/completions", func(c *gin.Context) { r.POST("/v1/completions", func(c *gin.Context) {
c.Header("Content-Type", "text/plain") c.Header("Content-Type", "application/json")
c.String(200, *responseMessage) c.JSON(http.StatusOK, gin.H{
"responseMessage": *responseMessage,
})
}) })
// issue #41 // issue #41
@@ -104,6 +110,10 @@ func main() {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize), "text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
"model": model, "model": model,
// expose some header values for testing
"h_content_type": c.GetHeader("Content-Type"),
"h_content_length": c.GetHeader("Content-Length"),
}) })
}) })
+7
View File
@@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os/exec" "os/exec"
"strconv"
"strings" "strings"
"sync" "sync"
"syscall" "syscall"
@@ -439,6 +440,12 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return return
} }
req.Header = r.Header.Clone() req.Header = r.Header.Clone()
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
if err == nil {
req.ContentLength = contentLength
}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway) http.Error(w, err.Error(), http.StatusBadGateway)
+15 -5
View File
@@ -305,6 +305,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
processGroup, _, err := pm.swapProcessGroup(requestedModel) processGroup, _, err := pm.swapProcessGroup(requestedModel)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
} }
// rewrite the path // rewrite the path
@@ -347,11 +348,13 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
requestedModel := gjson.GetBytes(bodyBytes, "model").String() requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" { if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
} }
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
} }
// issue #69 allow custom model names to be sent to upstream // issue #69 allow custom model names to be sent to upstream
@@ -373,15 +376,11 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil { if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName) pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
} }
} }
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Parse multipart form // Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error())) pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
@@ -398,8 +397,14 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
} }
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Copy all form values // Copy all form values
for key, values := range c.Request.MultipartForm.Value { for key, values := range c.Request.MultipartForm.Value {
for _, value := range values { for _, value := range values {
@@ -473,10 +478,15 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
modifiedReq.Header = c.Request.Header.Clone() modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType()) modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// set the content length of the body
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying // Use the modified request for proxying
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil { if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName) pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
} }
} }
+29 -1
View File
@@ -8,6 +8,7 @@ import (
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -165,7 +166,9 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
mu.Lock() mu.Lock()
results[key] = w.Body.String() var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
results[key] = response["responseMessage"]
mu.Unlock() mu.Unlock()
}(key) }(key)
@@ -442,6 +445,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "TheExpectedModel", response["model"]) assert.Equal(t, "TheExpectedModel", response["model"])
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"])
} }
// Test useModelName in configuration sends overrides what is sent to upstream // Test useModelName in configuration sends overrides what is sent to upstream
@@ -592,3 +596,27 @@ func TestProxyManager_Upstream(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String()) assert.Equal(t, "model1", rec.Body.String())
} }
func TestProxyManager_ChatContentLength(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses()
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
assert.Equal(t, "81", response["h_content_length"])
assert.Equal(t, "model1", response["responseMessage"])
}