go-llm/schema/basic.go

106 lines
2.5 KiB
Go
Raw Normal View History

package schema
import (
"errors"
"reflect"
"strconv"
"github.com/sashabaranov/go-openai/jsonschema"
)
// just enforcing that basic implements Type
var _ Type = basic{}
type basic struct {
jsonschema.DataType
// index is the position of the parameter in the StructField of the function's parameter struct
index int
// required is a flag that indicates whether the parameter is required in the function's parameter struct.
// this is inferred by if the parameter is a pointer type or not.
required bool
// description is a llm-readable description of the parameter passed to openai
description string
}
func (b basic) SchemaType() jsonschema.DataType {
return b.DataType
}
func (b basic) Definition() jsonschema.Definition {
return jsonschema.Definition{
Type: b.DataType,
Description: b.description,
}
}
func (b basic) Required() bool {
return b.required
}
func (b basic) Description() string {
return b.description
}
func (b basic) FromAny(val any) (reflect.Value, error) {
v := reflect.ValueOf(val)
switch b.DataType {
case jsonschema.String:
var val = v.String()
return reflect.ValueOf(val), nil
case jsonschema.Integer:
if v.Kind() == reflect.Float64 {
return v.Convert(reflect.TypeOf(int(0))), nil
} else if v.Kind() != reflect.Int {
return reflect.Value{}, errors.New("expected int, got " + v.Kind().String())
} else {
return v, nil
}
case jsonschema.Number:
if v.Kind() == reflect.Float64 {
return v.Convert(reflect.TypeOf(float64(0))), nil
} else if v.Kind() != reflect.Float64 {
return reflect.Value{}, errors.New("expected float64, got " + v.Kind().String())
} else {
return v, nil
}
case jsonschema.Boolean:
if v.Kind() == reflect.Bool {
return v, nil
} else if v.Kind() == reflect.String {
b, err := strconv.ParseBool(v.String())
if err != nil {
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
}
return reflect.ValueOf(b), nil
} else {
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
}
default:
return reflect.Value{}, errors.New("unknown type")
}
}
func (b basic) SetValueOnField(obj reflect.Value, val reflect.Value) {
// if this basic type is not required that means it's a pointer type
// so we need to create a new value of the type of the pointer
if !b.required {
vv := reflect.New(obj.Field(b.index).Type().Elem())
// and then set the value of the pointer to the new value
vv.Elem().Set(val)
obj.Field(b.index).Set(vv)
return
}
obj.Field(b.index).Set(val)
}