From 0d909edd44d94210d2f1bd6eca2b7253b354cad0 Mon Sep 17 00:00:00 2001 From: Steve Dudenhoeffer Date: Wed, 22 Jan 2025 23:56:20 -0500 Subject: [PATCH] Refactor Google LLM adapter to support tool schemas. Enhanced the `requestToChatHistory` method to include OpenAI schema conversion logic and integrate tools with generative AI schemas. This change improves flexibility when working with different schema types and tool definitions. Adjusted response handling to return a modified model alongside chat sessions and parts. --- google.go | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 63 insertions(+), 5 deletions(-) diff --git a/google.go b/google.go index bef5b9f..59e723f 100644 --- a/google.go +++ b/google.go @@ -5,6 +5,8 @@ import ( "encoding/json" "fmt" + "github.com/sashabaranov/go-openai/jsonschema" + "github.com/google/generative-ai-go/genai" "google.golang.org/api/option" ) @@ -20,8 +22,64 @@ func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) { return g, nil } -func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.ChatSession, []genai.Part) { - cs := model.StartChat() +func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.GenerativeModel, *genai.ChatSession, []genai.Part) { + res := *model + + var openAiSchemaToGenAISchema func(in jsonschema.Definition) *genai.Schema + openAiSchemaToGenAISchema = func(in jsonschema.Definition) *genai.Schema { + res := genai.Schema{} + + switch in.Type { + case jsonschema.Object: + res.Type = genai.TypeObject + + case jsonschema.Array: + res.Type = genai.TypeArray + + case jsonschema.String: + res.Type = genai.TypeString + + case jsonschema.Number: + res.Type = genai.TypeNumber + + case jsonschema.Boolean: + res.Type = genai.TypeBoolean + + default: + res.Type = genai.TypeUnspecified + } + + res.Description = in.Description + res.Enum = in.Enum + + res.Required = in.Required + if in.Properties != nil { + res.Properties = map[string]*genai.Schema{} + + for k, v := range in.Properties { + res.Properties[k] = openAiSchemaToGenAISchema(v) + } + } + + return &res + } + if in.Toolbox != nil { + for _, tool := range in.Toolbox.funcs { + def := tool.Parameters.Definition() + + res.Tools = append(res.Tools, &genai.Tool{ + FunctionDeclarations: []*genai.FunctionDeclaration{ + { + Name: tool.Name, + Description: tool.Description, + Parameters: openAiSchemaToGenAISchema(def), + }, + }, + }) + } + } + + cs := res.StartChat() for i, c := range in.Messages { content := genai.NewUserContent(genai.Text(c.Text)) @@ -36,13 +94,13 @@ func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) ( // if this is the last message, we want to add to history, we want it to be the parts if i == len(in.Messages)-1 { - return cs, content.Parts + return &res, cs, content.Parts } cs.History = append(cs.History, content) } - return cs, nil + return &res, cs, nil } func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) { @@ -99,7 +157,7 @@ func (g google) ChatComplete(ctx context.Context, req Request) (Response, error) model := cl.GenerativeModel(g.model) - cs, parts := g.requestToChatHistory(req, model) + _, cs, parts := g.requestToChatHistory(req, model) resp, err := cs.SendMessage(ctx, parts...)