package llm import ( "context" "sync" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" ) // mockProvider is a configurable mock implementation of provider.Provider for testing. type mockProvider struct { CompleteFunc func(ctx context.Context, req provider.Request) (provider.Response, error) StreamFunc func(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error // mu guards Requests mu sync.Mutex Requests []provider.Request } func (m *mockProvider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) { m.mu.Lock() m.Requests = append(m.Requests, req) m.mu.Unlock() return m.CompleteFunc(ctx, req) } func (m *mockProvider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error { m.mu.Lock() m.Requests = append(m.Requests, req) m.mu.Unlock() if m.StreamFunc != nil { return m.StreamFunc(ctx, req, events) } close(events) return nil } // lastRequest returns the most recent request recorded by the mock. func (m *mockProvider) lastRequest() provider.Request { m.mu.Lock() defer m.mu.Unlock() if len(m.Requests) == 0 { return provider.Request{} } return m.Requests[len(m.Requests)-1] } // newMockProvider creates a mock that always returns the given response. func newMockProvider(resp provider.Response) *mockProvider { return &mockProvider{ CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) { return resp, nil }, } } // newMockProviderFunc creates a mock with a custom Complete function. func newMockProviderFunc(fn func(ctx context.Context, req provider.Request) (provider.Response, error)) *mockProvider { return &mockProvider{CompleteFunc: fn} } // newMockStreamProvider creates a mock that streams the given events. func newMockStreamProvider(events []provider.StreamEvent) *mockProvider { return &mockProvider{ CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, nil }, StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error { for _, ev := range events { select { case ch <- ev: case <-ctx.Done(): return ctx.Err() } } return nil }, } } // newMockModel creates a *Model backed by the given mock provider. func newMockModel(p *mockProvider) *Model { return &Model{ provider: p, model: "mock-model", } }