Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 014a2fa9a3 | |||
| 5ceaef6144 |
@@ -1,9 +1,4 @@
|
|||||||

|

|
||||||

|
|
||||||

|
|
||||||

|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# llama-swap
|
# llama-swap
|
||||||
|
|
||||||
|
|||||||
+1
-3
@@ -304,9 +304,7 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.terminateProcess(); err != nil {
|
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||||
fmt.Fprintf(p.logMonitor, "!!! failed to gracefully terminate process [%s]: %v\n", p.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-sigtermTimeout.Done():
|
case <-sigtermTimeout.Done():
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
//go:build !windows
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import "syscall"
|
|
||||||
|
|
||||||
func (p *Process) terminateProcess() error {
|
|
||||||
return p.cmd.Process.Signal(syscall.SIGTERM)
|
|
||||||
}
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
//go:build windows
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os/exec"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (p *Process) terminateProcess() error {
|
|
||||||
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
|
|
||||||
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
+4
-15
@@ -72,25 +72,14 @@ func New(config *Config) *ProxyManager {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// see: issue: #81, #77 and #42 for CORS issues
|
// see: https://github.com/mostlygeek/llama-swap/issues/42
|
||||||
// respond with permissive OPTIONS for any endpoint
|
// respond with permissive OPTIONS for any endpoint
|
||||||
pm.ginEngine.Use(func(c *gin.Context) {
|
pm.ginEngine.Use(func(c *gin.Context) {
|
||||||
if c.Request.Method == "OPTIONS" {
|
if c.Request.Method == "OPTIONS" {
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||||
|
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||||
// allow whatever the client requested by default
|
c.AbortWithStatus(204)
|
||||||
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
|
||||||
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
|
|
||||||
c.Header("Access-Control-Allow-Headers", sanitized)
|
|
||||||
} else {
|
|
||||||
c.Header(
|
|
||||||
"Access-Control-Allow-Headers",
|
|
||||||
"Content-Type, Authorization, Accept, X-Requested-With",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
c.Header("Access-Control-Max-Age", "86400")
|
|
||||||
c.AbortWithStatus(http.StatusNoContent)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|||||||
@@ -639,72 +639,5 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
|||||||
assert.Equal(t, upstreamModelName, response["model"])
|
assert.Equal(t, upstreamModelName, response["model"])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
|
||||||
config := &Config{
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Models: map[string]ModelConfig{
|
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
|
||||||
},
|
|
||||||
LogRequests: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
method string
|
|
||||||
requestHeaders map[string]string
|
|
||||||
expectedStatus int
|
|
||||||
expectedHeaders map[string]string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "OPTIONS with no headers",
|
|
||||||
method: "OPTIONS",
|
|
||||||
expectedStatus: http.StatusNoContent,
|
|
||||||
expectedHeaders: map[string]string{
|
|
||||||
"Access-Control-Allow-Origin": "*",
|
|
||||||
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
|
||||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "OPTIONS with specific headers",
|
|
||||||
method: "OPTIONS",
|
|
||||||
requestHeaders: map[string]string{
|
|
||||||
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
|
|
||||||
},
|
|
||||||
expectedStatus: http.StatusNoContent,
|
|
||||||
expectedHeaders: map[string]string{
|
|
||||||
"Access-Control-Allow-Origin": "*",
|
|
||||||
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
|
||||||
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Non-OPTIONS request",
|
|
||||||
method: "GET",
|
|
||||||
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
proxy := New(config)
|
|
||||||
defer proxy.StopProcesses()
|
|
||||||
|
|
||||||
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
|
|
||||||
for k, v := range tt.requestHeaders {
|
|
||||||
req.Header.Set(k, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
proxy.ginEngine.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
|
||||||
|
|
||||||
for header, expectedValue := range tt.expectedHeaders {
|
|
||||||
assert.Equal(t, expectedValue, w.Header().Get(header))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,43 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func isTokenChar(r rune) bool {
|
|
||||||
switch {
|
|
||||||
case r >= 'a' && r <= 'z':
|
|
||||||
case r >= 'A' && r <= 'Z':
|
|
||||||
case r >= '0' && r <= '9':
|
|
||||||
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
|
|
||||||
parts := strings.Split(headerValues, ",")
|
|
||||||
valid := make([]string, 0, len(parts))
|
|
||||||
|
|
||||||
for _, p := range parts {
|
|
||||||
v := strings.TrimSpace(p)
|
|
||||||
if v == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
validPart := true
|
|
||||||
for _, c := range v {
|
|
||||||
if !isTokenChar(c) {
|
|
||||||
validPart = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if validPart {
|
|
||||||
valid = append(valid, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return strings.Join(valid, ", ")
|
|
||||||
}
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "empty string",
|
|
||||||
input: "",
|
|
||||||
expected: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "whitespace only",
|
|
||||||
input: " ",
|
|
||||||
expected: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single valid value",
|
|
||||||
input: "content-type",
|
|
||||||
expected: "content-type",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple valid values",
|
|
||||||
input: "content-type, authorization, x-requested-with",
|
|
||||||
expected: "content-type, authorization, x-requested-with",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "values with extra spaces",
|
|
||||||
input: " content-type , authorization ",
|
|
||||||
expected: "content-type, authorization",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "values with tabs",
|
|
||||||
input: "content-type,\tauthorization",
|
|
||||||
expected: "content-type, authorization",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "values with invalid characters",
|
|
||||||
input: "content-type, auth\n, x-requested-with\r",
|
|
||||||
expected: "content-type, auth, x-requested-with",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty values in list",
|
|
||||||
input: "content-type,,authorization",
|
|
||||||
expected: "content-type, authorization",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "leading and trailing commas",
|
|
||||||
input: ",content-type,authorization,",
|
|
||||||
expected: "content-type, authorization",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mixed valid and invalid values",
|
|
||||||
input: "content-type, \x00invalid, x-requested-with",
|
|
||||||
expected: "content-type, x-requested-with",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mixed case values",
|
|
||||||
input: "Content-Type, my-Valid-Header, Another-hEader",
|
|
||||||
expected: "Content-Type, my-Valid-Header, Another-hEader",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := SanitizeAccessControlRequestHeaderValues(tt.input)
|
|
||||||
if got != tt.expected {
|
|
||||||
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
|
|
||||||
tt.input, got, tt.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user