go-llm/google.go
Steve Dudenhoeffer 388a44fa79 Refactor Google LLM API to use chat session interface.
Replace message handling with a chat session model, aligning the logic with new API requirements. Adjust functions to properly build chat history and send messages via chat sessions, improving compatibility and extensibility.
2025-01-22 22:07:20 -05:00

116 lines
2.4 KiB
Go

package go_llm
import (
"context"
"encoding/json"
"fmt"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/option"
)
type google struct {
key string
model string
}
func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) {
g.model = modelVersion
return g, nil
}
func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.ChatSession, []genai.Part) {
cs := model.StartChat()
for i, c := range in.Messages {
content := genai.NewUserContent(genai.Text(c.Text))
switch c.Role {
case RoleAssistant, RoleSystem:
content.Role = "model"
case RoleUser:
content.Role = "user"
}
// if this is the last message, we want to add to history, we want it to be the parts
if i == len(in.Messages)-1 {
return cs, content.Parts
}
cs.History = append(cs.History, content)
}
return cs, nil
}
func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) {
res := Response{}
for _, c := range in.Candidates {
if c.Content != nil {
for _, p := range c.Content.Parts {
switch p.(type) {
case genai.Text:
res.Choices = append(res.Choices, ResponseChoice{
Content: string(p.(genai.Text)),
})
case genai.FunctionCall:
v := p.(genai.FunctionCall)
choice := ResponseChoice{}
choice.Content = v.Name
b, e := json.Marshal(v.Args)
if e != nil {
return Response{}, fmt.Errorf("error marshalling args: %w", e)
}
call := ToolCall{
ID: v.Name,
FunctionCall: FunctionCall{
Name: v.Name,
Arguments: string(b),
},
}
choice.Calls = append(choice.Calls, call)
res.Choices = append(res.Choices, choice)
default:
return Response{}, fmt.Errorf("unknown part type: %T", p)
}
}
}
}
return res, nil
}
func (g google) ChatComplete(ctx context.Context, req Request) (Response, error) {
cl, err := genai.NewClient(ctx, option.WithAPIKey(g.key))
if err != nil {
return Response{}, fmt.Errorf("error creating genai client: %w", err)
}
model := cl.GenerativeModel(g.model)
cs, parts := g.requestToChatHistory(req, model)
resp, err := cs.SendMessage(ctx, parts...)
//parts := g.requestToGoogleRequest(req, model)
//resp, err := model.GenerateContent(ctx, parts...)
if err != nil {
return Response{}, fmt.Errorf("error generating content: %w", err)
}
return g.responseToLLMResponse(resp)
}