go-llm/schema/GetType.go

126 lines
2.4 KiB
Go
Raw Normal View History

package schema
import (
"reflect"
"strings"
"github.com/sashabaranov/go-openai/jsonschema"
)
// GetType will, given an interface{} that is a struct (NOT a pointer to a struct), return the Type of the struct that
// can be used to generate a json schema and build an object from a parsed json object.
func GetType(a any) Type {
t := reflect.TypeOf(a)
if t.Kind() != reflect.Struct {
panic("GetType expects a struct")
}
return getObject(t)
}
func getFromType(t reflect.Type, b basic) Type {
if t.Kind() == reflect.Ptr {
t = t.Elem()
b.required = false
}
switch t.Kind() {
case reflect.String:
b.DataType = jsonschema.String
return b
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
b.DataType = jsonschema.Integer
return b
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
b.DataType = jsonschema.Integer
return b
case reflect.Float32, reflect.Float64:
b.DataType = jsonschema.Number
return b
case reflect.Bool:
b.DataType = jsonschema.Boolean
return b
case reflect.Struct:
o := getObject(t)
o.basic.required = b.required
o.basic.index = b.index
o.basic.description = b.description
return o
case reflect.Slice:
return getArray(t)
default:
panic("unhandled default case for " + t.Kind().String() + " in getFromType")
}
}
func getField(f reflect.StructField, index int) Type {
b := basic{
index: index,
required: true,
description: "",
}
t := f.Type
// if the tag "description" is set, use that as the description
if desc, ok := f.Tag.Lookup("description"); ok {
b.description = desc
}
// now if the tag "enum" is set, we need to create an enum type
if v, ok := f.Tag.Lookup("enum"); ok {
vals := strings.Split(v, ",")
for i := 0; i < len(vals); i++ {
vals[i] = strings.TrimSpace(vals[i])
if vals[i] == "" {
vals = append(vals[:i], vals[i+1:]...)
}
}
return enum{
basic: b,
values: vals,
}
}
return getFromType(t, b)
}
func getObject(t reflect.Type) object {
fields := make(map[string]Type, t.NumField())
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fields[field.Name] = getField(field, i)
}
return object{
basic: basic{DataType: jsonschema.Object},
fields: fields,
}
}
func getArray(t reflect.Type) array {
res := array{
basic: basic{
DataType: jsonschema.Array,
},
}
res.items = getFromType(t.Elem(), basic{})
return res
}