go-llm/google.go

107 lines
2.2 KiB
Go
Raw Normal View History

2024-10-06 22:16:26 -04:00
package go_llm
import (
"context"
"fmt"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/option"
)
type google struct {
key string
model string
}
func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) {
g.model = modelVersion
return g, nil
}
func (g google) requestToGoogleRequest(in Request, model *genai.GenerativeModel) []genai.Part {
if in.Temperature != nil {
model.GenerationConfig.Temperature = in.Temperature
}
res := []genai.Part{}
for _, c := range in.Messages {
res = append(res, genai.Text(c.Text))
}
for _, tool := range in.Toolbox {
panic("google ToolBox is todo" + tool.Name)
2024-10-06 22:16:26 -04:00
/*
t := genai.Tool{}
t.FunctionDeclarations = append(t.FunctionDeclarations, &genai.FunctionDeclaration{
Name: tool.Name,
Description: tool.Description,
Parameters: nil, //tool.Parameters,
})
*/
}
return res
}
func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) {
res := Response{}
for _, c := range in.Candidates {
if c.Content != nil {
for _, p := range c.Content.Parts {
switch p.(type) {
case genai.Text:
res.Choices = append(res.Choices, ResponseChoice{
Content: string(p.(genai.Text)),
})
case genai.FunctionCall:
v := p.(genai.FunctionCall)
choice := ResponseChoice{}
choice.Content = v.Name
call := ToolCall{
ID: v.Name,
FunctionCall: FunctionCall{
Name: v.Name,
Arguments: v.Args,
},
}
choice.Calls = append(choice.Calls, call)
res.Choices = append(res.Choices, choice)
default:
return Response{}, fmt.Errorf("unknown part type: %T", p)
}
}
}
}
return res, nil
}
func (g google) ChatComplete(ctx context.Context, req Request) (Response, error) {
cl, err := genai.NewClient(ctx, option.WithAPIKey(g.key))
if err != nil {
return Response{}, fmt.Errorf("error creating genai client: %w", err)
}
model := cl.GenerativeModel(g.model)
parts := g.requestToGoogleRequest(req, model)
resp, err := model.GenerateContent(ctx, parts...)
if err != nil {
return Response{}, fmt.Errorf("error generating content: %w", err)
}
return g.responseToLLMResponse(resp)
}