initial commit of untested function stuff

This commit is contained in:
2024-11-08 20:53:12 -05:00
parent f603010dee
commit 37939088ed
15 changed files with 693 additions and 20 deletions

125
schema/GetType.go Normal file
View File

@@ -0,0 +1,125 @@
package schema
import (
"reflect"
"strings"
"github.com/sashabaranov/go-openai/jsonschema"
)
// GetType will, given an interface{} that is a struct (NOT a pointer to a struct), return the Type of the struct that
// can be used to generate a json schema and build an object from a parsed json object.
func GetType(a any) Type {
t := reflect.TypeOf(a)
if t.Kind() != reflect.Struct {
panic("GetType expects a struct")
}
return getObject(t)
}
func getFromType(t reflect.Type, b basic) Type {
if t.Kind() == reflect.Ptr {
t = t.Elem()
b.required = false
}
switch t.Kind() {
case reflect.String:
b.DataType = jsonschema.String
return b
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
b.DataType = jsonschema.Integer
return b
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
b.DataType = jsonschema.Integer
return b
case reflect.Float32, reflect.Float64:
b.DataType = jsonschema.Number
return b
case reflect.Bool:
b.DataType = jsonschema.Boolean
return b
case reflect.Struct:
o := getObject(t)
o.basic.required = b.required
o.basic.index = b.index
o.basic.description = b.description
return o
case reflect.Slice:
return getArray(t)
default:
panic("unhandled default case for " + t.Kind().String() + " in getFromType")
}
}
func getField(f reflect.StructField, index int) Type {
b := basic{
index: index,
required: true,
description: "",
}
t := f.Type
// if the tag "description" is set, use that as the description
if desc, ok := f.Tag.Lookup("description"); ok {
b.description = desc
}
// now if the tag "enum" is set, we need to create an enum type
if v, ok := f.Tag.Lookup("enum"); ok {
vals := strings.Split(v, ",")
for i := 0; i < len(vals); i++ {
vals[i] = strings.TrimSpace(vals[i])
if vals[i] == "" {
vals = append(vals[:i], vals[i+1:]...)
}
}
return enum{
basic: b,
values: vals,
}
}
return getFromType(t, b)
}
func getObject(t reflect.Type) object {
fields := make(map[string]Type, t.NumField())
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fields[field.Name] = getField(field, i)
}
return object{
basic: basic{DataType: jsonschema.Object},
fields: fields,
}
}
func getArray(t reflect.Type) array {
res := array{
basic: basic{
DataType: jsonschema.Array,
},
}
res.items = getFromType(t.Elem(), basic{})
return res
}

65
schema/array.go Normal file
View File

@@ -0,0 +1,65 @@
package schema
import (
"errors"
"reflect"
"github.com/sashabaranov/go-openai/jsonschema"
)
type array struct {
basic
// items is the schema of the items in the array
items Type
}
func (a array) SchemaType() jsonschema.DataType {
return jsonschema.Array
}
func (a array) Definition() jsonschema.Definition {
def := a.basic.Definition()
def.Type = jsonschema.Array
i := a.items.Definition()
def.Items = &i
def.AdditionalProperties = false
return def
}
func (a array) FromAny(val any) (reflect.Value, error) {
v := reflect.ValueOf(val)
// first realize we may have a pointer to a slice if this type is not required
if !a.required && v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Slice {
return reflect.Value{}, errors.New("expected slice, got " + v.Kind().String())
}
// if the slice is nil, we can just return it
if v.IsNil() {
return v, nil
}
// if the slice is not nil, we need to convert each item
items := make([]reflect.Value, v.Len())
for i := 0; i < v.Len(); i++ {
item, err := a.items.FromAny(v.Index(i).Interface())
if err != nil {
return reflect.Value{}, err
}
items[i] = item
}
return reflect.ValueOf(items), nil
}
func (a array) SetValue(obj reflect.Value, val reflect.Value) {
if !a.required {
val = val.Addr()
}
obj.Field(a.index).Set(val)
}

105
schema/basic.go Normal file
View File

