170 lines
3.7 KiB
Go
170 lines
3.7 KiB
Go
package schema
|
|
|
|
import (
|
|
"errors"
|
|
"reflect"
|
|
|
|
"github.com/google/generative-ai-go/genai"
|
|
"github.com/openai/openai-go"
|
|
)
|
|
|
|
const (
|
|
// SyntheticFieldPrefix is any prefix that is added to any synthetic fields that are added to the object, to prevent
|
|
// collisions with the fields in the struct.
|
|
SyntheticFieldPrefix = "__"
|
|
)
|
|
|
|
type Object struct {
|
|
basic
|
|
|
|
ref reflect.Type
|
|
|
|
fields map[string]Type
|
|
|
|
// syntheticFields are fields that are not in the struct but are generated by a system.
|
|
synetheticFields map[string]Type
|
|
}
|
|
|
|
func (o Object) WithSyntheticField(name string, description string) Object {
|
|
if o.synetheticFields == nil {
|
|
o.synetheticFields = map[string]Type{}
|
|
}
|
|
|
|
o.synetheticFields[name] = basic{
|
|
DataType: TypeString,
|
|
typeName: "string",
|
|
index: -1,
|
|
required: false,
|
|
description: description,
|
|
}
|
|
|
|
return o
|
|
}
|
|
|
|
func (o Object) SyntheticFields() map[string]Type {
|
|
return o.synetheticFields
|
|
}
|
|
|
|
func (o Object) OpenAIParameters() openai.FunctionParameters {
|
|
var properties = map[string]openai.FunctionParameters{}
|
|
var required []string
|
|
for k, v := range o.fields {
|
|
properties[k] = v.OpenAIParameters()
|
|
if v.Required() {
|
|
required = append(required, k)
|
|
}
|
|
}
|
|
|
|
for k, v := range o.synetheticFields {
|
|
properties[SyntheticFieldPrefix+k] = v.OpenAIParameters()
|
|
if v.Required() {
|
|
required = append(required, SyntheticFieldPrefix+k)
|
|
}
|
|
}
|
|
|
|
var res = openai.FunctionParameters{
|
|
"type": "object",
|
|
"description": o.Description(),
|
|
"properties": properties,
|
|
}
|
|
|
|
if len(required) > 0 {
|
|
res["required"] = required
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
func (o Object) GoogleParameters() *genai.Schema {
|
|
var properties = map[string]*genai.Schema{}
|
|
var required []string
|
|
for k, v := range o.fields {
|
|
properties[k] = v.GoogleParameters()
|
|
if v.Required() {
|
|
required = append(required, k)
|
|
}
|
|
}
|
|
|
|
var res = &genai.Schema{
|
|
Type: genai.TypeObject,
|
|
Description: o.Description(),
|
|
Properties: properties,
|
|
}
|
|
|
|
if len(required) > 0 {
|
|
res.Required = required
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
func (o Object) AnthropicInputSchema() map[string]any {
|
|
var properties = map[string]any{}
|
|
var required []string
|
|
for k, v := range o.fields {
|
|
properties[k] = v.AnthropicInputSchema()
|
|
if v.Required() {
|
|
required = append(required, k)
|
|
}
|
|
}
|
|
|
|
var res = map[string]any{
|
|
"type": "object",
|
|
"description": o.Description(),
|
|
"properties": properties,
|
|
}
|
|
|
|
if len(required) > 0 {
|
|
res["required"] = required
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
// FromAny converts the value from any to the correct type, returning the value, and an error if any
|
|
func (o Object) FromAny(val any) (reflect.Value, error) {
|
|
// if the value is nil, we can't do anything
|
|
if val == nil {
|
|
return reflect.Value{}, nil
|
|
}
|
|
|
|
// now make a new object of the type we're trying to parse
|
|
obj := reflect.New(o.ref).Elem()
|
|
|
|
// now we need to iterate over the fields and set the values
|
|
for k, v := range o.fields {
|
|
// get the field by name
|
|
field := obj.FieldByName(k)
|
|
if !field.IsValid() {
|
|
return reflect.Value{}, errors.New("field " + k + " not found")
|
|
}
|
|
|
|
// get the value from the map
|
|
val2, ok := val.(map[string]interface{})[k]
|
|
if !ok {
|
|
return reflect.Value{}, errors.New("field " + k + " not found in map")
|
|
}
|
|
|
|
// now we need to convert the value to the correct type
|
|
val3, err := v.FromAny(val2)
|
|
if err != nil {
|
|
return reflect.Value{}, err
|
|
}
|
|
|
|
// now we need to set the value on the field
|
|
v.SetValueOnField(field, val3)
|
|
|
|
}
|
|
|
|
return obj, nil
|
|
}
|
|
|
|
func (o Object) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
|
// if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value
|
|
if !o.required {
|
|
val = val.Addr()
|
|
}
|
|
|
|
obj.Field(o.index).Set(val)
|
|
}
|