Steve Dudenhoeffer
0d909edd44
Enhanced the `requestToChatHistory` method to include OpenAI schema conversion logic and integrate tools with generative AI schemas. This change improves flexibility when working with different schema types and tool definitions. Adjusted response handling to return a modified model alongside chat sessions and parts.
174 lines
3.7 KiB
Go
174 lines
3.7 KiB
Go
package go_llm
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"github.com/sashabaranov/go-openai/jsonschema"
|
|
|
|
"github.com/google/generative-ai-go/genai"
|
|
"google.golang.org/api/option"
|
|
)
|
|
|
|
type google struct {
|
|
key string
|
|
model string
|
|
}
|
|
|
|
func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
|
g.model = modelVersion
|
|
|
|
return g, nil
|
|
}
|
|
|
|
func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.GenerativeModel, *genai.ChatSession, []genai.Part) {
|
|
res := *model
|
|
|
|
var openAiSchemaToGenAISchema func(in jsonschema.Definition) *genai.Schema
|
|
openAiSchemaToGenAISchema = func(in jsonschema.Definition) *genai.Schema {
|
|
res := genai.Schema{}
|
|
|
|
switch in.Type {
|
|
case jsonschema.Object:
|
|
res.Type = genai.TypeObject
|
|
|
|
case jsonschema.Array:
|
|
res.Type = genai.TypeArray
|
|
|
|
case jsonschema.String:
|
|
res.Type = genai.TypeString
|
|
|
|
case jsonschema.Number:
|
|
res.Type = genai.TypeNumber
|
|
|
|
case jsonschema.Boolean:
|
|
res.Type = genai.TypeBoolean
|
|
|
|
default:
|
|
res.Type = genai.TypeUnspecified
|
|
}
|
|
|
|
res.Description = in.Description
|
|
res.Enum = in.Enum
|
|
|
|
res.Required = in.Required
|
|
if in.Properties != nil {
|
|
res.Properties = map[string]*genai.Schema{}
|
|
|
|
for k, v := range in.Properties {
|
|
res.Properties[k] = openAiSchemaToGenAISchema(v)
|
|
}
|
|
}
|
|
|
|
return &res
|
|
}
|
|
if in.Toolbox != nil {
|
|
for _, tool := range in.Toolbox.funcs {
|
|
def := tool.Parameters.Definition()
|
|
|
|
res.Tools = append(res.Tools, &genai.Tool{
|
|
FunctionDeclarations: []*genai.FunctionDeclaration{
|
|
{
|
|
Name: tool.Name,
|
|
Description: tool.Description,
|
|
Parameters: openAiSchemaToGenAISchema(def),
|
|
},
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
cs := res.StartChat()
|
|
|
|
for i, c := range in.Messages {
|
|
content := genai.NewUserContent(genai.Text(c.Text))
|
|
|
|
switch c.Role {
|
|
case RoleAssistant, RoleSystem:
|
|
content.Role = "model"
|
|
|
|
case RoleUser:
|
|
content.Role = "user"
|
|
}
|
|
|
|
// if this is the last message, we want to add to history, we want it to be the parts
|
|
if i == len(in.Messages)-1 {
|
|
return &res, cs, content.Parts
|
|
}
|
|
|
|
cs.History = append(cs.History, content)
|
|
}
|
|
|
|
return &res, cs, nil
|
|
}
|
|
|
|
func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) {
|
|
res := Response{}
|
|
|
|
for _, c := range in.Candidates {
|
|
if c.Content != nil {
|
|
for _, p := range c.Content.Parts {
|
|
switch p.(type) {
|
|
case genai.Text:
|
|
res.Choices = append(res.Choices, ResponseChoice{
|
|
Content: string(p.(genai.Text)),
|
|
})
|
|
|
|
case genai.FunctionCall:
|
|
v := p.(genai.FunctionCall)
|
|
choice := ResponseChoice{}
|
|
|
|
choice.Content = v.Name
|
|
b, e := json.Marshal(v.Args)
|
|
|
|
if e != nil {
|
|
return Response{}, fmt.Errorf("error marshalling args: %w", e)
|
|
}
|
|
|
|
call := ToolCall{
|
|
ID: v.Name,
|
|
FunctionCall: FunctionCall{
|
|
Name: v.Name,
|
|
Arguments: string(b),
|
|
},
|
|
}
|
|
|
|
choice.Calls = append(choice.Calls, call)
|
|
|
|
res.Choices = append(res.Choices, choice)
|
|
|
|
default:
|
|
return Response{}, fmt.Errorf("unknown part type: %T", p)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (g google) ChatComplete(ctx context.Context, req Request) (Response, error) {
|
|
cl, err := genai.NewClient(ctx, option.WithAPIKey(g.key))
|
|
|
|
if err != nil {
|
|
return Response{}, fmt.Errorf("error creating genai client: %w", err)
|
|
}
|
|
|
|
model := cl.GenerativeModel(g.model)
|
|
|
|
_, cs, parts := g.requestToChatHistory(req, model)
|
|
|
|
resp, err := cs.SendMessage(ctx, parts...)
|
|
|
|
//parts := g.requestToGoogleRequest(req, model)
|
|
|
|
//resp, err := model.GenerateContent(ctx, parts...)
|
|
|
|
if err != nil {
|
|
return Response{}, fmt.Errorf("error generating content: %w", err)
|
|
}
|
|
|
|
return g.responseToLLMResponse(resp)
|
|
}
|