From 583f8724b2e6f0ea29e5e6e4f24fd0d27c8648e0 Mon Sep 17 00:00:00 2001 From: Steve Dudenhoeffer Date: Fri, 1 May 2026 18:24:02 +0000 Subject: [PATCH] feat(v2/ollama): implement native Complete() with tools, vision, thinking Non-streaming /api/chat support including: - Vision via images: []base64 - Tool calls on assistant + tool-role response messages - think field accepting string reasoning levels (or "true"/"false") - Authorization header when apiKey is non-empty (cloud mode) Tool-call arguments are passed as JSON objects to the wire and surfaced as JSON-string Arguments on provider.ToolCall. Tool calls are assigned synthetic IDs (tc_) when Ollama omits one, so the round-trip back as an assistant tool_calls + tool-role message remains correlated. Co-Authored-By: Claude Opus 4.6 --- v2/ollama/native.go | 213 +++++++++++++++++++++++- v2/ollama/native_test.go | 347 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 555 insertions(+), 5 deletions(-) create mode 100644 v2/ollama/native_test.go diff --git a/v2/ollama/native.go b/v2/ollama/native.go index 7ff3976..16dd891 100644 --- a/v2/ollama/native.go +++ b/v2/ollama/native.go @@ -5,10 +5,14 @@ package ollama import ( + "bytes" "context" + "encoding/base64" "encoding/json" - "errors" + "fmt" + "io" "net/http" + "strings" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" ) @@ -119,15 +123,214 @@ func encodeThink(reasoning string) json.RawMessage { } } -var errNotImplemented = errors.New("ollama native provider: not implemented") - // Complete performs a non-streaming chat completion via /api/chat. func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) { - return provider.Response{}, errNotImplemented + body, err := p.buildChatRequest(req, false) + if err != nil { + return provider.Response{}, err + } + + httpResp, err := p.doChatRequest(ctx, body) + if err != nil { + return provider.Response{}, err + } + defer httpResp.Body.Close() + + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + return provider.Response{}, fmt.Errorf("ollama: HTTP %d: %s", httpResp.StatusCode, string(b)) + } + + var chat nativeChatResponse + if err := json.NewDecoder(httpResp.Body).Decode(&chat); err != nil { + return provider.Response{}, fmt.Errorf("ollama: decode response: %w", err) + } + + resp := provider.Response{ + Text: chat.Message.Content, + Thinking: chat.Message.Thinking, + } + for i, tc := range chat.Message.ToolCalls { + resp.ToolCalls = append(resp.ToolCalls, provider.ToolCall{ + ID: toolCallID(tc, i), + Name: tc.Function.Name, + Arguments: rawMessageToArgString(tc.Function.Arguments), + }) + } + if chat.PromptEvalCount > 0 || chat.EvalCount > 0 { + resp.Usage = &provider.Usage{ + InputTokens: chat.PromptEvalCount, + OutputTokens: chat.EvalCount, + TotalTokens: chat.PromptEvalCount + chat.EvalCount, + } + } + return resp, nil } // Stream performs a streaming chat completion via /api/chat with // `stream: true`, parsing NDJSON line-by-line. func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error { - return errNotImplemented + return fmt.Errorf("ollama native provider: Stream not implemented") +} + +// buildChatRequest converts a provider.Request into the native wire body +// JSON. stream toggles the stream flag (true for /api/chat streaming). +func (p *Provider) buildChatRequest(req provider.Request, stream bool) ([]byte, error) { + wire := nativeChatRequest{ + Model: req.Model, + Stream: stream, + Think: encodeThink(req.Reasoning), + } + + for _, msg := range req.Messages { + m, err := convertMessage(msg) + if err != nil { + return nil, err + } + wire.Messages = append(wire.Messages, m) + } + + for _, t := range req.Tools { + wire.Tools = append(wire.Tools, nativeToolDef{ + Type: "function", + Function: nativeFunctionDef{ + Name: t.Name, + Description: t.Description, + Parameters: t.Schema, + }, + }) + } + + if req.Temperature != nil || req.MaxTokens != nil || req.TopP != nil || len(req.Stop) > 0 { + wire.Options = map[string]any{} + if req.Temperature != nil { + wire.Options["temperature"] = *req.Temperature + } + if req.TopP != nil { + wire.Options["top_p"] = *req.TopP + } + if req.MaxTokens != nil { + wire.Options["num_predict"] = *req.MaxTokens + } + if len(req.Stop) > 0 { + wire.Options["stop"] = req.Stop + } + } + + return json.Marshal(wire) +} + +// doChatRequest POSTs the wire body to /api/chat and returns the raw HTTP +// response. The caller is responsible for closing the response body. +func (p *Provider) doChatRequest(ctx context.Context, body []byte) (*http.Response, error) { + url := strings.TrimRight(p.baseURL, "/") + "/api/chat" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("ollama: build request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + if p.apiKey != "" { + httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) + } + resp, err := p.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("ollama: HTTP request: %w", err) + } + return resp, nil +} + +// convertMessage maps a provider.Message into a native wire message. +func convertMessage(msg provider.Message) (nativeChatMessage, error) { + out := nativeChatMessage{ + Role: msg.Role, + Content: msg.Content, + ToolCallID: msg.ToolCallID, + } + + for _, img := range msg.Images { + b64, err := imageToBase64(img) + if err != nil { + return nativeChatMessage{}, err + } + if b64 != "" { + out.Images = append(out.Images, b64) + } + } + + for i, tc := range msg.ToolCalls { + raw := json.RawMessage(strings.TrimSpace(tc.Arguments)) + if len(raw) == 0 { + raw = json.RawMessage(`{}`) + } + // Preserve a stable index so streaming peers can correlate deltas. + idx := i + out.ToolCalls = append(out.ToolCalls, nativeToolCall{ + ID: tc.ID, + Function: nativeFunctionCall{ + Index: &idx, + Name: tc.Name, + Arguments: raw, + }, + }) + } + + return out, nil +} + +// imageToBase64 returns the base64-encoded payload of an image, fetching +// URL-only images over HTTP if no inline base64 is supplied. +func imageToBase64(img provider.Image) (string, error) { + if img.Base64 != "" { + return img.Base64, nil + } + if img.URL == "" { + return "", nil + } + resp, err := http.Get(img.URL) + if err != nil { + return "", fmt.Errorf("ollama: fetch image %q: %w", img.URL, err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", fmt.Errorf("ollama: fetch image %q: HTTP %d", img.URL, resp.StatusCode) + } + data, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("ollama: read image %q: %w", img.URL, err) + } + return base64.StdEncoding.EncodeToString(data), nil +} + +// rawMessageToArgString converts a JSON-encoded arguments value into the +// string form the provider package uses for ToolCall.Arguments. Object/array +// values pass through verbatim; bare string values (some Ollama builds emit +// pre-stringified arguments) are unwrapped. +func rawMessageToArgString(raw json.RawMessage) string { + if len(raw) == 0 { + return "{}" + } + trimmed := strings.TrimSpace(string(raw)) + if len(trimmed) == 0 { + return "{}" + } + if trimmed[0] == '"' { + var s string + if err := json.Unmarshal([]byte(trimmed), &s); err == nil { + return s + } + } + return trimmed +} + +// toolCallID returns a stable identifier for a tool call. Ollama's native +// API typically does not include an id, so we synthesize one from the index +// when missing. +func toolCallID(tc nativeToolCall, index int) string { + if tc.ID != "" { + return tc.ID + } + if tc.Function.Index != nil { + return fmt.Sprintf("tc_%d", *tc.Function.Index) + } + return fmt.Sprintf("tc_%d", index) } diff --git a/v2/ollama/native_test.go b/v2/ollama/native_test.go new file mode 100644 index 0000000..a88c580 --- /dev/null +++ b/v2/ollama/native_test.go @@ -0,0 +1,347 @@ +package ollama + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +// captureRequest is a tiny helper that records the inbound HTTP request and +// returns a configurable response body. +type captureRequest struct { + method string + path string + authHeader string + contentType string + body []byte + parsedBody map[string]any +} + +func newTestServer(t *testing.T, captured *captureRequest, status int, respBody string, respContentType string) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured.method = r.Method + captured.path = r.URL.Path + captured.authHeader = r.Header.Get("Authorization") + captured.contentType = r.Header.Get("Content-Type") + body, _ := io.ReadAll(r.Body) + captured.body = body + _ = json.Unmarshal(body, &captured.parsedBody) + if respContentType == "" { + respContentType = "application/json" + } + w.Header().Set("Content-Type", respContentType) + w.WriteHeader(status) + _, _ = w.Write([]byte(respBody)) + })) + t.Cleanup(srv.Close) + return srv +} + +func TestCompleteBasic(t *testing.T) { + resp := `{ + "model": "kimi-k2.5", + "message": {"role": "assistant", "content": "hello there"}, + "done": true, + "done_reason": "stop", + "prompt_eval_count": 10, + "eval_count": 3 + }` + cap := &captureRequest{} + srv := newTestServer(t, cap, 200, resp, "") + + p := newNative("test-key", srv.URL) + got, err := p.Complete(context.Background(), provider.Request{ + Model: "kimi-k2.5", + Messages: []provider.Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("Complete: %v", err) + } + + if cap.method != "POST" { + t.Errorf("method: want POST, got %q", cap.method) + } + if cap.path != "/api/chat" { + t.Errorf("path: want /api/chat, got %q", cap.path) + } + if cap.authHeader != "Bearer test-key" { + t.Errorf("auth header: want %q, got %q", "Bearer test-key", cap.authHeader) + } + if cap.contentType != "application/json" { + t.Errorf("content-type: want application/json, got %q", cap.contentType) + } + if cap.parsedBody["model"] != "kimi-k2.5" { + t.Errorf("body.model: want kimi-k2.5, got %v", cap.parsedBody["model"]) + } + if cap.parsedBody["stream"] != false { + t.Errorf("body.stream: want false, got %v", cap.parsedBody["stream"]) + } + msgs, _ := cap.parsedBody["messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("messages: want 1 entry, got %d", len(msgs)) + } + m0, _ := msgs[0].(map[string]any) + if m0["role"] != "user" || m0["content"] != "hi" { + t.Errorf("first message: want role=user content=hi, got %v", m0) + } + + if got.Text != "hello there" { + t.Errorf("Text: want %q, got %q", "hello there", got.Text) + } + if got.Usage == nil { + t.Fatal("Usage: want non-nil") + } + if got.Usage.InputTokens != 10 || got.Usage.OutputTokens != 3 { + t.Errorf("Usage: want input=10 output=3, got input=%d output=%d", got.Usage.InputTokens, got.Usage.OutputTokens) + } + if got.Usage.TotalTokens != 13 { + t.Errorf("Usage.TotalTokens: want 13, got %d", got.Usage.TotalTokens) + } +} + +func TestCompleteNoAuthHeaderWhenLocal(t *testing.T) { + resp := `{"message":{"role":"assistant","content":"ok"},"done":true}` + cap := &captureRequest{} + srv := newTestServer(t, cap, 200, resp, "") + + p := newNative("", srv.URL) + if _, err := p.Complete(context.Background(), provider.Request{ + Model: "llama3.2", + Messages: []provider.Message{{Role: "user", Content: "hi"}}, + }); err != nil { + t.Fatalf("Complete: %v", err) + } + + if cap.authHeader != "" { + t.Errorf("auth header: want empty (local mode), got %q", cap.authHeader) + } +} + +func TestVisionImagesEncoded(t *testing.T) { + resp := `{"message":{"role":"assistant","content":"a cat"},"done":true}` + cap := &captureRequest{} + srv := newTestServer(t, cap, 200, resp, "") + + p := newNative("", srv.URL) + if _, err := p.Complete(context.Background(), provider.Request{ + Model: "llava", + Messages: []provider.Message{{ + Role: "user", + Content: "what's in this?", + Images: []provider.Image{ + {Base64: "AAAA", ContentType: "image/png"}, + }, + }}, + }); err != nil { + t.Fatalf("Complete: %v", err) + } + + msgs, _ := cap.parsedBody["messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("messages: want 1, got %d", len(msgs)) + } + m0, _ := msgs[0].(map[string]any) + imgs, _ := m0["images"].([]any) + if len(imgs) != 1 { + t.Fatalf("images: want 1 entry, got %d (msg=%v)", len(imgs), m0) + } + if imgs[0] != "AAAA" { + t.Errorf("images[0]: want raw base64 AAAA, got %v", imgs[0]) + } +} + +func TestThinkingField(t *testing.T) { + cases := []struct { + name string + reasoning string + want any // expected value of "think" in body, or nil if absent + }{ + {"absent", "", nil}, + {"high", "high", "high"}, + {"low", "low", "low"}, + {"medium", "medium", "medium"}, + {"true", "true", true}, + {"false", "false", false}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + resp := `{"message":{"role":"assistant","content":"ok"},"done":true}` + cap := &captureRequest{} + srv := newTestServer(t, cap, 200, resp, "") + + p := newNative("", srv.URL) + _, err := p.Complete(context.Background(), provider.Request{ + Model: "kimi-k2.5", + Messages: []provider.Message{{Role: "user", Content: "hi"}}, + Reasoning: c.reasoning, + }) + if err != nil { + t.Fatalf("Complete: %v", err) + } + got, present := cap.parsedBody["think"] + if c.want == nil { + if present { + t.Errorf("think field should be absent, got %v", got) + } + return + } + if !present { + t.Fatalf("think field absent; want %v", c.want) + } + if got != c.want { + t.Errorf("think: want %v (%T), got %v (%T)", c.want, c.want, got, got) + } + }) + } +} + +func TestToolRoundTrip(t *testing.T) { + t.Run("response tool_calls convert to provider.Response", func(t *testing.T) { + resp := `{ + "message": { + "role": "assistant", + "content": "", + "tool_calls": [ + {"function": {"name": "search", "arguments": {"query": "foo"}}} + ] + }, + "done": true, + "prompt_eval_count": 5, + "eval_count": 2 + }` + cap := &captureRequest{} + srv := newTestServer(t, cap, 200, resp, "") + + p := newNative("", srv.URL) + got, err := p.Complete(context.Background(), provider.Request{ + Model: "kimi-k2.5", + Messages: []provider.Message{{Role: "user", Content: "hi"}}, + Tools: []provider.ToolDef{ + { + Name: "search", + Description: "Run a search", + Schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + }, + }, + }, + }, + }) + if err != nil { + t.Fatalf("Complete: %v", err) + } + + // Verify request shape: tools array present. + toolsArr, _ := cap.parsedBody["tools"].([]any) + if len(toolsArr) != 1 { + t.Fatalf("tools: want 1 entry, got %d", len(toolsArr)) + } + t0, _ := toolsArr[0].(map[string]any) + if t0["type"] != "function" { + t.Errorf("tools[0].type: want function, got %v", t0["type"]) + } + fn, _ := t0["function"].(map[string]any) + if fn["name"] != "search" { + t.Errorf("tools[0].function.name: want search, got %v", fn["name"]) + } + + // Verify response conversion. + if len(got.ToolCalls) != 1 { + t.Fatalf("ToolCalls: want 1, got %d", len(got.ToolCalls)) + } + tc := got.ToolCalls[0] + if tc.Name != "search" { + t.Errorf("ToolCall.Name: want search, got %q", tc.Name) + } + // Arguments should be valid JSON containing query=foo + var args map[string]any + if err := json.Unmarshal([]byte(tc.Arguments), &args); err != nil { + t.Fatalf("ToolCall.Arguments not valid JSON: %v (got %q)", err, tc.Arguments) + } + if args["query"] != "foo" { + t.Errorf("ToolCall.Arguments.query: want foo, got %v", args["query"]) + } + }) + + t.Run("subsequent request includes assistant tool_calls and tool-role response", func(t *testing.T) { + resp := `{"message":{"role":"assistant","content":"done"},"done":true}` + cap := &captureRequest{} + srv := newTestServer(t, cap, 200, resp, "") + + p := newNative("", srv.URL) + _, err := p.Complete(context.Background(), provider.Request{ + Model: "kimi-k2.5", + Messages: []provider.Message{ + {Role: "user", Content: "search foo"}, + { + Role: "assistant", + ToolCalls: []provider.ToolCall{{ + ID: "tc1", + Name: "search", + Arguments: `{"query":"foo"}`, + }}, + }, + { + Role: "tool", + ToolCallID: "tc1", + Content: `{"result":"bar"}`, + }, + }, + }) + if err != nil { + t.Fatalf("Complete: %v", err) + } + + msgs, _ := cap.parsedBody["messages"].([]any) + if len(msgs) != 3 { + t.Fatalf("messages: want 3, got %d", len(msgs)) + } + + // Assistant message must carry tool_calls with the JSON-object arguments. + asst, _ := msgs[1].(map[string]any) + if asst["role"] != "assistant" { + t.Errorf("msgs[1].role: want assistant, got %v", asst["role"]) + } + tc, _ := asst["tool_calls"].([]any) + if len(tc) != 1 { + t.Fatalf("assistant.tool_calls: want 1, got %d", len(tc)) + } + fn, _ := tc[0].(map[string]any)["function"].(map[string]any) + if fn["name"] != "search" { + t.Errorf("assistant.tool_calls[0].function.name: want search, got %v", fn["name"]) + } + args, _ := fn["arguments"].(map[string]any) + if args["query"] != "foo" { + t.Errorf("assistant.tool_calls[0].function.arguments.query: want foo, got %v", args["query"]) + } + + // Tool-role message must have role=tool, tool_call_id, and content. + tool, _ := msgs[2].(map[string]any) + if tool["role"] != "tool" { + t.Errorf("msgs[2].role: want tool, got %v", tool["role"]) + } + if tool["tool_call_id"] != "tc1" { + t.Errorf("msgs[2].tool_call_id: want tc1, got %v", tool["tool_call_id"]) + } + if !strings.Contains(toString(tool["content"]), "bar") { + t.Errorf("msgs[2].content: want to contain bar, got %v", tool["content"]) + } + }) +} + +func toString(v any) string { + if s, ok := v.(string); ok { + return s + } + b, _ := json.Marshal(v) + return string(b) +}