From f70c7c08420b45fcc3f33b195fcc8cbf38406745 Mon Sep 17 00:00:00 2001 From: Steve Dudenhoeffer Date: Fri, 1 May 2026 18:29:04 +0000 Subject: [PATCH] feat(v2/ollama): implement native Stream() with NDJSON parsing Reads Ollama's NDJSON stream (one JSON object per line) and emits provider.StreamEvent values for text, thinking, tool-call start/delta/end, and a final Done event carrying assembled Response and Usage. Uses bufio.Scanner with a 4 MiB max-line buffer so multi-KB tool-call deltas parse cleanly, and accepts tool-call arguments delivered either as escaped string fragments (delta-style) or a complete JSON object (one-shot). Co-Authored-By: Claude Opus 4.6 --- v2/ollama/native.go | 203 ++++++++++++++++++++++++++++++++++- v2/ollama/native_test.go | 226 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 427 insertions(+), 2 deletions(-) diff --git a/v2/ollama/native.go b/v2/ollama/native.go index 16dd891..5fd1c83 100644 --- a/v2/ollama/native.go +++ b/v2/ollama/native.go @@ -5,6 +5,7 @@ package ollama import ( + "bufio" "bytes" "context" "encoding/base64" @@ -168,9 +169,207 @@ func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider } // Stream performs a streaming chat completion via /api/chat with -// `stream: true`, parsing NDJSON line-by-line. +// `stream: true`, parsing NDJSON line-by-line. Tool-call argument deltas are +// accumulated across chunks keyed by id (or function index) and finalized +// when the upstream Done flag arrives. func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error { - return fmt.Errorf("ollama native provider: Stream not implemented") + defer close(events) + + body, err := p.buildChatRequest(req, true) + if err != nil { + return err + } + + httpResp, err := p.doChatRequest(ctx, body) + if err != nil { + return err + } + defer httpResp.Body.Close() + + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + return fmt.Errorf("ollama: HTTP %d: %s", httpResp.StatusCode, string(b)) + } + + scanner := bufio.NewScanner(httpResp.Body) + // Ollama can emit multi-KB lines on tool-call deltas. Generous buffer. + const maxLineSize = 4 * 1024 * 1024 + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + type toolAcc struct { + id string + name string + args strings.Builder + index int // ToolIndex emitted on stream events + } + tools := map[string]*toolAcc{} + var toolOrder []*toolAcc + + var ( + fullText strings.Builder + fullThinking strings.Builder + usage *provider.Usage + streamErr error + ) + + for scanner.Scan() { + line := scanner.Bytes() + if len(bytes.TrimSpace(line)) == 0 { + continue + } + var chunk nativeChatResponse + if err := json.Unmarshal(line, &chunk); err != nil { + streamErr = fmt.Errorf("ollama: decode stream chunk: %w", err) + break + } + + if chunk.Message.Thinking != "" { + fullThinking.WriteString(chunk.Message.Thinking) + events <- provider.StreamEvent{ + Type: provider.StreamEventThinking, + Text: chunk.Message.Thinking, + } + } + if chunk.Message.Content != "" { + fullText.WriteString(chunk.Message.Content) + events <- provider.StreamEvent{ + Type: provider.StreamEventText, + Text: chunk.Message.Content, + } + } + + for pos, tc := range chunk.Message.ToolCalls { + key := streamToolKey(tc, pos) + acc, exists := tools[key] + if !exists { + acc = &toolAcc{ + id: tc.ID, + name: tc.Function.Name, + index: len(toolOrder), + } + if acc.id == "" { + acc.id = fmt.Sprintf("tc_%d", acc.index) + } + tools[key] = acc + toolOrder = append(toolOrder, acc) + events <- provider.StreamEvent{ + Type: provider.StreamEventToolStart, + ToolIndex: acc.index, + ToolCall: &provider.ToolCall{ + ID: acc.id, + Name: acc.name, + }, + } + } else { + // Continuation chunk may carry the tool's name late; capture it. + if tc.Function.Name != "" && acc.name == "" { + acc.name = tc.Function.Name + } + } + + delta := decodeArgumentDelta(tc.Function.Arguments) + if delta != "" { + acc.args.WriteString(delta) + events <- provider.StreamEvent{ + Type: provider.StreamEventToolDelta, + ToolIndex: acc.index, + ToolCall: &provider.ToolCall{ + Arguments: delta, + }, + } + } + } + + if chunk.Done { + if chunk.PromptEvalCount > 0 || chunk.EvalCount > 0 { + usage = &provider.Usage{ + InputTokens: chunk.PromptEvalCount, + OutputTokens: chunk.EvalCount, + TotalTokens: chunk.PromptEvalCount + chunk.EvalCount, + } + } + break + } + } + + if err := scanner.Err(); err != nil && streamErr == nil { + streamErr = fmt.Errorf("ollama: stream read: %w", err) + } + + if streamErr != nil { + events <- provider.StreamEvent{ + Type: provider.StreamEventError, + Error: streamErr, + } + return streamErr + } + + // Finalize accumulated tool calls. + finalCalls := make([]provider.ToolCall, 0, len(toolOrder)) + for _, acc := range toolOrder { + args := acc.args.String() + if args == "" { + args = "{}" + } + final := provider.ToolCall{ + ID: acc.id, + Name: acc.name, + Arguments: args, + } + finalCalls = append(finalCalls, final) + events <- provider.StreamEvent{ + Type: provider.StreamEventToolEnd, + ToolIndex: acc.index, + ToolCall: &final, + } + } + + events <- provider.StreamEvent{ + Type: provider.StreamEventDone, + Response: &provider.Response{ + Text: fullText.String(), + Thinking: fullThinking.String(), + ToolCalls: finalCalls, + Usage: usage, + }, + } + return nil +} + +// streamToolKey computes a stable map key correlating tool-call deltas +// across stream chunks. Prefer the wire id, fall back to function index, +// finally fall back to the tool's position in the chunk's tool_calls array +// (a single-tool stream collapses cleanly under any strategy). +func streamToolKey(tc nativeToolCall, position int) string { + if tc.ID != "" { + return "id:" + tc.ID + } + if tc.Function.Index != nil { + return fmt.Sprintf("idx:%d", *tc.Function.Index) + } + return fmt.Sprintf("pos:%d", position) +} + +// decodeArgumentDelta returns the string fragment to append when a streamed +// tool-call chunk includes arguments. Ollama may emit arguments either as a +// JSON-encoded string fragment (chunk-by-chunk concatenation, openaicompat +// style) or as a complete object value (one-shot delivery). We accept both: +// strings are unwrapped, objects/arrays pass through verbatim. +func decodeArgumentDelta(raw json.RawMessage) string { + if len(raw) == 0 { + return "" + } + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 || string(trimmed) == "null" { + return "" + } + if trimmed[0] == '"' { + var s string + if err := json.Unmarshal(trimmed, &s); err == nil { + return s + } + } + return string(trimmed) } // buildChatRequest converts a provider.Request into the native wire body diff --git a/v2/ollama/native_test.go b/v2/ollama/native_test.go index a88c580..311c45b 100644 --- a/v2/ollama/native_test.go +++ b/v2/ollama/native_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" ) @@ -345,3 +346,228 @@ func toString(v any) string { b, _ := json.Marshal(v) return string(b) } + +// streamServer returns an httptest.Server that writes the given NDJSON lines +// (each terminated with \n) as the response body. +func streamServer(t *testing.T, captured *captureRequest, lines []string) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured.method = r.Method + captured.path = r.URL.Path + captured.authHeader = r.Header.Get("Authorization") + captured.contentType = r.Header.Get("Content-Type") + body, _ := io.ReadAll(r.Body) + captured.body = body + _ = json.Unmarshal(body, &captured.parsedBody) + w.Header().Set("Content-Type", "application/x-ndjson") + w.WriteHeader(200) + flusher, _ := w.(http.Flusher) + for _, line := range lines { + _, _ = w.Write([]byte(line + "\n")) + if flusher != nil { + flusher.Flush() + } + } + })) + t.Cleanup(srv.Close) + return srv +} + +func collectStream(t *testing.T, p *Provider, req provider.Request) []provider.StreamEvent { + t.Helper() + events := make(chan provider.StreamEvent, 64) + done := make(chan error, 1) + go func() { + done <- p.Stream(context.Background(), req, events) + }() + var out []provider.StreamEvent + timeout := time.After(5 * time.Second) + streamErrored := false +loop: + for { + select { + case ev, ok := <-events: + if !ok { + break loop + } + out = append(out, ev) + if ev.Type == provider.StreamEventError { + streamErrored = true + } + case err := <-done: + if err != nil && !streamErrored { + t.Fatalf("Stream returned error: %v", err) + } + // Drain any final events buffered in the channel. + for { + select { + case ev, ok := <-events: + if !ok { + return out + } + out = append(out, ev) + default: + return out + } + } + case <-timeout: + t.Fatal("Stream did not complete within 5s") + } + } + if err := <-done; err != nil && !streamErrored { + t.Fatalf("Stream returned error: %v", err) + } + return out +} + +func TestStreamBasic(t *testing.T) { + lines := []string{ + `{"message":{"role":"assistant","content":"hello"},"done":false}`, + `{"message":{"role":"assistant","content":" world","thinking":"reasoning"},"done":false}`, + `{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":12,"eval_count":2}`, + } + cap := &captureRequest{} + srv := streamServer(t, cap, lines) + + p := newNative("", srv.URL) + events := collectStream(t, p, provider.Request{ + Model: "kimi-k2.5", + Messages: []provider.Message{{Role: "user", Content: "hi"}}, + }) + + // Verify request shape: stream:true. + if cap.parsedBody["stream"] != true { + t.Errorf("body.stream: want true, got %v", cap.parsedBody["stream"]) + } + + // Filter to relevant events (text, thinking, done) preserving order. + var kinds []string + var texts []string + var doneEvent *provider.StreamEvent + for i, ev := range events { + switch ev.Type { + case provider.StreamEventText: + kinds = append(kinds, "text") + texts = append(texts, ev.Text) + case provider.StreamEventThinking: + kinds = append(kinds, "thinking") + texts = append(texts, ev.Text) + case provider.StreamEventDone: + kinds = append(kinds, "done") + e := events[i] + doneEvent = &e + } + } + + wantKinds := []string{"text", "thinking", "text", "done"} + if !equalStrings(kinds, wantKinds) { + t.Errorf("event kinds: want %v, got %v", wantKinds, kinds) + } + if len(texts) >= 3 { + if texts[0] != "hello" { + t.Errorf("first text: want hello, got %q", texts[0]) + } + if texts[1] != "reasoning" { + t.Errorf("thinking: want reasoning, got %q", texts[1]) + } + if texts[2] != " world" { + t.Errorf("second text: want \" world\", got %q", texts[2]) + } + } + if doneEvent == nil || doneEvent.Response == nil { + t.Fatal("Done event missing Response") + } + if doneEvent.Response.Text != "hello world" { + t.Errorf("Response.Text: want %q, got %q", "hello world", doneEvent.Response.Text) + } + if doneEvent.Response.Thinking != "reasoning" { + t.Errorf("Response.Thinking: want %q, got %q", "reasoning", doneEvent.Response.Thinking) + } + if doneEvent.Response.Usage == nil { + t.Fatal("Response.Usage missing") + } + if doneEvent.Response.Usage.InputTokens != 12 || doneEvent.Response.Usage.OutputTokens != 2 { + t.Errorf("Usage: want input=12 output=2, got input=%d output=%d", doneEvent.Response.Usage.InputTokens, doneEvent.Response.Usage.OutputTokens) + } +} + +func TestStreamToolDeltaAccumulation(t *testing.T) { + lines := []string{ + `{"message":{"role":"assistant","content":"","tool_calls":[{"id":"tc1","function":{"name":"search","arguments":"{\"que"}}]},"done":false}`, + `{"message":{"role":"assistant","content":"","tool_calls":[{"id":"tc1","function":{"arguments":"ry\":\"foo\"}"}}]},"done":false}`, + `{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":4,"eval_count":1}`, + } + cap := &captureRequest{} + srv := streamServer(t, cap, lines) + + p := newNative("", srv.URL) + events := collectStream(t, p, provider.Request{ + Model: "kimi-k2.5", + Messages: []provider.Message{{Role: "user", Content: "search foo"}}, + Tools: []provider.ToolDef{ + {Name: "search", Schema: map[string]any{"type": "object"}}, + }, + }) + + // Build a slim trace of tool events. + type traceEntry struct { + kind string + args string + name string + id string + } + var trace []traceEntry + var doneEvent *provider.StreamEvent + for i, ev := range events { + switch ev.Type { + case provider.StreamEventToolStart: + trace = append(trace, traceEntry{kind: "start", name: ev.ToolCall.Name, id: ev.ToolCall.ID}) + case provider.StreamEventToolDelta: + trace = append(trace, traceEntry{kind: "delta", args: ev.ToolCall.Arguments}) + case provider.StreamEventToolEnd: + trace = append(trace, traceEntry{kind: "end", args: ev.ToolCall.Arguments, name: ev.ToolCall.Name, id: ev.ToolCall.ID}) + case provider.StreamEventDone: + e := events[i] + doneEvent = &e + } + } + + if len(trace) != 4 { + t.Fatalf("trace: want 4 entries (start, delta, delta, end), got %d: %+v", len(trace), trace) + } + if trace[0].kind != "start" || trace[0].name != "search" || trace[0].id != "tc1" { + t.Errorf("trace[0]: want start search tc1, got %+v", trace[0]) + } + if trace[1].kind != "delta" || trace[1].args != `{"que` { + t.Errorf("trace[1]: want delta args=%q, got %+v", `{"que`, trace[1]) + } + if trace[2].kind != "delta" || trace[2].args != `ry":"foo"}` { + t.Errorf("trace[2]: want delta args=%q, got %+v", `ry":"foo"}`, trace[2]) + } + if trace[3].kind != "end" || trace[3].args != `{"query":"foo"}` { + t.Errorf("trace[3]: want end args=%q, got %+v", `{"query":"foo"}`, trace[3]) + } + + if doneEvent == nil || doneEvent.Response == nil { + t.Fatal("Done event missing Response") + } + if len(doneEvent.Response.ToolCalls) != 1 { + t.Fatalf("Done.Response.ToolCalls: want 1, got %d", len(doneEvent.Response.ToolCalls)) + } + tc := doneEvent.Response.ToolCalls[0] + if tc.ID != "tc1" || tc.Name != "search" || tc.Arguments != `{"query":"foo"}` { + t.Errorf("Done.Response.ToolCalls[0]: want tc1/search/{...}, got %+v", tc) + } +} + +func equalStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +}