Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2d00120781 | |||
| afc9aef058 | |||
| d7b390df74 | |||
| 5025c2f1f3 | |||
| e3a0b013c1 | |||
| f5763a94a0 | |||
| 8ada72eb57 | |||
| 2441b383d3 | |||
| 25f251699c | |||
| 7f37bcc6eb | |||
| 519c3a4d22 |
@@ -0,0 +1,37 @@
|
|||||||
|
---
|
||||||
|
name: Bug Report
|
||||||
|
about: Something is not working as expected...
|
||||||
|
title: ''
|
||||||
|
labels: bug
|
||||||
|
assignees: ''
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Describe the bug**
|
||||||
|
A clear and concise description of what the bug is.
|
||||||
|
|
||||||
|
**Expected behaviour**
|
||||||
|
A clear and concise description of what you expected to happen.
|
||||||
|
|
||||||
|
**Operating system and version**
|
||||||
|
|
||||||
|
- OS: (linux, osx, windows, freebsd, etc)
|
||||||
|
- GPUs: (list architecture)
|
||||||
|
|
||||||
|
**My Configuration**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# copy / paste your configuration here
|
||||||
|
```
|
||||||
|
|
||||||
|
**Proxy Logs**
|
||||||
|
|
||||||
|
```
|
||||||
|
# copy / paste from /logs
|
||||||
|
```
|
||||||
|
|
||||||
|
**Upstream Logs**
|
||||||
|
|
||||||
|
```
|
||||||
|
# copy/paste from /logs
|
||||||
|
```
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
name: Windows CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "main" ]
|
||||||
|
|
||||||
|
pull_request:
|
||||||
|
branches: [ "main" ]
|
||||||
|
|
||||||
|
# Allows manual triggering of the workflow
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
|
||||||
|
run-tests:
|
||||||
|
runs-on: windows-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Go
|
||||||
|
uses: actions/setup-go@v4
|
||||||
|
with:
|
||||||
|
go-version: '1.23'
|
||||||
|
|
||||||
|
# cache simple-responder to save the build time
|
||||||
|
- name: Restore Simple Responder
|
||||||
|
id: restore-simple-responder
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: ./build
|
||||||
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
|
# necessary for testing proxy/Process swapping
|
||||||
|
- name: Create simple-responder
|
||||||
|
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: make simple-responder-windows
|
||||||
|
|
||||||
|
- name: Save Simple Responder
|
||||||
|
# nothing new to save ... skip this step
|
||||||
|
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||||
|
id: save-simple-responder
|
||||||
|
uses: actions/cache/save@v4
|
||||||
|
with:
|
||||||
|
path: ./build
|
||||||
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
|
- name: Test all
|
||||||
|
shell: bash
|
||||||
|
run: make test-all
|
||||||
@@ -1,6 +1,4 @@
|
|||||||
# This workflow will build a golang project
|
name: Linux CI
|
||||||
|
|
||||||
name: CI
|
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@@ -24,9 +22,26 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
go-version: '1.23'
|
go-version: '1.23'
|
||||||
|
|
||||||
|
# cache simple-responder to save the build time
|
||||||
|
- name: Restore Simple Responder
|
||||||
|
id: restore-simple-responder
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: ./build
|
||||||
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
# necessary for testing proxy/Process swapping
|
# necessary for testing proxy/Process swapping
|
||||||
- name: Create simple-responder
|
- name: Create simple-responder
|
||||||
run: make simple-responder
|
run: make simple-responder
|
||||||
|
|
||||||
|
- name: Save Simple Responder
|
||||||
|
# nothing new to save ... skip this step
|
||||||
|
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||||
|
id: save-simple-responder
|
||||||
|
uses: actions/cache/save@v4
|
||||||
|
with:
|
||||||
|
path: ./build
|
||||||
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
- name: Test all
|
- name: Test all
|
||||||
run: make test-all
|
run: make test-all
|
||||||
@@ -20,10 +20,10 @@ clean:
|
|||||||
rm -rf $(BUILD_DIR)
|
rm -rf $(BUILD_DIR)
|
||||||
|
|
||||||
test:
|
test:
|
||||||
go test -short -v ./proxy
|
go test -short -v -count=1 ./proxy
|
||||||
|
|
||||||
test-all:
|
test-all:
|
||||||
go test -v ./proxy
|
go test -v -count=1 ./proxy
|
||||||
|
|
||||||
# Build OSX binary
|
# Build OSX binary
|
||||||
mac:
|
mac:
|
||||||
@@ -46,6 +46,10 @@ simple-responder:
|
|||||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
|
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
|
||||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
|
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
|
||||||
|
|
||||||
|
simple-responder-windows:
|
||||||
|
@echo "Building simple responder for windows"
|
||||||
|
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe misc/simple-responder/simple-responder.go
|
||||||
|
|
||||||
# Ensure build directory exists
|
# Ensure build directory exists
|
||||||
$(BUILD_DIR):
|
$(BUILD_DIR):
|
||||||
mkdir -p $(BUILD_DIR)
|
mkdir -p $(BUILD_DIR)
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ go 1.23.0
|
|||||||
require (
|
require (
|
||||||
github.com/fsnotify/fsnotify v1.9.0
|
github.com/fsnotify/fsnotify v1.9.0
|
||||||
github.com/gin-gonic/gin v1.10.0
|
github.com/gin-gonic/gin v1.10.0
|
||||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
|
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
github.com/tidwall/gjson v1.18.0
|
github.com/tidwall/gjson v1.18.0
|
||||||
github.com/tidwall/sjson v1.2.5
|
github.com/tidwall/sjson v1.2.5
|
||||||
@@ -13,6 +12,7 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/billziss-gh/golib v0.2.0 // indirect
|
||||||
github.com/bytedance/sonic v1.11.6 // indirect
|
github.com/bytedance/sonic v1.11.6 // indirect
|
||||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8=
|
||||||
|
github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw=
|
||||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||||
|
|||||||
+1
-1
@@ -84,7 +84,7 @@ func main() {
|
|||||||
case newManager := <-reloadChan:
|
case newManager := <-reloadChan:
|
||||||
log.Println("Config change detected, waiting for in-flight requests to complete...")
|
log.Println("Config change detected, waiting for in-flight requests to complete...")
|
||||||
// Stop old manager processes gracefully (this waits for in-flight requests)
|
// Stop old manager processes gracefully (this waits for in-flight requests)
|
||||||
currentManager.StopProcesses()
|
currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
|
||||||
// Now do a full shutdown to clear the process map
|
// Now do a full shutdown to clear the process map
|
||||||
currentManager.Shutdown()
|
currentManager.Shutdown()
|
||||||
currentManager = newManager
|
currentManager = newManager
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ func main() {
|
|||||||
|
|
||||||
silent := flag.Bool("silent", false, "disable all logging")
|
silent := flag.Bool("silent", false, "disable all logging")
|
||||||
|
|
||||||
|
ignoreSigTerm := flag.Bool("ignore-sig-term", false, "ignore SIGTERM signal")
|
||||||
|
|
||||||
flag.Parse() // Parse the command-line flags
|
flag.Parse() // Parse the command-line flags
|
||||||
|
|
||||||
// Create a new Gin router
|
// Create a new Gin router
|
||||||
@@ -190,6 +192,10 @@ func main() {
|
|||||||
log.SetOutput(io.Discard)
|
log.SetOutput(io.Discard)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !*silent {
|
||||||
|
fmt.Printf("My PID: %d\n", os.Getpid())
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
log.Printf("simple-responder listening on %s\n", address)
|
log.Printf("simple-responder listening on %s\n", address)
|
||||||
// service connections
|
// service connections
|
||||||
@@ -200,11 +206,36 @@ func main() {
|
|||||||
|
|
||||||
// Wait for interrupt signal to gracefully shutdown the server with
|
// Wait for interrupt signal to gracefully shutdown the server with
|
||||||
// a timeout of 5 seconds.
|
// a timeout of 5 seconds.
|
||||||
quit := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
// kill (no param) default send syscall.SIGTERM
|
// kill (no param) default send syscall.SIGTERM
|
||||||
// kill -2 is syscall.SIGINT
|
// kill -2 is syscall.SIGINT
|
||||||
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
|
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
|
||||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
<-quit
|
|
||||||
|
countSigInt := 0
|
||||||
|
|
||||||
|
runloop:
|
||||||
|
for {
|
||||||
|
signal := <-sigChan
|
||||||
|
switch signal {
|
||||||
|
case syscall.SIGINT:
|
||||||
|
countSigInt++
|
||||||
|
if countSigInt > 1 {
|
||||||
|
break runloop
|
||||||
|
} else {
|
||||||
|
log.Println("Recieved SIGINT, send another SIGINT to shutdown")
|
||||||
|
}
|
||||||
|
case syscall.SIGTERM:
|
||||||
|
if *ignoreSigTerm {
|
||||||
|
log.Println("Ignoring SIGTERM")
|
||||||
|
} else {
|
||||||
|
log.Println("Recieved SIGTERM, shutting down")
|
||||||
|
break runloop
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
break runloop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Println("simple-responder shutting down")
|
log.Println("simple-responder shutting down")
|
||||||
}
|
}
|
||||||
|
|||||||
+24
-7
@@ -4,11 +4,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/google/shlex"
|
"github.com/billziss-gh/golib/shlex"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -228,14 +229,30 @@ func AddDefaultGroupToConfig(config Config) Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||||
// Remove trailing backslashes
|
var cleanedLines []string
|
||||||
cmdStr = strings.ReplaceAll(cmdStr, "\\ \n", " ")
|
for _, line := range strings.Split(cmdStr, "\n") {
|
||||||
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
|
trimmed := strings.TrimSpace(line)
|
||||||
|
// Skip comment lines
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Handle trailing backslashes by replacing with space
|
||||||
|
if strings.HasSuffix(trimmed, "\\") {
|
||||||
|
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||||
|
} else {
|
||||||
|
cleanedLines = append(cleanedLines, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// put it back together
|
||||||
|
cmdStr = strings.Join(cleanedLines, "\n")
|
||||||
|
|
||||||
// Split the command into arguments
|
// Split the command into arguments
|
||||||
args, err := shlex.Split(cmdStr)
|
var args []string
|
||||||
if err != nil {
|
if runtime.GOOS == "windows" {
|
||||||
return nil, err
|
args = shlex.Windows.Split(cmdStr)
|
||||||
|
} else {
|
||||||
|
args = shlex.Posix.Split(cmdStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the command is not empty
|
// Ensure the command is not empty
|
||||||
|
|||||||
@@ -0,0 +1,42 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||||
|
// Test a command with spaces and newlines
|
||||||
|
args, err := SanitizeCommand(`python model1.py \
|
||||||
|
-a "double quotes" \
|
||||||
|
--arg2 'single quotes'
|
||||||
|
-s
|
||||||
|
# comment 1
|
||||||
|
--arg3 123 \
|
||||||
|
|
||||||
|
# comment 2
|
||||||
|
--arg4 '"string in string"'
|
||||||
|
|
||||||
|
|
||||||
|
# this will get stripped out as well as the white space above
|
||||||
|
-c "'single quoted'"
|
||||||
|
`)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{
|
||||||
|
"python", "model1.py",
|
||||||
|
"-a", "double quotes",
|
||||||
|
"--arg2", "single quotes",
|
||||||
|
"-s",
|
||||||
|
"--arg3", "123",
|
||||||
|
"--arg4", `"string in string"`,
|
||||||
|
"-c", `'single quoted'`,
|
||||||
|
}, args)
|
||||||
|
|
||||||
|
// Test an empty command
|
||||||
|
args, err = SanitizeCommand("")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, args)
|
||||||
|
}
|
||||||
@@ -258,34 +258,6 @@ func TestConfig_FindConfig(t *testing.T) {
|
|||||||
assert.Equal(t, ModelConfig{}, modelConfig)
|
assert.Equal(t, ModelConfig{}, modelConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
|
||||||
|
|
||||||
// Test a command with spaces and newlines
|
|
||||||
args, err := SanitizeCommand(`python model1.py \
|
|
||||||
-a "double quotes" \
|
|
||||||
--arg2 'single quotes'
|
|
||||||
-s
|
|
||||||
--arg3 123 \
|
|
||||||
--arg4 '"string in string"'
|
|
||||||
-c "'single quoted'"
|
|
||||||
`)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{
|
|
||||||
"python", "model1.py",
|
|
||||||
"-a", "double quotes",
|
|
||||||
"--arg2", "single quotes",
|
|
||||||
"-s",
|
|
||||||
"--arg3", "123",
|
|
||||||
"--arg4", `"string in string"`,
|
|
||||||
"-c", `'single quoted'`,
|
|
||||||
}, args)
|
|
||||||
|
|
||||||
// Test an empty command
|
|
||||||
args, err = SanitizeCommand("")
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Nil(t, args)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConfig_AutomaticPortAssignments(t *testing.T) {
|
func TestConfig_AutomaticPortAssignments(t *testing.T) {
|
||||||
|
|
||||||
t.Run("Default Port Ranges", func(t *testing.T) {
|
t.Run("Default Port Ranges", func(t *testing.T) {
|
||||||
|
|||||||
@@ -0,0 +1,42 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||||
|
// does not support single quoted strings like in config_posix_test.go
|
||||||
|
args, err := SanitizeCommand(`python model1.py \
|
||||||
|
|
||||||
|
-a "double quotes" \
|
||||||
|
-s
|
||||||
|
--arg3 123 \
|
||||||
|
|
||||||
|
# comment 2
|
||||||
|
--arg4 '"string in string"'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# this will get stripped out as well as the white space above
|
||||||
|
-c "'single quoted'"
|
||||||
|
`)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{
|
||||||
|
"python", "model1.py",
|
||||||
|
"-a", "double quotes",
|
||||||
|
"-s",
|
||||||
|
"--arg3", "123",
|
||||||
|
"--arg4", "'string in string'", // this is a little weird but the lexer says so...?
|
||||||
|
"-c", `'single quoted'`,
|
||||||
|
}, args)
|
||||||
|
|
||||||
|
// Test an empty command
|
||||||
|
args, err = SanitizeCommand("")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, args)
|
||||||
|
|
||||||
|
}
|
||||||
+12
-3
@@ -45,17 +45,26 @@ func TestMain(m *testing.M) {
|
|||||||
func getSimpleResponderPath() string {
|
func getSimpleResponderPath() string {
|
||||||
goos := runtime.GOOS
|
goos := runtime.GOOS
|
||||||
goarch := runtime.GOARCH
|
goarch := runtime.GOARCH
|
||||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
|
||||||
|
if goos == "windows" {
|
||||||
|
return filepath.Join("..", "build", "simple-responder.exe")
|
||||||
|
} else {
|
||||||
|
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
func getTestPort() int {
|
||||||
portMutex.Lock()
|
portMutex.Lock()
|
||||||
defer portMutex.Unlock()
|
defer portMutex.Unlock()
|
||||||
|
|
||||||
port := nextTestPort
|
port := nextTestPort
|
||||||
nextTestPort++
|
nextTestPort++
|
||||||
|
|
||||||
return getTestSimpleResponderConfigPort(expectedMessage, port)
|
return port
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||||
|
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
||||||
|
|||||||
+43
-3
@@ -30,6 +30,13 @@ const (
|
|||||||
StateShutdown ProcessState = ProcessState("shutdown")
|
StateShutdown ProcessState = ProcessState("shutdown")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type StopStrategy int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StopImmediately StopStrategy = iota
|
||||||
|
StopWaitForInflightRequest
|
||||||
|
)
|
||||||
|
|
||||||
type Process struct {
|
type Process struct {
|
||||||
ID string
|
ID string
|
||||||
config ModelConfig
|
config ModelConfig
|
||||||
@@ -60,6 +67,12 @@ type Process struct {
|
|||||||
|
|
||||||
// for managing concurrency limits
|
// for managing concurrency limits
|
||||||
concurrencyLimitSemaphore chan struct{}
|
concurrencyLimitSemaphore chan struct{}
|
||||||
|
|
||||||
|
// stop timeout waiting for graceful shutdown
|
||||||
|
gracefulStopTimeout time.Duration
|
||||||
|
|
||||||
|
// track that this happened
|
||||||
|
upstreamWasStoppedWithKill bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
||||||
@@ -85,6 +98,10 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLo
|
|||||||
|
|
||||||
// concurrency limit
|
// concurrency limit
|
||||||
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
|
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
|
||||||
|
|
||||||
|
// stop timeout
|
||||||
|
gracefulStopTimeout: 5 * time.Second,
|
||||||
|
upstreamWasStoppedWithKill: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,6 +218,15 @@ func (p *Process) start() error {
|
|||||||
go func() {
|
go func() {
|
||||||
exitErr := p.cmd.Wait()
|
exitErr := p.cmd.Wait()
|
||||||
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
|
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
|
||||||
|
|
||||||
|
// there is a race condition when SIGKILL is used, p.cmd.Wait() returns, and then
|
||||||
|
// the code below fires, putting an error into cmdWaitChan. This code is to prevent this
|
||||||
|
if p.upstreamWasStoppedWithKill {
|
||||||
|
p.proxyLogger.Debugf("<%s> process was killed, NOT sending exitErr: %v", p.ID, exitErr)
|
||||||
|
p.upstreamWasStoppedWithKill = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p.cmdWaitChan <- exitErr
|
p.cmdWaitChan <- exitErr
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -313,13 +339,25 @@ func (p *Process) start() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop will wait for inflight requests to complete before stopping the process.
|
||||||
func (p *Process) Stop() {
|
func (p *Process) Stop() {
|
||||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for any inflight requests before proceeding
|
// wait for any inflight requests before proceeding
|
||||||
|
p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID)
|
||||||
p.inFlightRequests.Wait()
|
p.inFlightRequests.Wait()
|
||||||
|
p.StopImmediately()
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
|
||||||
|
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
|
||||||
|
func (p *Process) StopImmediately() {
|
||||||
|
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p.proxyLogger.Debugf("<%s> Stopping process", p.ID)
|
p.proxyLogger.Debugf("<%s> Stopping process", p.ID)
|
||||||
|
|
||||||
// calling Stop() when state is invalid is a no-op
|
// calling Stop() when state is invalid is a no-op
|
||||||
@@ -329,7 +367,7 @@ func (p *Process) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// stop the process with a graceful exit timeout
|
// stop the process with a graceful exit timeout
|
||||||
p.stopCommand(5 * time.Second)
|
p.stopCommand(p.gracefulStopTimeout)
|
||||||
|
|
||||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||||
p.proxyLogger.Infof("<%s> Stop() StateStopping -> StateStopped err: %v, current state: %v", p.ID, err, curState)
|
p.proxyLogger.Infof("<%s> Stop() StateStopping -> StateStopped err: %v, current state: %v", p.ID, err, curState)
|
||||||
@@ -338,10 +376,11 @@ func (p *Process) Stop() {
|
|||||||
|
|
||||||
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
||||||
// of time for any inflight requests to complete before shutting down. If the Process
|
// of time for any inflight requests to complete before shutting down. If the Process
|
||||||
// is in the state of starting, it will cancel it and shut it down
|
// is in the state of starting, it will cancel it and shut it down. Once a process is in
|
||||||
|
// the StateShutdown state, it can not be started again.
|
||||||
func (p *Process) Shutdown() {
|
func (p *Process) Shutdown() {
|
||||||
p.shutdownCancel()
|
p.shutdownCancel()
|
||||||
p.stopCommand(5 * time.Second)
|
p.stopCommand(p.gracefulStopTimeout)
|
||||||
p.state = StateShutdown
|
p.state = StateShutdown
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -368,6 +407,7 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
|||||||
select {
|
select {
|
||||||
case <-sigtermTimeout.Done():
|
case <-sigtermTimeout.Done():
|
||||||
p.proxyLogger.Debugf("<%s> Process timed out waiting to stop, sending KILL signal (normal during shutdown)", p.ID)
|
p.proxyLogger.Debugf("<%s> Process timed out waiting to stop, sending KILL signal (normal during shutdown)", p.ID)
|
||||||
|
p.upstreamWasStoppedWithKill = true
|
||||||
if err := p.cmd.Process.Kill(); err != nil {
|
if err := p.cmd.Process.Kill(); err != nil {
|
||||||
p.proxyLogger.Errorf("<%s> Failed to kill process: %v", p.ID, err)
|
p.proxyLogger.Errorf("<%s> Failed to kill process: %v", p.ID, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -372,3 +373,79 @@ func TestProcess_ConcurrencyLimit(t *testing.T) {
|
|||||||
process.ProxyRequest(w, denied)
|
process.ProxyRequest(w, denied)
|
||||||
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcess_StopImmediately(t *testing.T) {
|
||||||
|
expectedMessage := "test_stop_immediate"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
|
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
err := process.start()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, process.CurrentState(), StateReady)
|
||||||
|
go func() {
|
||||||
|
// slow, but will get killed by StopImmediate
|
||||||
|
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
}()
|
||||||
|
<-time.After(time.Millisecond)
|
||||||
|
process.StopImmediately()
|
||||||
|
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||||
|
// the upstream command
|
||||||
|
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||||
|
|
||||||
|
expectedMessage := "test_sigkill"
|
||||||
|
binaryPath := getSimpleResponderPath()
|
||||||
|
port := getTestPort()
|
||||||
|
|
||||||
|
config := ModelConfig{
|
||||||
|
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
||||||
|
// to force the process to exit
|
||||||
|
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
||||||
|
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
}
|
||||||
|
|
||||||
|
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
// reduce to make testing go faster
|
||||||
|
process.gracefulStopTimeout = time.Second
|
||||||
|
|
||||||
|
err := process.start()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, process.CurrentState(), StateReady)
|
||||||
|
|
||||||
|
waitChan := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
// slow, but will get killed by StopImmediate
|
||||||
|
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=2s", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
|
||||||
|
// StatusOK because that was already sent before the kill
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
// unexpected EOF because the kill happened, the "1" is sent before the kill
|
||||||
|
// then the unexpected EOF is sent after the kill
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
|
||||||
|
} else {
|
||||||
|
assert.Contains(t, w.Body.String(), "unexpected EOF")
|
||||||
|
}
|
||||||
|
|
||||||
|
close(waitChan)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-time.After(time.Millisecond)
|
||||||
|
process.StopImmediately()
|
||||||
|
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||||
|
|
||||||
|
// the request should have been interrupted by SIGKILL
|
||||||
|
<-waitChan
|
||||||
|
}
|
||||||
|
|||||||
@@ -76,14 +76,10 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
|
|||||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pg *ProcessGroup) StopProcesses() {
|
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
||||||
pg.Lock()
|
pg.Lock()
|
||||||
defer pg.Unlock()
|
defer pg.Unlock()
|
||||||
pg.stopProcesses()
|
|
||||||
}
|
|
||||||
|
|
||||||
// stopProcesses stops all processes in the group
|
|
||||||
func (pg *ProcessGroup) stopProcesses() {
|
|
||||||
if len(pg.processes) == 0 {
|
if len(pg.processes) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -94,7 +90,12 @@ func (pg *ProcessGroup) stopProcesses() {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(process *Process) {
|
go func(process *Process) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
process.Stop()
|
switch strategy {
|
||||||
|
case StopImmediately:
|
||||||
|
process.StopImmediately()
|
||||||
|
default:
|
||||||
|
process.Stop()
|
||||||
|
}
|
||||||
}(process)
|
}(process)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func TestProcessGroup_HasMember(t *testing.T) {
|
|||||||
|
|
||||||
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
||||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||||
defer pg.StopProcesses()
|
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
tests := []string{"model1", "model2"}
|
tests := []string{"model1", "model2"}
|
||||||
|
|
||||||
@@ -74,7 +74,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
|||||||
|
|
||||||
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||||
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
||||||
defer pg.StopProcesses()
|
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
tests := []string{"model3", "model4"}
|
tests := []string{"model3", "model4"}
|
||||||
|
|
||||||
|
|||||||
+31
-6
@@ -208,7 +208,7 @@ func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
// This is the public method safe for concurrent calls.
|
// This is the public method safe for concurrent calls.
|
||||||
// Unlike Shutdown, this method only stops the processes but doesn't perform
|
// Unlike Shutdown, this method only stops the processes but doesn't perform
|
||||||
// a complete shutdown, allowing for process replacement without full termination.
|
// a complete shutdown, allowing for process replacement without full termination.
|
||||||
func (pm *ProxyManager) StopProcesses() {
|
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
|
||||||
pm.Lock()
|
pm.Lock()
|
||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
@@ -218,7 +218,7 @@ func (pm *ProxyManager) StopProcesses() {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(processGroup *ProcessGroup) {
|
go func(processGroup *ProcessGroup) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
processGroup.stopProcesses()
|
processGroup.StopProcesses(strategy)
|
||||||
}(processGroup)
|
}(processGroup)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -260,7 +260,7 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
|
|||||||
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
|
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
|
||||||
for groupId, otherGroup := range pm.processGroups {
|
for groupId, otherGroup := range pm.processGroups {
|
||||||
if groupId != processGroup.id && !otherGroup.persistent {
|
if groupId != processGroup.id && !otherGroup.persistent {
|
||||||
otherGroup.StopProcesses()
|
otherGroup.StopProcesses(StopWaitForInflightRequest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -334,7 +334,31 @@ func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
|
|||||||
|
|
||||||
// Iterate over sorted keys
|
// Iterate over sorted keys
|
||||||
for _, modelID := range modelIDs {
|
for _, modelID := range modelIDs {
|
||||||
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a></li>", modelID, modelID))
|
// Get process state
|
||||||
|
processGroup := pm.findGroupByModelName(modelID)
|
||||||
|
var state string
|
||||||
|
if processGroup != nil {
|
||||||
|
process := processGroup.processes[modelID]
|
||||||
|
if process != nil {
|
||||||
|
var stateStr string
|
||||||
|
switch process.CurrentState() {
|
||||||
|
case StateReady:
|
||||||
|
stateStr = "Ready"
|
||||||
|
case StateStarting:
|
||||||
|
stateStr = "Starting"
|
||||||
|
case StateStopping:
|
||||||
|
stateStr = "Stopping"
|
||||||
|
case StateFailed:
|
||||||
|
stateStr = "Failed"
|
||||||
|
case StateShutdown:
|
||||||
|
stateStr = "Shutdown"
|
||||||
|
default:
|
||||||
|
stateStr = "Unknown"
|
||||||
|
}
|
||||||
|
state = stateStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a> - %s</li>", modelID, modelID, state))
|
||||||
}
|
}
|
||||||
html.WriteString("</ul></body></html>")
|
html.WriteString("</ul></body></html>")
|
||||||
c.Header("Content-Type", "text/html")
|
c.Header("Content-Type", "text/html")
|
||||||
@@ -374,7 +398,8 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
|
|
||||||
// dechunk it as we already have all the body bytes see issue #11
|
// dechunk it as we already have all the body bytes see issue #11
|
||||||
c.Request.Header.Del("transfer-encoding")
|
c.Request.Header.Del("transfer-encoding")
|
||||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
||||||
|
c.Request.ContentLength = int64(len(bodyBytes))
|
||||||
|
|
||||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
@@ -504,7 +529,7 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||||
pm.StopProcesses()
|
pm.StopProcesses(StopImmediately)
|
||||||
c.String(http.StatusOK, "OK")
|
c.String(http.StatusOK, "OK")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+18
-12
@@ -14,6 +14,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||||
@@ -27,7 +28,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
for _, modelName := range []string{"model1", "model2"} {
|
for _, modelName := range []string{"model1", "model2"} {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||||
@@ -63,7 +64,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
tests := []string{"model1", "model2"}
|
tests := []string{"model1", "model2"}
|
||||||
for _, requestedModel := range tests {
|
for _, requestedModel := range tests {
|
||||||
@@ -105,7 +106,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
// make requests to load all models, loading model1 should not affect model2
|
// make requests to load all models, loading model1 should not affect model2
|
||||||
tests := []string{"model2", "model1"}
|
tests := []string{"model2", "model1"}
|
||||||
@@ -141,7 +142,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
results := map[string]string{}
|
results := map[string]string{}
|
||||||
|
|
||||||
@@ -339,7 +340,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
|||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
LogLevel: "debug",
|
LogLevel: "warn",
|
||||||
})
|
})
|
||||||
|
|
||||||
// Define a helper struct to parse the JSON response.
|
// Define a helper struct to parse the JSON response.
|
||||||
@@ -352,7 +353,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
|||||||
|
|
||||||
// Create proxy once for all tests
|
// Create proxy once for all tests
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
t.Run("no models loaded", func(t *testing.T) {
|
t.Run("no models loaded", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "/running", nil)
|
req := httptest.NewRequest("GET", "/running", nil)
|
||||||
@@ -407,7 +408,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
// Create a buffer with multipart form data
|
// Create a buffer with multipart form data
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
@@ -448,7 +449,6 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
|||||||
// Test useModelName in configuration sends overrides what is sent to upstream
|
// Test useModelName in configuration sends overrides what is sent to upstream
|
||||||
func TestProxyManager_UseModelName(t *testing.T) {
|
func TestProxyManager_UseModelName(t *testing.T) {
|
||||||
upstreamModelName := "upstreamModel"
|
upstreamModelName := "upstreamModel"
|
||||||
|
|
||||||
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
||||||
modelConfig.UseModelName = upstreamModelName
|
modelConfig.UseModelName = upstreamModelName
|
||||||
|
|
||||||
@@ -461,7 +461,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
requestedModel := "model1"
|
requestedModel := "model1"
|
||||||
|
|
||||||
@@ -473,6 +473,12 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
|||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), upstreamModelName)
|
assert.Contains(t, w.Body.String(), upstreamModelName)
|
||||||
|
|
||||||
|
// make sure the content length was set correctly
|
||||||
|
// simple-responder will return the content length it got in the response
|
||||||
|
body := w.Body.Bytes()
|
||||||
|
contentLength := int(gjson.GetBytes(body, "h_content_length").Int())
|
||||||
|
assert.Equal(t, len(fmt.Sprintf(`{"model":"%s"}`, upstreamModelName)), contentLength)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) {
|
t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) {
|
||||||
@@ -557,7 +563,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
|
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
|
||||||
for k, v := range tt.requestHeaders {
|
for k, v := range tt.requestHeaders {
|
||||||
@@ -586,7 +592,7 @@ func TestProxyManager_Upstream(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
proxy.ServeHTTP(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
@@ -604,7 +610,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
|||||||
Reference in New Issue
Block a user