proxy: add swap matrix with solver-based model swapping (#646)

Add a new swap matrix to supersede groups for running concurrent models.
The matrix uses a solver that picks the lowest cost evictions to make a
requested model available. This simple approach along with a very basic
DSL grammar can enable very complex swapping scenarios.

- add DSL parser for set expressions with & (AND), | (OR), (), +ref
- add MatrixConfig structs, validation, and topological sort for +ref
- add MatrixSolver with cost-minimizing swap decisions
- add Matrix runtime integrating solver with Process lifecycle
- integrate matrix into ProxyManager with if-branches at all endpoints
- update config.example.yaml and config-schema.json with matrix schema
- config enforces groups XOR matrix (cannot use both)

fixes #643
This commit is contained in:
Benson Wong
2026-04-14 21:55:30 -07:00
committed by GitHub
parent 40e39f7a86
commit 35193f82f1
13 changed files with 2080 additions and 186 deletions
+99 -34
View File
@@ -77,6 +77,9 @@ type ProxyManager struct {
processGroups map[string]*ProcessGroup
// matrix-based swap (mutually exclusive with processGroups)
matrix *Matrix
inFlightCounter *InflightCounter
// shutdown signaling
@@ -203,10 +206,14 @@ func New(proxyConfig config.Config) *ProxyManager {
peerProxy: peerProxy,
}
// create the process groups
for groupID := range proxyConfig.Groups {
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup
// create either matrix or process groups (mutually exclusive)
if proxyConfig.Matrix != nil {
pm.matrix = NewMatrix(proxyConfig, proxyLogger, upstreamLogger)
} else {
for groupID := range proxyConfig.Groups {
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup
}
}
pm.setupGinEngine()
@@ -225,18 +232,29 @@ func New(proxyConfig config.Config) *ProxyManager {
}
proxyLogger.Infof("Preloading model: %s", modelID)
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
var preloadErr error
req, _ := http.NewRequest("GET", "/", nil)
if pm.matrix != nil {
preloadErr = pm.matrix.ProxyRequest(modelID, discardWriter, req)
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
preloadErr = err
} else {
preloadErr = processGroup.ProxyRequest(modelID, discardWriter, req)
}
}
if preloadErr != nil {
event.Emit(ModelPreloadedEvent{
ModelName: modelID,
Success: false,
})
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, err)
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, preloadErr)
continue
} else {
req, _ := http.NewRequest("GET", "/", nil)
processGroup.ProxyRequest(modelID, discardWriter, req)
event.Emit(ModelPreloadedEvent{
ModelName: modelID,
Success: true,
@@ -453,6 +471,11 @@ func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
pm.Lock()
defer pm.Unlock()
if pm.matrix != nil {
pm.matrix.StopProcesses(strategy)
return
}
// stop Processes in parallel
var wg sync.WaitGroup
for _, processGroup := range pm.processGroups {
@@ -473,6 +496,12 @@ func (pm *ProxyManager) Shutdown() {
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
if pm.matrix != nil {
pm.matrix.Shutdown()
pm.shutdownCancel()
return
}
var wg sync.WaitGroup
// Send shutdown signal to all process in groups
for _, processGroup := range pm.processGroups {
@@ -639,10 +668,16 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
return
}
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
var handler func(string, http.ResponseWriter, *http.Request) error
if pm.matrix != nil {
handler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
handler = processGroup.ProxyRequest
}
// rewrite the path
@@ -651,13 +686,13 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
// attempt to record metrics if it is a POST request
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, handler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
return
}
} else {
if err := processGroup.ProxyRequest(modelID, c.Writer, c.Request); err != nil {
if err := handler(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
return
@@ -683,10 +718,16 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
modelID, found := pm.config.RealModelName(requestedModel)
if found {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
var localHandler func(string, http.ResponseWriter, *http.Request) error
if pm.matrix != nil {
localHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
localHandler = processGroup.ProxyRequest
}
// issue #69 allow custom model names to be sent to upstream
@@ -737,7 +778,7 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
}
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = processGroup.ProxyRequest
nextHandler = localHandler
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
@@ -823,15 +864,19 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
modelID, found := pm.config.RealModelName(requestedModel)
if found {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
if pm.matrix != nil {
nextHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
nextHandler = processGroup.ProxyRequest
}
useModelName = pm.config.Models[modelID].UseModelName
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = processGroup.ProxyRequest
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
@@ -942,14 +987,18 @@ func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
var modelID string
if realModelID, found := pm.config.RealModelName(requestedModel); found {
processGroup, err := pm.swapProcessGroup(realModelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
modelID = realModelID
if pm.matrix != nil {
nextHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(realModelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
nextHandler = processGroup.ProxyRequest
}
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = processGroup.ProxyRequest
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
modelID = requestedModel
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
@@ -1048,9 +1097,9 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response.
for _, processGroup := range pm.processGroups {
for _, process := range processGroup.processes {
if process.CurrentState() == StateReady {
if pm.matrix != nil {
for _, modelID := range pm.matrix.RunningModels() {
if process, ok := pm.matrix.GetProcess(modelID); ok {
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
@@ -1062,6 +1111,22 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
})
}
}
} else {
for _, processGroup := range pm.processGroups {
for _, process := range processGroup.processes {
if process.CurrentState() == StateReady {
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
"cmd": process.config.Cmd,
"proxy": process.config.Proxy,
"ttl": process.config.UnloadAfter,
"name": process.config.Name,
"description": process.config.Description,
})
}
}
}
}
// Put the results under the `running` key.