package ollama import ( "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/http/httptest" "sync/atomic" "testing" ) func TestTags_ParsesResponse(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/api/tags" { t.Errorf("unexpected path: %s", r.URL.Path) } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(TagsResponse{ Models: []ModelInfo{ {Name: "qwen3:30b", Model: "qwen3:30b", Size: 19000000000}, {Name: "nomic-embed-text", Model: "nomic-embed-text", Size: 300000000}, }, }) })) defer srv.Close() client := NewClient(srv.URL, "") resp, err := client.Tags(context.Background()) if err != nil { t.Fatalf("Tags: %v", err) } if len(resp.Models) != 2 { t.Fatalf("got %d models, want 2", len(resp.Models)) } if resp.Models[0].Name != "qwen3:30b" { t.Errorf("first model = %q, want %q", resp.Models[0].Name, "qwen3:30b") } } func TestPs_ParsesResponse(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/api/ps" { t.Errorf("unexpected path: %s", r.URL.Path) } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(PsResponse{ Models: []RunningModel{ {Name: "nomic-embed-text", Model: "nomic-embed-text", Size: 300000000}, }, }) })) defer srv.Close() client := NewClient(srv.URL, "") resp, err := client.Ps(context.Background()) if err != nil { t.Fatalf("Ps: %v", err) } if len(resp.Models) != 1 { t.Fatalf("got %d models, want 1", len(resp.Models)) } } func TestChat_NonStreaming(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, _ := io.ReadAll(r.Body) var req ChatRequest if err := json.Unmarshal(body, &req); err != nil { t.Errorf("unmarshal: %v", err) } if req.Stream != nil && *req.Stream { t.Error("expected stream=false") } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(ChatResponse{ Model: "qwen3:30b", Done: true, Message: &Message{Role: "assistant", Content: "Hello!"}, }) })) defer srv.Close() client := NewClient(srv.URL, "") resp, ch, err := client.Chat(context.Background(), ChatRequest{ Model: "qwen3:30b", Messages: []Message{{Role: "user", Content: "hi"}}, }, false) if err != nil { t.Fatalf("Chat: %v", err) } if ch != nil { t.Error("expected nil channel for non-streaming") } if resp == nil { t.Fatal("expected non-nil response") } if resp.Message.Content != "Hello!" { t.Errorf("content = %q, want %q", resp.Message.Content, "Hello!") } } func TestChat_Streaming(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/x-ndjson") flusher := w.(http.Flusher) chunks := []ChatResponse{ {Model: "qwen3:30b", Done: false, Message: &Message{Role: "assistant", Content: "Hel"}}, {Model: "qwen3:30b", Done: false, Message: &Message{Role: "assistant", Content: "lo"}}, {Model: "qwen3:30b", Done: true, DoneReason: "stop"}, } for _, c := range chunks { b, _ := json.Marshal(c) w.Write(b) w.Write([]byte("\n")) flusher.Flush() } })) defer srv.Close() client := NewClient(srv.URL, "") resp, ch, err := client.Chat(context.Background(), ChatRequest{ Model: "qwen3:30b", Messages: []Message{{Role: "user", Content: "hi"}}, }, true) if err != nil { t.Fatalf("Chat: %v", err) } if resp != nil { t.Error("expected nil response for streaming") } if ch == nil { t.Fatal("expected non-nil channel for streaming") } var collected []ChatResponse for c := range ch { collected = append(collected, c) } if len(collected) != 3 { t.Fatalf("got %d chunks, want 3", len(collected)) } if !collected[2].Done { t.Error("last chunk should have done=true") } } func TestEmbed_ParsesResponse(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/api/embed" { t.Errorf("unexpected path: %s", r.URL.Path) } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(EmbedResponse{ Model: "nomic-embed-text", Embeddings: [][]float64{{0.1, 0.2, 0.3}}, }) })) defer srv.Close() client := NewClient(srv.URL, "") resp, err := client.Embed(context.Background(), EmbedRequest{ Model: "nomic-embed-text", Input: json.RawMessage(`"test input"`), }) if err != nil { t.Fatalf("Embed: %v", err) } if len(resp.Embeddings) != 1 { t.Fatalf("got %d embeddings, want 1", len(resp.Embeddings)) } if len(resp.Embeddings[0]) != 3 { t.Errorf("embedding dim = %d, want 3", len(resp.Embeddings[0])) } } func TestClient_SetsAuthToken(t *testing.T) { var gotAuth string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotAuth = r.Header.Get("Authorization") w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(TagsResponse{}) })) defer srv.Close() client := NewClient(srv.URL, "my-secret-token") _, err := client.Tags(context.Background()) if err != nil { t.Fatalf("Tags: %v", err) } if gotAuth != "Bearer my-secret-token" { t.Errorf("auth header = %q, want %q", gotAuth, "Bearer my-secret-token") } } func TestClient_ConnectionError(t *testing.T) { // Use a server that immediately closes. client := NewClient("http://127.0.0.1:1", "") _, err := client.Tags(context.Background()) if err == nil { t.Fatal("expected error for unreachable target") } var connErr *ConnectionError if !errors.As(err, &connErr) { t.Errorf("expected *ConnectionError, got %T: %v", err, err) } } func TestClient_HTTPError(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, `{"error":"bad model"}`, http.StatusBadRequest) })) defer srv.Close() client := NewClient(srv.URL, "") _, err := client.Tags(context.Background()) if err == nil { t.Fatal("expected error for 400 response") } var httpErr *HTTPError if !errors.As(err, &httpErr) { t.Errorf("expected *HTTPError, got %T: %v", err, err) } if httpErr.StatusCode != http.StatusBadRequest { t.Errorf("status = %d, want %d", httpErr.StatusCode, http.StatusBadRequest) } } func TestRawChat_ForwardsBody(t *testing.T) { var gotBody []byte srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotBody, _ = io.ReadAll(r.Body) w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{"done":true}`)) })) defer srv.Close() client := NewClient(srv.URL, "") body := []byte(`{"model":"qwen3:30b","messages":[{"role":"user","content":"test"}]}`) resp, err := client.RawChat(context.Background(), body) if err != nil { t.Fatalf("RawChat: %v", err) } defer resp.Body.Close() if string(gotBody) != string(body) { t.Errorf("body mismatch: got %s, want %s", gotBody, body) } } func TestRawEmbed_ForwardsBody(t *testing.T) { var callCount atomic.Int32 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount.Add(1) if r.URL.Path != "/api/embed" { t.Errorf("unexpected path: %s", r.URL.Path) } w.Header().Set("Content-Type", "application/json") w.Write([]byte(fmt.Sprintf(`{"model":"nomic-embed-text","embeddings":[[0.%d]]}`, callCount.Load()))) })) defer srv.Close() client := NewClient(srv.URL, "") body := []byte(`{"model":"nomic-embed-text","input":"test"}`) resp, err := client.RawEmbed(context.Background(), body) if err != nil { t.Fatalf("RawEmbed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) } }