diff --git a/proxy/proxymanager_loghandlers.go b/proxy/proxymanager_loghandlers.go index 484d26c2..b08d5ae2 100644 --- a/proxy/proxymanager_loghandlers.go +++ b/proxy/proxymanager_loghandlers.go @@ -107,6 +107,12 @@ func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) { return process.Logger(), nil } } + // also check the matrix when processGroups doesn't contain the model + if pm.matrix != nil { + if process, found := pm.matrix.GetProcess(name); found { + return process.Logger(), nil + } + } } return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID") diff --git a/proxy/proxymanager_loghandlers_test.go b/proxy/proxymanager_loghandlers_test.go index 21c3a9b3..1b9ba5b1 100644 --- a/proxy/proxymanager_loghandlers_test.go +++ b/proxy/proxymanager_loghandlers_test.go @@ -1,8 +1,15 @@ package proxy import ( + "context" + "net/http/httptest" "strings" "testing" + "time" + + "github.com/mostlygeek/llama-swap/proxy/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestLogMonitorIdQueryParameterStripping(t *testing.T) { @@ -47,3 +54,120 @@ func TestLogMonitorIdQueryParameterStripping(t *testing.T) { }) } } + +// TestProxyManager_GetLogger_ProcessGroups verifies getLogger resolves the +// well-known "proxy"/"upstream" loggers and a model ID managed by processGroups. +func TestProxyManager_GetLogger_ProcessGroups(t *testing.T) { + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 +`) + pm := New(cfg) + defer pm.StopProcesses(StopImmediately) + + tests := []struct { + id string + wantErr bool + }{ + {"proxy", false}, + {"upstream", false}, + {"model1", false}, + {"does-not-exist", true}, + } + + for _, tt := range tests { + t.Run(tt.id, func(t *testing.T) { + logger, err := pm.getLogger(tt.id) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid logger") + } else { + require.NoError(t, err) + assert.NotNil(t, logger) + } + }) + } +} + +// TestProxyManager_GetLogger_Matrix verifies that getLogger can resolve a model +// ID when the proxy is configured with a swap matrix (pm.processGroups is empty +// for matrix-managed models). +func TestProxyManager_GetLogger_Matrix(t *testing.T) { + cfg := config.Config{ + HealthCheckTimeout: 15, + Models: map[string]config.ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + "model2": getTestSimpleResponderConfig("model2"), + }, + ExpandedSets: []config.ExpandedSet{ + {SetName: "s1", Models: []string{"model1", "model2"}}, + }, + Matrix: &config.MatrixConfig{}, + } + + pm := New(cfg) + defer pm.StopProcesses(StopImmediately) + + tests := []struct { + id string + wantErr bool + }{ + {"proxy", false}, + {"upstream", false}, + {"model1", false}, + {"model2", false}, + {"does-not-exist", true}, + } + + for _, tt := range tests { + t.Run(tt.id, func(t *testing.T) { + logger, err := pm.getLogger(tt.id) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid logger") + } else { + require.NoError(t, err) + assert.NotNil(t, logger) + } + }) + } +} + +// TestProxyManager_StreamLogs_Matrix verifies that /logs/stream/ +// returns 200 (not 400) for a model managed by the swap matrix. +func TestProxyManager_StreamLogs_Matrix(t *testing.T) { + cfg := config.Config{ + HealthCheckTimeout: 15, + Models: map[string]config.ModelConfig{ + "matrix-model": getTestSimpleResponderConfig("matrix-model"), + }, + ExpandedSets: []config.ExpandedSet{ + {SetName: "s1", Models: []string{"matrix-model"}}, + }, + Matrix: &config.MatrixConfig{}, + } + + pm := New(cfg) + defer pm.StopProcesses(StopImmediately) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + req := httptest.NewRequest("GET", "/logs/stream/matrix-model", nil) + req = req.WithContext(ctx) + rec := CreateTestResponseRecorder() + + done := make(chan struct{}) + go func() { + defer close(done) + pm.ServeHTTP(rec, req) + }() + + <-ctx.Done() + <-done + + assert.Equal(t, 200, rec.Code) +}