107 lines
2.2 KiB
Go
107 lines
2.2 KiB
Go
|
package go_llm
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"github.com/google/generative-ai-go/genai"
|
||
|
"google.golang.org/api/option"
|
||
|
)
|
||
|
|
||
|
type google struct {
|
||
|
key string
|
||
|
model string
|
||
|
}
|
||
|
|
||
|
func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
||
|
g.model = modelVersion
|
||
|
|
||
|
return g, nil
|
||
|
}
|
||
|
|
||
|
func (g google) requestToGoogleRequest(in Request, model *genai.GenerativeModel) []genai.Part {
|
||
|
|
||
|
if in.Temperature != nil {
|
||
|
model.GenerationConfig.Temperature = in.Temperature
|
||
|
}
|
||
|
|
||
|
res := []genai.Part{}
|
||
|
|
||
|
for _, c := range in.Messages {
|
||
|
res = append(res, genai.Text(c.Text))
|
||
|
}
|
||
|
|
||
|
for _, tool := range in.Toolbox {
|
||
|
panic("google toolbox is todo" + tool.Name)
|
||
|
|
||
|
/*
|
||
|
t := genai.Tool{}
|
||
|
t.FunctionDeclarations = append(t.FunctionDeclarations, &genai.FunctionDeclaration{
|
||
|
Name: tool.Name,
|
||
|
Description: tool.Description,
|
||
|
Parameters: nil, //tool.Parameters,
|
||
|
})
|
||
|
*/
|
||
|
}
|
||
|
|
||
|
return res
|
||
|
}
|
||
|
|
||
|
func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) {
|
||
|
res := Response{}
|
||
|
|
||
|
for _, c := range in.Candidates {
|
||
|
if c.Content != nil {
|
||
|
for _, p := range c.Content.Parts {
|
||
|
switch p.(type) {
|
||
|
case genai.Text:
|
||
|
res.Choices = append(res.Choices, ResponseChoice{
|
||
|
Content: string(p.(genai.Text)),
|
||
|
})
|
||
|
|
||
|
case genai.FunctionCall:
|
||
|
v := p.(genai.FunctionCall)
|
||
|
choice := ResponseChoice{}
|
||
|
|
||
|
choice.Content = v.Name
|
||
|
call := ToolCall{
|
||
|
ID: v.Name,
|
||
|
FunctionCall: FunctionCall{
|
||
|
Name: v.Name,
|
||
|
Arguments: v.Args,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
choice.Calls = append(choice.Calls, call)
|
||
|
|
||
|
res.Choices = append(res.Choices, choice)
|
||
|
|
||
|
default:
|
||
|
return Response{}, fmt.Errorf("unknown part type: %T", p)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return res, nil
|
||
|
}
|
||
|
|
||
|
func (g google) ChatComplete(ctx context.Context, req Request) (Response, error) {
|
||
|
cl, err := genai.NewClient(ctx, option.WithAPIKey(g.key))
|
||
|
|
||
|
if err != nil {
|
||
|
return Response{}, fmt.Errorf("error creating genai client: %w", err)
|
||
|
}
|
||
|
|
||
|
model := cl.GenerativeModel(g.model)
|
||
|
|
||
|
parts := g.requestToGoogleRequest(req, model)
|
||
|
|
||
|
resp, err := model.GenerateContent(ctx, parts...)
|
||
|
|
||
|
if err != nil {
|
||
|
return Response{}, fmt.Errorf("error generating content: %w", err)
|
||
|
}
|
||
|
|
||
|
return g.responseToLLMResponse(resp)
|
||
|
}
|