feat: add DeepSeek, Moonshot, xAI, Groq, Ollama; drop v1; migrate TUI to v2
Five OpenAI-compatible providers join the library as first-class constructors (llm.DeepSeek, llm.Moonshot, llm.XAI, llm.Groq, llm.Ollama). Their wire-level implementation is shared via a new v2/openaicompat package which is the extracted guts of the old v2/openai provider; each provider supplies its own Rules value to declare per-model constraints (e.g., DeepSeek Reasoner rejects tools and temperature, Moonshot/xAI accept images only on *-vision* models, Groq rejects audio input). v2/openai itself becomes a thin wrapper that sets RestrictTemperature for o-series and gpt-5 models. A new provider registry (v2/registry.go) exposes llm.Providers() and drives the TUI's provider picker so adding a provider in future is a single-file change. The TUI at cmd/llm was migrated from v1 to v2 and moved to v2/cmd/llm. With nothing else depending on v1, the v1 code at the repo root (all .go files, schema/, internal/, provider/, root go.mod/go.sum) is deleted. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,537 @@
|
||||
// Package openaicompat implements a shared chat-completion Provider for any
|
||||
// service that speaks the OpenAI Chat Completions API (OpenAI itself, DeepSeek,
|
||||
// Moonshot, xAI, Groq, Ollama, and friends).
|
||||
//
|
||||
// Most providers differ from vanilla OpenAI only in endpoint URL and a handful
|
||||
// of per-model quirks (e.g., "this model is text-only", "this model doesn't
|
||||
// accept tools", "drop temperature on reasoning models"). Those quirks are
|
||||
// captured declaratively via Rules, so a concrete provider package is usually
|
||||
// a one-function wrapper that calls New with its own base URL and Rules.
|
||||
package openaicompat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
"github.com/openai/openai-go/packages/param"
|
||||
"github.com/openai/openai-go/shared"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
// Rules encodes provider-specific constraints on top of the OpenAI wire
|
||||
// protocol. The zero value means "no restrictions" and behaves like vanilla
|
||||
// OpenAI. Individual fields are documented inline.
|
||||
type Rules struct {
|
||||
// MaxImagesPerMessage rejects requests whose any single message carries
|
||||
// more images than this cap. 0 means "no cap".
|
||||
MaxImagesPerMessage int
|
||||
|
||||
// MaxAudioPerMessage rejects requests whose any single message carries
|
||||
// more audio attachments than this cap. 0 means "no cap".
|
||||
MaxAudioPerMessage int
|
||||
|
||||
// SupportsVision, when non-nil, is consulted for every request that
|
||||
// includes any image attachments. If it returns false for the request's
|
||||
// model, the call fails with a FeatureUnsupportedError before hitting
|
||||
// the network.
|
||||
SupportsVision func(model string) bool
|
||||
|
||||
// SupportsTools, when non-nil, is consulted for every request that
|
||||
// includes any tool definitions. If it returns false for the model,
|
||||
// the call fails with a FeatureUnsupportedError before hitting the
|
||||
// network.
|
||||
SupportsTools func(model string) bool
|
||||
|
||||
// SupportsAudio, when non-nil, is consulted for every request that
|
||||
// includes any audio attachments. If it returns false for the model,
|
||||
// the call fails with a FeatureUnsupportedError.
|
||||
SupportsAudio func(model string) bool
|
||||
|
||||
// RestrictTemperature, when non-nil and returning true for the request's
|
||||
// model, causes the Temperature field to be silently dropped from the
|
||||
// outgoing request. Used by OpenAI o-series and gpt-5* which reject a
|
||||
// user-provided temperature.
|
||||
RestrictTemperature func(model string) bool
|
||||
|
||||
// CustomizeRequest is a last-mile hook invoked after buildRequest but
|
||||
// before the call is sent. It receives the fully built OpenAI SDK
|
||||
// parameters and may mutate them freely (add headers, flip flags, tweak
|
||||
// response_format, etc.).
|
||||
CustomizeRequest func(params *openai.ChatCompletionNewParams)
|
||||
}
|
||||
|
||||
// FeatureUnsupportedError is returned when a Rules predicate rejects a request
|
||||
// because the target model does not support a feature the caller included.
|
||||
type FeatureUnsupportedError struct {
|
||||
Feature string
|
||||
Model string
|
||||
}
|
||||
|
||||
func (e *FeatureUnsupportedError) Error() string {
|
||||
return fmt.Sprintf("openaicompat: model %q does not support %s", e.Model, e.Feature)
|
||||
}
|
||||
|
||||
// Provider implements provider.Provider for any OpenAI-compatible endpoint.
|
||||
type Provider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
rules Rules
|
||||
}
|
||||
|
||||
// New creates a Provider. baseURL may be empty to let the OpenAI SDK use its
|
||||
// default; in practice concrete provider packages always pass a default.
|
||||
func New(apiKey, baseURL string, rules Rules) *Provider {
|
||||
return &Provider{apiKey: apiKey, baseURL: baseURL, rules: rules}
|
||||
}
|
||||
|
||||
// Complete performs a non-streaming completion.
|
||||
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
if err := p.checkRules(req); err != nil {
|
||||
return provider.Response{}, err
|
||||
}
|
||||
|
||||
cl := openai.NewClient(p.requestOptions()...)
|
||||
oaiReq := p.buildRequest(req)
|
||||
if p.rules.CustomizeRequest != nil {
|
||||
p.rules.CustomizeRequest(&oaiReq)
|
||||
}
|
||||
|
||||
resp, err := cl.Chat.Completions.New(ctx, oaiReq)
|
||||
if err != nil {
|
||||
return provider.Response{}, fmt.Errorf("openai completion error: %w", err)
|
||||
}
|
||||
|
||||
return p.convertResponse(resp), nil
|
||||
}
|
||||
|
||||
// Stream performs a streaming completion.
|
||||
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
||||
if err := p.checkRules(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cl := openai.NewClient(p.requestOptions()...)
|
||||
oaiReq := p.buildRequest(req)
|
||||
oaiReq.StreamOptions = openai.ChatCompletionStreamOptionsParam{
|
||||
IncludeUsage: openai.Bool(true),
|
||||
}
|
||||
if p.rules.CustomizeRequest != nil {
|
||||
p.rules.CustomizeRequest(&oaiReq)
|
||||
}
|
||||
|
||||
stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq)
|
||||
|
||||
var fullText strings.Builder
|
||||
var toolCalls []provider.ToolCall
|
||||
toolCallArgs := map[int]*strings.Builder{}
|
||||
var usage *provider.Usage
|
||||
|
||||
for stream.Next() {
|
||||
chunk := stream.Current()
|
||||
|
||||
// Capture usage from the final chunk (present when StreamOptions.IncludeUsage is true)
|
||||
if chunk.Usage.TotalTokens > 0 {
|
||||
usage = &provider.Usage{
|
||||
InputTokens: int(chunk.Usage.PromptTokens),
|
||||
OutputTokens: int(chunk.Usage.CompletionTokens),
|
||||
TotalTokens: int(chunk.Usage.TotalTokens),
|
||||
Details: extractUsageDetails(chunk.Usage),
|
||||
}
|
||||
}
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
// Text delta
|
||||
if choice.Delta.Content != "" {
|
||||
fullText.WriteString(choice.Delta.Content)
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventText,
|
||||
Text: choice.Delta.Content,
|
||||
}
|
||||
}
|
||||
|
||||
// Tool call deltas
|
||||
for _, tc := range choice.Delta.ToolCalls {
|
||||
idx := int(tc.Index)
|
||||
|
||||
if tc.ID != "" {
|
||||
// New tool call starting
|
||||
for len(toolCalls) <= idx {
|
||||
toolCalls = append(toolCalls, provider.ToolCall{})
|
||||
}
|
||||
toolCalls[idx].ID = tc.ID
|
||||
toolCalls[idx].Name = tc.Function.Name
|
||||
toolCallArgs[idx] = &strings.Builder{}
|
||||
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventToolStart,
|
||||
ToolCall: &provider.ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
},
|
||||
ToolIndex: idx,
|
||||
}
|
||||
}
|
||||
|
||||
if tc.Function.Arguments != "" {
|
||||
if b, ok := toolCallArgs[idx]; ok {
|
||||
b.WriteString(tc.Function.Arguments)
|
||||
}
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventToolDelta,
|
||||
ToolIndex: idx,
|
||||
ToolCall: &provider.ToolCall{
|
||||
Arguments: tc.Function.Arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
return fmt.Errorf("openai stream error: %w", err)
|
||||
}
|
||||
|
||||
// Finalize tool calls
|
||||
for idx := range toolCalls {
|
||||
if b, ok := toolCallArgs[idx]; ok {
|
||||
toolCalls[idx].Arguments = b.String()
|
||||
}
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventToolEnd,
|
||||
ToolIndex: idx,
|
||||
ToolCall: &toolCalls[idx],
|
||||
}
|
||||
}
|
||||
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventDone,
|
||||
Response: &provider.Response{
|
||||
Text: fullText.String(),
|
||||
ToolCalls: toolCalls,
|
||||
Usage: usage,
|
||||
},
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) requestOptions() []option.RequestOption {
|
||||
opts := []option.RequestOption{option.WithAPIKey(p.apiKey)}
|
||||
if p.baseURL != "" {
|
||||
opts = append(opts, option.WithBaseURL(p.baseURL))
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
// checkRules applies all Rules predicates against a request and returns an
|
||||
// error if any constraint is violated. Runs before any network call.
|
||||
func (p *Provider) checkRules(req provider.Request) error {
|
||||
var hasImages, hasAudio bool
|
||||
for _, msg := range req.Messages {
|
||||
if len(msg.Images) > 0 {
|
||||
hasImages = true
|
||||
}
|
||||
if len(msg.Audio) > 0 {
|
||||
hasAudio = true
|
||||
}
|
||||
if p.rules.MaxImagesPerMessage > 0 && len(msg.Images) > p.rules.MaxImagesPerMessage {
|
||||
return fmt.Errorf("openaicompat: message has %d images, max allowed is %d for model %q",
|
||||
len(msg.Images), p.rules.MaxImagesPerMessage, req.Model)
|
||||
}
|
||||
if p.rules.MaxAudioPerMessage > 0 && len(msg.Audio) > p.rules.MaxAudioPerMessage {
|
||||
return fmt.Errorf("openaicompat: message has %d audio attachments, max allowed is %d for model %q",
|
||||
len(msg.Audio), p.rules.MaxAudioPerMessage, req.Model)
|
||||
}
|
||||
}
|
||||
|
||||
if hasImages && p.rules.SupportsVision != nil && !p.rules.SupportsVision(req.Model) {
|
||||
return &FeatureUnsupportedError{Feature: "vision", Model: req.Model}
|
||||
}
|
||||
if hasAudio && p.rules.SupportsAudio != nil && !p.rules.SupportsAudio(req.Model) {
|
||||
return &FeatureUnsupportedError{Feature: "audio", Model: req.Model}
|
||||
}
|
||||
if len(req.Tools) > 0 && p.rules.SupportsTools != nil && !p.rules.SupportsTools(req.Model) {
|
||||
return &FeatureUnsupportedError{Feature: "tools", Model: req.Model}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) buildRequest(req provider.Request) openai.ChatCompletionNewParams {
|
||||
oaiReq := openai.ChatCompletionNewParams{
|
||||
Model: req.Model,
|
||||
}
|
||||
|
||||
for _, msg := range req.Messages {
|
||||
oaiReq.Messages = append(oaiReq.Messages, convertMessage(msg, req.Model))
|
||||
}
|
||||
|
||||
for _, tool := range req.Tools {
|
||||
oaiReq.Tools = append(oaiReq.Tools, openai.ChatCompletionToolParam{
|
||||
Type: "function",
|
||||
Function: shared.FunctionDefinitionParam{
|
||||
Name: tool.Name,
|
||||
Description: openai.String(tool.Description),
|
||||
Parameters: openai.FunctionParameters(tool.Schema),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if req.Temperature != nil {
|
||||
if p.rules.RestrictTemperature == nil || !p.rules.RestrictTemperature(req.Model) {
|
||||
oaiReq.Temperature = openai.Float(*req.Temperature)
|
||||
}
|
||||
}
|
||||
|
||||
if req.MaxTokens != nil {
|
||||
oaiReq.MaxCompletionTokens = openai.Int(int64(*req.MaxTokens))
|
||||
}
|
||||
|
||||
if req.TopP != nil {
|
||||
oaiReq.TopP = openai.Float(*req.TopP)
|
||||
}
|
||||
|
||||
if len(req.Stop) > 0 {
|
||||
oaiReq.Stop = openai.ChatCompletionNewParamsStopUnion{OfString: openai.String(req.Stop[0])}
|
||||
}
|
||||
|
||||
return oaiReq
|
||||
}
|
||||
|
||||
func convertMessage(msg provider.Message, model string) openai.ChatCompletionMessageParamUnion {
|
||||
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
|
||||
var textContent param.Opt[string]
|
||||
|
||||
for _, img := range msg.Images {
|
||||
var url string
|
||||
if img.Base64 != "" {
|
||||
url = "data:" + img.ContentType + ";base64," + img.Base64
|
||||
} else if img.URL != "" {
|
||||
url = img.URL
|
||||
}
|
||||
if url != "" {
|
||||
arrayOfContentParts = append(arrayOfContentParts,
|
||||
openai.ChatCompletionContentPartUnionParam{
|
||||
OfImageURL: &openai.ChatCompletionContentPartImageParam{
|
||||
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
|
||||
URL: url,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
for _, aud := range msg.Audio {
|
||||
var b64Data string
|
||||
var format string
|
||||
|
||||
if aud.Base64 != "" {
|
||||
b64Data = aud.Base64
|
||||
format = audioFormat(aud.ContentType)
|
||||
} else if aud.URL != "" {
|
||||
resp, err := http.Get(aud.URL)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
b64Data = base64.StdEncoding.EncodeToString(data)
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
if ct == "" {
|
||||
ct = aud.ContentType
|
||||
}
|
||||
if ct == "" {
|
||||
ct = audioFormatFromURL(aud.URL)
|
||||
}
|
||||
format = audioFormat(ct)
|
||||
}
|
||||
|
||||
if b64Data != "" && format != "" {
|
||||
arrayOfContentParts = append(arrayOfContentParts,
|
||||
openai.ChatCompletionContentPartUnionParam{
|
||||
OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{
|
||||
InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
|
||||
Data: b64Data,
|
||||
Format: format,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if msg.Content != "" {
|
||||
if len(arrayOfContentParts) > 0 {
|
||||
arrayOfContentParts = append(arrayOfContentParts,
|
||||
openai.ChatCompletionContentPartUnionParam{
|
||||
OfText: &openai.ChatCompletionContentPartTextParam{
|
||||
Text: msg.Content,
|
||||
},
|
||||
},
|
||||
)
|
||||
} else {
|
||||
textContent = openai.String(msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Determine if this model uses developer messages instead of system
|
||||
useDeveloper := false
|
||||
parts := strings.Split(model, "-")
|
||||
if len(parts) > 1 && len(parts[0]) > 0 && parts[0][0] == 'o' {
|
||||
useDeveloper = true
|
||||
}
|
||||
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
if useDeveloper {
|
||||
return openai.ChatCompletionMessageParamUnion{
|
||||
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
|
||||
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
|
||||
OfString: textContent,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return openai.ChatCompletionMessageParamUnion{
|
||||
OfSystem: &openai.ChatCompletionSystemMessageParam{
|
||||
Content: openai.ChatCompletionSystemMessageParamContentUnion{
|
||||
OfString: textContent,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
case "user":
|
||||
return openai.ChatCompletionMessageParamUnion{
|
||||
OfUser: &openai.ChatCompletionUserMessageParam{
|
||||
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
||||
OfString: textContent,
|
||||
OfArrayOfContentParts: arrayOfContentParts,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
case "assistant":
|
||||
as := &openai.ChatCompletionAssistantMessageParam{}
|
||||
if msg.Content != "" {
|
||||
as.Content.OfString = openai.String(msg.Content)
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
|
||||
ID: tc.ID,
|
||||
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
||||
Name: tc.Name,
|
||||
Arguments: tc.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
return openai.ChatCompletionMessageParamUnion{OfAssistant: as}
|
||||
|
||||
case "tool":
|
||||
return openai.ChatCompletionMessageParamUnion{
|
||||
OfTool: &openai.ChatCompletionToolMessageParam{
|
||||
ToolCallID: msg.ToolCallID,
|
||||
Content: openai.ChatCompletionToolMessageParamContentUnion{
|
||||
OfString: openai.String(msg.Content),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to user message
|
||||
return openai.ChatCompletionMessageParamUnion{
|
||||
OfUser: &openai.ChatCompletionUserMessageParam{
|
||||
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
||||
OfString: textContent,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Response {
|
||||
var res provider.Response
|
||||
|
||||
if resp == nil || len(resp.Choices) == 0 {
|
||||
return res
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
res.Text = choice.Message.Content
|
||||
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: strings.TrimSpace(tc.Function.Arguments),
|
||||
})
|
||||
}
|
||||
|
||||
if resp.Usage.TotalTokens > 0 {
|
||||
res.Usage = &provider.Usage{
|
||||
InputTokens: int(resp.Usage.PromptTokens),
|
||||
OutputTokens: int(resp.Usage.CompletionTokens),
|
||||
TotalTokens: int(resp.Usage.TotalTokens),
|
||||
}
|
||||
res.Usage.Details = extractUsageDetails(resp.Usage)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// audioFormat converts a MIME type to an OpenAI audio format string ("wav" or "mp3").
|
||||
func audioFormat(contentType string) string {
|
||||
ct := strings.ToLower(contentType)
|
||||
switch {
|
||||
case strings.Contains(ct, "wav"):
|
||||
return "wav"
|
||||
case strings.Contains(ct, "mp3"), strings.Contains(ct, "mpeg"):
|
||||
return "mp3"
|
||||
default:
|
||||
return "wav"
|
||||
}
|
||||
}
|
||||
|
||||
// extractUsageDetails extracts provider-specific detail tokens from an OpenAI CompletionUsage.
|
||||
func extractUsageDetails(usage openai.CompletionUsage) map[string]int {
|
||||
details := map[string]int{}
|
||||
if usage.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||
details[provider.UsageDetailReasoningTokens] = int(usage.CompletionTokensDetails.ReasoningTokens)
|
||||
}
|
||||
if usage.CompletionTokensDetails.AudioTokens > 0 {
|
||||
details[provider.UsageDetailAudioOutputTokens] = int(usage.CompletionTokensDetails.AudioTokens)
|
||||
}
|
||||
if usage.PromptTokensDetails.CachedTokens > 0 {
|
||||
details[provider.UsageDetailCachedInputTokens] = int(usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
if usage.PromptTokensDetails.AudioTokens > 0 {
|
||||
details[provider.UsageDetailAudioInputTokens] = int(usage.PromptTokensDetails.AudioTokens)
|
||||
}
|
||||
if len(details) == 0 {
|
||||
return nil
|
||||
}
|
||||
return details
|
||||
}
|
||||
|
||||
// audioFormatFromURL guesses the audio format from a URL's file extension.
|
||||
func audioFormatFromURL(u string) string {
|
||||
ext := strings.ToLower(path.Ext(u))
|
||||
switch ext {
|
||||
case ".mp3":
|
||||
return "audio/mp3"
|
||||
case ".wav":
|
||||
return "audio/wav"
|
||||
default:
|
||||
return "audio/wav"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,313 @@
|
||||
package openaicompat_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/openai/openai-go"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
// newTestServer returns a httptest server that captures the raw request body
|
||||
// on POST /chat/completions and returns a canned OpenAI response so Complete()
|
||||
// succeeds. Use `captured` to assert on what the provider would send.
|
||||
func newTestServer(t *testing.T) (*httptest.Server, *[]byte) {
|
||||
t.Helper()
|
||||
var body []byte
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
b, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("read body: %v", err)
|
||||
}
|
||||
body = b
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{
|
||||
"id": "cmpl-1",
|
||||
"object": "chat.completion",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role":"assistant","content":"ok"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}
|
||||
}`)
|
||||
}))
|
||||
return srv, &body
|
||||
}
|
||||
|
||||
func textReq(model, content string) provider.Request {
|
||||
return provider.Request{
|
||||
Model: model,
|
||||
Messages: []provider.Message{{Role: "user", Content: content}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_ZeroRulesPassesThrough(t *testing.T) {
|
||||
srv, body := newTestServer(t)
|
||||
defer srv.Close()
|
||||
|
||||
temp := 0.7
|
||||
req := textReq("gpt-4o", "hi")
|
||||
req.Temperature = &temp
|
||||
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{})
|
||||
resp, err := p.Complete(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("Complete: %v", err)
|
||||
}
|
||||
if resp.Text != "ok" {
|
||||
t.Errorf("Text = %q, want %q", resp.Text, "ok")
|
||||
}
|
||||
|
||||
// Temperature should be present since RestrictTemperature is nil.
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(*body, &parsed); err != nil {
|
||||
t.Fatalf("unmarshal request body: %v", err)
|
||||
}
|
||||
if _, ok := parsed["temperature"]; !ok {
|
||||
t.Errorf("expected temperature in request body, got: %s", *body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_RestrictTemperatureDropsField(t *testing.T) {
|
||||
srv, body := newTestServer(t)
|
||||
defer srv.Close()
|
||||
|
||||
temp := 0.7
|
||||
req := textReq("o1", "hi")
|
||||
req.Temperature = &temp
|
||||
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||
RestrictTemperature: func(m string) bool { return strings.HasPrefix(m, "o") },
|
||||
})
|
||||
if _, err := p.Complete(context.Background(), req); err != nil {
|
||||
t.Fatalf("Complete: %v", err)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(*body, &parsed); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if _, ok := parsed["temperature"]; ok {
|
||||
t.Errorf("temperature should be dropped for o1, got: %s", *body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_SupportsVisionRejectsWhenFalse(t *testing.T) {
|
||||
srv, _ := newTestServer(t)
|
||||
defer srv.Close()
|
||||
|
||||
req := provider.Request{
|
||||
Model: "deepseek-chat",
|
||||
Messages: []provider.Message{{
|
||||
Role: "user",
|
||||
Content: "describe",
|
||||
Images: []provider.Image{{URL: "https://example.com/a.png"}},
|
||||
}},
|
||||
}
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||
SupportsVision: func(string) bool { return false },
|
||||
})
|
||||
_, err := p.Complete(context.Background(), req)
|
||||
var fue *openaicompat.FeatureUnsupportedError
|
||||
if !errors.As(err, &fue) {
|
||||
t.Fatalf("want FeatureUnsupportedError, got %v", err)
|
||||
}
|
||||
if fue.Feature != "vision" || fue.Model != "deepseek-chat" {
|
||||
t.Errorf("unexpected err: %+v", fue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_SupportsToolsRejectsWhenFalse(t *testing.T) {
|
||||
srv, _ := newTestServer(t)
|
||||
defer srv.Close()
|
||||
|
||||
req := provider.Request{
|
||||
Model: "deepseek-reasoner",
|
||||
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||
Tools: []provider.ToolDef{
|
||||
{Name: "get_weather", Description: "weather", Schema: map[string]any{"type": "object"}},
|
||||
},
|
||||
}
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||
SupportsTools: func(m string) bool { return !strings.Contains(m, "reasoner") },
|
||||
})
|
||||
_, err := p.Complete(context.Background(), req)
|
||||
var fue *openaicompat.FeatureUnsupportedError
|
||||
if !errors.As(err, &fue) {
|
||||
t.Fatalf("want FeatureUnsupportedError, got %v", err)
|
||||
}
|
||||
if fue.Feature != "tools" {
|
||||
t.Errorf("feature = %q, want tools", fue.Feature)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_SupportsAudioRejectsWhenFalse(t *testing.T) {
|
||||
srv, _ := newTestServer(t)
|
||||
defer srv.Close()
|
||||
|
||||
req := provider.Request{
|
||||
Model: "groq-llama",
|
||||
Messages: []provider.Message{{
|
||||
Role: "user",
|
||||
Audio: []provider.Audio{{Base64: "AAA=", ContentType: "audio/wav"}},
|
||||
}},
|
||||
}
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||
SupportsAudio: func(string) bool { return false },
|
||||
})
|
||||
_, err := p.Complete(context.Background(), req)
|
||||
var fue *openaicompat.FeatureUnsupportedError
|
||||
if !errors.As(err, &fue) {
|
||||
t.Fatalf("want FeatureUnsupportedError, got %v", err)
|
||||
}
|
||||
if fue.Feature != "audio" {
|
||||
t.Errorf("feature = %q, want audio", fue.Feature)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_MaxImagesPerMessage(t *testing.T) {
|
||||
srv, _ := newTestServer(t)
|
||||
defer srv.Close()
|
||||
|
||||
req := provider.Request{
|
||||
Model: "anything",
|
||||
Messages: []provider.Message{{
|
||||
Role: "user",
|
||||
Images: []provider.Image{
|
||||
{URL: "a"}, {URL: "b"}, {URL: "c"},
|
||||
},
|
||||
}},
|
||||
}
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{MaxImagesPerMessage: 2})
|
||||
_, err := p.Complete(context.Background(), req)
|
||||
if err == nil || !strings.Contains(err.Error(), "max allowed is 2") {
|
||||
t.Fatalf("want max-images error, got %v", err)
|
||||
}
|
||||
|
||||
// Exactly at limit succeeds.
|
||||
req.Messages[0].Images = req.Messages[0].Images[:2]
|
||||
if _, err := p.Complete(context.Background(), req); err != nil {
|
||||
t.Errorf("at-limit request should succeed, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_CustomizeRequestInvoked(t *testing.T) {
|
||||
srv, body := newTestServer(t)
|
||||
defer srv.Close()
|
||||
|
||||
called := false
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||
CustomizeRequest: func(params *openai.ChatCompletionNewParams) {
|
||||
called = true
|
||||
// Confirm we receive a non-empty built request.
|
||||
if params.Model != "gpt-4o" {
|
||||
t.Errorf("CustomizeRequest saw model %q, want gpt-4o", params.Model)
|
||||
}
|
||||
// Mutation here should end up on the wire.
|
||||
params.User = openai.String("test-user")
|
||||
},
|
||||
})
|
||||
if _, err := p.Complete(context.Background(), textReq("gpt-4o", "hi")); err != nil {
|
||||
t.Fatalf("Complete: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("CustomizeRequest hook was not invoked")
|
||||
}
|
||||
if !strings.Contains(string(*body), `"user":"test-user"`) {
|
||||
t.Errorf("mutation from CustomizeRequest not reflected on wire: %s", *body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_EmitsDoneAndText(t *testing.T) {
|
||||
// SSE stream with one content chunk then [DONE].
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, _ := w.(http.Flusher)
|
||||
for _, line := range []string{
|
||||
`data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"hel"}}]}`,
|
||||
`data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"lo"}}]}`,
|
||||
`data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`,
|
||||
`data: [DONE]`,
|
||||
} {
|
||||
_, _ = io.WriteString(w, line+"\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{})
|
||||
events := make(chan provider.StreamEvent, 16)
|
||||
go func() {
|
||||
_ = p.Stream(context.Background(), textReq("gpt-4o", "hi"), events)
|
||||
close(events)
|
||||
}()
|
||||
|
||||
var text strings.Builder
|
||||
var sawDone bool
|
||||
var doneUsage *provider.Usage
|
||||
for ev := range events {
|
||||
switch ev.Type {
|
||||
case provider.StreamEventText:
|
||||
text.WriteString(ev.Text)
|
||||
case provider.StreamEventDone:
|
||||
sawDone = true
|
||||
if ev.Response != nil {
|
||||
doneUsage = ev.Response.Usage
|
||||
}
|
||||
}
|
||||
}
|
||||
if text.String() != "hello" {
|
||||
t.Errorf("got text %q, want %q", text.String(), "hello")
|
||||
}
|
||||
if !sawDone {
|
||||
t.Fatal("no Done event emitted")
|
||||
}
|
||||
if doneUsage == nil || doneUsage.TotalTokens != 3 {
|
||||
t.Errorf("usage on Done = %+v, want TotalTokens=3", doneUsage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_RulesCheckedBeforeNetwork(t *testing.T) {
|
||||
// Server should never be hit when rules reject up front.
|
||||
hit := false
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hit = true
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||
SupportsVision: func(string) bool { return false },
|
||||
})
|
||||
req := provider.Request{
|
||||
Model: "no-vision-model",
|
||||
Messages: []provider.Message{{
|
||||
Role: "user",
|
||||
Images: []provider.Image{{URL: "a"}},
|
||||
}},
|
||||
}
|
||||
events := make(chan provider.StreamEvent, 4)
|
||||
err := p.Stream(context.Background(), req, events)
|
||||
var fue *openaicompat.FeatureUnsupportedError
|
||||
if !errors.As(err, &fue) {
|
||||
t.Fatalf("want FeatureUnsupportedError, got %v", err)
|
||||
}
|
||||
if hit {
|
||||
t.Error("server was contacted despite Rules violation")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user