Compare commits

..

1 Commits

Author SHA1 Message Date
Benson Wong ab5a048584 bug fix with missing early return statements fix #112 2025-05-04 07:34:24 -07:00
5 changed files with 11 additions and 61 deletions
+1 -2
View File
@@ -15,8 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
#platform: [intel, cuda, vulkan, cpu, musa]
platform: [cuda, vulkan, cpu, musa]
platform: [intel, cuda, vulkan, cpu, musa]
fail-fast: false
steps:
- name: Checkout code
+4 -14
View File
@@ -33,17 +33,14 @@ func main() {
// Set up the handler function using the provided response message
r.POST("/v1/chat/completions", func(c *gin.Context) {
c.Header("Content-Type", "application/json")
c.Header("Content-Type", "text/plain")
// add a wait to simulate a slow query
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
time.Sleep(wait)
}
c.JSON(http.StatusOK, gin.H{
"responseMessage": *responseMessage,
"h_content_length": c.Request.Header.Get("Content-Length"),
})
c.String(200, *responseMessage)
})
// for issue #62 to check model name strips profile slug
@@ -66,11 +63,8 @@ func main() {
})
r.POST("/v1/completions", func(c *gin.Context) {
c.Header("Content-Type", "application/json")
c.JSON(http.StatusOK, gin.H{
"responseMessage": *responseMessage,
})
c.Header("Content-Type", "text/plain")
c.String(200, *responseMessage)
})
// issue #41
@@ -110,10 +104,6 @@ func main() {
c.JSON(http.StatusOK, gin.H{
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
"model": model,
// expose some header values for testing
"h_content_type": c.GetHeader("Content-Type"),
"h_content_length": c.GetHeader("Content-Length"),
})
})
-7
View File
@@ -8,7 +8,6 @@ import (
"net/http"
"net/url"
"os/exec"
"strconv"
"strings"
"sync"
"syscall"
@@ -440,12 +439,6 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return
}
req.Header = r.Header.Clone()
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
if err == nil {
req.ContentLength = contentLength
}
resp, err := client.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
+5 -9
View File
@@ -381,6 +381,11 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
}
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
@@ -400,11 +405,6 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
return
}
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
@@ -478,10 +478,6 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// set the content length of the body
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
+1 -29
View File
@@ -8,7 +8,6 @@ import (
"mime/multipart"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"testing"
"time"
@@ -166,9 +165,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
mu.Lock()
var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
results[key] = response["responseMessage"]
results[key] = w.Body.String()
mu.Unlock()
}(key)
@@ -445,7 +442,6 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "TheExpectedModel", response["model"])
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"])
}
// Test useModelName in configuration sends overrides what is sent to upstream
@@ -596,27 +592,3 @@ func TestProxyManager_Upstream(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String())
}
func TestProxyManager_ChatContentLength(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses()
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
assert.Equal(t, "81", response["h_content_length"])
assert.Equal(t, "model1", response["responseMessage"])
}