Compare commits

...

6 Commits

Author SHA1 Message Date
George 0ab214d1c8 perf: add vendor-agnostic GPU monitoring for Windows (experimental) (#779)
Add GPU monitoring support for AMD and Intel GPUs on Windows using
D3DKMT (DirectX) and PDH performance counters.

- Add PDH-based GPU utilization via \GPU Engine(*)\Utilization
Percentage counter, summing all engine types per adapter (3D, Compute,
Copy, Video).
- Add D3DKMT bindings for adapter enumeration, memory segments, and
adapter perf data.
- Use PDH as primary utilization source (works on all vendors), with
D3DKMT RunningTime as fallback for systems without PDH counters.
- Prefer nvidia-smi when available, fall back to D3DKMT + PDH for
AMD/Intel.
- Backend priority: nvidia-smi -> D3DKMT + PDH -> ErrNoGpuTool.

Verified on AMD 7900XTX GPU with llama.cpp Vulkan & ROCm backend: GPU
utilization correctly shows ~99% during inference, ~0-2% when idle.

---

LLM disclosure: GLM 5.1 & Kimi K2.6 have been used extensively during
exploration and coding to the point that the LLM's wrote over 3/4 of the
code, and I have done additional verification myself.
As such, it should be considered experimental.
Additional verification is needed.

I have tested it on my 7900XTX system with Windows 11, and it works
correctly, but as I only have this one rig, I cannot verify it
everywhere.
2026-06-16 21:49:09 -07:00
Benson Wong d07b063ab6 internal/server,shared: support request metadata (#850)
- add support for http handlers in the request chain to append metadata
to the request
- metrics middleware will include metadata in the activity log 
- update Activity UI to support metadata, drag sort columns
- update Activity UI capture dialog to use more screen space

Updates #834
2026-06-16 21:44:55 -07:00
Benson Wong 826210dac9 .coderabbit.yaml: disable unit_tests 2026-06-16 10:10:17 -07:00
Benson Wong 6cf1317341 schedule,shared: move concurrency 429 limits into scheduler code (#849)
- make concurrency limiting the scheduler.Scheduler's responsibility
- eliminate the separate concurrency limit middleware 
- move concurrencyLimit logic into scheduler.FIFO to maintain backwards compatibility
- add HTTPError from #834 

Updates #834
2026-06-15 22:35:12 -07:00
Wojciech 8e84b2ec4f README.md: add macports install option to README (#848) 2026-06-15 15:58:24 -07:00
Benson Wong ed77385d08 ui: improve manual model load and cancel (#847)
- When a model is manually loaded show a cancel buttton and a queued
status
- Implement cancellation in scheduler.Scheduler interface and FIFO
scheduler
- Add cache bust query parameter to bypass browser cache

Fixes #844
2026-06-14 13:38:10 -07:00
33 changed files with 1824 additions and 314 deletions
+2
View File
@@ -15,6 +15,8 @@ reviews:
auto_review: auto_review:
enabled: false enabled: false
drafts: false drafts: false
unit_tests:
enabled: false
chat: chat:
auto_reply: true auto_reply: true
issue_enrichment: issue_enrichment:
+15 -4
View File
@@ -88,10 +88,11 @@ Real time log streaming:
llama-swap can be installed in multiple ways llama-swap can be installed in multiple ways
1. Docker 1. Docker
2. Homebrew (OSX and Linux) 2. Homebrew (macOS and Linux)
3. WinGet 3. MacPorts (macOS)
4. From release binaries 4. WinGet
5. From source 5. From release binaries
6. From source
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap)) ### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
@@ -155,6 +156,16 @@ brew install llama-swap
llama-swap --config path/to/config.yaml --listen localhost:8080 llama-swap --config path/to/config.yaml --listen localhost:8080
``` ```
### MacPorts (macOS)
> [!NOTE]
> Maintained by MacPorts community - [llama-swap port](https://ports.macports.org/port/llama-swap). It is not an official part of llama-swap.
```shell
sudo port install llama-swap
llama-swap --config path/to/config.yaml --listen localhost:8080
```
### WinGet Install (Windows) ### WinGet Install (Windows)
> [!NOTE] > [!NOTE]
+92
View File
@@ -0,0 +1,92 @@
package perf
type LUID struct {
LowPart uint32
HighPart int32
}
const maxEnumAdapters = 16
type D3DKMT_ENUMADAPTERS2 struct {
NumAdapters uint32
pAdapters uintptr
}
type D3DKMT_ADAPTERINFO struct {
hAdapter uint32
AdapterLuid LUID
NumOfSources uint32
bPresentMoveRegionsPreferred int32
}
type D3DKMT_OPENADAPTERFROMLUID struct {
AdapterLuid LUID
hAdapter uint32
}
type D3DKMT_CLOSEADAPTER struct {
hAdapter uint32
}
type KMTQUERYADAPTERINFOTYPE int32
const (
KMTQAITYPE_UMDRIVERPRIVATE KMTQUERYADAPTERINFOTYPE = 0
KMTQAITYPE_ADAPTERREGISTRYINFO KMTQUERYADAPTERINFOTYPE = 8
KMTQAITYPE_DRIVERVERSION KMTQUERYADAPTERINFOTYPE = 13
KMTQAITYPE_PHYSICALADAPTERDEVICEIDS KMTQUERYADAPTERINFOTYPE = 31
KMTQAITYPE_NODEPERFDATA KMTQUERYADAPTERINFOTYPE = 61
KMTQAITYPE_ADAPTERPERFDATA KMTQUERYADAPTERINFOTYPE = 62
KMTQAITYPE_ADAPTERPERFDATA_CAPS KMTQUERYADAPTERINFOTYPE = 63
)
type D3DKMT_QUERYADAPTERINFO struct {
hAdapter uint32
Type KMTQUERYADAPTERINFOTYPE
pPrivateDriverData uintptr
PrivateDriverDataSize uint32
}
type D3DKMT_ADAPTER_PERFDATA struct {
PhysicalAdapterIndex uint32
MemoryFrequency uint64
MaxMemoryFrequency uint64
MaxMemoryFrequencyOC uint64
MemoryBandwidth uint64
PCIEBandwidth uint64
FanRPM uint32
Power uint32
Temperature uint32
PowerStateOverride byte
}
type D3DKMT_QUERYSTATISTICS_TYPE int32
const (
D3DKMT_QUERYSTATISTICS_ADAPTER D3DKMT_QUERYSTATISTICS_TYPE = 0
D3DKMT_QUERYSTATISTICS_PROCESS D3DKMT_QUERYSTATISTICS_TYPE = 1
D3DKMT_QUERYSTATISTICS_PROCESS_ADAPTER D3DKMT_QUERYSTATISTICS_TYPE = 2
D3DKMT_QUERYSTATISTICS_SEGMENT D3DKMT_QUERYSTATISTICS_TYPE = 3
D3DKMT_QUERYSTATISTICS_PROCESS_SEGMENT D3DKMT_QUERYSTATISTICS_TYPE = 4
D3DKMT_QUERYSTATISTICS_NODE D3DKMT_QUERYSTATISTICS_TYPE = 5
D3DKMT_QUERYSTATISTICS_PROCESS_NODE D3DKMT_QUERYSTATISTICS_TYPE = 6
D3DKMT_QUERYSTATISTICS_VIDPNSOURCE D3DKMT_QUERYSTATISTICS_TYPE = 7
D3DKMT_QUERYSTATISTICS_PROCESS_VIDPNSOURCE D3DKMT_QUERYSTATISTICS_TYPE = 8
)
type D3DKMT_ADAPTER_PERFDATACAPS struct {
PhysicalAdapterIndex uint32
MaxMemoryBandwidth uint64
MaxPCIEBandwidth uint64
MaxFanRPM uint32
TemperatureMax uint32
TemperatureWarning uint32
}
type D3DKMT_QUERYSTATISTICS_QUERY_SEGMENT struct {
SegmentId uint32
}
type D3DKMT_QUERYSTATISTICS_QUERY_NODE struct {
NodeId uint32
}
+529
View File
@@ -0,0 +1,529 @@
//go:build windows
package perf
import (
"context"
"encoding/binary"
"fmt"
"sync"
"time"
"unsafe"
"github.com/mostlygeek/llama-swap/internal/logmon"
"golang.org/x/sys/windows"
)
var (
d3dkmDLL *windows.LazyDLL
procEnumAdapters2 *windows.LazyProc
procOpenAdapterFromLuid *windows.LazyProc
procCloseAdapter *windows.LazyProc
procQueryAdapterInfo *windows.LazyProc
procQueryStatistics *windows.LazyProc
d3dkmtInitOnce sync.Once
d3dkmtInitErr error
)
// initD3DKMT lazily loads gdi32.dll and resolves D3DKMT function pointers.
// Safe for concurrent use via sync.Once.
func initD3DKMT() error {
d3dkmtInitOnce.Do(func() {
d3dkmDLL = windows.NewLazySystemDLL("gdi32.dll")
procEnumAdapters2 = d3dkmDLL.NewProc("D3DKMTEnumAdapters2")
procOpenAdapterFromLuid = d3dkmDLL.NewProc("D3DKMTOpenAdapterFromLuid")
procCloseAdapter = d3dkmDLL.NewProc("D3DKMTCloseAdapter")
procQueryAdapterInfo = d3dkmDLL.NewProc("D3DKMTQueryAdapterInfo")
procQueryStatistics = d3dkmDLL.NewProc("D3DKMTQueryStatistics")
for name, p := range map[string]*windows.LazyProc{
"D3DKMTEnumAdapters2": procEnumAdapters2,
"D3DKMTOpenAdapterFromLuid": procOpenAdapterFromLuid,
"D3DKMTCloseAdapter": procCloseAdapter,
"D3DKMTQueryAdapterInfo": procQueryAdapterInfo,
"D3DKMTQueryStatistics": procQueryStatistics,
} {
if err := p.Find(); err != nil {
d3dkmtInitErr = fmt.Errorf("D3DKMT %s not found: %w", name, err)
return
}
}
})
return d3dkmtInitErr
}
// ntstatusCall invokes a D3DKMT function and returns a non-nil error if the
// NTSTATUS result is not STATUS_SUCCESS (0).
func ntstatusCall(proc *windows.LazyProc, arg unsafe.Pointer) error {
ret, _, _ := proc.Call(uintptr(arg))
if ret != 0 {
return fmt.Errorf("NTSTATUS 0x%08x", uint32(ret))
}
return nil
}
// d3dkmEnumerateAdapters enumerates all available graphics adapters via
// D3DKMTEnumAdapters2.
func d3dkmEnumerateAdapters() ([]D3DKMT_ADAPTERINFO, error) {
var adapters [maxEnumAdapters]D3DKMT_ADAPTERINFO
enum := D3DKMT_ENUMADAPTERS2{
NumAdapters: maxEnumAdapters,
pAdapters: uintptr(unsafe.Pointer(&adapters[0])),
}
if err := ntstatusCall(procEnumAdapters2, unsafe.Pointer(&enum)); err != nil {
return nil, fmt.Errorf("EnumAdapters2: %w", err)
}
if enum.NumAdapters == 0 {
return nil, fmt.Errorf("no adapters found")
}
result := make([]D3DKMT_ADAPTERINFO, enum.NumAdapters)
for i := uint32(0); i < enum.NumAdapters; i++ {
result[i] = adapters[i]
}
return result, nil
}
// d3dkmOpenAdapter opens a D3DKMT adapter handle for the given LUID.
func d3dkmOpenAdapter(luid LUID) (uint32, error) {
req := D3DKMT_OPENADAPTERFROMLUID{
AdapterLuid: luid,
}
if err := ntstatusCall(procOpenAdapterFromLuid, unsafe.Pointer(&req)); err != nil {
return 0, fmt.Errorf("OpenAdapterFromLuid: %w", err)
}
return req.hAdapter, nil
}
// d3dkmCloseAdapter closes a previously opened D3DKMT adapter handle.
func d3dkmCloseAdapter(hAdapter uint32) error {
req := D3DKMT_CLOSEADAPTER{hAdapter: hAdapter}
return ntstatusCall(procCloseAdapter, unsafe.Pointer(&req))
}
// d3dkmGetAdapterPerfData queries per-adapter performance data (temperature,
// fan RPM, power, bandwidth) via KMTQAITYPE_ADAPTERPERFDATA.
func d3dkmGetAdapterPerfData(hAdapter uint32) (*D3DKMT_ADAPTER_PERFDATA, error) {
var data D3DKMT_ADAPTER_PERFDATA
req := D3DKMT_QUERYADAPTERINFO{
hAdapter: hAdapter,
Type: KMTQAITYPE_ADAPTERPERFDATA,
pPrivateDriverData: uintptr(unsafe.Pointer(&data)),
PrivateDriverDataSize: uint32(unsafe.Sizeof(data)),
}
if err := ntstatusCall(procQueryAdapterInfo, unsafe.Pointer(&req)); err != nil {
return nil, fmt.Errorf("QueryAdapterInfo(ADAPTERPERFDATA): %w", err)
}
return &data, nil
}
// d3dkmGetAdapterPerfDataCaps queries static adapter performance capabilities
// (max fan RPM, temperature limits, max bandwidth) via KMTQAITYPE_ADAPTERPERFDATA_CAPS.
func d3dkmGetAdapterPerfDataCaps(hAdapter uint32) (*D3DKMT_ADAPTER_PERFDATACAPS, error) {
var data D3DKMT_ADAPTER_PERFDATACAPS
req := D3DKMT_QUERYADAPTERINFO{
hAdapter: hAdapter,
Type: KMTQAITYPE_ADAPTERPERFDATA_CAPS,
pPrivateDriverData: uintptr(unsafe.Pointer(&data)),
PrivateDriverDataSize: uint32(unsafe.Sizeof(data)),
}
if err := ntstatusCall(procQueryAdapterInfo, unsafe.Pointer(&req)); err != nil {
return nil, fmt.Errorf("QueryAdapterInfo(ADAPTERPERFDATACAPS): %w", err)
}
return &data, nil
}
type queryStatsBuffer struct {
Type int32 // offset 0
AdapterLuid LUID // offset 4
hProcess uintptr // offset 16
// _result mirrors the D3DKMT_QUERYSTATISTICS_RESULT union.
// sizeof(D3DKMT_QUERYSTATISTICS) == 0x328 (808 bytes) on x64.
//
// The C struct layout (x64):
// offset 0: Type (int32, 4 bytes)
// offset 4: AdapterLuid (LUID, 8 bytes)
// offset 12: 4 bytes padding (for 8-byte alignment of hProcess)
// offset 16: hProcess (HANDLE, 8 bytes)
// offset 24: QueryResult (union, 780 bytes — largest member is AdapterInformation)
// offset 804: anonymous input union (QueryNode.NodeId / QuerySegment.SegmentId, 4 bytes)
//
// Previous bug: _result was [776]byte, placing QueryId at offset 800 instead of 804.
// The kernel read NodeId/SegmentId from offset 804 (always zero from _pad),
// causing all NODE and SEGMENT queries to use index 0 regardless of the value
// passed in QueryId. This produced alternating behavior where only GPU util OR
// memory util appeared to work, depending on which test variant happened to put
// non-zero data near offset 804 in the result buffer.
_result [780]byte // offset 24, size 780 — places QueryId at offset 804
QueryId int32 // offset 804 — matches C anonymous union for NodeId/SegmentId
}
func init() {
var buf queryStatsBuffer
if unsafe.Sizeof(buf) != 808 {
panic(fmt.Sprintf("queryStatsBuffer size %d != expected 808 (sizeof D3DKMT_QUERYSTATISTICS on x64)", unsafe.Sizeof(buf)))
}
if unsafe.Offsetof(buf.QueryId) != 804 {
panic(fmt.Sprintf("queryStatsBuffer.QueryId offset %d != expected 804 (C anonymous union offset)", unsafe.Offsetof(buf.QueryId)))
}
var perfData D3DKMT_ADAPTER_PERFDATA
if unsafe.Sizeof(perfData) != 64 {
panic(fmt.Sprintf("D3DKMT_ADAPTER_PERFDATA size %d != expected 64 on x64", unsafe.Sizeof(perfData)))
}
var caps D3DKMT_ADAPTER_PERFDATACAPS
if unsafe.Sizeof(caps) != 40 {
panic(fmt.Sprintf("D3DKMT_ADAPTER_PERFDATACAPS size %d != expected 40 on x64", unsafe.Sizeof(caps)))
}
}
const (
qsoffsetNbSegments = 0
qsoffsetNodeCount = 4
qsoffsetCommitLimit = 0
qsoffsetBytesCommitted = 8
qsoffsetBytesResident = 16
qsoffsetRunningTime = 0
qsoffsetSystemRunningTime = 272
)
// d3dkmQueryAdapterStats returns the number of memory segments and compute
// nodes for the adapter identified by luid.
func d3dkmQueryAdapterStats(luid LUID) (nbSegments uint32, nodeCount uint32, err error) {
buf := queryStatsBuffer{
Type: int32(D3DKMT_QUERYSTATISTICS_ADAPTER),
AdapterLuid: luid,
}
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
return 0, 0, fmt.Errorf("QueryStatistics(ADAPTER): %w", err)
}
nbSegments = binary.LittleEndian.Uint32(buf._result[qsoffsetNbSegments : qsoffsetNbSegments+4])
nodeCount = binary.LittleEndian.Uint32(buf._result[qsoffsetNodeCount : qsoffsetNodeCount+4])
return nbSegments, nodeCount, nil
}
// d3dkmQuerySegmentStats returns the commit limit (total) and resident
// (used) bytes for the given memory segment of an adapter.
func d3dkmQuerySegmentStats(luid LUID, segmentID uint32) (commitLimit uint64, bytesResident uint64, err error) {
buf := queryStatsBuffer{
Type: int32(D3DKMT_QUERYSTATISTICS_SEGMENT),
AdapterLuid: luid,
QueryId: int32(segmentID),
}
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
return 0, 0, fmt.Errorf("QueryStatistics(SEGMENT %d): %w", segmentID, err)
}
commitLimit = binary.LittleEndian.Uint64(buf._result[qsoffsetCommitLimit : qsoffsetCommitLimit+8])
bytesResident = binary.LittleEndian.Uint64(buf._result[qsoffsetBytesResident : qsoffsetBytesResident+8])
if bytesResident == 0 {
bytesResident = binary.LittleEndian.Uint64(buf._result[qsoffsetBytesCommitted : qsoffsetBytesCommitted+8])
}
return commitLimit, bytesResident, nil
}
// d3dkmQueryNodeStats returns the global and system running time counters
// (in 100ns units) for the given compute node of an adapter.
func d3dkmQueryNodeStats(luid LUID, nodeID uint32) (runningTime uint64, systemRunningTime uint64, err error) {
buf := queryStatsBuffer{
Type: int32(D3DKMT_QUERYSTATISTICS_NODE),
AdapterLuid: luid,
QueryId: int32(nodeID),
}
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
return 0, 0, fmt.Errorf("QueryStatistics(NODE %d): %w", nodeID, err)
}
runningTime = binary.LittleEndian.Uint64(buf._result[qsoffsetRunningTime : qsoffsetRunningTime+8])
systemRunningTime = binary.LittleEndian.Uint64(buf._result[qsoffsetSystemRunningTime : qsoffsetSystemRunningTime+8])
return runningTime, systemRunningTime, nil
}
type nodeRunningTimes struct {
Global uint64
System uint64
}
// d3dkmtNodeUtil computes GPU node utilization as a percentage from running
// time deltas. Returns -1 if counters went backwards (wrap/reset), 0 if idle.
func d3dkmtNodeUtil(prevRT, curRT nodeRunningTimes, elapsed100ns int64) float64 {
if curRT.Global < prevRT.Global || curRT.System < prevRT.System {
return -1
}
gd := curRT.Global - prevRT.Global
sd := curRT.System - prevRT.System
if gd > 0 && sd > 0 {
util := float64(gd) / float64(sd)
if util > 1.0 {
util = 1.0
}
return util * 100.0
} else if gd > 0 && elapsed100ns > 0 {
util := float64(gd) / float64(elapsed100ns) * 100.0
if util > 100.0 {
util = 100.0
}
return util
}
return 0
}
// d3dkmtFanPct returns fan speed as a percentage of maxFanRPM, clamped to
// 100%. Returns 0 if maxFanRPM is unavailable or fan is not spinning.
func d3dkmtFanPct(fanRPM, maxFanRPM uint32) float64 {
if maxFanRPM > 0 && fanRPM > 0 {
pct := float64(fanRPM) / float64(maxFanRPM) * 100.0
if pct > 100.0 {
pct = 100.0
}
return pct
}
return 0
}
// d3dkmtPowerW converts power from deci-watts (as reported by D3DKMT) to
// watts. Returns 0 if the power value is zero.
func d3dkmtPowerW(power uint32) float64 {
if power > 0 {
return float64(power) / 10.0
}
return 0
}
// d3dkmtTempC converts temperature from deci-Celsius (as reported by D3DKMT)
// to degrees Celsius.
func d3dkmtTempC(tempDeciC uint32) int {
return int(tempDeciC / 10)
}
type d3dkmtAdapterState struct {
luid LUID
hAdapter uint32
nbSegments uint32
nodeCount uint32
maxFanRPM uint32
prevNodeRT map[uint32]nodeRunningTimes
prevTime time.Time
}
// tryD3DKMT attempts to start GPU monitoring using D3DKMT and optional PDH
// counters. It returns a channel of GpuStat snapshots or an error if no
// usable adapters are found.
func tryD3DKMT(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if err := initD3DKMT(); err != nil {
return nil, err
}
adapterInfos, err := d3dkmEnumerateAdapters()
if err != nil {
return nil, err
}
type adapterMeta struct {
luid LUID
nbSegments uint32
nodeCount uint32
maxFanRPM uint32
}
var metaList []adapterMeta
for i, ai := range adapterInfos {
hAdapter, err := d3dkmOpenAdapter(ai.AdapterLuid)
if err != nil {
logger.Debugf("adapter %d: open failed: %s", i, err.Error())
continue
}
nbSegments, nodeCount, err := d3dkmQueryAdapterStats(ai.AdapterLuid)
if err != nil {
logger.Debugf("adapter %d: query stats failed: %s", i, err.Error())
d3dkmCloseAdapter(hAdapter)
continue
}
caps, err := d3dkmGetAdapterPerfDataCaps(hAdapter)
if err != nil {
logger.Debugf("adapter %d: perf caps failed: %s", i, err.Error())
}
d3dkmCloseAdapter(hAdapter)
var maxFanRPM uint32
if caps != nil {
maxFanRPM = caps.MaxFanRPM
}
metaList = append(metaList, adapterMeta{
luid: ai.AdapterLuid,
nbSegments: nbSegments,
nodeCount: nodeCount,
maxFanRPM: maxFanRPM,
})
logger.Debugf("adapter %d: segments=%d nodes=%d fan_max=%d luid=%d:%d", i, nbSegments, nodeCount, maxFanRPM, ai.AdapterLuid.HighPart, ai.AdapterLuid.LowPart)
}
if len(metaList) == 0 {
return nil, fmt.Errorf("no usable D3DKMT adapters found")
}
pdhUtil, pdhErr := initPdhGpuUtil()
if pdhErr != nil {
logger.Debugf("PDH GPU utilization not available: %s", pdhErr.Error())
} else {
logger.Info("using PDH performance counters for GPU utilization")
}
ch := make(chan []GpuStat, 1)
go func() {
defer close(ch)
if pdhUtil != nil {
defer pdhUtil.close()
}
var adapters []d3dkmtAdapterState
for _, m := range metaList {
hAdapter, err := d3dkmOpenAdapter(m.luid)
if err != nil {
logger.Debugf("reopen adapter failed: %s", err.Error())
continue
}
adapters = append(adapters, d3dkmtAdapterState{
luid: m.luid,
hAdapter: hAdapter,
nbSegments: m.nbSegments,
nodeCount: m.nodeCount,
maxFanRPM: m.maxFanRPM,
prevNodeRT: make(map[uint32]nodeRunningTimes),
})
}
if len(adapters) == 0 {
return
}
defer func() {
for _, a := range adapters {
d3dkmCloseAdapter(a.hAdapter)
}
}()
for i := range adapters {
a := &adapters[i]
for node := uint32(0); node < a.nodeCount; node++ {
globalRT, systemRT, err := d3dkmQueryNodeStats(a.luid, node)
if err != nil {
continue
}
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
}
a.prevTime = time.Now()
}
ticker := time.NewTicker(every)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
stats := make([]GpuStat, 0, len(adapters))
now := time.Now()
var pdhUtilMap map[LUID]float64
if pdhUtil != nil {
pdhUtilMap = pdhUtil.collect()
}
for i := range adapters {
a := &adapters[i]
perfData, err := d3dkmGetAdapterPerfData(a.hAdapter)
if err != nil {
logger.Debugf("adapter %d perfdata: %s", i, err.Error())
continue
}
var memUsedMB, memTotalMB int
for seg := uint32(0); seg < a.nbSegments; seg++ {
limit, resident, err := d3dkmQuerySegmentStats(a.luid, seg)
if err != nil {
continue
}
memUsedMB += int(resident / (1024 * 1024))
memTotalMB += int(limit / (1024 * 1024))
}
var gpuUtil float64
pdhGaveValue := false
if pdhUtilMap != nil {
if util, ok := pdhUtilMap[a.luid]; ok {
gpuUtil = util
pdhGaveValue = true
}
}
if !pdhGaveValue && a.nodeCount > 0 {
elapsedNs := now.Sub(a.prevTime).Nanoseconds()
elapsed100ns := elapsedNs / 100
for node := uint32(0); node < a.nodeCount; node++ {
globalRT, systemRT, err := d3dkmQueryNodeStats(a.luid, node)
if err != nil {
continue
}
if prevRT, ok := a.prevNodeRT[node]; ok {
if globalRT < prevRT.Global || systemRT < prevRT.System {
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
continue
}
nodeUtil := d3dkmtNodeUtil(prevRT, nodeRunningTimes{Global: globalRT, System: systemRT}, elapsed100ns)
if nodeUtil > gpuUtil {
gpuUtil = nodeUtil
}
}
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
}
a.prevTime = now
}
tempC := d3dkmtTempC(perfData.Temperature)
fanSpeedPct := d3dkmtFanPct(perfData.FanRPM, a.maxFanRPM)
powerDrawW := d3dkmtPowerW(perfData.Power)
var memUtilPct float64
if memTotalMB > 0 {
memUtilPct = float64(memUsedMB) / float64(memTotalMB) * 100.0
}
stats = append(stats, GpuStat{
Timestamp: now,
ID: i,
Name: fmt.Sprintf("GPU %d", i),
TempC: tempC,
GpuUtilPct: gpuUtil,
MemUtilPct: memUtilPct,
MemUsedMB: memUsedMB,
MemTotalMB: memTotalMB,
FanSpeedPct: fanSpeedPct,
PowerDrawW: powerDrawW,
})
}
if len(stats) > 0 {
select {
case ch <- stats:
default:
}
}
}
}
}()
return ch, nil
}
+98
View File
@@ -0,0 +1,98 @@
//go:build windows
package perf
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestD3dkmtNodeUtil_FullLoad(t *testing.T) {
prev := nodeRunningTimes{Global: 1000, System: 10000}
cur := nodeRunningTimes{Global: 5000, System: 14000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, 100.0, got)
}
func TestD3dkmtNodeUtil_PartialUtil(t *testing.T) {
prev := nodeRunningTimes{Global: 1000, System: 10000}
cur := nodeRunningTimes{Global: 3000, System: 14000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, 50.0, got)
}
func TestD3dkmtNodeUtil_Identical(t *testing.T) {
prev := nodeRunningTimes{Global: 10000, System: 10000}
cur := nodeRunningTimes{Global: 20000, System: 20000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, 100.0, got)
}
func TestD3dkmtNodeUtil_CounterWrap(t *testing.T) {
prev := nodeRunningTimes{Global: 9000, System: 10000}
cur := nodeRunningTimes{Global: 1000, System: 10000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, -1.0, got)
}
func TestD3dkmtNodeUtil_SystemWrap(t *testing.T) {
prev := nodeRunningTimes{Global: 1000, System: 9000}
cur := nodeRunningTimes{Global: 5000, System: 1000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, -1.0, got)
}
func TestD3dkmtNodeUtil_ZeroDelta(t *testing.T) {
prev := nodeRunningTimes{Global: 1000, System: 10000}
cur := nodeRunningTimes{Global: 1000, System: 10000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, 0.0, got)
}
func TestD3dkmtNodeUtil_ElapsedFallback(t *testing.T) {
prev := nodeRunningTimes{Global: 1000, System: 10000}
cur := nodeRunningTimes{Global: 6000, System: 10000}
got := d3dkmtNodeUtil(prev, cur, 50000)
assert.InDelta(t, 10.0, got, 0.01)
}
func TestD3dkmtFanPct_Normal(t *testing.T) {
assert.Equal(t, 50.0, d3dkmtFanPct(1500, 3000))
}
func TestD3dkmtFanPct_MaxFan(t *testing.T) {
assert.Equal(t, 100.0, d3dkmtFanPct(3000, 3000))
}
func TestD3dkmtFanPct_OverMaxClamped(t *testing.T) {
assert.Equal(t, 100.0, d3dkmtFanPct(4000, 3000))
}
func TestD3dkmtFanPct_ZeroMaxFan(t *testing.T) {
assert.Equal(t, 0.0, d3dkmtFanPct(1500, 0))
}
func TestD3dkmtFanPct_ZeroFanRPM(t *testing.T) {
assert.Equal(t, 0.0, d3dkmtFanPct(0, 3000))
}
func TestD3dkmtFanPct_BothZero(t *testing.T) {
assert.Equal(t, 0.0, d3dkmtFanPct(0, 0))
}
func TestD3dkmtPowerW(t *testing.T) {
assert.Equal(t, 250.0, d3dkmtPowerW(2500))
}
func TestD3dkmtPowerW_Zero(t *testing.T) {
assert.Equal(t, 0.0, d3dkmtPowerW(0))
}
func TestD3dkmtTempC(t *testing.T) {
assert.Equal(t, 65, d3dkmtTempC(650))
}
func TestD3dkmtTempC_Zero(t *testing.T) {
assert.Equal(t, 0, d3dkmtTempC(0))
}
+7
View File
@@ -22,6 +22,13 @@ func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monito
logger.Debugf("nvidia-smi: %s", err.Error()) logger.Debugf("nvidia-smi: %s", err.Error())
} }
if ch, err := tryD3DKMT(ctx, every, logger); err == nil {
logger.Info("using D3DKMT for GPU monitoring")
return ch, nil
} else {
logger.Debugf("D3DKMT: %s", err.Error())
}
return nil, ErrNoGpuTool return nil, ErrNoGpuTool
} }
+159
View File
@@ -0,0 +1,159 @@
//go:build windows
package perf
import (
"fmt"
"strconv"
"strings"
"unsafe"
"golang.org/x/sys/windows"
)
var (
pdhDLL = windows.NewLazySystemDLL("pdh.dll")
procPdhOpenQuery = pdhDLL.NewProc("PdhOpenQueryW")
procPdhAddEnglishCounter = pdhDLL.NewProc("PdhAddEnglishCounterW")
procPdhCollectQueryData = pdhDLL.NewProc("PdhCollectQueryData")
procPdhGetFormattedCounterArray = pdhDLL.NewProc("PdhGetFormattedCounterArrayW")
procPdhCloseQuery = pdhDLL.NewProc("PdhCloseQuery")
)
const (
pdhFmtDouble = 0x00000200
pdhMoreData = 0x800007D2
pdhNoData = 0x800007D5
)
type pdhCounterValue struct {
CStatus uint32
DblVal float64
}
type pdhCounterValueItem struct {
SzName *uint16
FmtValue pdhCounterValue
}
func init() {
var item pdhCounterValueItem
if unsafe.Sizeof(item) != 24 {
panic(fmt.Sprintf("pdhCounterValueItem size %d != expected 24 on x64", unsafe.Sizeof(item)))
}
}
type pdhGpuUtil struct {
query uintptr
counter uintptr
}
// initPdhGpuUtil creates a PDH query for the GPU Engine utilization counter.
// Returns nil with an error if PDH or the counter is unavailable.
func initPdhGpuUtil() (*pdhGpuUtil, error) {
var query uintptr
if ret, _, _ := procPdhOpenQuery.Call(0, 0, uintptr(unsafe.Pointer(&query))); ret != 0 {
return nil, fmt.Errorf("PdhOpenQuery: 0x%x", ret)
}
path, _ := windows.UTF16PtrFromString(`\GPU Engine(*)\Utilization Percentage`)
var counter uintptr
if ret, _, _ := procPdhAddEnglishCounter.Call(
query, uintptr(unsafe.Pointer(path)), 0, uintptr(unsafe.Pointer(&counter)),
); ret != 0 {
procPdhCloseQuery.Call(query)
return nil, fmt.Errorf("PdhAddEnglishCounter(GPU Engine): 0x%x", ret)
}
procPdhCollectQueryData.Call(query)
return &pdhGpuUtil{query: query, counter: counter}, nil
}
// close releases the PDH query handle.
func (p *pdhGpuUtil) close() {
if p.query != 0 {
procPdhCloseQuery.Call(p.query)
p.query = 0
}
}
// collect reads the PDH counter and returns a map of adapter LUID to
// aggregated GPU utilization percentage, summed across all engine instances
// per adapter and clamped to 100%.
func (p *pdhGpuUtil) collect() map[LUID]float64 {
ret, _, _ := procPdhCollectQueryData.Call(p.query)
if ret != 0 && ret != pdhNoData {
return nil
}
var bufSize uint32
var itemCount uint32
ret, _, _ = procPdhGetFormattedCounterArray.Call(
p.counter, pdhFmtDouble,
uintptr(unsafe.Pointer(&bufSize)),
uintptr(unsafe.Pointer(&itemCount)),
0,
)
if ret != pdhMoreData || itemCount == 0 {
return nil
}
buf := make([]byte, bufSize)
ret, _, _ = procPdhGetFormattedCounterArray.Call(
p.counter, pdhFmtDouble,
uintptr(unsafe.Pointer(&bufSize)),
uintptr(unsafe.Pointer(&itemCount)),
uintptr(unsafe.Pointer(&buf[0])),
)
if ret != 0 {
return nil
}
itemSize := uint32(unsafe.Sizeof(pdhCounterValueItem{}))
result := make(map[LUID]float64)
for i := uint32(0); i < itemCount; i++ {
item := (*pdhCounterValueItem)(unsafe.Pointer(&buf[i*itemSize]))
if item.FmtValue.CStatus != 0 {
continue
}
luid, ok := parsePdhLuid(windows.UTF16PtrToString(item.SzName))
if !ok {
continue
}
result[luid] += item.FmtValue.DblVal
}
for luid := range result {
if result[luid] > 100.0 {
result[luid] = 100.0
}
}
return result
}
// parsePdhLuid extracts the adapter LUID (high and low parts) from a PDH
// GPU Engine instance name (e.g. "pid_1234_luid_0x00000000_0x000148BF_phys_0_eng_2_engtype_Compute").
func parsePdhLuid(name string) (LUID, bool) {
idx := strings.Index(name, "luid_0x")
if idx < 0 {
return LUID{}, false
}
rest := name[idx+7:]
parts := strings.SplitN(rest, "_", 4)
if len(parts) < 3 {
return LUID{}, false
}
hp, err := strconv.ParseUint(parts[0], 16, 32)
if err != nil {
return LUID{}, false
}
lpStr := strings.TrimPrefix(parts[1], "0x")
lp, err := strconv.ParseUint(lpStr, 16, 32)
if err != nil {
return LUID{}, false
}
return LUID{LowPart: uint32(lp), HighPart: int32(hp)}, true
}
+53
View File
@@ -0,0 +1,53 @@
//go:build windows
package perf
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestParsePdhLuid_Valid(t *testing.T) {
name := `pid_25312_luid_0x00000000_0x000148BF_phys_0_eng_2_engtype_Compute`
got, ok := parsePdhLuid(name)
assert.True(t, ok)
assert.Equal(t, uint32(0x000148BF), got.LowPart)
assert.Equal(t, int32(0x00000000), got.HighPart)
}
func TestParsePdhLuid_ValidNvidia(t *testing.T) {
name := `pid_1388_luid_0x00000000_0x00011372_phys_0_eng_8_engtype_Compute_1`
got, ok := parsePdhLuid(name)
assert.True(t, ok)
assert.Equal(t, uint32(0x00011372), got.LowPart)
assert.Equal(t, int32(0x00000000), got.HighPart)
}
func TestParsePdhLuid_NonZeroHighPart(t *testing.T) {
name := `pid_1234_luid_0x00000001_0x0000C85A_phys_0_eng_5_engtype_Copy`
got, ok := parsePdhLuid(name)
assert.True(t, ok)
assert.Equal(t, uint32(0x0000C85A), got.LowPart)
assert.Equal(t, int32(0x00000001), got.HighPart)
}
func TestParsePdhLuid_InvalidNoLuid(t *testing.T) {
_, ok := parsePdhLuid("invalid_string_without_luid")
assert.False(t, ok)
}
func TestParsePdhLuid_InvalidEmpty(t *testing.T) {
_, ok := parsePdhLuid("")
assert.False(t, ok)
}
func TestParsePdhLuid_InvalidHex(t *testing.T) {
_, ok := parsePdhLuid("pid_1234_luid_0xZZZZ_0xGGGG_phys_0")
assert.False(t, ok)
}
func TestParsePdhLuid_ShortAfterLuid(t *testing.T) {
_, ok := parsePdhLuid("pid_1234_luid_0x00000000")
assert.False(t, ok)
}
+23 -6
View File
@@ -28,8 +28,7 @@ type unloadReq struct {
// baseRouter owns the channels, run-loop, and process machinery shared by every // baseRouter owns the channels, run-loop, and process machinery shared by every
// concrete router. Concrete routers embed *baseRouter and supply a // concrete router. Concrete routers embed *baseRouter and supply a
// scheduler.Factory (which captures their scheduler.Swapper) describing how // scheduler.Swapper describing how eviction sets are decided. baseRouter
// requests are scheduled and how their eviction set is decided. baseRouter
// implements scheduler.Effects so the scheduler can call back for side-effects. // implements scheduler.Effects so the scheduler can call back for side-effects.
type baseRouter struct { type baseRouter struct {
name string name string
@@ -54,6 +53,7 @@ type baseRouter struct {
procCancel context.CancelFunc procCancel context.CancelFunc
handlerCh chan scheduler.HandlerReq handlerCh chan scheduler.HandlerReq
cancelCh chan scheduler.HandlerReq
shutdownCh chan shutdownReq shutdownCh chan shutdownReq
unloadCh chan unloadReq unloadCh chan unloadReq
swapDoneCh chan scheduler.SwapDone swapDoneCh chan scheduler.SwapDone
@@ -74,8 +74,8 @@ func newBaseRouter(
conf config.Config, conf config.Config,
processes map[string]process.Process, processes map[string]process.Process,
logger *logmon.Monitor, logger *logmon.Monitor,
newSched scheduler.Factory, planner scheduler.Swapper,
) *baseRouter { ) (*baseRouter, error) {
shutdownCtx, shutdownFn := context.WithCancel(context.Background()) shutdownCtx, shutdownFn := context.WithCancel(context.Background())
procCtx, procCancel := context.WithCancel(context.Background()) procCtx, procCancel := context.WithCancel(context.Background())
b := &baseRouter{ b := &baseRouter{
@@ -88,14 +88,19 @@ func newBaseRouter(
procCtx: procCtx, procCtx: procCtx,
procCancel: procCancel, procCancel: procCancel,
handlerCh: make(chan scheduler.HandlerReq), handlerCh: make(chan scheduler.HandlerReq),
cancelCh: make(chan scheduler.HandlerReq),
shutdownCh: make(chan shutdownReq), shutdownCh: make(chan shutdownReq),
unloadCh: make(chan unloadReq), unloadCh: make(chan unloadReq),
swapDoneCh: make(chan scheduler.SwapDone), swapDoneCh: make(chan scheduler.SwapDone),
serveDoneCh: make(chan scheduler.ServeDoneEvent), serveDoneCh: make(chan scheduler.ServeDoneEvent),
runDone: make(chan struct{}), runDone: make(chan struct{}),
} }
b.schedule = newSched(name, logger, b) sched, err := scheduler.New(conf, name, logger, planner, b)
return b if err != nil {
return nil, err
}
b.schedule = sched
return b, nil
} }
func (b *baseRouter) notifyProcessed() { func (b *baseRouter) notifyProcessed() {
@@ -117,6 +122,10 @@ func (b *baseRouter) run() {
b.schedule.OnRequest(req) b.schedule.OnRequest(req)
b.notifyProcessed() b.notifyProcessed()
case req := <-b.cancelCh:
b.schedule.OnCancel(req)
b.notifyProcessed()
case req := <-b.unloadCh: case req := <-b.unloadCh:
b.schedule.OnUnload(req.targets, req.timeout) b.schedule.OnUnload(req.targets, req.timeout)
close(req.respond) close(req.respond)
@@ -473,6 +482,14 @@ func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
finishLoading() finishLoading()
case <-req.Context().Done(): case <-req.Context().Done():
finishLoading() finishLoading()
// Notify the scheduler so it can prune this request from its queue
// and swap waiters. Without this, a queued request whose client left
// would sit in the scheduler until drainQueue eventually starts a
// wasted model load for it.
select {
case b.cancelCh <- hr:
case <-b.shutdownCtx.Done():
}
return return
case <-b.shutdownCtx.Done(): case <-b.shutdownCtx.Done():
finishLoading() finishLoading()
+4 -4
View File
@@ -29,10 +29,10 @@ func (s *stubPlanner) OnSwapStart(string, []string) {}
func newTestBase(t *testing.T, processes map[string]process.Process, planner scheduler.Swapper) *baseRouter { func newTestBase(t *testing.T, processes map[string]process.Process, planner scheduler.Swapper) *baseRouter {
t.Helper() t.Helper()
conf := config.Config{HealthCheckTimeout: 5} conf := config.Config{HealthCheckTimeout: 5}
b := newBaseRouter("test", conf, processes, logmon.NewWriter(io.Discard), b, err := newBaseRouter("test", conf, processes, logmon.NewWriter(io.Discard), planner)
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler { if err != nil {
return scheduler.NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, eff) t.Fatalf("newBaseRouter: %v", err)
}) }
b.testProcessed = make(chan struct{}, 64) b.testProcessed = make(chan struct{}, 64)
go b.run() go b.run()
t.Cleanup(func() { t.Cleanup(func() {
+4 -5
View File
@@ -6,7 +6,6 @@ import (
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
) )
type Group struct { type Group struct {
@@ -30,10 +29,10 @@ func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group
} }
processes := make(map[string]process.Process, len(modelToGroup)) processes := make(map[string]process.Process, len(modelToGroup))
base := newBaseRouter("group", conf, processes, proxylog, base, err := newBaseRouter("group", conf, processes, proxylog, swapper)
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler { if err != nil {
return scheduler.NewFIFO(name, logger, swapper, conf.Routing.Scheduler.Settings.Fifo, eff) return nil, fmt.Errorf("creating base router: %w", err)
}) }
for mid := range modelToGroup { for mid := range modelToGroup {
modelCfg, _, ok := conf.FindConfig(mid) modelCfg, _, ok := conf.FindConfig(mid)
+4 -5
View File
@@ -10,7 +10,6 @@ import (
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
) )
// newTestGroup builds a Group directly from the supplied processes and config, // newTestGroup builds a Group directly from the supplied processes and config,
@@ -27,10 +26,10 @@ func newTestGroup(t *testing.T, conf config.Config, processes map[string]process
config: conf, config: conf,
modelToGroup: modelToGroup, modelToGroup: modelToGroup,
} }
base := newBaseRouter("group", conf, processes, logmon.NewWriter(io.Discard), base, err := newBaseRouter("group", conf, processes, logmon.NewWriter(io.Discard), swapper)
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler { if err != nil {
return scheduler.NewFIFO(name, logger, swapper, conf.Routing.Scheduler.Settings.Fifo, eff) t.Fatalf("newBaseRouter: %v", err)
}) }
base.testProcessed = make(chan struct{}, 64) base.testProcessed = make(chan struct{}, 64)
g := &Group{baseRouter: base} g := &Group{baseRouter: base}
go base.run() go base.run()
+4 -5
View File
@@ -6,7 +6,6 @@ import (
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
) )
type Matrix struct { type Matrix struct {
@@ -27,10 +26,10 @@ func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matr
// Build a process for every model in the config. Any model can run alone // Build a process for every model in the config. Any model can run alone
// even if it is not part of a set; this mirrors proxy.NewMatrix. // even if it is not part of a set; this mirrors proxy.NewMatrix.
processes := make(map[string]process.Process, len(conf.Models)) processes := make(map[string]process.Process, len(conf.Models))
base := newBaseRouter("matrix", conf, processes, proxylog, base, err := newBaseRouter("matrix", conf, processes, proxylog, swapper)
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler { if err != nil {
return scheduler.NewFIFO(name, logger, swapper, conf.Routing.Scheduler.Settings.Fifo, eff) return nil, fmt.Errorf("creating base router: %w", err)
}) }
for mid, modelCfg := range conf.Models { for mid, modelCfg := range conf.Models {
procLog := logmon.NewWriter(upstreamlog) procLog := logmon.NewWriter(upstreamlog)
+4 -5
View File
@@ -10,7 +10,6 @@ import (
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
) )
// newTestMatrix builds a Matrix router from supplied processes, bypassing // newTestMatrix builds a Matrix router from supplied processes, bypassing
@@ -22,10 +21,10 @@ func newTestMatrix(t *testing.T, conf config.Config, expanded []config.ExpandedS
solver: newMatrixSolver(expanded, evictCosts), solver: newMatrixSolver(expanded, evictCosts),
logger: logger, logger: logger,
} }
base := newBaseRouter("matrix", conf, processes, logger, base, err := newBaseRouter("matrix", conf, processes, logger, swapper)
func(name string, l *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler { if err != nil {
return scheduler.NewFIFO(name, l, swapper, conf.Routing.Scheduler.Settings.Fifo, eff) t.Fatalf("newBaseRouter: %v", err)
}) }
base.testProcessed = make(chan struct{}, 64) base.testProcessed = make(chan struct{}, 64)
r := &Matrix{baseRouter: base} r := &Matrix{baseRouter: base}
go base.run() go base.run()
+81 -3
View File
@@ -3,13 +3,19 @@ package scheduler
import ( import (
"fmt" "fmt"
"sort" "sort"
"strconv"
"time" "time"
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/shared"
) )
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
// the model config leaves concurrencyLimit unset.
const defaultConcurrencyLimit = 10
// activeSwap tracks one in-flight swap and the callers waiting on it. // activeSwap tracks one in-flight swap and the callers waiting on it.
type activeSwap struct { type activeSwap struct {
modelID string modelID string
@@ -33,20 +39,32 @@ type FIFO struct {
cfg config.FifoConfig cfg config.FifoConfig
effects Effects effects Effects
limits map[string]int
active map[string]*activeSwap active map[string]*activeSwap
inFlight map[string]int inFlight map[string]int
queued []HandlerReq queued []HandlerReq
} }
// NewFIFO builds a FIFO scheduler. It matches scheduler.Factory once a planner // NewFIFO builds a FIFO scheduler. Per-model concurrency limits are derived
// is captured in a closure. // from models: each model's ConcurrencyLimit overrides defaultConcurrencyLimit
func NewFIFO(name string, logger *logmon.Monitor, planner Swapper, cfg config.FifoConfig, eff Effects) *FIFO { // when set to a value greater than zero.
func NewFIFO(name string, logger *logmon.Monitor, planner Swapper, cfg config.FifoConfig, models map[string]config.ModelConfig, eff Effects) *FIFO {
limits := make(map[string]int, len(models))
for id, mc := range models {
limit := defaultConcurrencyLimit
if mc.ConcurrencyLimit > 0 {
limit = mc.ConcurrencyLimit
}
limits[id] = limit
}
return &FIFO{ return &FIFO{
name: name, name: name,
logger: logger, logger: logger,
planner: planner, planner: planner,
cfg: cfg, cfg: cfg,
effects: eff, effects: eff,
limits: limits,
active: make(map[string]*activeSwap), active: make(map[string]*activeSwap),
inFlight: make(map[string]int), inFlight: make(map[string]int),
} }
@@ -116,6 +134,46 @@ func (s *FIFO) OnRequest(req HandlerReq) {
s.startSwap(req, evict, running) s.startSwap(req, evict, running)
} }
// OnCancel removes a request whose client has disconnected from the queue and
// from every in-flight swap's waiters. If the request was the sole waiter of an
// active swap, the swap goroutine is left to complete on its own — OnSwapDone
// will find no waiters and simply clean up. This prevents drainQueue from ever
// starting a model load for a caller that is no longer there.
func (s *FIFO) OnCancel(req HandlerReq) {
removed := false
// Prune from the queue.
if len(s.queued) > 0 {
kept := s.queued[:0]
for _, q := range s.queued {
if q.Respond == req.Respond {
removed = true
continue
}
kept = append(kept, q)
}
s.queued = kept
}
// Prune from any active swap's waiters.
for _, sw := range s.active {
filtered := sw.waiters[:0]
for _, w := range sw.waiters {
if w.Respond == req.Respond {
removed = true
continue
}
filtered = append(filtered, w)
}
sw.waiters = filtered
}
if removed {
s.logger.Debugf("%s: cancelled request for model %s pruned from scheduler", s.name, req.Model)
broadcastQueuePositions(s.queued)
}
}
// OnSwapDone fans the result out to every waiter that joined this swap, removes // OnSwapDone fans the result out to every waiter that joined this swap, removes
// the swap from the active map, then walks the queue once, promoting any items // the swap from the active map, then walks the queue once, promoting any items
// that no longer collide with the remaining active set. FIFO order is preserved: // that no longer collide with the remaining active set. FIFO order is preserved:
@@ -214,12 +272,32 @@ func (s *FIFO) OnShutdown(err error) {
// grantHandler hands the caller a tracked handler for modelID and, only if the // grantHandler hands the caller a tracked handler for modelID and, only if the
// caller was still there to receive it, bumps the in-flight count. Incrementing // caller was still there to receive it, bumps the in-flight count. Incrementing
// when the grant failed would strand the counter and block future evictions. // when the grant failed would strand the counter and block future evictions.
// Requests that would exceed the model's concurrency limit are rejected with a
// shared.NewConcurrencyLimitError (HTTP 429 with Retry-After).
func (s *FIFO) grantHandler(req HandlerReq, modelID string) { func (s *FIFO) grantHandler(req HandlerReq, modelID string) {
if s.inFlight[modelID] >= s.limit(modelID) {
s.effects.GrantError(req, shared.ConcurrencyLimitError{})
return
}
if err := shared.SetReqData(req.Ctx, "fifo_priority", strconv.Itoa(s.cfg.Priority[req.Model])); err != nil {
s.logger.Debugf("failed to set fifo_priority metadata: %v", err)
}
if s.effects.GrantServe(req, modelID) { if s.effects.GrantServe(req, modelID) {
s.inFlight[modelID]++ s.inFlight[modelID]++
} }
} }
// limit returns the per-model concurrency cap, defaulting to
// defaultConcurrencyLimit when the model has no explicit entry.
func (s *FIFO) limit(modelID string) int {
if l, ok := s.limits[modelID]; ok {
return l
}
return defaultConcurrencyLimit
}
// startSwap records the swap as active and launches it via Effects. running is // startSwap records the swap as active and launches it via Effects. running is
// the set EvictionFor saw, forwarded to OnSwapStart so the planner logs against // the set EvictionFor saw, forwarded to OnSwapStart so the planner logs against
// the same picture it decided on. // the same picture it decided on.
+246 -4
View File
@@ -1,14 +1,17 @@
package scheduler package scheduler
import ( import (
"context"
"errors" "errors"
"io" "io"
"net/http"
"testing" "testing"
"time" "time"
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/shared"
) )
// FIFO methods all run on the router's single run-loop goroutine, so these // FIFO methods all run on the router's single run-loop goroutine, so these
@@ -52,8 +55,9 @@ type stopRec struct {
// fakeEffects is an in-memory scheduler.Effects. Tests program process states // fakeEffects is an in-memory scheduler.Effects. Tests program process states
// and GrantServe outcomes, then assert on the recorded calls. // and GrantServe outcomes, then assert on the recorded calls.
type fakeEffects struct { type fakeEffects struct {
states map[string]process.ProcessState // model -> state; missing => not handled states map[string]process.ProcessState // model -> state; missing => not handled
serveResult map[string]bool // GrantServe return per model (default true) serveResult map[string]bool // GrantServe return per model (default true)
lastServeReq HandlerReq
starts []startRec starts []startRec
grants []grantRec grants []grantRec
@@ -96,6 +100,7 @@ func (f *fakeEffects) GrantServe(req HandlerReq, modelID string) bool {
if v, set := f.serveResult[modelID]; set { if v, set := f.serveResult[modelID]; set {
ok = v ok = v
} }
f.lastServeReq = req
f.grants = append(f.grants, grantRec{model: modelID, serve: ok}) f.grants = append(f.grants, grantRec{model: modelID, serve: ok})
return ok return ok
} }
@@ -138,11 +143,20 @@ func (f *fakeEffects) startsFor(modelID string) int {
} }
func newFIFO(planner Swapper, eff Effects) *FIFO { func newFIFO(planner Swapper, eff Effects) *FIFO {
return NewFIFO("test", logmon.NewWriter(io.Discard), planner, config.FifoConfig{}, eff) return NewFIFO("test", logmon.NewWriter(io.Discard), planner, config.FifoConfig{}, nil, eff)
} }
func req(model string) HandlerReq { return HandlerReq{Model: model} } func req(model string) HandlerReq { return HandlerReq{Model: model} }
// reqCh creates a HandlerReq with a unique Respond channel so OnCancel can
// identify it among queued requests and swap waiters.
func reqCh(model string) HandlerReq {
return HandlerReq{
Model: model,
Respond: make(chan HandlerResp, 1),
}
}
func TestFIFO_FastPath(t *testing.T) { func TestFIFO_FastPath(t *testing.T) {
eff := newFakeEffects() eff := newFakeEffects()
eff.states["a"] = process.StateReady eff.states["a"] = process.StateReady
@@ -158,6 +172,27 @@ func TestFIFO_FastPath(t *testing.T) {
} }
} }
func TestFIFO_GrantSetsPriorityMetadata(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateReady
cfg := config.FifoConfig{Priority: map[string]int{"a": 7}}
s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, cfg, nil, eff)
ctx := shared.SetContext(context.Background(), shared.ReqContextData{ModelID: "a", Metadata: make(map[string]string)})
s.OnRequest(HandlerReq{Model: "a", Ctx: ctx})
if got := eff.served("a"); got != 1 {
t.Fatalf("served(a)=%d want 1", got)
}
data, ok := shared.ReadContext(eff.lastServeReq.Ctx)
if !ok {
t.Fatal("context data missing from granted request")
}
if data.Metadata["fifo_priority"] != "7" {
t.Errorf("fifo_priority = %q, want 7", data.Metadata["fifo_priority"])
}
}
func TestFIFO_ModelNotFound(t *testing.T) { func TestFIFO_ModelNotFound(t *testing.T) {
eff := newFakeEffects() // no states => model unknown eff := newFakeEffects() // no states => model unknown
s := newFIFO(&stubPlanner{}, eff) s := newFIFO(&stubPlanner{}, eff)
@@ -512,7 +547,7 @@ func TestFIFO_PriorityQueueOrder(t *testing.T) {
// loading collides with z's in-flight swap and parks in the queue. // loading collides with z's in-flight swap and parks in the queue.
planner := &stubPlanner{evict: map[string][]string{"z": {"A", "B", "C", "D"}}} planner := &stubPlanner{evict: map[string][]string{"z": {"A", "B", "C", "D"}}}
cfg := config.FifoConfig{Priority: map[string]int{"A": 10, "B": 5, "C": 5, "D": 1}} cfg := config.FifoConfig{Priority: map[string]int{"A": 10, "B": 5, "C": 5, "D": 1}}
s := NewFIFO("test", logmon.NewWriter(io.Discard), planner, cfg, eff) s := NewFIFO("test", logmon.NewWriter(io.Discard), planner, cfg, nil, eff)
s.OnRequest(req("z")) // StartSwap(z, [A,B,C,D]) s.OnRequest(req("z")) // StartSwap(z, [A,B,C,D])
@@ -535,3 +570,210 @@ func TestFIFO_PriorityQueueOrder(t *testing.T) {
} }
} }
} }
// TestFIFO_OnCancel_QueuedRequest verifies that cancelling a queued request
// prevents drainQueue from ever starting a model load for it. Without OnCancel
// the dead request would sit in the queue until a drain triggers a wasted swap.
func TestFIFO_OnCancel_QueuedRequest(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateStopped
eff.states["b"] = process.StateStopped
// b evicts a, so a request for b queues while a is loading.
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
s.OnRequest(req("a")) // StartSwap(a)
cancelledReq := reqCh("b")
s.OnRequest(cancelledReq) // queued (collides with a's in-flight swap)
if len(s.queued) != 1 {
t.Fatalf("queue len=%d want 1 before cancel", len(s.queued))
}
// Client disconnects.
s.OnCancel(cancelledReq)
if len(s.queued) != 0 {
t.Fatalf("queue len=%d want 0 after cancel", len(s.queued))
}
// a's swap finishes; drainQueue runs but b is gone — no swap for b.
eff.states["a"] = process.StateReady
s.OnSwapDone(SwapDone{ModelID: "a"})
if got := eff.startsFor("b"); got != 0 {
t.Errorf("StartSwap(b)=%d want 0 (cancelled request should not trigger a load)", got)
}
}
// TestFIFO_OnCancel_SwapWaiter verifies that cancelling a request that joined an
// in-flight swap removes it from the waiter list. When the swap completes, the
// cancelled waiter receives no grant and does not bump the in-flight count.
func TestFIFO_OnCancel_SwapWaiter(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateStopped
s := newFIFO(&stubPlanner{}, eff)
liveReq := reqCh("a")
cancelledReq := reqCh("a")
s.OnRequest(liveReq) // starts swap
s.OnRequest(cancelledReq) // joins
if sw := s.active["a"]; len(sw.waiters) != 2 {
t.Fatalf("waiters=%d want 2", len(sw.waiters))
}
s.OnCancel(cancelledReq)
if sw := s.active["a"]; len(sw.waiters) != 1 {
t.Fatalf("waiters=%d want 1 after cancel", len(sw.waiters))
}
// Swap finishes: only the live waiter is granted.
eff.states["a"] = process.StateReady
s.OnSwapDone(SwapDone{ModelID: "a"})
if got := eff.served("a"); got != 1 {
t.Errorf("served(a)=%d want 1 (only the non-cancelled waiter)", got)
}
}
// TestFIFO_OnCancel_NotPresent is a no-op: cancelling a request that was already
// granted (and is no longer queued or waiting) must not affect anything.
func TestFIFO_OnCancel_NotPresent(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateReady
s := newFIFO(&stubPlanner{}, eff)
r := reqCh("a")
s.OnRequest(r) // fast-path served immediately
// Cancel after grant — should be a harmless no-op.
s.OnCancel(r)
if got := eff.served("a"); got != 1 {
t.Errorf("served(a)=%d want 1 (cancel of granted request is a no-op)", got)
}
if len(s.queued) != 0 {
t.Errorf("queue should be empty, len=%d", len(s.queued))
}
}
// newFIFOWithLimit builds a FIFO whose single model has the given concurrency
// limit, already in StateReady so every request exercises the fast path.
func newFIFOWithLimit(t *testing.T, model string, limit int) (*FIFO, *fakeEffects) {
t.Helper()
eff := newFakeEffects()
eff.states[model] = process.StateReady
models := map[string]config.ModelConfig{
model: {ConcurrencyLimit: limit},
}
s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, config.FifoConfig{}, models, eff)
return s, eff
}
// TestFIFO_ConcurrencyLimit_RejectsOverLimit verifies that a request arriving
// while the model is at capacity gets an error grant instead of being served,
// and that a new request succeeds once an in-flight one completes.
func TestFIFO_ConcurrencyLimit_RejectsOverLimit(t *testing.T) {
s, eff := newFIFOWithLimit(t, "a", 1)
// First request: served (inFlight 0 → 1).
s.OnRequest(req("a"))
if got := eff.served("a"); got != 1 {
t.Fatalf("served(a)=%d want 1", got)
}
// Second request while slot is occupied: rejected with HTTPError 429.
s.OnRequest(req("a"))
if got := eff.errored("a"); got != 1 {
t.Fatalf("errored(a)=%d want 1 (over-limit)", got)
}
var httpErr shared.HTTPError
if !errors.As(eff.grants[len(eff.grants)-1].err, &httpErr) {
t.Fatalf("err=%v want HTTPError", eff.grants[len(eff.grants)-1].err)
}
if httpErr.StatusCode() != http.StatusTooManyRequests {
t.Fatalf("StatusCode()=%d want 429", httpErr.StatusCode())
}
if httpErr.Header().Get("Retry-After") == "" {
t.Fatal("missing Retry-After header")
}
// After the in-flight request finishes, a new request succeeds.
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
s.OnRequest(req("a"))
if got := eff.served("a"); got != 2 {
t.Fatalf("served(a)=%d want 2 after drain", got)
}
}
// TestFIFO_ConcurrencyLimit_DefaultIsTen verifies that a model without an
// explicit ConcurrencyLimit gets the default cap of 10.
func TestFIFO_ConcurrencyLimit_DefaultIsTen(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateReady
// nil models → every model gets defaultConcurrencyLimit (10).
s := newFIFO(&stubPlanner{}, eff)
for i := 0; i < 10; i++ {
s.OnRequest(req("a"))
}
if got := eff.served("a"); got != 10 {
t.Fatalf("served(a)=%d want 10 (default limit)", got)
}
// 11th request is rejected.
s.OnRequest(req("a"))
if got := eff.errored("a"); got != 1 {
t.Fatalf("errored(a)=%d want 1 (over default limit)", got)
}
}
// TestFIFO_ConcurrencyLimit_CustomLimit verifies a ConcurrencyLimit greater
// than zero overrides the default.
func TestFIFO_ConcurrencyLimit_CustomLimit(t *testing.T) {
s, eff := newFIFOWithLimit(t, "a", 2)
s.OnRequest(req("a"))
s.OnRequest(req("a"))
s.OnRequest(req("a"))
if got := eff.served("a"); got != 2 {
t.Fatalf("served(a)=%d want 2 (custom limit)", got)
}
if got := eff.errored("a"); got != 1 {
t.Fatalf("errored(a)=%d want 1 (over custom limit)", got)
}
}
// TestFIFO_ConcurrencyLimit_SwapWaiters verifies that when more swap waiters
// exist than the concurrency limit, excess waiters are rejected on swap
// completion rather than exceeding the limit.
func TestFIFO_ConcurrencyLimit_SwapWaiters(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateStopped
models := map[string]config.ModelConfig{
"a": {ConcurrencyLimit: 2},
}
s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, config.FifoConfig{}, models, eff)
// Three requests arrive while model is loading: one starts swap, two join.
s.OnRequest(req("a"))
s.OnRequest(req("a"))
s.OnRequest(req("a"))
if got := eff.startsFor("a"); got != 1 {
t.Fatalf("StartSwap(a)=%d want 1", got)
}
// Swap completes: two served (limit), one rejected.
eff.states["a"] = process.StateReady
s.OnSwapDone(SwapDone{ModelID: "a"})
if got := eff.served("a"); got != 2 {
t.Fatalf("served(a)=%d want 2 (limit on swap completion)", got)
}
if got := eff.errored("a"); got != 1 {
t.Fatalf("errored(a)=%d want 1 (excess waiter rejected)", got)
}
}
+22 -3
View File
@@ -11,9 +11,11 @@ package scheduler
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"time" "time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/shared" "github.com/mostlygeek/llama-swap/internal/shared"
@@ -47,6 +49,11 @@ type Swapper interface {
type Scheduler interface { type Scheduler interface {
// OnRequest handles one incoming ServeHTTP request. // OnRequest handles one incoming ServeHTTP request.
OnRequest(req HandlerReq) OnRequest(req HandlerReq)
// OnCancel handles a request whose client has disconnected before it was
// granted. The scheduler must remove the request from its queue and from
// any in-flight swap's waiters so it never triggers a model load or grant
// for a caller that is no longer there.
OnCancel(req HandlerReq)
// OnSwapDone handles a swap goroutine reporting completion. // OnSwapDone handles a swap goroutine reporting completion.
OnSwapDone(ev SwapDone) OnSwapDone(ev SwapDone)
// OnServeDone handles a tracked ServeHTTP finishing (in-flight decrement). // OnServeDone handles a tracked ServeHTTP finishing (in-flight decrement).
@@ -85,9 +92,21 @@ type Effects interface {
StopProcesses(timeout time.Duration, ids []string) StopProcesses(timeout time.Duration, ids []string)
} }
// Factory builds a Scheduler bound to a baseRouter's Effects. The concrete // New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured
// router captures its Swapper in the closure it passes as a Factory. // from conf and bound to the given planner and effects. Currently only "fifo"
type Factory func(name string, logger *logmon.Monitor, eff Effects) Scheduler // (the default) is supported.
func New(conf config.Config, name string, logger *logmon.Monitor, planner Swapper, eff Effects) (Scheduler, error) {
use := conf.Routing.Scheduler.Use
if use == "" {
use = "fifo"
}
switch use {
case "fifo":
return NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, conf.Models, eff), nil
default:
return nil, fmt.Errorf("unsupported scheduler type: %q", use)
}
}
// HandlerReq is one in-flight ServeHTTP request waiting for a routing decision. // HandlerReq is one in-flight ServeHTTP request waiting for a routing decision.
type HandlerReq struct { type HandlerReq struct {
+2 -2
View File
@@ -271,7 +271,7 @@ func (s *Server) startPreload() {
if err != nil { if err != nil {
continue continue
} }
req = req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: modelID, ModelID: modelID})) req = req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: modelID, ModelID: modelID, Metadata: make(map[string]string)}))
dw := &discardResponseWriter{status: http.StatusOK} dw := &discardResponseWriter{status: http.StatusOK}
s.local.ServeHTTP(dw, req) s.local.ServeHTTP(dw, req)
@@ -338,7 +338,7 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
// Strip the /upstream/<model> prefix before forwarding. // Strip the /upstream/<model> prefix before forwarding.
r.URL.Path = remainingPath r.URL.Path = remainingPath
// Pin the resolved model so the router skips body/query extraction. // Pin the resolved model so the router skips body/query extraction.
*r = *r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: searchName, ModelID: modelID})) *r = *r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: searchName, ModelID: modelID, Metadata: make(map[string]string)}))
switch { switch {
case s.local.Handles(modelID): case s.local.Handles(modelID):
-57
View File
@@ -1,57 +0,0 @@
package server
import (
"net/http"
"golang.org/x/sync/semaphore"
"github.com/mostlygeek/llama-swap/internal/chain"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/shared"
)
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
// the model config leaves concurrencyLimit unset. Matches the legacy
// proxy.Process default.
const defaultConcurrencyLimit = 10
// CreateConcurrencyMiddleware returns middleware that limits simultaneous
// model-dispatched requests per model. Each model gets a semaphore sized to
// its concurrencyLimit (or defaultConcurrencyLimit). A request that cannot
// immediately acquire a slot is rejected with 429. Models without a local
// config entry (e.g. peer-routed models) are not limited.
func CreateConcurrencyMiddleware(cfg config.Config) chain.Middleware {
semaphores := make(map[string]*semaphore.Weighted, len(cfg.Models))
for id, mc := range cfg.Models {
limit := defaultConcurrencyLimit
if mc.ConcurrencyLimit > 0 {
limit = mc.ConcurrencyLimit
}
semaphores[id] = semaphore.NewWeighted(int64(limit))
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data, err := shared.FetchContext(r, cfg)
if err != nil {
shared.SendError(w, r, shared.ErrNoModelInContext)
return
}
// fall through for peer models
sem, ok := semaphores[data.ModelID]
if !ok {
next.ServeHTTP(w, r)
return
}
if !sem.TryAcquire(1) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error":"Too many requests"}`))
return
}
defer sem.Release(1)
next.ServeHTTP(w, r)
})
}
}
-75
View File
@@ -1,75 +0,0 @@
package server
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/shared"
)
func concurrencyTestReq(model string) *http.Request {
r := httptest.NewRequest("GET", "/v1/chat/completions", nil)
return r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: model, ModelID: model}))
}
func TestServer_ConcurrencyMiddleware_RejectsOverLimit(t *testing.T) {
cfg := config.Config{
Models: map[string]config.ModelConfig{
"m1": {ConcurrencyLimit: 1},
},
}
entered := make(chan struct{})
release := make(chan struct{})
var once sync.Once
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
once.Do(func() { close(entered) })
<-release
w.WriteHeader(http.StatusOK)
})
h := CreateConcurrencyMiddleware(cfg)(final)
// First request occupies the only slot.
done := make(chan struct{})
go func() {
defer close(done)
h.ServeHTTP(httptest.NewRecorder(), concurrencyTestReq("m1"))
}()
<-entered
// Second concurrent request is rejected with 429.
w := httptest.NewRecorder()
h.ServeHTTP(w, concurrencyTestReq("m1"))
if w.Code != http.StatusTooManyRequests {
t.Fatalf("over-limit status = %d, want 429", w.Code)
}
// Once the slot frees, a new request succeeds.
close(release)
<-done
w = httptest.NewRecorder()
h.ServeHTTP(w, concurrencyTestReq("m1"))
if w.Code != http.StatusOK {
t.Fatalf("post-release status = %d, want 200", w.Code)
}
}
func TestServer_ConcurrencyMiddleware_UnconfiguredModelPassesThrough(t *testing.T) {
cfg := config.Config{Models: map[string]config.ModelConfig{}}
called := 0
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called++
w.WriteHeader(http.StatusOK)
})
h := CreateConcurrencyMiddleware(cfg)(final)
w := httptest.NewRecorder()
h.ServeHTTP(w, concurrencyTestReq("peer-model"))
if w.Code != http.StatusOK || called != 1 {
t.Fatalf("unconfigured model: status=%d called=%d, want 200/1", w.Code, called)
}
}
+17 -9
View File
@@ -33,15 +33,16 @@ type TokenMetrics struct {
// ActivityLogEntry represents parsed token statistics from llama-server logs. // ActivityLogEntry represents parsed token statistics from llama-server logs.
type ActivityLogEntry struct { type ActivityLogEntry struct {
ID int `json:"id"` ID int `json:"id"`
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
Model string `json:"model"` Model string `json:"model"`
ReqPath string `json:"req_path"` ReqPath string `json:"req_path"`
RespContentType string `json:"resp_content_type"` RespContentType string `json:"resp_content_type"`
RespStatusCode int `json:"resp_status_code"` RespStatusCode int `json:"resp_status_code"`
Tokens TokenMetrics `json:"tokens"` Tokens TokenMetrics `json:"tokens"`
DurationMs int `json:"duration_ms"` DurationMs int `json:"duration_ms"`
HasCapture bool `json:"has_capture"` HasCapture bool `json:"has_capture"`
Metadata map[string]string `json:"metadata,omitempty"`
} }
// ActivityLogEvent carries a single activity log entry to event subscribers. // ActivityLogEvent carries a single activity log entry to event subscribers.
@@ -135,6 +136,13 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()), DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
} }
if ctxData, ok := shared.ReadContext(r.Context()); ok && len(ctxData.Metadata) > 0 {
tm.Metadata = make(map[string]string, len(ctxData.Metadata))
for k, v := range ctxData.Metadata {
tm.Metadata[k] = v
}
}
queueAndEmit := func() { queueAndEmit := func() {
tm.ID = mp.queueMetrics(tm) tm.ID = mp.queueMetrics(tm)
mp.emitMetric(tm) mp.emitMetric(tm)
+31
View File
@@ -1,9 +1,13 @@
package server package server
import ( import (
"net/http"
"net/http/httptest"
"strings"
"testing" "testing"
"time" "time"
"github.com/mostlygeek/llama-swap/internal/shared"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@@ -56,6 +60,33 @@ func TestServer_ProcessStreamingResponse_NoData(t *testing.T) {
} }
} }
func TestMetricsMonitor_RecordMetadata(t *testing.T) {
mm := newMetricsMonitor(nil, 10, 0)
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"usage":{}}`))
r = r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{
ModelID: "m",
Metadata: map[string]string{"client": "web", "trace": "abc"},
}))
w := httptest.NewRecorder()
copier := newBodyCopier(w)
copier.WriteHeader(http.StatusOK)
copier.Write([]byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
mm.record("m", r, copier, 0, nil, nil)
entries := mm.getMetrics()
if len(entries) != 1 {
t.Fatalf("want 1 entry, got %d", len(entries))
}
if entries[0].Metadata["client"] != "web" {
t.Errorf("client = %q, want web", entries[0].Metadata["client"])
}
if entries[0].Metadata["trace"] != "abc" {
t.Errorf("trace = %q, want abc", entries[0].Metadata["trace"])
}
}
func TestServer_ParseMetrics_Infill(t *testing.T) { func TestServer_ParseMetrics_Infill(t *testing.T) {
// /infill responses are arrays; timings live in the last element. // /infill responses are arrays; timings live in the last element.
body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]` body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]`
-1
View File
@@ -177,7 +177,6 @@ func (s *Server) routes() {
modelChain := chain.New( modelChain := chain.New(
authMW, authMW,
CreateRequestContextMiddleware(s.cfg), CreateRequestContextMiddleware(s.cfg),
CreateConcurrencyMiddleware(s.cfg),
CreateFilterMiddleware(s.cfg), CreateFilterMiddleware(s.cfg),
CreateFormFilterMiddleware(s.cfg), CreateFormFilterMiddleware(s.cfg),
CreateInflightMiddleware(s.inflight), CreateInflightMiddleware(s.inflight),
+35
View File
@@ -26,6 +26,9 @@ type ReqContextData struct {
ModelID string ModelID string
Streaming bool Streaming bool
SendLoadingState bool SendLoadingState bool
// Metadata is a request-scoped key/value bag that handlers may mutate
// while processing. The metrics middleware copies it into ActivityLogEntry.
Metadata map[string]string
} }
var ( var (
@@ -37,6 +40,16 @@ var (
) )
func SendError(w http.ResponseWriter, r *http.Request, err error) { func SendError(w http.ResponseWriter, r *http.Request, err error) {
var httpErr HTTPError
if errors.As(err, &httpErr) {
for k, v := range httpErr.Header() {
w.Header()[k] = v
}
w.WriteHeader(httpErr.StatusCode())
w.Write(httpErr.Body())
return
}
switch { switch {
case errors.Is(err, ErrNoModelInContext): case errors.Is(err, ErrNoModelInContext):
SendResponse(w, r, http.StatusNotFound, "no model id could be identified") SendResponse(w, r, http.StatusNotFound, "no model id could be identified")
@@ -113,6 +126,25 @@ func ReadContext(ctx context.Context) (ReqContextData, bool) {
return data, ok return data, ok
} }
// SetReqData attaches a key/value pair to the request context's metadata map.
// The metadata map must already exist in the context's ReqContextData; callers
// should ensure FetchContext has run or initialize the map themselves.
// It returns an error for nil contexts or contexts without request data.
func SetReqData(ctx context.Context, key, value string) error {
if ctx == nil {
return fmt.Errorf("cannot set request metadata on nil context")
}
data, ok := ReadContext(ctx)
if !ok {
return fmt.Errorf("no request context data found")
}
if data.Metadata == nil {
return fmt.Errorf("no metadata map in request context")
}
data.Metadata[key] = value
return nil
}
// extractContext pulls fields from an HTTP request into a ReqContextData, // extractContext pulls fields from an HTTP request into a ReqContextData,
// returning whatever is available. For GET requests it reads query parameters. // returning whatever is available. For GET requests it reads query parameters.
// For POST requests it inspects Content-Type and parses JSON, // For POST requests it inspects Content-Type and parses JSON,
@@ -129,6 +161,7 @@ func extractContext(r *http.Request) (ReqContextData, error) {
Model: q.Get("model"), Model: q.Get("model"),
Streaming: q.Get("stream") == "true", Streaming: q.Get("stream") == "true",
ApiKey: apiKey, ApiKey: apiKey,
Metadata: make(map[string]string),
}, nil }, nil
} }
@@ -147,6 +180,7 @@ func extractContext(r *http.Request) (ReqContextData, error) {
Model: gjson.GetBytes(bodyBytes, "model").String(), Model: gjson.GetBytes(bodyBytes, "model").String(),
Streaming: gjson.GetBytes(bodyBytes, "stream").Bool(), Streaming: gjson.GetBytes(bodyBytes, "stream").Bool(),
ApiKey: apiKey, ApiKey: apiKey,
Metadata: make(map[string]string),
}, nil }, nil
} }
@@ -168,6 +202,7 @@ func extractContext(r *http.Request) (ReqContextData, error) {
Model: r.FormValue("model"), Model: r.FormValue("model"),
Streaming: r.FormValue("stream") == "true", Streaming: r.FormValue("stream") == "true",
ApiKey: apiKey, ApiKey: apiKey,
Metadata: make(map[string]string),
}, nil }, nil
} }
+32
View File
@@ -387,6 +387,38 @@ func TestExtractContext_ApiKey(t *testing.T) {
} }
} }
func TestSetReqData(t *testing.T) {
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3", Metadata: make(map[string]string)})
if err := SetReqData(ctx, "client", "web"); err != nil {
t.Fatalf("SetReqData: %v", err)
}
if err := SetReqData(ctx, "trace", "abc123"); err != nil {
t.Fatalf("SetReqData: %v", err)
}
data, ok := ReadContext(ctx)
if !ok {
t.Fatal("context data missing")
}
if data.Metadata["client"] != "web" {
t.Errorf("client = %q, want %q", data.Metadata["client"], "web")
}
if data.Metadata["trace"] != "abc123" {
t.Errorf("trace = %q, want %q", data.Metadata["trace"], "abc123")
}
}
func TestSetReqData_Errors(t *testing.T) {
if err := SetReqData(context.Background(), "k", "v"); err == nil {
t.Error("expected error when no request context data exists")
}
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
if err := SetReqData(ctx, "k", "v"); err == nil {
t.Error("expected error when metadata map is missing")
}
}
func TestServer_ExtractAPIKey(t *testing.T) { func TestServer_ExtractAPIKey(t *testing.T) {
basicHeader := func(user, pass string) string { basicHeader := func(user, pass string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass)) return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
+63
View File
@@ -0,0 +1,63 @@
package shared
import (
"encoding/json"
"net/http"
"strconv"
)
// HTTPError is an error that carries a complete HTTP response. A producer (e.g.
// a scheduler shedding a request) returns one of these; a renderer (e.g.
// router.SendError) writes the status, headers, and body verbatim instead of
// mapping the error to a generic status. It is the seam that lets a component
// shed a request with a rich response (e.g. a 429 with rate-limit headers and a
// JSON hint body) without the renderer knowing the producer's internals.
type HTTPError interface {
error
StatusCode() int
Header() http.Header
Body() []byte
}
// ConcurrencyLimitError is an HTTPError for a 429 concurrency-limit rejection.
// Zero-value fields fall back to sensible defaults: a 1-second Retry-After and a
// JSON hint body.
type ConcurrencyLimitError struct {
// RetryAfter, when > 0, is sent as the Retry-After header (in seconds).
// Defaults to 1.
RetryAfter int
// Message overrides the JSON body's "error" field. Defaults to
// "Too many requests".
Message string
}
func (e ConcurrencyLimitError) Error() string { return "concurrency limit reached" }
func (e ConcurrencyLimitError) StatusCode() int { return http.StatusTooManyRequests }
func (e ConcurrencyLimitError) Header() http.Header {
h := http.Header{}
h.Set("Content-Type", "application/json")
h.Set("Retry-After", e.retryAfter())
return h
}
func (e ConcurrencyLimitError) Body() []byte {
b, _ := json.Marshal(map[string]string{"error": e.message()})
return b
}
func (e ConcurrencyLimitError) retryAfter() string {
if e.RetryAfter > 0 {
return strconv.Itoa(e.RetryAfter)
}
return "1"
}
func (e ConcurrencyLimitError) message() string {
if e.Message != "" {
return e.Message
}
return "Too many requests"
}
@@ -193,7 +193,7 @@
<dialog <dialog
bind:this={dialogEl} bind:this={dialogEl}
onclose={handleDialogClose} onclose={handleDialogClose}
class="bg-surface text-txtmain rounded-lg shadow-xl max-w-4xl w-full max-h-[90vh] p-0 backdrop:bg-black/50 m-auto" class="bg-surface text-txtmain rounded-lg shadow-xl max-w-[80%] w-full max-h-[90vh] p-0 backdrop:bg-black/50 m-auto"
> >
{#if capture} {#if capture}
<div class="flex flex-col max-h-[90vh]"> <div class="flex flex-col max-h-[90vh]">
@@ -0,0 +1,85 @@
<script lang="ts">
import type { Snippet } from "svelte";
interface Props {
metadata: Record<string, string> | undefined;
children: Snippet;
}
let { metadata, children }: Props = $props();
let entries = $derived(Object.entries(metadata || {}));
let triggerEl: HTMLElement | undefined = $state();
let tooltipEl: HTMLDivElement | undefined = $state();
let show = $state(false);
let tooltipStyle = $state("");
function positionTooltip() {
if (!triggerEl || !tooltipEl) return;
const triggerRect = triggerEl.getBoundingClientRect();
const tooltipRect = tooltipEl.getBoundingClientRect();
const margin = 8;
const viewportWidth = window.innerWidth;
const viewportHeight = window.innerHeight;
let left = triggerRect.left;
let top = triggerRect.bottom + margin;
// Keep tooltip within horizontal viewport bounds
if (left + tooltipRect.width > viewportWidth - margin) {
left = triggerRect.right - tooltipRect.width;
}
if (left < margin) {
left = margin;
}
// Flip above trigger if it would overflow the bottom
if (top + tooltipRect.height > viewportHeight - margin) {
top = triggerRect.top - tooltipRect.height - margin;
}
tooltipStyle = `left: ${left}px; top: ${top}px; max-width: calc(100vw - ${margin * 2}px);`;
}
function onEnter() {
show = true;
requestAnimationFrame(positionTooltip);
}
function onLeave() {
show = false;
}
</script>
<span
bind:this={triggerEl}
onmouseenter={onEnter}
onmouseleave={onLeave}
onfocus={onEnter}
onblur={onLeave}
class="inline-flex"
role="button"
tabindex="0"
aria-label="Show metadata"
>
{@render children()}
</span>
{#if show && entries.length > 0}
<div
bind:this={tooltipEl}
style={tooltipStyle}
class="fixed px-3 py-2 bg-gray-900 text-white text-sm rounded-md z-50 normal-case min-w-[12rem] max-w-[24rem] shadow-lg whitespace-normal"
>
<table class="w-full text-left">
<tbody>
{#each entries as [key, value]}
<tr class="border-b border-white/10 last:border-0">
<td class="py-1 pr-3 font-medium whitespace-nowrap text-primary">{key}</td>
<td class="py-1 break-all">{value}</td>
</tr>
{/each}
</tbody>
</table>
</div>
{/if}
+30 -3
View File
@@ -6,6 +6,8 @@
let isUnloading = $state(false); let isUnloading = $state(false);
let menuOpen = $state(false); let menuOpen = $state(false);
let pendingLoads = $state<Record<string, boolean>>({});
const loadControllers = new Map<string, AbortController>();
const showUnlistedStore = persistentStore<boolean>("showUnlisted", true); const showUnlistedStore = persistentStore<boolean>("showUnlisted", true);
const showIdorNameStore = persistentStore<"id" | "name">("showIdorName", "id"); const showIdorNameStore = persistentStore<"id" | "name">("showIdorName", "id");
@@ -42,6 +44,25 @@
} }
} }
async function handleLoadModel(modelId: string): Promise<void> {
if (pendingLoads[modelId]) return;
const controller = new AbortController();
loadControllers.set(modelId, controller);
pendingLoads[modelId] = true;
try {
await loadModel(modelId, controller.signal);
} catch (e) {
console.error(e);
} finally {
loadControllers.delete(modelId);
delete pendingLoads[modelId];
}
}
function cancelLoad(modelId: string): void {
loadControllers.get(modelId)?.abort();
}
function toggleIdorName(): void { function toggleIdorName(): void {
showIdorNameStore.update((prev) => (prev === "name" ? "id" : "name")); showIdorNameStore.update((prev) => (prev === "name" ? "id" : "name"));
} }
@@ -170,14 +191,20 @@
{/if} {/if}
</td> </td>
<td class="w-12"> <td class="w-12">
{#if model.state === "stopped"} {#if model.state === "stopped" && pendingLoads[model.id]}
<button class="btn btn--sm" onclick={() => loadModel(model.id)}>Load</button> <button class="btn btn--sm" onclick={() => cancelLoad(model.id)}>Cancel</button>
{:else if model.state === "stopped"}
<button class="btn btn--sm" onclick={() => handleLoadModel(model.id)}>Load</button>
{:else} {:else}
<button class="btn btn--sm" onclick={() => unloadSingleModel(model.id)} disabled={model.state !== "ready"}>Unload</button> <button class="btn btn--sm" onclick={() => unloadSingleModel(model.id)} disabled={model.state !== "ready"}>Unload</button>
{/if} {/if}
</td> </td>
<td class="w-20"> <td class="w-20">
<span class="w-16 text-center status status--{model.state}">{model.state}</span> {#if model.state === "stopped" && pendingLoads[model.id]}
<span class="w-16 text-center status status--queued">queued</span>
{:else}
<span class="w-16 text-center status status--{model.state}">{model.state}</span>
{/if}
</td> </td>
</tr> </tr>
{/each} {/each}
+2 -1
View File
@@ -139,7 +139,8 @@
} }
.status--starting, .status--starting,
.status--stopping { .status--stopping,
.status--queued {
@apply bg-warning/10 text-warning; @apply bg-warning/10 text-warning;
} }
+1
View File
@@ -41,6 +41,7 @@ export interface ActivityLogEntry {
tokens: TokenMetrics; tokens: TokenMetrics;
duration_ms: number; duration_ms: number;
has_capture: boolean; has_capture: boolean;
metadata?: Record<string, string>;
} }
export interface ReqRespCapture { export interface ReqRespCapture {
+172 -119
View File
@@ -2,25 +2,13 @@
import { metrics, getCapture } from "../stores/api"; import { metrics, getCapture } from "../stores/api";
import ActivityStats from "../components/ActivityStats.svelte"; import ActivityStats from "../components/ActivityStats.svelte";
import Tooltip from "../components/Tooltip.svelte"; import Tooltip from "../components/Tooltip.svelte";
import MetadataTooltip from "../components/MetadataTooltip.svelte";
import CaptureDialog from "../components/CaptureDialog.svelte"; import CaptureDialog from "../components/CaptureDialog.svelte";
import { persistentStore } from "../stores/persistent"; import { persistentStore } from "../stores/persistent";
import { onMount } from "svelte"; import { onMount } from "svelte";
import type { ReqRespCapture } from "../lib/types"; import type { ReqRespCapture } from "../lib/types";
type ColumnKey = type ColumnKey = string;
| "id"
| "time"
| "model"
| "req_path"
| "resp_status_code"
| "resp_content_type"
| "cached"
| "prompt"
| "generated"
| "prompt_speed"
| "gen_speed"
| "duration"
| "capture";
interface ColumnDef { interface ColumnDef {
key: ColumnKey; key: ColumnKey;
@@ -42,17 +30,21 @@
{ key: "gen_speed", label: "Gen Speed", defaultVisible: true }, { key: "gen_speed", label: "Gen Speed", defaultVisible: true },
{ key: "duration", label: "Duration", defaultVisible: true }, { key: "duration", label: "Duration", defaultVisible: true },
{ key: "capture", label: "Capture", defaultVisible: true }, { key: "capture", label: "Capture", defaultVisible: true },
{ key: "meta", label: "Meta", defaultVisible: false },
]; ];
const defaultVisibleKeys = columns.filter((c) => c.defaultVisible).map((c) => c.key); const defaultVisibleKeys = columns.filter((c) => c.defaultVisible).map((c) => c.key);
const visibleColumns = persistentStore<ColumnKey[]>( const visibleColumns = persistentStore<ColumnKey[]>("activity-columns", defaultVisibleKeys);
"activity-columns", const columnOrder = persistentStore<ColumnKey[]>(
defaultVisibleKeys "activity-column-order",
columns.map((c) => c.key)
); );
let columnsMenuOpen = $state(false); let columnsMenuOpen = $state(false);
let dropdownContainer: HTMLDivElement | null = null; let dropdownContainer: HTMLDivElement | null = null;
let dragKey: ColumnKey | null = $state(null);
let dragOverKey: ColumnKey | null = $state(null);
onMount(() => { onMount(() => {
function handleKeydown(e: KeyboardEvent) { function handleKeydown(e: KeyboardEvent) {
@@ -84,6 +76,84 @@
} }
} }
function isColumnVisible(key: ColumnKey): boolean {
return $visibleColumns.includes(key);
}
function handleDragStart(e: DragEvent, key: ColumnKey) {
dragKey = key;
e.dataTransfer?.setData("text/plain", key);
if (e.dataTransfer) {
e.dataTransfer.effectAllowed = "move";
}
}
function handleDragOver(e: DragEvent, key: ColumnKey) {
e.preventDefault();
if (e.dataTransfer) {
e.dataTransfer.dropEffect = "move";
}
dragOverKey = key;
}
function handleDrop(e: DragEvent, targetKey: ColumnKey) {
e.preventDefault();
if (!dragKey || dragKey === targetKey) return;
const order = [...$columnOrder];
const fromIndex = order.indexOf(dragKey);
let toIndex = order.indexOf(targetKey);
if (fromIndex === -1 || toIndex === -1) return;
order.splice(fromIndex, 1);
if (fromIndex < toIndex) {
toIndex -= 1;
}
order.splice(toIndex, 0, dragKey);
columnOrder.set(order);
}
function handleDragEnd() {
dragKey = null;
dragOverKey = null;
}
let orderedColumns = $derived(
columns.slice().sort((a, b) => {
const aIndex = $columnOrder.indexOf(a.key);
const bIndex = $columnOrder.indexOf(b.key);
if (aIndex === -1 && bIndex === -1) return 0;
if (aIndex === -1) return 1;
if (bIndex === -1) return -1;
return aIndex - bIndex;
})
);
let activeVisibleColumns = $derived(
columns
.filter((c) => isColumnVisible(c.key))
.sort((a, b) => {
const aIndex = $columnOrder.indexOf(a.key);
const bIndex = $columnOrder.indexOf(b.key);
if (aIndex === -1 && bIndex === -1) return 0;
if (aIndex === -1) return 1;
if (bIndex === -1) return -1;
return aIndex - bIndex;
})
.map((c) => c.key)
);
let columnLabelMap = $derived(Object.fromEntries(columns.map((c) => [c.key, c.label])));
$effect(() => {
const staticKeys = new Set(columns.map((c) => c.key));
const order = $columnOrder;
const hasStale = order.some((k) => !staticKeys.has(k));
const missing = columns.filter((c) => !order.includes(c.key)).map((c) => c.key);
if (hasStale || missing.length > 0) {
const cleaned = order.filter((k) => staticKeys.has(k));
columnOrder.set([...cleaned, ...missing]);
}
});
function formatSpeed(speed: number): string { function formatSpeed(speed: number): string {
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s"; return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
} }
@@ -157,22 +227,37 @@
</svg> </svg>
</button> </button>
{#if columnsMenuOpen} {#if columnsMenuOpen}
<div class="absolute right-0 top-full mt-1 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-10 py-1 min-w-[16rem]"> <div class="absolute right-0 top-full mt-1 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-10 py-1 min-w-[16rem]" role="list">
<div class="px-3 py-2 text-xs font-medium uppercase tracking-wider text-gray-500 dark:text-gray-400 border-b border-gray-200 dark:border-white/10"> <div class="px-3 py-2 text-xs font-medium uppercase tracking-wider text-gray-500 dark:text-gray-400 border-b border-gray-200 dark:border-white/10" role="presentation">
Columns Columns
</div> </div>
{#each columns as col (col.key)} {#each orderedColumns as col (col.key)}
<label {@const key = col.key}
class="flex items-center gap-2 px-3 py-1.5 text-sm cursor-pointer hover:bg-secondary-hover transition-colors" <div
class="flex items-center gap-2 px-3 py-1.5 text-sm hover:bg-secondary-hover transition-colors {dragOverKey === key && dragKey !== key ? 'bg-primary/10 ring-1 ring-primary/40' : ''} {dragKey === key ? 'opacity-40' : ''}"
role="listitem"
ondragover={(e) => handleDragOver(e, key)}
ondrop={(e) => handleDrop(e, key)}
> >
<input <span
type="checkbox" class="text-txtsecondary select-none cursor-grab"
checked={$visibleColumns.includes(col.key)} draggable={true}
onchange={() => toggleColumn(col.key)} role="button"
class="rounded" tabindex="-1"
/> aria-label="Drag to reorder {col.label}"
{col.label} ondragstart={(e) => handleDragStart(e, key)}
</label> ondragend={handleDragEnd}
>⋮⋮</span>
<label class="flex items-center gap-2 flex-1 cursor-pointer">
<input
type="checkbox"
checked={isColumnVisible(key)}
onchange={() => toggleColumn(key)}
class="rounded"
/>
{col.label}
</label>
</div>
{/each} {/each}
</div> </div>
{/if} {/if}
@@ -182,112 +267,80 @@
<table class="min-w-full divide-y"> <table class="min-w-full divide-y">
<thead class="border-gray-200 dark:border-white/10"> <thead class="border-gray-200 dark:border-white/10">
<tr class="text-left text-xs uppercase tracking-wider"> <tr class="text-left text-xs uppercase tracking-wider">
{#if $visibleColumns.includes("id")} {#each activeVisibleColumns as key (key)}
<th class="px-6 py-3">ID</th>
{/if}
{#if $visibleColumns.includes("time")}
<th class="px-6 py-3">Time</th>
{/if}
{#if $visibleColumns.includes("model")}
<th class="px-6 py-3">Model</th>
{/if}
{#if $visibleColumns.includes("req_path")}
<th class="px-6 py-3">Path</th>
{/if}
{#if $visibleColumns.includes("resp_status_code")}
<th class="px-6 py-3">Status</th>
{/if}
{#if $visibleColumns.includes("resp_content_type")}
<th class="px-6 py-3">Content-Type</th>
{/if}
{#if $visibleColumns.includes("cached")}
<th class="px-6 py-3"> <th class="px-6 py-3">
Cached <Tooltip content="prompt tokens from cache" /> {#if key === "cached"}
Cached <Tooltip content="prompt tokens from cache" />
{:else if key === "prompt"}
Prompt <Tooltip content="new prompt tokens processed" />
{:else}
{columnLabelMap[key] ?? key}
{/if}
</th> </th>
{/if} {/each}
{#if $visibleColumns.includes("prompt")}
<th class="px-6 py-3">
Prompt <Tooltip content="new prompt tokens processed" />
</th>
{/if}
{#if $visibleColumns.includes("generated")}
<th class="px-6 py-3">Generated</th>
{/if}
{#if $visibleColumns.includes("prompt_speed")}
<th class="px-6 py-3">Prompt Speed</th>
{/if}
{#if $visibleColumns.includes("gen_speed")}
<th class="px-6 py-3">Gen Speed</th>
{/if}
{#if $visibleColumns.includes("duration")}
<th class="px-6 py-3">Duration</th>
{/if}
{#if $visibleColumns.includes("capture")}
<th class="px-6 py-3">Capture</th>
{/if}
</tr> </tr>
</thead> </thead>
<tbody class="divide-y"> <tbody class="divide-y">
{#if sortedMetrics.length === 0} {#if sortedMetrics.length === 0}
<tr> <tr>
<td colspan={$visibleColumns.length} class="px-6 py-8 text-center text-sm text-gray-500 dark:text-gray-400"> <td colspan={activeVisibleColumns.length} class="px-6 py-8 text-center text-sm text-gray-500 dark:text-gray-400">
No activity recorded No activity recorded
</td> </td>
</tr> </tr>
{:else} {:else}
{#each sortedMetrics as metric (metric.id)} {#each sortedMetrics as metric (metric.id)}
<tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10"> <tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
{#if $visibleColumns.includes("id")} {#each activeVisibleColumns as key (key)}
<td class="px-4 py-4">{metric.id + 1}</td>
{/if}
{#if $visibleColumns.includes("time")}
<td class="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td>
{/if}
{#if $visibleColumns.includes("model")}
<td class="px-6 py-4">{metric.model}</td>
{/if}
{#if $visibleColumns.includes("req_path")}
<td class="px-6 py-4">{metric.req_path || "-"}</td>
{/if}
{#if $visibleColumns.includes("resp_status_code")}
<td class="px-6 py-4">{metric.resp_status_code || "-"}</td>
{/if}
{#if $visibleColumns.includes("resp_content_type")}
<td class="px-6 py-4">{metric.resp_content_type || "-"}</td>
{/if}
{#if $visibleColumns.includes("cached")}
<td class="px-6 py-4">{metric.tokens.cache_tokens > 0 ? metric.tokens.cache_tokens.toLocaleString() : "-"}</td>
{/if}
{#if $visibleColumns.includes("prompt")}
<td class="px-6 py-4">{metric.tokens.input_tokens.toLocaleString()}</td>
{/if}
{#if $visibleColumns.includes("generated")}
<td class="px-6 py-4">{metric.tokens.output_tokens.toLocaleString()}</td>
{/if}
{#if $visibleColumns.includes("prompt_speed")}
<td class="px-6 py-4">{formatSpeed(metric.tokens.prompt_per_second)}</td>
{/if}
{#if $visibleColumns.includes("gen_speed")}
<td class="px-6 py-4">{formatSpeed(metric.tokens.tokens_per_second)}</td>
{/if}
{#if $visibleColumns.includes("duration")}
<td class="px-6 py-4">{formatDuration(metric.duration_ms)}</td>
{/if}
{#if $visibleColumns.includes("capture")}
<td class="px-6 py-4"> <td class="px-6 py-4">
{#if metric.has_capture} {#if key === "id"}
<button {metric.id + 1}
onclick={() => viewCapture(metric.id)} {:else if key === "time"}
disabled={loadingCaptureId === metric.id} {formatRelativeTime(metric.timestamp)}
class="btn btn--sm" {:else if key === "model"}
> {metric.model}
{loadingCaptureId === metric.id ? "..." : "View"} {:else if key === "req_path"}
</button> {metric.req_path || "-"}
{:else if key === "resp_status_code"}
{metric.resp_status_code || "-"}
{:else if key === "resp_content_type"}
{metric.resp_content_type || "-"}
{:else if key === "cached"}
{metric.tokens.cache_tokens > 0 ? metric.tokens.cache_tokens.toLocaleString() : "-"}
{:else if key === "prompt"}
{metric.tokens.input_tokens.toLocaleString()}
{:else if key === "generated"}
{metric.tokens.output_tokens.toLocaleString()}
{:else if key === "prompt_speed"}
{formatSpeed(metric.tokens.prompt_per_second)}
{:else if key === "gen_speed"}
{formatSpeed(metric.tokens.tokens_per_second)}
{:else if key === "duration"}
{formatDuration(metric.duration_ms)}
{:else if key === "capture"}
{#if metric.has_capture}
<button
onclick={() => viewCapture(metric.id)}
disabled={loadingCaptureId === metric.id}
class="btn btn--sm"
>
{loadingCaptureId === metric.id ? "..." : "View"}
</button>
{:else}
<span class="text-txtsecondary">-</span>
{/if}
{:else if key === "meta"}
{#if Object.keys(metric.metadata || {}).length > 0}
<MetadataTooltip metadata={metric.metadata}>
<span class="cursor-help text-txtsecondary hover:text-txtmain">...</span>
</MetadataTooltip>
{:else}
<span class="text-txtsecondary">-</span>
{/if}
{:else} {:else}
<span class="text-txtsecondary">-</span> -
{/if} {/if}
</td> </td>
{/if} {/each}
</tr> </tr>
{/each} {/each}
{/if} {/if}
+6 -2
View File
@@ -176,15 +176,19 @@ export async function unloadSingleModel(model: string): Promise<void> {
} }
} }
export async function loadModel(model: string): Promise<void> { export async function loadModel(model: string, signal?: AbortSignal): Promise<void> {
try { try {
const response = await fetch(`/upstream/${model}/`, { const response = await fetch(`/upstream/${model}/?_=${Date.now()}`, {
method: "GET", method: "GET",
signal,
}); });
if (!response.ok) { if (!response.ok) {
throw new Error(`Failed to load model: ${response.status}`); throw new Error(`Failed to load model: ${response.status}`);
} }
} catch (error) { } catch (error) {
if (error instanceof DOMException && error.name === "AbortError") {
return;
}
console.error("Failed to load model:", error); console.error("Failed to load model:", error);
throw error; throw error;
} }