Steve Dudenhoeffer
388a44fa79
Replace message handling with a chat session model, aligning the logic with new API requirements. Adjust functions to properly build chat history and send messages via chat sessions, improving compatibility and extensibility.
116 lines
2.4 KiB
Go
116 lines
2.4 KiB
Go
package go_llm
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"github.com/google/generative-ai-go/genai"
|
|
"google.golang.org/api/option"
|
|
)
|
|
|
|
type google struct {
|
|
key string
|
|
model string
|
|
}
|
|
|
|
func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
|
g.model = modelVersion
|
|
|
|
return g, nil
|
|
}
|
|
|
|
func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.ChatSession, []genai.Part) {
|
|
cs := model.StartChat()
|
|
|
|
for i, c := range in.Messages {
|
|
content := genai.NewUserContent(genai.Text(c.Text))
|
|
|
|
switch c.Role {
|
|
case RoleAssistant, RoleSystem:
|
|
content.Role = "model"
|
|
|
|
case RoleUser:
|
|
content.Role = "user"
|
|
}
|
|
|
|
// if this is the last message, we want to add to history, we want it to be the parts
|
|
if i == len(in.Messages)-1 {
|
|
return cs, content.Parts
|
|
}
|
|
|
|
cs.History = append(cs.History, content)
|
|
}
|
|
|
|
return cs, nil
|
|
}
|
|
|
|
func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) {
|
|
res := Response{}
|
|
|
|
for _, c := range in.Candidates {
|
|
if c.Content != nil {
|
|
for _, p := range c.Content.Parts {
|
|
switch p.(type) {
|
|
case genai.Text:
|
|
res.Choices = append(res.Choices, ResponseChoice{
|
|
Content: string(p.(genai.Text)),
|
|
})
|
|
|
|
case genai.FunctionCall:
|
|
v := p.(genai.FunctionCall)
|
|
choice := ResponseChoice{}
|
|
|
|
choice.Content = v.Name
|
|
b, e := json.Marshal(v.Args)
|
|
|
|
if e != nil {
|
|
return Response{}, fmt.Errorf("error marshalling args: %w", e)
|
|
}
|
|
|
|
call := ToolCall{
|
|
ID: v.Name,
|
|
FunctionCall: FunctionCall{
|
|
Name: v.Name,
|
|
Arguments: string(b),
|
|
},
|
|
}
|
|
|
|
choice.Calls = append(choice.Calls, call)
|
|
|
|
res.Choices = append(res.Choices, choice)
|
|
|
|
default:
|
|
return Response{}, fmt.Errorf("unknown part type: %T", p)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (g google) ChatComplete(ctx context.Context, req Request) (Response, error) {
|
|
cl, err := genai.NewClient(ctx, option.WithAPIKey(g.key))
|
|
|
|
if err != nil {
|
|
return Response{}, fmt.Errorf("error creating genai client: %w", err)
|
|
}
|
|
|
|
model := cl.GenerativeModel(g.model)
|
|
|
|
cs, parts := g.requestToChatHistory(req, model)
|
|
|
|
resp, err := cs.SendMessage(ctx, parts...)
|
|
|
|
//parts := g.requestToGoogleRequest(req, model)
|
|
|
|
//resp, err := model.GenerateContent(ctx, parts...)
|
|
|
|
if err != nil {
|
|
return Response{}, fmt.Errorf("error generating content: %w", err)
|
|
}
|
|
|
|
return g.responseToLLMResponse(resp)
|
|
}
|