diff --git a/google.go b/google.go index 9d7062c..bef5b9f 100644 --- a/google.go +++ b/google.go @@ -20,34 +20,29 @@ func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) { return g, nil } -func (g google) requestToGoogleRequest(in Request, model *genai.GenerativeModel) []genai.Part { +func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.ChatSession, []genai.Part) { + cs := model.StartChat() - if in.Temperature != nil { - model.GenerationConfig.Temperature = in.Temperature - } + for i, c := range in.Messages { + content := genai.NewUserContent(genai.Text(c.Text)) - res := []genai.Part{} + switch c.Role { + case RoleAssistant, RoleSystem: + content.Role = "model" - for _, c := range in.Messages { - res = append(res, genai.Text(c.Text)) - } - - if in.Toolbox != nil { - for _, tool := range in.Toolbox.funcs { - panic("google ToolBox is todo" + tool.Name) - - /* - t := genai.Tool{} - t.FunctionDeclarations = append(t.FunctionDeclarations, &genai.FunctionDeclaration{ - Name: tool.Name, - Description: tool.Description, - Parameters: nil, //tool.Parameters, - }) - */ + case RoleUser: + content.Role = "user" } + + // if this is the last message, we want to add to history, we want it to be the parts + if i == len(in.Messages)-1 { + return cs, content.Parts + } + + cs.History = append(cs.History, content) } - return res + return cs, nil } func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) { @@ -104,9 +99,13 @@ func (g google) ChatComplete(ctx context.Context, req Request) (Response, error) model := cl.GenerativeModel(g.model) - parts := g.requestToGoogleRequest(req, model) + cs, parts := g.requestToChatHistory(req, model) - resp, err := model.GenerateContent(ctx, parts...) + resp, err := cs.SendMessage(ctx, parts...) + + //parts := g.requestToGoogleRequest(req, model) + + //resp, err := model.GenerateContent(ctx, parts...) if err != nil { return Response{}, fmt.Errorf("error generating content: %w", err)