Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 558a72de17 | |||
| dc42cf366d | |||
| ba0a81937a | |||
| 574fdfabb4 | |||
| 5172cb2e12 | |||
| 5672cb03fd | |||
| 0f583163f7 | |||
| 7905fa9ea3 | |||
| bbaf172956 | |||
| fd50932dbc | |||
| 8c693e7fcf | |||
| 8f2af26a41 | |||
| 01d4838fb3 | |||
| accd65294b | |||
| 7472a25864 | |||
| cce0bc6aa1 | |||
| 36e25125e8 | |||
| 9a54273d15 | |||
| 87dce5f8f6 | |||
| 307e619521 | |||
| 6299c1b874 | |||
| a906cd459b | |||
| 78b2bc3dbc | |||
| 6a058e4191 | |||
| 1921e570d7 | |||
| c867a6c9a2 | |||
| 3bd1b23ce0 | |||
| 10606abf89 |
@@ -7,6 +7,10 @@ on:
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: 'Tag version to release (e.g. v144)'
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -20,15 +24,15 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
|
||||
-
|
||||
name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '23' # or your preferred version
|
||||
node-version: '23'
|
||||
-
|
||||
name: Install dependencies and build UI
|
||||
run: |
|
||||
@@ -46,4 +50,30 @@ jobs:
|
||||
version: '~> v2'
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
trigger-tap-update:
|
||||
runs-on: ubuntu-latest
|
||||
needs: goreleaser
|
||||
steps:
|
||||
- name: "Resolve tag to dispatch"
|
||||
id: tag
|
||||
run: |
|
||||
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
||||
echo "tag=${{ github.event.inputs.tag }}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "tag=${{ github.ref_name }}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: "Trigger tap repository update"
|
||||
uses: peter-evans/repository-dispatch@v2
|
||||
with:
|
||||
token: ${{ secrets.TAP_REPO_PAT }}
|
||||
repository: mostlygeek/homebrew-llama-swap
|
||||
event-type: new-release
|
||||
client-payload: |
|
||||
{
|
||||
"release": {
|
||||
"tag_name": "${{ steps.tag.outputs.tag }}"
|
||||
}
|
||||
}
|
||||
@@ -18,7 +18,7 @@ Written in golang, it is very easy to install (single binary with no dependencie
|
||||
- `v1/completions`
|
||||
- `v1/chat/completions`
|
||||
- `v1/embeddings`
|
||||
- `v1/rerank`
|
||||
- `v1/rerank`, `v1/reranking`, `rerank`
|
||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||
- ✅ llama-swap custom API endpoints
|
||||
@@ -27,6 +27,7 @@ Written in golang, it is very easy to install (single binary with no dependencie
|
||||
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||
- `/health` - just returns "OK"
|
||||
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||
@@ -70,13 +71,22 @@ See the [configuration documentation](https://github.com/mostlygeek/llama-swap/w
|
||||
|
||||
## Web UI
|
||||
|
||||
llama-swap ships with a web based interface to make it easier to monitor logs and check the status of models.
|
||||
llama-swap ships with a real time web interface to monitor logs and status of models:
|
||||
|
||||
<img width="1758" alt="image" src="https://github.com/user-attachments/assets/31ae5bcd-5efd-46b0-b64b-6db9e60196d3" />
|
||||
<img width="1786" height="1334" alt="image" src="https://github.com/user-attachments/assets/d6258cb9-1dad-40db-828f-2be860aec8fe" />
|
||||
|
||||
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
## Installation
|
||||
|
||||
Docker is the quickest way to try out llama-swap:
|
||||
llama-swap can be installed in multiple ways
|
||||
|
||||
1. Docker
|
||||
2. Homebrew (OSX and Linux)
|
||||
3. From release binaries
|
||||
4. From source
|
||||
|
||||
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
|
||||
Docker images with llama-swap and llama-server are built nightly.
|
||||
|
||||
```shell
|
||||
# use CPU inference comes with the example config above
|
||||
@@ -98,7 +108,7 @@ $ curl -s http://localhost:9292/v1/chat/completions \
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Docker images are built nightly for cuda, intel, vulcan, etc ...</summary>
|
||||
<summary>Docker images are built nightly with llama-server for cuda, intel, vulcan and musa.</summary>
|
||||
|
||||
They include:
|
||||
|
||||
@@ -121,9 +131,23 @@ $ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
|
||||
</details>
|
||||
|
||||
## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases))
|
||||
### Homebrew Install (macOS/Linux)
|
||||
|
||||
Pre-built binaries are available for Linux, Mac, Windows and FreeBSD. These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server.
|
||||
The latest release of `llama-swap` can be installed via [Homebrew](https://brew.sh).
|
||||
|
||||
```shell
|
||||
# Set up tap and install formula
|
||||
brew tap mostlygeek/llama-swap
|
||||
brew install llama-swap
|
||||
# Run llama-swap
|
||||
llama-swap --config path/to/config.yaml --listen localhost:8080
|
||||
```
|
||||
|
||||
This will install the `llama-swap` binary and make it available in your path. See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration)
|
||||
|
||||
### Pre-built Binaries ([download](https://github.com/mostlygeek/llama-swap/releases))
|
||||
|
||||
Binaries are available for Linux, Mac, Windows and FreeBSD. These are automatically published and are likely a few hours ahead of the docker releases. The binary install works with any OpenAI compatible server, not just llama-server.
|
||||
|
||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||
1. Create a configuration file, see the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration).
|
||||
@@ -137,7 +161,7 @@ Pre-built binaries are available for Linux, Mac, Windows and FreeBSD. These are
|
||||
### Building from source
|
||||
|
||||
1. Build requires golang and nodejs for the user interface.
|
||||
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
||||
1. `git clone https://github.com/mostlygeek/llama-swap.git`
|
||||
1. `make clean all`
|
||||
1. Binaries will be in `build/` subdirectory
|
||||
|
||||
@@ -173,6 +197,13 @@ Any OpenAI compatible server would work. llama-swap was originally designed for
|
||||
|
||||
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
||||
|
||||
## Contributors
|
||||
<a href="https://github.com/mostlygeek/llama-swap/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=mostlygeek/llama-swap" />
|
||||
</a>
|
||||
|
||||
Made with [contrib.rocks](https://contrib.rocks).
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://www.star-history.com/#mostlygeek/llama-swap&Date)
|
||||
|
||||
+20
-2
@@ -15,6 +15,12 @@ healthCheckTimeout: 500
|
||||
# - Valid log levels: debug, info, warn, error
|
||||
logLevel: info
|
||||
|
||||
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||
# - optional, default: 1000
|
||||
# - controls how many metrics are stored in memory before older ones are discarded
|
||||
# - useful for limiting memory usage when processing large volumes of metrics
|
||||
metricsMaxInMemory: 1000
|
||||
|
||||
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
||||
# - optional, default: 5800
|
||||
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
||||
@@ -49,7 +55,19 @@ models:
|
||||
cmd: |
|
||||
# ${latest-llama} is a macro that is defined above
|
||||
${latest-llama}
|
||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
--model path/to/llama-8B-Q4_K_M.gguf
|
||||
|
||||
# name: a display name for the model
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
name: "llama 3.1 8B"
|
||||
|
||||
# description: a description for the model
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
description: "A small but capable model used for quick testing"
|
||||
|
||||
# env: define an array of environment variables to inject into cmd's environment
|
||||
# - optional, default: empty array
|
||||
@@ -188,4 +206,4 @@ groups:
|
||||
members:
|
||||
- "forever-modelA"
|
||||
- "forever-modelB"
|
||||
- "forever-modelc"
|
||||
- "forever-modelc"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
healthCheckTimeout: 300
|
||||
logRequests: true
|
||||
metricsMaxInMemory: 1000
|
||||
|
||||
models:
|
||||
"qwen2.5":
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
The code in `event` was originally a part of https://github.com/kelindar/event (v1.5.2)
|
||||
|
||||
The original code uses a `time.Ticker` to process the event queue which caused a large increase in CPU usage ([#189](https://github.com/mostlygeek/llama-swap/issues/189)). This code was ported to remove the ticker and instead be more event driven.
|
||||
@@ -0,0 +1,30 @@
|
||||
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Default initializes a default in-process dispatcher
|
||||
var Default = NewDispatcherConfig(25000)
|
||||
|
||||
// On subscribes to an event, the type of the event will be automatically
|
||||
// inferred from the provided type. Must be constant for this to work. This
|
||||
// functions same way as Subscribe() but uses the default dispatcher instead.
|
||||
func On[T Event](handler func(T)) context.CancelFunc {
|
||||
return Subscribe(Default, handler)
|
||||
}
|
||||
|
||||
// OnType subscribes to an event with the specified event type. This functions
|
||||
// same way as SubscribeTo() but uses the default dispatcher instead.
|
||||
func OnType[T Event](eventType uint32, handler func(T)) context.CancelFunc {
|
||||
return SubscribeTo(Default, eventType, handler)
|
||||
}
|
||||
|
||||
// Emit writes an event into the dispatcher. This functions same way as
|
||||
// Publish() but uses the default dispatcher instead.
|
||||
func Emit[T Event](ev T) {
|
||||
Publish(Default, ev)
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
/*
|
||||
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
|
||||
BenchmarkSubcribeConcurrent-24 1826686 606.3 ns/op 1648 B/op 5 allocs/op
|
||||
*/
|
||||
func BenchmarkSubscribeConcurrent(b *testing.B) {
|
||||
d := NewDispatcher()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
unsub := Subscribe(d, func(ev MyEvent1) {})
|
||||
unsub()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultPublish(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Subscribe
|
||||
var count int64
|
||||
defer On(func(ev MyEvent1) {
|
||||
atomic.AddInt64(&count, 1)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
defer OnType(TypeEvent1, func(ev MyEvent1) {
|
||||
atomic.AddInt64(&count, 1)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
// Publish
|
||||
wg.Add(4)
|
||||
Emit(MyEvent1{})
|
||||
Emit(MyEvent1{})
|
||||
|
||||
// Wait and check
|
||||
wg.Wait()
|
||||
assert.Equal(t, int64(4), count)
|
||||
}
|
||||
+324
@@ -0,0 +1,324 @@
|
||||
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for details.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Event represents an event contract
|
||||
type Event interface {
|
||||
Type() uint32
|
||||
}
|
||||
|
||||
// registry holds an immutable sorted array of event mappings
|
||||
type registry struct {
|
||||
keys []uint32 // Event types (sorted)
|
||||
grps []any // Corresponding subscribers
|
||||
}
|
||||
|
||||
// ------------------------------------- Dispatcher -------------------------------------
|
||||
|
||||
// Dispatcher represents an event dispatcher.
|
||||
type Dispatcher struct {
|
||||
subs atomic.Pointer[registry] // Atomic pointer to immutable array
|
||||
done chan struct{} // Cancellation
|
||||
maxQueue int // Maximum queue size per consumer
|
||||
mu sync.Mutex // Only for writes (subscribe/unsubscribe)
|
||||
}
|
||||
|
||||
// NewDispatcher creates a new dispatcher of events.
|
||||
func NewDispatcher() *Dispatcher {
|
||||
return NewDispatcherConfig(50000)
|
||||
}
|
||||
|
||||
// NewDispatcherConfig creates a new dispatcher with configurable max queue size
|
||||
func NewDispatcherConfig(maxQueue int) *Dispatcher {
|
||||
d := &Dispatcher{
|
||||
done: make(chan struct{}),
|
||||
maxQueue: maxQueue,
|
||||
}
|
||||
|
||||
d.subs.Store(®istry{
|
||||
keys: make([]uint32, 0, 16),
|
||||
grps: make([]any, 0, 16),
|
||||
})
|
||||
return d
|
||||
}
|
||||
|
||||
// Close closes the dispatcher
|
||||
func (d *Dispatcher) Close() error {
|
||||
close(d.done)
|
||||
return nil
|
||||
}
|
||||
|
||||
// isClosed returns whether the dispatcher is closed or not
|
||||
func (d *Dispatcher) isClosed() bool {
|
||||
select {
|
||||
case <-d.done:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// findGroup performs a lock-free binary search for the event type
|
||||
func (d *Dispatcher) findGroup(eventType uint32) any {
|
||||
reg := d.subs.Load()
|
||||
keys := reg.keys
|
||||
|
||||
// Inlined binary search for better cache locality
|
||||
left, right := 0, len(keys)
|
||||
for left < right {
|
||||
mid := left + (right-left)/2
|
||||
if keys[mid] < eventType {
|
||||
left = mid + 1
|
||||
} else {
|
||||
right = mid
|
||||
}
|
||||
}
|
||||
|
||||
if left < len(keys) && keys[left] == eventType {
|
||||
return reg.grps[left]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Subscribe subscribes to an event, the type of the event will be automatically
|
||||
// inferred from the provided type. Must be constant for this to work.
|
||||
func Subscribe[T Event](broker *Dispatcher, handler func(T)) context.CancelFunc {
|
||||
var event T
|
||||
return SubscribeTo(broker, event.Type(), handler)
|
||||
}
|
||||
|
||||
// SubscribeTo subscribes to an event with the specified event type.
|
||||
func SubscribeTo[T Event](broker *Dispatcher, eventType uint32, handler func(T)) context.CancelFunc {
|
||||
if broker.isClosed() {
|
||||
panic(errClosed)
|
||||
}
|
||||
|
||||
broker.mu.Lock()
|
||||
defer broker.mu.Unlock()
|
||||
|
||||
// Check if group already exists
|
||||
if existing := broker.findGroup(eventType); existing != nil {
|
||||
grp := groupOf[T](eventType, existing)
|
||||
sub := grp.Add(handler)
|
||||
return func() {
|
||||
grp.Del(sub)
|
||||
}
|
||||
}
|
||||
|
||||
// Create new group
|
||||
grp := &group[T]{cond: sync.NewCond(new(sync.Mutex)), maxQueue: broker.maxQueue}
|
||||
sub := grp.Add(handler)
|
||||
|
||||
// Copy-on-write: insert new entry in sorted position
|
||||
old := broker.subs.Load()
|
||||
idx := sort.Search(len(old.keys), func(i int) bool {
|
||||
return old.keys[i] >= eventType
|
||||
})
|
||||
|
||||
// Create new arrays with space for one more element
|
||||
newKeys := make([]uint32, len(old.keys)+1)
|
||||
newGrps := make([]any, len(old.grps)+1)
|
||||
|
||||
// Copy elements before insertion point
|
||||
copy(newKeys[:idx], old.keys[:idx])
|
||||
copy(newGrps[:idx], old.grps[:idx])
|
||||
|
||||
// Insert new element
|
||||
newKeys[idx] = eventType
|
||||
newGrps[idx] = grp
|
||||
|
||||
// Copy elements after insertion point
|
||||
copy(newKeys[idx+1:], old.keys[idx:])
|
||||
copy(newGrps[idx+1:], old.grps[idx:])
|
||||
|
||||
// Atomically store the new registry (mutex ensures no concurrent writers)
|
||||
newReg := ®istry{keys: newKeys, grps: newGrps}
|
||||
broker.subs.Store(newReg)
|
||||
|
||||
return func() {
|
||||
grp.Del(sub)
|
||||
}
|
||||
}
|
||||
|
||||
// Publish writes an event into the dispatcher
|
||||
func Publish[T Event](broker *Dispatcher, ev T) {
|
||||
eventType := ev.Type()
|
||||
if sub := broker.findGroup(eventType); sub != nil {
|
||||
group := groupOf[T](eventType, sub)
|
||||
group.Broadcast(ev)
|
||||
}
|
||||
}
|
||||
|
||||
// Count counts the number of subscribers, this is for testing only.
|
||||
func (d *Dispatcher) count(eventType uint32) int {
|
||||
if group := d.findGroup(eventType); group != nil {
|
||||
return group.(interface{ Count() int }).Count()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// groupOf casts the subscriber group to the specified generic type
|
||||
func groupOf[T Event](eventType uint32, subs any) *group[T] {
|
||||
if group, ok := subs.(*group[T]); ok {
|
||||
return group
|
||||
}
|
||||
|
||||
panic(errConflict[T](eventType, subs))
|
||||
}
|
||||
|
||||
// ------------------------------------- Subscriber -------------------------------------
|
||||
|
||||
// consumer represents a consumer with a message queue
|
||||
type consumer[T Event] struct {
|
||||
queue []T // Current work queue
|
||||
stop bool // Stop signal
|
||||
}
|
||||
|
||||
// Listen listens to the event queue and processes events
|
||||
func (s *consumer[T]) Listen(c *sync.Cond, fn func(T)) {
|
||||
pending := make([]T, 0, 128)
|
||||
|
||||
for {
|
||||
c.L.Lock()
|
||||
for len(s.queue) == 0 {
|
||||
switch {
|
||||
case s.stop:
|
||||
c.L.Unlock()
|
||||
return
|
||||
default:
|
||||
c.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
// Swap buffers and reset the current queue
|
||||
temp := s.queue
|
||||
s.queue = pending[:0]
|
||||
pending = temp
|
||||
c.L.Unlock()
|
||||
|
||||
// Outside of the critical section, process the work
|
||||
for _, event := range pending {
|
||||
fn(event)
|
||||
}
|
||||
|
||||
// Notify potential publishers waiting due to backpressure
|
||||
c.Broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------- Subscriber Group -------------------------------------
|
||||
|
||||
// group represents a consumer group
|
||||
type group[T Event] struct {
|
||||
cond *sync.Cond
|
||||
subs []*consumer[T]
|
||||
maxQueue int // Maximum queue size per consumer
|
||||
maxLen int // Current maximum queue length across all consumers
|
||||
}
|
||||
|
||||
// Broadcast sends an event to all consumers
|
||||
func (s *group[T]) Broadcast(ev T) {
|
||||
s.cond.L.Lock()
|
||||
defer s.cond.L.Unlock()
|
||||
|
||||
// Calculate current maximum queue length
|
||||
s.maxLen = 0
|
||||
for _, sub := range s.subs {
|
||||
if len(sub.queue) > s.maxLen {
|
||||
s.maxLen = len(sub.queue)
|
||||
}
|
||||
}
|
||||
|
||||
// Backpressure: wait if queues are full
|
||||
for s.maxLen >= s.maxQueue {
|
||||
s.cond.Wait()
|
||||
|
||||
// Recalculate after wakeup
|
||||
s.maxLen = 0
|
||||
for _, sub := range s.subs {
|
||||
if len(sub.queue) > s.maxLen {
|
||||
s.maxLen = len(sub.queue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add event to all queues and track new maximum
|
||||
newMax := 0
|
||||
for _, sub := range s.subs {
|
||||
sub.queue = append(sub.queue, ev)
|
||||
if len(sub.queue) > newMax {
|
||||
newMax = len(sub.queue)
|
||||
}
|
||||
}
|
||||
s.maxLen = newMax
|
||||
s.cond.Broadcast() // Wake consumers
|
||||
}
|
||||
|
||||
// Add adds a subscriber to the list
|
||||
func (s *group[T]) Add(handler func(T)) *consumer[T] {
|
||||
sub := &consumer[T]{
|
||||
queue: make([]T, 0, 64),
|
||||
}
|
||||
|
||||
// Add the consumer to the list of active consumers
|
||||
s.cond.L.Lock()
|
||||
s.subs = append(s.subs, sub)
|
||||
s.cond.L.Unlock()
|
||||
|
||||
// Start listening
|
||||
go sub.Listen(s.cond, handler)
|
||||
return sub
|
||||
}
|
||||
|
||||
// Del removes a subscriber from the list
|
||||
func (s *group[T]) Del(sub *consumer[T]) {
|
||||
s.cond.L.Lock()
|
||||
defer s.cond.L.Unlock()
|
||||
|
||||
// Search and remove the subscriber
|
||||
sub.stop = true
|
||||
for i, v := range s.subs {
|
||||
if v == sub {
|
||||
copy(s.subs[i:], s.subs[i+1:])
|
||||
s.subs = s.subs[:len(s.subs)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------- Debugging -------------------------------------
|
||||
|
||||
var errClosed = fmt.Errorf("event dispatcher is closed")
|
||||
|
||||
// Count returns the number of subscribers in this group
|
||||
func (s *group[T]) Count() int {
|
||||
return len(s.subs)
|
||||
}
|
||||
|
||||
// String returns string representation of the type
|
||||
func (s *group[T]) String() string {
|
||||
typ := reflect.TypeOf(s).String()
|
||||
idx := strings.LastIndex(typ, "/")
|
||||
typ = typ[idx+1 : len(typ)-1]
|
||||
return typ
|
||||
}
|
||||
|
||||
// errConflict returns a conflict message
|
||||
func errConflict[T any](eventType uint32, existing any) string {
|
||||
var want T
|
||||
return fmt.Sprintf(
|
||||
"conflicting event type, want=<%T>, registered=<%s>, event=0x%v",
|
||||
want, existing, eventType,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,324 @@
|
||||
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPublish(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Subscribe, must be received in order
|
||||
var count int64
|
||||
defer Subscribe(d, func(ev MyEvent1) {
|
||||
assert.Equal(t, int(atomic.AddInt64(&count, 1)), ev.Number)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
// Publish
|
||||
wg.Add(3)
|
||||
Publish(d, MyEvent1{Number: 1})
|
||||
Publish(d, MyEvent1{Number: 2})
|
||||
Publish(d, MyEvent1{Number: 3})
|
||||
|
||||
// Wait and check
|
||||
wg.Wait()
|
||||
assert.Equal(t, int64(3), count)
|
||||
}
|
||||
|
||||
func TestUnsubscribe(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
assert.Equal(t, 0, d.count(TypeEvent1))
|
||||
unsubscribe := Subscribe(d, func(ev MyEvent1) {
|
||||
// Nothing
|
||||
})
|
||||
|
||||
assert.Equal(t, 1, d.count(TypeEvent1))
|
||||
unsubscribe()
|
||||
assert.Equal(t, 0, d.count(TypeEvent1))
|
||||
}
|
||||
|
||||
func TestConcurrent(t *testing.T) {
|
||||
const max = 1000000
|
||||
var count int64
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
d := NewDispatcher()
|
||||
defer Subscribe(d, func(ev MyEvent1) {
|
||||
if current := atomic.AddInt64(&count, 1); current == max {
|
||||
wg.Done()
|
||||
}
|
||||
})()
|
||||
|
||||
// Asynchronously publish
|
||||
go func() {
|
||||
for i := 0; i < max; i++ {
|
||||
Publish(d, MyEvent1{})
|
||||
}
|
||||
}()
|
||||
|
||||
defer Subscribe(d, func(ev MyEvent1) {
|
||||
// Subscriber that does nothing
|
||||
})()
|
||||
|
||||
wg.Wait()
|
||||
assert.Equal(t, max, int(count))
|
||||
}
|
||||
|
||||
func TestSubscribeDifferentType(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
assert.Panics(t, func() {
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent1) {})
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||
})
|
||||
}
|
||||
|
||||
func TestPublishDifferentType(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
assert.Panics(t, func() {
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||
Publish(d, MyEvent1{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestCloseDispatcher(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
defer SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})()
|
||||
|
||||
assert.NoError(t, d.Close())
|
||||
assert.Panics(t, func() {
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatrix(t *testing.T) {
|
||||
const amount = 1000
|
||||
for _, subs := range []int{1, 10, 100} {
|
||||
for _, topics := range []int{1, 10} {
|
||||
expected := subs * topics * amount
|
||||
t.Run(fmt.Sprintf("%dx%d", topics, subs), func(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(expected)
|
||||
|
||||
d := NewDispatcher()
|
||||
for i := 0; i < subs; i++ {
|
||||
for id := 0; id < topics; id++ {
|
||||
defer SubscribeTo(d, uint32(id), func(ev MyEvent3) {
|
||||
count.Add(1)
|
||||
wg.Done()
|
||||
})()
|
||||
}
|
||||
}
|
||||
|
||||
for n := 0; n < amount; n++ {
|
||||
for id := 0; id < topics; id++ {
|
||||
go Publish(d, MyEvent3{ID: id})
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
assert.Equal(t, expected, int(count.Load()))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentSubscriptionRace(t *testing.T) {
|
||||
// This test specifically targets the race condition that occurs when multiple
|
||||
// goroutines try to subscribe to different event types simultaneously.
|
||||
// Without the CAS loop, subscriptions could be lost due to registry corruption.
|
||||
|
||||
const numGoroutines = 100
|
||||
const numEventTypes = 50
|
||||
|
||||
d := NewDispatcher()
|
||||
defer d.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var receivedCount int64
|
||||
var subscribedTypes sync.Map // Thread-safe map
|
||||
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
// Start multiple goroutines that subscribe to different event types concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Each goroutine subscribes to a unique event type
|
||||
eventType := uint32(goroutineID%numEventTypes + 1000) // Offset to avoid collision with other tests
|
||||
|
||||
// Subscribe to the event type
|
||||
SubscribeTo(d, eventType, func(ev MyEvent3) {
|
||||
atomic.AddInt64(&receivedCount, 1)
|
||||
})
|
||||
|
||||
// Record that this type was subscribed
|
||||
subscribedTypes.Store(eventType, true)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all subscriptions to complete
|
||||
wg.Wait()
|
||||
|
||||
// Count the number of unique event types subscribed
|
||||
expectedTypes := 0
|
||||
subscribedTypes.Range(func(key, value interface{}) bool {
|
||||
expectedTypes++
|
||||
return true
|
||||
})
|
||||
|
||||
// Small delay to ensure all subscriptions are fully processed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Publish events to each subscribed type
|
||||
subscribedTypes.Range(func(key, value interface{}) bool {
|
||||
eventType := key.(uint32)
|
||||
Publish(d, MyEvent3{ID: int(eventType)})
|
||||
return true
|
||||
})
|
||||
|
||||
// Wait for all events to be processed
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Verify that we received at least the expected number of events
|
||||
// (there might be more if multiple goroutines subscribed to the same event type)
|
||||
received := atomic.LoadInt64(&receivedCount)
|
||||
assert.GreaterOrEqual(t, int(received), expectedTypes,
|
||||
"Should have received at least %d events, got %d", expectedTypes, received)
|
||||
|
||||
// Verify that we have the expected number of unique event types
|
||||
assert.Equal(t, numEventTypes, expectedTypes,
|
||||
"Should have exactly %d unique event types", numEventTypes)
|
||||
}
|
||||
|
||||
func TestConcurrentHandlerRegistration(t *testing.T) {
|
||||
const numGoroutines = 100
|
||||
|
||||
// Test concurrent subscriptions to the same event type
|
||||
t.Run("SameEventType", func(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
var handlerCount int64
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start multiple goroutines subscribing to the same event type (0x1)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
SubscribeTo(d, uint32(0x1), func(ev MyEvent1) {
|
||||
atomic.AddInt64(&handlerCount, 1)
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all handlers were registered by publishing an event
|
||||
atomic.StoreInt64(&handlerCount, 0)
|
||||
Publish(d, MyEvent1{})
|
||||
|
||||
// Small delay to ensure all handlers have executed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, int64(numGoroutines), atomic.LoadInt64(&handlerCount),
|
||||
"Not all handlers were registered due to race condition")
|
||||
})
|
||||
|
||||
// Test concurrent subscriptions to different event types
|
||||
t.Run("DifferentEventTypes", func(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
var wg sync.WaitGroup
|
||||
receivedEvents := make(map[uint32]*int64)
|
||||
|
||||
// Create multiple event types and subscribe concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
eventType := uint32(100 + i)
|
||||
counter := new(int64)
|
||||
receivedEvents[eventType] = counter
|
||||
|
||||
wg.Add(1)
|
||||
go func(et uint32, cnt *int64) {
|
||||
defer wg.Done()
|
||||
SubscribeTo(d, et, func(ev MyEvent3) {
|
||||
atomic.AddInt64(cnt, 1)
|
||||
})
|
||||
}(eventType, counter)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Publish events to all types
|
||||
for eventType := uint32(100); eventType < uint32(100+numGoroutines); eventType++ {
|
||||
Publish(d, MyEvent3{ID: int(eventType)})
|
||||
}
|
||||
|
||||
// Small delay to ensure all handlers have executed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Verify all event types received their events
|
||||
for eventType, counter := range receivedEvents {
|
||||
assert.Equal(t, int64(1), atomic.LoadInt64(counter),
|
||||
"Event type %d did not receive its event", eventType)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackpressure(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
d.maxQueue = 10
|
||||
|
||||
var processedCount int64
|
||||
unsub := SubscribeTo(d, uint32(0x200), func(ev MyEvent3) {
|
||||
atomic.AddInt64(&processedCount, 1)
|
||||
})
|
||||
defer unsub()
|
||||
|
||||
const eventsToPublish = 1000
|
||||
for i := 0; i < eventsToPublish; i++ {
|
||||
Publish(d, MyEvent3{ID: 0x200})
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify all events were eventually processed
|
||||
finalProcessed := atomic.LoadInt64(&processedCount)
|
||||
assert.Equal(t, int64(eventsToPublish), finalProcessed)
|
||||
t.Logf("Events processed: %d/%d", finalProcessed, eventsToPublish)
|
||||
}
|
||||
|
||||
// ------------------------------------- Test Events -------------------------------------
|
||||
|
||||
const (
|
||||
TypeEvent1 = 0x1
|
||||
TypeEvent2 = 0x2
|
||||
)
|
||||
|
||||
type MyEvent1 struct {
|
||||
Number int
|
||||
}
|
||||
|
||||
func (t MyEvent1) Type() uint32 { return TypeEvent1 }
|
||||
|
||||
type MyEvent2 struct {
|
||||
Text string
|
||||
}
|
||||
|
||||
func (t MyEvent2) Type() uint32 { return TypeEvent2 }
|
||||
|
||||
type MyEvent3 struct {
|
||||
ID int
|
||||
}
|
||||
|
||||
func (t MyEvent3) Type() uint32 { return uint32(t.ID) }
|
||||
@@ -3,6 +3,7 @@ module github.com/mostlygeek/llama-swap
|
||||
go 1.23.0
|
||||
|
||||
require (
|
||||
github.com/billziss-gh/golib v0.2.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
@@ -12,7 +13,6 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/billziss-gh/golib v0.2.0 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
|
||||
@@ -32,8 +32,6 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
|
||||
+110
-111
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/proxy"
|
||||
)
|
||||
|
||||
@@ -53,137 +54,135 @@ func main() {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
proxyManager := proxy.New(config)
|
||||
|
||||
// Setup channels for server management
|
||||
reloadChan := make(chan *proxy.ProxyManager)
|
||||
exitChan := make(chan struct{})
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Create server with initial handler
|
||||
srv := &http.Server{
|
||||
Addr: *listenStr,
|
||||
Handler: proxyManager,
|
||||
Addr: *listenStr,
|
||||
}
|
||||
|
||||
// Support for watching config and reloading when it changes
|
||||
reloadProxyManager := func() {
|
||||
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||
config, err = proxy.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Configuration Changed")
|
||||
currentPM.Shutdown()
|
||||
srv.Handler = proxy.New(config)
|
||||
fmt.Println("Configuration Reloaded")
|
||||
|
||||
// wait a few seconds and tell any UI to reload
|
||||
time.AfterFunc(3*time.Second, func() {
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateEnd,
|
||||
})
|
||||
})
|
||||
} else {
|
||||
config, err = proxy.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error, unable to load configuration: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
srv.Handler = proxy.New(config)
|
||||
}
|
||||
}
|
||||
|
||||
// load the initial proxy manager
|
||||
reloadProxyManager()
|
||||
debouncedReload := debounce(time.Second, reloadProxyManager)
|
||||
if *watchConfig {
|
||||
defer event.On(func(e proxy.ConfigFileChangedEvent) {
|
||||
if e.ReloadingState == proxy.ReloadingStateStart {
|
||||
debouncedReload()
|
||||
}
|
||||
})()
|
||||
|
||||
fmt.Println("Watching Configuration for changes")
|
||||
go func() {
|
||||
absConfigPath, err := filepath.Abs(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error getting absolute path for watching config file: %v\n", err)
|
||||
return
|
||||
}
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
fmt.Printf("Error creating file watcher: %v. File watching disabled.\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
configDir := filepath.Dir(absConfigPath)
|
||||
err = watcher.Add(configDir)
|
||||
if err != nil {
|
||||
fmt.Printf("Error adding config path directory (%s) to watcher: %v. File watching disabled.", configDir, err)
|
||||
return
|
||||
}
|
||||
|
||||
defer watcher.Close()
|
||||
for {
|
||||
select {
|
||||
case changeEvent := <-watcher.Events:
|
||||
if changeEvent.Name == absConfigPath && (changeEvent.Has(fsnotify.Write) || changeEvent.Has(fsnotify.Create) || changeEvent.Has(fsnotify.Remove)) {
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateStart,
|
||||
})
|
||||
} else if changeEvent.Name == filepath.Join(configDir, "..data") && changeEvent.Has(fsnotify.Create) {
|
||||
// the change for k8s configmap
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateStart,
|
||||
})
|
||||
}
|
||||
|
||||
case err := <-watcher.Errors:
|
||||
log.Printf("File watcher error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// shutdown on signal
|
||||
go func() {
|
||||
sig := <-sigChan
|
||||
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
if pm, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||
pm.Shutdown()
|
||||
} else {
|
||||
fmt.Println("srv.Handler is not of type *proxy.ProxyManager")
|
||||
}
|
||||
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
fmt.Printf("Server shutdown error: %v\n", err)
|
||||
}
|
||||
close(exitChan)
|
||||
}()
|
||||
|
||||
// Start server
|
||||
fmt.Printf("llama-swap listening on %s\n", *listenStr)
|
||||
go func() {
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
fmt.Printf("Fatal server error: %v\n", err)
|
||||
close(exitChan)
|
||||
log.Fatalf("Fatal server error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Handle config reloads and signals
|
||||
go func() {
|
||||
currentManager := proxyManager
|
||||
for {
|
||||
select {
|
||||
case newManager := <-reloadChan:
|
||||
log.Println("Config change detected, waiting for in-flight requests to complete...")
|
||||
// Stop old manager processes gracefully (this waits for in-flight requests)
|
||||
currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
|
||||
// Now do a full shutdown to clear the process map
|
||||
currentManager.Shutdown()
|
||||
currentManager = newManager
|
||||
srv.Handler = newManager
|
||||
log.Println("Server handler updated with new config")
|
||||
case sig := <-sigChan:
|
||||
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
currentManager.Shutdown()
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
fmt.Printf("Server shutdown error: %v\n", err)
|
||||
}
|
||||
close(exitChan)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Start file watcher if requested
|
||||
if *watchConfig {
|
||||
absConfigPath, err := filepath.Abs(*configPath)
|
||||
if err != nil {
|
||||
log.Printf("Error getting absolute path for config: %v. File watching disabled.", err)
|
||||
} else {
|
||||
go watchConfigFileWithReload(absConfigPath, reloadChan)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for exit signal
|
||||
<-exitChan
|
||||
}
|
||||
|
||||
// watchConfigFileWithReload monitors the configuration file and sends new ProxyManager instances through reloadChan.
|
||||
func watchConfigFileWithReload(configPath string, reloadChan chan<- *proxy.ProxyManager) {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Printf("Error creating file watcher: %v. File watching disabled.", err)
|
||||
return
|
||||
}
|
||||
defer watcher.Close()
|
||||
|
||||
err = watcher.Add(configPath)
|
||||
if err != nil {
|
||||
log.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", configPath, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Watching config file for changes: %s", configPath)
|
||||
|
||||
var debounceTimer *time.Timer
|
||||
debounceDuration := 2 * time.Second
|
||||
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// We only care about writes to the specific config file
|
||||
if event.Name == configPath && event.Has(fsnotify.Write) {
|
||||
// Reset or start the debounce timer
|
||||
if debounceTimer != nil {
|
||||
debounceTimer.Stop()
|
||||
}
|
||||
debounceTimer = time.AfterFunc(debounceDuration, func() {
|
||||
log.Printf("Config file modified: %s, reloading...", event.Name)
|
||||
|
||||
// Try up to 3 times with exponential backoff
|
||||
var newConfig proxy.Config
|
||||
var err error
|
||||
for retries := 0; retries < 3; retries++ {
|
||||
// Load new configuration
|
||||
newConfig, err = proxy.LoadConfig(configPath)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
log.Printf("Error loading new config (attempt %d/3): %v", retries+1, err)
|
||||
if retries < 2 {
|
||||
time.Sleep(time.Duration(1<<retries) * time.Second)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Failed to load new config after retries: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create new ProxyManager with new config
|
||||
newPM := proxy.New(newConfig)
|
||||
reloadChan <- newPM
|
||||
log.Println("Config reloaded successfully")
|
||||
})
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
log.Println("File watcher error channel closed.")
|
||||
return
|
||||
}
|
||||
log.Printf("File watcher error: %v", err)
|
||||
func debounce(interval time.Duration, f func()) func() {
|
||||
var timer *time.Timer
|
||||
return func() {
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
timer = time.AfterFunc(interval, f)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,20 +35,90 @@ func main() {
|
||||
|
||||
// Set up the handler function using the provided response message
|
||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
"h_content_length": c.Request.Header.Get("Content-Length"),
|
||||
"request_body": string(bodyBytes),
|
||||
})
|
||||
// Check if streaming is requested
|
||||
// Query is checked instead of JSON body since that event stream conflicts with other tests
|
||||
isStreaming := c.Query("stream") == "true"
|
||||
|
||||
if isStreaming {
|
||||
// Set headers for streaming
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
// Send 10 "asdf" tokens
|
||||
for i := 0; i < 10; i++ {
|
||||
data := gin.H{
|
||||
"created": time.Now().Unix(),
|
||||
"choices": []gin.H{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": gin.H{
|
||||
"content": "asdf",
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
c.SSEvent("message", data)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
// Send final data with usage info
|
||||
finalData := gin.H{
|
||||
"usage": gin.H{
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 35,
|
||||
},
|
||||
// add timings to simulate llama.cpp
|
||||
"timings": gin.H{
|
||||
"prompt_n": 25,
|
||||
"prompt_ms": 13,
|
||||
"predicted_n": 10,
|
||||
"predicted_ms": 17,
|
||||
"predicted_per_second": 10,
|
||||
},
|
||||
}
|
||||
c.SSEvent("message", finalData)
|
||||
c.Writer.Flush()
|
||||
|
||||
// Send [DONE]
|
||||
c.SSEvent("message", "[DONE]")
|
||||
c.Writer.Flush()
|
||||
} else {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
"h_content_length": c.Request.Header.Get("Content-Length"),
|
||||
"request_body": string(bodyBytes),
|
||||
"usage": gin.H{
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 35,
|
||||
},
|
||||
"timings": gin.H{
|
||||
"prompt_n": 25,
|
||||
"prompt_ms": 13,
|
||||
"predicted_n": 10,
|
||||
"predicted_ms": 17,
|
||||
"predicted_per_second": 10,
|
||||
},
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// for issue #62 to check model name strips profile slug
|
||||
@@ -74,6 +144,11 @@ func main() {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
"usage": gin.H{
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 35,
|
||||
},
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
@@ -28,6 +28,10 @@ type ModelConfig struct {
|
||||
Unlisted bool `yaml:"unlisted"`
|
||||
UseModelName string `yaml:"useModelName"`
|
||||
|
||||
// #179 for /v1/models
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
// Limit concurrency of HTTP requests to process
|
||||
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||
|
||||
@@ -48,6 +52,8 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
Unlisted: false,
|
||||
UseModelName: "",
|
||||
ConcurrencyLimit: 0,
|
||||
Name: "",
|
||||
Description: "",
|
||||
}
|
||||
|
||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||
@@ -136,6 +142,7 @@ type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
@@ -188,6 +195,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
MetricsMaxInMemory: 1000,
|
||||
}
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
@@ -249,6 +257,10 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
|
||||
// Strip comments from command fields before macro expansion
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
|
||||
for macroName, macroValue := range config.Macros {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
@@ -400,3 +412,16 @@ func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func StripComments(cmdStr string) string {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
|
||||
@@ -104,6 +104,8 @@ models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
name: "Model 1"
|
||||
description: "This is model 1"
|
||||
aliases:
|
||||
- "m1"
|
||||
- "model-one"
|
||||
@@ -168,6 +170,8 @@ groups:
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
Name: "Model 1",
|
||||
Description: "This is model 1",
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
@@ -192,6 +196,7 @@ groups:
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -325,3 +326,117 @@ models:
|
||||
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripComments(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no comments",
|
||||
input: "echo hello\necho world",
|
||||
expected: "echo hello\necho world",
|
||||
},
|
||||
{
|
||||
name: "single comment line",
|
||||
input: "# this is a comment\necho hello",
|
||||
expected: "echo hello",
|
||||
},
|
||||
{
|
||||
name: "multiple comment lines",
|
||||
input: "# comment 1\necho hello\n# comment 2\necho world",
|
||||
expected: "echo hello\necho world",
|
||||
},
|
||||
{
|
||||
name: "comment with spaces",
|
||||
input: " # indented comment\necho hello",
|
||||
expected: "echo hello",
|
||||
},
|
||||
{
|
||||
name: "empty lines preserved",
|
||||
input: "echo hello\n\necho world",
|
||||
expected: "echo hello\n\necho world",
|
||||
},
|
||||
{
|
||||
name: "only comments",
|
||||
input: "# comment 1\n# comment 2",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := StripComments(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("StripComments() = %q, expected %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_MacroInCommentStrippedBeforeExpansion(t *testing.T) {
|
||||
// Test case that reproduces the original bug where a macro in a comment
|
||||
// would get expanded and cause the comment text to be included in the command
|
||||
content := `
|
||||
startPort: 9990
|
||||
macros:
|
||||
"latest-llama": >
|
||||
/user/llama.cpp/build/bin/llama-server
|
||||
--port ${PORT}
|
||||
|
||||
models:
|
||||
"test-model":
|
||||
cmd: |
|
||||
# ${latest-llama} is a macro that is defined above
|
||||
${latest-llama}
|
||||
--model /path/to/model.gguf
|
||||
-ngl 99
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Get the sanitized command
|
||||
sanitizedCmd, err := SanitizeCommand(config.Models["test-model"].Cmd)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Join the command for easier inspection
|
||||
cmdStr := strings.Join(sanitizedCmd, " ")
|
||||
|
||||
// Verify that comment text is NOT present in the final command as separate arguments
|
||||
commentWords := []string{"is", "macro", "that", "defined", "above"}
|
||||
for _, word := range commentWords {
|
||||
found := slices.Contains(sanitizedCmd, word)
|
||||
assert.False(t, found, "Comment text '%s' should not be present as a separate argument in final command", word)
|
||||
}
|
||||
|
||||
// Verify that the actual command components ARE present
|
||||
expectedParts := []string{
|
||||
"/user/llama.cpp/build/bin/llama-server",
|
||||
"--port",
|
||||
"9990",
|
||||
"--model",
|
||||
"/path/to/model.gguf",
|
||||
"-ngl",
|
||||
"99",
|
||||
}
|
||||
|
||||
for _, part := range expectedParts {
|
||||
assert.Contains(t, cmdStr, part, "Expected command part '%s' not found in final command", part)
|
||||
}
|
||||
|
||||
// Verify the server path appears exactly once (not duplicated due to macro expansion)
|
||||
serverPath := "/user/llama.cpp/build/bin/llama-server"
|
||||
count := strings.Count(cmdStr, serverPath)
|
||||
assert.Equal(t, 1, count, "Expected exactly 1 occurrence of server path, found %d", count)
|
||||
|
||||
// Verify the expected final command structure
|
||||
expectedCmd := "/user/llama.cpp/build/bin/llama-server --port 9990 --model /path/to/model.gguf -ngl 99"
|
||||
assert.Equal(t, expectedCmd, cmdStr, "Final command does not match expected structure")
|
||||
}
|
||||
|
||||
@@ -193,6 +193,7 @@ groups:
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package proxy
|
||||
|
||||
// package level registry of the different event types
|
||||
|
||||
const ProcessStateChangeEventID = 0x01
|
||||
const ChatCompletionStatsEventID = 0x02
|
||||
const ConfigFileChangedEventID = 0x03
|
||||
const LogDataEventID = 0x04
|
||||
const TokenMetricsEventID = 0x05
|
||||
|
||||
type ProcessStateChangeEvent struct {
|
||||
ProcessName string
|
||||
NewState ProcessState
|
||||
OldState ProcessState
|
||||
}
|
||||
|
||||
func (e ProcessStateChangeEvent) Type() uint32 {
|
||||
return ProcessStateChangeEventID
|
||||
}
|
||||
|
||||
type ChatCompletionStats struct {
|
||||
TokensGenerated int
|
||||
}
|
||||
|
||||
func (e ChatCompletionStats) Type() uint32 {
|
||||
return ChatCompletionStatsEventID
|
||||
}
|
||||
|
||||
type ReloadingState int
|
||||
|
||||
const (
|
||||
ReloadingStateStart ReloadingState = iota
|
||||
ReloadingStateEnd
|
||||
)
|
||||
|
||||
type ConfigFileChangedEvent struct {
|
||||
ReloadingState ReloadingState
|
||||
}
|
||||
|
||||
func (e ConfigFileChangedEvent) Type() uint32 {
|
||||
return ConfigFileChangedEventID
|
||||
}
|
||||
|
||||
type LogDataEvent struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (e LogDataEvent) Type() uint32 {
|
||||
return LogDataEventID
|
||||
}
|
||||
+14
-31
@@ -2,10 +2,13 @@ package proxy
|
||||
|
||||
import (
|
||||
"container/ring"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
)
|
||||
|
||||
type LogLevel int
|
||||
@@ -18,7 +21,7 @@ const (
|
||||
)
|
||||
|
||||
type LogMonitor struct {
|
||||
clients map[chan []byte]bool
|
||||
eventbus *event.Dispatcher
|
||||
mu sync.RWMutex
|
||||
buffer *ring.Ring
|
||||
bufferMu sync.RWMutex
|
||||
@@ -37,11 +40,11 @@ func NewLogMonitor() *LogMonitor {
|
||||
|
||||
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||
return &LogMonitor{
|
||||
clients: make(map[chan []byte]bool),
|
||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||
stdout: stdout,
|
||||
level: LevelInfo,
|
||||
prefix: "",
|
||||
eventbus: event.NewDispatcherConfig(1000),
|
||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||
stdout: stdout,
|
||||
level: LevelInfo,
|
||||
prefix: "",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,34 +84,14 @@ func (w *LogMonitor) GetHistory() []byte {
|
||||
return history
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Subscribe() chan []byte {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
ch := make(chan []byte, 100)
|
||||
w.clients[ch] = true
|
||||
return ch
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Unsubscribe(ch chan []byte) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
delete(w.clients, ch)
|
||||
close(ch)
|
||||
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
|
||||
return event.Subscribe(w.eventbus, func(e LogDataEvent) {
|
||||
callback(e.Data)
|
||||
})
|
||||
}
|
||||
|
||||
func (w *LogMonitor) broadcast(msg []byte) {
|
||||
w.mu.RLock()
|
||||
defer w.mu.RUnlock()
|
||||
|
||||
for client := range w.clients {
|
||||
select {
|
||||
case client <- msg:
|
||||
default:
|
||||
// If client buffer is full, skip
|
||||
}
|
||||
}
|
||||
event.Publish(w.eventbus, LogDataEvent{Data: msg})
|
||||
}
|
||||
|
||||
func (w *LogMonitor) SetPrefix(prefix string) {
|
||||
|
||||
+13
-22
@@ -10,38 +10,29 @@ import (
|
||||
func TestLogMonitor(t *testing.T) {
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Test subscription
|
||||
client1 := logMonitor.Subscribe()
|
||||
client2 := logMonitor.Subscribe()
|
||||
|
||||
defer logMonitor.Unsubscribe(client1)
|
||||
defer logMonitor.Unsubscribe(client2)
|
||||
// A WaitGroup is used to wait for all the expected writes to complete
|
||||
var wg sync.WaitGroup
|
||||
|
||||
client1Messages := make([]byte, 0)
|
||||
client2Messages := make([]byte, 0)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
defer logMonitor.OnLogData(func(data []byte) {
|
||||
client1Messages = append(client1Messages, data...)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case data := <-client1:
|
||||
client1Messages = append(client1Messages, data...)
|
||||
case data := <-client2:
|
||||
client2Messages = append(client2Messages, data...)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer logMonitor.OnLogData(func(data []byte) {
|
||||
client2Messages = append(client2Messages, data...)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
wg.Add(6) // 2 x 3 writes
|
||||
|
||||
logMonitor.Write([]byte("1"))
|
||||
logMonitor.Write([]byte("2"))
|
||||
logMonitor.Write([]byte("3"))
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
// wait for all writes to complete
|
||||
wg.Wait()
|
||||
|
||||
// Check the buffer
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
|
||||
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
writer := &MetricsResponseWriter{
|
||||
ResponseWriter: c.Writer,
|
||||
metricsRecorder: &MetricsRecorder{
|
||||
metricsMonitor: pm.metricsMonitor,
|
||||
realModelName: realModelName,
|
||||
isStreaming: gjson.GetBytes(bodyBytes, "stream").Bool(),
|
||||
startTime: time.Now(),
|
||||
},
|
||||
}
|
||||
c.Writer = writer
|
||||
c.Next()
|
||||
|
||||
rec := writer.metricsRecorder
|
||||
rec.processBody(writer.body)
|
||||
}
|
||||
}
|
||||
|
||||
type MetricsRecorder struct {
|
||||
metricsMonitor *MetricsMonitor
|
||||
realModelName string
|
||||
isStreaming bool
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
// processBody handles response processing after request completes
|
||||
func (rec *MetricsRecorder) processBody(body []byte) {
|
||||
if rec.isStreaming {
|
||||
rec.processStreamingResponse(body)
|
||||
} else {
|
||||
rec.processNonStreamingResponse(body)
|
||||
}
|
||||
}
|
||||
|
||||
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool {
|
||||
usage := jsonData.Get("usage")
|
||||
if !usage.Exists() {
|
||||
return false
|
||||
}
|
||||
|
||||
// default values
|
||||
outputTokens := int(jsonData.Get("usage.completion_tokens").Int())
|
||||
inputTokens := int(jsonData.Get("usage.prompt_tokens").Int())
|
||||
tokensPerSecond := -1.0
|
||||
durationMs := int(time.Since(rec.startTime).Milliseconds())
|
||||
|
||||
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
||||
if timings := jsonData.Get("timings"); timings.Exists() {
|
||||
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
|
||||
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
|
||||
}
|
||||
|
||||
rec.metricsMonitor.addMetrics(TokenMetrics{
|
||||
Timestamp: time.Now(),
|
||||
Model: rec.realModelName,
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
TokensPerSecond: tokensPerSecond,
|
||||
DurationMs: durationMs,
|
||||
})
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (rec *MetricsRecorder) processStreamingResponse(body []byte) {
|
||||
// Iterate **backwards** through the lines looking for the data payload with
|
||||
// usage data
|
||||
lines := bytes.Split(body, []byte("\n"))
|
||||
|
||||
for i := len(lines) - 1; i >= 0; i-- {
|
||||
line := bytes.TrimSpace(lines[i])
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// SSE payload always follows "data:"
|
||||
prefix := []byte("data:")
|
||||
if !bytes.HasPrefix(line, prefix) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[len(prefix):])
|
||||
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(data, []byte("[DONE]")) {
|
||||
// [DONE] line itself contains nothing of interest.
|
||||
continue
|
||||
}
|
||||
|
||||
if gjson.ValidBytes(data) {
|
||||
if rec.parseAndRecordMetrics(gjson.ParseBytes(data)) {
|
||||
return // short circuit if a metric was recorded
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) {
|
||||
if len(body) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Parse JSON to extract usage information
|
||||
if gjson.ValidBytes(body) {
|
||||
rec.parseAndRecordMetrics(gjson.ParseBytes(body))
|
||||
}
|
||||
}
|
||||
|
||||
// MetricsResponseWriter captures the entire response for non-streaming
|
||||
type MetricsResponseWriter struct {
|
||||
gin.ResponseWriter
|
||||
body []byte
|
||||
metricsRecorder *MetricsRecorder
|
||||
}
|
||||
|
||||
func (w *MetricsResponseWriter) Write(b []byte) (int, error) {
|
||||
n, err := w.ResponseWriter.Write(b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
w.body = append(w.body, b...)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (w *MetricsResponseWriter) WriteHeader(statusCode int) {
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *MetricsResponseWriter) Header() http.Header {
|
||||
return w.ResponseWriter.Header()
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
)
|
||||
|
||||
// TokenMetrics represents parsed token statistics from llama-server logs
|
||||
type TokenMetrics struct {
|
||||
ID int `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Model string `json:"model"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
}
|
||||
|
||||
// TokenMetricsEvent represents a token metrics event
|
||||
type TokenMetricsEvent struct {
|
||||
Metrics TokenMetrics
|
||||
}
|
||||
|
||||
func (e TokenMetricsEvent) Type() uint32 {
|
||||
return TokenMetricsEventID // defined in events.go
|
||||
}
|
||||
|
||||
// MetricsMonitor parses llama-server output for token statistics
|
||||
type MetricsMonitor struct {
|
||||
mu sync.RWMutex
|
||||
metrics []TokenMetrics
|
||||
maxMetrics int
|
||||
nextID int
|
||||
}
|
||||
|
||||
func NewMetricsMonitor(config *Config) *MetricsMonitor {
|
||||
maxMetrics := config.MetricsMaxInMemory
|
||||
if maxMetrics <= 0 {
|
||||
maxMetrics = 1000 // Default fallback
|
||||
}
|
||||
|
||||
mp := &MetricsMonitor{
|
||||
maxMetrics: maxMetrics,
|
||||
}
|
||||
|
||||
return mp
|
||||
}
|
||||
|
||||
// addMetrics adds a new metric to the collection and publishes an event
|
||||
func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
metric.ID = mp.nextID
|
||||
mp.nextID++
|
||||
mp.metrics = append(mp.metrics, metric)
|
||||
if len(mp.metrics) > mp.maxMetrics {
|
||||
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
||||
}
|
||||
|
||||
event.Emit(TokenMetricsEvent{Metrics: metric})
|
||||
}
|
||||
|
||||
// GetMetrics returns a copy of the current metrics
|
||||
func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
result := make([]TokenMetrics, len(mp.metrics))
|
||||
copy(result, mp.metrics)
|
||||
return result
|
||||
}
|
||||
|
||||
// GetMetricsJSON returns metrics as JSON
|
||||
func (mp *MetricsMonitor) GetMetricsJSON() ([]byte, error) {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
return json.Marshal(mp.metrics)
|
||||
}
|
||||
+6
-3
@@ -13,6 +13,8 @@ import (
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
)
|
||||
|
||||
type ProcessState string
|
||||
@@ -127,6 +129,7 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
|
||||
|
||||
p.state = newState
|
||||
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
|
||||
return p.state, nil
|
||||
}
|
||||
|
||||
@@ -209,11 +212,11 @@ func (p *Process) start() error {
|
||||
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
|
||||
p.state = StateStopped // force it into a stopped state
|
||||
return fmt.Errorf(
|
||||
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||
err, curState, swapErr,
|
||||
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||
strings.Join(args, " "), err, curState, swapErr,
|
||||
)
|
||||
}
|
||||
return fmt.Errorf("start() failed: %v", err)
|
||||
return fmt.Errorf("start() failed for command '%s': %v", strings.Join(args, " "), err)
|
||||
}
|
||||
|
||||
// Capture the exit error for later signalling
|
||||
|
||||
@@ -107,7 +107,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "start() failed: ")
|
||||
assert.Contains(t, w.Body.String(), "start() failed for command 'nonexistent-command':")
|
||||
}
|
||||
|
||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
|
||||
+55
-21
@@ -2,7 +2,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
@@ -33,7 +33,13 @@ type ProxyManager struct {
|
||||
upstreamLogger *LogMonitor
|
||||
muxLogger *LogMonitor
|
||||
|
||||
metricsMonitor *MetricsMonitor
|
||||
|
||||
processGroups map[string]*ProcessGroup
|
||||
|
||||
// shutdown signaling
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func New(config Config) *ProxyManager {
|
||||
@@ -64,6 +70,8 @@ func New(config Config) *ProxyManager {
|
||||
upstreamLogger.SetLogLevel(LevelInfo)
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
|
||||
pm := &ProxyManager{
|
||||
config: config,
|
||||
ginEngine: gin.New(),
|
||||
@@ -72,7 +80,12 @@ func New(config Config) *ProxyManager {
|
||||
muxLogger: stdoutLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
|
||||
metricsMonitor: NewMetricsMonitor(&config),
|
||||
|
||||
processGroups: make(map[string]*ProcessGroup),
|
||||
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownCancel: shutdownCancel,
|
||||
}
|
||||
|
||||
// create the process groups
|
||||
@@ -140,14 +153,18 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
c.Next()
|
||||
})
|
||||
|
||||
mm := MetricsMiddleware(pm)
|
||||
|
||||
// Set up routes using the Gin engine
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler)
|
||||
// Support legacy /v1/completions api, see issue #12
|
||||
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler)
|
||||
|
||||
// Support embeddings
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler)
|
||||
|
||||
// Support audio/speech endpoint
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||
@@ -158,9 +175,7 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
// in proxymanager_loghandlers.go
|
||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
||||
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
|
||||
|
||||
/**
|
||||
* User Interface Endpoints
|
||||
@@ -176,6 +191,9 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
|
||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
||||
if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil {
|
||||
@@ -262,6 +280,7 @@ func (pm *ProxyManager) Shutdown() {
|
||||
}(processGroup)
|
||||
}
|
||||
wg.Wait()
|
||||
pm.shutdownCancel()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
||||
@@ -289,32 +308,41 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
data := []interface{}{}
|
||||
data := make([]gin.H, 0, len(pm.config.Models))
|
||||
createdTime := time.Now().Unix()
|
||||
|
||||
for id, modelConfig := range pm.config.Models {
|
||||
if modelConfig.Unlisted {
|
||||
continue
|
||||
}
|
||||
|
||||
data = append(data, map[string]interface{}{
|
||||
record := gin.H{
|
||||
"id": id,
|
||||
"object": "model",
|
||||
"created": time.Now().Unix(),
|
||||
"created": createdTime,
|
||||
"owned_by": "llama-swap",
|
||||
})
|
||||
}
|
||||
|
||||
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||
record["name"] = name
|
||||
}
|
||||
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||
record["description"] = desc
|
||||
}
|
||||
|
||||
data = append(data, record)
|
||||
}
|
||||
|
||||
// Set the Content-Type header to application/json
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
if origin := c.Request.Header.Get("Origin"); origin != "" {
|
||||
// Set CORS headers if origin exists
|
||||
if origin := c.GetHeader("Origin"); origin != "" {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
// Encode the data as JSON and write it to the response writer
|
||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"object": "list", "data": data}); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
|
||||
return
|
||||
}
|
||||
// Use gin's JSON method which handles content-type and encoding
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
@@ -349,7 +377,13 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
|
||||
+121
-18
@@ -1,25 +1,30 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
Id string `json:"id"`
|
||||
State string `json:"state"`
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
}
|
||||
|
||||
func addApiHandlers(pm *ProxyManager) {
|
||||
// Add API endpoints for React to consume
|
||||
apiGroup := pm.ginEngine.Group("/api")
|
||||
{
|
||||
apiGroup.GET("/models", pm.apiListModels)
|
||||
apiGroup.GET("/modelsSSE", pm.apiListModelsSSE)
|
||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||
apiGroup.GET("/events", pm.apiSendEvents)
|
||||
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,37 +70,135 @@ func (pm *ProxyManager) getModelStatus() []Model {
|
||||
}
|
||||
}
|
||||
models = append(models, Model{
|
||||
Id: modelID,
|
||||
State: state,
|
||||
Id: modelID,
|
||||
Name: pm.config.Models[modelID].Name,
|
||||
Description: pm.config.Models[modelID].Description,
|
||||
State: state,
|
||||
Unlisted: pm.config.Models[modelID].Unlisted,
|
||||
})
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiListModels(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, pm.getModelStatus())
|
||||
type messageType string
|
||||
|
||||
const (
|
||||
msgTypeModelStatus messageType = "modelStatus"
|
||||
msgTypeLogData messageType = "logData"
|
||||
msgTypeMetrics messageType = "metrics"
|
||||
)
|
||||
|
||||
type messageEnvelope struct {
|
||||
Type messageType `json:"type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// stream the models as a SSE
|
||||
func (pm *ProxyManager) apiListModelsSSE(c *gin.Context) {
|
||||
// sends a stream of different message types that happen on the server
|
||||
func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
sendBuffer := make(chan messageEnvelope, 25)
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
sendModels := func() {
|
||||
data, err := json.Marshal(pm.getModelStatus())
|
||||
if err == nil {
|
||||
msg := messageEnvelope{Type: msgTypeModelStatus, Data: string(data)}
|
||||
select {
|
||||
case sendBuffer <- msg:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
sendLogData := func(source string, data []byte) {
|
||||
data, err := json.Marshal(gin.H{
|
||||
"source": source,
|
||||
"data": string(data),
|
||||
})
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeLogData, Data: string(data)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sendMetrics := func(metrics TokenMetrics) {
|
||||
jsonData, err := json.Marshal(metrics)
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeMetrics, Data: string(jsonData)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send updated models list
|
||||
*/
|
||||
defer event.On(func(e ProcessStateChangeEvent) {
|
||||
sendModels()
|
||||
})()
|
||||
defer event.On(func(e ConfigFileChangedEvent) {
|
||||
sendModels()
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send Log data
|
||||
*/
|
||||
defer pm.proxyLogger.OnLogData(func(data []byte) {
|
||||
sendLogData("proxy", data)
|
||||
})()
|
||||
defer pm.upstreamLogger.OnLogData(func(data []byte) {
|
||||
sendLogData("upstream", data)
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send Metrics data
|
||||
*/
|
||||
defer event.On(func(e TokenMetricsEvent) {
|
||||
sendMetrics(e.Metrics)
|
||||
})()
|
||||
|
||||
// send initial batch of data
|
||||
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||
sendModels()
|
||||
for _, metrics := range pm.metricsMonitor.GetMetrics() {
|
||||
sendMetrics(metrics)
|
||||
}
|
||||
|
||||
// Stream new events
|
||||
for {
|
||||
select {
|
||||
case <-notify:
|
||||
case <-c.Request.Context().Done():
|
||||
cancel()
|
||||
return
|
||||
default:
|
||||
models := pm.getModelStatus()
|
||||
c.SSEvent("message", models)
|
||||
case <-pm.shutdownCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case msg := <-sendBuffer:
|
||||
c.SSEvent("message", msg)
|
||||
c.Writer.Flush()
|
||||
<-time.After(1000 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
||||
jsonData, err := pm.metricsMonitor.GetMetricsJSON()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
|
||||
return
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", jsonData)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -34,10 +35,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
c.String(http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
ch := logger.Subscribe()
|
||||
defer logger.Unsubscribe(ch)
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
||||
@@ -55,57 +53,28 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Stream new logs
|
||||
sendChan := make(chan []byte, 10)
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer logger.OnLogData(func(data []byte) {
|
||||
select {
|
||||
case sendChan <- data:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
})()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
_, err := c.Writer.Write(msg)
|
||||
if err != nil {
|
||||
// just break the loop if we can't write for some reason
|
||||
return
|
||||
}
|
||||
case <-c.Request.Context().Done():
|
||||
cancel()
|
||||
return
|
||||
case <-pm.shutdownCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case data := <-sendChan:
|
||||
c.Writer.Write(data)
|
||||
flusher.Flush()
|
||||
case <-notify:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
logMonitorId := c.Param("logMonitorID")
|
||||
logger, err := pm.getLogger(logMonitorId)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
ch := logger.Subscribe()
|
||||
defer logger.Unsubscribe(ch)
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
|
||||
// Send history first if not skipped
|
||||
_, skipHistory := c.GetQuery("no-history")
|
||||
if !skipHistory {
|
||||
history := logger.GetHistory()
|
||||
if len(history) != 0 {
|
||||
c.SSEvent("message", string(history))
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Stream new logs
|
||||
for {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
c.SSEvent("message", string(msg))
|
||||
c.Writer.Flush()
|
||||
case <-notify:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+122
-6
@@ -165,9 +165,11 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
var response map[string]string
|
||||
var response map[string]interface{}
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
results[key] = response["responseMessage"]
|
||||
result, ok := response["responseMessage"].(string)
|
||||
assert.Equal(t, ok, true)
|
||||
results[key] = result
|
||||
mu.Unlock()
|
||||
}(key)
|
||||
|
||||
@@ -183,11 +185,20 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
|
||||
model1Config := getTestSimpleResponderConfig("model1")
|
||||
model1Config.Name = "Model 1"
|
||||
model1Config.Description = "Model 1 description is used for testing"
|
||||
|
||||
model2Config := getTestSimpleResponderConfig("model2")
|
||||
model2Config.Name = " " // empty whitespace only strings will get ignored
|
||||
model2Config.Description = " "
|
||||
|
||||
config := Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model1": model1Config,
|
||||
"model2": model2Config,
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
@@ -213,6 +224,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
var response struct {
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||
}
|
||||
@@ -227,6 +239,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
"model3": {},
|
||||
}
|
||||
|
||||
// make all models
|
||||
for _, model := range response.Data {
|
||||
modelID, ok := model["id"].(string)
|
||||
assert.True(t, ok, "model ID should be a string")
|
||||
@@ -245,6 +258,21 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
ownedBy, ok := model["owned_by"].(string)
|
||||
assert.True(t, ok, "owned_by should be a string")
|
||||
assert.Equal(t, "llama-swap", ownedBy)
|
||||
|
||||
// check for optional name and description
|
||||
if modelID == "model1" {
|
||||
name, ok := model["name"].(string)
|
||||
assert.True(t, ok, "name should be a string")
|
||||
assert.Equal(t, "Model 1", name)
|
||||
description, ok := model["description"].(string)
|
||||
assert.True(t, ok, "description should be a string")
|
||||
assert.Equal(t, "Model 1 description is used for testing", description)
|
||||
} else {
|
||||
_, exists := model["name"]
|
||||
assert.False(t, exists, "unexpected name field for model: %s", modelID)
|
||||
_, exists = model["description"]
|
||||
assert.False(t, exists, "unexpected description field for model: %s", modelID)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure all expected models were returned
|
||||
@@ -618,7 +646,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var response map[string]string
|
||||
var response map[string]interface{}
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
assert.Equal(t, "81", response["h_content_length"])
|
||||
assert.Equal(t, "model1", response["responseMessage"])
|
||||
@@ -646,7 +674,7 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var response map[string]string
|
||||
var response map[string]interface{}
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
|
||||
// `temperature` and `stream` are gone but model remains
|
||||
@@ -657,3 +685,91 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
// assert.Equal(t, "abc", response["y_param"])
|
||||
// t.Logf("%v", response)
|
||||
}
|
||||
|
||||
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
// Make a non-streaming request
|
||||
reqBody := `{"model":"model1", "stream": false}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Check that metrics were recorded
|
||||
metrics := proxy.metricsMonitor.GetMetrics()
|
||||
if !assert.NotEmpty(t, metrics, "metrics should be recorded for non-streaming request") {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the last metric has the correct model
|
||||
lastMetric := metrics[len(metrics)-1]
|
||||
assert.Equal(t, "model1", lastMetric.Model)
|
||||
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
|
||||
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
|
||||
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
|
||||
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
|
||||
}
|
||||
|
||||
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
// Make a streaming request
|
||||
reqBody := `{"model":"model1", "stream": true}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Check that metrics were recorded
|
||||
metrics := proxy.metricsMonitor.GetMetrics()
|
||||
if !assert.NotEmpty(t, metrics, "metrics should be recorded for streaming request") {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the last metric has the correct model
|
||||
lastMetric := metrics[len(metrics)-1]
|
||||
assert.Equal(t, "model1", lastMetric.Model)
|
||||
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
|
||||
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
|
||||
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
|
||||
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
|
||||
}
|
||||
|
||||
func TestProxyManager_HealthEndpoint(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "OK", rec.Body.String())
|
||||
}
|
||||
|
||||
Generated
+21
@@ -12,6 +12,8 @@
|
||||
"@tanstack/react-query": "^5.80.6",
|
||||
"react": "^19.1.0",
|
||||
"react-dom": "^19.1.0",
|
||||
"react-icons": "^5.5.0",
|
||||
"react-resizable-panels": "^3.0.4",
|
||||
"react-router-dom": "^7.6.2",
|
||||
"tailwindcss": "^4.1.8"
|
||||
},
|
||||
@@ -3460,6 +3462,15 @@
|
||||
"react": "^19.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/react-icons": {
|
||||
"version": "5.5.0",
|
||||
"resolved": "https://registry.npmjs.org/react-icons/-/react-icons-5.5.0.tgz",
|
||||
"integrity": "sha512-MEFcXdkP3dLo8uumGI5xN3lDFNsRtrjbOEKDLD7yv76v4wpnEq2Lt2qeHaQOr34I/wPN3s3+N08WkQ+CW37Xiw==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"react": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/react-refresh": {
|
||||
"version": "0.17.0",
|
||||
"resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.17.0.tgz",
|
||||
@@ -3470,6 +3481,16 @@
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/react-resizable-panels": {
|
||||
"version": "3.0.4",
|
||||
"resolved": "https://registry.npmjs.org/react-resizable-panels/-/react-resizable-panels-3.0.4.tgz",
|
||||
"integrity": "sha512-8Y4KNgV94XhUvI2LeByyPIjoUJb71M/0hyhtzkHaqpVHs+ZQs8b627HmzyhmVYi3C9YP6R+XD1KmG7hHjEZXFQ==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"react": "^16.14.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.14.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc"
|
||||
}
|
||||
},
|
||||
"node_modules/react-router": {
|
||||
"version": "7.6.2",
|
||||
"resolved": "https://registry.npmjs.org/react-router/-/react-router-7.6.2.tgz",
|
||||
|
||||
+3
-1
@@ -14,6 +14,8 @@
|
||||
"@tanstack/react-query": "^5.80.6",
|
||||
"react": "^19.1.0",
|
||||
"react-dom": "^19.1.0",
|
||||
"react-icons": "^5.5.0",
|
||||
"react-resizable-panels": "^3.0.4",
|
||||
"react-router-dom": "^7.6.2",
|
||||
"tailwindcss": "^4.1.8"
|
||||
},
|
||||
@@ -30,4 +32,4 @@
|
||||
"typescript-eslint": "^8.30.1",
|
||||
"vite": "^6.3.5"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+14
-6
@@ -3,16 +3,19 @@ import { useTheme } from "./contexts/ThemeProvider";
|
||||
import { APIProvider } from "./contexts/APIProvider";
|
||||
import LogViewerPage from "./pages/LogViewer";
|
||||
import ModelPage from "./pages/Models";
|
||||
import ActivityPage from "./pages/Activity";
|
||||
import { RiSunFill, RiMoonFill } from "react-icons/ri";
|
||||
|
||||
function App() {
|
||||
const theme = useTheme();
|
||||
const { isNarrow, toggleTheme, isDarkMode } = useTheme();
|
||||
|
||||
return (
|
||||
<Router basename="/ui/">
|
||||
<APIProvider>
|
||||
<div>
|
||||
<div className="flex flex-col h-screen">
|
||||
<nav className="bg-surface border-b border-border p-2 h-[75px]">
|
||||
<div className="flex items-center justify-between mx-auto px-4 h-full">
|
||||
<h1 className="flex items-center p-0">llama-swap</h1>
|
||||
{!isNarrow && <h1 className="flex items-center p-0">llama-swap</h1>}
|
||||
<div className="flex items-center space-x-4">
|
||||
<NavLink to="/" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
|
||||
Logs
|
||||
@@ -21,17 +24,22 @@ function App() {
|
||||
<NavLink to="/models" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
|
||||
Models
|
||||
</NavLink>
|
||||
<button className="btn btn--sm" onClick={theme.toggleTheme}>
|
||||
{theme.isDarkMode ? "🌙" : "☀️"}
|
||||
|
||||
<NavLink to="/activity" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
|
||||
Activity
|
||||
</NavLink>
|
||||
<button className="" onClick={toggleTheme}>
|
||||
{isDarkMode ? <RiMoonFill /> : <RiSunFill />}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<main className="mx-auto py-4 px-4">
|
||||
<main className="flex-1 overflow-auto p-4">
|
||||
<Routes>
|
||||
<Route path="/" element={<LogViewerPage />} />
|
||||
<Route path="/models" element={<ModelPage />} />
|
||||
<Route path="/activity" element={<ActivityPage />} />
|
||||
<Route path="*" element={<Navigate to="/" replace />} />
|
||||
</Routes>
|
||||
</main>
|
||||
|
||||
+93
-117
@@ -6,6 +6,9 @@ const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
|
||||
export interface Model {
|
||||
id: string;
|
||||
state: ModelStatus;
|
||||
name: string;
|
||||
description: string;
|
||||
unlisted: boolean;
|
||||
}
|
||||
|
||||
interface APIProviderType {
|
||||
@@ -13,26 +16,44 @@ interface APIProviderType {
|
||||
listModels: () => Promise<Model[]>;
|
||||
unloadAllModels: () => Promise<void>;
|
||||
loadModel: (model: string) => Promise<void>;
|
||||
enableProxyLogs: (enabled: boolean) => void;
|
||||
enableUpstreamLogs: (enabled: boolean) => void;
|
||||
enableModelUpdates: (enabled: boolean) => void;
|
||||
enableAPIEvents: (enabled: boolean) => void;
|
||||
proxyLogs: string;
|
||||
upstreamLogs: string;
|
||||
metrics: Metrics[];
|
||||
}
|
||||
|
||||
interface Metrics {
|
||||
id: number;
|
||||
timestamp: string;
|
||||
model: string;
|
||||
input_tokens: number;
|
||||
output_tokens: number;
|
||||
tokens_per_second: number;
|
||||
duration_ms: number;
|
||||
}
|
||||
|
||||
interface LogData {
|
||||
source: "upstream" | "proxy";
|
||||
data: string;
|
||||
}
|
||||
interface APIEventEnvelope {
|
||||
type: "modelStatus" | "logData" | "metrics";
|
||||
data: string;
|
||||
}
|
||||
|
||||
const APIContext = createContext<APIProviderType | undefined>(undefined);
|
||||
type APIProviderProps = {
|
||||
children: ReactNode;
|
||||
autoStartAPIEvents?: boolean;
|
||||
};
|
||||
|
||||
export function APIProvider({ children }: APIProviderProps) {
|
||||
export function APIProvider({ children, autoStartAPIEvents = true }: APIProviderProps) {
|
||||
const [proxyLogs, setProxyLogs] = useState("");
|
||||
const [upstreamLogs, setUpstreamLogs] = useState("");
|
||||
const proxyEventSource = useRef<EventSource | null>(null);
|
||||
const upstreamEventSource = useRef<EventSource | null>(null);
|
||||
const [metrics, setMetrics] = useState<Metrics[]>([]);
|
||||
const apiEventSource = useRef<EventSource | null>(null);
|
||||
|
||||
const [models, setModels] = useState<Model[]>([]);
|
||||
const modelStatusEventSource = useRef<EventSource | null>(null);
|
||||
|
||||
const appendLog = useCallback((newData: string, setter: React.Dispatch<React.SetStateAction<string>>) => {
|
||||
setter((prev) => {
|
||||
@@ -41,112 +62,78 @@ export function APIProvider({ children }: APIProviderProps) {
|
||||
});
|
||||
}, []);
|
||||
|
||||
const handleProxyMessage = useCallback(
|
||||
(e: MessageEvent) => {
|
||||
appendLog(e.data, setProxyLogs);
|
||||
},
|
||||
[proxyLogs, appendLog]
|
||||
);
|
||||
const enableAPIEvents = useCallback((enabled: boolean) => {
|
||||
if (!enabled) {
|
||||
apiEventSource.current?.close();
|
||||
apiEventSource.current = null;
|
||||
setMetrics([]);
|
||||
return;
|
||||
}
|
||||
|
||||
const handleUpstreamMessage = useCallback(
|
||||
(e: MessageEvent) => {
|
||||
appendLog(e.data, setUpstreamLogs);
|
||||
},
|
||||
[appendLog]
|
||||
);
|
||||
let retryCount = 0;
|
||||
const initialDelay = 1000; // 1 second
|
||||
|
||||
const enableProxyLogs = useCallback(
|
||||
(enabled: boolean) => {
|
||||
if (enabled) {
|
||||
let retryCount = 0;
|
||||
const maxRetries = 3;
|
||||
const initialDelay = 1000; // 1 second
|
||||
const connect = () => {
|
||||
const eventSource = new EventSource("/api/events");
|
||||
|
||||
const connect = () => {
|
||||
const eventSource = new EventSource("/logs/streamSSE/proxy");
|
||||
eventSource.onmessage = (e: MessageEvent) => {
|
||||
try {
|
||||
const message = JSON.parse(e.data) as APIEventEnvelope;
|
||||
switch (message.type) {
|
||||
case "modelStatus":
|
||||
{
|
||||
const models = JSON.parse(message.data) as Model[];
|
||||
setModels(models);
|
||||
}
|
||||
break;
|
||||
|
||||
eventSource.onmessage = handleProxyMessage;
|
||||
eventSource.onerror = () => {
|
||||
eventSource.close();
|
||||
if (retryCount < maxRetries) {
|
||||
retryCount++;
|
||||
const delay = initialDelay * Math.pow(2, retryCount - 1);
|
||||
setTimeout(connect, delay);
|
||||
}
|
||||
};
|
||||
case "logData":
|
||||
const logData = JSON.parse(message.data) as LogData;
|
||||
switch (logData.source) {
|
||||
case "proxy":
|
||||
appendLog(logData.data, setProxyLogs);
|
||||
break;
|
||||
case "upstream":
|
||||
appendLog(logData.data, setUpstreamLogs);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
proxyEventSource.current = eventSource;
|
||||
};
|
||||
|
||||
connect();
|
||||
} else {
|
||||
proxyEventSource.current?.close();
|
||||
proxyEventSource.current = null;
|
||||
}
|
||||
},
|
||||
[handleProxyMessage]
|
||||
);
|
||||
|
||||
const enableUpstreamLogs = useCallback(
|
||||
(enabled: boolean) => {
|
||||
if (enabled) {
|
||||
let retryCount = 0;
|
||||
const maxRetries = 3;
|
||||
const initialDelay = 1000; // 1 second
|
||||
|
||||
const connect = () => {
|
||||
const eventSource = new EventSource("/logs/streamSSE/upstream");
|
||||
|
||||
eventSource.onmessage = handleUpstreamMessage;
|
||||
eventSource.onerror = () => {
|
||||
eventSource.close();
|
||||
if (retryCount < maxRetries) {
|
||||
retryCount++;
|
||||
const delay = initialDelay * Math.pow(2, retryCount - 1);
|
||||
setTimeout(connect, delay);
|
||||
}
|
||||
};
|
||||
|
||||
upstreamEventSource.current = eventSource;
|
||||
};
|
||||
|
||||
connect();
|
||||
} else {
|
||||
upstreamEventSource.current?.close();
|
||||
upstreamEventSource.current = null;
|
||||
}
|
||||
},
|
||||
[handleUpstreamMessage]
|
||||
);
|
||||
|
||||
const enableModelUpdates = useCallback(
|
||||
(enabled: boolean) => {
|
||||
if (enabled) {
|
||||
const eventSource = new EventSource("/api/modelsSSE");
|
||||
eventSource.onmessage = (e: MessageEvent) => {
|
||||
try {
|
||||
const models = JSON.parse(e.data) as Model[];
|
||||
setModels(models);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
case "metrics":
|
||||
{
|
||||
const newMetric = JSON.parse(message.data) as Metrics;
|
||||
setMetrics((prevMetrics) => {
|
||||
return [newMetric, ...prevMetrics];
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
};
|
||||
modelStatusEventSource.current = eventSource;
|
||||
} else {
|
||||
modelStatusEventSource.current?.close();
|
||||
modelStatusEventSource.current = null;
|
||||
}
|
||||
},
|
||||
[setModels]
|
||||
);
|
||||
} catch (err) {
|
||||
console.error(e.data, err);
|
||||
}
|
||||
};
|
||||
eventSource.onerror = () => {
|
||||
eventSource.close();
|
||||
retryCount++;
|
||||
const delay = Math.min(initialDelay * Math.pow(2, retryCount - 1), 5000);
|
||||
setTimeout(connect, delay);
|
||||
};
|
||||
|
||||
apiEventSource.current = eventSource;
|
||||
};
|
||||
|
||||
connect();
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (autoStartAPIEvents) {
|
||||
enableAPIEvents(true);
|
||||
}
|
||||
|
||||
return () => {
|
||||
proxyEventSource.current?.close();
|
||||
upstreamEventSource.current?.close();
|
||||
modelStatusEventSource.current?.close();
|
||||
enableAPIEvents(false);
|
||||
};
|
||||
}, []);
|
||||
}, [enableAPIEvents, autoStartAPIEvents]);
|
||||
|
||||
const listModels = useCallback(async (): Promise<Model[]> => {
|
||||
try {
|
||||
@@ -196,23 +183,12 @@ export function APIProvider({ children }: APIProviderProps) {
|
||||
listModels,
|
||||
unloadAllModels,
|
||||
loadModel,
|
||||
enableProxyLogs,
|
||||
enableUpstreamLogs,
|
||||
enableModelUpdates,
|
||||
enableAPIEvents,
|
||||
proxyLogs,
|
||||
upstreamLogs,
|
||||
metrics,
|
||||
}),
|
||||
[
|
||||
models,
|
||||
listModels,
|
||||
unloadAllModels,
|
||||
loadModel,
|
||||
enableProxyLogs,
|
||||
enableUpstreamLogs,
|
||||
enableModelUpdates,
|
||||
proxyLogs,
|
||||
upstreamLogs,
|
||||
]
|
||||
[models, listModels, unloadAllModels, loadModel, enableAPIEvents, proxyLogs, upstreamLogs, metrics]
|
||||
);
|
||||
|
||||
return <APIContext.Provider value={value}>{children}</APIContext.Provider>;
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import { createContext, useContext, useEffect, type ReactNode } from "react";
|
||||
import { createContext, useContext, useEffect, type ReactNode, useMemo, useState } from "react";
|
||||
import { usePersistentState } from "../hooks/usePersistentState";
|
||||
|
||||
type ScreenWidth = "xs" | "sm" | "md" | "lg" | "xl" | "2xl";
|
||||
type ThemeContextType = {
|
||||
isDarkMode: boolean;
|
||||
screenWidth: ScreenWidth;
|
||||
isNarrow: boolean;
|
||||
toggleTheme: () => void;
|
||||
};
|
||||
|
||||
@@ -14,14 +17,46 @@ type ThemeProviderProps = {
|
||||
|
||||
export function ThemeProvider({ children }: ThemeProviderProps) {
|
||||
const [isDarkMode, setIsDarkMode] = usePersistentState<boolean>("theme", false);
|
||||
const [screenWidth, setScreenWidth] = useState<ScreenWidth>("md"); // Default to md
|
||||
|
||||
// matches tailwind classes
|
||||
// https://tailwindcss.com/docs/responsive-design
|
||||
useEffect(() => {
|
||||
const checkInnerWidth = () => {
|
||||
const innerWidth = window.innerWidth;
|
||||
if (innerWidth < 640) {
|
||||
setScreenWidth("xs");
|
||||
} else if (innerWidth < 768) {
|
||||
setScreenWidth("sm");
|
||||
} else if (innerWidth < 1024) {
|
||||
setScreenWidth("md");
|
||||
} else if (innerWidth < 1280) {
|
||||
setScreenWidth("lg");
|
||||
} else if (innerWidth < 1536) {
|
||||
setScreenWidth("xl");
|
||||
} else {
|
||||
setScreenWidth("2xl");
|
||||
}
|
||||
};
|
||||
|
||||
checkInnerWidth();
|
||||
window.addEventListener("resize", checkInnerWidth);
|
||||
|
||||
return () => window.removeEventListener("resize", checkInnerWidth);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
document.documentElement.setAttribute("data-theme", isDarkMode ? "dark" : "light");
|
||||
}, [isDarkMode]);
|
||||
|
||||
const toggleTheme = () => setIsDarkMode((prev) => !prev);
|
||||
const isNarrow = useMemo(() => {
|
||||
return screenWidth === "xs" || screenWidth === "sm" || screenWidth === "md";
|
||||
}, [screenWidth]);
|
||||
|
||||
return <ThemeContext.Provider value={{ isDarkMode, toggleTheme }}>{children}</ThemeContext.Provider>;
|
||||
return (
|
||||
<ThemeContext.Provider value={{ isDarkMode, toggleTheme, screenWidth, isNarrow }}>{children}</ThemeContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useTheme(): ThemeContextType {
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
export function processEvalTimes(text: string) {
|
||||
const lines = text.match(/^ *eval time.*$/gm) || [];
|
||||
|
||||
let totalTokens = 0;
|
||||
let totalTime = 0;
|
||||
|
||||
lines.forEach((line) => {
|
||||
const tokensMatch = line.match(/\/\s*(\d+)\s*tokens/);
|
||||
const timeMatch = line.match(/=\s*(\d+\.\d+)\s*ms/);
|
||||
|
||||
if (tokensMatch) totalTokens += parseFloat(tokensMatch[1]);
|
||||
if (timeMatch) totalTime += parseFloat(timeMatch[1]);
|
||||
});
|
||||
|
||||
const avgTokensPerSecond = totalTime > 0 ? totalTokens / (totalTime / 1000) : 0;
|
||||
|
||||
return [lines.length, totalTokens, Math.round(avgTokensPerSecond * 100) / 100];
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import { useAPI } from "../contexts/APIProvider";
|
||||
|
||||
const formatTimestamp = (timestamp: string): string => {
|
||||
return new Date(timestamp).toLocaleString();
|
||||
};
|
||||
|
||||
const formatSpeed = (speed: number): string => {
|
||||
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
|
||||
};
|
||||
|
||||
const formatDuration = (ms: number): string => {
|
||||
return (ms / 1000).toFixed(2) + "s";
|
||||
};
|
||||
|
||||
const ActivityPage = () => {
|
||||
const { metrics } = useAPI();
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (metrics.length > 0) {
|
||||
setError(null);
|
||||
}
|
||||
}, [metrics]);
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="p-6">
|
||||
<h1 className="text-2xl font-bold mb-4">Activity</h1>
|
||||
<div className="bg-red-50 border border-red-200 rounded-md p-4">
|
||||
<p className="text-red-800">{error}</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-6">
|
||||
<h1 className="text-2xl font-bold mb-4">Activity</h1>
|
||||
|
||||
{metrics.length === 0 ? (
|
||||
<div className="text-center py-8">
|
||||
<p className="text-gray-600">No metrics data available</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="overflow-x-auto">
|
||||
<table className="min-w-full divide-y">
|
||||
<thead>
|
||||
<tr>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Timestamp</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Model</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Input Tokens</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Output Tokens</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Generation Speed</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Duration</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody className="divide-y">
|
||||
{metrics.map((metric, index) => (
|
||||
<tr key={`${metric.id}-${index}`}>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatTimestamp(metric.timestamp)}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.model}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.input_tokens.toLocaleString()}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.output_tokens.toLocaleString()}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatSpeed(metric.tokens_per_second)}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatDuration(metric.duration_ms)}</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ActivityPage;
|
||||
+71
-55
@@ -1,24 +1,38 @@
|
||||
import { useState, useEffect, useRef, useMemo, useCallback } from "react";
|
||||
import { useAPI } from "../contexts/APIProvider";
|
||||
import { usePersistentState } from "../hooks/usePersistentState";
|
||||
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
|
||||
import {
|
||||
RiTextWrap,
|
||||
RiAlignJustify,
|
||||
RiFontSize,
|
||||
RiMenuSearchLine,
|
||||
RiMenuSearchFill,
|
||||
RiCloseCircleFill,
|
||||
} from "react-icons/ri";
|
||||
import { useTheme } from "../contexts/ThemeProvider";
|
||||
|
||||
const LogViewer = () => {
|
||||
const { proxyLogs, upstreamLogs, enableProxyLogs, enableUpstreamLogs } = useAPI();
|
||||
|
||||
useEffect(() => {
|
||||
enableProxyLogs(true);
|
||||
enableUpstreamLogs(true);
|
||||
return () => {
|
||||
enableProxyLogs(false);
|
||||
enableUpstreamLogs(false);
|
||||
};
|
||||
}, []);
|
||||
const { proxyLogs, upstreamLogs } = useAPI();
|
||||
const { isNarrow } = useTheme();
|
||||
const direction = isNarrow ? "vertical" : "horizontal";
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-5" style={{ height: "calc(100vh - 125px)" }}>
|
||||
<LogPanel id="proxy" title="Proxy Logs" logData={proxyLogs} />
|
||||
<LogPanel id="upstream" title="Upstream Logs" logData={upstreamLogs} />
|
||||
</div>
|
||||
<PanelGroup direction={direction} className="gap-2" autoSaveId={`logviewer-panel-group-${direction}`}>
|
||||
<Panel id="proxy" defaultSize={50} minSize={5} maxSize={100} collapsible={true}>
|
||||
<LogPanel id="proxy" title="Proxy Logs" logData={proxyLogs} />
|
||||
</Panel>
|
||||
<PanelResizeHandle
|
||||
className={
|
||||
direction === "horizontal"
|
||||
? "w-2 h-full bg-primary hover:bg-success transition-colors rounded"
|
||||
: "w-full h-2 bg-primary hover:bg-success transition-colors rounded"
|
||||
}
|
||||
/>
|
||||
<Panel id="upstream" defaultSize={50} minSize={5} maxSize={100} collapsible={true}>
|
||||
<LogPanel id="upstream" title="Upstream Logs" logData={upstreamLogs} />
|
||||
</Panel>
|
||||
</PanelGroup>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -26,17 +40,15 @@ interface LogPanelProps {
|
||||
id: string;
|
||||
title: string;
|
||||
logData: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export const LogPanel = ({ id, title, logData, className }: LogPanelProps) => {
|
||||
const [isCollapsed, setIsCollapsed] = usePersistentState(`logPanel-${id}-isCollapsed`, false);
|
||||
export const LogPanel = ({ id, title, logData }: LogPanelProps) => {
|
||||
const [filterRegex, setFilterRegex] = useState("");
|
||||
const [fontSize, setFontSize] = usePersistentState<"xxs" | "xs" | "small" | "normal">(
|
||||
`logPanel-${id}-fontSize`,
|
||||
"normal"
|
||||
);
|
||||
const [wrapText, setTextWrap] = usePersistentState(`logPanel-${id}-wrapText`, false);
|
||||
const [showFilter, setShowFilter] = usePersistentState(`logPanel-${id}-showFilter`, false);
|
||||
|
||||
const textWrapClass = useMemo(() => {
|
||||
return wrapText ? "whitespace-pre-wrap" : "whitespace-pre";
|
||||
@@ -57,6 +69,19 @@ export const LogPanel = ({ id, title, logData, className }: LogPanelProps) => {
|
||||
});
|
||||
}, []);
|
||||
|
||||
const toggleWrapText = useCallback(() => {
|
||||
setTextWrap((prev) => !prev);
|
||||
}, []);
|
||||
|
||||
const toggleFilter = useCallback(() => {
|
||||
if (showFilter) {
|
||||
setShowFilter(false);
|
||||
setFilterRegex(""); // Clear filter when closing
|
||||
} else {
|
||||
setShowFilter(true);
|
||||
}
|
||||
}, [filterRegex, setFilterRegex, showFilter]);
|
||||
|
||||
const fontSizeClass = useMemo(() => {
|
||||
switch (fontSize) {
|
||||
case "xxs":
|
||||
@@ -90,56 +115,47 @@ export const LogPanel = ({ id, title, logData, className }: LogPanelProps) => {
|
||||
}, [filteredLogs]);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`bg-surface border border-border rounded-lg overflow-hidden flex flex-col ${
|
||||
!isCollapsed && "h-full"
|
||||
} ${className || ""}`}
|
||||
>
|
||||
<div className="bg-surface border border-border rounded-lg overflow-hidden flex flex-col h-full">
|
||||
<div className="p-4 border-b border-border bg-secondary">
|
||||
<div className="flex flex-col md:flex-row md:items-center md:justify-between gap-4">
|
||||
{/* Title - Always full width on mobile, normal on desktop */}
|
||||
<div className="w-full md:w-auto" onClick={() => setIsCollapsed(!isCollapsed)}>
|
||||
<h3 className="m-0 text-lg">{title}</h3>
|
||||
<div className="flex items-center justify-between">
|
||||
<h3 className="m-0 text-lg p-0">{title}</h3>
|
||||
|
||||
<div className="flex gap-2 items-center">
|
||||
<button className="btn" onClick={toggleFontSize}>
|
||||
<RiFontSize />
|
||||
</button>
|
||||
<button className="btn" onClick={toggleWrapText}>
|
||||
{wrapText ? <RiTextWrap /> : <RiAlignJustify />}
|
||||
</button>
|
||||
<button className="btn" onClick={toggleFilter}>
|
||||
{showFilter ? <RiMenuSearchFill /> : <RiMenuSearchLine />}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col sm:flex-row gap-4 w-full md:w-auto">
|
||||
{/* Sizing Buttons - Stacks vertically on mobile */}
|
||||
<div className="flex flex-wrap gap-2">
|
||||
<button className="btn" onClick={toggleFontSize}>
|
||||
font: {fontSize}
|
||||
</button>
|
||||
<button className="btn" onClick={() => setTextWrap((prev) => !prev)}>
|
||||
{wrapText ? "wrap" : "wrap off"}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Filtering Options - Full width on mobile, normal on desktop */}
|
||||
<div className="flex flex-1 min-w-0 gap-2">
|
||||
{/* Filtering Options - Full width on mobile, normal on desktop */}
|
||||
{showFilter && (
|
||||
<div className="mt-2 w-full">
|
||||
<div className="flex gap-2 items-center w-full">
|
||||
<input
|
||||
type="text"
|
||||
className="flex-1 min-w-[120px] text-sm border p-2 rounded"
|
||||
className="w-full text-sm border p-2 rounded"
|
||||
placeholder="Filter logs..."
|
||||
value={filterRegex}
|
||||
onChange={(e) => setFilterRegex(e.target.value)}
|
||||
/>
|
||||
<button className="btn" onClick={() => setFilterRegex("")}>
|
||||
Clear
|
||||
<button className="pl-2" onClick={() => setFilterRegex("")}>
|
||||
<RiCloseCircleFill size="24" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="bg-background font-mono text-sm flex-1 overflow-hidden">
|
||||
<pre ref={preTagRef} className={`${textWrapClass} ${fontSizeClass} h-full overflow-auto p-4`}>
|
||||
{filteredLogs}
|
||||
</pre>
|
||||
</div>
|
||||
|
||||
{!isCollapsed && (
|
||||
<div className="flex-1 bg-background font-mono text-sm p-3 overflow-hidden">
|
||||
<pre
|
||||
ref={preTagRef}
|
||||
className={`h-full p-4 overflow-y-auto whitespace-pre min-h-0 ${textWrapClass} ${fontSizeClass}`}
|
||||
>
|
||||
{filteredLogs}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
+131
-85
@@ -1,20 +1,49 @@
|
||||
import { useState, useEffect, useCallback, useMemo } from "react";
|
||||
import { useState, useCallback, useMemo } from "react";
|
||||
import { useAPI } from "../contexts/APIProvider";
|
||||
import { LogPanel } from "./LogViewer";
|
||||
import { processEvalTimes } from "../lib/Utils";
|
||||
import { usePersistentState } from "../hooks/usePersistentState";
|
||||
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
|
||||
import { useTheme } from "../contexts/ThemeProvider";
|
||||
import { RiEyeFill, RiEyeOffFill, RiStopCircleLine } from "react-icons/ri";
|
||||
|
||||
export default function ModelsPage() {
|
||||
const { models, enableModelUpdates, unloadAllModels, loadModel, upstreamLogs, enableUpstreamLogs } = useAPI();
|
||||
const [isUnloading, setIsUnloading] = useState(false);
|
||||
const { isNarrow } = useTheme();
|
||||
const direction = isNarrow ? "vertical" : "horizontal";
|
||||
const { upstreamLogs } = useAPI();
|
||||
|
||||
useEffect(() => {
|
||||
enableModelUpdates(true);
|
||||
enableUpstreamLogs(true);
|
||||
return () => {
|
||||
enableModelUpdates(false);
|
||||
enableUpstreamLogs(false);
|
||||
};
|
||||
}, []);
|
||||
return (
|
||||
<PanelGroup direction={direction} className="gap-2" autoSaveId={`models-panel-group-${direction}`}>
|
||||
<Panel id="models" defaultSize={50} minSize={isNarrow ? 0 : 25} maxSize={100} collapsible={isNarrow}>
|
||||
<ModelsPanel />
|
||||
</Panel>
|
||||
|
||||
<PanelResizeHandle
|
||||
className={
|
||||
direction === "horizontal"
|
||||
? "w-2 h-full bg-primary hover:bg-success transition-colors rounded"
|
||||
: "w-full h-2 bg-primary hover:bg-success transition-colors rounded"
|
||||
}
|
||||
/>
|
||||
<Panel collapsible={true} defaultSize={50} minSize={0}>
|
||||
<div className="flex flex-col h-full space-y-4">
|
||||
{direction === "horizontal" && <StatsPanel />}
|
||||
<div className="flex-1 min-h-0">
|
||||
<LogPanel id="modelsupstream" title="Upstream Logs" logData={upstreamLogs} />
|
||||
</div>
|
||||
</div>
|
||||
</Panel>
|
||||
</PanelGroup>
|
||||
);
|
||||
}
|
||||
|
||||
function ModelsPanel() {
|
||||
const { models, loadModel, unloadAllModels } = useAPI();
|
||||
const [isUnloading, setIsUnloading] = useState(false);
|
||||
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
|
||||
|
||||
const filteredModels = useMemo(() => {
|
||||
return models.filter((model) => showUnlisted || !model.unlisted);
|
||||
}, [models, showUnlisted]);
|
||||
|
||||
const handleUnloadAllModels = useCallback(async () => {
|
||||
setIsUnloading(true);
|
||||
@@ -23,88 +52,105 @@ export default function ModelsPage() {
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
} finally {
|
||||
// at least give it a second to show the unloading message
|
||||
setTimeout(() => {
|
||||
setIsUnloading(false);
|
||||
}, 1000);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const [totalLines, totalTokens, avgTokensPerSecond] = useMemo(() => {
|
||||
return processEvalTimes(upstreamLogs);
|
||||
}, [upstreamLogs]);
|
||||
}, [unloadAllModels]);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="flex flex-col md:flex-row gap-4">
|
||||
{/* Left Column */}
|
||||
<div className="w-full md:w-1/2 flex items-top">
|
||||
<div className="card w-full">
|
||||
<h2 className="">Models</h2>
|
||||
<button className="btn" onClick={handleUnloadAllModels} disabled={isUnloading}>
|
||||
{isUnloading ? "Unloading..." : "Unload All Models"}
|
||||
</button>
|
||||
<table className="w-full mt-4">
|
||||
<thead>
|
||||
<tr className="border-b border-primary">
|
||||
<th className="text-left p-2">Name</th>
|
||||
<th className="text-left p-2"></th>
|
||||
<th className="text-left p-2">State</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{models.map((model) => (
|
||||
<tr key={model.id} className="border-b hover:bg-secondary-hover border-border">
|
||||
<td className="p-2">
|
||||
<a href={`/upstream/${model.id}/`} className="underline" target="_blank">
|
||||
{model.id}
|
||||
</a>
|
||||
</td>
|
||||
<td className="p-2">
|
||||
<button
|
||||
className="btn btn--sm"
|
||||
disabled={model.state !== "stopped"}
|
||||
onClick={() => loadModel(model.id)}
|
||||
>
|
||||
Load
|
||||
</button>
|
||||
</td>
|
||||
<td className="p-2">
|
||||
<span className={`status status--${model.state}`}>{model.state}</span>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<div className="card h-full flex flex-col">
|
||||
<div className="shrink-0">
|
||||
<h2>Models</h2>
|
||||
<div className="flex justify-between">
|
||||
<button
|
||||
className="btn flex items-center gap-2"
|
||||
onClick={() => setShowUnlisted(!showUnlisted)}
|
||||
style={{ lineHeight: "1.2" }}
|
||||
>
|
||||
{showUnlisted ? <RiEyeFill /> : <RiEyeOffFill />} unlisted
|
||||
</button>
|
||||
<button className="btn flex items-center gap-2" onClick={handleUnloadAllModels} disabled={isUnloading}>
|
||||
<RiStopCircleLine size="24" /> {isUnloading ? "Unloading..." : "Unload"}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Right Column */}
|
||||
<div className="w-full md:w-1/2 flex flex-col" style={{ height: "calc(100vh - 125px)" }}>
|
||||
<div className="card mb-4 min-h-[250px]">
|
||||
<h2>Log Stats</h2>
|
||||
<p className="italic my-2">note: eval logs from llama-server</p>
|
||||
<table className="w-full border border-gray-200">
|
||||
<tbody>
|
||||
<tr className="border-b border-gray-200">
|
||||
<td className="py-2 px-4 font-medium border-r border-gray-200">Requests</td>
|
||||
<td className="py-2 px-4 text-right">{totalLines}</td>
|
||||
</tr>
|
||||
<tr className="border-b border-gray-200">
|
||||
<td className="py-2 px-4 font-medium border-r border-gray-200">Total Tokens Generated</td>
|
||||
<td className="py-2 px-4 text-right">{totalTokens}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td className="py-2 px-4 font-medium border-r border-gray-200">Average Tokens/Second</td>
|
||||
<td className="py-2 px-4 text-right">{avgTokensPerSecond}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<LogPanel id="modelsupstream" title="Upstream Logs" logData={upstreamLogs} />
|
||||
</div>
|
||||
<div className="flex-1 overflow-y-auto">
|
||||
<table className="w-full">
|
||||
<thead className="sticky top-0 bg-card z-10">
|
||||
<tr className="border-b border-primary bg-surface">
|
||||
<th className="text-left p-2">Name</th>
|
||||
<th className="text-left p-2"></th>
|
||||
<th className="text-left p-2">State</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{filteredModels.map((model) => (
|
||||
<tr key={model.id} className="border-b hover:bg-secondary-hover border-border">
|
||||
<td className={`p-2 ${model.unlisted ? "text-txtsecondary" : ""}`}>
|
||||
<a href={`/upstream/${model.id}/`} className={`underline`} target="_blank">
|
||||
{model.name !== "" ? model.name : model.id}
|
||||
</a>
|
||||
{model.description !== "" && (
|
||||
<p className={model.unlisted ? "text-opacity-70" : ""}>
|
||||
<em>{model.description}</em>
|
||||
</p>
|
||||
)}
|
||||
</td>
|
||||
<td className="p-2 w-[50px]">
|
||||
<button
|
||||
className="btn btn--sm"
|
||||
disabled={model.state !== "stopped"}
|
||||
onClick={() => loadModel(model.id)}
|
||||
>
|
||||
Load
|
||||
</button>
|
||||
</td>
|
||||
<td className="p-2 w-[75px]">
|
||||
<span className={`status status--${model.state}`}>{model.state}</span>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function StatsPanel() {
|
||||
const { metrics } = useAPI();
|
||||
|
||||
const [totalRequests, totalTokens, avgTokensPerSecond] = useMemo(() => {
|
||||
const totalRequests = metrics.length;
|
||||
if (totalRequests === 0) {
|
||||
return [0, 0, 0];
|
||||
}
|
||||
const totalTokens = metrics.reduce((sum, m) => sum + m.output_tokens, 0);
|
||||
const avgTokensPerSecond = (metrics.reduce((sum, m) => sum + m.tokens_per_second, 0) / totalRequests).toFixed(2);
|
||||
return [totalRequests, totalTokens, avgTokensPerSecond];
|
||||
}, [metrics]);
|
||||
|
||||
return (
|
||||
<div className="card">
|
||||
<h2>Chat Activity</h2>
|
||||
<table className="w-full border border-gray-200">
|
||||
<tbody>
|
||||
<tr className="border-b border-gray-200">
|
||||
<td className="py-2 px-4 font-medium border-r border-gray-200">Requests</td>
|
||||
<td className="py-2 px-4 text-right">{totalRequests}</td>
|
||||
</tr>
|
||||
<tr className="border-b border-gray-200">
|
||||
<td className="py-2 px-4 font-medium border-r border-gray-200">Total Tokens Generated</td>
|
||||
<td className="py-2 px-4 text-right">{totalTokens}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td className="py-2 px-4 font-medium border-r border-gray-200">Average Tokens/Second</td>
|
||||
<td className="py-2 px-4 text-right">{avgTokensPerSecond}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user