package ollama import ( "context" "encoding/json" "io" "net/http" "net/http/httptest" "strings" "testing" "time" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" ) // captureRequest is a tiny helper that records the inbound HTTP request and // returns a configurable response body. type captureRequest struct { method string path string authHeader string contentType string body []byte parsedBody map[string]any } func newTestServer(t *testing.T, captured *captureRequest, status int, respBody string, respContentType string) *httptest.Server { t.Helper() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { captured.method = r.Method captured.path = r.URL.Path captured.authHeader = r.Header.Get("Authorization") captured.contentType = r.Header.Get("Content-Type") body, _ := io.ReadAll(r.Body) captured.body = body _ = json.Unmarshal(body, &captured.parsedBody) if respContentType == "" { respContentType = "application/json" } w.Header().Set("Content-Type", respContentType) w.WriteHeader(status) _, _ = w.Write([]byte(respBody)) })) t.Cleanup(srv.Close) return srv } func TestCompleteBasic(t *testing.T) { resp := `{ "model": "kimi-k2.5", "message": {"role": "assistant", "content": "hello there"}, "done": true, "done_reason": "stop", "prompt_eval_count": 10, "eval_count": 3 }` cap := &captureRequest{} srv := newTestServer(t, cap, 200, resp, "") p := newNative("test-key", srv.URL) got, err := p.Complete(context.Background(), provider.Request{ Model: "kimi-k2.5", Messages: []provider.Message{{Role: "user", Content: "hi"}}, }) if err != nil { t.Fatalf("Complete: %v", err) } if cap.method != "POST" { t.Errorf("method: want POST, got %q", cap.method) } if cap.path != "/api/chat" { t.Errorf("path: want /api/chat, got %q", cap.path) } if cap.authHeader != "Bearer test-key" { t.Errorf("auth header: want %q, got %q", "Bearer test-key", cap.authHeader) } if cap.contentType != "application/json" { t.Errorf("content-type: want application/json, got %q", cap.contentType) } if cap.parsedBody["model"] != "kimi-k2.5" { t.Errorf("body.model: want kimi-k2.5, got %v", cap.parsedBody["model"]) } if cap.parsedBody["stream"] != false { t.Errorf("body.stream: want false, got %v", cap.parsedBody["stream"]) } msgs, _ := cap.parsedBody["messages"].([]any) if len(msgs) != 1 { t.Fatalf("messages: want 1 entry, got %d", len(msgs)) } m0, _ := msgs[0].(map[string]any) if m0["role"] != "user" || m0["content"] != "hi" { t.Errorf("first message: want role=user content=hi, got %v", m0) } if got.Text != "hello there" { t.Errorf("Text: want %q, got %q", "hello there", got.Text) } if got.Usage == nil { t.Fatal("Usage: want non-nil") } if got.Usage.InputTokens != 10 || got.Usage.OutputTokens != 3 { t.Errorf("Usage: want input=10 output=3, got input=%d output=%d", got.Usage.InputTokens, got.Usage.OutputTokens) } if got.Usage.TotalTokens != 13 { t.Errorf("Usage.TotalTokens: want 13, got %d", got.Usage.TotalTokens) } } func TestCompleteNoAuthHeaderWhenLocal(t *testing.T) { resp := `{"message":{"role":"assistant","content":"ok"},"done":true}` cap := &captureRequest{} srv := newTestServer(t, cap, 200, resp, "") p := newNative("", srv.URL) if _, err := p.Complete(context.Background(), provider.Request{ Model: "llama3.2", Messages: []provider.Message{{Role: "user", Content: "hi"}}, }); err != nil { t.Fatalf("Complete: %v", err) } if cap.authHeader != "" { t.Errorf("auth header: want empty (local mode), got %q", cap.authHeader) } } func TestVisionImagesEncoded(t *testing.T) { resp := `{"message":{"role":"assistant","content":"a cat"},"done":true}` cap := &captureRequest{} srv := newTestServer(t, cap, 200, resp, "") p := newNative("", srv.URL) if _, err := p.Complete(context.Background(), provider.Request{ Model: "llava", Messages: []provider.Message{{ Role: "user", Content: "what's in this?", Images: []provider.Image{ {Base64: "AAAA", ContentType: "image/png"}, }, }}, }); err != nil { t.Fatalf("Complete: %v", err) } msgs, _ := cap.parsedBody["messages"].([]any) if len(msgs) != 1 { t.Fatalf("messages: want 1, got %d", len(msgs)) } m0, _ := msgs[0].(map[string]any) imgs, _ := m0["images"].([]any) if len(imgs) != 1 { t.Fatalf("images: want 1 entry, got %d (msg=%v)", len(imgs), m0) } if imgs[0] != "AAAA" { t.Errorf("images[0]: want raw base64 AAAA, got %v", imgs[0]) } } func TestThinkingField(t *testing.T) { cases := []struct { name string reasoning string want any // expected value of "think" in body, or nil if absent }{ {"absent", "", nil}, {"high", "high", "high"}, {"low", "low", "low"}, {"medium", "medium", "medium"}, {"true", "true", true}, {"false", "false", false}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { resp := `{"message":{"role":"assistant","content":"ok"},"done":true}` cap := &captureRequest{} srv := newTestServer(t, cap, 200, resp, "") p := newNative("", srv.URL) _, err := p.Complete(context.Background(), provider.Request{ Model: "kimi-k2.5", Messages: []provider.Message{{Role: "user", Content: "hi"}}, Reasoning: c.reasoning, }) if err != nil { t.Fatalf("Complete: %v", err) } got, present := cap.parsedBody["think"] if c.want == nil { if present { t.Errorf("think field should be absent, got %v", got) } return } if !present { t.Fatalf("think field absent; want %v", c.want) } if got != c.want { t.Errorf("think: want %v (%T), got %v (%T)", c.want, c.want, got, got) } }) } } func TestToolRoundTrip(t *testing.T) { t.Run("response tool_calls convert to provider.Response", func(t *testing.T) { resp := `{ "message": { "role": "assistant", "content": "", "tool_calls": [ {"function": {"name": "search", "arguments": {"query": "foo"}}} ] }, "done": true, "prompt_eval_count": 5, "eval_count": 2 }` cap := &captureRequest{} srv := newTestServer(t, cap, 200, resp, "") p := newNative("", srv.URL) got, err := p.Complete(context.Background(), provider.Request{ Model: "kimi-k2.5", Messages: []provider.Message{{Role: "user", Content: "hi"}}, Tools: []provider.ToolDef{ { Name: "search", Description: "Run a search", Schema: map[string]any{ "type": "object", "properties": map[string]any{ "query": map[string]any{"type": "string"}, }, }, }, }, }) if err != nil { t.Fatalf("Complete: %v", err) } // Verify request shape: tools array present. toolsArr, _ := cap.parsedBody["tools"].([]any) if len(toolsArr) != 1 { t.Fatalf("tools: want 1 entry, got %d", len(toolsArr)) } t0, _ := toolsArr[0].(map[string]any) if t0["type"] != "function" { t.Errorf("tools[0].type: want function, got %v", t0["type"]) } fn, _ := t0["function"].(map[string]any) if fn["name"] != "search" { t.Errorf("tools[0].function.name: want search, got %v", fn["name"]) } // Verify response conversion. if len(got.ToolCalls) != 1 { t.Fatalf("ToolCalls: want 1, got %d", len(got.ToolCalls)) } tc := got.ToolCalls[0] if tc.Name != "search" { t.Errorf("ToolCall.Name: want search, got %q", tc.Name) } // Arguments should be valid JSON containing query=foo var args map[string]any if err := json.Unmarshal([]byte(tc.Arguments), &args); err != nil { t.Fatalf("ToolCall.Arguments not valid JSON: %v (got %q)", err, tc.Arguments) } if args["query"] != "foo" { t.Errorf("ToolCall.Arguments.query: want foo, got %v", args["query"]) } }) t.Run("subsequent request includes assistant tool_calls and tool-role response", func(t *testing.T) { resp := `{"message":{"role":"assistant","content":"done"},"done":true}` cap := &captureRequest{} srv := newTestServer(t, cap, 200, resp, "") p := newNative("", srv.URL) _, err := p.Complete(context.Background(), provider.Request{ Model: "kimi-k2.5", Messages: []provider.Message{ {Role: "user", Content: "search foo"}, { Role: "assistant", ToolCalls: []provider.ToolCall{{ ID: "tc1", Name: "search", Arguments: `{"query":"foo"}`, }}, }, { Role: "tool", ToolCallID: "tc1", Content: `{"result":"bar"}`, }, }, }) if err != nil { t.Fatalf("Complete: %v", err) } msgs, _ := cap.parsedBody["messages"].([]any) if len(msgs) != 3 { t.Fatalf("messages: want 3, got %d", len(msgs)) } // Assistant message must carry tool_calls with the JSON-object arguments. asst, _ := msgs[1].(map[string]any) if asst["role"] != "assistant" { t.Errorf("msgs[1].role: want assistant, got %v", asst["role"]) } tc, _ := asst["tool_calls"].([]any) if len(tc) != 1 { t.Fatalf("assistant.tool_calls: want 1, got %d", len(tc)) } fn, _ := tc[0].(map[string]any)["function"].(map[string]any) if fn["name"] != "search" { t.Errorf("assistant.tool_calls[0].function.name: want search, got %v", fn["name"]) } args, _ := fn["arguments"].(map[string]any) if args["query"] != "foo" { t.Errorf("assistant.tool_calls[0].function.arguments.query: want foo, got %v", args["query"]) } // Tool-role message must have role=tool, tool_call_id, and content. tool, _ := msgs[2].(map[string]any) if tool["role"] != "tool" { t.Errorf("msgs[2].role: want tool, got %v", tool["role"]) } if tool["tool_call_id"] != "tc1" { t.Errorf("msgs[2].tool_call_id: want tc1, got %v", tool["tool_call_id"]) } if !strings.Contains(toString(tool["content"]), "bar") { t.Errorf("msgs[2].content: want to contain bar, got %v", tool["content"]) } }) } func toString(v any) string { if s, ok := v.(string); ok { return s } b, _ := json.Marshal(v) return string(b) } // streamServer returns an httptest.Server that writes the given NDJSON lines // (each terminated with \n) as the response body. func streamServer(t *testing.T, captured *captureRequest, lines []string) *httptest.Server { t.Helper() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { captured.method = r.Method captured.path = r.URL.Path captured.authHeader = r.Header.Get("Authorization") captured.contentType = r.Header.Get("Content-Type") body, _ := io.ReadAll(r.Body) captured.body = body _ = json.Unmarshal(body, &captured.parsedBody) w.Header().Set("Content-Type", "application/x-ndjson") w.WriteHeader(200) flusher, _ := w.(http.Flusher) for _, line := range lines { _, _ = w.Write([]byte(line + "\n")) if flusher != nil { flusher.Flush() } } })) t.Cleanup(srv.Close) return srv } func collectStream(t *testing.T, p *Provider, req provider.Request) []provider.StreamEvent { t.Helper() events := make(chan provider.StreamEvent, 64) done := make(chan error, 1) go func() { done <- p.Stream(context.Background(), req, events) }() var out []provider.StreamEvent timeout := time.After(5 * time.Second) streamErrored := false loop: for { select { case ev, ok := <-events: if !ok { break loop } out = append(out, ev) if ev.Type == provider.StreamEventError { streamErrored = true } case err := <-done: if err != nil && !streamErrored { t.Fatalf("Stream returned error: %v", err) } // Drain any final events buffered in the channel. for { select { case ev, ok := <-events: if !ok { return out } out = append(out, ev) default: return out } } case <-timeout: t.Fatal("Stream did not complete within 5s") } } if err := <-done; err != nil && !streamErrored { t.Fatalf("Stream returned error: %v", err) } return out } func TestStreamBasic(t *testing.T) { lines := []string{ `{"message":{"role":"assistant","content":"hello"},"done":false}`, `{"message":{"role":"assistant","content":" world","thinking":"reasoning"},"done":false}`, `{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":12,"eval_count":2}`, } cap := &captureRequest{} srv := streamServer(t, cap, lines) p := newNative("", srv.URL) events := collectStream(t, p, provider.Request{ Model: "kimi-k2.5", Messages: []provider.Message{{Role: "user", Content: "hi"}}, }) // Verify request shape: stream:true. if cap.parsedBody["stream"] != true { t.Errorf("body.stream: want true, got %v", cap.parsedBody["stream"]) } // Filter to relevant events (text, thinking, done) preserving order. var kinds []string var texts []string var doneEvent *provider.StreamEvent for i, ev := range events { switch ev.Type { case provider.StreamEventText: kinds = append(kinds, "text") texts = append(texts, ev.Text) case provider.StreamEventThinking: kinds = append(kinds, "thinking") texts = append(texts, ev.Text) case provider.StreamEventDone: kinds = append(kinds, "done") e := events[i] doneEvent = &e } } wantKinds := []string{"text", "thinking", "text", "done"} if !equalStrings(kinds, wantKinds) { t.Errorf("event kinds: want %v, got %v", wantKinds, kinds) } if len(texts) >= 3 { if texts[0] != "hello" { t.Errorf("first text: want hello, got %q", texts[0]) } if texts[1] != "reasoning" { t.Errorf("thinking: want reasoning, got %q", texts[1]) } if texts[2] != " world" { t.Errorf("second text: want \" world\", got %q", texts[2]) } } if doneEvent == nil || doneEvent.Response == nil { t.Fatal("Done event missing Response") } if doneEvent.Response.Text != "hello world" { t.Errorf("Response.Text: want %q, got %q", "hello world", doneEvent.Response.Text) } if doneEvent.Response.Thinking != "reasoning" { t.Errorf("Response.Thinking: want %q, got %q", "reasoning", doneEvent.Response.Thinking) } if doneEvent.Response.Usage == nil { t.Fatal("Response.Usage missing") } if doneEvent.Response.Usage.InputTokens != 12 || doneEvent.Response.Usage.OutputTokens != 2 { t.Errorf("Usage: want input=12 output=2, got input=%d output=%d", doneEvent.Response.Usage.InputTokens, doneEvent.Response.Usage.OutputTokens) } } func TestStreamToolDeltaAccumulation(t *testing.T) { lines := []string{ `{"message":{"role":"assistant","content":"","tool_calls":[{"id":"tc1","function":{"name":"search","arguments":"{\"que"}}]},"done":false}`, `{"message":{"role":"assistant","content":"","tool_calls":[{"id":"tc1","function":{"arguments":"ry\":\"foo\"}"}}]},"done":false}`, `{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":4,"eval_count":1}`, } cap := &captureRequest{} srv := streamServer(t, cap, lines) p := newNative("", srv.URL) events := collectStream(t, p, provider.Request{ Model: "kimi-k2.5", Messages: []provider.Message{{Role: "user", Content: "search foo"}}, Tools: []provider.ToolDef{ {Name: "search", Schema: map[string]any{"type": "object"}}, }, }) // Build a slim trace of tool events. type traceEntry struct { kind string args string name string id string } var trace []traceEntry var doneEvent *provider.StreamEvent for i, ev := range events { switch ev.Type { case provider.StreamEventToolStart: trace = append(trace, traceEntry{kind: "start", name: ev.ToolCall.Name, id: ev.ToolCall.ID}) case provider.StreamEventToolDelta: trace = append(trace, traceEntry{kind: "delta", args: ev.ToolCall.Arguments}) case provider.StreamEventToolEnd: trace = append(trace, traceEntry{kind: "end", args: ev.ToolCall.Arguments, name: ev.ToolCall.Name, id: ev.ToolCall.ID}) case provider.StreamEventDone: e := events[i] doneEvent = &e } } if len(trace) != 4 { t.Fatalf("trace: want 4 entries (start, delta, delta, end), got %d: %+v", len(trace), trace) } if trace[0].kind != "start" || trace[0].name != "search" || trace[0].id != "tc1" { t.Errorf("trace[0]: want start search tc1, got %+v", trace[0]) } if trace[1].kind != "delta" || trace[1].args != `{"que` { t.Errorf("trace[1]: want delta args=%q, got %+v", `{"que`, trace[1]) } if trace[2].kind != "delta" || trace[2].args != `ry":"foo"}` { t.Errorf("trace[2]: want delta args=%q, got %+v", `ry":"foo"}`, trace[2]) } if trace[3].kind != "end" || trace[3].args != `{"query":"foo"}` { t.Errorf("trace[3]: want end args=%q, got %+v", `{"query":"foo"}`, trace[3]) } if doneEvent == nil || doneEvent.Response == nil { t.Fatal("Done event missing Response") } if len(doneEvent.Response.ToolCalls) != 1 { t.Fatalf("Done.Response.ToolCalls: want 1, got %d", len(doneEvent.Response.ToolCalls)) } tc := doneEvent.Response.ToolCalls[0] if tc.ID != "tc1" || tc.Name != "search" || tc.Arguments != `{"query":"foo"}` { t.Errorf("Done.Response.ToolCalls[0]: want tc1/search/{...}, got %+v", tc) } } func equalStrings(a, b []string) bool { if len(a) != len(b) { return false } for i := range a { if a[i] != b[i] { return false } } return true }