Add Peer Model Support (#438)
This PR allows a single llama-swap to be the central proxy for models served by other inference servers. The peer servers can be another llama-swap or any API that supports the /v1/* inference endpoint. Updates: #433, #299 Closes: #296
This commit is contained in:
+137
-90
@@ -50,6 +50,9 @@ type ProxyManager struct {
|
||||
buildDate string
|
||||
commit string
|
||||
version string
|
||||
|
||||
// peer proxy see: #296, #433
|
||||
peerProxy *PeerProxy
|
||||
}
|
||||
|
||||
func New(proxyConfig config.Config) *ProxyManager {
|
||||
@@ -133,6 +136,12 @@ func New(proxyConfig config.Config) *ProxyManager {
|
||||
maxMetrics = proxyConfig.MetricsMaxInMemory
|
||||
}
|
||||
|
||||
peerProxy, err := NewPeerProxy(proxyConfig.Peers, proxyLogger)
|
||||
if err != nil {
|
||||
proxyLogger.Errorf("Disabling Peering. Failed to create proxy peers: %v", err)
|
||||
peerProxy = nil
|
||||
}
|
||||
|
||||
pm := &ProxyManager{
|
||||
config: proxyConfig,
|
||||
ginEngine: gin.New(),
|
||||
@@ -151,6 +160,8 @@ func New(proxyConfig config.Config) *ProxyManager {
|
||||
buildDate: "unknown",
|
||||
commit: "abcd1234",
|
||||
version: "0",
|
||||
|
||||
peerProxy: peerProxy,
|
||||
}
|
||||
|
||||
// create the process groups
|
||||
@@ -166,22 +177,29 @@ func New(proxyConfig config.Config) *ProxyManager {
|
||||
// do it in the background, don't block startup -- not sure if good idea yet
|
||||
go func() {
|
||||
discardWriter := &DiscardWriter{}
|
||||
for _, realModelName := range proxyConfig.Hooks.OnStartup.Preload {
|
||||
proxyLogger.Infof("Preloading model: %s", realModelName)
|
||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||
for _, preloadModelName := range proxyConfig.Hooks.OnStartup.Preload {
|
||||
modelID, ok := proxyConfig.RealModelName(preloadModelName)
|
||||
|
||||
if !ok {
|
||||
proxyLogger.Warnf("Preload model %s not found in config", preloadModelName)
|
||||
continue
|
||||
}
|
||||
|
||||
proxyLogger.Infof("Preloading model: %s", modelID)
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
|
||||
if err != nil {
|
||||
event.Emit(ModelPreloadedEvent{
|
||||
ModelName: realModelName,
|
||||
ModelName: modelID,
|
||||
Success: false,
|
||||
})
|
||||
proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err)
|
||||
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, err)
|
||||
continue
|
||||
} else {
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
processGroup.ProxyRequest(realModelName, discardWriter, req)
|
||||
processGroup.ProxyRequest(modelID, discardWriter, req)
|
||||
event.Emit(ModelPreloadedEvent{
|
||||
ModelName: realModelName,
|
||||
ModelName: modelID,
|
||||
Success: true,
|
||||
})
|
||||
}
|
||||
@@ -399,16 +417,10 @@ func (pm *ProxyManager) Shutdown() {
|
||||
pm.shutdownCancel()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
||||
// de-alias the real model name and get a real one
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapProcessGroup(realModelName string) (*ProcessGroup, error) {
|
||||
processGroup := pm.findGroupByModelName(realModelName)
|
||||
if processGroup == nil {
|
||||
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
|
||||
return nil, fmt.Errorf("could not find process group for model %s", realModelName)
|
||||
}
|
||||
|
||||
if processGroup.exclusive {
|
||||
@@ -420,54 +432,71 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
|
||||
}
|
||||
}
|
||||
|
||||
return processGroup, realModelName, nil
|
||||
return processGroup, nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
data := make([]gin.H, 0, len(pm.config.Models))
|
||||
createdTime := time.Now().Unix()
|
||||
|
||||
newRecord := func(modelId string, modelConfig config.ModelConfig) gin.H {
|
||||
record := gin.H{
|
||||
"id": modelId,
|
||||
"object": "model",
|
||||
"created": createdTime,
|
||||
"owned_by": "llama-swap",
|
||||
}
|
||||
|
||||
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||
record["name"] = name
|
||||
}
|
||||
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||
record["description"] = desc
|
||||
}
|
||||
|
||||
// Add metadata if present
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
record["meta"] = gin.H{
|
||||
"llamaswap": modelConfig.Metadata,
|
||||
}
|
||||
}
|
||||
return record
|
||||
}
|
||||
|
||||
for id, modelConfig := range pm.config.Models {
|
||||
if modelConfig.Unlisted {
|
||||
continue
|
||||
}
|
||||
|
||||
newRecord := func(modelId string) gin.H {
|
||||
record := gin.H{
|
||||
"id": modelId,
|
||||
"object": "model",
|
||||
"created": createdTime,
|
||||
"owned_by": "llama-swap",
|
||||
}
|
||||
|
||||
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||
record["name"] = name
|
||||
}
|
||||
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||
record["description"] = desc
|
||||
}
|
||||
|
||||
// Add metadata if present
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
record["meta"] = gin.H{
|
||||
"llamaswap": modelConfig.Metadata,
|
||||
}
|
||||
}
|
||||
return record
|
||||
}
|
||||
|
||||
data = append(data, newRecord(id))
|
||||
data = append(data, newRecord(id, modelConfig))
|
||||
|
||||
// Include aliases
|
||||
if pm.config.IncludeAliasesInList {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if alias := strings.TrimSpace(alias); alias != "" {
|
||||
data = append(data, newRecord(alias))
|
||||
data = append(data, newRecord(alias, modelConfig))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if pm.peerProxy != nil {
|
||||
for peerID, peer := range pm.peerProxy.ListPeers() {
|
||||
// add peer models
|
||||
for _, modelID := range peer.Models {
|
||||
// Skip unlisted models if not showing them
|
||||
record := newRecord(modelID, config.ModelConfig{
|
||||
Name: fmt.Sprintf("%s: %s", peerID, modelID),
|
||||
Metadata: map[string]any{
|
||||
"peerID": peerID,
|
||||
},
|
||||
})
|
||||
|
||||
data = append(data, record)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by the "id" key
|
||||
sort.Slice(data, func(i, j int) bool {
|
||||
si, _ := data[i]["id"].(string)
|
||||
@@ -506,8 +535,8 @@ func (pm *ProxyManager) findModelInPath(path string) (searchName string, realNam
|
||||
searchModelName = searchModelName + "/" + part
|
||||
}
|
||||
|
||||
if real, ok := pm.config.RealModelName(searchModelName); ok {
|
||||
return searchModelName, real, "/" + strings.Join(parts[i+1:], "/"), true
|
||||
if modelID, ok := pm.config.RealModelName(searchModelName); ok {
|
||||
return searchModelName, modelID, "/" + strings.Join(parts[i+1:], "/"), true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -517,23 +546,22 @@ func (pm *ProxyManager) findModelInPath(path string) (searchName string, realNam
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
upstreamPath := c.Param("upstreamPath")
|
||||
|
||||
searchModelName, modelName, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
|
||||
searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
|
||||
|
||||
if !modelFound {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is exactly a model name with no additional path
|
||||
// and doesn't end with a trailing slash
|
||||
// Redirect /upstream/modelname to /upstream/modelname/ for URL consistency.
|
||||
// This ensures relative URLs in upstream responses resolve correctly and
|
||||
// provides canonical URL form. Uses 308 for POST/PUT/etc to preserve the
|
||||
// HTTP method (301 would downgrade to GET).
|
||||
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
|
||||
// Build new URL with query parameters preserved
|
||||
newPath := "/upstream/" + searchModelName + "/"
|
||||
if c.Request.URL.RawQuery != "" {
|
||||
newPath += "?" + c.Request.URL.RawQuery
|
||||
}
|
||||
|
||||
// Use 308 for non-GET/HEAD requests to preserve method
|
||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
||||
c.Redirect(http.StatusMovedPermanently, newPath)
|
||||
} else {
|
||||
@@ -542,7 +570,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(modelName)
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
@@ -554,15 +582,15 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
|
||||
// attempt to record metrics if it is a POST request
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath)
|
||||
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||
if err := processGroup.ProxyRequest(modelID, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath)
|
||||
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -581,41 +609,54 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
// Look for a matching local model first
|
||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||
|
||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[realModelName].UseModelName
|
||||
if useModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||
modelID, found := pm.config.RealModelName(requestedModel)
|
||||
if found {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// issue #174 strip parameters from the JSON body
|
||||
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams()
|
||||
if err != nil { // just log it and continue
|
||||
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error())
|
||||
} else {
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[modelID].UseModelName
|
||||
if useModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// issue #174 strip parameters from the JSON body
|
||||
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
|
||||
if err != nil { // just log it and continue
|
||||
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
|
||||
} else {
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
modelID = requestedModel
|
||||
nextHandler = pm.peerProxy.ProxyRequest
|
||||
|
||||
}
|
||||
|
||||
if nextHandler == nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
@@ -628,19 +669,19 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
||||
// issue #366 extract values that downstream handlers may need
|
||||
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
||||
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
||||
ctx = context.WithValue(ctx, proxyCtxKey("model"), realModelName)
|
||||
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -660,7 +701,13 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||
modelID, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
@@ -678,7 +725,7 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
// If this is the model field and we have a profile, use just the model name
|
||||
if key == "model" {
|
||||
// # issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[realModelName].UseModelName
|
||||
useModelName := pm.config.Models[modelID].UseModelName
|
||||
|
||||
if useModelName != "" {
|
||||
fieldValue = useModelName
|
||||
@@ -749,9 +796,9 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||
|
||||
// Use the modified request for proxying
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
||||
if err := processGroup.ProxyRequest(modelID, c.Writer, modifiedReq); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, modelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user