Steve Dudenhoeffer
e7b7aab62e
Implemented a nil check for Toolbox to prevent potential nil pointer dereferences. Cleaned up and reorganized code for better readability and maintainability while keeping placeholder functionality intact.
117 lines
2.4 KiB
Go
117 lines
2.4 KiB
Go
package go_llm
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"github.com/google/generative-ai-go/genai"
|
|
"google.golang.org/api/option"
|
|
)
|
|
|
|
type google struct {
|
|
key string
|
|
model string
|
|
}
|
|
|
|
func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
|
g.model = modelVersion
|
|
|
|
return g, nil
|
|
}
|
|
|
|
func (g google) requestToGoogleRequest(in Request, model *genai.GenerativeModel) []genai.Part {
|
|
|
|
if in.Temperature != nil {
|
|
model.GenerationConfig.Temperature = in.Temperature
|
|
}
|
|
|
|
res := []genai.Part{}
|
|
|
|
for _, c := range in.Messages {
|
|
res = append(res, genai.Text(c.Text))
|
|
}
|
|
|
|
if in.Toolbox != nil {
|
|
for _, tool := range in.Toolbox.funcs {
|
|
panic("google ToolBox is todo" + tool.Name)
|
|
|
|
/*
|
|
t := genai.Tool{}
|
|
t.FunctionDeclarations = append(t.FunctionDeclarations, &genai.FunctionDeclaration{
|
|
Name: tool.Name,
|
|
Description: tool.Description,
|
|
Parameters: nil, //tool.Parameters,
|
|
})
|
|
*/
|
|
}
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) {
|
|
res := Response{}
|
|
|
|
for _, c := range in.Candidates {
|
|
if c.Content != nil {
|
|
for _, p := range c.Content.Parts {
|
|
switch p.(type) {
|
|
case genai.Text:
|
|
res.Choices = append(res.Choices, ResponseChoice{
|
|
Content: string(p.(genai.Text)),
|
|
})
|
|
|
|
case genai.FunctionCall:
|
|
v := p.(genai.FunctionCall)
|
|
choice := ResponseChoice{}
|
|
|
|
choice.Content = v.Name
|
|
b, e := json.Marshal(v.Args)
|
|
|
|
if e != nil {
|
|
return Response{}, fmt.Errorf("error marshalling args: %w", e)
|
|
}
|
|
|
|
call := ToolCall{
|
|
ID: v.Name,
|
|
FunctionCall: FunctionCall{
|
|
Name: v.Name,
|
|
Arguments: string(b),
|
|
},
|
|
}
|
|
|
|
choice.Calls = append(choice.Calls, call)
|
|
|
|
res.Choices = append(res.Choices, choice)
|
|
|
|
default:
|
|
return Response{}, fmt.Errorf("unknown part type: %T", p)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (g google) ChatComplete(ctx context.Context, req Request) (Response, error) {
|
|
cl, err := genai.NewClient(ctx, option.WithAPIKey(g.key))
|
|
|
|
if err != nil {
|
|
return Response{}, fmt.Errorf("error creating genai client: %w", err)
|
|
}
|
|
|
|
model := cl.GenerativeModel(g.model)
|
|
|
|
parts := g.requestToGoogleRequest(req, model)
|
|
|
|
resp, err := model.GenerateContent(ctx, parts...)
|
|
|
|
if err != nil {
|
|
return Response{}, fmt.Errorf("error generating content: %w", err)
|
|
}
|
|
|
|
return g.responseToLLMResponse(resp)
|
|
}
|