f70c7c0842
Reads Ollama's NDJSON stream (one JSON object per line) and emits provider.StreamEvent values for text, thinking, tool-call start/delta/end, and a final Done event carrying assembled Response and Usage. Uses bufio.Scanner with a 4 MiB max-line buffer so multi-KB tool-call deltas parse cleanly, and accepts tool-call arguments delivered either as escaped string fragments (delta-style) or a complete JSON object (one-shot). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
536 lines
15 KiB
Go
536 lines
15 KiB
Go
// Package ollama implements the go-llm v2 provider interface for Ollama,
|
|
// targeting Ollama's native /api/chat endpoint. Supports both local Ollama
|
|
// instances (no API key) and Ollama Cloud (https://ollama.com, requires an
|
|
// API key).
|
|
package ollama
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
|
)
|
|
|
|
// DefaultLocalBaseURL is the default base URL for a locally-running Ollama
|
|
// instance.
|
|
const DefaultLocalBaseURL = "http://localhost:11434"
|
|
|
|
// DefaultCloudBaseURL is the default base URL for Ollama Cloud.
|
|
const DefaultCloudBaseURL = "https://ollama.com"
|
|
|
|
// Provider implements provider.Provider over Ollama's native /api/chat
|
|
// endpoint. An empty apiKey means local-mode (no Authorization header sent);
|
|
// a non-empty apiKey is sent as a Bearer token (cloud-mode).
|
|
type Provider struct {
|
|
apiKey string
|
|
baseURL string
|
|
client *http.Client
|
|
}
|
|
|
|
// newNative constructs a native Ollama provider. Callers should use the
|
|
// package-level New() constructor or the v2 llm.Ollama() / llm.OllamaCloud()
|
|
// helpers.
|
|
func newNative(apiKey, baseURL string) *Provider {
|
|
return &Provider{
|
|
apiKey: apiKey,
|
|
baseURL: baseURL,
|
|
client: &http.Client{},
|
|
}
|
|
}
|
|
|
|
// nativeChatRequest is the JSON body POSTed to /api/chat.
|
|
type nativeChatRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []nativeChatMessage `json:"messages"`
|
|
Tools []nativeToolDef `json:"tools,omitempty"`
|
|
Stream bool `json:"stream"`
|
|
// Think is polymorphic — Ollama accepts true/false or "low"/"medium"/"high".
|
|
Think json.RawMessage `json:"think,omitempty"`
|
|
Options map[string]any `json:"options,omitempty"`
|
|
}
|
|
|
|
// nativeChatMessage is one entry in the messages array on the wire. It also
|
|
// carries assistant tool calls and tool-role responses.
|
|
type nativeChatMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content,omitempty"`
|
|
Images []string `json:"images,omitempty"`
|
|
ToolCalls []nativeToolCall `json:"tool_calls,omitempty"`
|
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
|
Thinking string `json:"thinking,omitempty"`
|
|
}
|
|
|
|
// nativeToolCall mirrors Ollama's tool-call wire shape: a function with name
|
|
// and JSON-encoded arguments. Ollama's spec doesn't require an id, but some
|
|
// builds and some streaming chunks include one — we accept it on both wire and
|
|
// internal sides.
|
|
type nativeToolCall struct {
|
|
ID string `json:"id,omitempty"`
|
|
Function nativeFunctionCall `json:"function"`
|
|
}
|
|
|
|
type nativeFunctionCall struct {
|
|
Index *int `json:"index,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
Arguments json.RawMessage `json:"arguments,omitempty"`
|
|
}
|
|
|
|
// nativeChatResponse is the JSON body returned from a non-streaming /api/chat
|
|
// call (and is also the per-line shape during streaming).
|
|
type nativeChatResponse struct {
|
|
Model string `json:"model,omitempty"`
|
|
Message nativeChatMessage `json:"message"`
|
|
Done bool `json:"done"`
|
|
DoneReason string `json:"done_reason,omitempty"`
|
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
|
EvalCount int `json:"eval_count,omitempty"`
|
|
TotalDuration int64 `json:"total_duration,omitempty"`
|
|
}
|
|
|
|
// nativeToolDef is the wire shape of a tool definition sent to Ollama.
|
|
type nativeToolDef struct {
|
|
Type string `json:"type"`
|
|
Function nativeFunctionDef `json:"function"`
|
|
}
|
|
|
|
type nativeFunctionDef struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description,omitempty"`
|
|
Parameters map[string]any `json:"parameters,omitempty"`
|
|
}
|
|
|
|
// encodeThink converts a go-llm Reasoning string ("", "low", "medium",
|
|
// "high", or the literal strings "true"/"false") into Ollama's polymorphic
|
|
// `think` field. Returns nil for the empty string so the field is omitted.
|
|
func encodeThink(reasoning string) json.RawMessage {
|
|
switch reasoning {
|
|
case "":
|
|
return nil
|
|
case "true":
|
|
return json.RawMessage(`true`)
|
|
case "false":
|
|
return json.RawMessage(`false`)
|
|
default:
|
|
// "low" / "medium" / "high" — encode as a JSON string.
|
|
b, _ := json.Marshal(reasoning)
|
|
return b
|
|
}
|
|
}
|
|
|
|
// Complete performs a non-streaming chat completion via /api/chat.
|
|
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
|
body, err := p.buildChatRequest(req, false)
|
|
if err != nil {
|
|
return provider.Response{}, err
|
|
}
|
|
|
|
httpResp, err := p.doChatRequest(ctx, body)
|
|
if err != nil {
|
|
return provider.Response{}, err
|
|
}
|
|
defer httpResp.Body.Close()
|
|
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
b, _ := io.ReadAll(httpResp.Body)
|
|
return provider.Response{}, fmt.Errorf("ollama: HTTP %d: %s", httpResp.StatusCode, string(b))
|
|
}
|
|
|
|
var chat nativeChatResponse
|
|
if err := json.NewDecoder(httpResp.Body).Decode(&chat); err != nil {
|
|
return provider.Response{}, fmt.Errorf("ollama: decode response: %w", err)
|
|
}
|
|
|
|
resp := provider.Response{
|
|
Text: chat.Message.Content,
|
|
Thinking: chat.Message.Thinking,
|
|
}
|
|
for i, tc := range chat.Message.ToolCalls {
|
|
resp.ToolCalls = append(resp.ToolCalls, provider.ToolCall{
|
|
ID: toolCallID(tc, i),
|
|
Name: tc.Function.Name,
|
|
Arguments: rawMessageToArgString(tc.Function.Arguments),
|
|
})
|
|
}
|
|
if chat.PromptEvalCount > 0 || chat.EvalCount > 0 {
|
|
resp.Usage = &provider.Usage{
|
|
InputTokens: chat.PromptEvalCount,
|
|
OutputTokens: chat.EvalCount,
|
|
TotalTokens: chat.PromptEvalCount + chat.EvalCount,
|
|
}
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
// Stream performs a streaming chat completion via /api/chat with
|
|
// `stream: true`, parsing NDJSON line-by-line. Tool-call argument deltas are
|
|
// accumulated across chunks keyed by id (or function index) and finalized
|
|
// when the upstream Done flag arrives.
|
|
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
|
defer close(events)
|
|
|
|
body, err := p.buildChatRequest(req, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
httpResp, err := p.doChatRequest(ctx, body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer httpResp.Body.Close()
|
|
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
b, _ := io.ReadAll(httpResp.Body)
|
|
return fmt.Errorf("ollama: HTTP %d: %s", httpResp.StatusCode, string(b))
|
|
}
|
|
|
|
scanner := bufio.NewScanner(httpResp.Body)
|
|
// Ollama can emit multi-KB lines on tool-call deltas. Generous buffer.
|
|
const maxLineSize = 4 * 1024 * 1024
|
|
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
|
|
|
type toolAcc struct {
|
|
id string
|
|
name string
|
|
args strings.Builder
|
|
index int // ToolIndex emitted on stream events
|
|
}
|
|
tools := map[string]*toolAcc{}
|
|
var toolOrder []*toolAcc
|
|
|
|
var (
|
|
fullText strings.Builder
|
|
fullThinking strings.Builder
|
|
usage *provider.Usage
|
|
streamErr error
|
|
)
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Bytes()
|
|
if len(bytes.TrimSpace(line)) == 0 {
|
|
continue
|
|
}
|
|
var chunk nativeChatResponse
|
|
if err := json.Unmarshal(line, &chunk); err != nil {
|
|
streamErr = fmt.Errorf("ollama: decode stream chunk: %w", err)
|
|
break
|
|
}
|
|
|
|
if chunk.Message.Thinking != "" {
|
|
fullThinking.WriteString(chunk.Message.Thinking)
|
|
events <- provider.StreamEvent{
|
|
Type: provider.StreamEventThinking,
|
|
Text: chunk.Message.Thinking,
|
|
}
|
|
}
|
|
if chunk.Message.Content != "" {
|
|
fullText.WriteString(chunk.Message.Content)
|
|
events <- provider.StreamEvent{
|
|
Type: provider.StreamEventText,
|
|
Text: chunk.Message.Content,
|
|
}
|
|
}
|
|
|
|
for pos, tc := range chunk.Message.ToolCalls {
|
|
key := streamToolKey(tc, pos)
|
|
acc, exists := tools[key]
|
|
if !exists {
|
|
acc = &toolAcc{
|
|
id: tc.ID,
|
|
name: tc.Function.Name,
|
|
index: len(toolOrder),
|
|
}
|
|
if acc.id == "" {
|
|
acc.id = fmt.Sprintf("tc_%d", acc.index)
|
|
}
|
|
tools[key] = acc
|
|
toolOrder = append(toolOrder, acc)
|
|
events <- provider.StreamEvent{
|
|
Type: provider.StreamEventToolStart,
|
|
ToolIndex: acc.index,
|
|
ToolCall: &provider.ToolCall{
|
|
ID: acc.id,
|
|
Name: acc.name,
|
|
},
|
|
}
|
|
} else {
|
|
// Continuation chunk may carry the tool's name late; capture it.
|
|
if tc.Function.Name != "" && acc.name == "" {
|
|
acc.name = tc.Function.Name
|
|
}
|
|
}
|
|
|
|
delta := decodeArgumentDelta(tc.Function.Arguments)
|
|
if delta != "" {
|
|
acc.args.WriteString(delta)
|
|
events <- provider.StreamEvent{
|
|
Type: provider.StreamEventToolDelta,
|
|
ToolIndex: acc.index,
|
|
ToolCall: &provider.ToolCall{
|
|
Arguments: delta,
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
if chunk.Done {
|
|
if chunk.PromptEvalCount > 0 || chunk.EvalCount > 0 {
|
|
usage = &provider.Usage{
|
|
InputTokens: chunk.PromptEvalCount,
|
|
OutputTokens: chunk.EvalCount,
|
|
TotalTokens: chunk.PromptEvalCount + chunk.EvalCount,
|
|
}
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil && streamErr == nil {
|
|
streamErr = fmt.Errorf("ollama: stream read: %w", err)
|
|
}
|
|
|
|
if streamErr != nil {
|
|
events <- provider.StreamEvent{
|
|
Type: provider.StreamEventError,
|
|
Error: streamErr,
|
|
}
|
|
return streamErr
|
|
}
|
|
|
|
// Finalize accumulated tool calls.
|
|
finalCalls := make([]provider.ToolCall, 0, len(toolOrder))
|
|
for _, acc := range toolOrder {
|
|
args := acc.args.String()
|
|
if args == "" {
|
|
args = "{}"
|
|
}
|
|
final := provider.ToolCall{
|
|
ID: acc.id,
|
|
Name: acc.name,
|
|
Arguments: args,
|
|
}
|
|
finalCalls = append(finalCalls, final)
|
|
events <- provider.StreamEvent{
|
|
Type: provider.StreamEventToolEnd,
|
|
ToolIndex: acc.index,
|
|
ToolCall: &final,
|
|
}
|
|
}
|
|
|
|
events <- provider.StreamEvent{
|
|
Type: provider.StreamEventDone,
|
|
Response: &provider.Response{
|
|
Text: fullText.String(),
|
|
Thinking: fullThinking.String(),
|
|
ToolCalls: finalCalls,
|
|
Usage: usage,
|
|
},
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// streamToolKey computes a stable map key correlating tool-call deltas
|
|
// across stream chunks. Prefer the wire id, fall back to function index,
|
|
// finally fall back to the tool's position in the chunk's tool_calls array
|
|
// (a single-tool stream collapses cleanly under any strategy).
|
|
func streamToolKey(tc nativeToolCall, position int) string {
|
|
if tc.ID != "" {
|
|
return "id:" + tc.ID
|
|
}
|
|
if tc.Function.Index != nil {
|
|
return fmt.Sprintf("idx:%d", *tc.Function.Index)
|
|
}
|
|
return fmt.Sprintf("pos:%d", position)
|
|
}
|
|
|
|
// decodeArgumentDelta returns the string fragment to append when a streamed
|
|
// tool-call chunk includes arguments. Ollama may emit arguments either as a
|
|
// JSON-encoded string fragment (chunk-by-chunk concatenation, openaicompat
|
|
// style) or as a complete object value (one-shot delivery). We accept both:
|
|
// strings are unwrapped, objects/arrays pass through verbatim.
|
|
func decodeArgumentDelta(raw json.RawMessage) string {
|
|
if len(raw) == 0 {
|
|
return ""
|
|
}
|
|
trimmed := bytes.TrimSpace(raw)
|
|
if len(trimmed) == 0 || string(trimmed) == "null" {
|
|
return ""
|
|
}
|
|
if trimmed[0] == '"' {
|
|
var s string
|
|
if err := json.Unmarshal(trimmed, &s); err == nil {
|
|
return s
|
|
}
|
|
}
|
|
return string(trimmed)
|
|
}
|
|
|
|
// buildChatRequest converts a provider.Request into the native wire body
|
|
// JSON. stream toggles the stream flag (true for /api/chat streaming).
|
|
func (p *Provider) buildChatRequest(req provider.Request, stream bool) ([]byte, error) {
|
|
wire := nativeChatRequest{
|
|
Model: req.Model,
|
|
Stream: stream,
|
|
Think: encodeThink(req.Reasoning),
|
|
}
|
|
|
|
for _, msg := range req.Messages {
|
|
m, err := convertMessage(msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
wire.Messages = append(wire.Messages, m)
|
|
}
|
|
|
|
for _, t := range req.Tools {
|
|
wire.Tools = append(wire.Tools, nativeToolDef{
|
|
Type: "function",
|
|
Function: nativeFunctionDef{
|
|
Name: t.Name,
|
|
Description: t.Description,
|
|
Parameters: t.Schema,
|
|
},
|
|
})
|
|
}
|
|
|
|
if req.Temperature != nil || req.MaxTokens != nil || req.TopP != nil || len(req.Stop) > 0 {
|
|
wire.Options = map[string]any{}
|
|
if req.Temperature != nil {
|
|
wire.Options["temperature"] = *req.Temperature
|
|
}
|
|
if req.TopP != nil {
|
|
wire.Options["top_p"] = *req.TopP
|
|
}
|
|
if req.MaxTokens != nil {
|
|
wire.Options["num_predict"] = *req.MaxTokens
|
|
}
|
|
if len(req.Stop) > 0 {
|
|
wire.Options["stop"] = req.Stop
|
|
}
|
|
}
|
|
|
|
return json.Marshal(wire)
|
|
}
|
|
|
|
// doChatRequest POSTs the wire body to /api/chat and returns the raw HTTP
|
|
// response. The caller is responsible for closing the response body.
|
|
func (p *Provider) doChatRequest(ctx context.Context, body []byte) (*http.Response, error) {
|
|
url := strings.TrimRight(p.baseURL, "/") + "/api/chat"
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ollama: build request: %w", err)
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
if p.apiKey != "" {
|
|
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
|
|
}
|
|
resp, err := p.client.Do(httpReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ollama: HTTP request: %w", err)
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
// convertMessage maps a provider.Message into a native wire message.
|
|
func convertMessage(msg provider.Message) (nativeChatMessage, error) {
|
|
out := nativeChatMessage{
|
|
Role: msg.Role,
|
|
Content: msg.Content,
|
|
ToolCallID: msg.ToolCallID,
|
|
}
|
|
|
|
for _, img := range msg.Images {
|
|
b64, err := imageToBase64(img)
|
|
if err != nil {
|
|
return nativeChatMessage{}, err
|
|
}
|
|
if b64 != "" {
|
|
out.Images = append(out.Images, b64)
|
|
}
|
|
}
|
|
|
|
for i, tc := range msg.ToolCalls {
|
|
raw := json.RawMessage(strings.TrimSpace(tc.Arguments))
|
|
if len(raw) == 0 {
|
|
raw = json.RawMessage(`{}`)
|
|
}
|
|
// Preserve a stable index so streaming peers can correlate deltas.
|
|
idx := i
|
|
out.ToolCalls = append(out.ToolCalls, nativeToolCall{
|
|
ID: tc.ID,
|
|
Function: nativeFunctionCall{
|
|
Index: &idx,
|
|
Name: tc.Name,
|
|
Arguments: raw,
|
|
},
|
|
})
|
|
}
|
|
|
|
return out, nil
|
|
}
|
|
|
|
// imageToBase64 returns the base64-encoded payload of an image, fetching
|
|
// URL-only images over HTTP if no inline base64 is supplied.
|
|
func imageToBase64(img provider.Image) (string, error) {
|
|
if img.Base64 != "" {
|
|
return img.Base64, nil
|
|
}
|
|
if img.URL == "" {
|
|
return "", nil
|
|
}
|
|
resp, err := http.Get(img.URL)
|
|
if err != nil {
|
|
return "", fmt.Errorf("ollama: fetch image %q: %w", img.URL, err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return "", fmt.Errorf("ollama: fetch image %q: HTTP %d", img.URL, resp.StatusCode)
|
|
}
|
|
data, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", fmt.Errorf("ollama: read image %q: %w", img.URL, err)
|
|
}
|
|
return base64.StdEncoding.EncodeToString(data), nil
|
|
}
|
|
|
|
// rawMessageToArgString converts a JSON-encoded arguments value into the
|
|
// string form the provider package uses for ToolCall.Arguments. Object/array
|
|
// values pass through verbatim; bare string values (some Ollama builds emit
|
|
// pre-stringified arguments) are unwrapped.
|
|
func rawMessageToArgString(raw json.RawMessage) string {
|
|
if len(raw) == 0 {
|
|
return "{}"
|
|
}
|
|
trimmed := strings.TrimSpace(string(raw))
|
|
if len(trimmed) == 0 {
|
|
return "{}"
|
|
}
|
|
if trimmed[0] == '"' {
|
|
var s string
|
|
if err := json.Unmarshal([]byte(trimmed), &s); err == nil {
|
|
return s
|
|
}
|
|
}
|
|
return trimmed
|
|
}
|
|
|
|
// toolCallID returns a stable identifier for a tool call. Ollama's native
|
|
// API typically does not include an id, so we synthesize one from the index
|
|
// when missing.
|
|
func toolCallID(tc nativeToolCall, index int) string {
|
|
if tc.ID != "" {
|
|
return tc.ID
|
|
}
|
|
if tc.Function.Index != nil {
|
|
return fmt.Sprintf("tc_%d", *tc.Function.Index)
|
|
}
|
|
return fmt.Sprintf("tc_%d", index)
|
|
}
|