Files
go-llm/v2/ollama/native_test.go
T
steve 583f8724b2 feat(v2/ollama): implement native Complete() with tools, vision, thinking
Non-streaming /api/chat support including:
- Vision via images: []base64
- Tool calls on assistant + tool-role response messages
- think field accepting string reasoning levels (or "true"/"false")
- Authorization header when apiKey is non-empty (cloud mode)

Tool-call arguments are passed as JSON objects to the wire and surfaced
as JSON-string Arguments on provider.ToolCall. Tool calls are assigned
synthetic IDs (tc_<index>) when Ollama omits one, so the round-trip
back as an assistant tool_calls + tool-role message remains correlated.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-01 18:24:02 +00:00

348 lines
9.7 KiB
Go

package ollama
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// captureRequest is a tiny helper that records the inbound HTTP request and
// returns a configurable response body.
type captureRequest struct {
method string
path string
authHeader string
contentType string
body []byte
parsedBody map[string]any
}
func newTestServer(t *testing.T, captured *captureRequest, status int, respBody string, respContentType string) *httptest.Server {
t.Helper()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.path = r.URL.Path
captured.authHeader = r.Header.Get("Authorization")
captured.contentType = r.Header.Get("Content-Type")
body, _ := io.ReadAll(r.Body)
captured.body = body
_ = json.Unmarshal(body, &captured.parsedBody)
if respContentType == "" {
respContentType = "application/json"
}
w.Header().Set("Content-Type", respContentType)
w.WriteHeader(status)
_, _ = w.Write([]byte(respBody))
}))
t.Cleanup(srv.Close)
return srv
}
func TestCompleteBasic(t *testing.T) {
resp := `{
"model": "kimi-k2.5",
"message": {"role": "assistant", "content": "hello there"},
"done": true,
"done_reason": "stop",
"prompt_eval_count": 10,
"eval_count": 3
}`
cap := &captureRequest{}
srv := newTestServer(t, cap, 200, resp, "")
p := newNative("test-key", srv.URL)
got, err := p.Complete(context.Background(), provider.Request{
Model: "kimi-k2.5",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err != nil {
t.Fatalf("Complete: %v", err)
}
if cap.method != "POST" {
t.Errorf("method: want POST, got %q", cap.method)
}
if cap.path != "/api/chat" {
t.Errorf("path: want /api/chat, got %q", cap.path)
}
if cap.authHeader != "Bearer test-key" {
t.Errorf("auth header: want %q, got %q", "Bearer test-key", cap.authHeader)
}
if cap.contentType != "application/json" {
t.Errorf("content-type: want application/json, got %q", cap.contentType)
}
if cap.parsedBody["model"] != "kimi-k2.5" {
t.Errorf("body.model: want kimi-k2.5, got %v", cap.parsedBody["model"])
}
if cap.parsedBody["stream"] != false {
t.Errorf("body.stream: want false, got %v", cap.parsedBody["stream"])
}
msgs, _ := cap.parsedBody["messages"].([]any)
if len(msgs) != 1 {
t.Fatalf("messages: want 1 entry, got %d", len(msgs))
}
m0, _ := msgs[0].(map[string]any)
if m0["role"] != "user" || m0["content"] != "hi" {
t.Errorf("first message: want role=user content=hi, got %v", m0)
}
if got.Text != "hello there" {
t.Errorf("Text: want %q, got %q", "hello there", got.Text)
}
if got.Usage == nil {
t.Fatal("Usage: want non-nil")
}
if got.Usage.InputTokens != 10 || got.Usage.OutputTokens != 3 {
t.Errorf("Usage: want input=10 output=3, got input=%d output=%d", got.Usage.InputTokens, got.Usage.OutputTokens)
}
if got.Usage.TotalTokens != 13 {
t.Errorf("Usage.TotalTokens: want 13, got %d", got.Usage.TotalTokens)
}
}
func TestCompleteNoAuthHeaderWhenLocal(t *testing.T) {
resp := `{"message":{"role":"assistant","content":"ok"},"done":true}`
cap := &captureRequest{}
srv := newTestServer(t, cap, 200, resp, "")
p := newNative("", srv.URL)
if _, err := p.Complete(context.Background(), provider.Request{
Model: "llama3.2",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
}); err != nil {
t.Fatalf("Complete: %v", err)
}
if cap.authHeader != "" {
t.Errorf("auth header: want empty (local mode), got %q", cap.authHeader)
}
}
func TestVisionImagesEncoded(t *testing.T) {
resp := `{"message":{"role":"assistant","content":"a cat"},"done":true}`
cap := &captureRequest{}
srv := newTestServer(t, cap, 200, resp, "")
p := newNative("", srv.URL)
if _, err := p.Complete(context.Background(), provider.Request{
Model: "llava",
Messages: []provider.Message{{
Role: "user",
Content: "what's in this?",
Images: []provider.Image{
{Base64: "AAAA", ContentType: "image/png"},
},
}},
}); err != nil {
t.Fatalf("Complete: %v", err)
}
msgs, _ := cap.parsedBody["messages"].([]any)
if len(msgs) != 1 {
t.Fatalf("messages: want 1, got %d", len(msgs))
}
m0, _ := msgs[0].(map[string]any)
imgs, _ := m0["images"].([]any)
if len(imgs) != 1 {
t.Fatalf("images: want 1 entry, got %d (msg=%v)", len(imgs), m0)
}
if imgs[0] != "AAAA" {
t.Errorf("images[0]: want raw base64 AAAA, got %v", imgs[0])
}
}
func TestThinkingField(t *testing.T) {
cases := []struct {
name string
reasoning string
want any // expected value of "think" in body, or nil if absent
}{
{"absent", "", nil},
{"high", "high", "high"},
{"low", "low", "low"},
{"medium", "medium", "medium"},
{"true", "true", true},
{"false", "false", false},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
resp := `{"message":{"role":"assistant","content":"ok"},"done":true}`
cap := &captureRequest{}
srv := newTestServer(t, cap, 200, resp, "")
p := newNative("", srv.URL)
_, err := p.Complete(context.Background(), provider.Request{
Model: "kimi-k2.5",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
Reasoning: c.reasoning,
})
if err != nil {
t.Fatalf("Complete: %v", err)
}
got, present := cap.parsedBody["think"]
if c.want == nil {
if present {
t.Errorf("think field should be absent, got %v", got)
}
return
}
if !present {
t.Fatalf("think field absent; want %v", c.want)
}
if got != c.want {
t.Errorf("think: want %v (%T), got %v (%T)", c.want, c.want, got, got)
}
})
}
}
func TestToolRoundTrip(t *testing.T) {
t.Run("response tool_calls convert to provider.Response", func(t *testing.T) {
resp := `{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{"function": {"name": "search", "arguments": {"query": "foo"}}}
]
},
"done": true,
"prompt_eval_count": 5,
"eval_count": 2
}`
cap := &captureRequest{}
srv := newTestServer(t, cap, 200, resp, "")
p := newNative("", srv.URL)
got, err := p.Complete(context.Background(), provider.Request{
Model: "kimi-k2.5",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
Tools: []provider.ToolDef{
{
Name: "search",
Description: "Run a search",
Schema: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{"type": "string"},
},
},
},
},
})
if err != nil {
t.Fatalf("Complete: %v", err)
}
// Verify request shape: tools array present.
toolsArr, _ := cap.parsedBody["tools"].([]any)
if len(toolsArr) != 1 {
t.Fatalf("tools: want 1 entry, got %d", len(toolsArr))
}
t0, _ := toolsArr[0].(map[string]any)
if t0["type"] != "function" {
t.Errorf("tools[0].type: want function, got %v", t0["type"])
}
fn, _ := t0["function"].(map[string]any)
if fn["name"] != "search" {
t.Errorf("tools[0].function.name: want search, got %v", fn["name"])
}
// Verify response conversion.
if len(got.ToolCalls) != 1 {
t.Fatalf("ToolCalls: want 1, got %d", len(got.ToolCalls))
}
tc := got.ToolCalls[0]
if tc.Name != "search" {
t.Errorf("ToolCall.Name: want search, got %q", tc.Name)
}
// Arguments should be valid JSON containing query=foo
var args map[string]any
if err := json.Unmarshal([]byte(tc.Arguments), &args); err != nil {
t.Fatalf("ToolCall.Arguments not valid JSON: %v (got %q)", err, tc.Arguments)
}
if args["query"] != "foo" {
t.Errorf("ToolCall.Arguments.query: want foo, got %v", args["query"])
}
})
t.Run("subsequent request includes assistant tool_calls and tool-role response", func(t *testing.T) {
resp := `{"message":{"role":"assistant","content":"done"},"done":true}`
cap := &captureRequest{}
srv := newTestServer(t, cap, 200, resp, "")
p := newNative("", srv.URL)
_, err := p.Complete(context.Background(), provider.Request{
Model: "kimi-k2.5",
Messages: []provider.Message{
{Role: "user", Content: "search foo"},
{
Role: "assistant",
ToolCalls: []provider.ToolCall{{
ID: "tc1",
Name: "search",
Arguments: `{"query":"foo"}`,
}},
},
{
Role: "tool",
ToolCallID: "tc1",
Content: `{"result":"bar"}`,
},
},
})
if err != nil {
t.Fatalf("Complete: %v", err)
}
msgs, _ := cap.parsedBody["messages"].([]any)
if len(msgs) != 3 {
t.Fatalf("messages: want 3, got %d", len(msgs))
}
// Assistant message must carry tool_calls with the JSON-object arguments.
asst, _ := msgs[1].(map[string]any)
if asst["role"] != "assistant" {
t.Errorf("msgs[1].role: want assistant, got %v", asst["role"])
}
tc, _ := asst["tool_calls"].([]any)
if len(tc) != 1 {
t.Fatalf("assistant.tool_calls: want 1, got %d", len(tc))
}
fn, _ := tc[0].(map[string]any)["function"].(map[string]any)
if fn["name"] != "search" {
t.Errorf("assistant.tool_calls[0].function.name: want search, got %v", fn["name"])
}
args, _ := fn["arguments"].(map[string]any)
if args["query"] != "foo" {
t.Errorf("assistant.tool_calls[0].function.arguments.query: want foo, got %v", args["query"])
}
// Tool-role message must have role=tool, tool_call_id, and content.
tool, _ := msgs[2].(map[string]any)
if tool["role"] != "tool" {
t.Errorf("msgs[2].role: want tool, got %v", tool["role"])
}
if tool["tool_call_id"] != "tc1" {
t.Errorf("msgs[2].tool_call_id: want tc1, got %v", tool["tool_call_id"])
}
if !strings.Contains(toString(tool["content"]), "bar") {
t.Errorf("msgs[2].content: want to contain bar, got %v", tool["content"])
}
})
}
func toString(v any) string {
if s, ok := v.(string); ok {
return s
}
b, _ := json.Marshal(v)
return string(b)
}