package schema import ( "errors" "reflect" "strconv" "github.com/google/generative-ai-go/genai" "github.com/openai/openai-go" ) // just enforcing that basic implements Type var _ Type = basic{} type DataType string const ( TypeString DataType = "string" TypeInteger DataType = "integer" TypeNumber DataType = "number" TypeBoolean DataType = "boolean" TypeObject DataType = "object" TypeArray DataType = "array" ) type basic struct { DataType typeName string // index is the position of the parameter in the StructField of the function's parameter struct index int // required is a flag that indicates whether the parameter is required in the function's parameter struct. // this is inferred by if the parameter is a pointer type or not. required bool // description is a llm-readable description of the parameter passed to openai description string } func (b basic) OpenAIParameters() openai.FunctionParameters { return openai.FunctionParameters{ "type": b.typeName, "description": b.description, } } func (b basic) GoogleParameters() *genai.Schema { var t = genai.TypeUnspecified switch b.DataType { case TypeString: t = genai.TypeString case TypeInteger: t = genai.TypeInteger case TypeNumber: t = genai.TypeNumber case TypeBoolean: t = genai.TypeBoolean case TypeObject: t = genai.TypeObject case TypeArray: t = genai.TypeArray default: t = genai.TypeUnspecified } return &genai.Schema{ Type: t, Description: b.description, } } func (b basic) Required() bool { return b.required } func (b basic) Description() string { return b.description } func (b basic) FromAny(val any) (reflect.Value, error) { v := reflect.ValueOf(val) switch b.DataType { case TypeString: var val = v.String() return reflect.ValueOf(val), nil case TypeInteger: if v.Kind() == reflect.Float64 { return v.Convert(reflect.TypeOf(int(0))), nil } else if v.Kind() != reflect.Int { return reflect.Value{}, errors.New("expected int, got " + v.Kind().String()) } else { return v, nil } case TypeNumber: if v.Kind() == reflect.Float64 { return v.Convert(reflect.TypeOf(float64(0))), nil } else if v.Kind() != reflect.Float64 { return reflect.Value{}, errors.New("expected float64, got " + v.Kind().String()) } else { return v, nil } case TypeBoolean: if v.Kind() == reflect.Bool { return v, nil } else if v.Kind() == reflect.String { b, err := strconv.ParseBool(v.String()) if err != nil { return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String()) } return reflect.ValueOf(b), nil } else { return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String()) } default: return reflect.Value{}, errors.New("unknown type") } } func (b basic) SetValueOnField(obj reflect.Value, val reflect.Value) { // if this basic type is not required that means it's a pointer type // so we need to create a new value of the type of the pointer if !b.required { vv := reflect.New(obj.Field(b.index).Type().Elem()) // and then set the value of the pointer to the new value vv.Elem().Set(val) obj.Field(b.index).Set(vv) return } obj.Field(b.index).Set(val) }