- Migrate `compress_image.go` to `internal/imageutil` for better encapsulation. - Reorganize LLM provider implementations into distinct packages (`google`, `openai`, and `anthropic`). - Replace `go_llm` package name with `llm`. - Refactor internal APIs for improved clarity, including renaming `anthropic` to `anthropicImpl` and `google` to `googleImpl`. - Add helper methods and restructure message handling for better separation of concerns.
166 lines
3.7 KiB
Go
166 lines
3.7 KiB
Go
package llm
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
|
|
"google.golang.org/genai"
|
|
)
|
|
|
|
type googleImpl struct {
|
|
key string
|
|
model string
|
|
}
|
|
|
|
var _ LLM = googleImpl{}
|
|
|
|
func (g googleImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
|
g.model = modelVersion
|
|
|
|
return g, nil
|
|
}
|
|
|
|
func (g googleImpl) requestToContents(in Request) ([]*genai.Content, *genai.GenerateContentConfig) {
|
|
var contents []*genai.Content
|
|
var cfg genai.GenerateContentConfig
|
|
|
|
for _, tool := range in.Toolbox.Functions() {
|
|
cfg.Tools = append(cfg.Tools, &genai.Tool{
|
|
FunctionDeclarations: []*genai.FunctionDeclaration{
|
|
{
|
|
Name: tool.Name,
|
|
Description: tool.Description,
|
|
Parameters: tool.Parameters.GoogleParameters(),
|
|
},
|
|
},
|
|
})
|
|
}
|
|
|
|
if in.Toolbox.RequiresTool() {
|
|
cfg.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{
|
|
Mode: genai.FunctionCallingConfigModeAny,
|
|
}}
|
|
}
|
|
|
|
for _, c := range in.Messages {
|
|
var role genai.Role
|
|
switch c.Role {
|
|
case RoleAssistant, RoleSystem:
|
|
role = genai.RoleModel
|
|
case RoleUser:
|
|
role = genai.RoleUser
|
|
}
|
|
|
|
var parts []*genai.Part
|
|
if c.Text != "" {
|
|
parts = append(parts, genai.NewPartFromText(c.Text))
|
|
}
|
|
|
|
for _, img := range c.Images {
|
|
if img.Url != "" {
|
|
// gemini does not support URLs, so we need to download the image and convert it to a blob
|
|
resp, err := http.Get(img.Url)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("error downloading image: %v", err))
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.ContentLength > 20*1024*1024 {
|
|
panic(fmt.Sprintf("image size exceeds 20MB: %d bytes", resp.ContentLength))
|
|
}
|
|
|
|
data, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("error reading image data: %v", err))
|
|
}
|
|
|
|
mimeType := http.DetectContentType(data)
|
|
switch mimeType {
|
|
case "image/jpeg", "image/png", "image/gif":
|
|
// MIME type is valid
|
|
default:
|
|
panic(fmt.Sprintf("unsupported image MIME type: %s", mimeType))
|
|
}
|
|
|
|
parts = append(parts, genai.NewPartFromBytes(data, mimeType))
|
|
} else {
|
|
b, e := base64.StdEncoding.DecodeString(img.Base64)
|
|
if e != nil {
|
|
panic(fmt.Sprintf("error decoding base64: %v", e))
|
|
}
|
|
|
|
parts = append(parts, genai.NewPartFromBytes(b, img.ContentType))
|
|
}
|
|
}
|
|
|
|
contents = append(contents, genai.NewContentFromParts(parts, role))
|
|
}
|
|
|
|
return contents, &cfg
|
|
}
|
|
|
|
func (g googleImpl) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) {
|
|
res := Response{}
|
|
|
|
for _, c := range in.Candidates {
|
|
var choice ResponseChoice
|
|
var set = false
|
|
if c.Content != nil {
|
|
for _, p := range c.Content.Parts {
|
|
if p.Text != "" {
|
|
set = true
|
|
choice.Content = p.Text
|
|
} else if p.FunctionCall != nil {
|
|
v := p.FunctionCall
|
|
b, e := json.Marshal(v.Args)
|
|
if e != nil {
|
|
return Response{}, fmt.Errorf("error marshalling args: %w", e)
|
|
}
|
|
|
|
call := ToolCall{
|
|
ID: v.Name,
|
|
FunctionCall: FunctionCall{
|
|
Name: v.Name,
|
|
Arguments: string(b),
|
|
},
|
|
}
|
|
|
|
choice.Calls = append(choice.Calls, call)
|
|
set = true
|
|
}
|
|
}
|
|
}
|
|
|
|
if set {
|
|
choice.Role = RoleAssistant
|
|
res.Choices = append(res.Choices, choice)
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (g googleImpl) ChatComplete(ctx context.Context, req Request) (Response, error) {
|
|
cl, err := genai.NewClient(ctx, &genai.ClientConfig{
|
|
APIKey: g.key,
|
|
Backend: genai.BackendGeminiAPI,
|
|
})
|
|
|
|
if err != nil {
|
|
return Response{}, fmt.Errorf("error creating genai client: %w", err)
|
|
}
|
|
|
|
contents, cfg := g.requestToContents(req)
|
|
|
|
resp, err := cl.Models.GenerateContent(ctx, g.model, contents, cfg)
|
|
if err != nil {
|
|
return Response{}, fmt.Errorf("error generating content: %w", err)
|
|
}
|
|
|
|
return g.responseToLLMResponse(resp)
|
|
}
|