feat(v2/ollama): implement native Complete() with tools, vision, thinking
Non-streaming /api/chat support including: - Vision via images: []base64 - Tool calls on assistant + tool-role response messages - think field accepting string reasoning levels (or "true"/"false") - Authorization header when apiKey is non-empty (cloud mode) Tool-call arguments are passed as JSON objects to the wire and surfaced as JSON-string Arguments on provider.ToolCall. Tool calls are assigned synthetic IDs (tc_<index>) when Ollama omits one, so the round-trip back as an assistant tool_calls + tool-role message remains correlated. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+208
-5
@@ -5,10 +5,14 @@
|
|||||||
package ollama
|
package ollama
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
)
|
)
|
||||||
@@ -119,15 +123,214 @@ func encodeThink(reasoning string) json.RawMessage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var errNotImplemented = errors.New("ollama native provider: not implemented")
|
|
||||||
|
|
||||||
// Complete performs a non-streaming chat completion via /api/chat.
|
// Complete performs a non-streaming chat completion via /api/chat.
|
||||||
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
return provider.Response{}, errNotImplemented
|
body, err := p.buildChatRequest(req, false)
|
||||||
|
if err != nil {
|
||||||
|
return provider.Response{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpResp, err := p.doChatRequest(ctx, body)
|
||||||
|
if err != nil {
|
||||||
|
return provider.Response{}, err
|
||||||
|
}
|
||||||
|
defer httpResp.Body.Close()
|
||||||
|
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
|
return provider.Response{}, fmt.Errorf("ollama: HTTP %d: %s", httpResp.StatusCode, string(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
var chat nativeChatResponse
|
||||||
|
if err := json.NewDecoder(httpResp.Body).Decode(&chat); err != nil {
|
||||||
|
return provider.Response{}, fmt.Errorf("ollama: decode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := provider.Response{
|
||||||
|
Text: chat.Message.Content,
|
||||||
|
Thinking: chat.Message.Thinking,
|
||||||
|
}
|
||||||
|
for i, tc := range chat.Message.ToolCalls {
|
||||||
|
resp.ToolCalls = append(resp.ToolCalls, provider.ToolCall{
|
||||||
|
ID: toolCallID(tc, i),
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
Arguments: rawMessageToArgString(tc.Function.Arguments),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if chat.PromptEvalCount > 0 || chat.EvalCount > 0 {
|
||||||
|
resp.Usage = &provider.Usage{
|
||||||
|
InputTokens: chat.PromptEvalCount,
|
||||||
|
OutputTokens: chat.EvalCount,
|
||||||
|
TotalTokens: chat.PromptEvalCount + chat.EvalCount,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stream performs a streaming chat completion via /api/chat with
|
// Stream performs a streaming chat completion via /api/chat with
|
||||||
// `stream: true`, parsing NDJSON line-by-line.
|
// `stream: true`, parsing NDJSON line-by-line.
|
||||||
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
||||||
return errNotImplemented
|
return fmt.Errorf("ollama native provider: Stream not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildChatRequest converts a provider.Request into the native wire body
|
||||||
|
// JSON. stream toggles the stream flag (true for /api/chat streaming).
|
||||||
|
func (p *Provider) buildChatRequest(req provider.Request, stream bool) ([]byte, error) {
|
||||||
|
wire := nativeChatRequest{
|
||||||
|
Model: req.Model,
|
||||||
|
Stream: stream,
|
||||||
|
Think: encodeThink(req.Reasoning),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
m, err := convertMessage(msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
wire.Messages = append(wire.Messages, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range req.Tools {
|
||||||
|
wire.Tools = append(wire.Tools, nativeToolDef{
|
||||||
|
Type: "function",
|
||||||
|
Function: nativeFunctionDef{
|
||||||
|
Name: t.Name,
|
||||||
|
Description: t.Description,
|
||||||
|
Parameters: t.Schema,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Temperature != nil || req.MaxTokens != nil || req.TopP != nil || len(req.Stop) > 0 {
|
||||||
|
wire.Options = map[string]any{}
|
||||||
|
if req.Temperature != nil {
|
||||||
|
wire.Options["temperature"] = *req.Temperature
|
||||||
|
}
|
||||||
|
if req.TopP != nil {
|
||||||
|
wire.Options["top_p"] = *req.TopP
|
||||||
|
}
|
||||||
|
if req.MaxTokens != nil {
|
||||||
|
wire.Options["num_predict"] = *req.MaxTokens
|
||||||
|
}
|
||||||
|
if len(req.Stop) > 0 {
|
||||||
|
wire.Options["stop"] = req.Stop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(wire)
|
||||||
|
}
|
||||||
|
|
||||||
|
// doChatRequest POSTs the wire body to /api/chat and returns the raw HTTP
|
||||||
|
// response. The caller is responsible for closing the response body.
|
||||||
|
func (p *Provider) doChatRequest(ctx context.Context, body []byte) (*http.Response, error) {
|
||||||
|
url := strings.TrimRight(p.baseURL, "/") + "/api/chat"
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ollama: build request: %w", err)
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
if p.apiKey != "" {
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||||
|
}
|
||||||
|
resp, err := p.client.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ollama: HTTP request: %w", err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertMessage maps a provider.Message into a native wire message.
|
||||||
|
func convertMessage(msg provider.Message) (nativeChatMessage, error) {
|
||||||
|
out := nativeChatMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: msg.Content,
|
||||||
|
ToolCallID: msg.ToolCallID,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, img := range msg.Images {
|
||||||
|
b64, err := imageToBase64(img)
|
||||||
|
if err != nil {
|
||||||
|
return nativeChatMessage{}, err
|
||||||
|
}
|
||||||
|
if b64 != "" {
|
||||||
|
out.Images = append(out.Images, b64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tc := range msg.ToolCalls {
|
||||||
|
raw := json.RawMessage(strings.TrimSpace(tc.Arguments))
|
||||||
|
if len(raw) == 0 {
|
||||||
|
raw = json.RawMessage(`{}`)
|
||||||
|
}
|
||||||
|
// Preserve a stable index so streaming peers can correlate deltas.
|
||||||
|
idx := i
|
||||||
|
out.ToolCalls = append(out.ToolCalls, nativeToolCall{
|
||||||
|
ID: tc.ID,
|
||||||
|
Function: nativeFunctionCall{
|
||||||
|
Index: &idx,
|
||||||
|
Name: tc.Name,
|
||||||
|
Arguments: raw,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// imageToBase64 returns the base64-encoded payload of an image, fetching
|
||||||
|
// URL-only images over HTTP if no inline base64 is supplied.
|
||||||
|
func imageToBase64(img provider.Image) (string, error) {
|
||||||
|
if img.Base64 != "" {
|
||||||
|
return img.Base64, nil
|
||||||
|
}
|
||||||
|
if img.URL == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
resp, err := http.Get(img.URL)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("ollama: fetch image %q: %w", img.URL, err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return "", fmt.Errorf("ollama: fetch image %q: HTTP %d", img.URL, resp.StatusCode)
|
||||||
|
}
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("ollama: read image %q: %w", img.URL, err)
|
||||||
|
}
|
||||||
|
return base64.StdEncoding.EncodeToString(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rawMessageToArgString converts a JSON-encoded arguments value into the
|
||||||
|
// string form the provider package uses for ToolCall.Arguments. Object/array
|
||||||
|
// values pass through verbatim; bare string values (some Ollama builds emit
|
||||||
|
// pre-stringified arguments) are unwrapped.
|
||||||
|
func rawMessageToArgString(raw json.RawMessage) string {
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
trimmed := strings.TrimSpace(string(raw))
|
||||||
|
if len(trimmed) == 0 {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
if trimmed[0] == '"' {
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal([]byte(trimmed), &s); err == nil {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
|
||||||
|
// toolCallID returns a stable identifier for a tool call. Ollama's native
|
||||||
|
// API typically does not include an id, so we synthesize one from the index
|
||||||
|
// when missing.
|
||||||
|
func toolCallID(tc nativeToolCall, index int) string {
|
||||||
|
if tc.ID != "" {
|
||||||
|
return tc.ID
|
||||||
|
}
|
||||||
|
if tc.Function.Index != nil {
|
||||||
|
return fmt.Sprintf("tc_%d", *tc.Function.Index)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("tc_%d", index)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,347 @@
|
|||||||
|
package ollama
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
// captureRequest is a tiny helper that records the inbound HTTP request and
|
||||||
|
// returns a configurable response body.
|
||||||
|
type captureRequest struct {
|
||||||
|
method string
|
||||||
|
path string
|
||||||
|
authHeader string
|
||||||
|
contentType string
|
||||||
|
body []byte
|
||||||
|
parsedBody map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestServer(t *testing.T, captured *captureRequest, status int, respBody string, respContentType string) *httptest.Server {
|
||||||
|
t.Helper()
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
captured.method = r.Method
|
||||||
|
captured.path = r.URL.Path
|
||||||
|
captured.authHeader = r.Header.Get("Authorization")
|
||||||
|
captured.contentType = r.Header.Get("Content-Type")
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
captured.body = body
|
||||||
|
_ = json.Unmarshal(body, &captured.parsedBody)
|
||||||
|
if respContentType == "" {
|
||||||
|
respContentType = "application/json"
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", respContentType)
|
||||||
|
w.WriteHeader(status)
|
||||||
|
_, _ = w.Write([]byte(respBody))
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
return srv
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompleteBasic(t *testing.T) {
|
||||||
|
resp := `{
|
||||||
|
"model": "kimi-k2.5",
|
||||||
|
"message": {"role": "assistant", "content": "hello there"},
|
||||||
|
"done": true,
|
||||||
|
"done_reason": "stop",
|
||||||
|
"prompt_eval_count": 10,
|
||||||
|
"eval_count": 3
|
||||||
|
}`
|
||||||
|
cap := &captureRequest{}
|
||||||
|
srv := newTestServer(t, cap, 200, resp, "")
|
||||||
|
|
||||||
|
p := newNative("test-key", srv.URL)
|
||||||
|
got, err := p.Complete(context.Background(), provider.Request{
|
||||||
|
Model: "kimi-k2.5",
|
||||||
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Complete: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cap.method != "POST" {
|
||||||
|
t.Errorf("method: want POST, got %q", cap.method)
|
||||||
|
}
|
||||||
|
if cap.path != "/api/chat" {
|
||||||
|
t.Errorf("path: want /api/chat, got %q", cap.path)
|
||||||
|
}
|
||||||
|
if cap.authHeader != "Bearer test-key" {
|
||||||
|
t.Errorf("auth header: want %q, got %q", "Bearer test-key", cap.authHeader)
|
||||||
|
}
|
||||||
|
if cap.contentType != "application/json" {
|
||||||
|
t.Errorf("content-type: want application/json, got %q", cap.contentType)
|
||||||
|
}
|
||||||
|
if cap.parsedBody["model"] != "kimi-k2.5" {
|
||||||
|
t.Errorf("body.model: want kimi-k2.5, got %v", cap.parsedBody["model"])
|
||||||
|
}
|
||||||
|
if cap.parsedBody["stream"] != false {
|
||||||
|
t.Errorf("body.stream: want false, got %v", cap.parsedBody["stream"])
|
||||||
|
}
|
||||||
|
msgs, _ := cap.parsedBody["messages"].([]any)
|
||||||
|
if len(msgs) != 1 {
|
||||||
|
t.Fatalf("messages: want 1 entry, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
m0, _ := msgs[0].(map[string]any)
|
||||||
|
if m0["role"] != "user" || m0["content"] != "hi" {
|
||||||
|
t.Errorf("first message: want role=user content=hi, got %v", m0)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Text != "hello there" {
|
||||||
|
t.Errorf("Text: want %q, got %q", "hello there", got.Text)
|
||||||
|
}
|
||||||
|
if got.Usage == nil {
|
||||||
|
t.Fatal("Usage: want non-nil")
|
||||||
|
}
|
||||||
|
if got.Usage.InputTokens != 10 || got.Usage.OutputTokens != 3 {
|
||||||
|
t.Errorf("Usage: want input=10 output=3, got input=%d output=%d", got.Usage.InputTokens, got.Usage.OutputTokens)
|
||||||
|
}
|
||||||
|
if got.Usage.TotalTokens != 13 {
|
||||||
|
t.Errorf("Usage.TotalTokens: want 13, got %d", got.Usage.TotalTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompleteNoAuthHeaderWhenLocal(t *testing.T) {
|
||||||
|
resp := `{"message":{"role":"assistant","content":"ok"},"done":true}`
|
||||||
|
cap := &captureRequest{}
|
||||||
|
srv := newTestServer(t, cap, 200, resp, "")
|
||||||
|
|
||||||
|
p := newNative("", srv.URL)
|
||||||
|
if _, err := p.Complete(context.Background(), provider.Request{
|
||||||
|
Model: "llama3.2",
|
||||||
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("Complete: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cap.authHeader != "" {
|
||||||
|
t.Errorf("auth header: want empty (local mode), got %q", cap.authHeader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVisionImagesEncoded(t *testing.T) {
|
||||||
|
resp := `{"message":{"role":"assistant","content":"a cat"},"done":true}`
|
||||||
|
cap := &captureRequest{}
|
||||||
|
srv := newTestServer(t, cap, 200, resp, "")
|
||||||
|
|
||||||
|
p := newNative("", srv.URL)
|
||||||
|
if _, err := p.Complete(context.Background(), provider.Request{
|
||||||
|
Model: "llava",
|
||||||
|
Messages: []provider.Message{{
|
||||||
|
Role: "user",
|
||||||
|
Content: "what's in this?",
|
||||||
|
Images: []provider.Image{
|
||||||
|
{Base64: "AAAA", ContentType: "image/png"},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("Complete: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs, _ := cap.parsedBody["messages"].([]any)
|
||||||
|
if len(msgs) != 1 {
|
||||||
|
t.Fatalf("messages: want 1, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
m0, _ := msgs[0].(map[string]any)
|
||||||
|
imgs, _ := m0["images"].([]any)
|
||||||
|
if len(imgs) != 1 {
|
||||||
|
t.Fatalf("images: want 1 entry, got %d (msg=%v)", len(imgs), m0)
|
||||||
|
}
|
||||||
|
if imgs[0] != "AAAA" {
|
||||||
|
t.Errorf("images[0]: want raw base64 AAAA, got %v", imgs[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThinkingField(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
reasoning string
|
||||||
|
want any // expected value of "think" in body, or nil if absent
|
||||||
|
}{
|
||||||
|
{"absent", "", nil},
|
||||||
|
{"high", "high", "high"},
|
||||||
|
{"low", "low", "low"},
|
||||||
|
{"medium", "medium", "medium"},
|
||||||
|
{"true", "true", true},
|
||||||
|
{"false", "false", false},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
resp := `{"message":{"role":"assistant","content":"ok"},"done":true}`
|
||||||
|
cap := &captureRequest{}
|
||||||
|
srv := newTestServer(t, cap, 200, resp, "")
|
||||||
|
|
||||||
|
p := newNative("", srv.URL)
|
||||||
|
_, err := p.Complete(context.Background(), provider.Request{
|
||||||
|
Model: "kimi-k2.5",
|
||||||
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||||
|
Reasoning: c.reasoning,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Complete: %v", err)
|
||||||
|
}
|
||||||
|
got, present := cap.parsedBody["think"]
|
||||||
|
if c.want == nil {
|
||||||
|
if present {
|
||||||
|
t.Errorf("think field should be absent, got %v", got)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !present {
|
||||||
|
t.Fatalf("think field absent; want %v", c.want)
|
||||||
|
}
|
||||||
|
if got != c.want {
|
||||||
|
t.Errorf("think: want %v (%T), got %v (%T)", c.want, c.want, got, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolRoundTrip(t *testing.T) {
|
||||||
|
t.Run("response tool_calls convert to provider.Response", func(t *testing.T) {
|
||||||
|
resp := `{
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{"function": {"name": "search", "arguments": {"query": "foo"}}}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"done": true,
|
||||||
|
"prompt_eval_count": 5,
|
||||||
|
"eval_count": 2
|
||||||
|
}`
|
||||||
|
cap := &captureRequest{}
|
||||||
|
srv := newTestServer(t, cap, 200, resp, "")
|
||||||
|
|
||||||
|
p := newNative("", srv.URL)
|
||||||
|
got, err := p.Complete(context.Background(), provider.Request{
|
||||||
|
Model: "kimi-k2.5",
|
||||||
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||||
|
Tools: []provider.ToolDef{
|
||||||
|
{
|
||||||
|
Name: "search",
|
||||||
|
Description: "Run a search",
|
||||||
|
Schema: map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"query": map[string]any{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Complete: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify request shape: tools array present.
|
||||||
|
toolsArr, _ := cap.parsedBody["tools"].([]any)
|
||||||
|
if len(toolsArr) != 1 {
|
||||||
|
t.Fatalf("tools: want 1 entry, got %d", len(toolsArr))
|
||||||
|
}
|
||||||
|
t0, _ := toolsArr[0].(map[string]any)
|
||||||
|
if t0["type"] != "function" {
|
||||||
|
t.Errorf("tools[0].type: want function, got %v", t0["type"])
|
||||||
|
}
|
||||||
|
fn, _ := t0["function"].(map[string]any)
|
||||||
|
if fn["name"] != "search" {
|
||||||
|
t.Errorf("tools[0].function.name: want search, got %v", fn["name"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify response conversion.
|
||||||
|
if len(got.ToolCalls) != 1 {
|
||||||
|
t.Fatalf("ToolCalls: want 1, got %d", len(got.ToolCalls))
|
||||||
|
}
|
||||||
|
tc := got.ToolCalls[0]
|
||||||
|
if tc.Name != "search" {
|
||||||
|
t.Errorf("ToolCall.Name: want search, got %q", tc.Name)
|
||||||
|
}
|
||||||
|
// Arguments should be valid JSON containing query=foo
|
||||||
|
var args map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(tc.Arguments), &args); err != nil {
|
||||||
|
t.Fatalf("ToolCall.Arguments not valid JSON: %v (got %q)", err, tc.Arguments)
|
||||||
|
}
|
||||||
|
if args["query"] != "foo" {
|
||||||
|
t.Errorf("ToolCall.Arguments.query: want foo, got %v", args["query"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("subsequent request includes assistant tool_calls and tool-role response", func(t *testing.T) {
|
||||||
|
resp := `{"message":{"role":"assistant","content":"done"},"done":true}`
|
||||||
|
cap := &captureRequest{}
|
||||||
|
srv := newTestServer(t, cap, 200, resp, "")
|
||||||
|
|
||||||
|
p := newNative("", srv.URL)
|
||||||
|
_, err := p.Complete(context.Background(), provider.Request{
|
||||||
|
Model: "kimi-k2.5",
|
||||||
|
Messages: []provider.Message{
|
||||||
|
{Role: "user", Content: "search foo"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []provider.ToolCall{{
|
||||||
|
ID: "tc1",
|
||||||
|
Name: "search",
|
||||||
|
Arguments: `{"query":"foo"}`,
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
ToolCallID: "tc1",
|
||||||
|
Content: `{"result":"bar"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Complete: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs, _ := cap.parsedBody["messages"].([]any)
|
||||||
|
if len(msgs) != 3 {
|
||||||
|
t.Fatalf("messages: want 3, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assistant message must carry tool_calls with the JSON-object arguments.
|
||||||
|
asst, _ := msgs[1].(map[string]any)
|
||||||
|
if asst["role"] != "assistant" {
|
||||||
|
t.Errorf("msgs[1].role: want assistant, got %v", asst["role"])
|
||||||
|
}
|
||||||
|
tc, _ := asst["tool_calls"].([]any)
|
||||||
|
if len(tc) != 1 {
|
||||||
|
t.Fatalf("assistant.tool_calls: want 1, got %d", len(tc))
|
||||||
|
}
|
||||||
|
fn, _ := tc[0].(map[string]any)["function"].(map[string]any)
|
||||||
|
if fn["name"] != "search" {
|
||||||
|
t.Errorf("assistant.tool_calls[0].function.name: want search, got %v", fn["name"])
|
||||||
|
}
|
||||||
|
args, _ := fn["arguments"].(map[string]any)
|
||||||
|
if args["query"] != "foo" {
|
||||||
|
t.Errorf("assistant.tool_calls[0].function.arguments.query: want foo, got %v", args["query"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool-role message must have role=tool, tool_call_id, and content.
|
||||||
|
tool, _ := msgs[2].(map[string]any)
|
||||||
|
if tool["role"] != "tool" {
|
||||||
|
t.Errorf("msgs[2].role: want tool, got %v", tool["role"])
|
||||||
|
}
|
||||||
|
if tool["tool_call_id"] != "tc1" {
|
||||||
|
t.Errorf("msgs[2].tool_call_id: want tc1, got %v", tool["tool_call_id"])
|
||||||
|
}
|
||||||
|
if !strings.Contains(toString(tool["content"]), "bar") {
|
||||||
|
t.Errorf("msgs[2].content: want to contain bar, got %v", tool["content"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func toString(v any) string {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(v)
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user