initial commit of untested function stuff
This commit is contained in:
125
schema/GetType.go
Normal file
125
schema/GetType.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
// GetType will, given an interface{} that is a struct (NOT a pointer to a struct), return the Type of the struct that
|
||||
// can be used to generate a json schema and build an object from a parsed json object.
|
||||
func GetType(a any) Type {
|
||||
t := reflect.TypeOf(a)
|
||||
|
||||
if t.Kind() != reflect.Struct {
|
||||
panic("GetType expects a struct")
|
||||
}
|
||||
|
||||
return getObject(t)
|
||||
}
|
||||
|
||||
func getFromType(t reflect.Type, b basic) Type {
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
b.required = false
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.String:
|
||||
b.DataType = jsonschema.String
|
||||
return b
|
||||
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
b.DataType = jsonschema.Integer
|
||||
return b
|
||||
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
b.DataType = jsonschema.Integer
|
||||
return b
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
b.DataType = jsonschema.Number
|
||||
return b
|
||||
|
||||
case reflect.Bool:
|
||||
b.DataType = jsonschema.Boolean
|
||||
return b
|
||||
|
||||
case reflect.Struct:
|
||||
o := getObject(t)
|
||||
|
||||
o.basic.required = b.required
|
||||
o.basic.index = b.index
|
||||
o.basic.description = b.description
|
||||
|
||||
return o
|
||||
|
||||
case reflect.Slice:
|
||||
return getArray(t)
|
||||
|
||||
default:
|
||||
panic("unhandled default case for " + t.Kind().String() + " in getFromType")
|
||||
}
|
||||
}
|
||||
|
||||
func getField(f reflect.StructField, index int) Type {
|
||||
b := basic{
|
||||
index: index,
|
||||
required: true,
|
||||
description: "",
|
||||
}
|
||||
|
||||
t := f.Type
|
||||
|
||||
// if the tag "description" is set, use that as the description
|
||||
if desc, ok := f.Tag.Lookup("description"); ok {
|
||||
b.description = desc
|
||||
}
|
||||
|
||||
// now if the tag "enum" is set, we need to create an enum type
|
||||
if v, ok := f.Tag.Lookup("enum"); ok {
|
||||
vals := strings.Split(v, ",")
|
||||
|
||||
for i := 0; i < len(vals); i++ {
|
||||
vals[i] = strings.TrimSpace(vals[i])
|
||||
|
||||
if vals[i] == "" {
|
||||
vals = append(vals[:i], vals[i+1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
return enum{
|
||||
basic: b,
|
||||
values: vals,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return getFromType(t, b)
|
||||
}
|
||||
|
||||
func getObject(t reflect.Type) object {
|
||||
fields := make(map[string]Type, t.NumField())
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fields[field.Name] = getField(field, i)
|
||||
}
|
||||
|
||||
return object{
|
||||
basic: basic{DataType: jsonschema.Object},
|
||||
fields: fields,
|
||||
}
|
||||
}
|
||||
|
||||
func getArray(t reflect.Type) array {
|
||||
res := array{
|
||||
basic: basic{
|
||||
DataType: jsonschema.Array,
|
||||
},
|
||||
}
|
||||
|
||||
res.items = getFromType(t.Elem(), basic{})
|
||||
|
||||
return res
|
||||
}
|
65
schema/array.go
Normal file
65
schema/array.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
type array struct {
|
||||
basic
|
||||
|
||||
// items is the schema of the items in the array
|
||||
items Type
|
||||
}
|
||||
|
||||
func (a array) SchemaType() jsonschema.DataType {
|
||||
return jsonschema.Array
|
||||
}
|
||||
|
||||
func (a array) Definition() jsonschema.Definition {
|
||||
def := a.basic.Definition()
|
||||
def.Type = jsonschema.Array
|
||||
i := a.items.Definition()
|
||||
def.Items = &i
|
||||
def.AdditionalProperties = false
|
||||
return def
|
||||
}
|
||||
|
||||
func (a array) FromAny(val any) (reflect.Value, error) {
|
||||
v := reflect.ValueOf(val)
|
||||
|
||||
// first realize we may have a pointer to a slice if this type is not required
|
||||
if !a.required && v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
if v.Kind() != reflect.Slice {
|
||||
return reflect.Value{}, errors.New("expected slice, got " + v.Kind().String())
|
||||
}
|
||||
|
||||
// if the slice is nil, we can just return it
|
||||
if v.IsNil() {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// if the slice is not nil, we need to convert each item
|
||||
items := make([]reflect.Value, v.Len())
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
item, err := a.items.FromAny(v.Index(i).Interface())
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
items[i] = item
|
||||
}
|
||||
|
||||
return reflect.ValueOf(items), nil
|
||||
}
|
||||
|
||||
func (a array) SetValue(obj reflect.Value, val reflect.Value) {
|
||||
if !a.required {
|
||||
val = val.Addr()
|
||||
}
|
||||
obj.Field(a.index).Set(val)
|
||||
}
|
105
schema/basic.go
Normal file
105
schema/basic.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
// just enforcing that basic implements Type
|
||||
var _ Type = basic{}
|
||||
|
||||
type basic struct {
|
||||
jsonschema.DataType
|
||||
|
||||
// index is the position of the parameter in the StructField of the function's parameter struct
|
||||
index int
|
||||
|
||||
// required is a flag that indicates whether the parameter is required in the function's parameter struct.
|
||||
// this is inferred by if the parameter is a pointer type or not.
|
||||
required bool
|
||||
|
||||
// description is a llm-readable description of the parameter passed to openai
|
||||
description string
|
||||
}
|
||||
|
||||
func (b basic) SchemaType() jsonschema.DataType {
|
||||
return b.DataType
|
||||
}
|
||||
|
||||
func (b basic) Definition() jsonschema.Definition {
|
||||
return jsonschema.Definition{
|
||||
Type: b.DataType,
|
||||
Description: b.description,
|
||||
}
|
||||
}
|
||||
|
||||
func (b basic) Required() bool {
|
||||
return b.required
|
||||
}
|
||||
|
||||
func (b basic) Description() string {
|
||||
return b.description
|
||||
}
|
||||
|
||||
func (b basic) FromAny(val any) (reflect.Value, error) {
|
||||
v := reflect.ValueOf(val)
|
||||
|
||||
switch b.DataType {
|
||||
case jsonschema.String:
|
||||
var val = v.String()
|
||||
|
||||
return reflect.ValueOf(val), nil
|
||||
|
||||
case jsonschema.Integer:
|
||||
if v.Kind() == reflect.Float64 {
|
||||
return v.Convert(reflect.TypeOf(int(0))), nil
|
||||
} else if v.Kind() != reflect.Int {
|
||||
return reflect.Value{}, errors.New("expected int, got " + v.Kind().String())
|
||||
} else {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
case jsonschema.Number:
|
||||
if v.Kind() == reflect.Float64 {
|
||||
return v.Convert(reflect.TypeOf(float64(0))), nil
|
||||
} else if v.Kind() != reflect.Float64 {
|
||||
return reflect.Value{}, errors.New("expected float64, got " + v.Kind().String())
|
||||
} else {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
case jsonschema.Boolean:
|
||||
if v.Kind() == reflect.Bool {
|
||||
return v, nil
|
||||
} else if v.Kind() == reflect.String {
|
||||
b, err := strconv.ParseBool(v.String())
|
||||
if err != nil {
|
||||
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
|
||||
}
|
||||
return reflect.ValueOf(b), nil
|
||||
} else {
|
||||
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
|
||||
}
|
||||
|
||||
default:
|
||||
return reflect.Value{}, errors.New("unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
func (b basic) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
||||
// if this basic type is not required that means it's a pointer type
|
||||
// so we need to create a new value of the type of the pointer
|
||||
if !b.required {
|
||||
vv := reflect.New(obj.Field(b.index).Type().Elem())
|
||||
|
||||
// and then set the value of the pointer to the new value
|
||||
vv.Elem().Set(val)
|
||||
|
||||
obj.Field(b.index).Set(vv)
|
||||
return
|
||||
}
|
||||
obj.Field(b.index).Set(val)
|
||||
}
|
47
schema/enum.go
Normal file
47
schema/enum.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
type enum struct {
|
||||
basic
|
||||
|
||||
values []string
|
||||
}
|
||||
|
||||
func (e enum) SchemaType() jsonschema.DataType {
|
||||
return jsonschema.String
|
||||
}
|
||||
|
||||
func (e enum) Definition() jsonschema.Definition {
|
||||
def := e.basic.Definition()
|
||||
def.Enum = e.values
|
||||
return def
|
||||
}
|
||||
|
||||
func (e enum) FromAny(val any) (reflect.Value, error) {
|
||||
v := reflect.ValueOf(val)
|
||||
if v.Kind() != reflect.String {
|
||||
return reflect.Value{}, errors.New("expected string, got " + v.Kind().String())
|
||||
}
|
||||
|
||||
s := v.String()
|
||||
if !slices.Contains(e.values, s) {
|
||||
return reflect.Value{}, errors.New("value " + s + " not in enum")
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (e enum) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
||||
if !e.required {
|
||||
val = val.Addr()
|
||||
}
|
||||
obj.Field(e.index).Set(val)
|
||||
}
|
78
schema/object.go
Normal file
78
schema/object.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
type object struct {
|
||||
basic
|
||||
|
||||
ref reflect.Type
|
||||
|
||||
fields map[string]Type
|
||||
}
|
||||
|
||||
func (o object) SchemaType() jsonschema.DataType {
|
||||
return jsonschema.Object
|
||||
}
|
||||
|
||||
func (o object) Definition() jsonschema.Definition {
|
||||
def := o.basic.Definition()
|
||||
def.Type = jsonschema.Object
|
||||
def.Properties = make(map[string]jsonschema.Definition)
|
||||
for k, v := range o.fields {
|
||||
def.Properties[k] = v.Definition()
|
||||
}
|
||||
|
||||
def.AdditionalProperties = false
|
||||
return def
|
||||
}
|
||||
|
||||
func (o object) FromAny(val any) (reflect.Value, error) {
|
||||
// if the value is nil, we can't do anything
|
||||
if val == nil {
|
||||
return reflect.Value{}, nil
|
||||
}
|
||||
|
||||
// now make a new object of the type we're trying to parse
|
||||
obj := reflect.New(o.ref).Elem()
|
||||
|
||||
// now we need to iterate over the fields and set the values
|
||||
for k, v := range o.fields {
|
||||
// get the field by name
|
||||
field := obj.FieldByName(k)
|
||||
if !field.IsValid() {
|
||||
return reflect.Value{}, errors.New("field " + k + " not found")
|
||||
}
|
||||
|
||||
// get the value from the map
|
||||
val2, ok := val.(map[string]interface{})[k]
|
||||
if !ok {
|
||||
return reflect.Value{}, errors.New("field " + k + " not found in map")
|
||||
}
|
||||
|
||||
// now we need to convert the value to the correct type
|
||||
val3, err := v.FromAny(val2)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
// now we need to set the value on the field
|
||||
v.SetValueOnField(field, val3)
|
||||
|
||||
}
|
||||
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
func (o object) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
||||
// if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value
|
||||
if !o.required {
|
||||
val = val.Addr()
|
||||
}
|
||||
|
||||
obj.Field(o.index).Set(val)
|
||||
}
|
18
schema/type.go
Normal file
18
schema/type.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
type Type interface {
|
||||
SchemaType() jsonschema.DataType
|
||||
Definition() jsonschema.Definition
|
||||
|
||||
Required() bool
|
||||
Description() string
|
||||
|
||||
FromAny(any) (reflect.Value, error)
|
||||
SetValueOnField(obj reflect.Value, val reflect.Value)
|
||||
}
|
Reference in New Issue
Block a user