package openai import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" "gitea.stevedudenhoeffer.com/steve/majordomo/llm" ) // model is one provider-bound target. type model struct { p *Provider id string caps llm.Capabilities } // Capabilities implements llm.Model. func (m *model) Capabilities() llm.Capabilities { return m.caps } // Generate implements llm.Model. func (m *model) Generate(ctx context.Context, req llm.Request, opts ...llm.Option) (*llm.Response, error) { req = req.Apply(opts...) if err := checkRequest(m.caps, req); err != nil { return nil, err } httpResp, err := m.do(ctx, req, false) if err != nil { return nil, err } defer httpResp.Body.Close() if httpResp.StatusCode/100 != 2 { return nil, m.apiError(httpResp) } var wire chatResponse if err := json.NewDecoder(httpResp.Body).Decode(&wire); err != nil { return nil, fmt.Errorf("openai: decode response: %w", err) } return m.toResponse(&wire), nil } // Stream implements llm.Model. func (m *model) Stream(ctx context.Context, req llm.Request, opts ...llm.Option) (llm.Stream, error) { req = req.Apply(opts...) if !m.caps.SupportsStreaming { return nil, fmt.Errorf("%w: streaming not supported by %s/%s", llm.ErrUnsupported, m.p.name, m.id) } if err := checkRequest(m.caps, req); err != nil { return nil, err } httpResp, err := m.do(ctx, req, true) if err != nil { return nil, err } if httpResp.StatusCode/100 != 2 { defer httpResp.Body.Close() return nil, m.apiError(httpResp) } sc := bufio.NewScanner(httpResp.Body) // Why: a single SSE data line carries a whole JSON chunk; tool-call // argument fragments can make lines far larger than Scanner's 64 KiB // default cap. sc.Buffer(make([]byte, 0, 64*1024), 16<<20) return &stream{m: m, body: httpResp.Body, sc: sc}, nil } // do builds and performs the HTTP request. Transport failures are wrapped // raw (never as *llm.APIError) so llm.Classify still sees net.Error, // syscall errnos, and context errors underneath. func (m *model) do(ctx context.Context, req llm.Request, stream bool) (*http.Response, error) { if m.p.apiKey == "" { // Why a synthetic 401: the constructor never fails, so a missing // key must surface at request time as the auth failure it is — // permanent under llm.Classify, like a real 401. return nil, &llm.APIError{ Provider: m.p.name, Model: m.id, Status: http.StatusUnauthorized, Code: "missing_api_key", Message: "no API key configured: set OPENAI_API_KEY or use WithAPIKey", } } body, err := json.Marshal(m.buildRequest(req, stream)) if err != nil { return nil, fmt.Errorf("openai: encode request: %w", err) } httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, m.p.baseURL+"/chat/completions", bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("openai: build request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+m.p.apiKey) if stream { httpReq.Header.Set("Accept", "text/event-stream") } httpResp, err := m.p.client.Do(httpReq) if err != nil { return nil, fmt.Errorf("openai: do request: %w", err) } return httpResp, nil } // apiError converts a non-2xx response into *llm.APIError, pulling code and // message from the {"error":{...}} body when it parses. func (m *model) apiError(httpResp *http.Response) error { apiErr := &llm.APIError{Provider: m.p.name, Model: m.id, Status: httpResp.StatusCode} body, _ := io.ReadAll(io.LimitReader(httpResp.Body, 1<<20)) var env errorEnvelope if err := json.Unmarshal(body, &env); err == nil && (env.Error.Message != "" || env.Error.Type != "" || env.Error.Code != "") { apiErr.Message = env.Error.Message apiErr.Code = env.Error.Code if apiErr.Code == "" { apiErr.Code = env.Error.Type } } else { // Why: compat servers emit all sorts of error bodies; a raw snippet // beats silence when the canonical envelope is absent. apiErr.Message = strings.TrimSpace(string(body)) } return apiErr } // toResponse maps the wire response onto the canonical llm.Response. func (m *model) toResponse(wire *chatResponse) *llm.Response { resp := &llm.Response{Model: m.p.name + "/" + m.id, Raw: wire} if wire.Usage != nil { resp.Usage = llm.Usage{ InputTokens: wire.Usage.PromptTokens, OutputTokens: wire.Usage.CompletionTokens, } } if len(wire.Choices) == 0 { resp.FinishReason = llm.FinishOther return resp } choice := wire.Choices[0] if choice.Message.Content != "" { resp.Parts = append(resp.Parts, llm.TextPart{Text: choice.Message.Content}) } for i, tc := range choice.Message.ToolCalls { id := tc.ID if id == "" { // Why: ToolResult.ID must echo ToolCall.ID, so calls from compat // servers that omit ids get synthesized ones. id = fmt.Sprintf("call_%d", i) } resp.ToolCalls = append(resp.ToolCalls, llm.ToolCall{ ID: id, Name: tc.Function.Name, Arguments: json.RawMessage(tc.Function.Arguments), }) } resp.FinishReason = mapFinish(choice.FinishReason, len(resp.ToolCalls) > 0) return resp } // mapFinish maps a wire finish_reason to the canonical enum. Tool-call // presence wins over the reported reason: a forced (named tool_choice) call // can finish with "stop" while still carrying tool_calls. func mapFinish(reason string, hasToolCalls bool) llm.FinishReason { if hasToolCalls { return llm.FinishToolCalls } switch reason { case "stop": return llm.FinishStop case "length": return llm.FinishLength case "tool_calls": return llm.FinishToolCalls case "content_filter": return llm.FinishContentFilter default: return llm.FinishOther } } // checkRequest enforces the model's effective capabilities. Why enforcement // rather than normalization: a separate media layer resizes/transcodes // images BEFORE requests reach the provider; this check is the honest // backstop that refuses, with llm.ErrUnsupported, what the target // declaredly cannot serve (chains advance past it penalty-free). func checkRequest(caps llm.Capabilities, req llm.Request) error { if len(req.Tools) > 0 && !caps.SupportsTools { return fmt.Errorf("%w: tools not supported", llm.ErrUnsupported) } if len(req.Schema) > 0 && !caps.SupportsStructured { return fmt.Errorf("%w: structured output not supported", llm.ErrUnsupported) } images := 0 for _, msg := range req.Messages { for _, part := range msg.Parts { img, ok := part.(llm.ImagePart) if !ok { continue } images++ if !caps.SupportsImages() { return fmt.Errorf("%w: image input not supported", llm.ErrUnsupported) } if !caps.MIMEAllowed(img.MIME) { return fmt.Errorf("%w: image MIME type %q not allowed (allowed: %s)", llm.ErrUnsupported, img.MIME, strings.Join(caps.AllowedImageMIME, ", ")) } if caps.MaxImageBytes > 0 && len(img.Data) > caps.MaxImageBytes { return fmt.Errorf("%w: image is %d bytes, limit is %d", llm.ErrUnsupported, len(img.Data), caps.MaxImageBytes) } } } if images > caps.MaxImagesPerReq { return fmt.Errorf("%w: request carries %d images, limit is %d", llm.ErrUnsupported, images, caps.MaxImagesPerReq) } return nil }