proxy: add swap matrix with solver-based model swapping (#646)
Add a new swap matrix to supersede groups for running concurrent models. The matrix uses a solver that picks the lowest cost evictions to make a requested model available. This simple approach along with a very basic DSL grammar can enable very complex swapping scenarios. - add DSL parser for set expressions with & (AND), | (OR), (), +ref - add MatrixConfig structs, validation, and topological sort for +ref - add MatrixSolver with cost-minimizing swap decisions - add Matrix runtime integrating solver with Process lifecycle - integrate matrix into ProxyManager with if-branches at all endpoints - update config.example.yaml and config-schema.json with matrix schema - config enforces groups XOR matrix (cannot use both) fixes #643
This commit is contained in:
+99
-34
@@ -77,6 +77,9 @@ type ProxyManager struct {
|
||||
|
||||
processGroups map[string]*ProcessGroup
|
||||
|
||||
// matrix-based swap (mutually exclusive with processGroups)
|
||||
matrix *Matrix
|
||||
|
||||
inFlightCounter *InflightCounter
|
||||
|
||||
// shutdown signaling
|
||||
@@ -203,10 +206,14 @@ func New(proxyConfig config.Config) *ProxyManager {
|
||||
peerProxy: peerProxy,
|
||||
}
|
||||
|
||||
// create the process groups
|
||||
for groupID := range proxyConfig.Groups {
|
||||
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
|
||||
pm.processGroups[groupID] = processGroup
|
||||
// create either matrix or process groups (mutually exclusive)
|
||||
if proxyConfig.Matrix != nil {
|
||||
pm.matrix = NewMatrix(proxyConfig, proxyLogger, upstreamLogger)
|
||||
} else {
|
||||
for groupID := range proxyConfig.Groups {
|
||||
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
|
||||
pm.processGroups[groupID] = processGroup
|
||||
}
|
||||
}
|
||||
|
||||
pm.setupGinEngine()
|
||||
@@ -225,18 +232,29 @@ func New(proxyConfig config.Config) *ProxyManager {
|
||||
}
|
||||
|
||||
proxyLogger.Infof("Preloading model: %s", modelID)
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
|
||||
if err != nil {
|
||||
var preloadErr error
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
|
||||
if pm.matrix != nil {
|
||||
preloadErr = pm.matrix.ProxyRequest(modelID, discardWriter, req)
|
||||
} else {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
preloadErr = err
|
||||
} else {
|
||||
preloadErr = processGroup.ProxyRequest(modelID, discardWriter, req)
|
||||
}
|
||||
}
|
||||
|
||||
if preloadErr != nil {
|
||||
event.Emit(ModelPreloadedEvent{
|
||||
ModelName: modelID,
|
||||
Success: false,
|
||||
})
|
||||
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, err)
|
||||
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, preloadErr)
|
||||
continue
|
||||
} else {
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
processGroup.ProxyRequest(modelID, discardWriter, req)
|
||||
event.Emit(ModelPreloadedEvent{
|
||||
ModelName: modelID,
|
||||
Success: true,
|
||||
@@ -453,6 +471,11 @@ func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
if pm.matrix != nil {
|
||||
pm.matrix.StopProcesses(strategy)
|
||||
return
|
||||
}
|
||||
|
||||
// stop Processes in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, processGroup := range pm.processGroups {
|
||||
@@ -473,6 +496,12 @@ func (pm *ProxyManager) Shutdown() {
|
||||
|
||||
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
|
||||
|
||||
if pm.matrix != nil {
|
||||
pm.matrix.Shutdown()
|
||||
pm.shutdownCancel()
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
// Send shutdown signal to all process in groups
|
||||
for _, processGroup := range pm.processGroups {
|
||||
@@ -639,10 +668,16 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
var handler func(string, http.ResponseWriter, *http.Request) error
|
||||
if pm.matrix != nil {
|
||||
handler = pm.matrix.ProxyRequest
|
||||
} else {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
handler = processGroup.ProxyRequest
|
||||
}
|
||||
|
||||
// rewrite the path
|
||||
@@ -651,13 +686,13 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
|
||||
// attempt to record metrics if it is a POST request
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, handler); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := processGroup.ProxyRequest(modelID, c.Writer, c.Request); err != nil {
|
||||
if err := handler(modelID, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
|
||||
return
|
||||
@@ -683,10 +718,16 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
||||
|
||||
modelID, found := pm.config.RealModelName(requestedModel)
|
||||
if found {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
var localHandler func(string, http.ResponseWriter, *http.Request) error
|
||||
if pm.matrix != nil {
|
||||
localHandler = pm.matrix.ProxyRequest
|
||||
} else {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
localHandler = processGroup.ProxyRequest
|
||||
}
|
||||
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
@@ -737,7 +778,7 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
nextHandler = localHandler
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
modelID = requestedModel
|
||||
@@ -823,15 +864,19 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
|
||||
modelID, found := pm.config.RealModelName(requestedModel)
|
||||
if found {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
if pm.matrix != nil {
|
||||
nextHandler = pm.matrix.ProxyRequest
|
||||
} else {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
}
|
||||
|
||||
useModelName = pm.config.Models[modelID].UseModelName
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
modelID = requestedModel
|
||||
@@ -942,14 +987,18 @@ func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
|
||||
var modelID string
|
||||
|
||||
if realModelID, found := pm.config.RealModelName(requestedModel); found {
|
||||
processGroup, err := pm.swapProcessGroup(realModelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
modelID = realModelID
|
||||
if pm.matrix != nil {
|
||||
nextHandler = pm.matrix.ProxyRequest
|
||||
} else {
|
||||
processGroup, err := pm.swapProcessGroup(realModelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
}
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
modelID = requestedModel
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
@@ -1048,9 +1097,9 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
||||
context.Header("Content-Type", "application/json")
|
||||
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
||||
|
||||
for _, processGroup := range pm.processGroups {
|
||||
for _, process := range processGroup.processes {
|
||||
if process.CurrentState() == StateReady {
|
||||
if pm.matrix != nil {
|
||||
for _, modelID := range pm.matrix.RunningModels() {
|
||||
if process, ok := pm.matrix.GetProcess(modelID); ok {
|
||||
runningProcesses = append(runningProcesses, gin.H{
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
@@ -1062,6 +1111,22 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, processGroup := range pm.processGroups {
|
||||
for _, process := range processGroup.processes {
|
||||
if process.CurrentState() == StateReady {
|
||||
runningProcesses = append(runningProcesses, gin.H{
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
"cmd": process.config.Cmd,
|
||||
"proxy": process.config.Proxy,
|
||||
"ttl": process.config.UnloadAfter,
|
||||
"name": process.config.Name,
|
||||
"description": process.config.Description,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Put the results under the `running` key.
|
||||
|
||||
Reference in New Issue
Block a user