Compare commits

...

4 Commits

Author SHA1 Message Date
Benson Wong 12b69fb718 proxy: recover from panic in Process.statusUpdate (#378)
Process.statusUpdate() panics when it can not write data, usually from a
client disconnect. Since it runs in a goroutine and did not have a
recover() the result was a crash.

ref: https://github.com/mostlygeek/llama-swap/discussions/326#discussioncomment-14856197
2025-11-03 05:30:09 -08:00
Ryan Steed f91a8b2462 refactor: update Containerfile to support non-root user execution and improve security (#368)
Set default container user/group to lower privilege app user 

* refactor: update Containerfile to support non-root user execution and improve security

- Updated LS_VER argument from 89 to 170 to use the latest version
- Added UID/GID arguments with default values of 0 (root) for backward compatibility
- Added USER_HOME environment variable set to /root
- Implemented conditional user/group creation logic that only runs when UID/GID are not 0
- Created necessary directory structure with proper ownership using mkdir and chown commands
- Switched to non-root user execution for improved security posture
- Updated COPY instruction to use --chown flag for proper file ownership

* chore: update containerfile to use non-root user with proper UID/GID

- Changed default UID and GID from 0 (root) to 10001 for security best practices
- Updated USER_HOME from /root to /app to avoid running as root user
2025-10-31 17:01:04 -07:00
Benson Wong a89b803d4a Stream loading state when swapping models (#371)
Swapping models can take a long time and leave a lot of silence while the model is loading. Rather than silently load the model in the background, this PR allows llama-swap to send status updates in the reasoning_content of a streaming chat response.

Fixes: #366
2025-10-29 00:09:39 -07:00
Benson Wong f852689104 proxy: add panic recovery to Process.ProxyRequest (#363)
Switching to use httputil.ReverseProxy in #342 introduced a possible
panic if a client disconnects while streaming the body. Since llama-swap
does not use http.Server the recover() is not automatically there.

- introduce a recover() in Process.ProxyRequest to recover and log the
  event
- add TestProcess_ReverseProxyPanicIsHandled to reproduce and test the
  fix

fixes: #362
2025-10-25 20:40:05 -07:00
10 changed files with 491 additions and 54 deletions
+12
View File
@@ -35,6 +35,14 @@ metricsMaxInMemory: 1000
# - it is automatically incremented for every model that uses it # - it is automatically incremented for every model that uses it
startPort: 10001 startPort: 10001
# sendLoadingState: inject loading status updates into the reasoning (thinking)
# field
# - optional, default: false
# - when true, a stream of loading messages will be sent to the client in the
# reasoning field so chat UIs can show that loading is in progress.
# - see #366 for more details
sendLoadingState: true
# macros: a dictionary of string substitutions # macros: a dictionary of string substitutions
# - optional, default: empty dictionary # - optional, default: empty dictionary
# - macros are reusable snippets # - macros are reusable snippets
@@ -184,6 +192,10 @@ models:
# - recommended to be omitted and the default used # - recommended to be omitted and the default used
concurrencyLimit: 0 concurrencyLimit: 0
# sendLoadingState: overrides the global sendLoadingState setting for this model
# - optional, default: undefined (use global setting)
sendLoadingState: false
# Unlisted model example: # Unlisted model example:
"qwen-unlisted": "qwen-unlisted":
# unlisted: boolean, true or false # unlisted: boolean, true or false
+24 -2
View File
@@ -2,7 +2,29 @@ ARG BASE_TAG=server-cuda
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG} FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
# has to be after the FROM # has to be after the FROM
ARG LS_VER=89 ARG LS_VER=170
# Set default UID/GID arguments
ARG UID=10001
ARG GID=10001
ARG USER_HOME=/app
# Add user/group
ENV HOME=$USER_HOME
RUN if [ $UID -ne 0 ]; then \
if [ $GID -ne 0 ]; then \
addgroup --system --gid $GID app; \
fi; \
adduser --system --no-create-home --uid $UID --gid $GID \
--home $USER_HOME app; \
fi
# Handle paths
RUN mkdir --parents $HOME /app
RUN chown --recursive $UID:$GID $HOME /app
# Switch user
USER $UID:$GID
WORKDIR /app WORKDIR /app
RUN \ RUN \
@@ -10,7 +32,7 @@ RUN \
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \ tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
COPY config.example.yaml /app/config.yaml COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1 HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ] ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
+10
View File
@@ -129,6 +129,9 @@ type Config struct {
// hooks, see: #209 // hooks, see: #209
Hooks HooksConfig `yaml:"hooks"` Hooks HooksConfig `yaml:"hooks"`
// send loading state in reasoning
SendLoadingState bool `yaml:"sendLoadingState"`
} }
func (c *Config) RealModelName(search string) (string, bool) { func (c *Config) RealModelName(search string) (string, bool) {
@@ -350,6 +353,13 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
) )
} }
// if sendLoadingState is nil, set it to the global config value
// see #366
if modelConfig.SendLoadingState == nil {
v := config.SendLoadingState // copy it
modelConfig.SendLoadingState = &v
}
config.Models[modelId] = modelConfig config.Models[modelId] = modelConfig
} }
+7
View File
@@ -160,6 +160,8 @@ groups:
t.Fatalf("Failed to load config: %v", err) t.Fatalf("Failed to load config: %v", err)
} }
modelLoadingState := false
expected := Config{ expected := Config{
LogLevel: "info", LogLevel: "info",
StartPort: 5800, StartPort: 5800,
@@ -171,6 +173,7 @@ groups:
Preload: []string{"model1", "model2"}, Preload: []string{"model1", "model2"},
}, },
}, },
SendLoadingState: false,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": { "model1": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
@@ -180,6 +183,7 @@ groups:
CheckEndpoint: "/health", CheckEndpoint: "/health",
Name: "Model 1", Name: "Model 1",
Description: "This is model 1", Description: "This is model 1",
SendLoadingState: &modelLoadingState,
}, },
"model2": { "model2": {
Cmd: "path/to/server --arg1 one", Cmd: "path/to/server --arg1 one",
@@ -187,6 +191,7 @@ groups:
Aliases: []string{"m2"}, Aliases: []string{"m2"},
Env: []string{}, Env: []string{},
CheckEndpoint: "/", CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
}, },
"model3": { "model3": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
@@ -194,6 +199,7 @@ groups:
Aliases: []string{"mthree"}, Aliases: []string{"mthree"},
Env: []string{}, Env: []string{},
CheckEndpoint: "/", CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
}, },
"model4": { "model4": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
@@ -201,6 +207,7 @@ groups:
CheckEndpoint: "/", CheckEndpoint: "/",
Aliases: []string{}, Aliases: []string{},
Env: []string{}, Env: []string{},
SendLoadingState: &modelLoadingState,
}, },
}, },
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
+7
View File
@@ -152,12 +152,15 @@ groups:
t.Fatalf("Failed to load config: %v", err) t.Fatalf("Failed to load config: %v", err)
} }
modelLoadingState := false
expected := Config{ expected := Config{
LogLevel: "info", LogLevel: "info",
StartPort: 5800, StartPort: 5800,
Macros: MacroList{ Macros: MacroList{
{"svr-path", "path/to/server"}, {"svr-path", "path/to/server"},
}, },
SendLoadingState: false,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": { "model1": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
@@ -166,6 +169,7 @@ groups:
Aliases: []string{"m1", "model-one"}, Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"}, Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health", CheckEndpoint: "/health",
SendLoadingState: &modelLoadingState,
}, },
"model2": { "model2": {
Cmd: "path/to/server --arg1 one", Cmd: "path/to/server --arg1 one",
@@ -174,6 +178,7 @@ groups:
Aliases: []string{"m2"}, Aliases: []string{"m2"},
Env: []string{}, Env: []string{},
CheckEndpoint: "/", CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
}, },
"model3": { "model3": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
@@ -182,6 +187,7 @@ groups:
Aliases: []string{"mthree"}, Aliases: []string{"mthree"},
Env: []string{}, Env: []string{},
CheckEndpoint: "/", CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
}, },
"model4": { "model4": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
@@ -190,6 +196,7 @@ groups:
CheckEndpoint: "/", CheckEndpoint: "/",
Aliases: []string{}, Aliases: []string{},
Env: []string{}, Env: []string{},
SendLoadingState: &modelLoadingState,
}, },
}, },
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
+3
View File
@@ -35,6 +35,9 @@ type ModelConfig struct {
// Metadata: see #264 // Metadata: see #264
// Arbitrary metadata that can be exposed through the API // Arbitrary metadata that can be exposed through the API
Metadata map[string]any `yaml:"metadata"` Metadata map[string]any `yaml:"metadata"`
// override global setting
SendLoadingState *bool `yaml:"sendLoadingState"`
} }
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
+21 -1
View File
@@ -50,5 +50,25 @@ models:
} }
}) })
} }
}
func TestConfig_ModelSendLoadingState(t *testing.T) {
content := `
sendLoadingState: true
models:
model1:
cmd: path/to/cmd --port ${PORT}
sendLoadingState: false
model2:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.True(t, config.SendLoadingState)
if assert.NotNil(t, config.Models["model1"].SendLoadingState) {
assert.False(t, *config.Models["model1"].SendLoadingState)
}
if assert.NotNil(t, config.Models["model2"].SendLoadingState) {
assert.True(t, *config.Models["model2"].SendLoadingState)
}
} }
+280 -3
View File
@@ -2,8 +2,10 @@ package proxy
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
@@ -462,6 +464,12 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
} }
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
if p.reverseProxy == nil {
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
return
}
requestBeginTime := time.Now() requestBeginTime := time.Now()
var startDuration time.Duration var startDuration time.Duration
@@ -488,21 +496,65 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
p.inFlightRequests.Done() p.inFlightRequests.Done()
}() }()
// for #366
// - extract streaming param from request context, should have been set by proxymanager
var srw *statusResponseWriter
swapCtx, cancelLoadCtx := context.WithCancel(r.Context())
// start the process on demand // start the process on demand
if p.CurrentState() != StateReady { if p.CurrentState() != StateReady {
// start a goroutine to stream loading status messages into the response writer
// add a sync so the streaming client only runs when the goroutine has exited
isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool)
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming {
srw = newStatusResponseWriter(p, w)
go srw.statusUpdates(swapCtx)
} else {
p.proxyLogger.Debugf("<%s> SendLoadingState is nil or false, not streaming loading state", p.ID)
}
beginStartTime := time.Now() beginStartTime := time.Now()
if err := p.start(); err != nil { if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err) errstr := fmt.Sprintf("unable to start process: %s", err)
cancelLoadCtx()
if srw != nil {
srw.sendData(fmt.Sprintf("Unable to swap model err: %s\n", errstr))
// Wait for statusUpdates goroutine to finish writing its deferred "Done!" messages
// before closing the connection. Without this, the connection would close before
// the goroutine can write its cleanup messages, causing incomplete SSE output.
srw.waitForCompletion(100 * time.Millisecond)
} else {
http.Error(w, errstr, http.StatusBadGateway) http.Error(w, errstr, http.StatusBadGateway)
}
return return
} }
startDuration = time.Since(beginStartTime) startDuration = time.Since(beginStartTime)
} }
if p.reverseProxy != nil { // should trigger srw to stop sending loading events ...
p.reverseProxy.ServeHTTP(w, r) cancelLoadCtx()
// recover from http.ErrAbortHandler panics that can occur when the client
// disconnects before the response is sent
defer func() {
if r := recover(); r != nil {
if r == http.ErrAbortHandler {
p.proxyLogger.Infof("<%s> recovered from client disconnection during streaming", p.ID)
} else { } else {
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError) p.proxyLogger.Infof("<%s> recovered from panic: %v", p.ID, r)
}
}
}()
if srw != nil {
// Wait for the goroutine to finish writing its final messages
const completionTimeout = 1 * time.Second
if !srw.waitForCompletion(completionTimeout) {
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
}
p.reverseProxy.ServeHTTP(srw, r)
} else {
p.reverseProxy.ServeHTTP(w, r)
} }
totalTime := time.Since(requestBeginTime) totalTime := time.Since(requestBeginTime)
@@ -588,3 +640,228 @@ func (p *Process) cmdStopUpstreamProcess() error {
return nil return nil
} }
var loadingRemarks = []string{
"Still faster than your last standup meeting...",
"Reticulating splines...",
"Waking up the hamsters...",
"Teaching the model manners...",
"Convincing the GPU to participate...",
"Loading weights (they're heavy)...",
"Herding electrons...",
"Compiling excuses for the delay...",
"Downloading more RAM...",
"Asking the model nicely to boot up...",
"Bribing CUDA with cookies...",
"Still loading (blame VRAM)...",
"The model is fashionably late...",
"Warming up those tensors...",
"Making the neural net do push-ups...",
"Your patience is appreciated (really)...",
"Almost there (probably)...",
"Loading like it's 1999...",
"The model forgot where it put its keys...",
"Quantum tunneling through layers...",
"Negotiating with the PCIe bus...",
"Defrosting frozen parameters...",
"Teaching attention heads to focus...",
"Running the matrix (slowly)...",
"Untangling transformer blocks...",
"Calibrating the flux capacitor...",
"Spinning up the probability wheels...",
"Waiting for the GPU to wake from its nap...",
"Converting caffeine to compute...",
"Allocating virtual patience...",
"Performing arcane CUDA rituals...",
"The model is stuck in traffic...",
"Inflating embeddings...",
"Summoning computational demons...",
"Pleading with the OOM killer...",
"Calculating the meaning of life (still at 42)...",
"Training the training wheels...",
"Optimizing the optimizer...",
"Bootstrapping the bootstrapper...",
"Loading loading screen...",
"Processing processing logs...",
"Buffering buffer overflow jokes...",
"The model hit snooze...",
"Debugging the debugger...",
"Compiling the compiler...",
"Parsing the parser (meta)...",
"Tokenizing tokens...",
"Encoding the encoder...",
"Hashing hash browns...",
"Forking spoons (not forks)...",
"The model is contemplating existence...",
"Transcending dimensional barriers...",
"Invoking elder tensor gods...",
"Unfurling probability clouds...",
"Synchronizing parallel universes...",
"The GPU is having second thoughts...",
"Recalibrating reality matrices...",
"Time is an illusion, loading doubly so...",
"Convincing bits to flip themselves...",
"The model is reading its own documentation...",
}
type statusResponseWriter struct {
hasWritten bool
writer http.ResponseWriter
process *Process
wg sync.WaitGroup // Track goroutine completion
start time.Time
}
func newStatusResponseWriter(p *Process, w http.ResponseWriter) *statusResponseWriter {
s := &statusResponseWriter{
writer: w,
process: p,
start: time.Now(),
}
s.Header().Set("Content-Type", "text/event-stream") // SSE
s.Header().Set("Cache-Control", "no-cache") // no-cache
s.Header().Set("Connection", "keep-alive") // keep-alive
s.WriteHeader(http.StatusOK) // send status code 200
s.sendLine("━━━━━")
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", p.ID))
return s
}
// statusUpdates sends status updates to the client while the model is loading
func (s *statusResponseWriter) statusUpdates(ctx context.Context) {
s.wg.Add(1)
defer s.wg.Done()
// Recover from panics caused by client disconnection
// Note: recover() only works within the same goroutine, so we need it here
defer func() {
if r := recover(); r != nil {
s.process.proxyLogger.Debugf("<%s> statusUpdates recovered from panic (likely client disconnect): %v", s.process.ID, r)
}
}()
defer func() {
duration := time.Since(s.start)
s.sendLine(fmt.Sprintf("\nDone! (%.2fs)", duration.Seconds()))
s.sendLine("━━━━━")
s.sendLine(" ")
}()
// Create a shuffled copy of loadingRemarks
remarks := make([]string, len(loadingRemarks))
copy(remarks, loadingRemarks)
rand.Shuffle(len(remarks), func(i, j int) {
remarks[i], remarks[j] = remarks[j], remarks[i]
})
ri := 0
// Pick a random duration to send a remark
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
lastRemarkTime := time.Now()
ticker := time.NewTicker(time.Second)
defer ticker.Stop() // Ensure ticker is stopped to prevent resource leak
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if s.process.CurrentState() == StateReady {
return
}
// Check if it's time for a snarky remark
if time.Since(lastRemarkTime) >= nextRemarkIn {
remark := remarks[ri%len(remarks)]
ri++
s.sendLine(fmt.Sprintf("\n%s", remark))
lastRemarkTime = time.Now()
// Pick a new random duration for the next remark
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
} else {
s.sendData(".")
}
}
}
}
// waitForCompletion waits for the statusUpdates goroutine to finish
func (s *statusResponseWriter) waitForCompletion(timeout time.Duration) bool {
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
return true
case <-time.After(timeout):
return false
}
}
func (s *statusResponseWriter) sendLine(line string) {
s.sendData(line + "\n")
}
func (s *statusResponseWriter) sendData(data string) {
// Create the proper SSE JSON structure
type Delta struct {
ReasoningContent string `json:"reasoning_content"`
}
type Choice struct {
Delta Delta `json:"delta"`
}
type SSEMessage struct {
Choices []Choice `json:"choices"`
}
msg := SSEMessage{
Choices: []Choice{
{
Delta: Delta{
ReasoningContent: data,
},
},
},
}
jsonData, err := json.Marshal(msg)
if err != nil {
s.process.proxyLogger.Errorf("<%s> Failed to marshal SSE message: %v", s.process.ID, err)
return
}
// Write SSE formatted data, panic if not able to write
_, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData)
if err != nil {
panic(fmt.Sprintf("<%s> Failed to write SSE data: %v", s.process.ID, err))
}
s.Flush()
}
func (s *statusResponseWriter) Header() http.Header {
return s.writer.Header()
}
func (s *statusResponseWriter) Write(data []byte) (int, error) {
return s.writer.Write(data)
}
func (s *statusResponseWriter) WriteHeader(statusCode int) {
if s.hasWritten {
return
}
s.hasWritten = true
s.writer.WriteHeader(statusCode)
s.Flush()
}
// Add Flush method
func (s *statusResponseWriter) Flush() {
if flusher, ok := s.writer.(http.Flusher); ok {
flusher.Flush()
}
}
+71
View File
@@ -494,3 +494,74 @@ func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1") assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
} }
// TestProcess_ReverseProxyPanicIsHandled tests that panics from
// httputil.ReverseProxy in Process.ProxyRequest(w, r) do not bubble up and are
// handled appropriately.
//
// httputil.ReverseProxy will panic with http.ErrAbortHandler when it has sent headers
// can't copy the body. This can be caused by a client disconnecting before the full
// response is sent from some reason.
//
// bug: https://github.com/mostlygeek/llama-swap/issues/362
// see: https://github.com/golang/go/issues/23643 (where panic was added to httputil.ReverseProxy)
func TestProcess_ReverseProxyPanicIsHandled(t *testing.T) {
// Add defer/recover to catch any panics that aren't handled by ProxyRequest
// If this recover() is hit, it means ProxyRequest didn't handle the panic properly
defer func() {
if r := recover(); r != nil {
t.Fatalf("ProxyRequest should handle panics from reverseProxy.ServeHTTP, but panic was not caught: %v", r)
}
}()
expectedMessage := "panic_test"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("panic-test", 5, config, debugLogger, debugLogger)
defer process.Stop()
// Start the process
err := process.start()
assert.Nil(t, err)
assert.Equal(t, StateReady, process.CurrentState())
// Create a custom ResponseWriter that simulates a client disconnect
// by panicking when Write is called after headers are sent
panicWriter := &panicOnWriteResponseWriter{
ResponseRecorder: httptest.NewRecorder(),
shouldPanic: true,
}
// Make a request that will trigger the panic
req := httptest.NewRequest("GET", "/slow-respond?echo=test&delay=100ms", nil)
// This should panic inside reverseProxy.ServeHTTP when the panicWriter.Write() is called.
// ProxyRequest should catch and handle this panic gracefully.
process.ProxyRequest(panicWriter, req)
// If we get here, the panic was properly recovered in ProxyRequest
// The process should still be in a ready state
assert.Equal(t, StateReady, process.CurrentState())
}
// panicOnWriteResponseWriter is a ResponseWriter that panics on Write
// to simulate a client disconnect after headers are sent
// used by: TestProcess_ReverseProxyPanicIsHandled
type panicOnWriteResponseWriter struct {
*httptest.ResponseRecorder
shouldPanic bool
headerWritten bool
}
func (w *panicOnWriteResponseWriter) WriteHeader(statusCode int) {
w.headerWritten = true
w.ResponseRecorder.WriteHeader(statusCode)
}
func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
if w.shouldPanic && w.headerWritten {
// Simulate the panic that httputil.ReverseProxy throws
panic(http.ErrAbortHandler)
}
return w.ResponseRecorder.Write(b)
}
+8
View File
@@ -25,6 +25,8 @@ const (
PROFILE_SPLIT_CHAR = ":" PROFILE_SPLIT_CHAR = ":"
) )
type proxyCtxKey string
type ProxyManager struct { type ProxyManager struct {
sync.Mutex sync.Mutex
@@ -555,6 +557,12 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes))) c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.ContentLength = int64(len(bodyBytes)) c.Request.ContentLength = int64(len(bodyBytes))
// issue #366 extract values that downstream handlers may need
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
ctx = context.WithValue(ctx, proxyCtxKey("model"), realModelName)
c.Request = c.Request.WithContext(ctx)
if pm.metricsMonitor != nil && c.Request.Method == "POST" { if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil { if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))