Files
llama-swap/proxy/proxymanager.go
T
Benson Wong 565c44766d config,proxy: add new configuration logToStdout (#432)
The new logToStdout option controls what is logged to stdout. The
default has been changed to just the proxy logs, which contain swap and
http request logs.

There are four supported settings: none, proxy, upstream, both. The
"both" setting is the legacy setting where everything was spewed to
stdout.
2025-12-21 22:23:31 -08:00

812 lines
24 KiB
Go

package proxy
import (
"bytes"
"context"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
PROFILE_SPLIT_CHAR = ":"
)
type proxyCtxKey string
type ProxyManager struct {
sync.Mutex
config config.Config
ginEngine *gin.Engine
// logging
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
muxLogger *LogMonitor
metricsMonitor *metricsMonitor
processGroups map[string]*ProcessGroup
// shutdown signaling
shutdownCtx context.Context
shutdownCancel context.CancelFunc
// version info
buildDate string
commit string
version string
}
func New(proxyConfig config.Config) *ProxyManager {
// set up loggers
var muxLogger, upstreamLogger, proxyLogger *LogMonitor
switch proxyConfig.LogToStdout {
case config.LogToStdoutNone:
muxLogger = NewLogMonitorWriter(io.Discard)
upstreamLogger = NewLogMonitorWriter(io.Discard)
proxyLogger = NewLogMonitorWriter(io.Discard)
case config.LogToStdoutBoth:
muxLogger = NewLogMonitorWriter(os.Stdout)
upstreamLogger = NewLogMonitorWriter(muxLogger)
proxyLogger = NewLogMonitorWriter(muxLogger)
case config.LogToStdoutUpstream:
muxLogger = NewLogMonitorWriter(os.Stdout)
upstreamLogger = NewLogMonitorWriter(muxLogger)
proxyLogger = NewLogMonitorWriter(io.Discard)
default:
// same as config.LogToStdoutProxy
// helpful because some old tests create a config.Config directly and it
// may not have LogToStdout set explicitly
muxLogger = NewLogMonitorWriter(os.Stdout)
upstreamLogger = NewLogMonitorWriter(io.Discard)
proxyLogger = NewLogMonitorWriter(muxLogger)
}
if proxyConfig.LogRequests {
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
}
switch strings.ToLower(strings.TrimSpace(proxyConfig.LogLevel)) {
case "debug":
proxyLogger.SetLogLevel(LevelDebug)
upstreamLogger.SetLogLevel(LevelDebug)
case "info":
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
case "warn":
proxyLogger.SetLogLevel(LevelWarn)
upstreamLogger.SetLogLevel(LevelWarn)
case "error":
proxyLogger.SetLogLevel(LevelError)
upstreamLogger.SetLogLevel(LevelError)
default:
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
}
// see: https://go.dev/src/time/format.go
timeFormats := map[string]string{
"ansic": time.ANSIC,
"unixdate": time.UnixDate,
"rubydate": time.RubyDate,
"rfc822": time.RFC822,
"rfc822z": time.RFC822Z,
"rfc850": time.RFC850,
"rfc1123": time.RFC1123,
"rfc1123z": time.RFC1123Z,
"rfc3339": time.RFC3339,
"rfc3339nano": time.RFC3339Nano,
"kitchen": time.Kitchen,
"stamp": time.Stamp,
"stampmilli": time.StampMilli,
"stampmicro": time.StampMicro,
"stampnano": time.StampNano,
}
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(proxyConfig.LogTimeFormat))]; ok {
proxyLogger.SetLogTimeFormat(timeFormat)
upstreamLogger.SetLogTimeFormat(timeFormat)
}
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
var maxMetrics int
if proxyConfig.MetricsMaxInMemory <= 0 {
maxMetrics = 1000 // Default fallback
} else {
maxMetrics = proxyConfig.MetricsMaxInMemory
}
pm := &ProxyManager{
config: proxyConfig,
ginEngine: gin.New(),
proxyLogger: proxyLogger,
muxLogger: muxLogger,
upstreamLogger: upstreamLogger,
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics),
processGroups: make(map[string]*ProcessGroup),
shutdownCtx: shutdownCtx,
shutdownCancel: shutdownCancel,
buildDate: "unknown",
commit: "abcd1234",
version: "0",
}
// create the process groups
for groupID := range proxyConfig.Groups {
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup
}
pm.setupGinEngine()
// run any startup hooks
if len(proxyConfig.Hooks.OnStartup.Preload) > 0 {
// do it in the background, don't block startup -- not sure if good idea yet
go func() {
discardWriter := &DiscardWriter{}
for _, realModelName := range proxyConfig.Hooks.OnStartup.Preload {
proxyLogger.Infof("Preloading model: %s", realModelName)
processGroup, _, err := pm.swapProcessGroup(realModelName)
if err != nil {
event.Emit(ModelPreloadedEvent{
ModelName: realModelName,
Success: false,
})
proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err)
continue
} else {
req, _ := http.NewRequest("GET", "/", nil)
processGroup.ProxyRequest(realModelName, discardWriter, req)
event.Emit(ModelPreloadedEvent{
ModelName: realModelName,
Success: true,
})
}
}
}()
}
return pm
}
func (pm *ProxyManager) setupGinEngine() {
pm.ginEngine.Use(func(c *gin.Context) {
// don't log the Wake on Lan proxy health check
if c.Request.URL.Path == "/wol-health" {
c.Next()
return
}
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
clientIP,
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
// see: issue: #81, #77 and #42 for CORS issues
// respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) {
if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
// allow whatever the client requested by default
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
c.Header("Access-Control-Allow-Headers", sanitized)
} else {
c.Header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, Accept, X-Requested-With",
)
}
c.Header("Access-Control-Max-Age", "86400")
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
})
// Set up routes using the Gin engine
pm.ginEngine.POST("/v1/chat/completions", pm.proxyInferenceHandler)
// Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", pm.proxyInferenceHandler)
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
pm.ginEngine.POST("/v1/messages", pm.proxyInferenceHandler)
// Support embeddings and reranking
pm.ginEngine.POST("/v1/embeddings", pm.proxyInferenceHandler)
// llama-server's /reranking endpoint + aliases
pm.ginEngine.POST("/reranking", pm.proxyInferenceHandler)
pm.ginEngine.POST("/rerank", pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/rerank", pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/reranking", pm.proxyInferenceHandler)
// llama-server's /infill endpoint for code infilling
pm.ginEngine.POST("/infill", pm.proxyInferenceHandler)
// llama-server's /completion endpoint
pm.ginEngine.POST("/completion", pm.proxyInferenceHandler)
// Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
// in proxymanager_loghandlers.go
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.streamLogsHandler)
/**
* User Interface Endpoints
*/
pm.ginEngine.GET("/", func(c *gin.Context) {
c.Redirect(http.StatusFound, "/ui")
})
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
c.Redirect(http.StatusFound, "/ui/models")
})
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
pm.ginEngine.GET("/health", func(c *gin.Context) {
c.String(http.StatusOK, "OK")
})
// see cmd/wol-proxy/wol-proxy.go, not logged
pm.ginEngine.GET("/wol-health", func(c *gin.Context) {
c.String(http.StatusOK, "OK")
})
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil {
c.Data(http.StatusOK, "image/x-icon", data)
} else {
c.String(http.StatusInternalServerError, err.Error())
}
})
reactFS, err := GetReactFS()
if err != nil {
pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err)
} else {
// serve files that exist under /ui/*
pm.ginEngine.StaticFS("/ui", reactFS)
// server SPA for UI under /ui/*
pm.ginEngine.NoRoute(func(c *gin.Context) {
if !strings.HasPrefix(c.Request.URL.Path, "/ui") {
c.AbortWithStatus(http.StatusNotFound)
return
}
file, err := reactFS.Open("index.html")
if err != nil {
c.String(http.StatusInternalServerError, err.Error())
return
}
defer file.Close()
http.ServeContent(c.Writer, c.Request, "index.html", time.Now(), file)
})
}
// see: proxymanager_api.go
// add API handler functions
addApiHandlers(pm)
// Disable console color for testing
gin.DisableConsoleColor()
}
// ServeHTTP implements http.Handler interface
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
pm.ginEngine.ServeHTTP(w, r)
}
// StopProcesses acquires a lock and stops all running upstream processes.
// This is the public method safe for concurrent calls.
// Unlike Shutdown, this method only stops the processes but doesn't perform
// a complete shutdown, allowing for process replacement without full termination.
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
pm.Lock()
defer pm.Unlock()
// stop Processes in parallel
var wg sync.WaitGroup
for _, processGroup := range pm.processGroups {
wg.Add(1)
go func(processGroup *ProcessGroup) {
defer wg.Done()
processGroup.StopProcesses(strategy)
}(processGroup)
}
wg.Wait()
}
// Shutdown stops all processes managed by this ProxyManager
func (pm *ProxyManager) Shutdown() {
pm.Lock()
defer pm.Unlock()
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
var wg sync.WaitGroup
// Send shutdown signal to all process in groups
for _, processGroup := range pm.processGroups {
wg.Add(1)
go func(processGroup *ProcessGroup) {
defer wg.Done()
processGroup.Shutdown()
}(processGroup)
}
wg.Wait()
pm.shutdownCancel()
}
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
}
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
}
if processGroup.exclusive {
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
for groupId, otherGroup := range pm.processGroups {
if groupId != processGroup.id && !otherGroup.persistent {
otherGroup.StopProcesses(StopWaitForInflightRequest)
}
}
}
return processGroup, realModelName, nil
}
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := make([]gin.H, 0, len(pm.config.Models))
createdTime := time.Now().Unix()
for id, modelConfig := range pm.config.Models {
if modelConfig.Unlisted {
continue
}
newRecord := func(modelId string) gin.H {
record := gin.H{
"id": modelId,
"object": "model",
"created": createdTime,
"owned_by": "llama-swap",
}
if name := strings.TrimSpace(modelConfig.Name); name != "" {
record["name"] = name
}
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
record["description"] = desc
}
// Add metadata if present
if len(modelConfig.Metadata) > 0 {
record["meta"] = gin.H{
"llamaswap": modelConfig.Metadata,
}
}
return record
}
data = append(data, newRecord(id))
// Include aliases
if pm.config.IncludeAliasesInList {
for _, alias := range modelConfig.Aliases {
if alias := strings.TrimSpace(alias); alias != "" {
data = append(data, newRecord(alias))
}
}
}
}
// Sort by the "id" key
sort.Slice(data, func(i, j int) bool {
si, _ := data[i]["id"].(string)
sj, _ := data[j]["id"].(string)
return si < sj
})
// Set CORS headers if origin exists
if origin := c.GetHeader("Origin"); origin != "" {
c.Header("Access-Control-Allow-Origin", origin)
}
// Use gin's JSON method which handles content-type and encoding
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": data,
})
}
// findModelInPath searches for a valid model name in a path with slashes.
// It iteratively builds up path segments until it finds a matching model.
// Returns: (searchModelName, realModelName, remainingPath, found)
// Example: "/author/model/endpoint" with model "author/model" -> ("author/model", "author/model", "/endpoint", true)
func (pm *ProxyManager) findModelInPath(path string) (searchName string, realName string, remainingPath string, found bool) {
parts := strings.Split(strings.TrimSpace(path), "/")
searchModelName := ""
for i, part := range parts {
if part == "" {
continue
}
if searchModelName == "" {
searchModelName = part
} else {
searchModelName = searchModelName + "/" + part
}
if real, ok := pm.config.RealModelName(searchModelName); ok {
return searchModelName, real, "/" + strings.Join(parts[i+1:], "/"), true
}
}
return "", "", "", false
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
upstreamPath := c.Param("upstreamPath")
searchModelName, modelName, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
if !modelFound {
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
return
}
// Check if this is exactly a model name with no additional path
// and doesn't end with a trailing slash
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
// Build new URL with query parameters preserved
newPath := "/upstream/" + searchModelName + "/"
if c.Request.URL.RawQuery != "" {
newPath += "?" + c.Request.URL.RawQuery
}
// Use 308 for non-GET/HEAD requests to preserve method
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
c.Redirect(http.StatusMovedPermanently, newPath)
} else {
c.Redirect(http.StatusPermanentRedirect, newPath)
}
return
}
processGroup, realModelName, err := pm.swapProcessGroup(modelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
// rewrite the path
originalPath := c.Request.URL.Path
c.Request.URL.Path = remainingPath
// attempt to record metrics if it is a POST request
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath)
return
}
} else {
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath)
return
}
}
}
func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
return
}
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
}
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
return
}
processGroup, _, err := pm.swapProcessGroup(realModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
// issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[realModelName].UseModelName
if useModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
return
}
}
// issue #174 strip parameters from the JSON body
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams()
if err != nil { // just log it and continue
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error())
} else {
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
return
}
}
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.ContentLength = int64(len(bodyBytes))
// issue #366 extract values that downstream handlers may need
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
ctx = context.WithValue(ctx, proxyCtxKey("model"), realModelName)
c.Request = c.Request.WithContext(ctx)
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
} else {
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
}
}
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
return
}
// Get model parameter from the form
requestedModel := c.Request.FormValue("model")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
return
}
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" {
// # issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[realModelName].UseModelName
if useModelName != "" {
fieldValue = useModelName
} else {
fieldValue = requestedModel
}
}
field, err := multipartWriter.CreateFormField(key)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
return
}
if _, err = field.Write([]byte(fieldValue)); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
return
}
}
}
// Copy all files from the original request
for key, fileHeaders := range c.Request.MultipartForm.File {
for _, fileHeader := range fileHeaders {
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
return
}
file, err := fileHeader.Open()
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
return
}
if _, err = io.Copy(formFile, file); err != nil {
file.Close()
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
return
}
file.Close()
}
}
// Close the multipart writer to finalize the form
if err := multipartWriter.Close(); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
return
}
// Create a new request with the reconstructed form data
modifiedReq, err := http.NewRequestWithContext(
c.Request.Context(),
c.Request.Method,
c.Request.URL.String(),
&requestBuffer,
)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
return
}
// Copy the headers from the original request
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// set the content length of the body
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
}
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
acceptHeader := c.GetHeader("Accept")
if strings.Contains(acceptHeader, "application/json") {
c.JSON(statusCode, gin.H{"error": message})
} else {
c.String(statusCode, message)
}
}
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
pm.StopProcesses(StopImmediately)
c.String(http.StatusOK, "OK")
}
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response.
for _, processGroup := range pm.processGroups {
for _, process := range processGroup.processes {
if process.CurrentState() == StateReady {
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
})
}
}
}
// Put the results under the `running` key.
response := gin.H{
"running": runningProcesses,
}
context.JSON(http.StatusOK, response) // Always return 200 OK
}
func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
for _, group := range pm.processGroups {
if group.HasMember(modelName) {
return group
}
}
return nil
}
func (pm *ProxyManager) SetVersion(buildDate string, commit string, version string) {
pm.Lock()
defer pm.Unlock()
pm.buildDate = buildDate
pm.commit = commit
pm.version = version
}