Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 78b2bc3dbc | |||
| 6a058e4191 | |||
| 1921e570d7 | |||
| c867a6c9a2 | |||
| 3bd1b23ce0 | |||
| 10606abf89 |
+13
-1
@@ -49,7 +49,19 @@ models:
|
|||||||
cmd: |
|
cmd: |
|
||||||
# ${latest-llama} is a macro that is defined above
|
# ${latest-llama} is a macro that is defined above
|
||||||
${latest-llama}
|
${latest-llama}
|
||||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
--model path/to/llama-8B-Q4_K_M.gguf
|
||||||
|
|
||||||
|
# name: a display name for the model
|
||||||
|
# - optional, default: empty string
|
||||||
|
# - if set, it will be used in the v1/models API response
|
||||||
|
# - if not set, it will be omitted in the JSON model record
|
||||||
|
name: "llama 3.1 8B"
|
||||||
|
|
||||||
|
# description: a description for the model
|
||||||
|
# - optional, default: empty string
|
||||||
|
# - if set, it will be used in the v1/models API response
|
||||||
|
# - if not set, it will be omitted in the JSON model record
|
||||||
|
description: "A small but capable model used for quick testing"
|
||||||
|
|
||||||
# env: define an array of environment variables to inject into cmd's environment
|
# env: define an array of environment variables to inject into cmd's environment
|
||||||
# - optional, default: empty array
|
# - optional, default: empty array
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ require (
|
|||||||
github.com/tidwall/gjson v1.18.0
|
github.com/tidwall/gjson v1.18.0
|
||||||
github.com/tidwall/sjson v1.2.5
|
github.com/tidwall/sjson v1.2.5
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
|
github.com/kelindar/event v1.5.2
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaU
|
|||||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
|
github.com/kelindar/event v1.5.2 h1:qtgssZqMh/QQMCIxlbx4wU3DoMHOrJXKdiZhphJ4YbY=
|
||||||
|
github.com/kelindar/event v1.5.2/go.mod h1:UxWPQjWK8u0o9Z3ponm2mgREimM95hm26/M9z8F488Q=
|
||||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||||
|
|||||||
+105
-111
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/kelindar/event"
|
||||||
"github.com/mostlygeek/llama-swap/proxy"
|
"github.com/mostlygeek/llama-swap/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -53,137 +54,130 @@ func main() {
|
|||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyManager := proxy.New(config)
|
|
||||||
|
|
||||||
// Setup channels for server management
|
// Setup channels for server management
|
||||||
reloadChan := make(chan *proxy.ProxyManager)
|
|
||||||
exitChan := make(chan struct{})
|
exitChan := make(chan struct{})
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
// Create server with initial handler
|
// Create server with initial handler
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: *listenStr,
|
Addr: *listenStr,
|
||||||
Handler: proxyManager,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Support for watching config and reloading when it changes
|
||||||
|
reloadProxyManager := func() {
|
||||||
|
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||||
|
config, err = proxy.LoadConfig(*configPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Configuration Changed")
|
||||||
|
currentPM.Shutdown()
|
||||||
|
srv.Handler = proxy.New(config)
|
||||||
|
fmt.Println("Configuration Reloaded")
|
||||||
|
|
||||||
|
// wait a few seconds and tell any UI to reload
|
||||||
|
time.AfterFunc(3*time.Second, func() {
|
||||||
|
event.Emit(proxy.ConfigFileChangedEvent{
|
||||||
|
ReloadingState: proxy.ReloadingStateEnd,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
config, err = proxy.LoadConfig(*configPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error, unable to load configuration: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
srv.Handler = proxy.New(config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// load the initial proxy manager
|
||||||
|
reloadProxyManager()
|
||||||
|
debouncedReload := debounce(time.Second, reloadProxyManager)
|
||||||
|
if *watchConfig {
|
||||||
|
defer event.On(func(e proxy.ConfigFileChangedEvent) {
|
||||||
|
if e.ReloadingState == proxy.ReloadingStateStart {
|
||||||
|
debouncedReload()
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
fmt.Println("Watching Configuration for changes")
|
||||||
|
go func() {
|
||||||
|
absConfigPath, err := filepath.Abs(*configPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error getting absolute path for watching config file: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
watcher, err := fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error creating file watcher: %v. File watching disabled.\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
configDir := filepath.Dir(absConfigPath)
|
||||||
|
err = watcher.Add(configDir)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error adding config path directory (%s) to watcher: %v. File watching disabled.", configDir, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer watcher.Close()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case changeEvent := <-watcher.Events:
|
||||||
|
if changeEvent.Name == absConfigPath && (changeEvent.Has(fsnotify.Write) || changeEvent.Has(fsnotify.Create) || changeEvent.Has(fsnotify.Remove)) {
|
||||||
|
event.Emit(proxy.ConfigFileChangedEvent{
|
||||||
|
ReloadingState: proxy.ReloadingStateStart,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
case err := <-watcher.Errors:
|
||||||
|
log.Printf("File watcher error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// shutdown on signal
|
||||||
|
go func() {
|
||||||
|
sig := <-sigChan
|
||||||
|
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if pm, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||||
|
pm.Shutdown()
|
||||||
|
} else {
|
||||||
|
fmt.Println("srv.Handler is not of type *proxy.ProxyManager")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
|
fmt.Printf("Server shutdown error: %v\n", err)
|
||||||
|
}
|
||||||
|
close(exitChan)
|
||||||
|
}()
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
fmt.Printf("llama-swap listening on %s\n", *listenStr)
|
fmt.Printf("llama-swap listening on %s\n", *listenStr)
|
||||||
go func() {
|
go func() {
|
||||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
fmt.Printf("Fatal server error: %v\n", err)
|
log.Fatalf("Fatal server error: %v\n", err)
|
||||||
close(exitChan)
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Handle config reloads and signals
|
|
||||||
go func() {
|
|
||||||
currentManager := proxyManager
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case newManager := <-reloadChan:
|
|
||||||
log.Println("Config change detected, waiting for in-flight requests to complete...")
|
|
||||||
// Stop old manager processes gracefully (this waits for in-flight requests)
|
|
||||||
currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
|
|
||||||
// Now do a full shutdown to clear the process map
|
|
||||||
currentManager.Shutdown()
|
|
||||||
currentManager = newManager
|
|
||||||
srv.Handler = newManager
|
|
||||||
log.Println("Server handler updated with new config")
|
|
||||||
case sig := <-sigChan:
|
|
||||||
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
currentManager.Shutdown()
|
|
||||||
if err := srv.Shutdown(ctx); err != nil {
|
|
||||||
fmt.Printf("Server shutdown error: %v\n", err)
|
|
||||||
}
|
|
||||||
close(exitChan)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Start file watcher if requested
|
|
||||||
if *watchConfig {
|
|
||||||
absConfigPath, err := filepath.Abs(*configPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error getting absolute path for config: %v. File watching disabled.", err)
|
|
||||||
} else {
|
|
||||||
go watchConfigFileWithReload(absConfigPath, reloadChan)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for exit signal
|
// Wait for exit signal
|
||||||
<-exitChan
|
<-exitChan
|
||||||
}
|
}
|
||||||
|
|
||||||
// watchConfigFileWithReload monitors the configuration file and sends new ProxyManager instances through reloadChan.
|
func debounce(interval time.Duration, f func()) func() {
|
||||||
func watchConfigFileWithReload(configPath string, reloadChan chan<- *proxy.ProxyManager) {
|
var timer *time.Timer
|
||||||
watcher, err := fsnotify.NewWatcher()
|
return func() {
|
||||||
if err != nil {
|
if timer != nil {
|
||||||
log.Printf("Error creating file watcher: %v. File watching disabled.", err)
|
timer.Stop()
|
||||||
return
|
|
||||||
}
|
|
||||||
defer watcher.Close()
|
|
||||||
|
|
||||||
err = watcher.Add(configPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", configPath, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Watching config file for changes: %s", configPath)
|
|
||||||
|
|
||||||
var debounceTimer *time.Timer
|
|
||||||
debounceDuration := 2 * time.Second
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case event, ok := <-watcher.Events:
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// We only care about writes to the specific config file
|
|
||||||
if event.Name == configPath && event.Has(fsnotify.Write) {
|
|
||||||
// Reset or start the debounce timer
|
|
||||||
if debounceTimer != nil {
|
|
||||||
debounceTimer.Stop()
|
|
||||||
}
|
|
||||||
debounceTimer = time.AfterFunc(debounceDuration, func() {
|
|
||||||
log.Printf("Config file modified: %s, reloading...", event.Name)
|
|
||||||
|
|
||||||
// Try up to 3 times with exponential backoff
|
|
||||||
var newConfig proxy.Config
|
|
||||||
var err error
|
|
||||||
for retries := 0; retries < 3; retries++ {
|
|
||||||
// Load new configuration
|
|
||||||
newConfig, err = proxy.LoadConfig(configPath)
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
log.Printf("Error loading new config (attempt %d/3): %v", retries+1, err)
|
|
||||||
if retries < 2 {
|
|
||||||
time.Sleep(time.Duration(1<<retries) * time.Second)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to load new config after retries: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create new ProxyManager with new config
|
|
||||||
newPM := proxy.New(newConfig)
|
|
||||||
reloadChan <- newPM
|
|
||||||
log.Println("Config reloaded successfully")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
case err, ok := <-watcher.Errors:
|
|
||||||
if !ok {
|
|
||||||
log.Println("File watcher error channel closed.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Printf("File watcher error: %v", err)
|
|
||||||
}
|
}
|
||||||
|
timer = time.AfterFunc(interval, f)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ type ModelConfig struct {
|
|||||||
Unlisted bool `yaml:"unlisted"`
|
Unlisted bool `yaml:"unlisted"`
|
||||||
UseModelName string `yaml:"useModelName"`
|
UseModelName string `yaml:"useModelName"`
|
||||||
|
|
||||||
|
// #179 for /v1/models
|
||||||
|
Name string `yaml:"name"`
|
||||||
|
Description string `yaml:"description"`
|
||||||
|
|
||||||
// Limit concurrency of HTTP requests to process
|
// Limit concurrency of HTTP requests to process
|
||||||
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||||
|
|
||||||
@@ -48,6 +52,8 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|||||||
Unlisted: false,
|
Unlisted: false,
|
||||||
UseModelName: "",
|
UseModelName: "",
|
||||||
ConcurrencyLimit: 0,
|
ConcurrencyLimit: 0,
|
||||||
|
Name: "",
|
||||||
|
Description: "",
|
||||||
}
|
}
|
||||||
|
|
||||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||||
|
|||||||
@@ -104,6 +104,8 @@ models:
|
|||||||
model1:
|
model1:
|
||||||
cmd: path/to/cmd --arg1 one
|
cmd: path/to/cmd --arg1 one
|
||||||
proxy: "http://localhost:8080"
|
proxy: "http://localhost:8080"
|
||||||
|
name: "Model 1"
|
||||||
|
description: "This is model 1"
|
||||||
aliases:
|
aliases:
|
||||||
- "m1"
|
- "m1"
|
||||||
- "model-one"
|
- "model-one"
|
||||||
@@ -168,6 +170,8 @@ groups:
|
|||||||
Aliases: []string{"m1", "model-one"},
|
Aliases: []string{"m1", "model-one"},
|
||||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
|
Name: "Model 1",
|
||||||
|
Description: "This is model 1",
|
||||||
},
|
},
|
||||||
"model2": {
|
"model2": {
|
||||||
Cmd: "path/to/server --arg1 one",
|
Cmd: "path/to/server --arg1 one",
|
||||||
|
|||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
// package level registry of the different event types
|
||||||
|
|
||||||
|
const ProcessStateChangeEventID = 0x01
|
||||||
|
const ChatCompletionStatsEventID = 0x02
|
||||||
|
const ConfigFileChangedEventID = 0x03
|
||||||
|
const LogDataEventID = 0x04
|
||||||
|
|
||||||
|
type ProcessStateChangeEvent struct {
|
||||||
|
ProcessName string
|
||||||
|
NewState ProcessState
|
||||||
|
OldState ProcessState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ProcessStateChangeEvent) Type() uint32 {
|
||||||
|
return ProcessStateChangeEventID
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionStats struct {
|
||||||
|
TokensGenerated int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ChatCompletionStats) Type() uint32 {
|
||||||
|
return ChatCompletionStatsEventID
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReloadingState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ReloadingStateStart ReloadingState = iota
|
||||||
|
ReloadingStateEnd
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConfigFileChangedEvent struct {
|
||||||
|
ReloadingState ReloadingState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ConfigFileChangedEvent) Type() uint32 {
|
||||||
|
return ConfigFileChangedEventID
|
||||||
|
}
|
||||||
|
|
||||||
|
type LogDataEvent struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e LogDataEvent) Type() uint32 {
|
||||||
|
return LogDataEventID
|
||||||
|
}
|
||||||
+14
-31
@@ -2,10 +2,13 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"container/ring"
|
"container/ring"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/kelindar/event"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LogLevel int
|
type LogLevel int
|
||||||
@@ -18,7 +21,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type LogMonitor struct {
|
type LogMonitor struct {
|
||||||
clients map[chan []byte]bool
|
eventbus *event.Dispatcher
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
buffer *ring.Ring
|
buffer *ring.Ring
|
||||||
bufferMu sync.RWMutex
|
bufferMu sync.RWMutex
|
||||||
@@ -37,11 +40,11 @@ func NewLogMonitor() *LogMonitor {
|
|||||||
|
|
||||||
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||||
return &LogMonitor{
|
return &LogMonitor{
|
||||||
clients: make(map[chan []byte]bool),
|
eventbus: event.NewDispatcher(),
|
||||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||||
stdout: stdout,
|
stdout: stdout,
|
||||||
level: LevelInfo,
|
level: LevelInfo,
|
||||||
prefix: "",
|
prefix: "",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,34 +84,14 @@ func (w *LogMonitor) GetHistory() []byte {
|
|||||||
return history
|
return history
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *LogMonitor) Subscribe() chan []byte {
|
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
|
||||||
w.mu.Lock()
|
return event.Subscribe(w.eventbus, func(e LogDataEvent) {
|
||||||
defer w.mu.Unlock()
|
callback(e.Data)
|
||||||
|
})
|
||||||
ch := make(chan []byte, 100)
|
|
||||||
w.clients[ch] = true
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *LogMonitor) Unsubscribe(ch chan []byte) {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
delete(w.clients, ch)
|
|
||||||
close(ch)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *LogMonitor) broadcast(msg []byte) {
|
func (w *LogMonitor) broadcast(msg []byte) {
|
||||||
w.mu.RLock()
|
event.Publish(w.eventbus, LogDataEvent{Data: msg})
|
||||||
defer w.mu.RUnlock()
|
|
||||||
|
|
||||||
for client := range w.clients {
|
|
||||||
select {
|
|
||||||
case client <- msg:
|
|
||||||
default:
|
|
||||||
// If client buffer is full, skip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *LogMonitor) SetPrefix(prefix string) {
|
func (w *LogMonitor) SetPrefix(prefix string) {
|
||||||
|
|||||||
+13
-22
@@ -10,38 +10,29 @@ import (
|
|||||||
func TestLogMonitor(t *testing.T) {
|
func TestLogMonitor(t *testing.T) {
|
||||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||||
|
|
||||||
// Test subscription
|
// A WaitGroup is used to wait for all the expected writes to complete
|
||||||
client1 := logMonitor.Subscribe()
|
var wg sync.WaitGroup
|
||||||
client2 := logMonitor.Subscribe()
|
|
||||||
|
|
||||||
defer logMonitor.Unsubscribe(client1)
|
|
||||||
defer logMonitor.Unsubscribe(client2)
|
|
||||||
|
|
||||||
client1Messages := make([]byte, 0)
|
client1Messages := make([]byte, 0)
|
||||||
client2Messages := make([]byte, 0)
|
client2Messages := make([]byte, 0)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
defer logMonitor.OnLogData(func(data []byte) {
|
||||||
wg.Add(1)
|
client1Messages = append(client1Messages, data...)
|
||||||
|
wg.Done()
|
||||||
|
})()
|
||||||
|
|
||||||
go func() {
|
defer logMonitor.OnLogData(func(data []byte) {
|
||||||
defer wg.Done()
|
client2Messages = append(client2Messages, data...)
|
||||||
for {
|
wg.Done()
|
||||||
select {
|
})()
|
||||||
case data := <-client1:
|
|
||||||
client1Messages = append(client1Messages, data...)
|
wg.Add(6) // 2 x 3 writes
|
||||||
case data := <-client2:
|
|
||||||
client2Messages = append(client2Messages, data...)
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
logMonitor.Write([]byte("1"))
|
logMonitor.Write([]byte("1"))
|
||||||
logMonitor.Write([]byte("2"))
|
logMonitor.Write([]byte("2"))
|
||||||
logMonitor.Write([]byte("3"))
|
logMonitor.Write([]byte("3"))
|
||||||
|
|
||||||
// Wait for the goroutine to finish
|
// wait for all writes to complete
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
// Check the buffer
|
// Check the buffer
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/kelindar/event"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProcessState string
|
type ProcessState string
|
||||||
@@ -127,6 +129,7 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
|
|||||||
|
|
||||||
p.state = newState
|
p.state = newState
|
||||||
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||||
|
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
|
||||||
return p.state, nil
|
return p.state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+33
-16
@@ -2,7 +2,7 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
@@ -34,6 +34,10 @@ type ProxyManager struct {
|
|||||||
muxLogger *LogMonitor
|
muxLogger *LogMonitor
|
||||||
|
|
||||||
processGroups map[string]*ProcessGroup
|
processGroups map[string]*ProcessGroup
|
||||||
|
|
||||||
|
// shutdown signaling
|
||||||
|
shutdownCtx context.Context
|
||||||
|
shutdownCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(config Config) *ProxyManager {
|
func New(config Config) *ProxyManager {
|
||||||
@@ -64,6 +68,8 @@ func New(config Config) *ProxyManager {
|
|||||||
upstreamLogger.SetLogLevel(LevelInfo)
|
upstreamLogger.SetLogLevel(LevelInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
pm := &ProxyManager{
|
pm := &ProxyManager{
|
||||||
config: config,
|
config: config,
|
||||||
ginEngine: gin.New(),
|
ginEngine: gin.New(),
|
||||||
@@ -73,6 +79,9 @@ func New(config Config) *ProxyManager {
|
|||||||
upstreamLogger: upstreamLogger,
|
upstreamLogger: upstreamLogger,
|
||||||
|
|
||||||
processGroups: make(map[string]*ProcessGroup),
|
processGroups: make(map[string]*ProcessGroup),
|
||||||
|
|
||||||
|
shutdownCtx: shutdownCtx,
|
||||||
|
shutdownCancel: shutdownCancel,
|
||||||
}
|
}
|
||||||
|
|
||||||
// create the process groups
|
// create the process groups
|
||||||
@@ -158,9 +167,7 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
// in proxymanager_loghandlers.go
|
// in proxymanager_loghandlers.go
|
||||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
||||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||||
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
|
||||||
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
||||||
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* User Interface Endpoints
|
* User Interface Endpoints
|
||||||
@@ -262,6 +269,7 @@ func (pm *ProxyManager) Shutdown() {
|
|||||||
}(processGroup)
|
}(processGroup)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
pm.shutdownCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
||||||
@@ -289,32 +297,41 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||||
data := []interface{}{}
|
data := make([]gin.H, 0, len(pm.config.Models))
|
||||||
|
createdTime := time.Now().Unix()
|
||||||
|
|
||||||
for id, modelConfig := range pm.config.Models {
|
for id, modelConfig := range pm.config.Models {
|
||||||
if modelConfig.Unlisted {
|
if modelConfig.Unlisted {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
data = append(data, map[string]interface{}{
|
record := gin.H{
|
||||||
"id": id,
|
"id": id,
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"created": time.Now().Unix(),
|
"created": createdTime,
|
||||||
"owned_by": "llama-swap",
|
"owned_by": "llama-swap",
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||||
|
record["name"] = name
|
||||||
|
}
|
||||||
|
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||||
|
record["description"] = desc
|
||||||
|
}
|
||||||
|
|
||||||
|
data = append(data, record)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the Content-Type header to application/json
|
// Set CORS headers if origin exists
|
||||||
c.Header("Content-Type", "application/json")
|
if origin := c.GetHeader("Origin"); origin != "" {
|
||||||
|
|
||||||
if origin := c.Request.Header.Get("Origin"); origin != "" {
|
|
||||||
c.Header("Access-Control-Allow-Origin", origin)
|
c.Header("Access-Control-Allow-Origin", origin)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode the data as JSON and write it to the response writer
|
// Use gin's JSON method which handles content-type and encoding
|
||||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"object": "list", "data": data}); err != nil {
|
c.JSON(http.StatusOK, gin.H{
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
|
"object": "list",
|
||||||
return
|
"data": data,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||||
|
|||||||
+88
-18
@@ -1,25 +1,29 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/kelindar/event"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
State string `json:"state"`
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
State string `json:"state"`
|
||||||
|
Unlisted bool `json:"unlisted"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func addApiHandlers(pm *ProxyManager) {
|
func addApiHandlers(pm *ProxyManager) {
|
||||||
// Add API endpoints for React to consume
|
// Add API endpoints for React to consume
|
||||||
apiGroup := pm.ginEngine.Group("/api")
|
apiGroup := pm.ginEngine.Group("/api")
|
||||||
{
|
{
|
||||||
apiGroup.GET("/models", pm.apiListModels)
|
|
||||||
apiGroup.GET("/modelsSSE", pm.apiListModelsSSE)
|
|
||||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||||
|
apiGroup.GET("/events", pm.apiSendEvents)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,37 +69,103 @@ func (pm *ProxyManager) getModelStatus() []Model {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
models = append(models, Model{
|
models = append(models, Model{
|
||||||
Id: modelID,
|
Id: modelID,
|
||||||
State: state,
|
Name: pm.config.Models[modelID].Name,
|
||||||
|
Description: pm.config.Models[modelID].Description,
|
||||||
|
State: state,
|
||||||
|
Unlisted: pm.config.Models[modelID].Unlisted,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return models
|
return models
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) apiListModels(c *gin.Context) {
|
type messageType string
|
||||||
c.JSON(http.StatusOK, pm.getModelStatus())
|
|
||||||
|
const (
|
||||||
|
msgTypeModelStatus messageType = "modelStatus"
|
||||||
|
msgTypeLogData messageType = "logData"
|
||||||
|
)
|
||||||
|
|
||||||
|
type messageEnvelope struct {
|
||||||
|
Type messageType `json:"type"`
|
||||||
|
Data string `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// stream the models as a SSE
|
// sends a stream of different message types that happen on the server
|
||||||
func (pm *ProxyManager) apiListModelsSSE(c *gin.Context) {
|
func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
c.Header("Cache-Control", "no-cache")
|
c.Header("Cache-Control", "no-cache")
|
||||||
c.Header("Connection", "keep-alive")
|
c.Header("Connection", "keep-alive")
|
||||||
c.Header("X-Content-Type-Options", "nosniff")
|
c.Header("X-Content-Type-Options", "nosniff")
|
||||||
|
|
||||||
notify := c.Request.Context().Done()
|
sendBuffer := make(chan messageEnvelope, 25)
|
||||||
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
|
sendModels := func() {
|
||||||
|
data, err := json.Marshal(pm.getModelStatus())
|
||||||
|
if err == nil {
|
||||||
|
msg := messageEnvelope{Type: msgTypeModelStatus, Data: string(data)}
|
||||||
|
select {
|
||||||
|
case sendBuffer <- msg:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sendLogData := func(source string, data []byte) {
|
||||||
|
data, err := json.Marshal(gin.H{
|
||||||
|
"source": source,
|
||||||
|
"data": string(data),
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
select {
|
||||||
|
case sendBuffer <- messageEnvelope{Type: msgTypeLogData, Data: string(data)}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send updated models list
|
||||||
|
*/
|
||||||
|
defer event.On(func(e ProcessStateChangeEvent) {
|
||||||
|
sendModels()
|
||||||
|
})()
|
||||||
|
defer event.On(func(e ConfigFileChangedEvent) {
|
||||||
|
sendModels()
|
||||||
|
})()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send Log data
|
||||||
|
*/
|
||||||
|
defer pm.proxyLogger.OnLogData(func(data []byte) {
|
||||||
|
sendLogData("proxy", data)
|
||||||
|
})()
|
||||||
|
defer pm.upstreamLogger.OnLogData(func(data []byte) {
|
||||||
|
sendLogData("upstream", data)
|
||||||
|
})()
|
||||||
|
|
||||||
|
// send initial batch of data
|
||||||
|
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||||
|
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||||
|
sendModels()
|
||||||
|
|
||||||
// Stream new events
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-notify:
|
case <-c.Request.Context().Done():
|
||||||
|
cancel()
|
||||||
return
|
return
|
||||||
default:
|
case <-pm.shutdownCtx.Done():
|
||||||
models := pm.getModelStatus()
|
cancel()
|
||||||
c.SSEvent("message", models)
|
return
|
||||||
|
case msg := <-sendBuffer:
|
||||||
|
c.SSEvent("message", msg)
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
<-time.After(1000 * time.Millisecond)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -34,10 +35,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
c.String(http.StatusBadRequest, err.Error())
|
c.String(http.StatusBadRequest, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ch := logger.Subscribe()
|
|
||||||
defer logger.Unsubscribe(ch)
|
|
||||||
|
|
||||||
notify := c.Request.Context().Done()
|
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
||||||
@@ -55,57 +53,28 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stream new logs
|
sendChan := make(chan []byte, 10)
|
||||||
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
|
defer logger.OnLogData(func(data []byte) {
|
||||||
|
select {
|
||||||
|
case sendChan <- data:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case msg := <-ch:
|
case <-c.Request.Context().Done():
|
||||||
_, err := c.Writer.Write(msg)
|
cancel()
|
||||||
if err != nil {
|
return
|
||||||
// just break the loop if we can't write for some reason
|
case <-pm.shutdownCtx.Done():
|
||||||
return
|
cancel()
|
||||||
}
|
return
|
||||||
|
case data := <-sendChan:
|
||||||
|
c.Writer.Write(data)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
case <-notify:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
|
||||||
c.Header("Content-Type", "text/event-stream")
|
|
||||||
c.Header("Cache-Control", "no-cache")
|
|
||||||
c.Header("Connection", "keep-alive")
|
|
||||||
c.Header("X-Content-Type-Options", "nosniff")
|
|
||||||
|
|
||||||
logMonitorId := c.Param("logMonitorID")
|
|
||||||
logger, err := pm.getLogger(logMonitorId)
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ch := logger.Subscribe()
|
|
||||||
defer logger.Unsubscribe(ch)
|
|
||||||
|
|
||||||
notify := c.Request.Context().Done()
|
|
||||||
|
|
||||||
// Send history first if not skipped
|
|
||||||
_, skipHistory := c.GetQuery("no-history")
|
|
||||||
if !skipHistory {
|
|
||||||
history := logger.GetHistory()
|
|
||||||
if len(history) != 0 {
|
|
||||||
c.SSEvent("message", string(history))
|
|
||||||
c.Writer.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stream new logs
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case msg := <-ch:
|
|
||||||
c.SSEvent("message", string(msg))
|
|
||||||
c.Writer.Flush()
|
|
||||||
case <-notify:
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -183,11 +183,20 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||||
|
|
||||||
|
model1Config := getTestSimpleResponderConfig("model1")
|
||||||
|
model1Config.Name = "Model 1"
|
||||||
|
model1Config.Description = "Model 1 description is used for testing"
|
||||||
|
|
||||||
|
model2Config := getTestSimpleResponderConfig("model2")
|
||||||
|
model2Config.Name = " " // empty whitespace only strings will get ignored
|
||||||
|
model2Config.Description = " "
|
||||||
|
|
||||||
config := Config{
|
config := Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": model1Config,
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": model2Config,
|
||||||
"model3": getTestSimpleResponderConfig("model3"),
|
"model3": getTestSimpleResponderConfig("model3"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
@@ -213,6 +222,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|||||||
var response struct {
|
var response struct {
|
||||||
Data []map[string]interface{} `json:"data"`
|
Data []map[string]interface{} `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||||
}
|
}
|
||||||
@@ -227,6 +237,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|||||||
"model3": {},
|
"model3": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// make all models
|
||||||
for _, model := range response.Data {
|
for _, model := range response.Data {
|
||||||
modelID, ok := model["id"].(string)
|
modelID, ok := model["id"].(string)
|
||||||
assert.True(t, ok, "model ID should be a string")
|
assert.True(t, ok, "model ID should be a string")
|
||||||
@@ -245,6 +256,21 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|||||||
ownedBy, ok := model["owned_by"].(string)
|
ownedBy, ok := model["owned_by"].(string)
|
||||||
assert.True(t, ok, "owned_by should be a string")
|
assert.True(t, ok, "owned_by should be a string")
|
||||||
assert.Equal(t, "llama-swap", ownedBy)
|
assert.Equal(t, "llama-swap", ownedBy)
|
||||||
|
|
||||||
|
// check for optional name and description
|
||||||
|
if modelID == "model1" {
|
||||||
|
name, ok := model["name"].(string)
|
||||||
|
assert.True(t, ok, "name should be a string")
|
||||||
|
assert.Equal(t, "Model 1", name)
|
||||||
|
description, ok := model["description"].(string)
|
||||||
|
assert.True(t, ok, "description should be a string")
|
||||||
|
assert.Equal(t, "Model 1 description is used for testing", description)
|
||||||
|
} else {
|
||||||
|
_, exists := model["name"]
|
||||||
|
assert.False(t, exists, "unexpected name field for model: %s", modelID)
|
||||||
|
_, exists = model["description"]
|
||||||
|
assert.False(t, exists, "unexpected description field for model: %s", modelID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure all expected models were returned
|
// Ensure all expected models were returned
|
||||||
|
|||||||
+61
-109
@@ -6,6 +6,9 @@ const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
|
|||||||
export interface Model {
|
export interface Model {
|
||||||
id: string;
|
id: string;
|
||||||
state: ModelStatus;
|
state: ModelStatus;
|
||||||
|
name: string;
|
||||||
|
description: string;
|
||||||
|
unlisted: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface APIProviderType {
|
interface APIProviderType {
|
||||||
@@ -13,12 +16,18 @@ interface APIProviderType {
|
|||||||
listModels: () => Promise<Model[]>;
|
listModels: () => Promise<Model[]>;
|
||||||
unloadAllModels: () => Promise<void>;
|
unloadAllModels: () => Promise<void>;
|
||||||
loadModel: (model: string) => Promise<void>;
|
loadModel: (model: string) => Promise<void>;
|
||||||
enableProxyLogs: (enabled: boolean) => void;
|
enableAPIEvents: (enabled: boolean) => void;
|
||||||
enableUpstreamLogs: (enabled: boolean) => void;
|
|
||||||
enableModelUpdates: (enabled: boolean) => void;
|
|
||||||
proxyLogs: string;
|
proxyLogs: string;
|
||||||
upstreamLogs: string;
|
upstreamLogs: string;
|
||||||
}
|
}
|
||||||
|
interface LogData {
|
||||||
|
source: "upstream" | "proxy";
|
||||||
|
data: string;
|
||||||
|
}
|
||||||
|
interface APIEventEnvelope {
|
||||||
|
type: "modelStatus" | "logData";
|
||||||
|
data: string;
|
||||||
|
}
|
||||||
|
|
||||||
const APIContext = createContext<APIProviderType | undefined>(undefined);
|
const APIContext = createContext<APIProviderType | undefined>(undefined);
|
||||||
type APIProviderProps = {
|
type APIProviderProps = {
|
||||||
@@ -30,6 +39,7 @@ export function APIProvider({ children }: APIProviderProps) {
|
|||||||
const [upstreamLogs, setUpstreamLogs] = useState("");
|
const [upstreamLogs, setUpstreamLogs] = useState("");
|
||||||
const proxyEventSource = useRef<EventSource | null>(null);
|
const proxyEventSource = useRef<EventSource | null>(null);
|
||||||
const upstreamEventSource = useRef<EventSource | null>(null);
|
const upstreamEventSource = useRef<EventSource | null>(null);
|
||||||
|
const apiEventSource = useRef<EventSource | null>(null);
|
||||||
|
|
||||||
const [models, setModels] = useState<Model[]>([]);
|
const [models, setModels] = useState<Model[]>([]);
|
||||||
const modelStatusEventSource = useRef<EventSource | null>(null);
|
const modelStatusEventSource = useRef<EventSource | null>(null);
|
||||||
@@ -41,104 +51,58 @@ export function APIProvider({ children }: APIProviderProps) {
|
|||||||
});
|
});
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const handleProxyMessage = useCallback(
|
const enableAPIEvents = useCallback((enabled: boolean) => {
|
||||||
(e: MessageEvent) => {
|
if (!enabled) {
|
||||||
appendLog(e.data, setProxyLogs);
|
apiEventSource.current?.close();
|
||||||
},
|
apiEventSource.current = null;
|
||||||
[proxyLogs, appendLog]
|
return;
|
||||||
);
|
}
|
||||||
|
|
||||||
const handleUpstreamMessage = useCallback(
|
let retryCount = 0;
|
||||||
(e: MessageEvent) => {
|
const initialDelay = 1000; // 1 second
|
||||||
appendLog(e.data, setUpstreamLogs);
|
|
||||||
},
|
|
||||||
[appendLog]
|
|
||||||
);
|
|
||||||
|
|
||||||
const enableProxyLogs = useCallback(
|
const connect = () => {
|
||||||
(enabled: boolean) => {
|
const eventSource = new EventSource("/api/events");
|
||||||
if (enabled) {
|
|
||||||
let retryCount = 0;
|
|
||||||
const maxRetries = 3;
|
|
||||||
const initialDelay = 1000; // 1 second
|
|
||||||
|
|
||||||
const connect = () => {
|
eventSource.onmessage = (e: MessageEvent) => {
|
||||||
const eventSource = new EventSource("/logs/streamSSE/proxy");
|
try {
|
||||||
|
const message = JSON.parse(e.data) as APIEventEnvelope;
|
||||||
|
switch (message.type) {
|
||||||
|
case "modelStatus":
|
||||||
|
{
|
||||||
|
const models = JSON.parse(message.data) as Model[];
|
||||||
|
setModels(models);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
eventSource.onmessage = handleProxyMessage;
|
case "logData": {
|
||||||
eventSource.onerror = () => {
|
const logData = JSON.parse(message.data) as LogData;
|
||||||
eventSource.close();
|
switch (logData.source) {
|
||||||
if (retryCount < maxRetries) {
|
case "proxy":
|
||||||
retryCount++;
|
appendLog(logData.data, setProxyLogs);
|
||||||
const delay = initialDelay * Math.pow(2, retryCount - 1);
|
break;
|
||||||
setTimeout(connect, delay);
|
case "upstream":
|
||||||
|
appendLog(logData.data, setUpstreamLogs);
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
proxyEventSource.current = eventSource;
|
|
||||||
};
|
|
||||||
|
|
||||||
connect();
|
|
||||||
} else {
|
|
||||||
proxyEventSource.current?.close();
|
|
||||||
proxyEventSource.current = null;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[handleProxyMessage]
|
|
||||||
);
|
|
||||||
|
|
||||||
const enableUpstreamLogs = useCallback(
|
|
||||||
(enabled: boolean) => {
|
|
||||||
if (enabled) {
|
|
||||||
let retryCount = 0;
|
|
||||||
const maxRetries = 3;
|
|
||||||
const initialDelay = 1000; // 1 second
|
|
||||||
|
|
||||||
const connect = () => {
|
|
||||||
const eventSource = new EventSource("/logs/streamSSE/upstream");
|
|
||||||
|
|
||||||
eventSource.onmessage = handleUpstreamMessage;
|
|
||||||
eventSource.onerror = () => {
|
|
||||||
eventSource.close();
|
|
||||||
if (retryCount < maxRetries) {
|
|
||||||
retryCount++;
|
|
||||||
const delay = initialDelay * Math.pow(2, retryCount - 1);
|
|
||||||
setTimeout(connect, delay);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
upstreamEventSource.current = eventSource;
|
|
||||||
};
|
|
||||||
|
|
||||||
connect();
|
|
||||||
} else {
|
|
||||||
upstreamEventSource.current?.close();
|
|
||||||
upstreamEventSource.current = null;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[handleUpstreamMessage]
|
|
||||||
);
|
|
||||||
|
|
||||||
const enableModelUpdates = useCallback(
|
|
||||||
(enabled: boolean) => {
|
|
||||||
if (enabled) {
|
|
||||||
const eventSource = new EventSource("/api/modelsSSE");
|
|
||||||
eventSource.onmessage = (e: MessageEvent) => {
|
|
||||||
try {
|
|
||||||
const models = JSON.parse(e.data) as Model[];
|
|
||||||
setModels(models);
|
|
||||||
} catch (e) {
|
|
||||||
console.error(e);
|
|
||||||
}
|
}
|
||||||
};
|
} catch (err) {
|
||||||
modelStatusEventSource.current = eventSource;
|
console.error(e.data, err);
|
||||||
} else {
|
}
|
||||||
modelStatusEventSource.current?.close();
|
};
|
||||||
modelStatusEventSource.current = null;
|
eventSource.onerror = () => {
|
||||||
}
|
eventSource.close();
|
||||||
},
|
retryCount++;
|
||||||
[setModels]
|
const delay = Math.min(initialDelay * Math.pow(2, retryCount - 1), 5000);
|
||||||
);
|
setTimeout(connect, delay);
|
||||||
|
};
|
||||||
|
|
||||||
|
apiEventSource.current = eventSource;
|
||||||
|
};
|
||||||
|
|
||||||
|
connect();
|
||||||
|
}, []);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
return () => {
|
return () => {
|
||||||
@@ -196,23 +160,11 @@ export function APIProvider({ children }: APIProviderProps) {
|
|||||||
listModels,
|
listModels,
|
||||||
unloadAllModels,
|
unloadAllModels,
|
||||||
loadModel,
|
loadModel,
|
||||||
enableProxyLogs,
|
enableAPIEvents,
|
||||||
enableUpstreamLogs,
|
|
||||||
enableModelUpdates,
|
|
||||||
proxyLogs,
|
proxyLogs,
|
||||||
upstreamLogs,
|
upstreamLogs,
|
||||||
}),
|
}),
|
||||||
[
|
[models, listModels, unloadAllModels, loadModel, enableAPIEvents, proxyLogs, upstreamLogs]
|
||||||
models,
|
|
||||||
listModels,
|
|
||||||
unloadAllModels,
|
|
||||||
loadModel,
|
|
||||||
enableProxyLogs,
|
|
||||||
enableUpstreamLogs,
|
|
||||||
enableModelUpdates,
|
|
||||||
proxyLogs,
|
|
||||||
upstreamLogs,
|
|
||||||
]
|
|
||||||
);
|
);
|
||||||
|
|
||||||
return <APIContext.Provider value={value}>{children}</APIContext.Provider>;
|
return <APIContext.Provider value={value}>{children}</APIContext.Provider>;
|
||||||
|
|||||||
@@ -3,14 +3,12 @@ import { useAPI } from "../contexts/APIProvider";
|
|||||||
import { usePersistentState } from "../hooks/usePersistentState";
|
import { usePersistentState } from "../hooks/usePersistentState";
|
||||||
|
|
||||||
const LogViewer = () => {
|
const LogViewer = () => {
|
||||||
const { proxyLogs, upstreamLogs, enableProxyLogs, enableUpstreamLogs } = useAPI();
|
const { proxyLogs, upstreamLogs, enableAPIEvents } = useAPI();
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
enableProxyLogs(true);
|
enableAPIEvents(true);
|
||||||
enableUpstreamLogs(true);
|
|
||||||
return () => {
|
return () => {
|
||||||
enableProxyLogs(false);
|
enableAPIEvents(false);
|
||||||
enableUpstreamLogs(false);
|
|
||||||
};
|
};
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
|||||||
+27
-12
@@ -2,17 +2,21 @@ import { useState, useEffect, useCallback, useMemo } from "react";
|
|||||||
import { useAPI } from "../contexts/APIProvider";
|
import { useAPI } from "../contexts/APIProvider";
|
||||||
import { LogPanel } from "./LogViewer";
|
import { LogPanel } from "./LogViewer";
|
||||||
import { processEvalTimes } from "../lib/Utils";
|
import { processEvalTimes } from "../lib/Utils";
|
||||||
|
import { usePersistentState } from "../hooks/usePersistentState";
|
||||||
|
|
||||||
export default function ModelsPage() {
|
export default function ModelsPage() {
|
||||||
const { models, enableModelUpdates, unloadAllModels, loadModel, upstreamLogs, enableUpstreamLogs } = useAPI();
|
const { models, unloadAllModels, loadModel, upstreamLogs, enableAPIEvents } = useAPI();
|
||||||
const [isUnloading, setIsUnloading] = useState(false);
|
const [isUnloading, setIsUnloading] = useState(false);
|
||||||
|
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
|
||||||
|
|
||||||
|
const filteredModels = useMemo(() => {
|
||||||
|
return models.filter((model) => showUnlisted || !model.unlisted);
|
||||||
|
}, [models, showUnlisted]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
enableModelUpdates(true);
|
enableAPIEvents(true);
|
||||||
enableUpstreamLogs(true);
|
|
||||||
return () => {
|
return () => {
|
||||||
enableModelUpdates(false);
|
enableAPIEvents(false);
|
||||||
enableUpstreamLogs(false);
|
|
||||||
};
|
};
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
@@ -41,9 +45,15 @@ export default function ModelsPage() {
|
|||||||
<div className="w-full md:w-1/2 flex items-top">
|
<div className="w-full md:w-1/2 flex items-top">
|
||||||
<div className="card w-full">
|
<div className="card w-full">
|
||||||
<h2 className="">Models</h2>
|
<h2 className="">Models</h2>
|
||||||
<button className="btn" onClick={handleUnloadAllModels} disabled={isUnloading}>
|
<div className="flex justify-between">
|
||||||
{isUnloading ? "Unloading..." : "Unload All Models"}
|
<button className="btn" onClick={() => setShowUnlisted(!showUnlisted)} style={{ lineHeight: "1.2" }}>
|
||||||
</button>
|
{showUnlisted ? "🟢 unlisted" : "⚫️ unlisted"}
|
||||||
|
</button>
|
||||||
|
<button className="btn" onClick={handleUnloadAllModels} disabled={isUnloading}>
|
||||||
|
{isUnloading ? "Stopping ..." : "Stop All"}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
<table className="w-full mt-4">
|
<table className="w-full mt-4">
|
||||||
<thead>
|
<thead>
|
||||||
<tr className="border-b border-primary">
|
<tr className="border-b border-primary">
|
||||||
@@ -53,14 +63,19 @@ export default function ModelsPage() {
|
|||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
{models.map((model) => (
|
{filteredModels.map((model) => (
|
||||||
<tr key={model.id} className="border-b hover:bg-secondary-hover border-border">
|
<tr key={model.id} className="border-b hover:bg-secondary-hover border-border">
|
||||||
<td className="p-2">
|
<td className="p-2">
|
||||||
<a href={`/upstream/${model.id}/`} className="underline" target="_blank">
|
<a href={`/upstream/${model.id}/`} className="underline" target="_blank">
|
||||||
{model.id}
|
{model.name !== "" ? model.name : model.id}
|
||||||
</a>
|
</a>
|
||||||
|
{model.description != "" && (
|
||||||
|
<p>
|
||||||
|
<em>{model.description}</em>
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
</td>
|
</td>
|
||||||
<td className="p-2">
|
<td className="p-2 w-[50px]">
|
||||||
<button
|
<button
|
||||||
className="btn btn--sm"
|
className="btn btn--sm"
|
||||||
disabled={model.state !== "stopped"}
|
disabled={model.state !== "stopped"}
|
||||||
@@ -69,7 +84,7 @@ export default function ModelsPage() {
|
|||||||
Load
|
Load
|
||||||
</button>
|
</button>
|
||||||
</td>
|
</td>
|
||||||
<td className="p-2">
|
<td className="p-2 w-[75px]">
|
||||||
<span className={`status status--${model.state}`}>{model.state}</span>
|
<span className={`status status--${model.state}`}>{model.state}</span>
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|||||||
Reference in New Issue
Block a user