feat(v2/ollama): implement native Stream() with NDJSON parsing

Reads Ollama's NDJSON stream (one JSON object per line) and emits
provider.StreamEvent values for text, thinking, tool-call start/delta/end,
and a final Done event carrying assembled Response and Usage. Uses
bufio.Scanner with a 4 MiB max-line buffer so multi-KB tool-call deltas
parse cleanly, and accepts tool-call arguments delivered either as
escaped string fragments (delta-style) or a complete JSON object
(one-shot).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-01 18:29:04 +00:00
parent 583f8724b2
commit f70c7c0842
2 changed files with 427 additions and 2 deletions
+226
View File
@@ -8,6 +8,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
@@ -345,3 +346,228 @@ func toString(v any) string {
b, _ := json.Marshal(v)
return string(b)
}
// streamServer returns an httptest.Server that writes the given NDJSON lines
// (each terminated with \n) as the response body.
func streamServer(t *testing.T, captured *captureRequest, lines []string) *httptest.Server {
t.Helper()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.path = r.URL.Path
captured.authHeader = r.Header.Get("Authorization")
captured.contentType = r.Header.Get("Content-Type")
body, _ := io.ReadAll(r.Body)
captured.body = body
_ = json.Unmarshal(body, &captured.parsedBody)
w.Header().Set("Content-Type", "application/x-ndjson")
w.WriteHeader(200)
flusher, _ := w.(http.Flusher)
for _, line := range lines {
_, _ = w.Write([]byte(line + "\n"))
if flusher != nil {
flusher.Flush()
}
}
}))
t.Cleanup(srv.Close)
return srv
}
func collectStream(t *testing.T, p *Provider, req provider.Request) []provider.StreamEvent {
t.Helper()
events := make(chan provider.StreamEvent, 64)
done := make(chan error, 1)
go func() {
done <- p.Stream(context.Background(), req, events)
}()
var out []provider.StreamEvent
timeout := time.After(5 * time.Second)
streamErrored := false
loop:
for {
select {
case ev, ok := <-events:
if !ok {
break loop
}
out = append(out, ev)
if ev.Type == provider.StreamEventError {
streamErrored = true
}
case err := <-done:
if err != nil && !streamErrored {
t.Fatalf("Stream returned error: %v", err)
}
// Drain any final events buffered in the channel.
for {
select {
case ev, ok := <-events:
if !ok {
return out
}
out = append(out, ev)
default:
return out
}
}
case <-timeout:
t.Fatal("Stream did not complete within 5s")
}
}
if err := <-done; err != nil && !streamErrored {
t.Fatalf("Stream returned error: %v", err)
}
return out
}
func TestStreamBasic(t *testing.T) {
lines := []string{
`{"message":{"role":"assistant","content":"hello"},"done":false}`,
`{"message":{"role":"assistant","content":" world","thinking":"reasoning"},"done":false}`,
`{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":12,"eval_count":2}`,
}
cap := &captureRequest{}
srv := streamServer(t, cap, lines)
p := newNative("", srv.URL)
events := collectStream(t, p, provider.Request{
Model: "kimi-k2.5",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
// Verify request shape: stream:true.
if cap.parsedBody["stream"] != true {
t.Errorf("body.stream: want true, got %v", cap.parsedBody["stream"])
}
// Filter to relevant events (text, thinking, done) preserving order.
var kinds []string
var texts []string
var doneEvent *provider.StreamEvent
for i, ev := range events {
switch ev.Type {
case provider.StreamEventText:
kinds = append(kinds, "text")
texts = append(texts, ev.Text)
case provider.StreamEventThinking:
kinds = append(kinds, "thinking")
texts = append(texts, ev.Text)
case provider.StreamEventDone:
kinds = append(kinds, "done")
e := events[i]
doneEvent = &e
}
}
wantKinds := []string{"text", "thinking", "text", "done"}
if !equalStrings(kinds, wantKinds) {
t.Errorf("event kinds: want %v, got %v", wantKinds, kinds)
}
if len(texts) >= 3 {
if texts[0] != "hello" {
t.Errorf("first text: want hello, got %q", texts[0])
}
if texts[1] != "reasoning" {
t.Errorf("thinking: want reasoning, got %q", texts[1])
}
if texts[2] != " world" {
t.Errorf("second text: want \" world\", got %q", texts[2])
}
}
if doneEvent == nil || doneEvent.Response == nil {
t.Fatal("Done event missing Response")
}
if doneEvent.Response.Text != "hello world" {
t.Errorf("Response.Text: want %q, got %q", "hello world", doneEvent.Response.Text)
}
if doneEvent.Response.Thinking != "reasoning" {
t.Errorf("Response.Thinking: want %q, got %q", "reasoning", doneEvent.Response.Thinking)
}
if doneEvent.Response.Usage == nil {
t.Fatal("Response.Usage missing")
}
if doneEvent.Response.Usage.InputTokens != 12 || doneEvent.Response.Usage.OutputTokens != 2 {
t.Errorf("Usage: want input=12 output=2, got input=%d output=%d", doneEvent.Response.Usage.InputTokens, doneEvent.Response.Usage.OutputTokens)
}
}
func TestStreamToolDeltaAccumulation(t *testing.T) {
lines := []string{
`{"message":{"role":"assistant","content":"","tool_calls":[{"id":"tc1","function":{"name":"search","arguments":"{\"que"}}]},"done":false}`,
`{"message":{"role":"assistant","content":"","tool_calls":[{"id":"tc1","function":{"arguments":"ry\":\"foo\"}"}}]},"done":false}`,
`{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":4,"eval_count":1}`,
}
cap := &captureRequest{}
srv := streamServer(t, cap, lines)
p := newNative("", srv.URL)
events := collectStream(t, p, provider.Request{
Model: "kimi-k2.5",
Messages: []provider.Message{{Role: "user", Content: "search foo"}},
Tools: []provider.ToolDef{
{Name: "search", Schema: map[string]any{"type": "object"}},
},
})
// Build a slim trace of tool events.
type traceEntry struct {
kind string
args string
name string
id string
}
var trace []traceEntry
var doneEvent *provider.StreamEvent
for i, ev := range events {
switch ev.Type {
case provider.StreamEventToolStart:
trace = append(trace, traceEntry{kind: "start", name: ev.ToolCall.Name, id: ev.ToolCall.ID})
case provider.StreamEventToolDelta:
trace = append(trace, traceEntry{kind: "delta", args: ev.ToolCall.Arguments})
case provider.StreamEventToolEnd:
trace = append(trace, traceEntry{kind: "end", args: ev.ToolCall.Arguments, name: ev.ToolCall.Name, id: ev.ToolCall.ID})
case provider.StreamEventDone:
e := events[i]
doneEvent = &e
}
}
if len(trace) != 4 {
t.Fatalf("trace: want 4 entries (start, delta, delta, end), got %d: %+v", len(trace), trace)
}
if trace[0].kind != "start" || trace[0].name != "search" || trace[0].id != "tc1" {
t.Errorf("trace[0]: want start search tc1, got %+v", trace[0])
}
if trace[1].kind != "delta" || trace[1].args != `{"que` {
t.Errorf("trace[1]: want delta args=%q, got %+v", `{"que`, trace[1])
}
if trace[2].kind != "delta" || trace[2].args != `ry":"foo"}` {
t.Errorf("trace[2]: want delta args=%q, got %+v", `ry":"foo"}`, trace[2])
}
if trace[3].kind != "end" || trace[3].args != `{"query":"foo"}` {
t.Errorf("trace[3]: want end args=%q, got %+v", `{"query":"foo"}`, trace[3])
}
if doneEvent == nil || doneEvent.Response == nil {
t.Fatal("Done event missing Response")
}
if len(doneEvent.Response.ToolCalls) != 1 {
t.Fatalf("Done.Response.ToolCalls: want 1, got %d", len(doneEvent.Response.ToolCalls))
}
tc := doneEvent.Response.ToolCalls[0]
if tc.ID != "tc1" || tc.Name != "search" || tc.Arguments != `{"query":"foo"}` {
t.Errorf("Done.Response.ToolCalls[0]: want tc1/search/{...}, got %+v", tc)
}
}
func equalStrings(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}