diff --git a/google.go b/google.go index bef5b9f..59e723f 100644 --- a/google.go +++ b/google.go @@ -5,6 +5,8 @@ import ( "encoding/json" "fmt" + "github.com/sashabaranov/go-openai/jsonschema" + "github.com/google/generative-ai-go/genai" "google.golang.org/api/option" ) @@ -20,8 +22,64 @@ func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) { return g, nil } -func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.ChatSession, []genai.Part) { - cs := model.StartChat() +func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.GenerativeModel, *genai.ChatSession, []genai.Part) { + res := *model + + var openAiSchemaToGenAISchema func(in jsonschema.Definition) *genai.Schema + openAiSchemaToGenAISchema = func(in jsonschema.Definition) *genai.Schema { + res := genai.Schema{} + + switch in.Type { + case jsonschema.Object: + res.Type = genai.TypeObject + + case jsonschema.Array: + res.Type = genai.TypeArray + + case jsonschema.String: + res.Type = genai.TypeString + + case jsonschema.Number: + res.Type = genai.TypeNumber + + case jsonschema.Boolean: + res.Type = genai.TypeBoolean + + default: + res.Type = genai.TypeUnspecified + } + + res.Description = in.Description + res.Enum = in.Enum + + res.Required = in.Required + if in.Properties != nil { + res.Properties = map[string]*genai.Schema{} + + for k, v := range in.Properties { + res.Properties[k] = openAiSchemaToGenAISchema(v) + } + } + + return &res + } + if in.Toolbox != nil { + for _, tool := range in.Toolbox.funcs { + def := tool.Parameters.Definition() + + res.Tools = append(res.Tools, &genai.Tool{ + FunctionDeclarations: []*genai.FunctionDeclaration{ + { + Name: tool.Name, + Description: tool.Description, + Parameters: openAiSchemaToGenAISchema(def), + }, + }, + }) + } + } + + cs := res.StartChat() for i, c := range in.Messages { content := genai.NewUserContent(genai.Text(c.Text)) @@ -36,13 +94,13 @@ func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) ( // if this is the last message, we want to add to history, we want it to be the parts if i == len(in.Messages)-1 { - return cs, content.Parts + return &res, cs, content.Parts } cs.History = append(cs.History, content) } - return cs, nil + return &res, cs, nil } func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) { @@ -99,7 +157,7 @@ func (g google) ChatComplete(ctx context.Context, req Request) (Response, error) model := cl.GenerativeModel(g.model) - cs, parts := g.requestToChatHistory(req, model) + _, cs, parts := g.requestToChatHistory(req, model) resp, err := cs.SendMessage(ctx, parts...)