Refactor toolbox and function handling to support synthetic fields and improve type definitions
This commit is contained in:
@@ -25,27 +25,27 @@ func getFromType(t reflect.Type, b basic) Type {
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.String:
|
||||
b.DataType = String
|
||||
b.DataType = TypeString
|
||||
b.typeName = "string"
|
||||
return b
|
||||
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
b.DataType = Integer
|
||||
b.DataType = TypeInteger
|
||||
b.typeName = "integer"
|
||||
return b
|
||||
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
b.DataType = Integer
|
||||
b.DataType = TypeInteger
|
||||
b.typeName = "integer"
|
||||
return b
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
b.DataType = Number
|
||||
b.DataType = TypeNumber
|
||||
b.typeName = "number"
|
||||
return b
|
||||
|
||||
case reflect.Bool:
|
||||
b.DataType = Boolean
|
||||
b.DataType = TypeBoolean
|
||||
b.typeName = "boolean"
|
||||
return b
|
||||
|
||||
@@ -92,7 +92,7 @@ func getField(f reflect.StructField, index int) Type {
|
||||
}
|
||||
}
|
||||
|
||||
b.DataType = String
|
||||
b.DataType = TypeString
|
||||
b.typeName = "string"
|
||||
return enum{
|
||||
basic: b,
|
||||
@@ -104,15 +104,26 @@ func getField(f reflect.StructField, index int) Type {
|
||||
return getFromType(t, b)
|
||||
}
|
||||
|
||||
func getObject(t reflect.Type) object {
|
||||
func getObject(t reflect.Type) Object {
|
||||
fields := make(map[string]Type, t.NumField())
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fields[field.Name] = getField(field, i)
|
||||
|
||||
if field.Anonymous {
|
||||
// if the field is anonymous, we need to get the fields of the anonymous struct
|
||||
// and add them to the object
|
||||
anon := getObject(field.Type)
|
||||
for k, v := range anon.fields {
|
||||
fields[k] = v
|
||||
}
|
||||
continue
|
||||
} else {
|
||||
fields[field.Name] = getField(field, i)
|
||||
}
|
||||
}
|
||||
|
||||
return object{
|
||||
basic: basic{DataType: Object, typeName: "object"},
|
||||
return Object{
|
||||
basic: basic{DataType: TypeObject, typeName: "object"},
|
||||
fields: fields,
|
||||
}
|
||||
}
|
||||
@@ -120,7 +131,7 @@ func getObject(t reflect.Type) object {
|
||||
func getArray(t reflect.Type) array {
|
||||
res := array{
|
||||
basic: basic{
|
||||
DataType: Array,
|
||||
DataType: TypeArray,
|
||||
typeName: "array",
|
||||
},
|
||||
}
|
||||
|
@@ -15,12 +15,12 @@ var _ Type = basic{}
|
||||
type DataType string
|
||||
|
||||
const (
|
||||
String DataType = "string"
|
||||
Integer DataType = "integer"
|
||||
Number DataType = "number"
|
||||
Boolean DataType = "boolean"
|
||||
Object DataType = "object"
|
||||
Array DataType = "array"
|
||||
TypeString DataType = "string"
|
||||
TypeInteger DataType = "integer"
|
||||
TypeNumber DataType = "number"
|
||||
TypeBoolean DataType = "boolean"
|
||||
TypeObject DataType = "object"
|
||||
TypeArray DataType = "array"
|
||||
)
|
||||
|
||||
type basic struct {
|
||||
@@ -49,17 +49,17 @@ func (b basic) GoogleParameters() *genai.Schema {
|
||||
var t = genai.TypeUnspecified
|
||||
|
||||
switch b.DataType {
|
||||
case String:
|
||||
case TypeString:
|
||||
t = genai.TypeString
|
||||
case Integer:
|
||||
case TypeInteger:
|
||||
t = genai.TypeInteger
|
||||
case Number:
|
||||
case TypeNumber:
|
||||
t = genai.TypeNumber
|
||||
case Boolean:
|
||||
case TypeBoolean:
|
||||
t = genai.TypeBoolean
|
||||
case Object:
|
||||
case TypeObject:
|
||||
t = genai.TypeObject
|
||||
case Array:
|
||||
case TypeArray:
|
||||
t = genai.TypeArray
|
||||
default:
|
||||
t = genai.TypeUnspecified
|
||||
@@ -82,12 +82,12 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
|
||||
v := reflect.ValueOf(val)
|
||||
|
||||
switch b.DataType {
|
||||
case String:
|
||||
case TypeString:
|
||||
var val = v.String()
|
||||
|
||||
return reflect.ValueOf(val), nil
|
||||
|
||||
case Integer:
|
||||
case TypeInteger:
|
||||
if v.Kind() == reflect.Float64 {
|
||||
return v.Convert(reflect.TypeOf(int(0))), nil
|
||||
} else if v.Kind() != reflect.Int {
|
||||
@@ -96,7 +96,7 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
case Number:
|
||||
case TypeNumber:
|
||||
if v.Kind() == reflect.Float64 {
|
||||
return v.Convert(reflect.TypeOf(float64(0))), nil
|
||||
} else if v.Kind() != reflect.Float64 {
|
||||
@@ -105,7 +105,7 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
case Boolean:
|
||||
case TypeBoolean:
|
||||
if v.Kind() == reflect.Bool {
|
||||
return v, nil
|
||||
} else if v.Kind() == reflect.String {
|
||||
|
@@ -8,15 +8,44 @@ import (
|
||||
"github.com/openai/openai-go"
|
||||
)
|
||||
|
||||
type object struct {
|
||||
const (
|
||||
// SyntheticFieldPrefix is any prefix that is added to any synthetic fields that are added to the object, to prevent
|
||||
// collisions with the fields in the struct.
|
||||
SyntheticFieldPrefix = "__"
|
||||
)
|
||||
|
||||
type Object struct {
|
||||
basic
|
||||
|
||||
ref reflect.Type
|
||||
|
||||
fields map[string]Type
|
||||
|
||||
// syntheticFields are fields that are not in the struct but are generated by a system.
|
||||
synetheticFields map[string]Type
|
||||
}
|
||||
|
||||
func (o object) OpenAIParameters() openai.FunctionParameters {
|
||||
func (o Object) WithSyntheticField(name string, description string) Object {
|
||||
if o.synetheticFields == nil {
|
||||
o.synetheticFields = map[string]Type{}
|
||||
}
|
||||
|
||||
o.synetheticFields[name] = basic{
|
||||
DataType: TypeString,
|
||||
typeName: "string",
|
||||
index: -1,
|
||||
required: false,
|
||||
description: description,
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
||||
|
||||
func (o Object) SyntheticFields() map[string]Type {
|
||||
return o.synetheticFields
|
||||
}
|
||||
|
||||
func (o Object) OpenAIParameters() openai.FunctionParameters {
|
||||
var properties = map[string]openai.FunctionParameters{}
|
||||
var required []string
|
||||
for k, v := range o.fields {
|
||||
@@ -26,6 +55,13 @@ func (o object) OpenAIParameters() openai.FunctionParameters {
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range o.synetheticFields {
|
||||
properties[SyntheticFieldPrefix+k] = v.OpenAIParameters()
|
||||
if v.Required() {
|
||||
required = append(required, SyntheticFieldPrefix+k)
|
||||
}
|
||||
}
|
||||
|
||||
var res = openai.FunctionParameters{
|
||||
"type": "object",
|
||||
"description": o.Description(),
|
||||
@@ -39,7 +75,7 @@ func (o object) OpenAIParameters() openai.FunctionParameters {
|
||||
return res
|
||||
}
|
||||
|
||||
func (o object) GoogleParameters() *genai.Schema {
|
||||
func (o Object) GoogleParameters() *genai.Schema {
|
||||
var properties = map[string]*genai.Schema{}
|
||||
var required []string
|
||||
for k, v := range o.fields {
|
||||
@@ -62,7 +98,8 @@ func (o object) GoogleParameters() *genai.Schema {
|
||||
return res
|
||||
}
|
||||
|
||||
func (o object) FromAny(val any) (reflect.Value, error) {
|
||||
// FromAny converts the value from any to the correct type, returning the value, and an error if any
|
||||
func (o Object) FromAny(val any) (reflect.Value, error) {
|
||||
// if the value is nil, we can't do anything
|
||||
if val == nil {
|
||||
return reflect.Value{}, nil
|
||||
@@ -99,7 +136,7 @@ func (o object) FromAny(val any) (reflect.Value, error) {
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
func (o object) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
||||
func (o Object) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
||||
// if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value
|
||||
if !o.required {
|
||||
val = val.Addr()
|
||||
|
Reference in New Issue
Block a user