From 8801ce594566025f681d35fc96bb6978637300a7 Mon Sep 17 00:00:00 2001 From: Steve Dudenhoeffer Date: Sun, 25 Jan 2026 01:46:29 -0500 Subject: [PATCH] Add OpenAI-based transcriber implementation - Introduce `openaiTranscriber` for integrating OpenAI's Whisper audio transcription capabilities. - Define `Transcriber` interface and associated types (`Transcription`, `TranscriptionOptions`, segments, and words). - Implement transcription logic supporting features like languages, prompts, temperature, and timestamp granularities. - Add `audioFileToWav` utility using `ffmpeg` for audio file conversion to WAV format. - Ensure response parsing for structured and verbose JSON outputs. --- openai_transcriber.go | 219 ++++++++++++++++++++++++++++++++++++++++++ transcriber.go | 145 ++++++++++++++++++++++++++++ 2 files changed, 364 insertions(+) create mode 100644 openai_transcriber.go create mode 100644 transcriber.go diff --git a/openai_transcriber.go b/openai_transcriber.go new file mode 100644 index 0000000..3a554c3 --- /dev/null +++ b/openai_transcriber.go @@ -0,0 +1,219 @@ +package llm + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" +) + +type openaiTranscriber struct { + key string + model string + baseUrl string +} + +var _ Transcriber = openaiTranscriber{} + +// OpenAITranscriber creates a transcriber backed by OpenAI's audio models. +// If model is empty, whisper-1 is used by default. +func OpenAITranscriber(key string, model string) Transcriber { + if strings.TrimSpace(model) == "" { + model = "whisper-1" + } + return openaiTranscriber{ + key: key, + model: model, + } +} + +func (o openaiTranscriber) Transcribe(ctx context.Context, wav []byte, opts TranscriptionOptions) (Transcription, error) { + if len(wav) == 0 { + return Transcription{}, fmt.Errorf("wav data is empty") + } + + format := opts.ResponseFormat + if format == "" { + if strings.HasPrefix(o.model, "gpt-4o") { + format = TranscriptionResponseFormatJSON + } else { + format = TranscriptionResponseFormatVerboseJSON + } + } + + if format != TranscriptionResponseFormatJSON && format != TranscriptionResponseFormatVerboseJSON { + return Transcription{}, fmt.Errorf("openai transcriber requires response_format json or verbose_json for structured output") + } + + if len(opts.TimestampGranularities) > 0 && format != TranscriptionResponseFormatVerboseJSON { + return Transcription{}, fmt.Errorf("timestamp granularities require response_format=verbose_json") + } + + params := openai.AudioTranscriptionNewParams{ + File: bytes.NewReader(wav), + Model: openai.AudioModel(o.model), + } + + if opts.Language != "" { + params.Language = openai.String(opts.Language) + } + if opts.Prompt != "" { + params.Prompt = openai.String(opts.Prompt) + } + if opts.Temperature != nil { + params.Temperature = openai.Float(*opts.Temperature) + } + + params.ResponseFormat = openai.AudioResponseFormat(format) + + if opts.IncludeLogprobs { + params.Include = []openai.TranscriptionInclude{openai.TranscriptionIncludeLogprobs} + } + + if len(opts.TimestampGranularities) > 0 { + for _, granularity := range opts.TimestampGranularities { + params.TimestampGranularities = append(params.TimestampGranularities, string(granularity)) + } + } + + clientOptions := []option.RequestOption{ + option.WithAPIKey(o.key), + } + if o.baseUrl != "" { + clientOptions = append(clientOptions, option.WithBaseURL(o.baseUrl)) + } + + client := openai.NewClient(clientOptions...) + resp, err := client.Audio.Transcriptions.New(ctx, params) + if err != nil { + return Transcription{}, fmt.Errorf("openai transcription failed: %w", err) + } + + return openaiTranscriptionToResult(o.model, resp), nil +} + +type openaiVerboseTranscription struct { + Text string `json:"text"` + Language string `json:"language"` + Duration float64 `json:"duration"` + Segments []openaiVerboseSegment `json:"segments"` + Words []openaiVerboseWord `json:"words"` +} + +type openaiVerboseSegment struct { + ID int `json:"id"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + AvgLogprob *float64 `json:"avg_logprob"` + CompressionRatio *float64 `json:"compression_ratio"` + NoSpeechProb *float64 `json:"no_speech_prob"` + Words []openaiVerboseWord `json:"words"` +} + +type openaiVerboseWord struct { + Word string `json:"word"` + Start float64 `json:"start"` + End float64 `json:"end"` +} + +func openaiTranscriptionToResult(model string, resp *openai.Transcription) Transcription { + result := Transcription{ + Provider: "openai", + Model: model, + } + if resp == nil { + return result + } + + result.Text = resp.Text + result.RawJSON = resp.RawJSON() + + for _, logprob := range resp.Logprobs { + result.Logprobs = append(result.Logprobs, TranscriptionTokenLogprob{ + Token: logprob.Token, + Bytes: logprob.Bytes, + Logprob: logprob.Logprob, + }) + } + + if usage := openaiUsageToTranscriptionUsage(resp.Usage); usage.Type != "" { + result.Usage = usage + } + + if result.RawJSON == "" { + return result + } + + var verbose openaiVerboseTranscription + if err := json.Unmarshal([]byte(result.RawJSON), &verbose); err != nil { + return result + } + + if verbose.Text != "" { + result.Text = verbose.Text + } + result.Language = verbose.Language + result.DurationSeconds = verbose.Duration + + for _, seg := range verbose.Segments { + segment := TranscriptionSegment{ + ID: seg.ID, + Start: seg.Start, + End: seg.End, + Text: seg.Text, + Tokens: append([]int(nil), seg.Tokens...), + AvgLogprob: seg.AvgLogprob, + CompressionRatio: seg.CompressionRatio, + NoSpeechProb: seg.NoSpeechProb, + } + + for _, word := range seg.Words { + segment.Words = append(segment.Words, TranscriptionWord{ + Word: word.Word, + Start: word.Start, + End: word.End, + }) + } + + result.Segments = append(result.Segments, segment) + } + + for _, word := range verbose.Words { + result.Words = append(result.Words, TranscriptionWord{ + Word: word.Word, + Start: word.Start, + End: word.End, + }) + } + + return result +} + +func openaiUsageToTranscriptionUsage(usage openai.TranscriptionUsageUnion) TranscriptionUsage { + switch usage.Type { + case "tokens": + tokens := usage.AsTokens() + return TranscriptionUsage{ + Type: usage.Type, + InputTokens: tokens.InputTokens, + OutputTokens: tokens.OutputTokens, + TotalTokens: tokens.TotalTokens, + AudioTokens: tokens.InputTokenDetails.AudioTokens, + TextTokens: tokens.InputTokenDetails.TextTokens, + } + case "duration": + duration := usage.AsDuration() + return TranscriptionUsage{ + Type: usage.Type, + Seconds: duration.Seconds, + } + default: + return TranscriptionUsage{} + } +} diff --git a/transcriber.go b/transcriber.go new file mode 100644 index 0000000..9cef7f3 --- /dev/null +++ b/transcriber.go @@ -0,0 +1,145 @@ +package llm + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" +) + +// Transcriber abstracts a speech-to-text model implementation. +type Transcriber interface { + Transcribe(ctx context.Context, wav []byte, opts TranscriptionOptions) (Transcription, error) +} + +// TranscriptionResponseFormat controls the output format requested from a transcriber. +type TranscriptionResponseFormat string + +const ( + TranscriptionResponseFormatJSON TranscriptionResponseFormat = "json" + TranscriptionResponseFormatVerboseJSON TranscriptionResponseFormat = "verbose_json" + TranscriptionResponseFormatText TranscriptionResponseFormat = "text" + TranscriptionResponseFormatSRT TranscriptionResponseFormat = "srt" + TranscriptionResponseFormatVTT TranscriptionResponseFormat = "vtt" +) + +// TranscriptionTimestampGranularity defines the requested timestamp detail. +type TranscriptionTimestampGranularity string + +const ( + TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word" + TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" +) + +// TranscriptionOptions configures transcription behavior. +type TranscriptionOptions struct { + Language string + Prompt string + Temperature *float64 + ResponseFormat TranscriptionResponseFormat + TimestampGranularities []TranscriptionTimestampGranularity + IncludeLogprobs bool +} + +// Transcription captures a normalized transcription result. +type Transcription struct { + Provider string + Model string + Text string + Language string + DurationSeconds float64 + Segments []TranscriptionSegment + Words []TranscriptionWord + Logprobs []TranscriptionTokenLogprob + Usage TranscriptionUsage + RawJSON string +} + +// TranscriptionSegment provides a coarse time-sliced transcription segment. +type TranscriptionSegment struct { + ID int + Start float64 + End float64 + Text string + Tokens []int + AvgLogprob *float64 + CompressionRatio *float64 + NoSpeechProb *float64 + Words []TranscriptionWord +} + +// TranscriptionWord provides a word-level timestamp. +type TranscriptionWord struct { + Word string + Start float64 + End float64 + Confidence *float64 +} + +// TranscriptionTokenLogprob captures token-level log probability details. +type TranscriptionTokenLogprob struct { + Token string + Bytes []float64 + Logprob float64 +} + +// TranscriptionUsage captures token or duration usage details. +type TranscriptionUsage struct { + Type string + InputTokens int64 + OutputTokens int64 + TotalTokens int64 + AudioTokens int64 + TextTokens int64 + Seconds float64 +} + +// TranscribeFile converts an audio file to WAV and transcribes it. +func TranscribeFile(ctx context.Context, filename string, transcriber Transcriber, opts TranscriptionOptions) (Transcription, error) { + if transcriber == nil { + return Transcription{}, fmt.Errorf("transcriber is nil") + } + + wav, err := audioFileToWav(ctx, filename) + if err != nil { + return Transcription{}, err + } + + return transcriber.Transcribe(ctx, wav, opts) +} + +func audioFileToWav(ctx context.Context, filename string) ([]byte, error) { + if filename == "" { + return nil, fmt.Errorf("filename is empty") + } + + if strings.EqualFold(filepath.Ext(filename), ".wav") { + data, err := os.ReadFile(filename) + if err != nil { + return nil, fmt.Errorf("read wav file: %w", err) + } + return data, nil + } + + tempFile, err := os.CreateTemp("", "go-llm-audio-*.wav") + if err != nil { + return nil, fmt.Errorf("create temp wav file: %w", err) + } + tempPath := tempFile.Name() + _ = tempFile.Close() + defer os.Remove(tempPath) + + cmd := exec.CommandContext(ctx, "ffmpeg", "-hide_banner", "-loglevel", "error", "-y", "-i", filename, "-vn", "-f", "wav", tempPath) + if output, err := cmd.CombinedOutput(); err != nil { + return nil, fmt.Errorf("ffmpeg convert failed: %w (output: %s)", err, strings.TrimSpace(string(output))) + } + + data, err := os.ReadFile(tempPath) + if err != nil { + return nil, fmt.Errorf("read converted wav file: %w", err) + } + + return data, nil +}