f70c7c0842
Reads Ollama's NDJSON stream (one JSON object per line) and emits provider.StreamEvent values for text, thinking, tool-call start/delta/end, and a final Done event carrying assembled Response and Usage. Uses bufio.Scanner with a 4 MiB max-line buffer so multi-KB tool-call deltas parse cleanly, and accepts tool-call arguments delivered either as escaped string fragments (delta-style) or a complete JSON object (one-shot). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
574 lines
17 KiB
Go
574 lines
17 KiB
Go
package ollama
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
|
)
|
|
|
|
// captureRequest is a tiny helper that records the inbound HTTP request and
|
|
// returns a configurable response body.
|
|
type captureRequest struct {
|
|
method string
|
|
path string
|
|
authHeader string
|
|
contentType string
|
|
body []byte
|
|
parsedBody map[string]any
|
|
}
|
|
|
|
func newTestServer(t *testing.T, captured *captureRequest, status int, respBody string, respContentType string) *httptest.Server {
|
|
t.Helper()
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
captured.method = r.Method
|
|
captured.path = r.URL.Path
|
|
captured.authHeader = r.Header.Get("Authorization")
|
|
captured.contentType = r.Header.Get("Content-Type")
|
|
body, _ := io.ReadAll(r.Body)
|
|
captured.body = body
|
|
_ = json.Unmarshal(body, &captured.parsedBody)
|
|
if respContentType == "" {
|
|
respContentType = "application/json"
|
|
}
|
|
w.Header().Set("Content-Type", respContentType)
|
|
w.WriteHeader(status)
|
|
_, _ = w.Write([]byte(respBody))
|
|
}))
|
|
t.Cleanup(srv.Close)
|
|
return srv
|
|
}
|
|
|
|
func TestCompleteBasic(t *testing.T) {
|
|
resp := `{
|
|
"model": "kimi-k2.5",
|
|
"message": {"role": "assistant", "content": "hello there"},
|
|
"done": true,
|
|
"done_reason": "stop",
|
|
"prompt_eval_count": 10,
|
|
"eval_count": 3
|
|
}`
|
|
cap := &captureRequest{}
|
|
srv := newTestServer(t, cap, 200, resp, "")
|
|
|
|
p := newNative("test-key", srv.URL)
|
|
got, err := p.Complete(context.Background(), provider.Request{
|
|
Model: "kimi-k2.5",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
|
|
if cap.method != "POST" {
|
|
t.Errorf("method: want POST, got %q", cap.method)
|
|
}
|
|
if cap.path != "/api/chat" {
|
|
t.Errorf("path: want /api/chat, got %q", cap.path)
|
|
}
|
|
if cap.authHeader != "Bearer test-key" {
|
|
t.Errorf("auth header: want %q, got %q", "Bearer test-key", cap.authHeader)
|
|
}
|
|
if cap.contentType != "application/json" {
|
|
t.Errorf("content-type: want application/json, got %q", cap.contentType)
|
|
}
|
|
if cap.parsedBody["model"] != "kimi-k2.5" {
|
|
t.Errorf("body.model: want kimi-k2.5, got %v", cap.parsedBody["model"])
|
|
}
|
|
if cap.parsedBody["stream"] != false {
|
|
t.Errorf("body.stream: want false, got %v", cap.parsedBody["stream"])
|
|
}
|
|
msgs, _ := cap.parsedBody["messages"].([]any)
|
|
if len(msgs) != 1 {
|
|
t.Fatalf("messages: want 1 entry, got %d", len(msgs))
|
|
}
|
|
m0, _ := msgs[0].(map[string]any)
|
|
if m0["role"] != "user" || m0["content"] != "hi" {
|
|
t.Errorf("first message: want role=user content=hi, got %v", m0)
|
|
}
|
|
|
|
if got.Text != "hello there" {
|
|
t.Errorf("Text: want %q, got %q", "hello there", got.Text)
|
|
}
|
|
if got.Usage == nil {
|
|
t.Fatal("Usage: want non-nil")
|
|
}
|
|
if got.Usage.InputTokens != 10 || got.Usage.OutputTokens != 3 {
|
|
t.Errorf("Usage: want input=10 output=3, got input=%d output=%d", got.Usage.InputTokens, got.Usage.OutputTokens)
|
|
}
|
|
if got.Usage.TotalTokens != 13 {
|
|
t.Errorf("Usage.TotalTokens: want 13, got %d", got.Usage.TotalTokens)
|
|
}
|
|
}
|
|
|
|
func TestCompleteNoAuthHeaderWhenLocal(t *testing.T) {
|
|
resp := `{"message":{"role":"assistant","content":"ok"},"done":true}`
|
|
cap := &captureRequest{}
|
|
srv := newTestServer(t, cap, 200, resp, "")
|
|
|
|
p := newNative("", srv.URL)
|
|
if _, err := p.Complete(context.Background(), provider.Request{
|
|
Model: "llama3.2",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
}); err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
|
|
if cap.authHeader != "" {
|
|
t.Errorf("auth header: want empty (local mode), got %q", cap.authHeader)
|
|
}
|
|
}
|
|
|
|
func TestVisionImagesEncoded(t *testing.T) {
|
|
resp := `{"message":{"role":"assistant","content":"a cat"},"done":true}`
|
|
cap := &captureRequest{}
|
|
srv := newTestServer(t, cap, 200, resp, "")
|
|
|
|
p := newNative("", srv.URL)
|
|
if _, err := p.Complete(context.Background(), provider.Request{
|
|
Model: "llava",
|
|
Messages: []provider.Message{{
|
|
Role: "user",
|
|
Content: "what's in this?",
|
|
Images: []provider.Image{
|
|
{Base64: "AAAA", ContentType: "image/png"},
|
|
},
|
|
}},
|
|
}); err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
|
|
msgs, _ := cap.parsedBody["messages"].([]any)
|
|
if len(msgs) != 1 {
|
|
t.Fatalf("messages: want 1, got %d", len(msgs))
|
|
}
|
|
m0, _ := msgs[0].(map[string]any)
|
|
imgs, _ := m0["images"].([]any)
|
|
if len(imgs) != 1 {
|
|
t.Fatalf("images: want 1 entry, got %d (msg=%v)", len(imgs), m0)
|
|
}
|
|
if imgs[0] != "AAAA" {
|
|
t.Errorf("images[0]: want raw base64 AAAA, got %v", imgs[0])
|
|
}
|
|
}
|
|
|
|
func TestThinkingField(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
reasoning string
|
|
want any // expected value of "think" in body, or nil if absent
|
|
}{
|
|
{"absent", "", nil},
|
|
{"high", "high", "high"},
|
|
{"low", "low", "low"},
|
|
{"medium", "medium", "medium"},
|
|
{"true", "true", true},
|
|
{"false", "false", false},
|
|
}
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
resp := `{"message":{"role":"assistant","content":"ok"},"done":true}`
|
|
cap := &captureRequest{}
|
|
srv := newTestServer(t, cap, 200, resp, "")
|
|
|
|
p := newNative("", srv.URL)
|
|
_, err := p.Complete(context.Background(), provider.Request{
|
|
Model: "kimi-k2.5",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
Reasoning: c.reasoning,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
got, present := cap.parsedBody["think"]
|
|
if c.want == nil {
|
|
if present {
|
|
t.Errorf("think field should be absent, got %v", got)
|
|
}
|
|
return
|
|
}
|
|
if !present {
|
|
t.Fatalf("think field absent; want %v", c.want)
|
|
}
|
|
if got != c.want {
|
|
t.Errorf("think: want %v (%T), got %v (%T)", c.want, c.want, got, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestToolRoundTrip(t *testing.T) {
|
|
t.Run("response tool_calls convert to provider.Response", func(t *testing.T) {
|
|
resp := `{
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [
|
|
{"function": {"name": "search", "arguments": {"query": "foo"}}}
|
|
]
|
|
},
|
|
"done": true,
|
|
"prompt_eval_count": 5,
|
|
"eval_count": 2
|
|
}`
|
|
cap := &captureRequest{}
|
|
srv := newTestServer(t, cap, 200, resp, "")
|
|
|
|
p := newNative("", srv.URL)
|
|
got, err := p.Complete(context.Background(), provider.Request{
|
|
Model: "kimi-k2.5",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
Tools: []provider.ToolDef{
|
|
{
|
|
Name: "search",
|
|
Description: "Run a search",
|
|
Schema: map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{
|
|
"query": map[string]any{"type": "string"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
|
|
// Verify request shape: tools array present.
|
|
toolsArr, _ := cap.parsedBody["tools"].([]any)
|
|
if len(toolsArr) != 1 {
|
|
t.Fatalf("tools: want 1 entry, got %d", len(toolsArr))
|
|
}
|
|
t0, _ := toolsArr[0].(map[string]any)
|
|
if t0["type"] != "function" {
|
|
t.Errorf("tools[0].type: want function, got %v", t0["type"])
|
|
}
|
|
fn, _ := t0["function"].(map[string]any)
|
|
if fn["name"] != "search" {
|
|
t.Errorf("tools[0].function.name: want search, got %v", fn["name"])
|
|
}
|
|
|
|
// Verify response conversion.
|
|
if len(got.ToolCalls) != 1 {
|
|
t.Fatalf("ToolCalls: want 1, got %d", len(got.ToolCalls))
|
|
}
|
|
tc := got.ToolCalls[0]
|
|
if tc.Name != "search" {
|
|
t.Errorf("ToolCall.Name: want search, got %q", tc.Name)
|
|
}
|
|
// Arguments should be valid JSON containing query=foo
|
|
var args map[string]any
|
|
if err := json.Unmarshal([]byte(tc.Arguments), &args); err != nil {
|
|
t.Fatalf("ToolCall.Arguments not valid JSON: %v (got %q)", err, tc.Arguments)
|
|
}
|
|
if args["query"] != "foo" {
|
|
t.Errorf("ToolCall.Arguments.query: want foo, got %v", args["query"])
|
|
}
|
|
})
|
|
|
|
t.Run("subsequent request includes assistant tool_calls and tool-role response", func(t *testing.T) {
|
|
resp := `{"message":{"role":"assistant","content":"done"},"done":true}`
|
|
cap := &captureRequest{}
|
|
srv := newTestServer(t, cap, 200, resp, "")
|
|
|
|
p := newNative("", srv.URL)
|
|
_, err := p.Complete(context.Background(), provider.Request{
|
|
Model: "kimi-k2.5",
|
|
Messages: []provider.Message{
|
|
{Role: "user", Content: "search foo"},
|
|
{
|
|
Role: "assistant",
|
|
ToolCalls: []provider.ToolCall{{
|
|
ID: "tc1",
|
|
Name: "search",
|
|
Arguments: `{"query":"foo"}`,
|
|
}},
|
|
},
|
|
{
|
|
Role: "tool",
|
|
ToolCallID: "tc1",
|
|
Content: `{"result":"bar"}`,
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
|
|
msgs, _ := cap.parsedBody["messages"].([]any)
|
|
if len(msgs) != 3 {
|
|
t.Fatalf("messages: want 3, got %d", len(msgs))
|
|
}
|
|
|
|
// Assistant message must carry tool_calls with the JSON-object arguments.
|
|
asst, _ := msgs[1].(map[string]any)
|
|
if asst["role"] != "assistant" {
|
|
t.Errorf("msgs[1].role: want assistant, got %v", asst["role"])
|
|
}
|
|
tc, _ := asst["tool_calls"].([]any)
|
|
if len(tc) != 1 {
|
|
t.Fatalf("assistant.tool_calls: want 1, got %d", len(tc))
|
|
}
|
|
fn, _ := tc[0].(map[string]any)["function"].(map[string]any)
|
|
if fn["name"] != "search" {
|
|
t.Errorf("assistant.tool_calls[0].function.name: want search, got %v", fn["name"])
|
|
}
|
|
args, _ := fn["arguments"].(map[string]any)
|
|
if args["query"] != "foo" {
|
|
t.Errorf("assistant.tool_calls[0].function.arguments.query: want foo, got %v", args["query"])
|
|
}
|
|
|
|
// Tool-role message must have role=tool, tool_call_id, and content.
|
|
tool, _ := msgs[2].(map[string]any)
|
|
if tool["role"] != "tool" {
|
|
t.Errorf("msgs[2].role: want tool, got %v", tool["role"])
|
|
}
|
|
if tool["tool_call_id"] != "tc1" {
|
|
t.Errorf("msgs[2].tool_call_id: want tc1, got %v", tool["tool_call_id"])
|
|
}
|
|
if !strings.Contains(toString(tool["content"]), "bar") {
|
|
t.Errorf("msgs[2].content: want to contain bar, got %v", tool["content"])
|
|
}
|
|
})
|
|
}
|
|
|
|
func toString(v any) string {
|
|
if s, ok := v.(string); ok {
|
|
return s
|
|
}
|
|
b, _ := json.Marshal(v)
|
|
return string(b)
|
|
}
|
|
|
|
// streamServer returns an httptest.Server that writes the given NDJSON lines
|
|
// (each terminated with \n) as the response body.
|
|
func streamServer(t *testing.T, captured *captureRequest, lines []string) *httptest.Server {
|
|
t.Helper()
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
captured.method = r.Method
|
|
captured.path = r.URL.Path
|
|
captured.authHeader = r.Header.Get("Authorization")
|
|
captured.contentType = r.Header.Get("Content-Type")
|
|
body, _ := io.ReadAll(r.Body)
|
|
captured.body = body
|
|
_ = json.Unmarshal(body, &captured.parsedBody)
|
|
w.Header().Set("Content-Type", "application/x-ndjson")
|
|
w.WriteHeader(200)
|
|
flusher, _ := w.(http.Flusher)
|
|
for _, line := range lines {
|
|
_, _ = w.Write([]byte(line + "\n"))
|
|
if flusher != nil {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
}))
|
|
t.Cleanup(srv.Close)
|
|
return srv
|
|
}
|
|
|
|
func collectStream(t *testing.T, p *Provider, req provider.Request) []provider.StreamEvent {
|
|
t.Helper()
|
|
events := make(chan provider.StreamEvent, 64)
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
done <- p.Stream(context.Background(), req, events)
|
|
}()
|
|
var out []provider.StreamEvent
|
|
timeout := time.After(5 * time.Second)
|
|
streamErrored := false
|
|
loop:
|
|
for {
|
|
select {
|
|
case ev, ok := <-events:
|
|
if !ok {
|
|
break loop
|
|
}
|
|
out = append(out, ev)
|
|
if ev.Type == provider.StreamEventError {
|
|
streamErrored = true
|
|
}
|
|
case err := <-done:
|
|
if err != nil && !streamErrored {
|
|
t.Fatalf("Stream returned error: %v", err)
|
|
}
|
|
// Drain any final events buffered in the channel.
|
|
for {
|
|
select {
|
|
case ev, ok := <-events:
|
|
if !ok {
|
|
return out
|
|
}
|
|
out = append(out, ev)
|
|
default:
|
|
return out
|
|
}
|
|
}
|
|
case <-timeout:
|
|
t.Fatal("Stream did not complete within 5s")
|
|
}
|
|
}
|
|
if err := <-done; err != nil && !streamErrored {
|
|
t.Fatalf("Stream returned error: %v", err)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func TestStreamBasic(t *testing.T) {
|
|
lines := []string{
|
|
`{"message":{"role":"assistant","content":"hello"},"done":false}`,
|
|
`{"message":{"role":"assistant","content":" world","thinking":"reasoning"},"done":false}`,
|
|
`{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":12,"eval_count":2}`,
|
|
}
|
|
cap := &captureRequest{}
|
|
srv := streamServer(t, cap, lines)
|
|
|
|
p := newNative("", srv.URL)
|
|
events := collectStream(t, p, provider.Request{
|
|
Model: "kimi-k2.5",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
|
|
// Verify request shape: stream:true.
|
|
if cap.parsedBody["stream"] != true {
|
|
t.Errorf("body.stream: want true, got %v", cap.parsedBody["stream"])
|
|
}
|
|
|
|
// Filter to relevant events (text, thinking, done) preserving order.
|
|
var kinds []string
|
|
var texts []string
|
|
var doneEvent *provider.StreamEvent
|
|
for i, ev := range events {
|
|
switch ev.Type {
|
|
case provider.StreamEventText:
|
|
kinds = append(kinds, "text")
|
|
texts = append(texts, ev.Text)
|
|
case provider.StreamEventThinking:
|
|
kinds = append(kinds, "thinking")
|
|
texts = append(texts, ev.Text)
|
|
case provider.StreamEventDone:
|
|
kinds = append(kinds, "done")
|
|
e := events[i]
|
|
doneEvent = &e
|
|
}
|
|
}
|
|
|
|
wantKinds := []string{"text", "thinking", "text", "done"}
|
|
if !equalStrings(kinds, wantKinds) {
|
|
t.Errorf("event kinds: want %v, got %v", wantKinds, kinds)
|
|
}
|
|
if len(texts) >= 3 {
|
|
if texts[0] != "hello" {
|
|
t.Errorf("first text: want hello, got %q", texts[0])
|
|
}
|
|
if texts[1] != "reasoning" {
|
|
t.Errorf("thinking: want reasoning, got %q", texts[1])
|
|
}
|
|
if texts[2] != " world" {
|
|
t.Errorf("second text: want \" world\", got %q", texts[2])
|
|
}
|
|
}
|
|
if doneEvent == nil || doneEvent.Response == nil {
|
|
t.Fatal("Done event missing Response")
|
|
}
|
|
if doneEvent.Response.Text != "hello world" {
|
|
t.Errorf("Response.Text: want %q, got %q", "hello world", doneEvent.Response.Text)
|
|
}
|
|
if doneEvent.Response.Thinking != "reasoning" {
|
|
t.Errorf("Response.Thinking: want %q, got %q", "reasoning", doneEvent.Response.Thinking)
|
|
}
|
|
if doneEvent.Response.Usage == nil {
|
|
t.Fatal("Response.Usage missing")
|
|
}
|
|
if doneEvent.Response.Usage.InputTokens != 12 || doneEvent.Response.Usage.OutputTokens != 2 {
|
|
t.Errorf("Usage: want input=12 output=2, got input=%d output=%d", doneEvent.Response.Usage.InputTokens, doneEvent.Response.Usage.OutputTokens)
|
|
}
|
|
}
|
|
|
|
func TestStreamToolDeltaAccumulation(t *testing.T) {
|
|
lines := []string{
|
|
`{"message":{"role":"assistant","content":"","tool_calls":[{"id":"tc1","function":{"name":"search","arguments":"{\"que"}}]},"done":false}`,
|
|
`{"message":{"role":"assistant","content":"","tool_calls":[{"id":"tc1","function":{"arguments":"ry\":\"foo\"}"}}]},"done":false}`,
|
|
`{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":4,"eval_count":1}`,
|
|
}
|
|
cap := &captureRequest{}
|
|
srv := streamServer(t, cap, lines)
|
|
|
|
p := newNative("", srv.URL)
|
|
events := collectStream(t, p, provider.Request{
|
|
Model: "kimi-k2.5",
|
|
Messages: []provider.Message{{Role: "user", Content: "search foo"}},
|
|
Tools: []provider.ToolDef{
|
|
{Name: "search", Schema: map[string]any{"type": "object"}},
|
|
},
|
|
})
|
|
|
|
// Build a slim trace of tool events.
|
|
type traceEntry struct {
|
|
kind string
|
|
args string
|
|
name string
|
|
id string
|
|
}
|
|
var trace []traceEntry
|
|
var doneEvent *provider.StreamEvent
|
|
for i, ev := range events {
|
|
switch ev.Type {
|
|
case provider.StreamEventToolStart:
|
|
trace = append(trace, traceEntry{kind: "start", name: ev.ToolCall.Name, id: ev.ToolCall.ID})
|
|
case provider.StreamEventToolDelta:
|
|
trace = append(trace, traceEntry{kind: "delta", args: ev.ToolCall.Arguments})
|
|
case provider.StreamEventToolEnd:
|
|
trace = append(trace, traceEntry{kind: "end", args: ev.ToolCall.Arguments, name: ev.ToolCall.Name, id: ev.ToolCall.ID})
|
|
case provider.StreamEventDone:
|
|
e := events[i]
|
|
doneEvent = &e
|
|
}
|
|
}
|
|
|
|
if len(trace) != 4 {
|
|
t.Fatalf("trace: want 4 entries (start, delta, delta, end), got %d: %+v", len(trace), trace)
|
|
}
|
|
if trace[0].kind != "start" || trace[0].name != "search" || trace[0].id != "tc1" {
|
|
t.Errorf("trace[0]: want start search tc1, got %+v", trace[0])
|
|
}
|
|
if trace[1].kind != "delta" || trace[1].args != `{"que` {
|
|
t.Errorf("trace[1]: want delta args=%q, got %+v", `{"que`, trace[1])
|
|
}
|
|
if trace[2].kind != "delta" || trace[2].args != `ry":"foo"}` {
|
|
t.Errorf("trace[2]: want delta args=%q, got %+v", `ry":"foo"}`, trace[2])
|
|
}
|
|
if trace[3].kind != "end" || trace[3].args != `{"query":"foo"}` {
|
|
t.Errorf("trace[3]: want end args=%q, got %+v", `{"query":"foo"}`, trace[3])
|
|
}
|
|
|
|
if doneEvent == nil || doneEvent.Response == nil {
|
|
t.Fatal("Done event missing Response")
|
|
}
|
|
if len(doneEvent.Response.ToolCalls) != 1 {
|
|
t.Fatalf("Done.Response.ToolCalls: want 1, got %d", len(doneEvent.Response.ToolCalls))
|
|
}
|
|
tc := doneEvent.Response.ToolCalls[0]
|
|
if tc.ID != "tc1" || tc.Name != "search" || tc.Arguments != `{"query":"foo"}` {
|
|
t.Errorf("Done.Response.ToolCalls[0]: want tc1/search/{...}, got %+v", tc)
|
|
}
|
|
}
|
|
|
|
func equalStrings(a, b []string) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
for i := range a {
|
|
if a[i] != b[i] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|