feat(v2/ollama): implement native Stream() with NDJSON parsing
Reads Ollama's NDJSON stream (one JSON object per line) and emits provider.StreamEvent values for text, thinking, tool-call start/delta/end, and a final Done event carrying assembled Response and Usage. Uses bufio.Scanner with a 4 MiB max-line buffer so multi-KB tool-call deltas parse cleanly, and accepts tool-call arguments delivered either as escaped string fragments (delta-style) or a complete JSON object (one-shot). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+201
-2
@@ -5,6 +5,7 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
@@ -168,9 +169,207 @@ func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider
|
||||
}
|
||||
|
||||
// Stream performs a streaming chat completion via /api/chat with
|
||||
// `stream: true`, parsing NDJSON line-by-line.
|
||||
// `stream: true`, parsing NDJSON line-by-line. Tool-call argument deltas are
|
||||
// accumulated across chunks keyed by id (or function index) and finalized
|
||||
// when the upstream Done flag arrives.
|
||||
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
||||
return fmt.Errorf("ollama native provider: Stream not implemented")
|
||||
defer close(events)
|
||||
|
||||
body, err := p.buildChatRequest(req, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpResp, err := p.doChatRequest(ctx, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
return fmt.Errorf("ollama: HTTP %d: %s", httpResp.StatusCode, string(b))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
// Ollama can emit multi-KB lines on tool-call deltas. Generous buffer.
|
||||
const maxLineSize = 4 * 1024 * 1024
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
type toolAcc struct {
|
||||
id string
|
||||
name string
|
||||
args strings.Builder
|
||||
index int // ToolIndex emitted on stream events
|
||||
}
|
||||
tools := map[string]*toolAcc{}
|
||||
var toolOrder []*toolAcc
|
||||
|
||||
var (
|
||||
fullText strings.Builder
|
||||
fullThinking strings.Builder
|
||||
usage *provider.Usage
|
||||
streamErr error
|
||||
)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if len(bytes.TrimSpace(line)) == 0 {
|
||||
continue
|
||||
}
|
||||
var chunk nativeChatResponse
|
||||
if err := json.Unmarshal(line, &chunk); err != nil {
|
||||
streamErr = fmt.Errorf("ollama: decode stream chunk: %w", err)
|
||||
break
|
||||
}
|
||||
|
||||
if chunk.Message.Thinking != "" {
|
||||
fullThinking.WriteString(chunk.Message.Thinking)
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventThinking,
|
||||
Text: chunk.Message.Thinking,
|
||||
}
|
||||
}
|
||||
if chunk.Message.Content != "" {
|
||||
fullText.WriteString(chunk.Message.Content)
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventText,
|
||||
Text: chunk.Message.Content,
|
||||
}
|
||||
}
|
||||
|
||||
for pos, tc := range chunk.Message.ToolCalls {
|
||||
key := streamToolKey(tc, pos)
|
||||
acc, exists := tools[key]
|
||||
if !exists {
|
||||
acc = &toolAcc{
|
||||
id: tc.ID,
|
||||
name: tc.Function.Name,
|
||||
index: len(toolOrder),
|
||||
}
|
||||
if acc.id == "" {
|
||||
acc.id = fmt.Sprintf("tc_%d", acc.index)
|
||||
}
|
||||
tools[key] = acc
|
||||
toolOrder = append(toolOrder, acc)
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventToolStart,
|
||||
ToolIndex: acc.index,
|
||||
ToolCall: &provider.ToolCall{
|
||||
ID: acc.id,
|
||||
Name: acc.name,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// Continuation chunk may carry the tool's name late; capture it.
|
||||
if tc.Function.Name != "" && acc.name == "" {
|
||||
acc.name = tc.Function.Name
|
||||
}
|
||||
}
|
||||
|
||||
delta := decodeArgumentDelta(tc.Function.Arguments)
|
||||
if delta != "" {
|
||||
acc.args.WriteString(delta)
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventToolDelta,
|
||||
ToolIndex: acc.index,
|
||||
ToolCall: &provider.ToolCall{
|
||||
Arguments: delta,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if chunk.Done {
|
||||
if chunk.PromptEvalCount > 0 || chunk.EvalCount > 0 {
|
||||
usage = &provider.Usage{
|
||||
InputTokens: chunk.PromptEvalCount,
|
||||
OutputTokens: chunk.EvalCount,
|
||||
TotalTokens: chunk.PromptEvalCount + chunk.EvalCount,
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil && streamErr == nil {
|
||||
streamErr = fmt.Errorf("ollama: stream read: %w", err)
|
||||
}
|
||||
|
||||
if streamErr != nil {
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventError,
|
||||
Error: streamErr,
|
||||
}
|
||||
return streamErr
|
||||
}
|
||||
|
||||
// Finalize accumulated tool calls.
|
||||
finalCalls := make([]provider.ToolCall, 0, len(toolOrder))
|
||||
for _, acc := range toolOrder {
|
||||
args := acc.args.String()
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
final := provider.ToolCall{
|
||||
ID: acc.id,
|
||||
Name: acc.name,
|
||||
Arguments: args,
|
||||
}
|
||||
finalCalls = append(finalCalls, final)
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventToolEnd,
|
||||
ToolIndex: acc.index,
|
||||
ToolCall: &final,
|
||||
}
|
||||
}
|
||||
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventDone,
|
||||
Response: &provider.Response{
|
||||
Text: fullText.String(),
|
||||
Thinking: fullThinking.String(),
|
||||
ToolCalls: finalCalls,
|
||||
Usage: usage,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamToolKey computes a stable map key correlating tool-call deltas
|
||||
// across stream chunks. Prefer the wire id, fall back to function index,
|
||||
// finally fall back to the tool's position in the chunk's tool_calls array
|
||||
// (a single-tool stream collapses cleanly under any strategy).
|
||||
func streamToolKey(tc nativeToolCall, position int) string {
|
||||
if tc.ID != "" {
|
||||
return "id:" + tc.ID
|
||||
}
|
||||
if tc.Function.Index != nil {
|
||||
return fmt.Sprintf("idx:%d", *tc.Function.Index)
|
||||
}
|
||||
return fmt.Sprintf("pos:%d", position)
|
||||
}
|
||||
|
||||
// decodeArgumentDelta returns the string fragment to append when a streamed
|
||||
// tool-call chunk includes arguments. Ollama may emit arguments either as a
|
||||
// JSON-encoded string fragment (chunk-by-chunk concatenation, openaicompat
|
||||
// style) or as a complete object value (one-shot delivery). We accept both:
|
||||
// strings are unwrapped, objects/arrays pass through verbatim.
|
||||
func decodeArgumentDelta(raw json.RawMessage) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
trimmed := bytes.TrimSpace(raw)
|
||||
if len(trimmed) == 0 || string(trimmed) == "null" {
|
||||
return ""
|
||||
}
|
||||
if trimmed[0] == '"' {
|
||||
var s string
|
||||
if err := json.Unmarshal(trimmed, &s); err == nil {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return string(trimmed)
|
||||
}
|
||||
|
||||
// buildChatRequest converts a provider.Request into the native wire body
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
@@ -345,3 +346,228 @@ func toString(v any) string {
|
||||
b, _ := json.Marshal(v)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// streamServer returns an httptest.Server that writes the given NDJSON lines
|
||||
// (each terminated with \n) as the response body.
|
||||
func streamServer(t *testing.T, captured *captureRequest, lines []string) *httptest.Server {
|
||||
t.Helper()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.path = r.URL.Path
|
||||
captured.authHeader = r.Header.Get("Authorization")
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
captured.body = body
|
||||
_ = json.Unmarshal(body, &captured.parsedBody)
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.WriteHeader(200)
|
||||
flusher, _ := w.(http.Flusher)
|
||||
for _, line := range lines {
|
||||
_, _ = w.Write([]byte(line + "\n"))
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
return srv
|
||||
}
|
||||
|
||||
func collectStream(t *testing.T, p *Provider, req provider.Request) []provider.StreamEvent {
|
||||
t.Helper()
|
||||
events := make(chan provider.StreamEvent, 64)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- p.Stream(context.Background(), req, events)
|
||||
}()
|
||||
var out []provider.StreamEvent
|
||||
timeout := time.After(5 * time.Second)
|
||||
streamErrored := false
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
break loop
|
||||
}
|
||||
out = append(out, ev)
|
||||
if ev.Type == provider.StreamEventError {
|
||||
streamErrored = true
|
||||
}
|
||||
case err := <-done:
|
||||
if err != nil && !streamErrored {
|
||||
t.Fatalf("Stream returned error: %v", err)
|
||||
}
|
||||
// Drain any final events buffered in the channel.
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return out
|
||||
}
|
||||
out = append(out, ev)
|
||||
default:
|
||||
return out
|
||||
}
|
||||
}
|
||||
case <-timeout:
|
||||
t.Fatal("Stream did not complete within 5s")
|
||||
}
|
||||
}
|
||||
if err := <-done; err != nil && !streamErrored {
|
||||
t.Fatalf("Stream returned error: %v", err)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestStreamBasic(t *testing.T) {
|
||||
lines := []string{
|
||||
`{"message":{"role":"assistant","content":"hello"},"done":false}`,
|
||||
`{"message":{"role":"assistant","content":" world","thinking":"reasoning"},"done":false}`,
|
||||
`{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":12,"eval_count":2}`,
|
||||
}
|
||||
cap := &captureRequest{}
|
||||
srv := streamServer(t, cap, lines)
|
||||
|
||||
p := newNative("", srv.URL)
|
||||
events := collectStream(t, p, provider.Request{
|
||||
Model: "kimi-k2.5",
|
||||
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||
})
|
||||
|
||||
// Verify request shape: stream:true.
|
||||
if cap.parsedBody["stream"] != true {
|
||||
t.Errorf("body.stream: want true, got %v", cap.parsedBody["stream"])
|
||||
}
|
||||
|
||||
// Filter to relevant events (text, thinking, done) preserving order.
|
||||
var kinds []string
|
||||
var texts []string
|
||||
var doneEvent *provider.StreamEvent
|
||||
for i, ev := range events {
|
||||
switch ev.Type {
|
||||
case provider.StreamEventText:
|
||||
kinds = append(kinds, "text")
|
||||
texts = append(texts, ev.Text)
|
||||
case provider.StreamEventThinking:
|
||||
kinds = append(kinds, "thinking")
|
||||
texts = append(texts, ev.Text)
|
||||
case provider.StreamEventDone:
|
||||
kinds = append(kinds, "done")
|
||||
e := events[i]
|
||||
doneEvent = &e
|
||||
}
|
||||
}
|
||||
|
||||
wantKinds := []string{"text", "thinking", "text", "done"}
|
||||
if !equalStrings(kinds, wantKinds) {
|
||||
t.Errorf("event kinds: want %v, got %v", wantKinds, kinds)
|
||||
}
|
||||
if len(texts) >= 3 {
|
||||
if texts[0] != "hello" {
|
||||
t.Errorf("first text: want hello, got %q", texts[0])
|
||||
}
|
||||
if texts[1] != "reasoning" {
|
||||
t.Errorf("thinking: want reasoning, got %q", texts[1])
|
||||
}
|
||||
if texts[2] != " world" {
|
||||
t.Errorf("second text: want \" world\", got %q", texts[2])
|
||||
}
|
||||
}
|
||||
if doneEvent == nil || doneEvent.Response == nil {
|
||||
t.Fatal("Done event missing Response")
|
||||
}
|
||||
if doneEvent.Response.Text != "hello world" {
|
||||
t.Errorf("Response.Text: want %q, got %q", "hello world", doneEvent.Response.Text)
|
||||
}
|
||||
if doneEvent.Response.Thinking != "reasoning" {
|
||||
t.Errorf("Response.Thinking: want %q, got %q", "reasoning", doneEvent.Response.Thinking)
|
||||
}
|
||||
if doneEvent.Response.Usage == nil {
|
||||
t.Fatal("Response.Usage missing")
|
||||
}
|
||||
if doneEvent.Response.Usage.InputTokens != 12 || doneEvent.Response.Usage.OutputTokens != 2 {
|
||||
t.Errorf("Usage: want input=12 output=2, got input=%d output=%d", doneEvent.Response.Usage.InputTokens, doneEvent.Response.Usage.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamToolDeltaAccumulation(t *testing.T) {
|
||||
lines := []string{
|
||||
`{"message":{"role":"assistant","content":"","tool_calls":[{"id":"tc1","function":{"name":"search","arguments":"{\"que"}}]},"done":false}`,
|
||||
`{"message":{"role":"assistant","content":"","tool_calls":[{"id":"tc1","function":{"arguments":"ry\":\"foo\"}"}}]},"done":false}`,
|
||||
`{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":4,"eval_count":1}`,
|
||||
}
|
||||
cap := &captureRequest{}
|
||||
srv := streamServer(t, cap, lines)
|
||||
|
||||
p := newNative("", srv.URL)
|
||||
events := collectStream(t, p, provider.Request{
|
||||
Model: "kimi-k2.5",
|
||||
Messages: []provider.Message{{Role: "user", Content: "search foo"}},
|
||||
Tools: []provider.ToolDef{
|
||||
{Name: "search", Schema: map[string]any{"type": "object"}},
|
||||
},
|
||||
})
|
||||
|
||||
// Build a slim trace of tool events.
|
||||
type traceEntry struct {
|
||||
kind string
|
||||
args string
|
||||
name string
|
||||
id string
|
||||
}
|
||||
var trace []traceEntry
|
||||
var doneEvent *provider.StreamEvent
|
||||
for i, ev := range events {
|
||||
switch ev.Type {
|
||||
case provider.StreamEventToolStart:
|
||||
trace = append(trace, traceEntry{kind: "start", name: ev.ToolCall.Name, id: ev.ToolCall.ID})
|
||||
case provider.StreamEventToolDelta:
|
||||
trace = append(trace, traceEntry{kind: "delta", args: ev.ToolCall.Arguments})
|
||||
case provider.StreamEventToolEnd:
|
||||
trace = append(trace, traceEntry{kind: "end", args: ev.ToolCall.Arguments, name: ev.ToolCall.Name, id: ev.ToolCall.ID})
|
||||
case provider.StreamEventDone:
|
||||
e := events[i]
|
||||
doneEvent = &e
|
||||
}
|
||||
}
|
||||
|
||||
if len(trace) != 4 {
|
||||
t.Fatalf("trace: want 4 entries (start, delta, delta, end), got %d: %+v", len(trace), trace)
|
||||
}
|
||||
if trace[0].kind != "start" || trace[0].name != "search" || trace[0].id != "tc1" {
|
||||
t.Errorf("trace[0]: want start search tc1, got %+v", trace[0])
|
||||
}
|
||||
if trace[1].kind != "delta" || trace[1].args != `{"que` {
|
||||
t.Errorf("trace[1]: want delta args=%q, got %+v", `{"que`, trace[1])
|
||||
}
|
||||
if trace[2].kind != "delta" || trace[2].args != `ry":"foo"}` {
|
||||
t.Errorf("trace[2]: want delta args=%q, got %+v", `ry":"foo"}`, trace[2])
|
||||
}
|
||||
if trace[3].kind != "end" || trace[3].args != `{"query":"foo"}` {
|
||||
t.Errorf("trace[3]: want end args=%q, got %+v", `{"query":"foo"}`, trace[3])
|
||||
}
|
||||
|
||||
if doneEvent == nil || doneEvent.Response == nil {
|
||||
t.Fatal("Done event missing Response")
|
||||
}
|
||||
if len(doneEvent.Response.ToolCalls) != 1 {
|
||||
t.Fatalf("Done.Response.ToolCalls: want 1, got %d", len(doneEvent.Response.ToolCalls))
|
||||
}
|
||||
tc := doneEvent.Response.ToolCalls[0]
|
||||
if tc.ID != "tc1" || tc.Name != "search" || tc.Arguments != `{"query":"foo"}` {
|
||||
t.Errorf("Done.Response.ToolCalls[0]: want tc1/search/{...}, got %+v", tc)
|
||||
}
|
||||
}
|
||||
|
||||
func equalStrings(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user