go-llm/google.go

174 lines
3.7 KiB
Go
Raw Normal View History

2024-10-06 22:16:26 -04:00
package go_llm
import (
"context"
"encoding/json"
2024-10-06 22:16:26 -04:00
"fmt"
"github.com/sashabaranov/go-openai/jsonschema"
2024-10-06 22:16:26 -04:00
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/option"
)
type google struct {
key string
model string
}
func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) {
g.model = modelVersion
return g, nil
}
func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.GenerativeModel, *genai.ChatSession, []genai.Part) {
res := *model
var openAiSchemaToGenAISchema func(in jsonschema.Definition) *genai.Schema
openAiSchemaToGenAISchema = func(in jsonschema.Definition) *genai.Schema {
res := genai.Schema{}
switch in.Type {
case jsonschema.Object:
res.Type = genai.TypeObject
case jsonschema.Array:
res.Type = genai.TypeArray
case jsonschema.String:
res.Type = genai.TypeString
case jsonschema.Number:
res.Type = genai.TypeNumber
case jsonschema.Boolean:
res.Type = genai.TypeBoolean
default:
res.Type = genai.TypeUnspecified
}
res.Description = in.Description
res.Enum = in.Enum
res.Required = in.Required
if in.Properties != nil {
res.Properties = map[string]*genai.Schema{}
for k, v := range in.Properties {
res.Properties[k] = openAiSchemaToGenAISchema(v)
}
}
return &res
}
if in.Toolbox != nil {
for _, tool := range in.Toolbox.funcs {
def := tool.Parameters.Definition()
res.Tools = append(res.Tools, &genai.Tool{
FunctionDeclarations: []*genai.FunctionDeclaration{
{
Name: tool.Name,
Description: tool.Description,
Parameters: openAiSchemaToGenAISchema(def),
},
},
})
}
}
cs := res.StartChat()
2024-10-06 22:16:26 -04:00
for i, c := range in.Messages {
content := genai.NewUserContent(genai.Text(c.Text))
2024-10-06 22:16:26 -04:00
switch c.Role {
case RoleAssistant, RoleSystem:
content.Role = "model"
2024-10-06 22:16:26 -04:00
case RoleUser:
content.Role = "user"
}
2024-10-06 22:16:26 -04:00
// if this is the last message, we want to add to history, we want it to be the parts
if i == len(in.Messages)-1 {
return &res, cs, content.Parts
}
cs.History = append(cs.History, content)
2024-10-06 22:16:26 -04:00
}
return &res, cs, nil
2024-10-06 22:16:26 -04:00
}
func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) {
res := Response{}
for _, c := range in.Candidates {
if c.Content != nil {
for _, p := range c.Content.Parts {
switch p.(type) {
case genai.Text:
res.Choices = append(res.Choices, ResponseChoice{
Content: string(p.(genai.Text)),
})
case genai.FunctionCall:
v := p.(genai.FunctionCall)
choice := ResponseChoice{}
choice.Content = v.Name
b, e := json.Marshal(v.Args)
if e != nil {
return Response{}, fmt.Errorf("error marshalling args: %w", e)
}
2024-10-06 22:16:26 -04:00
call := ToolCall{
ID: v.Name,
FunctionCall: FunctionCall{
Name: v.Name,
Arguments: string(b),
2024-10-06 22:16:26 -04:00
},
}
choice.Calls = append(choice.Calls, call)
res.Choices = append(res.Choices, choice)
default:
return Response{}, fmt.Errorf("unknown part type: %T", p)
}
}
}
}
return res, nil
}
func (g google) ChatComplete(ctx context.Context, req Request) (Response, error) {
cl, err := genai.NewClient(ctx, option.WithAPIKey(g.key))
if err != nil {
return Response{}, fmt.Errorf("error creating genai client: %w", err)
}
model := cl.GenerativeModel(g.model)
_, cs, parts := g.requestToChatHistory(req, model)
resp, err := cs.SendMessage(ctx, parts...)
//parts := g.requestToGoogleRequest(req, model)
2024-10-06 22:16:26 -04:00
//resp, err := model.GenerateContent(ctx, parts...)
2024-10-06 22:16:26 -04:00
if err != nil {
return Response{}, fmt.Errorf("error generating content: %w", err)
}
return g.responseToLLMResponse(resp)
}