@@ -0,0 +1,105 @@
package schema
import (
"errors"
"reflect"
"strconv"
"github.com/sashabaranov/go-openai/jsonschema"
)
// just enforcing that basic implements Type
var _ Type = basic{}
type basic struct {
jsonschema.DataType
// index is the position of the parameter in the StructField of the function's parameter struct
index int
// required is a flag that indicates whether the parameter is required in the function's parameter struct.
// this is inferred by if the parameter is a pointer type or not.
required bool
// description is a llm-readable description of the parameter passed to openai
description string
}
func (b basic) SchemaType() jsonschema.DataType {
return b.DataType
}
func (b basic) Definition() jsonschema.Definition {
return jsonschema.Definition{
Type: b.DataType,
Description: b.description,
}
}
func (b basic) Required() bool {
return b.required
}
func (b basic) Description() string {
return b.description
}
func (b basic) FromAny(val any) (reflect.Value, error) {
v := reflect.ValueOf(val)
switch b.DataType {
case jsonschema.String:
var val = v.String()
return reflect.ValueOf(val), nil
case jsonschema.Integer:
if v.Kind() == reflect.Float64 {
return v.Convert(reflect.TypeOf(int(0))), nil
} else if v.Kind() != reflect.Int {
return reflect.Value{}, errors.New("expected int, got " + v.Kind().String())
} else {
return v, nil
}
case jsonschema.Number:
if v.Kind() == reflect.Float64 {
return v.Convert(reflect.TypeOf(float64(0))), nil
} else if v.Kind() != reflect.Float64 {
return reflect.Value{}, errors.New("expected float64, got " + v.Kind().String())
} else {
return v, nil
}
case jsonschema.Boolean:
if v.Kind() == reflect.Bool {
return v, nil
} else if v.Kind() == reflect.String {
b, err := strconv.ParseBool(v.String())
if err != nil {
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
}
return reflect.ValueOf(b), nil
} else {
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
}
default:
return reflect.Value{}, errors.New("unknown type")
}
}
func (b basic) SetValueOnField(obj reflect.Value, val reflect.Value) {
// if this basic type is not required that means it's a pointer type
// so we need to create a new value of the type of the pointer
if !b.required {
vv := reflect.New(obj.Field(b.index).Type().Elem())
// and then set the value of the pointer to the new value
vv.Elem().Set(val)
obj.Field(b.index).Set(vv)
return
}
obj.Field(b.index).Set(val)
}

47
schema/enum.go Normal file
View File

@@ -0,0 +1,47 @@
package schema
import (
"errors"
"reflect"
"golang.org/x/exp/slices"
"github.com/sashabaranov/go-openai/jsonschema"
)
type enum struct {
basic
values []string
}
func (e enum) SchemaType() jsonschema.DataType {
return jsonschema.String
}
func (e enum) Definition() jsonschema.Definition {
def := e.basic.Definition()
def.Enum = e.values
return def
}
func (e enum) FromAny(val any) (reflect.Value, error) {
v := reflect.ValueOf(val)
if v.Kind() != reflect.String {
return reflect.Value{}, errors.New("expected string, got " + v.Kind().String())
}
s := v.String()
if !slices.Contains(e.values, s) {
return reflect.Value{}, errors.New("value " + s + " not in enum")
}
return v, nil
}
func (e enum) SetValueOnField(obj reflect.Value, val reflect.Value) {
if !e.required {
val = val.Addr()
}
obj.Field(e.index).Set(val)
}

78
schema/object.go Normal file
View File

@@ -0,0 +1,78 @@
package schema
import (
"errors"
"reflect"
"github.com/sashabaranov/go-openai/jsonschema"
)
type object struct {
basic
ref reflect.Type
fields map[string]Type
}
func (o object) SchemaType() jsonschema.DataType {
return jsonschema.Object
}
func (o object) Definition() jsonschema.Definition {
def := o.basic.Definition()
def.Type = jsonschema.Object
def.Properties = make(map[string]jsonschema.Definition)
for k, v := range o.fields {
def.Properties[k] = v.Definition()
}
def.AdditionalProperties = false
return def
}
func (o object) FromAny(val any) (reflect.Value, error) {
// if the value is nil, we can't do anything
if val == nil {
return reflect.Value{}, nil
}
// now make a new object of the type we're trying to parse
obj := reflect.New(o.ref).Elem()
// now we need to iterate over the fields and set the values
for k, v := range o.fields {
// get the field by name
field := obj.FieldByName(k)
if !field.IsValid() {
return reflect.Value{}, errors.New("field " + k + " not found")
}
// get the value from the map
val2, ok := val.(map[string]interface{})[k]
if !ok {
return reflect.Value{}, errors.New("field " + k + " not found in map")
}
// now we need to convert the value to the correct type
val3, err := v.FromAny(val2)
if err != nil {
return reflect.Value{}, err
}
// now we need to set the value on the field
v.SetValueOnField(field, val3)
}
return obj, nil
}
func (o object) SetValueOnField(obj reflect.Value, val reflect.Value) {
// if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value
if !o.required {
val = val.Addr()
}
obj.Field(o.index).Set(val)
}

18
schema/type.go Normal file
View File

@@ -0,0 +1,18 @@
package schema
import (
"reflect"
"github.com/sashabaranov/go-openai/jsonschema"
)
type Type interface {
SchemaType() jsonschema.DataType
Definition() jsonschema.Definition
Required() bool
Description() string
FromAny(any) (reflect.Value, error)
SetValueOnField(obj reflect.Value, val reflect.Value)
}