Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 316ad63f76 | |||
| e37077a963 | |||
| eff9b60434 | |||
| 9bcddad91b | |||
| a15e47922c |
@@ -572,6 +572,24 @@
|
|||||||
"default": {},
|
"default": {},
|
||||||
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||||
},
|
},
|
||||||
|
"upstream": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Controls behaviour of the /upstream passthrough endpoint. Recommended to only use in special use cases; leaving it as the default will typically be the best experience.",
|
||||||
|
"properties": {
|
||||||
|
"ignorePaths": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"default": [
|
||||||
|
".*\\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$"
|
||||||
|
],
|
||||||
|
"description": "List of RE2 compatible regular expressions. Any request to a path matching any of the regular expressions will be ignored and not trigger a swap. When not specified, defaults to a pattern matching common static-asset suffixes (.js, .json, .css, .png, .gif, .jpg, .jpeg, .ico, .txt)."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"default": {}
|
||||||
|
},
|
||||||
"routing": {
|
"routing": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"description": "Canonical routing/scheduling configuration. Alternative to the legacy top-level 'groups'/'matrix' keys; a config must not use both styles.",
|
"description": "Canonical routing/scheduling configuration. Alternative to the legacy top-level 'groups'/'matrix' keys; a config must not use both styles.",
|
||||||
|
|||||||
@@ -134,6 +134,18 @@ apiKeys:
|
|||||||
- "${env.API_KEY_1}"
|
- "${env.API_KEY_1}"
|
||||||
- "${env.API_KEY_2}"
|
- "${env.API_KEY_2}"
|
||||||
|
|
||||||
|
# upstream: controls behaviour of the /upstream passthrough endpoint
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - recommended to only use in special use cases. Leaving it as the
|
||||||
|
# default will typically be the best experience
|
||||||
|
upstream:
|
||||||
|
# ignorePaths: list of RE2 compatible regular expressions
|
||||||
|
# - default: (see below)
|
||||||
|
# - any request to a path matching any of the regular expressions
|
||||||
|
# will be ignored and not trigger a swap
|
||||||
|
ignorePaths:
|
||||||
|
- '.*\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$'
|
||||||
|
|
||||||
# models: a dictionary of model configurations
|
# models: a dictionary of model configurations
|
||||||
# - required
|
# - required
|
||||||
# - each key is the model's ID, used in API requests
|
# - each key is the model's ID, used in API requests
|
||||||
|
|||||||
@@ -163,6 +163,9 @@ type Config struct {
|
|||||||
|
|
||||||
// support remote peers, see issue #433, #296
|
// support remote peers, see issue #433, #296
|
||||||
Peers PeerDictionaryConfig `yaml:"peers"`
|
Peers PeerDictionaryConfig `yaml:"peers"`
|
||||||
|
|
||||||
|
// upstream controls behaviour of the /upstream passthrough endpoint
|
||||||
|
Upstream UpstreamConfig `yaml:"upstream"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoutingConfig is the canonical, normalized routing/scheduling configuration.
|
// RoutingConfig is the canonical, normalized routing/scheduling configuration.
|
||||||
@@ -270,6 +273,12 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply default for upstream.ignorePaths when not specified. The default
|
||||||
|
// matches common static-asset suffixes so they do not trigger a swap.
|
||||||
|
if len(config.Upstream.IgnorePaths) == 0 {
|
||||||
|
config.Upstream.IgnorePaths = DefaultUpstreamIgnorePaths()
|
||||||
|
}
|
||||||
|
|
||||||
switch config.LogToStdout {
|
switch config.LogToStdout {
|
||||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -266,6 +266,9 @@ groups:
|
|||||||
"mthree": "model3",
|
"mthree": "model3",
|
||||||
},
|
},
|
||||||
Groups: expectedGroups,
|
Groups: expectedGroups,
|
||||||
|
Upstream: UpstreamConfig{
|
||||||
|
IgnorePaths: DefaultUpstreamIgnorePaths(),
|
||||||
|
},
|
||||||
Routing: RoutingConfig{
|
Routing: RoutingConfig{
|
||||||
Router: RouterConfig{
|
Router: RouterConfig{
|
||||||
Use: "group",
|
Use: "group",
|
||||||
|
|||||||
@@ -255,6 +255,9 @@ groups:
|
|||||||
"mthree": "model3",
|
"mthree": "model3",
|
||||||
},
|
},
|
||||||
Groups: expectedGroups,
|
Groups: expectedGroups,
|
||||||
|
Upstream: UpstreamConfig{
|
||||||
|
IgnorePaths: DefaultUpstreamIgnorePaths(),
|
||||||
|
},
|
||||||
Routing: RoutingConfig{
|
Routing: RoutingConfig{
|
||||||
Router: RouterConfig{
|
Router: RouterConfig{
|
||||||
Use: "group",
|
Use: "group",
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultUpstreamIgnorePathsPattern is the default regular expression applied
|
||||||
|
// to upstream.ignorePaths when the section is empty or absent from the config.
|
||||||
|
// It matches common static-asset suffixes so requests for .js/.css/.png/etc.
|
||||||
|
// files do not trigger a model swap.
|
||||||
|
const DefaultUpstreamIgnorePathsPattern = `.*\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$`
|
||||||
|
|
||||||
|
// DefaultUpstreamIgnorePaths returns the default compiled ignore paths used
|
||||||
|
// when upstream.ignorePaths is not specified in the config. The returned slice
|
||||||
|
// is fresh so callers may mutate it without affecting other configs.
|
||||||
|
func DefaultUpstreamIgnorePaths() []*regexp.Regexp {
|
||||||
|
return []*regexp.Regexp{regexp.MustCompile(DefaultUpstreamIgnorePathsPattern)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamConfig controls behaviour of the /upstream passthrough endpoint.
|
||||||
|
type UpstreamConfig struct {
|
||||||
|
// IgnorePaths is a slice of compiled regular expressions. Any request to
|
||||||
|
// /upstream/<model>/<path> whose remaining path matches any of these
|
||||||
|
// expressions will be ignored and not trigger a swap. When the config
|
||||||
|
// does not specify any patterns, DefaultUpstreamIgnorePaths is applied.
|
||||||
|
IgnorePaths []*regexp.Regexp `yaml:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// rawUpstreamConfig is the intermediate form used to unmarshal the YAML into
|
||||||
|
// plain strings, which are then compiled into *regexp.Regexp.
|
||||||
|
type rawUpstreamConfig struct {
|
||||||
|
IgnorePaths []string `yaml:"ignorePaths"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalYAML compiles each ignorePaths entry into a *regexp.Regexp. If any
|
||||||
|
// entry fails to compile, an error is returned.
|
||||||
|
func (u *UpstreamConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||||
|
var raw rawUpstreamConfig
|
||||||
|
if err := value.Decode(&raw); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
patterns := make([]*regexp.Regexp, 0, len(raw.IgnorePaths))
|
||||||
|
for _, p := range raw.IgnorePaths {
|
||||||
|
re, err := regexp.Compile(p)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("upstream.ignorePaths: invalid regular expression %q: %w", p, err)
|
||||||
|
}
|
||||||
|
patterns = append(patterns, re)
|
||||||
|
}
|
||||||
|
u.IgnorePaths = patterns
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const upstreamConfigHeader = `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
|
||||||
|
func TestConfig_UpstreamIgnorePaths_DefaultWhenAbsent(t *testing.T) {
|
||||||
|
// When upstream is not specified at all, the default pattern is applied.
|
||||||
|
content := upstreamConfigHeader
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, cfg.Upstream.IgnorePaths, 1)
|
||||||
|
|
||||||
|
def := cfg.Upstream.IgnorePaths[0]
|
||||||
|
assert.IsType(t, ®exp.Regexp{}, def)
|
||||||
|
assert.Equal(t, DefaultUpstreamIgnorePathsPattern, def.String())
|
||||||
|
|
||||||
|
// The default matches common static-asset suffixes.
|
||||||
|
assert.True(t, def.MatchString("/foo.js"))
|
||||||
|
assert.True(t, def.MatchString("/bar/baz.json"))
|
||||||
|
assert.True(t, def.MatchString("/static/img.png"))
|
||||||
|
assert.True(t, def.MatchString("/notes.txt"))
|
||||||
|
assert.True(t, def.MatchString("/favicon.ico"))
|
||||||
|
// And does not match inference API paths.
|
||||||
|
assert.False(t, def.MatchString("/v1/chat/completions"))
|
||||||
|
assert.False(t, def.MatchString("/v1/models"))
|
||||||
|
assert.False(t, def.MatchString("/health"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_UpstreamIgnorePaths_DefaultWhenSectionEmpty(t *testing.T) {
|
||||||
|
// When upstream is present but ignorePaths is omitted, the default is still
|
||||||
|
// applied.
|
||||||
|
content := `upstream: {}` + "\n" + upstreamConfigHeader
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, cfg.Upstream.IgnorePaths, 1)
|
||||||
|
assert.Equal(t, DefaultUpstreamIgnorePathsPattern, cfg.Upstream.IgnorePaths[0].String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_UpstreamIgnorePaths_Compiles(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
upstream:
|
||||||
|
ignorePaths:
|
||||||
|
- ".*\\.(js|json|css|png|gif|jpg|jpeg|txt)$"
|
||||||
|
- "^/static/.*"
|
||||||
|
` + upstreamConfigHeader
|
||||||
|
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, cfg.Upstream.IgnorePaths, 2)
|
||||||
|
|
||||||
|
// Verify the patterns are compiled into *regexp.Regexp and match as expected.
|
||||||
|
assert.True(t, cfg.Upstream.IgnorePaths[0].MatchString("/foo.js"))
|
||||||
|
assert.True(t, cfg.Upstream.IgnorePaths[0].MatchString("/bar/baz.json"))
|
||||||
|
assert.False(t, cfg.Upstream.IgnorePaths[0].MatchString("/v1/chat/completions"))
|
||||||
|
assert.True(t, cfg.Upstream.IgnorePaths[1].MatchString("/static/foo.png"))
|
||||||
|
assert.False(t, cfg.Upstream.IgnorePaths[1].MatchString("/v1/chat/completions"))
|
||||||
|
|
||||||
|
// Confirm the type is *regexp.Regexp to satisfy the API contract.
|
||||||
|
for _, re := range cfg.Upstream.IgnorePaths {
|
||||||
|
assert.IsType(t, ®exp.Regexp{}, re)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_UpstreamIgnorePaths_InvalidRegexReturnsError(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
upstream:
|
||||||
|
ignorePaths:
|
||||||
|
- "[invalid("
|
||||||
|
` + upstreamConfigHeader
|
||||||
|
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "upstream.ignorePaths")
|
||||||
|
assert.Contains(t, err.Error(), "invalid regular expression")
|
||||||
|
}
|
||||||
+25
-27
@@ -2,6 +2,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -9,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/event"
|
"github.com/mostlygeek/llama-swap/internal/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -314,7 +316,7 @@ func handleUpstreamRedirect(w http.ResponseWriter, r *http.Request) {
|
|||||||
func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
||||||
upstreamPath := r.PathValue("upstreamPath")
|
upstreamPath := r.PathValue("upstreamPath")
|
||||||
|
|
||||||
searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath)
|
searchName, modelID, remainingPath, found := shared.FindModelInPath(s.cfg, "/"+upstreamPath)
|
||||||
if !found {
|
if !found {
|
||||||
shared.SendResponse(w, r, http.StatusNotFound, "model not found")
|
shared.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||||
return
|
return
|
||||||
@@ -340,6 +342,28 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Pin the resolved model so the router skips body/query extraction.
|
// Pin the resolved model so the router skips body/query extraction.
|
||||||
*r = *r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: searchName, ModelID: modelID, Metadata: make(map[string]string)}))
|
*r = *r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: searchName, ModelID: modelID, Metadata: make(map[string]string)}))
|
||||||
|
|
||||||
|
// If the path matches an upstream.ignorePaths entry and the model is
|
||||||
|
// not already loaded, refuse the request without triggering a swap. The
|
||||||
|
// server was not able to process the response because the model was not
|
||||||
|
// already loaded.
|
||||||
|
for _, re := range s.cfg.Upstream.IgnorePaths {
|
||||||
|
if !re.MatchString(remainingPath) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if s.local.Handles(modelID) {
|
||||||
|
state, ok := s.local.RunningModels()[modelID]
|
||||||
|
if !ok || state != process.StateReady {
|
||||||
|
shared.SendResponse(w, r, http.StatusConflict,
|
||||||
|
fmt.Sprintf("model %s is not loaded; path matches upstream.ignorePaths", modelID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Either the model is already loaded (no swap would be triggered)
|
||||||
|
// or this is a peer model (peer proxying never swaps). Fall through
|
||||||
|
// to normal dispatch.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case s.local.Handles(modelID):
|
case s.local.Handles(modelID):
|
||||||
s.local.ServeHTTP(w, r)
|
s.local.ServeHTTP(w, r)
|
||||||
@@ -349,29 +373,3 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
|||||||
shared.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
|
shared.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// findModelInPath walks a slash-separated path, building up segments until one
|
|
||||||
// matches a configured model. This resolves model names that contain slashes
|
|
||||||
// (e.g. "author/model"). Returns the matched name, its real model ID, the
|
|
||||||
// remaining path, and whether a match was found.
|
|
||||||
func findModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
|
|
||||||
parts := strings.Split(strings.TrimSpace(path), "/")
|
|
||||||
name := ""
|
|
||||||
|
|
||||||
for i, part := range parts {
|
|
||||||
if part == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if name == "" {
|
|
||||||
name = part
|
|
||||||
} else {
|
|
||||||
name = name + "/" + part
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelID, ok := cfg.RealModelName(name); ok {
|
|
||||||
return name, modelID, "/" + strings.Join(parts[i+1:], "/"), true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", "", "", false
|
|
||||||
}
|
|
||||||
|
|||||||
+169
-2
@@ -2,11 +2,17 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestServer_HandleListModels(t *testing.T) {
|
func TestServer_HandleListModels(t *testing.T) {
|
||||||
@@ -78,6 +84,7 @@ func TestServer_HandleListModels_Aliases(t *testing.T) {
|
|||||||
|
|
||||||
func TestServer_FindModelInPath(t *testing.T) {
|
func TestServer_FindModelInPath(t *testing.T) {
|
||||||
cfg := config.Config{Models: map[string]config.ModelConfig{
|
cfg := config.Config{Models: map[string]config.ModelConfig{
|
||||||
|
"author": {},
|
||||||
"author/model": {},
|
"author/model": {},
|
||||||
"simple": {},
|
"simple": {},
|
||||||
}}
|
}}
|
||||||
@@ -91,13 +98,14 @@ func TestServer_FindModelInPath(t *testing.T) {
|
|||||||
{"/simple/v1/chat", "simple", "/v1/chat", true},
|
{"/simple/v1/chat", "simple", "/v1/chat", true},
|
||||||
{"/author/model/v1/chat", "author/model", "/v1/chat", true},
|
{"/author/model/v1/chat", "author/model", "/v1/chat", true},
|
||||||
{"/author/model", "author/model", "/", true},
|
{"/author/model", "author/model", "/", true},
|
||||||
|
{"/author/v1/chat", "author", "/v1/chat", true},
|
||||||
{"/missing/v1", "", "", false},
|
{"/missing/v1", "", "", false},
|
||||||
{"/", "", "", false},
|
{"/", "", "", false},
|
||||||
}
|
}
|
||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
name, _, rem, found := findModelInPath(cfg, c.path)
|
name, _, rem, found := shared.FindModelInPath(cfg, c.path)
|
||||||
if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) {
|
if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) {
|
||||||
t.Errorf("findModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)",
|
t.Errorf("FindModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)",
|
||||||
c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound)
|
c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -133,6 +141,165 @@ func TestServer_HandleUpstream(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func upstreamMetricsServer(response string) *Server {
|
||||||
|
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||||
|
proxylog := logmon.NewWriter(io.Discard)
|
||||||
|
s := &Server{
|
||||||
|
cfg: cfg,
|
||||||
|
muxlog: logmon.NewWriter(io.Discard),
|
||||||
|
proxylog: proxylog,
|
||||||
|
upstreamlog: logmon.NewWriter(io.Discard),
|
||||||
|
inflight: &inflightCounter{},
|
||||||
|
metrics: newMetricsMonitor(proxylog, 10, 0),
|
||||||
|
local: newStubRouter([]string{"m1"}, response),
|
||||||
|
peer: newStubRouter(nil, ""),
|
||||||
|
}
|
||||||
|
s.routes()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleUpstream_IgnorePaths(t *testing.T) {
|
||||||
|
// Compile a pattern that matches static asset suffixes.
|
||||||
|
pattern := regexp.MustCompile(`.*\.(js|json|css|png|gif|jpg|jpeg|txt)$`)
|
||||||
|
|
||||||
|
t.Run("matched path, model not loaded, returns 409", func(t *testing.T) {
|
||||||
|
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||||
|
// running is nil/empty: model is not in RunningModels() => not loaded.
|
||||||
|
s := newTestServer(local, newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{"m1": {}},
|
||||||
|
Upstream: config.UpstreamConfig{
|
||||||
|
IgnorePaths: []*regexp.Regexp{pattern},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusConflict {
|
||||||
|
t.Fatalf("status = %d, want %d (body=%q)", w.Code, http.StatusConflict, w.Body.String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(w.Body.String(), "not loaded") {
|
||||||
|
t.Errorf("body = %q, want it to contain 'not loaded'", w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("matched path, model already loaded, serves normally", func(t *testing.T) {
|
||||||
|
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||||
|
local.running = map[string]process.ProcessState{"m1": process.StateReady}
|
||||||
|
s := newTestServer(local, newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{"m1": {}},
|
||||||
|
Upstream: config.UpstreamConfig{
|
||||||
|
IgnorePaths: []*regexp.Regexp{pattern},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
|
||||||
|
t.Fatalf("status=%d body=%q, want 200 'upstream-body'", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-matched path, model not loaded, serves normally", func(t *testing.T) {
|
||||||
|
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||||
|
s := newTestServer(local, newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{"m1": {}},
|
||||||
|
Upstream: config.UpstreamConfig{
|
||||||
|
IgnorePaths: []*regexp.Regexp{pattern},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat/completions", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
|
||||||
|
t.Fatalf("status=%d body=%q, want 200 'upstream-body'", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("matched path, peer model, serves normally", func(t *testing.T) {
|
||||||
|
// Peer routers do not appear via RunningModels on the local router;
|
||||||
|
// they should fall through to normal dispatch without 409.
|
||||||
|
local := newStubRouter(nil, "")
|
||||||
|
peer := newStubRouter([]string{"m1"}, "peer-body")
|
||||||
|
s := newTestServer(local, peer)
|
||||||
|
s.cfg = config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{"m1": {}},
|
||||||
|
Upstream: config.UpstreamConfig{
|
||||||
|
IgnorePaths: []*regexp.Regexp{pattern},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "peer-body" {
|
||||||
|
t.Fatalf("status=%d body=%q, want 200 'peer-body'", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleUpstream_MetricsRecordsSupportedPath(t *testing.T) {
|
||||||
|
resp := `{"usage":{"prompt_tokens":3,"completion_tokens":5}}`
|
||||||
|
s := upstreamMetricsServer(resp)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/chat/completions", strings.NewReader(`{}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
s.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != resp {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
entries := s.metrics.getMetrics()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("want 1 metrics entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Model != "m1" {
|
||||||
|
t.Errorf("model = %q, want m1", entries[0].Model)
|
||||||
|
}
|
||||||
|
if entries[0].ReqPath != "/v1/chat/completions" {
|
||||||
|
t.Errorf("req_path = %q, want /v1/chat/completions", entries[0].ReqPath)
|
||||||
|
}
|
||||||
|
if entries[0].Tokens.InputTokens != 3 || entries[0].Tokens.OutputTokens != 5 {
|
||||||
|
t.Errorf("tokens = %+v, want input=3 output=5", entries[0].Tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleUpstream_MetricsSkipsUnsupportedPath(t *testing.T) {
|
||||||
|
s := upstreamMetricsServer("ok")
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/probe", strings.NewReader(`{}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
s.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "ok" {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if len(s.metrics.getMetrics()) != 0 {
|
||||||
|
t.Errorf("want no metrics entries for unsupported path, got %d", len(s.metrics.getMetrics()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleUpstream_MetricsSkipsGET(t *testing.T) {
|
||||||
|
s := upstreamMetricsServer(`{"usage":{}}`)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat/completions", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d", w.Code)
|
||||||
|
}
|
||||||
|
if len(s.metrics.getMetrics()) != 0 {
|
||||||
|
t.Errorf("want no metrics entries for GET upstream, got %d", len(s.metrics.getMetrics()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer_HandleMetrics_Unavailable(t *testing.T) {
|
func TestServer_HandleMetrics_Unavailable(t *testing.T) {
|
||||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
|
||||||
|
|||||||
@@ -105,7 +105,9 @@ func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
|||||||
// filtered to samples after the ?after=<RFC3339> timestamp.
|
// filtered to samples after the ?after=<RFC3339> timestamp.
|
||||||
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
||||||
if s.perf == nil {
|
if s.perf == nil {
|
||||||
shared.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
json.NewEncoder(w).Encode(map[string]bool{"enabled": false})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,6 +138,7 @@ func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
json.NewEncoder(w).Encode(map[string]any{
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"enabled": true,
|
||||||
"sys_stats": sysStats,
|
"sys_stats": sysStats,
|
||||||
"gpu_stats": gpuStats,
|
"gpu_stats": gpuStats,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ func (s *Server) getLogger(logMonitorID string) (*logmon.Monitor, error) {
|
|||||||
case "upstream":
|
case "upstream":
|
||||||
return s.upstreamlog, nil
|
return s.upstreamlog, nil
|
||||||
default:
|
default:
|
||||||
if _, modelID, _, found := findModelInPath(s.cfg, "/"+logMonitorID); found {
|
if _, modelID, _, found := shared.FindModelInPath(s.cfg, "/"+logMonitorID); found {
|
||||||
if log, ok := s.local.ProcessLogger(modelID); ok {
|
if log, ok := s.local.ProcessLogger(modelID); ok {
|
||||||
return log, nil
|
return log, nil
|
||||||
}
|
}
|
||||||
|
|||||||
+100
-9
@@ -25,6 +25,8 @@ import (
|
|||||||
// TokenMetrics holds token usage and performance metrics.
|
// TokenMetrics holds token usage and performance metrics.
|
||||||
type TokenMetrics struct {
|
type TokenMetrics struct {
|
||||||
CachedTokens int `json:"cache_tokens"`
|
CachedTokens int `json:"cache_tokens"`
|
||||||
|
DraftTokens int `json:"draft_tokens"`
|
||||||
|
DraftAccTokens int `json:"draft_acc_tokens"`
|
||||||
InputTokens int `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||||
@@ -42,6 +44,7 @@ type ActivityLogEntry struct {
|
|||||||
Tokens TokenMetrics `json:"tokens"`
|
Tokens TokenMetrics `json:"tokens"`
|
||||||
DurationMs int `json:"duration_ms"`
|
DurationMs int `json:"duration_ms"`
|
||||||
HasCapture bool `json:"has_capture"`
|
HasCapture bool `json:"has_capture"`
|
||||||
|
ErrorMsg string `json:"error_msg,omitempty"`
|
||||||
Metadata map[string]string `json:"metadata,omitempty"`
|
Metadata map[string]string `json:"metadata,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,9 +126,11 @@ func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// record parses a completed response body and stores/emits an activity entry.
|
// record parses a completed response body and stores/emits an activity entry.
|
||||||
// When captures are enabled, a zstd+CBOR capture is stored for successful
|
// Successful requests store a zstd+CBOR capture (when enabled) with cf
|
||||||
// requests, with cf controlling which request/response parts are retained.
|
// controlling which parts are retained. Failed (non-200) requests capture the
|
||||||
// reqBody and reqHeaders are the request data buffered before dispatch.
|
// request only and set ErrorMsg to a description of the failure, so the error
|
||||||
|
// can be inspected without storing unreadable raw response bytes. reqBody and
|
||||||
|
// reqHeaders are the request data buffered before dispatch.
|
||||||
func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string) {
|
func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string) {
|
||||||
tm := ActivityLogEntry{
|
tm := ActivityLogEntry{
|
||||||
Timestamp: time.Now(),
|
Timestamp: time.Now(),
|
||||||
@@ -150,7 +155,13 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
|||||||
|
|
||||||
if recorder.Status() != http.StatusOK {
|
if recorder.Status() != http.StatusOK {
|
||||||
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), r.URL.Path)
|
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), r.URL.Path)
|
||||||
queueAndEmit()
|
decoded, decErr := mp.decodeResponseBody(recorder, r.URL.Path)
|
||||||
|
tm.ErrorMsg = failedErrorMessage(recorder.Status(), decoded, decErr)
|
||||||
|
tm.ID = mp.queueMetrics(tm)
|
||||||
|
// Capture the request only; the failure is surfaced via ErrorMsg
|
||||||
|
// rather than storing the (possibly undisplayable) response body.
|
||||||
|
tm.HasCapture = mp.storeCapture(tm.ID, r, recorder, cf&^captureRespBody, reqBody, reqHeaders, nil)
|
||||||
|
mp.emitMetric(tm)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,6 +176,7 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
|||||||
decoded, err := decompressBody(body, encoding)
|
decoded, err := decompressBody(body, encoding)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
||||||
|
tm.ErrorMsg = fmt.Sprintf("response decompression failed: %v", err)
|
||||||
queueAndEmit()
|
queueAndEmit()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -203,9 +215,20 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
|||||||
}
|
}
|
||||||
|
|
||||||
tm.ID = mp.queueMetrics(tm)
|
tm.ID = mp.queueMetrics(tm)
|
||||||
if mp.enableCaptures {
|
tm.HasCapture = mp.storeCapture(tm.ID, r, recorder, cf, reqBody, reqHeaders, body)
|
||||||
|
mp.emitMetric(tm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// storeCapture assembles a ReqRespCapture for id, honoring the captureFields
|
||||||
|
// mask, and stores it when captures are enabled. body is the response body to
|
||||||
|
// capture (already decompressed by the caller); pass nil to omit it. Returns
|
||||||
|
// true if a capture was stored.
|
||||||
|
func (mp *metricsMonitor) storeCapture(id int, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string, body []byte) bool {
|
||||||
|
if !mp.enableCaptures {
|
||||||
|
return false
|
||||||
|
}
|
||||||
capture := ReqRespCapture{
|
capture := ReqRespCapture{
|
||||||
ID: tm.ID,
|
ID: id,
|
||||||
ReqPath: r.URL.Path,
|
ReqPath: r.URL.Path,
|
||||||
ReqHeaders: reqHeaders,
|
ReqHeaders: reqHeaders,
|
||||||
}
|
}
|
||||||
@@ -220,11 +243,71 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
|||||||
if cf&captureRespBody != 0 {
|
if cf&captureRespBody != 0 {
|
||||||
capture.RespBody = body
|
capture.RespBody = body
|
||||||
}
|
}
|
||||||
if mp.addCapture(capture) {
|
return mp.addCapture(capture)
|
||||||
tm.HasCapture = true
|
}
|
||||||
|
|
||||||
|
// decodeResponseBody returns the buffered response body, decompressing it when
|
||||||
|
// the upstream set a Content-Encoding we recognize. On decompression failure it
|
||||||
|
// logs a warning and returns an error so the caller can record a description
|
||||||
|
// (via ErrorMsg) instead of storing unreadable raw bytes.
|
||||||
|
func (mp *metricsMonitor) decodeResponseBody(recorder *responseBodyCopier, path string) ([]byte, error) {
|
||||||
|
body := recorder.body.Bytes()
|
||||||
|
if len(body) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
encoding := recorder.Header().Get("Content-Encoding")
|
||||||
|
if encoding == "" {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
decoded, err := decompressBody(body, encoding)
|
||||||
|
if err != nil {
|
||||||
|
mp.logger.Warnf("metrics: response decompression failed: %v, path=%s", err, path)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorMessagePaths lists JSON paths where a human-readable error message can
|
||||||
|
// live across OpenAI- and llama.cpp-style error responses.
|
||||||
|
var errorMessagePaths = []string{"error.message", "error", "message", "detail"}
|
||||||
|
|
||||||
|
// extractErrorMessage pulls a human-readable error string from a JSON error
|
||||||
|
// response. Returns "" if no message is found or the body is not valid JSON.
|
||||||
|
func extractErrorMessage(body []byte) string {
|
||||||
|
if !gjson.ValidBytes(body) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
parsed := gjson.ParseBytes(body)
|
||||||
|
for _, path := range errorMessagePaths {
|
||||||
|
v := parsed.Get(path)
|
||||||
|
if v.Exists() && v.Type == gjson.String {
|
||||||
|
if s := strings.TrimSpace(v.String()); s != "" {
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
mp.emitMetric(tm)
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// failedErrorMessage builds a human-readable description for a non-200 response.
|
||||||
|
// It prefers an error message parsed from the (decompressed) body and falls back
|
||||||
|
// to the HTTP status text. A non-nil decErr indicates the body could not be
|
||||||
|
// decoded, in which case the decode error is described instead.
|
||||||
|
func failedErrorMessage(status int, body []byte, decErr error) string {
|
||||||
|
const maxLen = 500
|
||||||
|
if decErr != nil {
|
||||||
|
return fmt.Sprintf("response decode failed: %v", decErr)
|
||||||
|
}
|
||||||
|
if msg := extractErrorMessage(body); msg != "" {
|
||||||
|
if len(msg) > maxLen {
|
||||||
|
msg = msg[:maxLen] + "..."
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
if text := http.StatusText(status); text != "" {
|
||||||
|
return fmt.Sprintf("%d %s", status, text)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("HTTP %d", status)
|
||||||
}
|
}
|
||||||
|
|
||||||
// usagePaths lists the JSON paths where a per-event usage object can live.
|
// usagePaths lists the JSON paths where a per-event usage object can live.
|
||||||
@@ -345,6 +428,8 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
|
|||||||
durationMs := wallDurationMs
|
durationMs := wallDurationMs
|
||||||
tokensPerSecond := -1.0
|
tokensPerSecond := -1.0
|
||||||
promptPerSecond := -1.0
|
promptPerSecond := -1.0
|
||||||
|
draftTokens := -1
|
||||||
|
draftAccTokens := -1
|
||||||
|
|
||||||
if timings.Exists() {
|
if timings.Exists() {
|
||||||
inputTokens = timings.Get("prompt_n").Int()
|
inputTokens = timings.Get("prompt_n").Int()
|
||||||
@@ -358,6 +443,10 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
|
|||||||
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||||
cachedTokens = cachedValue.Int()
|
cachedTokens = cachedValue.Int()
|
||||||
}
|
}
|
||||||
|
if timings.Get("draft_n").Exists() && timings.Get("draft_n_accepted").Exists() {
|
||||||
|
draftTokens = int(timings.Get("draft_n").Int())
|
||||||
|
draftAccTokens = int(timings.Get("draft_n_accepted").Int())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ActivityLogEntry{
|
return ActivityLogEntry{
|
||||||
@@ -365,6 +454,8 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
|
|||||||
Model: modelID,
|
Model: modelID,
|
||||||
Tokens: TokenMetrics{
|
Tokens: TokenMetrics{
|
||||||
CachedTokens: int(cachedTokens),
|
CachedTokens: int(cachedTokens),
|
||||||
|
DraftTokens: draftTokens,
|
||||||
|
DraftAccTokens: draftAccTokens,
|
||||||
InputTokens: int(inputTokens),
|
InputTokens: int(inputTokens),
|
||||||
OutputTokens: int(outputTokens),
|
OutputTokens: int(outputTokens),
|
||||||
PromptPerSecond: promptPerSecond,
|
PromptPerSecond: promptPerSecond,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
@@ -21,8 +22,27 @@ func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middle
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Determine the model-routed endpoint path. Regular routes are
|
||||||
|
// already meterable; /upstream/<model>/<path> is metered only when
|
||||||
|
// the remaining path matches a model-dispatched endpoint.
|
||||||
|
checkPath := r.URL.Path
|
||||||
|
if strings.HasPrefix(r.URL.Path, "/upstream/") {
|
||||||
|
var found bool
|
||||||
|
_, _, checkPath, found = shared.FindModelInPath(cfg, strings.TrimPrefix(r.URL.Path, "/upstream"))
|
||||||
|
if !found {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isMetricsRecordPath(checkPath) {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Resolve the model now so downstream dispatch hits the context
|
// Resolve the model now so downstream dispatch hits the context
|
||||||
// fast path; FetchContext restores the request body.
|
// fast path; FetchContext restores the request body for regular
|
||||||
|
// routes and extracts the model from the URL for /upstream routes.
|
||||||
data, err := shared.FetchContext(r, cfg)
|
data, err := shared.FetchContext(r, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
shared.SendError(w, r, shared.ErrNoModelInContext)
|
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||||
@@ -31,7 +51,7 @@ func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middle
|
|||||||
|
|
||||||
// Buffer the request body/headers for capture before dispatch
|
// Buffer the request body/headers for capture before dispatch
|
||||||
// consumes them.
|
// consumes them.
|
||||||
cf := captureFieldsFor(r.URL.Path)
|
cf := captureFieldsFor(checkPath)
|
||||||
var reqBody []byte
|
var reqBody []byte
|
||||||
var reqHeaders map[string]string
|
var reqHeaders map[string]string
|
||||||
if mm.enableCaptures {
|
if mm.enableCaptures {
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
@@ -87,6 +90,172 @@ func TestMetricsMonitor_RecordMetadata(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_RecordFailedRequestCapture(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
reqHeaders := map[string]string{"content-type": "application/json"}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
copier := newBodyCopier(w)
|
||||||
|
copier.Header().Set("Content-Type", "application/json")
|
||||||
|
copier.WriteHeader(http.StatusBadGateway)
|
||||||
|
copier.Write([]byte(`{"error":{"message":"model unavailable"}}`))
|
||||||
|
|
||||||
|
reqBody := []byte(`{"model":"m","messages":[]}`)
|
||||||
|
mm.record("m", r, copier, captureAll, reqBody, reqHeaders)
|
||||||
|
|
||||||
|
entries := mm.getMetrics()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
entry := entries[0]
|
||||||
|
if entry.RespStatusCode != http.StatusBadGateway {
|
||||||
|
t.Errorf("status = %d, want %d", entry.RespStatusCode, http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
if entry.ErrorMsg != "model unavailable" {
|
||||||
|
t.Errorf("error_msg = %q, want extracted message", entry.ErrorMsg)
|
||||||
|
}
|
||||||
|
if !entry.HasCapture {
|
||||||
|
t.Fatal("failed request should capture the request so it can be inspected")
|
||||||
|
}
|
||||||
|
|
||||||
|
got := mm.getCaptureByID(entry.ID)
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("capture not found")
|
||||||
|
}
|
||||||
|
if string(got.ReqBody) != `{"model":"m","messages":[]}` {
|
||||||
|
t.Errorf("req body = %q", got.ReqBody)
|
||||||
|
}
|
||||||
|
if len(got.RespBody) != 0 {
|
||||||
|
t.Errorf("resp body stored for failed request (len=%d); want none", len(got.RespBody))
|
||||||
|
}
|
||||||
|
if got.RespHeaders["Content-Type"] != "application/json" {
|
||||||
|
t.Errorf("resp Content-Type = %q", got.RespHeaders["Content-Type"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_RecordFailedRequestStatusFallback(t *testing.T) {
|
||||||
|
// Non-JSON error body: ErrorMsg falls back to the HTTP status text.
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
copier := newBodyCopier(w)
|
||||||
|
copier.WriteHeader(http.StatusBadGateway)
|
||||||
|
copier.Write([]byte("<html>upstream down</html>"))
|
||||||
|
|
||||||
|
mm.record("m", r, copier, captureAll, nil, nil)
|
||||||
|
|
||||||
|
entries := mm.getMetrics()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].ErrorMsg != "502 Bad Gateway" {
|
||||||
|
t.Errorf("error_msg = %q, want status text", entries[0].ErrorMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_RecordFailedRequestCaptureDisabled(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 0) // captures disabled
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
copier := newBodyCopier(w)
|
||||||
|
copier.WriteHeader(http.StatusInternalServerError)
|
||||||
|
copier.Write([]byte(`{"error":"boom"}`))
|
||||||
|
|
||||||
|
mm.record("m", r, copier, captureAll, []byte("req"), nil)
|
||||||
|
|
||||||
|
entries := mm.getMetrics()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].HasCapture {
|
||||||
|
t.Fatal("captures disabled, HasCapture should be false")
|
||||||
|
}
|
||||||
|
// ErrorMsg is independent of whether captures are enabled.
|
||||||
|
if entries[0].ErrorMsg != "boom" {
|
||||||
|
t.Errorf("error_msg = %q, want boom", entries[0].ErrorMsg)
|
||||||
|
}
|
||||||
|
if mm.getCaptureByID(entries[0].ID) != nil {
|
||||||
|
t.Fatal("no capture should be stored when disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_RecordDecompressionFailureSetsErrorMsg(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
copier := newBodyCopier(w)
|
||||||
|
copier.Header().Set("Content-Encoding", "gzip")
|
||||||
|
copier.WriteHeader(http.StatusOK)
|
||||||
|
copier.Write([]byte("not-really-gzip"))
|
||||||
|
|
||||||
|
mm.record("m", r, copier, captureAll, []byte("req"), nil)
|
||||||
|
|
||||||
|
entries := mm.getMetrics()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].ErrorMsg == "" {
|
||||||
|
t.Fatal("expected ErrorMsg for decompression failure")
|
||||||
|
}
|
||||||
|
// Raw bytes must not be stored when the body could not be decoded.
|
||||||
|
if entries[0].HasCapture {
|
||||||
|
t.Fatal("decompression failure should not store a capture")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_DecodeResponseBody(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||||
|
|
||||||
|
// No Content-Encoding: body returned unchanged.
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
copier := newBodyCopier(w)
|
||||||
|
copier.Write([]byte("plain"))
|
||||||
|
got, err := mm.decodeResponseBody(copier, "/p")
|
||||||
|
if err != nil || string(got) != "plain" {
|
||||||
|
t.Fatalf("plain body = %q, err = %v", got, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bogus gzip payload: returns an error and no body (no raw bytes kept).
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
copier2 := newBodyCopier(w2)
|
||||||
|
copier2.Header().Set("Content-Encoding", "gzip")
|
||||||
|
copier2.Write([]byte("not-really-gzip"))
|
||||||
|
got, err = mm.decodeResponseBody(copier2, "/p")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected decompression error")
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Errorf("expected nil body on failure, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ExtractErrorMessage(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"openai object", `{"error":{"message":"rate limited"}}`, "rate limited"},
|
||||||
|
{"string error", `{"error":"bad request"}`, "bad request"},
|
||||||
|
{"message field", `{"message":"nope"}`, "nope"},
|
||||||
|
{"detail field", `{"detail":"oops"}`, "oops"},
|
||||||
|
{"object error ignored", `{"error":{"code":42}}`, ""},
|
||||||
|
{"no error", `{"usage":{}}`, ""},
|
||||||
|
{"invalid json", `not-json`, ""},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if got := extractErrorMessage([]byte(tc.body)); got != tc.want {
|
||||||
|
t.Errorf("extractErrorMessage = %q, want %q", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer_ParseMetrics_Infill(t *testing.T) {
|
func TestServer_ParseMetrics_Infill(t *testing.T) {
|
||||||
// /infill responses are arrays; timings live in the last element.
|
// /infill responses are arrays; timings live in the last element.
|
||||||
body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]`
|
body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]`
|
||||||
@@ -103,3 +272,40 @@ func TestServer_ParseMetrics_Infill(t *testing.T) {
|
|||||||
t.Fatalf("tokens = %+v", entry.Tokens)
|
t.Fatalf("tokens = %+v", entry.Tokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestServer_MetricsMiddleware_UpstreamAudioCaptureSkipsRespBody verifies that
|
||||||
|
// an /upstream/<model>/v1/audio/speech request uses the path-specific capture
|
||||||
|
// mask (headers only) rather than falling back to captureAll.
|
||||||
|
func TestServer_MetricsMiddleware_UpstreamAudioCaptureSkipsRespBody(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5)
|
||||||
|
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||||
|
|
||||||
|
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "audio/mpeg")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("BINARY-AUDIO-DATA"))
|
||||||
|
})
|
||||||
|
handler := CreateMetricsMiddleware(mm, cfg)(inner)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/audio/speech", strings.NewReader(`{"model":"m1"}`))
|
||||||
|
handler.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
|
||||||
|
entries := mm.getMetrics()
|
||||||
|
if len(entries) == 0 {
|
||||||
|
t.Fatal("no metrics recorded")
|
||||||
|
}
|
||||||
|
last := entries[len(entries)-1]
|
||||||
|
if !last.HasCapture {
|
||||||
|
t.Fatal("expected capture to be stored")
|
||||||
|
}
|
||||||
|
cap := mm.getCaptureByID(last.ID)
|
||||||
|
if cap == nil {
|
||||||
|
t.Fatal("capture not found")
|
||||||
|
}
|
||||||
|
if len(cap.RespBody) != 0 {
|
||||||
|
t.Errorf("RespBody stored for /upstream audio route (len=%d); want path-specific mask to skip body", len(cap.RespBody))
|
||||||
|
}
|
||||||
|
if len(cap.RespHeaders) == 0 {
|
||||||
|
t.Error("RespHeaders not stored; want captureRespHeaders mask")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -89,6 +89,27 @@ var modelGetRoutes = []string{
|
|||||||
"/sdapi/v1/loras",
|
"/sdapi/v1/loras",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isMetricsRecordPath reports whether path is one of the model-dispatched
|
||||||
|
// endpoints that the metrics middleware records in the activity log.
|
||||||
|
func isMetricsRecordPath(path string) bool {
|
||||||
|
for _, p := range modelPostJSONRoutes {
|
||||||
|
if p == path {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, p := range modelPostFormRoutes {
|
||||||
|
if p == path {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, p := range modelGetRoutes {
|
||||||
|
if p == path {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// BuildInfo carries version metadata surfaced by GET /api/version.
|
// BuildInfo carries version metadata surfaced by GET /api/version.
|
||||||
type BuildInfo struct {
|
type BuildInfo struct {
|
||||||
Version string
|
Version string
|
||||||
@@ -219,9 +240,11 @@ func (s *Server) routes() {
|
|||||||
mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload))
|
mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload))
|
||||||
mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning))
|
mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning))
|
||||||
|
|
||||||
// Upstream passthrough.
|
// Upstream passthrough. Meter only the model-dispatched endpoints that can
|
||||||
|
// produce token usage/timings.
|
||||||
|
upstreamChain := apiChain.Append(CreateMetricsMiddleware(s.metrics, s.cfg))
|
||||||
mux.HandleFunc("GET /upstream", handleUpstreamRedirect)
|
mux.HandleFunc("GET /upstream", handleUpstreamRedirect)
|
||||||
mux.Handle("/upstream/{upstreamPath...}", apiChain.ThenFunc(s.handleUpstream))
|
mux.Handle("/upstream/{upstreamPath...}", upstreamChain.ThenFunc(s.handleUpstream))
|
||||||
|
|
||||||
// API group (API-key protected) consumed by the UI.
|
// API group (API-key protected) consumed by the UI.
|
||||||
mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll))
|
mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll))
|
||||||
|
|||||||
+65
-4
@@ -91,16 +91,24 @@ func SendResponse(w http.ResponseWriter, r *http.Request, status int, message st
|
|||||||
w.Write(resp)
|
w.Write(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchContext will attempt to get the model id from the context then
|
// FetchContext will attempt to get the model id from the context, then
|
||||||
// from the model body. If it extracts the model from the body it will
|
// from an /upstream/<model> path prefix, then from the request body/query.
|
||||||
// store the model in the context for downstream handlers. An error
|
// If it extracts the model it will store it in the context for downstream
|
||||||
// will be returned when model can not be fetch from either location.
|
// handlers. An error will be returned when a model cannot be identified.
|
||||||
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
||||||
data, ok := ReadContext(r.Context())
|
data, ok := ReadContext(r.Context())
|
||||||
if ok {
|
if ok {
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(r.URL.Path, "/upstream/") {
|
||||||
|
if data, ok := extractUpstreamContext(r, cfg); ok {
|
||||||
|
*r = *r.WithContext(SetContext(r.Context(), data))
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
return ReqContextData{}, ErrNoModelInContext
|
||||||
|
}
|
||||||
|
|
||||||
if data, err := extractContext(r); err == nil && data.Model != "" {
|
if data, err := extractContext(r); err == nil && data.Model != "" {
|
||||||
realName, _ := cfg.RealModelName(data.Model)
|
realName, _ := cfg.RealModelName(data.Model)
|
||||||
if realName == "" {
|
if realName == "" {
|
||||||
@@ -117,6 +125,59 @@ func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
|||||||
return ReqContextData{}, ErrNoModelInContext
|
return ReqContextData{}, ErrNoModelInContext
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractUpstreamContext resolves the model from an /upstream/<model>/... path.
|
||||||
|
func extractUpstreamContext(r *http.Request, cfg config.Config) (ReqContextData, bool) {
|
||||||
|
searchName, realName, _, found := FindModelInPath(cfg, strings.TrimPrefix(r.URL.Path, "/upstream"))
|
||||||
|
if !found {
|
||||||
|
return ReqContextData{}, false
|
||||||
|
}
|
||||||
|
return ReqContextData{
|
||||||
|
Model: searchName,
|
||||||
|
ModelID: realName,
|
||||||
|
ApiKey: ExtractAPIKey(r),
|
||||||
|
Streaming: r.URL.Query().Get("stream") == "true",
|
||||||
|
SendLoadingState: sendLoadingState(cfg, realName),
|
||||||
|
Metadata: make(map[string]string),
|
||||||
|
}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendLoadingState reports whether the configured model wants loading-state SSEs.
|
||||||
|
func sendLoadingState(cfg config.Config, modelID string) bool {
|
||||||
|
if mc, ok := cfg.Models[modelID]; ok {
|
||||||
|
return mc.SendLoadingState != nil && *mc.SendLoadingState
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindModelInPath walks a slash-separated path, building up segments until one
|
||||||
|
// matches a configured model. This resolves model names that contain slashes
|
||||||
|
// (e.g. "author/model"). Returns the matched name, its real model ID, the
|
||||||
|
// remaining path, and whether a match was found.
|
||||||
|
func FindModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
|
||||||
|
parts := strings.Split(strings.TrimSpace(path), "/")
|
||||||
|
name := ""
|
||||||
|
|
||||||
|
for i, part := range parts {
|
||||||
|
if part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
name = part
|
||||||
|
} else {
|
||||||
|
name = name + "/" + part
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelID, ok := cfg.RealModelName(name); ok {
|
||||||
|
searchName = name
|
||||||
|
realName = modelID
|
||||||
|
remainingPath = "/" + strings.Join(parts[i+1:], "/")
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func SetContext(ctx context.Context, data ReqContextData) context.Context {
|
func SetContext(ctx context.Context, data ReqContextData) context.Context {
|
||||||
return context.WithValue(ctx, ReqContextKey, data)
|
return context.WithValue(ctx, ReqContextKey, data)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestExtractContext_GET(t *testing.T) {
|
func TestExtractContext_GET(t *testing.T) {
|
||||||
@@ -456,3 +458,68 @@ func TestServer_ExtractAPIKey(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFetchContext_UpstreamPath(t *testing.T) {
|
||||||
|
cfg := config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"m1": {},
|
||||||
|
"author/model": {},
|
||||||
|
"real": {Aliases: []string{"nick"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
wantModel string
|
||||||
|
wantModelID string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"known model", "/upstream/m1/v1/chat/completions", "m1", "m1", false},
|
||||||
|
{"model with slash", "/upstream/author/model/v1/chat", "author/model", "author/model", false},
|
||||||
|
{"unknown model", "/upstream/nope/v1/chat/completions", "", "", true},
|
||||||
|
{"bare model path", "/upstream/m1/", "m1", "m1", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodPost, c.path, strings.NewReader(`{}`))
|
||||||
|
data, err := FetchContext(r, cfg)
|
||||||
|
if (err != nil) != c.wantErr {
|
||||||
|
t.Fatalf("wantErr=%v got err=%v", c.wantErr, err)
|
||||||
|
}
|
||||||
|
if c.wantErr {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if data.Model != c.wantModel {
|
||||||
|
t.Errorf("model = %q, want %q", data.Model, c.wantModel)
|
||||||
|
}
|
||||||
|
if data.ModelID != c.wantModelID {
|
||||||
|
t.Errorf("modelID = %q, want %q", data.ModelID, c.wantModelID)
|
||||||
|
}
|
||||||
|
if data.Metadata == nil {
|
||||||
|
t.Error("metadata map not initialized")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchContext_UpstreamPath_DoesNotReadBody(t *testing.T) {
|
||||||
|
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||||
|
body := `{"model":"should-not-matter"}`
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/chat/completions", strings.NewReader(body))
|
||||||
|
|
||||||
|
_, err := FetchContext(r, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FetchContext: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The body should be untouched so the upstream handler can still read it.
|
||||||
|
got, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read body: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != body {
|
||||||
|
t.Errorf("body was consumed: %q", string(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
import Performance from "./routes/Performance.svelte";
|
import Performance from "./routes/Performance.svelte";
|
||||||
import Playground from "./routes/Playground.svelte";
|
import Playground from "./routes/Playground.svelte";
|
||||||
import PlaygroundStub from "./routes/PlaygroundStub.svelte";
|
import PlaygroundStub from "./routes/PlaygroundStub.svelte";
|
||||||
import { enableAPIEvents } from "./stores/api";
|
import { enableAPIEvents, checkPerformanceEnabled } from "./stores/api";
|
||||||
import { initScreenWidth, initSystemThemeListener, isDarkMode, appTitle, connectionState } from "./stores/theme";
|
import { initScreenWidth, initSystemThemeListener, isDarkMode, appTitle, connectionState } from "./stores/theme";
|
||||||
import { currentRoute } from "./stores/route";
|
import { currentRoute } from "./stores/route";
|
||||||
|
|
||||||
@@ -39,6 +39,7 @@
|
|||||||
const cleanupScreenWidth = initScreenWidth();
|
const cleanupScreenWidth = initScreenWidth();
|
||||||
const cleanupSystemTheme = initSystemThemeListener();
|
const cleanupSystemTheme = initSystemThemeListener();
|
||||||
enableAPIEvents(true);
|
enableAPIEvents(true);
|
||||||
|
checkPerformanceEnabled();
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
cleanupScreenWidth();
|
cleanupScreenWidth();
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import { screenWidth, toggleTheme, themeMode, appTitle, isNarrow } from "../stores/theme";
|
import { screenWidth, toggleTheme, themeMode, appTitle, isNarrow } from "../stores/theme";
|
||||||
import { currentRoute } from "../stores/route";
|
import { currentRoute } from "../stores/route";
|
||||||
import { playgroundActivity } from "../stores/playgroundActivity";
|
import { playgroundActivity } from "../stores/playgroundActivity";
|
||||||
|
import { performanceEnabled } from "../stores/api";
|
||||||
import ConnectionStatus from "./ConnectionStatus.svelte";
|
import ConnectionStatus from "./ConnectionStatus.svelte";
|
||||||
|
|
||||||
function handleTitleChange(newTitle: string): void {
|
function handleTitleChange(newTitle: string): void {
|
||||||
@@ -84,6 +85,7 @@
|
|||||||
>
|
>
|
||||||
Logs
|
Logs
|
||||||
</a>
|
</a>
|
||||||
|
{#if $performanceEnabled}
|
||||||
<a
|
<a
|
||||||
href="/performance"
|
href="/performance"
|
||||||
use:link
|
use:link
|
||||||
@@ -94,6 +96,7 @@
|
|||||||
>
|
>
|
||||||
Performance
|
Performance
|
||||||
</a>
|
</a>
|
||||||
|
{/if}
|
||||||
<button onclick={toggleTheme} title="Toggle theme (current: {$themeMode})">
|
<button onclick={toggleTheme} title="Toggle theme (current: {$themeMode})">
|
||||||
{#if $themeMode === "system"}
|
{#if $themeMode === "system"}
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ export interface Model {
|
|||||||
|
|
||||||
export interface TokenMetrics {
|
export interface TokenMetrics {
|
||||||
cache_tokens: number;
|
cache_tokens: number;
|
||||||
|
draft_tokens: number;
|
||||||
|
draft_acc_tokens: number;
|
||||||
input_tokens: number;
|
input_tokens: number;
|
||||||
output_tokens: number;
|
output_tokens: number;
|
||||||
prompt_per_second: number;
|
prompt_per_second: number;
|
||||||
@@ -41,6 +43,7 @@ export interface ActivityLogEntry {
|
|||||||
tokens: TokenMetrics;
|
tokens: TokenMetrics;
|
||||||
duration_ms: number;
|
duration_ms: number;
|
||||||
has_capture: boolean;
|
has_capture: boolean;
|
||||||
|
error_msg?: string;
|
||||||
metadata?: Record<string, string>;
|
metadata?: Record<string, string>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,11 +21,12 @@
|
|||||||
{ key: "time", label: "Time", defaultVisible: true },
|
{ key: "time", label: "Time", defaultVisible: true },
|
||||||
{ key: "model", label: "Model", defaultVisible: true },
|
{ key: "model", label: "Model", defaultVisible: true },
|
||||||
{ key: "req_path", label: "Path", defaultVisible: false },
|
{ key: "req_path", label: "Path", defaultVisible: false },
|
||||||
{ key: "resp_status_code", label: "Status", defaultVisible: false },
|
{ key: "resp_status_code", label: "Status", defaultVisible: true },
|
||||||
{ key: "resp_content_type", label: "Content-Type", defaultVisible: false },
|
{ key: "resp_content_type", label: "Content-Type", defaultVisible: false },
|
||||||
{ key: "cached", label: "Cached", defaultVisible: true },
|
{ key: "cached", label: "Cached", defaultVisible: true },
|
||||||
{ key: "prompt", label: "Prompt", defaultVisible: true },
|
{ key: "prompt", label: "Prompt", defaultVisible: true },
|
||||||
{ key: "generated", label: "Generated", defaultVisible: true },
|
{ key: "generated", label: "Generated", defaultVisible: true },
|
||||||
|
{ key: "drafted", label: "Drafted", defaultVisible: false },
|
||||||
{ key: "prompt_speed", label: "Prompt Speed", defaultVisible: true },
|
{ key: "prompt_speed", label: "Prompt Speed", defaultVisible: true },
|
||||||
{ key: "gen_speed", label: "Gen Speed", defaultVisible: true },
|
{ key: "gen_speed", label: "Gen Speed", defaultVisible: true },
|
||||||
{ key: "duration", label: "Duration", defaultVisible: true },
|
{ key: "duration", label: "Duration", defaultVisible: true },
|
||||||
@@ -158,6 +159,10 @@
|
|||||||
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
|
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function formatDrafted(drafted: number, accepted: number): string {
|
||||||
|
return drafted > 0 ? (accepted * 100 / drafted).toFixed(1) + "% (" + accepted + "/" + drafted + ")" : "-";
|
||||||
|
}
|
||||||
|
|
||||||
function formatDuration(ms: number): string {
|
function formatDuration(ms: number): string {
|
||||||
return (ms / 1000).toFixed(2) + "s";
|
return (ms / 1000).toFixed(2) + "s";
|
||||||
}
|
}
|
||||||
@@ -273,6 +278,8 @@
|
|||||||
Cached <Tooltip content="prompt tokens from cache" />
|
Cached <Tooltip content="prompt tokens from cache" />
|
||||||
{:else if key === "prompt"}
|
{:else if key === "prompt"}
|
||||||
Prompt <Tooltip content="new prompt tokens processed" />
|
Prompt <Tooltip content="new prompt tokens processed" />
|
||||||
|
{:else if key === "drafted"}
|
||||||
|
Drafted <Tooltip content="acceptance rate (accepted/drafted)" />
|
||||||
{:else}
|
{:else}
|
||||||
{columnLabelMap[key] ?? key}
|
{columnLabelMap[key] ?? key}
|
||||||
{/if}
|
{/if}
|
||||||
@@ -301,7 +308,13 @@
|
|||||||
{:else if key === "req_path"}
|
{:else if key === "req_path"}
|
||||||
{metric.req_path || "-"}
|
{metric.req_path || "-"}
|
||||||
{:else if key === "resp_status_code"}
|
{:else if key === "resp_status_code"}
|
||||||
|
{#if metric.error_msg}
|
||||||
|
<span class="text-red-500 dark:text-red-400 cursor-help" title={metric.error_msg}>
|
||||||
{metric.resp_status_code || "-"}
|
{metric.resp_status_code || "-"}
|
||||||
|
</span>
|
||||||
|
{:else}
|
||||||
|
{metric.resp_status_code || "-"}
|
||||||
|
{/if}
|
||||||
{:else if key === "resp_content_type"}
|
{:else if key === "resp_content_type"}
|
||||||
{metric.resp_content_type || "-"}
|
{metric.resp_content_type || "-"}
|
||||||
{:else if key === "cached"}
|
{:else if key === "cached"}
|
||||||
@@ -310,6 +323,8 @@
|
|||||||
{metric.tokens.input_tokens.toLocaleString()}
|
{metric.tokens.input_tokens.toLocaleString()}
|
||||||
{:else if key === "generated"}
|
{:else if key === "generated"}
|
||||||
{metric.tokens.output_tokens.toLocaleString()}
|
{metric.tokens.output_tokens.toLocaleString()}
|
||||||
|
{:else if key === "drafted"}
|
||||||
|
{formatDrafted(metric.tokens.draft_tokens, metric.tokens.draft_acc_tokens)}
|
||||||
{:else if key === "prompt_speed"}
|
{:else if key === "prompt_speed"}
|
||||||
{formatSpeed(metric.tokens.prompt_per_second)}
|
{formatSpeed(metric.tokens.prompt_per_second)}
|
||||||
{:else if key === "gen_speed"}
|
{:else if key === "gen_speed"}
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ export const proxyLogs = writable<string>("");
|
|||||||
export const upstreamLogs = writable<string>("");
|
export const upstreamLogs = writable<string>("");
|
||||||
export const metrics = writable<ActivityLogEntry[]>([]);
|
export const metrics = writable<ActivityLogEntry[]>([]);
|
||||||
export const inFlightRequests = writable<number>(0);
|
export const inFlightRequests = writable<number>(0);
|
||||||
|
export const performanceEnabled = writable<boolean>(false);
|
||||||
export const versionInfo = writable<VersionInfo>({
|
export const versionInfo = writable<VersionInfo>({
|
||||||
build_date: "unknown",
|
build_date: "unknown",
|
||||||
commit: "unknown",
|
commit: "unknown",
|
||||||
@@ -210,6 +211,20 @@ export async function getCapture(id: number): Promise<ReqRespCapture | null> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function checkPerformanceEnabled(): Promise<void> {
|
||||||
|
try {
|
||||||
|
const response = await fetch("/api/performance");
|
||||||
|
if (!response.ok) {
|
||||||
|
performanceEnabled.set(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const data = await response.json();
|
||||||
|
performanceEnabled.set(data.enabled);
|
||||||
|
} catch {
|
||||||
|
performanceEnabled.set(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export async function fetchPerformance(after?: string): Promise<PerformanceResponse | null> {
|
export async function fetchPerformance(after?: string): Promise<PerformanceResponse | null> {
|
||||||
try {
|
try {
|
||||||
const url = after ? `/api/performance?after=${encodeURIComponent(after)}` : "/api/performance";
|
const url = after ? `/api/performance?after=${encodeURIComponent(after)}` : "/api/performance";
|
||||||
|
|||||||
Reference in New Issue
Block a user