From e6a9e210ba769d44a04a21a96934bbe6fd233831 Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Sun, 21 Dec 2025 21:47:14 -0800 Subject: [PATCH] proxy: fix path bug in /logs/stream/{model_id} (#431) A {model_id} containing a forward slash trips up gin's path param parsing. This updates /logs/stream to work like /upstream where the model_id is built up in parts and searched for in the configuration. Updates #421 --- proxy/proxymanager.go | 74 +++++++++++++++---------------- proxy/proxymanager_loghandlers.go | 9 ++-- proxy/proxymanager_test.go | 4 +- 3 files changed, 45 insertions(+), 42 deletions(-) diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index c8ce0d42..b7e578df 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -266,7 +266,7 @@ func (pm *ProxyManager) setupGinEngine() { // in proxymanager_loghandlers.go pm.ginEngine.GET("/logs", pm.sendLogsHandlers) pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) - pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler) + pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.streamLogsHandler) /** * User Interface Endpoints @@ -466,61 +466,61 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) { }) } -func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { - upstreamPath := c.Param("upstreamPath") - - // split the upstream path by / and search for the model name - parts := strings.Split(strings.TrimSpace(upstreamPath), "/") - if len(parts) == 0 { - pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path") - return - } - - modelFound := false +// findModelInPath searches for a valid model name in a path with slashes. +// It iteratively builds up path segments until it finds a matching model. +// Returns: (searchModelName, realModelName, remainingPath, found) +// Example: "/author/model/endpoint" with model "author/model" -> ("author/model", "author/model", "/endpoint", true) +func (pm *ProxyManager) findModelInPath(path string) (searchName string, realName string, remainingPath string, found bool) { + parts := strings.Split(strings.TrimSpace(path), "/") searchModelName := "" - var modelName, remainingPath string + for i, part := range parts { - if parts[i] == "" { + if part == "" { continue } if searchModelName == "" { searchModelName = part } else { - searchModelName = searchModelName + "/" + parts[i] + searchModelName = searchModelName + "/" + part } if real, ok := pm.config.RealModelName(searchModelName); ok { - modelName = real - remainingPath = "/" + strings.Join(parts[i+1:], "/") - modelFound = true - - // Check if this is exactly a model name with no additional path - // and doesn't end with a trailing slash - if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") { - // Build new URL with query parameters preserved - newPath := "/upstream/" + searchModelName + "/" - if c.Request.URL.RawQuery != "" { - newPath += "?" + c.Request.URL.RawQuery - } - - // Use 308 for non-GET/HEAD requests to preserve method - if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead { - c.Redirect(http.StatusMovedPermanently, newPath) - } else { - c.Redirect(http.StatusPermanentRedirect, newPath) - } - return - } - break + return searchModelName, real, "/" + strings.Join(parts[i+1:], "/"), true } } + return "", "", "", false +} + +func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { + upstreamPath := c.Param("upstreamPath") + + searchModelName, modelName, remainingPath, modelFound := pm.findModelInPath(upstreamPath) + if !modelFound { pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path") return } + // Check if this is exactly a model name with no additional path + // and doesn't end with a trailing slash + if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") { + // Build new URL with query parameters preserved + newPath := "/upstream/" + searchModelName + "/" + if c.Request.URL.RawQuery != "" { + newPath += "?" + c.Request.URL.RawQuery + } + + // Use 308 for non-GET/HEAD requests to preserve method + if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead { + c.Redirect(http.StatusMovedPermanently, newPath) + } else { + c.Redirect(http.StatusPermanentRedirect, newPath) + } + return + } + processGroup, realModelName, err := pm.swapProcessGroup(modelName) if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) diff --git a/proxy/proxymanager_loghandlers.go b/proxy/proxymanager_loghandlers.go index d4a59e88..daeb786c 100644 --- a/proxy/proxymanager_loghandlers.go +++ b/proxy/proxymanager_loghandlers.go @@ -31,7 +31,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) { // prevent nginx from buffering streamed logs c.Header("X-Accel-Buffering", "no") - logMonitorId := c.Param("logMonitorID") + logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/") logger, err := pm.getLogger(logMonitorId) if err != nil { c.String(http.StatusBadRequest, err.Error()) @@ -92,8 +92,9 @@ func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) { case "upstream": return pm.upstreamLogger, nil default: - // search for a models specific logger - if name, found := pm.config.RealModelName(logMonitorId); found { + // search for a models specific logger using findModelInPath + // to handle model names with slashes (e.g., "author/model") + if _, name, _, found := pm.findModelInPath("/" + logMonitorId); found { for _, group := range pm.processGroups { if process, found := group.GetMember(name); found { return process.Logger(), nil @@ -101,6 +102,6 @@ func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) { } } - return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'") + return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID") } } diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index c2f41f40..4ae024e6 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -1078,7 +1078,8 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) { config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), + "model1": getTestSimpleResponderConfig("model1"), + "author/model": getTestSimpleResponderConfig("author/model"), }, LogLevel: "error", }) @@ -1091,6 +1092,7 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) { "/logs/stream", "/logs/stream/proxy", "/logs/stream/upstream", + "/logs/stream/author/model", } for _, endpoint := range endpoints {