proxy: fix path bug in /logs/stream/{model_id} (#431)

A {model_id} containing a forward slash trips up gin's path param
parsing. This updates /logs/stream to work like /upstream where the
model_id is built up in parts and searched for in the configuration.

Updates #421
This commit is contained in:
Benson Wong
2025-12-21 21:47:14 -08:00
committed by GitHub
parent d3f329f924
commit e6a9e210ba
3 changed files with 45 additions and 42 deletions
+37 -37
View File
@@ -266,7 +266,7 @@ func (pm *ProxyManager) setupGinEngine() {
// in proxymanager_loghandlers.go
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.streamLogsHandler)
/**
* User Interface Endpoints
@@ -466,61 +466,61 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
})
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
upstreamPath := c.Param("upstreamPath")
// split the upstream path by / and search for the model name
parts := strings.Split(strings.TrimSpace(upstreamPath), "/")
if len(parts) == 0 {
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
return
}
modelFound := false
// findModelInPath searches for a valid model name in a path with slashes.
// It iteratively builds up path segments until it finds a matching model.
// Returns: (searchModelName, realModelName, remainingPath, found)
// Example: "/author/model/endpoint" with model "author/model" -> ("author/model", "author/model", "/endpoint", true)
func (pm *ProxyManager) findModelInPath(path string) (searchName string, realName string, remainingPath string, found bool) {
parts := strings.Split(strings.TrimSpace(path), "/")
searchModelName := ""
var modelName, remainingPath string
for i, part := range parts {
if parts[i] == "" {
if part == "" {
continue
}
if searchModelName == "" {
searchModelName = part
} else {
searchModelName = searchModelName + "/" + parts[i]
searchModelName = searchModelName + "/" + part
}
if real, ok := pm.config.RealModelName(searchModelName); ok {
modelName = real
remainingPath = "/" + strings.Join(parts[i+1:], "/")
modelFound = true
// Check if this is exactly a model name with no additional path
// and doesn't end with a trailing slash
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
// Build new URL with query parameters preserved
newPath := "/upstream/" + searchModelName + "/"
if c.Request.URL.RawQuery != "" {
newPath += "?" + c.Request.URL.RawQuery
}
// Use 308 for non-GET/HEAD requests to preserve method
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
c.Redirect(http.StatusMovedPermanently, newPath)
} else {
c.Redirect(http.StatusPermanentRedirect, newPath)
}
return
}
break
return searchModelName, real, "/" + strings.Join(parts[i+1:], "/"), true
}
}
return "", "", "", false
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
upstreamPath := c.Param("upstreamPath")
searchModelName, modelName, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
if !modelFound {
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
return
}
// Check if this is exactly a model name with no additional path
// and doesn't end with a trailing slash
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
// Build new URL with query parameters preserved
newPath := "/upstream/" + searchModelName + "/"
if c.Request.URL.RawQuery != "" {
newPath += "?" + c.Request.URL.RawQuery
}
// Use 308 for non-GET/HEAD requests to preserve method
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
c.Redirect(http.StatusMovedPermanently, newPath)
} else {
c.Redirect(http.StatusPermanentRedirect, newPath)
}
return
}
processGroup, realModelName, err := pm.swapProcessGroup(modelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
+5 -4
View File
@@ -31,7 +31,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// prevent nginx from buffering streamed logs
c.Header("X-Accel-Buffering", "no")
logMonitorId := c.Param("logMonitorID")
logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/")
logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
@@ -92,8 +92,9 @@ func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
case "upstream":
return pm.upstreamLogger, nil
default:
// search for a models specific logger
if name, found := pm.config.RealModelName(logMonitorId); found {
// search for a models specific logger using findModelInPath
// to handle model names with slashes (e.g., "author/model")
if _, name, _, found := pm.findModelInPath("/" + logMonitorId); found {
for _, group := range pm.processGroups {
if process, found := group.GetMember(name); found {
return process.Logger(), nil
@@ -101,6 +102,6 @@ func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
}
}
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
}
}
+3 -1
View File
@@ -1078,7 +1078,8 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model1": getTestSimpleResponderConfig("model1"),
"author/model": getTestSimpleResponderConfig("author/model"),
},
LogLevel: "error",
})
@@ -1091,6 +1092,7 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
"/logs/stream",
"/logs/stream/proxy",
"/logs/stream/upstream",
"/logs/stream/author/model",
}
for _, endpoint := range endpoints {