Refactor toolbox and function handling to support synthetic fields and improve type definitions

This commit is contained in:
2025-04-12 02:20:40 -04:00
parent 2ae583e9f3
commit 3093b988f8
13 changed files with 288 additions and 160 deletions

View File

@@ -25,27 +25,27 @@ func getFromType(t reflect.Type, b basic) Type {
switch t.Kind() {
case reflect.String:
b.DataType = String
b.DataType = TypeString
b.typeName = "string"
return b
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
b.DataType = Integer
b.DataType = TypeInteger
b.typeName = "integer"
return b
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
b.DataType = Integer
b.DataType = TypeInteger
b.typeName = "integer"
return b
case reflect.Float32, reflect.Float64:
b.DataType = Number
b.DataType = TypeNumber
b.typeName = "number"
return b
case reflect.Bool:
b.DataType = Boolean
b.DataType = TypeBoolean
b.typeName = "boolean"
return b
@@ -92,7 +92,7 @@ func getField(f reflect.StructField, index int) Type {
}
}
b.DataType = String
b.DataType = TypeString
b.typeName = "string"
return enum{
basic: b,
@@ -104,15 +104,26 @@ func getField(f reflect.StructField, index int) Type {
return getFromType(t, b)
}
func getObject(t reflect.Type) object {
func getObject(t reflect.Type) Object {
fields := make(map[string]Type, t.NumField())
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fields[field.Name] = getField(field, i)
if field.Anonymous {
// if the field is anonymous, we need to get the fields of the anonymous struct
// and add them to the object
anon := getObject(field.Type)
for k, v := range anon.fields {
fields[k] = v
}
continue
} else {
fields[field.Name] = getField(field, i)
}
}
return object{
basic: basic{DataType: Object, typeName: "object"},
return Object{
basic: basic{DataType: TypeObject, typeName: "object"},
fields: fields,
}
}
@@ -120,7 +131,7 @@ func getObject(t reflect.Type) object {
func getArray(t reflect.Type) array {
res := array{
basic: basic{
DataType: Array,
DataType: TypeArray,
typeName: "array",
},
}

View File

@@ -15,12 +15,12 @@ var _ Type = basic{}
type DataType string
const (
String DataType = "string"
Integer DataType = "integer"
Number DataType = "number"
Boolean DataType = "boolean"
Object DataType = "object"
Array DataType = "array"
TypeString DataType = "string"
TypeInteger DataType = "integer"
TypeNumber DataType = "number"
TypeBoolean DataType = "boolean"
TypeObject DataType = "object"
TypeArray DataType = "array"
)
type basic struct {
@@ -49,17 +49,17 @@ func (b basic) GoogleParameters() *genai.Schema {
var t = genai.TypeUnspecified
switch b.DataType {
case String:
case TypeString:
t = genai.TypeString
case Integer:
case TypeInteger:
t = genai.TypeInteger
case Number:
case TypeNumber:
t = genai.TypeNumber
case Boolean:
case TypeBoolean:
t = genai.TypeBoolean
case Object:
case TypeObject:
t = genai.TypeObject
case Array:
case TypeArray:
t = genai.TypeArray
default:
t = genai.TypeUnspecified
@@ -82,12 +82,12 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
v := reflect.ValueOf(val)
switch b.DataType {
case String:
case TypeString:
var val = v.String()
return reflect.ValueOf(val), nil
case Integer:
case TypeInteger:
if v.Kind() == reflect.Float64 {
return v.Convert(reflect.TypeOf(int(0))), nil
} else if v.Kind() != reflect.Int {
@@ -96,7 +96,7 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
return v, nil
}
case Number:
case TypeNumber:
if v.Kind() == reflect.Float64 {
return v.Convert(reflect.TypeOf(float64(0))), nil
} else if v.Kind() != reflect.Float64 {
@@ -105,7 +105,7 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
return v, nil
}
case Boolean:
case TypeBoolean:
if v.Kind() == reflect.Bool {
return v, nil
} else if v.Kind() == reflect.String {

View File

@@ -8,15 +8,44 @@ import (
"github.com/openai/openai-go"
)
type object struct {
const (
// SyntheticFieldPrefix is any prefix that is added to any synthetic fields that are added to the object, to prevent
// collisions with the fields in the struct.
SyntheticFieldPrefix = "__"
)
type Object struct {
basic
ref reflect.Type
fields map[string]Type
// syntheticFields are fields that are not in the struct but are generated by a system.
synetheticFields map[string]Type
}
func (o object) OpenAIParameters() openai.FunctionParameters {
func (o Object) WithSyntheticField(name string, description string) Object {
if o.synetheticFields == nil {
o.synetheticFields = map[string]Type{}
}
o.synetheticFields[name] = basic{
DataType: TypeString,
typeName: "string",
index: -1,
required: false,
description: description,
}
return o
}
func (o Object) SyntheticFields() map[string]Type {
return o.synetheticFields
}
func (o Object) OpenAIParameters() openai.FunctionParameters {
var properties = map[string]openai.FunctionParameters{}
var required []string
for k, v := range o.fields {
@@ -26,6 +55,13 @@ func (o object) OpenAIParameters() openai.FunctionParameters {
}
}
for k, v := range o.synetheticFields {
properties[SyntheticFieldPrefix+k] = v.OpenAIParameters()
if v.Required() {
required = append(required, SyntheticFieldPrefix+k)
}
}
var res = openai.FunctionParameters{
"type": "object",
"description": o.Description(),
@@ -39,7 +75,7 @@ func (o object) OpenAIParameters() openai.FunctionParameters {
return res
}
func (o object) GoogleParameters() *genai.Schema {
func (o Object) GoogleParameters() *genai.Schema {
var properties = map[string]*genai.Schema{}
var required []string
for k, v := range o.fields {
@@ -62,7 +98,8 @@ func (o object) GoogleParameters() *genai.Schema {
return res
}
func (o object) FromAny(val any) (reflect.Value, error) {
// FromAny converts the value from any to the correct type, returning the value, and an error if any
func (o Object) FromAny(val any) (reflect.Value, error) {
// if the value is nil, we can't do anything
if val == nil {
return reflect.Value{}, nil
@@ -99,7 +136,7 @@ func (o object) FromAny(val any) (reflect.Value, error) {
return obj, nil
}
func (o object) SetValueOnField(obj reflect.Value, val reflect.Value) {
func (o Object) SetValueOnField(obj reflect.Value, val reflect.Value) {
// if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value
if !o.required {
val = val.Addr()