106 lines
2.5 KiB
Go
106 lines
2.5 KiB
Go
|
package schema
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"reflect"
|
||
|
"strconv"
|
||
|
|
||
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||
|
)
|
||
|
|
||
|
// just enforcing that basic implements Type
|
||
|
var _ Type = basic{}
|
||
|
|
||
|
type basic struct {
|
||
|
jsonschema.DataType
|
||
|
|
||
|
// index is the position of the parameter in the StructField of the function's parameter struct
|
||
|
index int
|
||
|
|
||
|
// required is a flag that indicates whether the parameter is required in the function's parameter struct.
|
||
|
// this is inferred by if the parameter is a pointer type or not.
|
||
|
required bool
|
||
|
|
||
|
// description is a llm-readable description of the parameter passed to openai
|
||
|
description string
|
||
|
}
|
||
|
|
||
|
func (b basic) SchemaType() jsonschema.DataType {
|
||
|
return b.DataType
|
||
|
}
|
||
|
|
||
|
func (b basic) Definition() jsonschema.Definition {
|
||
|
return jsonschema.Definition{
|
||
|
Type: b.DataType,
|
||
|
Description: b.description,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (b basic) Required() bool {
|
||
|
return b.required
|
||
|
}
|
||
|
|
||
|
func (b basic) Description() string {
|
||
|
return b.description
|
||
|
}
|
||
|
|
||
|
func (b basic) FromAny(val any) (reflect.Value, error) {
|
||
|
v := reflect.ValueOf(val)
|
||
|
|
||
|
switch b.DataType {
|
||
|
case jsonschema.String:
|
||
|
var val = v.String()
|
||
|
|
||
|
return reflect.ValueOf(val), nil
|
||
|
|
||
|
case jsonschema.Integer:
|
||
|
if v.Kind() == reflect.Float64 {
|
||
|
return v.Convert(reflect.TypeOf(int(0))), nil
|
||
|
} else if v.Kind() != reflect.Int {
|
||
|
return reflect.Value{}, errors.New("expected int, got " + v.Kind().String())
|
||
|
} else {
|
||
|
return v, nil
|
||
|
}
|
||
|
|
||
|
case jsonschema.Number:
|
||
|
if v.Kind() == reflect.Float64 {
|
||
|
return v.Convert(reflect.TypeOf(float64(0))), nil
|
||
|
} else if v.Kind() != reflect.Float64 {
|
||
|
return reflect.Value{}, errors.New("expected float64, got " + v.Kind().String())
|
||
|
} else {
|
||
|
return v, nil
|
||
|
}
|
||
|
|
||
|
case jsonschema.Boolean:
|
||
|
if v.Kind() == reflect.Bool {
|
||
|
return v, nil
|
||
|
} else if v.Kind() == reflect.String {
|
||
|
b, err := strconv.ParseBool(v.String())
|
||
|
if err != nil {
|
||
|
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
|
||
|
}
|
||
|
return reflect.ValueOf(b), nil
|
||
|
} else {
|
||
|
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
|
||
|
}
|
||
|
|
||
|
default:
|
||
|
return reflect.Value{}, errors.New("unknown type")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (b basic) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
||
|
// if this basic type is not required that means it's a pointer type
|
||
|
// so we need to create a new value of the type of the pointer
|
||
|
if !b.required {
|
||
|
vv := reflect.New(obj.Field(b.index).Type().Elem())
|
||
|
|
||
|
// and then set the value of the pointer to the new value
|
||
|
vv.Elem().Set(val)
|
||
|
|
||
|
obj.Field(b.index).Set(vv)
|
||
|
return
|
||
|
}
|
||
|
obj.Field(b.index).Set(val)
|
||
|
}
|