package llm import ( "context" "errors" "testing" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" ) func TestModel_Complete(t *testing.T) { mp := newMockProvider(provider.Response{ Text: "Hello!", Usage: &provider.Usage{ InputTokens: 10, OutputTokens: 5, TotalTokens: 15, }, }) model := newMockModel(mp) resp, err := model.Complete(context.Background(), []Message{UserMessage("Hi")}) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.Text != "Hello!" { t.Errorf("expected text 'Hello!', got %q", resp.Text) } if resp.Usage == nil { t.Fatal("expected usage, got nil") } if resp.Usage.InputTokens != 10 { t.Errorf("expected input tokens 10, got %d", resp.Usage.InputTokens) } if resp.Usage.OutputTokens != 5 { t.Errorf("expected output tokens 5, got %d", resp.Usage.OutputTokens) } if resp.Usage.TotalTokens != 15 { t.Errorf("expected total tokens 15, got %d", resp.Usage.TotalTokens) } } func TestModel_Complete_WithOptions(t *testing.T) { mp := newMockProvider(provider.Response{Text: "ok"}) model := newMockModel(mp) temp := 0.7 maxTok := 100 topP := 0.9 _, err := model.Complete(context.Background(), []Message{UserMessage("test")}, WithTemperature(temp), WithMaxTokens(maxTok), WithTopP(topP), WithStop("STOP", "END"), ) if err != nil { t.Fatalf("unexpected error: %v", err) } req := mp.lastRequest() if req.Temperature == nil || *req.Temperature != temp { t.Errorf("expected temperature %v, got %v", temp, req.Temperature) } if req.MaxTokens == nil || *req.MaxTokens != maxTok { t.Errorf("expected maxTokens %v, got %v", maxTok, req.MaxTokens) } if req.TopP == nil || *req.TopP != topP { t.Errorf("expected topP %v, got %v", topP, req.TopP) } if len(req.Stop) != 2 || req.Stop[0] != "STOP" || req.Stop[1] != "END" { t.Errorf("expected stop [STOP END], got %v", req.Stop) } } func TestModel_Complete_Error(t *testing.T) { wantErr := errors.New("provider error") mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, wantErr }) model := newMockModel(mp) _, err := model.Complete(context.Background(), []Message{UserMessage("Hi")}) if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, wantErr) { t.Errorf("expected error %v, got %v", wantErr, err) } } func TestModel_Complete_WithTools(t *testing.T) { mp := newMockProvider(provider.Response{Text: "done"}) model := newMockModel(mp) tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) { return "hello", nil }) tb := NewToolBox(tool) _, err := model.Complete(context.Background(), []Message{UserMessage("test")}, WithTools(tb)) if err != nil { t.Fatalf("unexpected error: %v", err) } req := mp.lastRequest() if len(req.Tools) != 1 { t.Fatalf("expected 1 tool, got %d", len(req.Tools)) } if req.Tools[0].Name != "greet" { t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name) } if req.Tools[0].Description != "Says hello" { t.Errorf("expected tool description 'Says hello', got %q", req.Tools[0].Description) } } func TestClient_Model(t *testing.T) { mp := newMockProvider(provider.Response{Text: "hi"}) client := newClient(mp) model := client.Model("test-model") resp, err := model.Complete(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.Text != "hi" { t.Errorf("expected 'hi', got %q", resp.Text) } req := mp.lastRequest() if req.Model != "test-model" { t.Errorf("expected model 'test-model', got %q", req.Model) } } func TestClient_WithMiddleware(t *testing.T) { var called bool mw := func(next CompletionFunc) CompletionFunc { return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { called = true return next(ctx, model, messages, cfg) } } mp := newMockProvider(provider.Response{Text: "ok"}) client := newClient(mp).WithMiddleware(mw) model := client.Model("test-model") _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } if !called { t.Error("middleware was not called") } } func TestModel_WithMiddleware(t *testing.T) { var order []string mw1 := func(next CompletionFunc) CompletionFunc { return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { order = append(order, "mw1") return next(ctx, model, messages, cfg) } } mw2 := func(next CompletionFunc) CompletionFunc { return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { order = append(order, "mw2") return next(ctx, model, messages, cfg) } } mp := newMockProvider(provider.Response{Text: "ok"}) model := newMockModel(mp).WithMiddleware(mw1).WithMiddleware(mw2) _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(order) != 2 || order[0] != "mw1" || order[1] != "mw2" { t.Errorf("expected middleware order [mw1 mw2], got %v", order) } } func TestModel_Complete_NoUsage(t *testing.T) { mp := newMockProvider(provider.Response{Text: "no usage"}) model := newMockModel(mp) resp, err := model.Complete(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.Usage != nil { t.Errorf("expected nil usage, got %+v", resp.Usage) } } func TestModel_Complete_ResponseMessage(t *testing.T) { mp := newMockProvider(provider.Response{Text: "response text"}) model := newMockModel(mp) resp, err := model.Complete(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } msg := resp.Message() if msg.Role != RoleAssistant { t.Errorf("expected role assistant, got %v", msg.Role) } if msg.Content.Text != "response text" { t.Errorf("expected text 'response text', got %q", msg.Content.Text) } }