proxy: meter /upstream requests via metrics middleware (#858)
Wrap /upstream/{upstreamPath...} in the metrics middleware so activity
log entries are recorded for model-dispatched endpoints accessed through
the upstream passthrough.
- Move findModelInPath to shared.FindModelInPath and reuse it in
handleUpstream, the log monitor lookup, and FetchContext.
- Extend FetchContext to resolve the model from /upstream/<model>/...
paths without consuming the request body.
- Add isMetricsRecordPath to limit recording to the model-dispatched
endpoints that produce token usage/timings.
- Add tests for upstream metrics recording and FetchContext upstream
path resolution.
Fixes #855
This commit is contained in:
+65
-4
@@ -91,16 +91,24 @@ func SendResponse(w http.ResponseWriter, r *http.Request, status int, message st
|
||||
w.Write(resp)
|
||||
}
|
||||
|
||||
// FetchContext will attempt to get the model id from the context then
|
||||
// from the model body. If it extracts the model from the body it will
|
||||
// store the model in the context for downstream handlers. An error
|
||||
// will be returned when model can not be fetch from either location.
|
||||
// FetchContext will attempt to get the model id from the context, then
|
||||
// from an /upstream/<model> path prefix, then from the request body/query.
|
||||
// If it extracts the model it will store it in the context for downstream
|
||||
// handlers. An error will be returned when a model cannot be identified.
|
||||
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
||||
data, ok := ReadContext(r.Context())
|
||||
if ok {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(r.URL.Path, "/upstream/") {
|
||||
if data, ok := extractUpstreamContext(r, cfg); ok {
|
||||
*r = *r.WithContext(SetContext(r.Context(), data))
|
||||
return data, nil
|
||||
}
|
||||
return ReqContextData{}, ErrNoModelInContext
|
||||
}
|
||||
|
||||
if data, err := extractContext(r); err == nil && data.Model != "" {
|
||||
realName, _ := cfg.RealModelName(data.Model)
|
||||
if realName == "" {
|
||||
@@ -117,6 +125,59 @@ func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
||||
return ReqContextData{}, ErrNoModelInContext
|
||||
}
|
||||
|
||||
// extractUpstreamContext resolves the model from an /upstream/<model>/... path.
|
||||
func extractUpstreamContext(r *http.Request, cfg config.Config) (ReqContextData, bool) {
|
||||
searchName, realName, _, found := FindModelInPath(cfg, strings.TrimPrefix(r.URL.Path, "/upstream"))
|
||||
if !found {
|
||||
return ReqContextData{}, false
|
||||
}
|
||||
return ReqContextData{
|
||||
Model: searchName,
|
||||
ModelID: realName,
|
||||
ApiKey: ExtractAPIKey(r),
|
||||
Streaming: r.URL.Query().Get("stream") == "true",
|
||||
SendLoadingState: sendLoadingState(cfg, realName),
|
||||
Metadata: make(map[string]string),
|
||||
}, true
|
||||
}
|
||||
|
||||
// sendLoadingState reports whether the configured model wants loading-state SSEs.
|
||||
func sendLoadingState(cfg config.Config, modelID string) bool {
|
||||
if mc, ok := cfg.Models[modelID]; ok {
|
||||
return mc.SendLoadingState != nil && *mc.SendLoadingState
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FindModelInPath walks a slash-separated path, building up segments until one
|
||||
// matches a configured model. This resolves model names that contain slashes
|
||||
// (e.g. "author/model"). Returns the matched name, its real model ID, the
|
||||
// remaining path, and whether a match was found.
|
||||
func FindModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
|
||||
parts := strings.Split(strings.TrimSpace(path), "/")
|
||||
name := ""
|
||||
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if name == "" {
|
||||
name = part
|
||||
} else {
|
||||
name = name + "/" + part
|
||||
}
|
||||
|
||||
if modelID, ok := cfg.RealModelName(name); ok {
|
||||
searchName = name
|
||||
realName = modelID
|
||||
remainingPath = "/" + strings.Join(parts[i+1:], "/")
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func SetContext(ctx context.Context, data ReqContextData) context.Context {
|
||||
return context.WithValue(ctx, ReqContextKey, data)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
)
|
||||
|
||||
func TestExtractContext_GET(t *testing.T) {
|
||||
@@ -456,3 +458,68 @@ func TestServer_ExtractAPIKey(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchContext_UpstreamPath(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"m1": {},
|
||||
"author/model": {},
|
||||
"real": {Aliases: []string{"nick"}},
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
path string
|
||||
wantModel string
|
||||
wantModelID string
|
||||
wantErr bool
|
||||
}{
|
||||
{"known model", "/upstream/m1/v1/chat/completions", "m1", "m1", false},
|
||||
{"model with slash", "/upstream/author/model/v1/chat", "author/model", "author/model", false},
|
||||
{"unknown model", "/upstream/nope/v1/chat/completions", "", "", true},
|
||||
{"bare model path", "/upstream/m1/", "m1", "m1", false},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodPost, c.path, strings.NewReader(`{}`))
|
||||
data, err := FetchContext(r, cfg)
|
||||
if (err != nil) != c.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", c.wantErr, err)
|
||||
}
|
||||
if c.wantErr {
|
||||
return
|
||||
}
|
||||
if data.Model != c.wantModel {
|
||||
t.Errorf("model = %q, want %q", data.Model, c.wantModel)
|
||||
}
|
||||
if data.ModelID != c.wantModelID {
|
||||
t.Errorf("modelID = %q, want %q", data.ModelID, c.wantModelID)
|
||||
}
|
||||
if data.Metadata == nil {
|
||||
t.Error("metadata map not initialized")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchContext_UpstreamPath_DoesNotReadBody(t *testing.T) {
|
||||
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||
body := `{"model":"should-not-matter"}`
|
||||
r := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/chat/completions", strings.NewReader(body))
|
||||
|
||||
_, err := FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("FetchContext: %v", err)
|
||||
}
|
||||
|
||||
// The body should be untouched so the upstream handler can still read it.
|
||||
got, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read body: %v", err)
|
||||
}
|
||||
if string(got) != body {
|
||||
t.Errorf("body was consumed: %q", string(got))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user