diff --git a/cmd/foreman/main.go b/cmd/foreman/main.go index 9af03f6..4c04a3c 100644 --- a/cmd/foreman/main.go +++ b/cmd/foreman/main.go @@ -7,11 +7,16 @@ package main import ( + "context" + "encoding/json" "fmt" "log/slog" "os" + "os/signal" + "syscall" "gitea.stevedudenhoeffer.com/steve/foreman/internal/config" + "gitea.stevedudenhoeffer.com/steve/foreman/internal/ollama" "gitea.stevedudenhoeffer.com/steve/foreman/internal/server" "gitea.stevedudenhoeffer.com/steve/foreman/internal/store" ) @@ -47,10 +52,12 @@ func main() { } } -// runServe loads configuration, opens the store, and starts the HTTP server. +// runServe loads configuration, opens the store, creates the Ollama client, +// starts the model poller, warms the embedder, and starts the HTTP server. // // Why: the serve subcommand is the daemon's primary mode of operation. -// What: wires config -> store -> server and blocks on ListenAndServe. +// What: wires config -> store -> ollama client -> poller -> server and blocks on +// ListenAndServe. Graceful shutdown on SIGINT/SIGTERM cancels the poller. // Test: tested indirectly via integration tests; each component is unit tested. func runServe(logger *slog.Logger) error { cfg, err := config.Load() @@ -73,6 +80,45 @@ func runServe(logger *slog.Logger) error { } defer st.Close() - srv := server.New(cfg, st, logger) + // Create the Ollama client. + client := ollama.NewClient(cfg.OllamaURL, cfg.OllamaToken) + + // Create the model inventory and start the poller. + inventory := ollama.NewModelInventory(client, logger) + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + go inventory.Start(ctx, cfg.PollInterval) + + // Warm the embedder model if configured (non-blocking). + if cfg.EmbedModel != "" { + warmEmbedder(ctx, client, cfg.EmbedModel, logger) + } + + srv := server.New(cfg, st, client, inventory, logger) return srv.ListenAndServe() } + +// warmEmbedder sends a trivial embed request with keep_alive=-1 to pin the +// embedder in slot 1 (ADR-0013). Does not block startup on failure. +// +// Why: the embedder must be always-resident so embedding requests are fast and +// never trigger a swap (ADR-0013). +// What: issues /api/embed with keep_alive:-1 to load and pin the model. +// Test: start foreman with FOREMAN_EMBED_MODEL set, verify the warmup call fires. +func warmEmbedder(ctx context.Context, client ollama.Client, model string, logger *slog.Logger) { + logger.Info("warming embedder model", "model", model) + + req := ollama.EmbedRequest{ + Model: model, + Input: json.RawMessage(`"warmup"`), + KeepAlive: json.RawMessage(`-1`), + } + + _, err := client.Embed(ctx, req) + if err != nil { + logger.Warn("embedder warmup failed (non-fatal)", "model", model, "error", err) + return + } + logger.Info("embedder warmed successfully", "model", model) +} diff --git a/internal/ollama/client.go b/internal/ollama/client.go new file mode 100644 index 0000000..0dc0c99 --- /dev/null +++ b/internal/ollama/client.go @@ -0,0 +1,326 @@ +package ollama + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "strings" +) + +// scannerBufSize is the buffer size for the NDJSON scanner (4 MB). +// Large enough to handle big tool-call payloads in a single line. +const scannerBufSize = 4 * 1024 * 1024 + +// Client defines the interface for communicating with an Ollama target. +// +// Why: an interface allows the worker loop, passthrough handlers, and tests to +// share a single contract and swap in stubs. +// What: covers the four Ollama endpoints foreman uses: chat, embed, tags, and ps. +// Test: implement with a stub HTTP server; verify round-trip for each method. +type Client interface { + // Chat sends a chat request. When stream is false, returns (*ChatResponse, nil, nil). + // When stream is true, returns (nil, <-chan ChatResponse, nil) with chunks delivered + // on the channel. The channel is closed when the stream ends. + Chat(ctx context.Context, req ChatRequest, stream bool) (*ChatResponse, <-chan ChatResponse, error) + + // Embed sends an embedding request to /api/embed. + Embed(ctx context.Context, req EmbedRequest) (*EmbedResponse, error) + + // Tags returns the list of installed models from /api/tags. + Tags(ctx context.Context) (*TagsResponse, error) + + // Ps returns the list of currently-loaded models from /api/ps. + Ps(ctx context.Context) (*PsResponse, error) + + // RawChat performs a raw proxied chat request, returning the http.Response for + // the caller to stream directly to a downstream client. The caller is responsible + // for closing the response body. + RawChat(ctx context.Context, body []byte) (*http.Response, error) + + // RawEmbed performs a raw proxied embed request, returning the http.Response. + // The caller is responsible for closing the response body. + RawEmbed(ctx context.Context, body []byte) (*http.Response, error) +} + +// httpClient is the concrete implementation of Client backed by net/http. +type httpClient struct { + baseURL string + token string + httpClient *http.Client +} + +// NewClient creates a new Ollama HTTP client. +// +// Why: centralizes base URL, auth token, and HTTP client configuration. +// What: returns a Client that makes HTTP requests to the given Ollama base URL. +// Test: create with a httptest.Server URL, call Tags, verify correct request path. +func NewClient(baseURL, token string) Client { + // Trim trailing slash for consistent URL construction. + baseURL = strings.TrimRight(baseURL, "/") + return &httpClient{ + baseURL: baseURL, + token: token, + httpClient: &http.Client{}, + } +} + +// Chat sends a POST /api/chat to the Ollama target. +// +// Why: the worker loop and sync passthrough both need structured chat access. +// What: POSTs the chat request, returns either a single response or a channel of +// streamed chunks depending on the stream parameter. +// Test: stub a /api/chat endpoint returning NDJSON or a single JSON object; verify +// both streaming and non-streaming paths. +func (c *httpClient) Chat(ctx context.Context, req ChatRequest, stream bool) (*ChatResponse, <-chan ChatResponse, error) { + streamVal := stream + req.Stream = &streamVal + + body, err := json.Marshal(req) + if err != nil { + return nil, nil, fmt.Errorf("marshal chat request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/api/chat", bytes.NewReader(body)) + if err != nil { + return nil, nil, fmt.Errorf("create chat request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + c.setAuth(httpReq) + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, nil, c.wrapConnErr(err) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + defer resp.Body.Close() + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return nil, nil, &HTTPError{StatusCode: resp.StatusCode, Body: string(errBody)} + } + + if !stream { + defer resp.Body.Close() + var chatResp ChatResponse + if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { + return nil, nil, fmt.Errorf("decode chat response: %w", err) + } + return &chatResp, nil, nil + } + + // Streaming: read NDJSON lines and send on channel. + ch := make(chan ChatResponse, 64) + go func() { + defer close(ch) + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, scannerBufSize), scannerBufSize) + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + var chunk ChatResponse + if err := json.Unmarshal(line, &chunk); err != nil { + continue + } + ch <- chunk + } + }() + + return nil, ch, nil +} + +// Embed sends a POST /api/embed to the Ollama target. +// +// Why: embedding requests bypass the queue and go directly to the target (ADR-0013). +// What: POSTs the embed request and returns the parsed response. +// Test: stub /api/embed, send a request, verify embeddings in the response. +func (c *httpClient) Embed(ctx context.Context, req EmbedRequest) (*EmbedResponse, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("marshal embed request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/api/embed", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create embed request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + c.setAuth(httpReq) + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, c.wrapConnErr(err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return nil, &HTTPError{StatusCode: resp.StatusCode, Body: string(errBody)} + } + + var embedResp EmbedResponse + if err := json.NewDecoder(resp.Body).Decode(&embedResp); err != nil { + return nil, fmt.Errorf("decode embed response: %w", err) + } + return &embedResp, nil +} + +// Tags fetches GET /api/tags from the Ollama target. +// +// Why: the model poller needs the installed model list for inventory and validation. +// What: GETs /api/tags and returns the parsed response. +// Test: stub /api/tags with a model list, verify Tags() returns it. +func (c *httpClient) Tags(ctx context.Context) (*TagsResponse, error) { + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/api/tags", nil) + if err != nil { + return nil, fmt.Errorf("create tags request: %w", err) + } + c.setAuth(httpReq) + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, c.wrapConnErr(err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return nil, &HTTPError{StatusCode: resp.StatusCode, Body: string(errBody)} + } + + var tagsResp TagsResponse + if err := json.NewDecoder(resp.Body).Decode(&tagsResp); err != nil { + return nil, fmt.Errorf("decode tags response: %w", err) + } + return &tagsResp, nil +} + +// Ps fetches GET /api/ps from the Ollama target. +// +// Why: the poller and scheduler need to know which models are currently loaded. +// What: GETs /api/ps and returns the parsed response. +// Test: stub /api/ps with running models, verify Ps() returns them. +func (c *httpClient) Ps(ctx context.Context) (*PsResponse, error) { + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/api/ps", nil) + if err != nil { + return nil, fmt.Errorf("create ps request: %w", err) + } + c.setAuth(httpReq) + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, c.wrapConnErr(err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return nil, &HTTPError{StatusCode: resp.StatusCode, Body: string(errBody)} + } + + var psResp PsResponse + if err := json.NewDecoder(resp.Body).Decode(&psResp); err != nil { + return nil, fmt.Errorf("decode ps response: %w", err) + } + return &psResp, nil +} + +// RawChat performs a raw proxied POST /api/chat, returning the http.Response for +// direct streaming to a downstream client. +// +// Why: the passthrough handler needs raw access to the response body for NDJSON +// streaming without double-parsing. +// What: POSTs the raw body to /api/chat and returns the raw HTTP response. +// Test: stub /api/chat, call RawChat, verify response status and body forwarding. +func (c *httpClient) RawChat(ctx context.Context, body []byte) (*http.Response, error) { + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/api/chat", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create raw chat request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + c.setAuth(httpReq) + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, c.wrapConnErr(err) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + defer resp.Body.Close() + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return nil, &HTTPError{StatusCode: resp.StatusCode, Body: string(errBody)} + } + + return resp, nil +} + +// RawEmbed performs a raw proxied POST /api/embed, returning the http.Response. +// +// Why: the embed passthrough handler proxies the raw body/response without parsing. +// What: POSTs the raw body to /api/embed and returns the raw HTTP response. +// Test: stub /api/embed, call RawEmbed, verify response forwarding. +func (c *httpClient) RawEmbed(ctx context.Context, body []byte) (*http.Response, error) { + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/api/embed", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create raw embed request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + c.setAuth(httpReq) + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, c.wrapConnErr(err) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + defer resp.Body.Close() + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return nil, &HTTPError{StatusCode: resp.StatusCode, Body: string(errBody)} + } + + return resp, nil +} + +// setAuth adds the bearer token to the request if configured. +func (c *httpClient) setAuth(req *http.Request) { + if c.token != "" { + req.Header.Set("Authorization", "Bearer "+c.token) + } +} + +// wrapConnErr checks if the error is a network-level failure and wraps it in a +// ConnectionError. Non-network errors are returned as-is. +func (c *httpClient) wrapConnErr(err error) error { + if err == nil { + return nil + } + // Check for common network error types. + if _, ok := err.(*net.OpError); ok { + return &ConnectionError{URL: c.baseURL, Err: err} + } + if _, ok := err.(net.Error); ok { + return &ConnectionError{URL: c.baseURL, Err: err} + } + // Also catch connection refused, DNS errors, etc. that might be wrapped. + if isConnectionError(err) { + return &ConnectionError{URL: c.baseURL, Err: err} + } + return err +} + +// isConnectionError checks for common connection-level error patterns in wrapped errors. +func isConnectionError(err error) bool { + msg := err.Error() + return strings.Contains(msg, "connection refused") || + strings.Contains(msg, "no such host") || + strings.Contains(msg, "network is unreachable") || + strings.Contains(msg, "i/o timeout") || + strings.Contains(msg, "dial tcp") +} diff --git a/internal/ollama/client_test.go b/internal/ollama/client_test.go new file mode 100644 index 0000000..1b85298 --- /dev/null +++ b/internal/ollama/client_test.go @@ -0,0 +1,279 @@ +package ollama + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" +) + +func TestTags_ParsesResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/tags" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(TagsResponse{ + Models: []ModelInfo{ + {Name: "qwen3:30b", Model: "qwen3:30b", Size: 19000000000}, + {Name: "nomic-embed-text", Model: "nomic-embed-text", Size: 300000000}, + }, + }) + })) + defer srv.Close() + + client := NewClient(srv.URL, "") + resp, err := client.Tags(context.Background()) + if err != nil { + t.Fatalf("Tags: %v", err) + } + if len(resp.Models) != 2 { + t.Fatalf("got %d models, want 2", len(resp.Models)) + } + if resp.Models[0].Name != "qwen3:30b" { + t.Errorf("first model = %q, want %q", resp.Models[0].Name, "qwen3:30b") + } +} + +func TestPs_ParsesResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/ps" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(PsResponse{ + Models: []RunningModel{ + {Name: "nomic-embed-text", Model: "nomic-embed-text", Size: 300000000}, + }, + }) + })) + defer srv.Close() + + client := NewClient(srv.URL, "") + resp, err := client.Ps(context.Background()) + if err != nil { + t.Fatalf("Ps: %v", err) + } + if len(resp.Models) != 1 { + t.Fatalf("got %d models, want 1", len(resp.Models)) + } +} + +func TestChat_NonStreaming(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req ChatRequest + if err := json.Unmarshal(body, &req); err != nil { + t.Errorf("unmarshal: %v", err) + } + if req.Stream != nil && *req.Stream { + t.Error("expected stream=false") + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(ChatResponse{ + Model: "qwen3:30b", + Done: true, + Message: &Message{Role: "assistant", Content: "Hello!"}, + }) + })) + defer srv.Close() + + client := NewClient(srv.URL, "") + resp, ch, err := client.Chat(context.Background(), ChatRequest{ + Model: "qwen3:30b", + Messages: []Message{{Role: "user", Content: "hi"}}, + }, false) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if ch != nil { + t.Error("expected nil channel for non-streaming") + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if resp.Message.Content != "Hello!" { + t.Errorf("content = %q, want %q", resp.Message.Content, "Hello!") + } +} + +func TestChat_Streaming(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/x-ndjson") + flusher := w.(http.Flusher) + + chunks := []ChatResponse{ + {Model: "qwen3:30b", Done: false, Message: &Message{Role: "assistant", Content: "Hel"}}, + {Model: "qwen3:30b", Done: false, Message: &Message{Role: "assistant", Content: "lo"}}, + {Model: "qwen3:30b", Done: true, DoneReason: "stop"}, + } + for _, c := range chunks { + b, _ := json.Marshal(c) + w.Write(b) + w.Write([]byte("\n")) + flusher.Flush() + } + })) + defer srv.Close() + + client := NewClient(srv.URL, "") + resp, ch, err := client.Chat(context.Background(), ChatRequest{ + Model: "qwen3:30b", + Messages: []Message{{Role: "user", Content: "hi"}}, + }, true) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if resp != nil { + t.Error("expected nil response for streaming") + } + if ch == nil { + t.Fatal("expected non-nil channel for streaming") + } + + var collected []ChatResponse + for c := range ch { + collected = append(collected, c) + } + if len(collected) != 3 { + t.Fatalf("got %d chunks, want 3", len(collected)) + } + if !collected[2].Done { + t.Error("last chunk should have done=true") + } +} + +func TestEmbed_ParsesResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/embed" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(EmbedResponse{ + Model: "nomic-embed-text", + Embeddings: [][]float64{{0.1, 0.2, 0.3}}, + }) + })) + defer srv.Close() + + client := NewClient(srv.URL, "") + resp, err := client.Embed(context.Background(), EmbedRequest{ + Model: "nomic-embed-text", + Input: json.RawMessage(`"test input"`), + }) + if err != nil { + t.Fatalf("Embed: %v", err) + } + if len(resp.Embeddings) != 1 { + t.Fatalf("got %d embeddings, want 1", len(resp.Embeddings)) + } + if len(resp.Embeddings[0]) != 3 { + t.Errorf("embedding dim = %d, want 3", len(resp.Embeddings[0])) + } +} + +func TestClient_SetsAuthToken(t *testing.T) { + var gotAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(TagsResponse{}) + })) + defer srv.Close() + + client := NewClient(srv.URL, "my-secret-token") + _, err := client.Tags(context.Background()) + if err != nil { + t.Fatalf("Tags: %v", err) + } + if gotAuth != "Bearer my-secret-token" { + t.Errorf("auth header = %q, want %q", gotAuth, "Bearer my-secret-token") + } +} + +func TestClient_ConnectionError(t *testing.T) { + // Use a server that immediately closes. + client := NewClient("http://127.0.0.1:1", "") + _, err := client.Tags(context.Background()) + if err == nil { + t.Fatal("expected error for unreachable target") + } + + var connErr *ConnectionError + if !errors.As(err, &connErr) { + t.Errorf("expected *ConnectionError, got %T: %v", err, err) + } +} + +func TestClient_HTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, `{"error":"bad model"}`, http.StatusBadRequest) + })) + defer srv.Close() + + client := NewClient(srv.URL, "") + _, err := client.Tags(context.Background()) + if err == nil { + t.Fatal("expected error for 400 response") + } + + var httpErr *HTTPError + if !errors.As(err, &httpErr) { + t.Errorf("expected *HTTPError, got %T: %v", err, err) + } + if httpErr.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want %d", httpErr.StatusCode, http.StatusBadRequest) + } +} + +func TestRawChat_ForwardsBody(t *testing.T) { + var gotBody []byte + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotBody, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"done":true}`)) + })) + defer srv.Close() + + client := NewClient(srv.URL, "") + body := []byte(`{"model":"qwen3:30b","messages":[{"role":"user","content":"test"}]}`) + resp, err := client.RawChat(context.Background(), body) + if err != nil { + t.Fatalf("RawChat: %v", err) + } + defer resp.Body.Close() + + if string(gotBody) != string(body) { + t.Errorf("body mismatch: got %s, want %s", gotBody, body) + } +} + +func TestRawEmbed_ForwardsBody(t *testing.T) { + var callCount atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount.Add(1) + if r.URL.Path != "/api/embed" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(fmt.Sprintf(`{"model":"nomic-embed-text","embeddings":[[0.%d]]}`, callCount.Load()))) + })) + defer srv.Close() + + client := NewClient(srv.URL, "") + body := []byte(`{"model":"nomic-embed-text","input":"test"}`) + resp, err := client.RawEmbed(context.Background(), body) + if err != nil { + t.Fatalf("RawEmbed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } +} diff --git a/internal/ollama/errors.go b/internal/ollama/errors.go new file mode 100644 index 0000000..8101128 --- /dev/null +++ b/internal/ollama/errors.go @@ -0,0 +1,41 @@ +package ollama + +import ( + "fmt" +) + +// ConnectionError wraps a network-level error when the Ollama target is unreachable. +// Phase 3 uses this to distinguish connection failures (retry-eligible) from HTTP +// errors (usually not retryable). +// +// Why: callers must differentiate "target is down" from "target returned 4xx/5xx" +// to decide on retry strategy. +// What: wraps a net-level error and satisfies the error and Unwrap interfaces. +// Test: create a ConnectionError, verify errors.Is/As can match it. +type ConnectionError struct { + URL string + Err error +} + +func (e *ConnectionError) Error() string { + return fmt.Sprintf("connection to ollama target %s failed: %v", e.URL, e.Err) +} + +func (e *ConnectionError) Unwrap() error { + return e.Err +} + +// HTTPError represents a non-2xx HTTP response from the Ollama target. +// +// Why: callers need the status code to distinguish client errors (4xx) from +// server errors (5xx) and decide on retry logic. +// What: holds the HTTP status code and response body for error diagnosis. +// Test: create an HTTPError with status 500, verify Error() includes the code. +type HTTPError struct { + StatusCode int + Body string +} + +func (e *HTTPError) Error() string { + return fmt.Sprintf("ollama target returned HTTP %d: %s", e.StatusCode, e.Body) +} diff --git a/internal/ollama/inventory.go b/internal/ollama/inventory.go new file mode 100644 index 0000000..1b3848c --- /dev/null +++ b/internal/ollama/inventory.go @@ -0,0 +1,167 @@ +package ollama + +import ( + "context" + "log/slog" + "sync" + "time" +) + +// ModelInventory maintains an in-memory cache of the Ollama target's installed +// and running models, refreshed by a background poller. +// +// Why: foreman needs a reasonably fresh view of installed models for validation, +// passthrough, and scheduling without hitting the target on every request. +// What: wraps a mutex-protected model list updated by a polling goroutine. +// Test: create with a stub client, poll, verify Models()/HasModel()/Degraded(). +type ModelInventory struct { + client Client + logger *slog.Logger + + mu sync.RWMutex + models []ModelInfo + runningModels []RunningModel + lastPoll time.Time + degraded bool +} + +// NewModelInventory creates a new ModelInventory backed by the given client. +// +// Why: separates construction from the poll loop so callers can control lifecycle. +// What: returns an inventory ready for polling; call Start to begin the background loop. +// Test: create, call Refresh manually, assert Models() is populated. +func NewModelInventory(client Client, logger *slog.Logger) *ModelInventory { + return &ModelInventory{ + client: client, + logger: logger, + } +} + +// Models returns the current known model list. +// +// Why: the /api/tags handler and model validation need the cached list. +// What: returns a copy of the model slice under a read lock. +// Test: Refresh, call Models(), verify the returned slice matches. +func (inv *ModelInventory) Models() []ModelInfo { + inv.mu.RLock() + defer inv.mu.RUnlock() + out := make([]ModelInfo, len(inv.models)) + copy(out, inv.models) + return out +} + +// HasModel checks whether a model name is in the current inventory. +// +// Why: job submission validates the model name before queuing. +// What: scans the model list for an exact name match. +// Test: Refresh with known models, assert HasModel returns true/false correctly. +func (inv *ModelInventory) HasModel(name string) bool { + inv.mu.RLock() + defer inv.mu.RUnlock() + for _, m := range inv.models { + if m.Name == name { + return true + } + } + return false +} + +// ResidentModels returns the list of currently-loaded models from the last /api/ps poll. +// +// Why: the scheduler (Phase 3) uses this to decide if a model swap is needed. +// What: returns a copy of the running model slice under a read lock. +// Test: Refresh, call ResidentModels(), verify it matches /api/ps output. +func (inv *ModelInventory) ResidentModels() []RunningModel { + inv.mu.RLock() + defer inv.mu.RUnlock() + out := make([]RunningModel, len(inv.runningModels)) + copy(out, inv.runningModels) + return out +} + +// LastPoll returns the timestamp of the most recent successful poll. +// +// Why: health/diagnostics use this to judge staleness. +// What: returns the lastPoll time under a read lock. +// Test: Refresh, assert LastPoll is non-zero and recent. +func (inv *ModelInventory) LastPoll() time.Time { + inv.mu.RLock() + defer inv.mu.RUnlock() + return inv.lastPoll +} + +// Degraded reports whether the target was unreachable on the last poll attempt. +// +// Why: the /healthz endpoint surfaces this to operators and probes. +// What: returns the degraded flag under a read lock. +// Test: Refresh with an unreachable stub, assert Degraded() is true; then with a +// reachable stub, assert it clears. +func (inv *ModelInventory) Degraded() bool { + inv.mu.RLock() + defer inv.mu.RUnlock() + return inv.degraded +} + +// Refresh performs an immediate poll of /api/tags and /api/ps on the target. +// +// Why: called by the poller goroutine and on-demand (e.g., model-miss re-check). +// What: fetches tags and ps, updates the cached lists, clears or sets the degraded +// flag based on success/failure. +// Test: stub both endpoints, call Refresh, verify Models() and ResidentModels() match. +func (inv *ModelInventory) Refresh(ctx context.Context) error { + tags, tagsErr := inv.client.Tags(ctx) + ps, psErr := inv.client.Ps(ctx) + + inv.mu.Lock() + defer inv.mu.Unlock() + + if tagsErr != nil { + inv.degraded = true + inv.logger.Warn("model poll failed: tags", + "error", tagsErr, + "retained_models", len(inv.models), + ) + return tagsErr + } + + // Tags succeeded — update model list and clear degraded. + inv.models = tags.Models + inv.lastPoll = time.Now() + inv.degraded = false + + if psErr != nil { + // Tags worked but ps failed — log but don't mark degraded for ps alone. + inv.logger.Warn("model poll partial: ps failed", "error", psErr) + } else { + inv.runningModels = ps.Models + } + + return nil +} + +// Start begins the background polling loop. It blocks until ctx is cancelled. +// Call this in a goroutine. +// +// Why: continuous polling keeps the model inventory fresh for validation and scheduling. +// What: polls at the given interval, respecting context cancellation for clean shutdown. +// Test: start with a short interval and cancelled context, verify it exits cleanly. +func (inv *ModelInventory) Start(ctx context.Context, interval time.Duration) { + // Do an initial poll immediately. + if err := inv.Refresh(ctx); err != nil { + inv.logger.Warn("initial model poll failed", "error", err) + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := inv.Refresh(ctx); err != nil { + inv.logger.Warn("periodic model poll failed", "error", err) + } + } + } +} diff --git a/internal/ollama/inventory_test.go b/internal/ollama/inventory_test.go new file mode 100644 index 0000000..e72cca8 --- /dev/null +++ b/internal/ollama/inventory_test.go @@ -0,0 +1,201 @@ +package ollama + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "testing" + "time" +) + +// mockClient implements Client for inventory testing. +type mockClient struct { + tagsFn func(ctx context.Context) (*TagsResponse, error) + psFn func(ctx context.Context) (*PsResponse, error) +} + +func (m *mockClient) Chat(ctx context.Context, req ChatRequest, stream bool) (*ChatResponse, <-chan ChatResponse, error) { + return nil, nil, fmt.Errorf("not implemented") +} + +func (m *mockClient) Embed(ctx context.Context, req EmbedRequest) (*EmbedResponse, error) { + return nil, fmt.Errorf("not implemented") +} + +func (m *mockClient) Tags(ctx context.Context) (*TagsResponse, error) { + return m.tagsFn(ctx) +} + +func (m *mockClient) Ps(ctx context.Context) (*PsResponse, error) { + return m.psFn(ctx) +} + +func (m *mockClient) RawChat(ctx context.Context, body []byte) (*http.Response, error) { + return nil, fmt.Errorf("not implemented") +} + +func (m *mockClient) RawEmbed(ctx context.Context, body []byte) (*http.Response, error) { + return nil, fmt.Errorf("not implemented") +} + +func TestInventory_RefreshPopulatesModels(t *testing.T) { + client := &mockClient{ + tagsFn: func(ctx context.Context) (*TagsResponse, error) { + return &TagsResponse{ + Models: []ModelInfo{ + {Name: "qwen3:30b"}, + {Name: "nomic-embed-text"}, + }, + }, nil + }, + psFn: func(ctx context.Context) (*PsResponse, error) { + return &PsResponse{ + Models: []RunningModel{ + {Name: "nomic-embed-text"}, + }, + }, nil + }, + } + + inv := NewModelInventory(client, slog.Default()) + if err := inv.Refresh(context.Background()); err != nil { + t.Fatalf("Refresh: %v", err) + } + + models := inv.Models() + if len(models) != 2 { + t.Fatalf("got %d models, want 2", len(models)) + } + + if !inv.HasModel("qwen3:30b") { + t.Error("HasModel(qwen3:30b) = false, want true") + } + if inv.HasModel("nonexistent") { + t.Error("HasModel(nonexistent) = true, want false") + } + + resident := inv.ResidentModels() + if len(resident) != 1 { + t.Fatalf("got %d resident models, want 1", len(resident)) + } + + if inv.Degraded() { + t.Error("degraded should be false after successful refresh") + } + if inv.LastPoll().IsZero() { + t.Error("lastPoll should be non-zero after refresh") + } +} + +func TestInventory_DegradedOnFailure(t *testing.T) { + callCount := 0 + client := &mockClient{ + tagsFn: func(ctx context.Context) (*TagsResponse, error) { + callCount++ + if callCount == 1 { + return &TagsResponse{ + Models: []ModelInfo{{Name: "qwen3:30b"}}, + }, nil + } + return nil, fmt.Errorf("connection refused") + }, + psFn: func(ctx context.Context) (*PsResponse, error) { + return &PsResponse{}, nil + }, + } + + inv := NewModelInventory(client, slog.Default()) + + // First refresh succeeds. + if err := inv.Refresh(context.Background()); err != nil { + t.Fatalf("first Refresh: %v", err) + } + if inv.Degraded() { + t.Error("should not be degraded after first successful poll") + } + + // Second refresh fails — should retain models but mark degraded. + if err := inv.Refresh(context.Background()); err == nil { + t.Fatal("expected error on second refresh") + } + if !inv.Degraded() { + t.Error("should be degraded after failed poll") + } + + // Models should be retained. + if !inv.HasModel("qwen3:30b") { + t.Error("should retain models after failed poll") + } +} + +func TestInventory_RecoveryFromDegraded(t *testing.T) { + failing := true + client := &mockClient{ + tagsFn: func(ctx context.Context) (*TagsResponse, error) { + if failing { + return nil, fmt.Errorf("connection refused") + } + return &TagsResponse{ + Models: []ModelInfo{{Name: "qwen3:30b"}}, + }, nil + }, + psFn: func(ctx context.Context) (*PsResponse, error) { + return &PsResponse{}, nil + }, + } + + inv := NewModelInventory(client, slog.Default()) + + // First refresh fails. + inv.Refresh(context.Background()) + if !inv.Degraded() { + t.Error("should be degraded after failed poll") + } + + // Target recovers. + failing = false + if err := inv.Refresh(context.Background()); err != nil { + t.Fatalf("recovery Refresh: %v", err) + } + if inv.Degraded() { + t.Error("should not be degraded after successful poll") + } +} + +func TestInventory_StartAndCancel(t *testing.T) { + pollCount := 0 + client := &mockClient{ + tagsFn: func(ctx context.Context) (*TagsResponse, error) { + pollCount++ + return &TagsResponse{}, nil + }, + psFn: func(ctx context.Context) (*PsResponse, error) { + return &PsResponse{}, nil + }, + } + + inv := NewModelInventory(client, slog.Default()) + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan struct{}) + go func() { + inv.Start(ctx, 10*time.Millisecond) + close(done) + }() + + // Let it poll a few times. + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case <-done: + // Clean exit. + case <-time.After(2 * time.Second): + t.Fatal("Start did not exit after context cancellation") + } + + if pollCount < 2 { + t.Errorf("poll count = %d, want >= 2 (initial + at least one tick)", pollCount) + } +} diff --git a/internal/ollama/types.go b/internal/ollama/types.go new file mode 100644 index 0000000..c3184aa --- /dev/null +++ b/internal/ollama/types.go @@ -0,0 +1,99 @@ +// Package ollama provides a client for talking to an Ollama target. +// +// Why: foreman needs a clean, testable interface to the Ollama HTTP API so the +// worker loop and passthrough handlers share a single client contract. +// What: defines wire types matching Ollama's native JSON API and a Client +// interface for chat, embed, tags, and ps operations. +// Test: use a stub HTTP server that returns canned Ollama JSON; verify the client +// parses responses and surfaces errors correctly. +package ollama + +import ( + "encoding/json" + "time" +) + +// ChatRequest is the JSON body for POST /api/chat. +// Fields use json.RawMessage where polymorphism or pass-through fidelity is required. +type ChatRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Stream *bool `json:"stream,omitempty"` + Tools json.RawMessage `json:"tools,omitempty"` + Options json.RawMessage `json:"options,omitempty"` + KeepAlive json.RawMessage `json:"keep_alive,omitempty"` + Think json.RawMessage `json:"think,omitempty"` + Format json.RawMessage `json:"format,omitempty"` + Template string `json:"template,omitempty"` + System string `json:"system,omitempty"` + Context json.RawMessage `json:"context,omitempty"` +} + +// Message is a single message in a chat conversation. +type Message struct { + Role string `json:"role"` + Content string `json:"content"` + Images []string `json:"images,omitempty"` + ToolCalls json.RawMessage `json:"tool_calls,omitempty"` + Thinking json.RawMessage `json:"thinking,omitempty"` +} + +// ChatResponse is the JSON response from POST /api/chat. +type ChatResponse struct { + Model string `json:"model,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + Message *Message `json:"message,omitempty"` + Done bool `json:"done"` + DoneReason string `json:"done_reason,omitempty"` + TotalDuration int64 `json:"total_duration,omitempty"` + LoadDuration int64 `json:"load_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration int64 `json:"eval_duration,omitempty"` + Context json.RawMessage `json:"context,omitempty"` +} + +// EmbedRequest is the JSON body for POST /api/embed (and /api/embeddings). +type EmbedRequest struct { + Model string `json:"model"` + Input json.RawMessage `json:"input"` + KeepAlive json.RawMessage `json:"keep_alive,omitempty"` + Options json.RawMessage `json:"options,omitempty"` +} + +// EmbedResponse is the JSON response from POST /api/embed. +type EmbedResponse struct { + Model string `json:"model,omitempty"` + Embeddings [][]float64 `json:"embeddings,omitempty"` +} + +// ModelInfo describes an installed model, as returned by GET /api/tags. +type ModelInfo struct { + Name string `json:"name"` + Model string `json:"model"` + ModifiedAt time.Time `json:"modified_at"` + Size int64 `json:"size"` + Digest string `json:"digest"` + Details json.RawMessage `json:"details,omitempty"` +} + +// TagsResponse is the JSON response from GET /api/tags. +type TagsResponse struct { + Models []ModelInfo `json:"models"` +} + +// RunningModel describes a currently-loaded model, as returned by GET /api/ps. +type RunningModel struct { + Name string `json:"name"` + Model string `json:"model"` + Size int64 `json:"size"` + Digest string `json:"digest"` + ExpiresAt time.Time `json:"expires_at"` + Details json.RawMessage `json:"details,omitempty"` + SizeVRAM int64 `json:"size_vram,omitempty"` +} + +// PsResponse is the JSON response from GET /api/ps. +type PsResponse struct { + Models []RunningModel `json:"models"` +} diff --git a/internal/server/server.go b/internal/server/server.go index bfbc68d..8ff06c5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,41 +3,54 @@ // Why: foreman exposes a native Ollama-compatible API plus async job endpoints; // centralizing routing and middleware here keeps cmd/foreman thin. // What: creates a stdlib net/http server with health checks, optional bearer-token -// auth, and an extensible mux for later phases. +// auth, Ollama passthrough (chat, tags, ps, embed), and an extensible mux. // Test: start the server with httptest, hit /healthz, verify 200; set a token, -// verify 401 without it. +// verify 401 without it; test Ollama passthrough routes. package server import ( + "bufio" "encoding/json" + "io" "log/slog" "net/http" "strings" "gitea.stevedudenhoeffer.com/steve/foreman/internal/config" + "gitea.stevedudenhoeffer.com/steve/foreman/internal/ollama" "gitea.stevedudenhoeffer.com/steve/foreman/internal/store" ) +// scannerBufSize is the buffer size for the NDJSON scanner (4 MB). +const scannerBufSize = 4 * 1024 * 1024 + // Server holds the HTTP server and its dependencies. type Server struct { - cfg config.Config - store *store.Store - mux *http.ServeMux - logger *slog.Logger + cfg config.Config + store *store.Store + client ollama.Client + inventory *ollama.ModelInventory + chatGate chan struct{} + mux *http.ServeMux + logger *slog.Logger } -// New creates a new Server with the given configuration and store. The mux is -// populated with initial routes; callers can add more before calling ListenAndServe. +// New creates a new Server with the given configuration, store, Ollama client, +// and model inventory. The mux is populated with all routes. // // Why: dependency injection makes the server testable and extensible. -// What: wires config, store, and logger into the server, registers routes. +// What: wires config, store, client, inventory, and logger into the server, +// registers routes, and creates the single-flight chat gate. // Test: create with New, use httptest to exercise routes. -func New(cfg config.Config, st *store.Store, logger *slog.Logger) *Server { +func New(cfg config.Config, st *store.Store, client ollama.Client, inv *ollama.ModelInventory, logger *slog.Logger) *Server { s := &Server{ - cfg: cfg, - store: st, - mux: http.NewServeMux(), - logger: logger, + cfg: cfg, + store: st, + client: client, + inventory: inv, + chatGate: make(chan struct{}, 1), + mux: http.NewServeMux(), + logger: logger, } s.routes() return s @@ -65,6 +78,11 @@ func (s *Server) ListenAndServe() error { // routes registers all HTTP routes on the mux. func (s *Server) routes() { s.mux.HandleFunc("GET /healthz", s.handleHealthz) + s.mux.HandleFunc("GET /api/tags", s.handleTags) + s.mux.HandleFunc("GET /api/ps", s.handlePs) + s.mux.HandleFunc("POST /api/chat", s.handleChat) + s.mux.HandleFunc("POST /api/embed", s.handleEmbed) + s.mux.HandleFunc("POST /api/embeddings", s.handleEmbed) } // healthResponse is the JSON shape returned by /healthz. @@ -73,17 +91,187 @@ type healthResponse struct { Degraded bool `json:"degraded"` } -// handleHealthz returns the daemon's health status. The degraded flag is a -// placeholder for the model poller's connectivity state (Phase 2). +// handleHealthz returns the daemon's health status, including the poller's +// degraded flag so probes and operators can see target connectivity. +// +// Why: load balancers and operators need a single endpoint for health. +// What: returns 200 with a JSON body including the degraded flag from the poller. +// Test: set up a server with a degraded inventory, assert degraded=true in response. func (s *Server) handleHealthz(w http.ResponseWriter, r *http.Request) { + degraded := false + if s.inventory != nil { + degraded = s.inventory.Degraded() + } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(healthResponse{ Status: "ok", - Degraded: false, + Degraded: degraded, }) } +// handleTags returns the cached model inventory as Ollama-format JSON. +// +// Why: foreman's /api/tags must be indistinguishable from Ollama's /api/tags. +// What: returns the poller's cached TagsResponse. +// Test: populate the inventory, GET /api/tags, assert the response matches. +func (s *Server) handleTags(w http.ResponseWriter, r *http.Request) { + models := s.inventory.Models() + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(ollama.TagsResponse{Models: models}) +} + +// handlePs returns the cached running models from the poller. +// +// Why: foreman's /api/ps lets callers see what's resident on the target. +// What: returns the poller's cached PsResponse. +// Test: populate the inventory with running models, GET /api/ps, assert match. +func (s *Server) handlePs(w http.ResponseWriter, r *http.Request) { + running := s.inventory.ResidentModels() + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(ollama.PsResponse{Models: running}) +} + +// handleEmbed proxies embedding requests directly and concurrently to the target. +// These bypass any serialization gate per ADR-0013. +// +// Why: embeddings hit the always-resident embedder and must not wait behind chat jobs. +// What: reads the request body, proxies to the target, and returns the response. +// Test: send concurrent embed requests, assert they all complete without serialization. +func (s *Server) handleEmbed(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, `{"error":"failed to read request body"}`, http.StatusBadRequest) + return + } + + resp, err := s.client.RawEmbed(r.Context(), body) + if err != nil { + s.logger.Error("embed proxy failed", "error", err) + if httpErr, ok := err.(*ollama.HTTPError); ok { + http.Error(w, httpErr.Body, httpErr.StatusCode) + return + } + http.Error(w, `{"error":"target unreachable"}`, http.StatusBadGateway) + return + } + defer resp.Body.Close() + + // Copy response headers and body. + for k, vv := range resp.Header { + for _, v := range vv { + w.Header().Add(k, v) + } + } + if w.Header().Get("Content-Type") == "" { + w.Header().Set("Content-Type", "application/json") + } + w.WriteHeader(resp.StatusCode) + io.Copy(w, resp.Body) +} + +// handleChat is the critical passthrough path for /api/chat. It validates the +// model, serializes through a single-flight gate, and proxies to the target +// with NDJSON streaming support. +// +// Why: the sync passthrough is foreman's primary API surface for go-llm (ADR-0003). +// What: validates model, acquires the chat gate, proxies to the target, streams +// NDJSON chunks back if streaming, releases the gate on completion. +// Test: verify model validation (404 on unknown), serialization (two concurrent +// requests don't overlap), streaming (NDJSON chunks pass through faithfully). +func (s *Server) handleChat(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, `{"error":"failed to read request body"}`, http.StatusBadRequest) + return + } + + // Parse just enough to validate the model and detect streaming. + var partial struct { + Model string `json:"model"` + Stream *bool `json:"stream"` + } + if err := json.Unmarshal(body, &partial); err != nil { + http.Error(w, `{"error":"invalid JSON body"}`, http.StatusBadRequest) + return + } + if partial.Model == "" { + http.Error(w, `{"error":"model is required"}`, http.StatusBadRequest) + return + } + + // Validate the model exists. One re-poll on miss (ADR-0007). + if !s.inventory.HasModel(partial.Model) { + if err := s.inventory.Refresh(r.Context()); err != nil { + s.logger.Warn("model re-poll failed", "error", err) + } + if !s.inventory.HasModel(partial.Model) { + http.Error(w, `{"error":"model not found"}`, http.StatusNotFound) + return + } + } + + // Determine if streaming. Ollama defaults to streaming when "stream" is absent. + streaming := true + if partial.Stream != nil && !*partial.Stream { + streaming = false + } + + // Acquire the single-flight chat gate. This serializes all chat requests + // through one at a time. Phase 3 replaces this with the full SQLite queue + + // worker loop. + select { + case s.chatGate <- struct{}{}: + // Acquired. + case <-r.Context().Done(): + http.Error(w, `{"error":"request cancelled while waiting"}`, http.StatusServiceUnavailable) + return + } + defer func() { <-s.chatGate }() + + // Proxy to the target. + resp, err := s.client.RawChat(r.Context(), body) + if err != nil { + s.logger.Error("chat proxy failed", "error", err, "model", partial.Model) + if httpErr, ok := err.(*ollama.HTTPError); ok { + http.Error(w, httpErr.Body, httpErr.StatusCode) + return + } + http.Error(w, `{"error":"target unreachable"}`, http.StatusBadGateway) + return + } + defer resp.Body.Close() + + if streaming { + w.Header().Set("Content-Type", "application/x-ndjson") + w.WriteHeader(http.StatusOK) + + flusher, canFlush := w.(http.Flusher) + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, scannerBufSize), scannerBufSize) + + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + w.Write(line) + w.Write([]byte("\n")) + if canFlush { + flusher.Flush() + } + } + if err := scanner.Err(); err != nil { + s.logger.Warn("stream read error", "error", err, "model", partial.Model) + } + } else { + // Non-streaming: proxy the complete JSON response. + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + io.Copy(w, resp.Body) + } +} + // authMiddleware validates the Authorization: Bearer header on all // requests except /healthz. Returns 401 if the token is missing or wrong. func (s *Server) authMiddleware(next http.Handler) http.Handler { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 16775a6..60db0cf 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -1,19 +1,29 @@ package server import ( + "bytes" + "context" "encoding/json" + "fmt" + "io" "log/slog" "net/http" "net/http/httptest" "path/filepath" + "strings" + "sync" + "sync/atomic" "testing" + "time" "gitea.stevedudenhoeffer.com/steve/foreman/internal/config" + "gitea.stevedudenhoeffer.com/steve/foreman/internal/ollama" "gitea.stevedudenhoeffer.com/steve/foreman/internal/store" ) -// newTestServer creates a Server backed by a temp-dir SQLite store. -func newTestServer(t *testing.T, cfg config.Config) *Server { +// newTestServer creates a Server backed by a temp-dir SQLite store, a stub client, +// and a pre-populated inventory. +func newTestServer(t *testing.T, cfg config.Config, client ollama.Client) *Server { t.Helper() dbPath := filepath.Join(t.TempDir(), "test.db") st, err := store.Open(dbPath) @@ -23,13 +33,28 @@ func newTestServer(t *testing.T, cfg config.Config) *Server { t.Cleanup(func() { st.Close() }) logger := slog.Default() - return New(cfg, st, logger) + inv := ollama.NewModelInventory(client, logger) + return New(cfg, st, client, inv, logger) +} + +// newTestServerWithInventory creates a Server and pre-refreshes the inventory. +func newTestServerWithInventory(t *testing.T, cfg config.Config, client ollama.Client) *Server { + t.Helper() + srv := newTestServer(t, cfg, client) + if err := srv.inventory.Refresh(context.Background()); err != nil { + t.Fatalf("inventory.Refresh: %v", err) + } + return srv } func TestHealthz_OK(t *testing.T) { - srv := newTestServer(t, config.Config{ + stub := &stubClient{ + tags: &ollama.TagsResponse{}, + ps: &ollama.PsResponse{}, + } + srv := newTestServerWithInventory(t, config.Config{ OllamaURL: "http://localhost:11434", - }) + }, stub) req := httptest.NewRequest(http.MethodGet, "/healthz", nil) rec := httptest.NewRecorder() @@ -52,12 +77,15 @@ func TestHealthz_OK(t *testing.T) { } func TestHealthz_NoAuthRequired(t *testing.T) { - srv := newTestServer(t, config.Config{ + stub := &stubClient{ + tags: &ollama.TagsResponse{}, + ps: &ollama.PsResponse{}, + } + srv := newTestServerWithInventory(t, config.Config{ OllamaURL: "http://localhost:11434", Token: "secret-token", - }) + }, stub) - // /healthz should work without any auth header even when token is configured. req := httptest.NewRequest(http.MethodGet, "/healthz", nil) rec := httptest.NewRecorder() srv.Handler().ServeHTTP(rec, req) @@ -68,16 +96,20 @@ func TestHealthz_NoAuthRequired(t *testing.T) { } func TestAuth_RequiredWhenTokenSet(t *testing.T) { - srv := newTestServer(t, config.Config{ + stub := &stubClient{ + tags: &ollama.TagsResponse{}, + ps: &ollama.PsResponse{}, + } + srv := newTestServerWithInventory(t, config.Config{ OllamaURL: "http://localhost:11434", Token: "secret-token", - }) + }, stub) tests := []struct { - name string - path string - auth string - want int + name string + path string + auth string + want int }{ { name: "no auth header", @@ -123,13 +155,14 @@ func TestAuth_RequiredWhenTokenSet(t *testing.T) { } func TestAuth_NotRequiredWhenNoToken(t *testing.T) { - srv := newTestServer(t, config.Config{ + stub := &stubClient{ + tags: &ollama.TagsResponse{}, + ps: &ollama.PsResponse{}, + } + srv := newTestServerWithInventory(t, config.Config{ OllamaURL: "http://localhost:11434", - // Token intentionally empty. - }) + }, stub) - // Without a configured token, any request should pass auth (even to a - // nonexistent route, which returns 404 rather than 401). req := httptest.NewRequest(http.MethodGet, "/some-route", nil) rec := httptest.NewRecorder() srv.Handler().ServeHTTP(rec, req) @@ -138,3 +171,412 @@ func TestAuth_NotRequiredWhenNoToken(t *testing.T) { t.Error("should not require auth when no token is configured") } } + +func TestTags_ReturnsCachedModels(t *testing.T) { + stub := &stubClient{ + tags: &ollama.TagsResponse{ + Models: []ollama.ModelInfo{ + {Name: "qwen3:30b", Model: "qwen3:30b", Size: 19000000000}, + {Name: "nomic-embed-text", Model: "nomic-embed-text", Size: 300000000}, + }, + }, + ps: &ollama.PsResponse{}, + } + srv := newTestServerWithInventory(t, config.Config{ + OllamaURL: "http://localhost:11434", + }, stub) + + req := httptest.NewRequest(http.MethodGet, "/api/tags", nil) + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp ollama.TagsResponse + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if len(resp.Models) != 2 { + t.Fatalf("got %d models, want 2", len(resp.Models)) + } + if resp.Models[0].Name != "qwen3:30b" { + t.Errorf("first model = %q, want %q", resp.Models[0].Name, "qwen3:30b") + } +} + +func TestPs_ReturnsCachedRunningModels(t *testing.T) { + stub := &stubClient{ + tags: &ollama.TagsResponse{}, + ps: &ollama.PsResponse{ + Models: []ollama.RunningModel{ + {Name: "nomic-embed-text", Model: "nomic-embed-text", Size: 300000000}, + }, + }, + } + srv := newTestServerWithInventory(t, config.Config{ + OllamaURL: "http://localhost:11434", + }, stub) + + req := httptest.NewRequest(http.MethodGet, "/api/ps", nil) + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp ollama.PsResponse + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if len(resp.Models) != 1 { + t.Fatalf("got %d models, want 1", len(resp.Models)) + } +} + +func TestChat_UnknownModel404(t *testing.T) { + stub := &stubClient{ + tags: &ollama.TagsResponse{ + Models: []ollama.ModelInfo{ + {Name: "qwen3:30b"}, + }, + }, + ps: &ollama.PsResponse{}, + } + srv := newTestServerWithInventory(t, config.Config{ + OllamaURL: "http://localhost:11434", + }, stub) + + body := `{"model":"nonexistent-model","messages":[{"role":"user","content":"hi"}],"stream":false}` + req := httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(body)) + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) + } +} + +func TestChat_NonStreaming(t *testing.T) { + chatResp := ollama.ChatResponse{ + Model: "qwen3:30b", + Done: true, + Message: &ollama.Message{Role: "assistant", Content: "Hello!"}, + } + respBytes, _ := json.Marshal(chatResp) + + stub := &stubClient{ + tags: &ollama.TagsResponse{ + Models: []ollama.ModelInfo{{Name: "qwen3:30b"}}, + }, + ps: &ollama.PsResponse{}, + rawChatResp: newRawResponse(200, "application/json", respBytes), + } + srv := newTestServerWithInventory(t, config.Config{ + OllamaURL: "http://localhost:11434", + }, stub) + + body := `{"model":"qwen3:30b","messages":[{"role":"user","content":"hi"}],"stream":false}` + req := httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(body)) + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + ct := rec.Header().Get("Content-Type") + if ct != "application/json" { + t.Errorf("Content-Type = %q, want %q", ct, "application/json") + } + + var got ollama.ChatResponse + if err := json.NewDecoder(rec.Body).Decode(&got); err != nil { + t.Fatalf("decode: %v", err) + } + if got.Message == nil || got.Message.Content != "Hello!" { + t.Errorf("response content = %v, want Hello!", got.Message) + } +} + +func TestChat_Streaming(t *testing.T) { + // Build NDJSON chunks. + chunks := []ollama.ChatResponse{ + {Model: "qwen3:30b", Done: false, Message: &ollama.Message{Role: "assistant", Content: "Hel"}}, + {Model: "qwen3:30b", Done: false, Message: &ollama.Message{Role: "assistant", Content: "lo"}}, + {Model: "qwen3:30b", Done: true, DoneReason: "stop"}, + } + var ndjson bytes.Buffer + for _, c := range chunks { + b, _ := json.Marshal(c) + ndjson.Write(b) + ndjson.WriteByte('\n') + } + + stub := &stubClient{ + tags: &ollama.TagsResponse{ + Models: []ollama.ModelInfo{{Name: "qwen3:30b"}}, + }, + ps: &ollama.PsResponse{}, + rawChatResp: newRawResponse(200, "application/x-ndjson", ndjson.Bytes()), + } + srv := newTestServerWithInventory(t, config.Config{ + OllamaURL: "http://localhost:11434", + }, stub) + + body := `{"model":"qwen3:30b","messages":[{"role":"user","content":"hi"}]}` + req := httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(body)) + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + ct := rec.Header().Get("Content-Type") + if ct != "application/x-ndjson" { + t.Errorf("Content-Type = %q, want %q", ct, "application/x-ndjson") + } + + // Verify chunks pass through faithfully. + lines := strings.Split(strings.TrimSpace(rec.Body.String()), "\n") + if len(lines) != 3 { + t.Fatalf("got %d lines, want 3", len(lines)) + } + + var last ollama.ChatResponse + if err := json.Unmarshal([]byte(lines[2]), &last); err != nil { + t.Fatalf("unmarshal last chunk: %v", err) + } + if !last.Done { + t.Error("last chunk should have done=true") + } +} + +func TestChat_Serialization(t *testing.T) { + // Track concurrent requests at the stub. + var inflight atomic.Int32 + var maxInflight atomic.Int32 + + stub := &stubClient{ + tags: &ollama.TagsResponse{ + Models: []ollama.ModelInfo{{Name: "qwen3:30b"}}, + }, + ps: &ollama.PsResponse{}, + rawChatFunc: func(ctx context.Context, body []byte) (*http.Response, error) { + cur := inflight.Add(1) + defer inflight.Add(-1) + for { + old := maxInflight.Load() + if cur <= old || maxInflight.CompareAndSwap(old, cur) { + break + } + } + // Simulate work. + time.Sleep(50 * time.Millisecond) + resp := ollama.ChatResponse{Model: "qwen3:30b", Done: true} + b, _ := json.Marshal(resp) + return newRawResponse(200, "application/json", b), nil + }, + } + srv := newTestServerWithInventory(t, config.Config{ + OllamaURL: "http://localhost:11434", + }, stub) + + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + body := `{"model":"qwen3:30b","messages":[{"role":"user","content":"hi"}],"stream":false}` + req := httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(body)) + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + }() + } + wg.Wait() + + if got := maxInflight.Load(); got > 1 { + t.Errorf("max concurrent chat requests at target = %d, want 1 (gate should serialize)", got) + } +} + +func TestEmbed_ConcurrentBypassesGate(t *testing.T) { + // Track concurrent embed requests. + var inflight atomic.Int32 + var maxInflight atomic.Int32 + + stub := &stubClient{ + tags: &ollama.TagsResponse{ + Models: []ollama.ModelInfo{{Name: "qwen3:30b"}}, + }, + ps: &ollama.PsResponse{}, + rawEmbedFunc: func(ctx context.Context, body []byte) (*http.Response, error) { + cur := inflight.Add(1) + defer inflight.Add(-1) + for { + old := maxInflight.Load() + if cur <= old || maxInflight.CompareAndSwap(old, cur) { + break + } + } + // Simulate some work so concurrent requests overlap. + time.Sleep(50 * time.Millisecond) + resp := ollama.EmbedResponse{Model: "nomic-embed-text", Embeddings: [][]float64{{0.1, 0.2}}} + b, _ := json.Marshal(resp) + return newRawResponse(200, "application/json", b), nil + }, + } + srv := newTestServerWithInventory(t, config.Config{ + OllamaURL: "http://localhost:11434", + }, stub) + + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + body := `{"model":"nomic-embed-text","input":"test"}` + req := httptest.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(body)) + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Errorf("embed status = %d, want %d", rec.Code, http.StatusOK) + } + }() + } + wg.Wait() + + if got := maxInflight.Load(); got < 2 { + t.Errorf("max concurrent embed requests = %d, want >= 2 (embeds should run in parallel)", got) + } +} + +func TestEmbed_AlsoWorksOnEmbeddingsPath(t *testing.T) { + embedResp := ollama.EmbedResponse{ + Model: "nomic-embed-text", + Embeddings: [][]float64{{0.1, 0.2, 0.3}}, + } + respBytes, _ := json.Marshal(embedResp) + + stub := &stubClient{ + tags: &ollama.TagsResponse{}, + ps: &ollama.PsResponse{}, + rawEmbedFunc: func(ctx context.Context, body []byte) (*http.Response, error) { + return newRawResponse(200, "application/json", respBytes), nil + }, + } + srv := newTestServerWithInventory(t, config.Config{ + OllamaURL: "http://localhost:11434", + }, stub) + + body := `{"model":"nomic-embed-text","input":"test"}` + req := httptest.NewRequest(http.MethodPost, "/api/embeddings", strings.NewReader(body)) + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } +} + +func TestHealthz_DegradedFromInventory(t *testing.T) { + stub := &stubClient{ + tagsErr: fmt.Errorf("connection refused"), + ps: &ollama.PsResponse{}, + } + srv := newTestServer(t, config.Config{ + OllamaURL: "http://localhost:11434", + }, stub) + + // Refresh will fail, setting degraded = true. + srv.inventory.Refresh(context.Background()) + + req := httptest.NewRequest(http.MethodGet, "/healthz", nil) + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp healthResponse + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if !resp.Degraded { + t.Error("expected degraded=true when inventory poll failed") + } +} + +// --- Stub client for testing --- + +// stubClient implements ollama.Client for testing. +type stubClient struct { + tags *ollama.TagsResponse + tagsErr error + ps *ollama.PsResponse + psErr error + + rawChatResp *http.Response + rawChatFunc func(ctx context.Context, body []byte) (*http.Response, error) + + rawEmbedResp *http.Response + rawEmbedFunc func(ctx context.Context, body []byte) (*http.Response, error) +} + +func (s *stubClient) Chat(ctx context.Context, req ollama.ChatRequest, stream bool) (*ollama.ChatResponse, <-chan ollama.ChatResponse, error) { + return nil, nil, fmt.Errorf("stubClient.Chat not implemented") +} + +func (s *stubClient) Embed(ctx context.Context, req ollama.EmbedRequest) (*ollama.EmbedResponse, error) { + return nil, fmt.Errorf("stubClient.Embed not implemented") +} + +func (s *stubClient) Tags(ctx context.Context) (*ollama.TagsResponse, error) { + if s.tagsErr != nil { + return nil, s.tagsErr + } + return s.tags, nil +} + +func (s *stubClient) Ps(ctx context.Context) (*ollama.PsResponse, error) { + if s.psErr != nil { + return nil, s.psErr + } + return s.ps, nil +} + +func (s *stubClient) RawChat(ctx context.Context, body []byte) (*http.Response, error) { + if s.rawChatFunc != nil { + return s.rawChatFunc(ctx, body) + } + if s.rawChatResp != nil { + return s.rawChatResp, nil + } + return nil, fmt.Errorf("stubClient.RawChat not configured") +} + +func (s *stubClient) RawEmbed(ctx context.Context, body []byte) (*http.Response, error) { + if s.rawEmbedFunc != nil { + return s.rawEmbedFunc(ctx, body) + } + if s.rawEmbedResp != nil { + return s.rawEmbedResp, nil + } + return nil, fmt.Errorf("stubClient.RawEmbed not configured") +} + +// newRawResponse builds a minimal *http.Response for testing. +func newRawResponse(status int, contentType string, body []byte) *http.Response { + return &http.Response{ + StatusCode: status, + Header: http.Header{"Content-Type": {contentType}}, + Body: io.NopCloser(bytes.NewReader(body)), + } +} diff --git a/progress.md b/progress.md index f1081bb..9463dd7 100644 --- a/progress.md +++ b/progress.md @@ -17,3 +17,52 @@ - Dockerfile: multi-stage distroless build - Config files: `.env.example`, `.gitignore` - Tests: config validation, store CRUD + edge cases, server health + auth middleware + +## Phase 2: Ollama target client, model poller, native passthrough — 2026-05-23 + +- `internal/ollama/` — target client package: + - Wire types (`types.go`): ChatRequest/Response, EmbedRequest/Response, TagsResponse, + PsResponse, ModelInfo, RunningModel — matching Ollama's native JSON API exactly. + Polymorphic fields (think, keep_alive, tools, options) use `json.RawMessage` + for transparent passthrough fidelity. + - `Client` interface (`client.go`): Chat (stream/non-stream), Embed, Tags, Ps, + RawChat, RawEmbed. RawChat/RawEmbed return `*http.Response` for zero-copy + streaming passthrough. + - `httpClient` implementation: auth token injection, NDJSON streaming via + `bufio.Scanner` with 4 MB buffer, connection vs HTTP error classification. + - Custom error types (`errors.go`): `*ConnectionError` for network failures + (retry-eligible), `*HTTPError` for non-2xx responses. `errors.Is`/`errors.As` + compatible. + - `ModelInventory` (`inventory.go`): mutex-protected in-memory cache of installed + and running models. Methods: Models(), HasModel(), ResidentModels(), LastPoll(), + Degraded(), Refresh(). Background `Start()` goroutine polls at + `FOREMAN_POLL_INTERVAL` (default 30s). On target unreachable: retains last-known + inventory, sets `degraded=true`. Clears degraded on recovery. +- `internal/server/` — new Ollama passthrough routes: + - `GET /api/tags` — serves poller's cached model list + - `GET /api/ps` — serves poller's cached running models + - `POST /api/embed`, `POST /api/embeddings` — direct concurrent proxy to target, + bypasses the chat gate entirely (ADR-0013) + - `POST /api/chat` — critical path: validates model (re-poll on miss, 404 if + still absent), serializes through a capacity-1 channel gate, proxies to target + with NDJSON streaming (`application/x-ndjson`, flushed per chunk) or + non-streaming JSON passthrough + - `GET /healthz` — now wired to `inventory.Degraded()` for real target status +- `cmd/foreman/main.go` — full serve wiring: + - Creates Ollama client, starts model poller goroutine, warms embedder + (`keep_alive: -1`), creates server with all dependencies, signal-based + graceful shutdown via `context.NotifyContext` +- Tests (all passing with `-race`): + - Client: tags/ps parsing, chat streaming + non-streaming, embed, auth token + forwarding, `*ConnectionError` on unreachable target, `*HTTPError` on non-2xx + - Inventory: refresh populates models, degraded on failure, model retention, + recovery from degraded, Start/cancel lifecycle + - Server: tags/ps passthrough, model validation (404 on unknown), non-streaming + chat proxy, NDJSON streaming passthrough with correct Content-Type, chat + serialization (gate holds concurrent requests to max 1 in-flight), concurrent + embed bypass (multiple requests run in parallel), degraded health endpoint, + embeddings alias path + +The Mac is now usable as a go-llm target through foreman: +`llm.OllamaCloud(token, WithBaseURL("http://foreman:8080"))` works transparently +for chat (streaming + non-streaming), tags, ps, and embeddings.