166 lines
3.6 KiB
Go
166 lines
3.6 KiB
Go
package schema
|
|
|
|
import (
|
|
"errors"
|
|
"reflect"
|
|
"strconv"
|
|
|
|
"github.com/google/generative-ai-go/genai"
|
|
"github.com/openai/openai-go"
|
|
)
|
|
|
|
// just enforcing that basic implements Type
|
|
var _ Type = basic{}
|
|
|
|
type DataType string
|
|
|
|
const (
|
|
TypeString DataType = "string"
|
|
TypeInteger DataType = "integer"
|
|
TypeNumber DataType = "number"
|
|
TypeBoolean DataType = "boolean"
|
|
TypeObject DataType = "object"
|
|
TypeArray DataType = "array"
|
|
)
|
|
|
|
type basic struct {
|
|
DataType
|
|
typeName string
|
|
|
|
// index is the position of the parameter in the StructField of the function's parameter struct
|
|
index int
|
|
|
|
// required is a flag that indicates whether the parameter is required in the function's parameter struct.
|
|
// this is inferred by if the parameter is a pointer type or not.
|
|
required bool
|
|
|
|
// description is a llm-readable description of the parameter passed to openai
|
|
description string
|
|
}
|
|
|
|
func (b basic) OpenAIParameters() openai.FunctionParameters {
|
|
return openai.FunctionParameters{
|
|
"type": b.typeName,
|
|
"description": b.description,
|
|
}
|
|
}
|
|
|
|
func (b basic) GoogleParameters() *genai.Schema {
|
|
var t = genai.TypeUnspecified
|
|
|
|
switch b.DataType {
|
|
case TypeString:
|
|
t = genai.TypeString
|
|
case TypeInteger:
|
|
t = genai.TypeInteger
|
|
case TypeNumber:
|
|
t = genai.TypeNumber
|
|
case TypeBoolean:
|
|
t = genai.TypeBoolean
|
|
case TypeObject:
|
|
t = genai.TypeObject
|
|
case TypeArray:
|
|
t = genai.TypeArray
|
|
default:
|
|
t = genai.TypeUnspecified
|
|
}
|
|
return &genai.Schema{
|
|
Type: t,
|
|
Description: b.description,
|
|
}
|
|
}
|
|
|
|
func (b basic) AnthropicInputSchema() map[string]any {
|
|
var t = "string"
|
|
|
|
switch b.DataType {
|
|
case TypeString:
|
|
t = "string"
|
|
case TypeInteger:
|
|
t = "integer"
|
|
case TypeNumber:
|
|
t = "number"
|
|
case TypeBoolean:
|
|
t = "boolean"
|
|
case TypeObject:
|
|
t = "object"
|
|
case TypeArray:
|
|
t = "array"
|
|
default:
|
|
t = "unknown"
|
|
}
|
|
|
|
return map[string]any{
|
|
"type": t,
|
|
"description": b.description,
|
|
}
|
|
}
|
|
|
|
func (b basic) Required() bool {
|
|
return b.required
|
|
}
|
|
|
|
func (b basic) Description() string {
|
|
return b.description
|
|
}
|
|
|
|
func (b basic) FromAny(val any) (reflect.Value, error) {
|
|
v := reflect.ValueOf(val)
|
|
|
|
switch b.DataType {
|
|
case TypeString:
|
|
var val = v.String()
|
|
|
|
return reflect.ValueOf(val), nil
|
|
|
|
case TypeInteger:
|
|
if v.Kind() == reflect.Float64 {
|
|
return v.Convert(reflect.TypeOf(int(0))), nil
|
|
} else if v.Kind() != reflect.Int {
|
|
return reflect.Value{}, errors.New("expected int, got " + v.Kind().String())
|
|
} else {
|
|
return v, nil
|
|
}
|
|
|
|
case TypeNumber:
|
|
if v.Kind() == reflect.Float64 {
|
|
return v.Convert(reflect.TypeOf(float64(0))), nil
|
|
} else if v.Kind() != reflect.Float64 {
|
|
return reflect.Value{}, errors.New("expected float64, got " + v.Kind().String())
|
|
} else {
|
|
return v, nil
|
|
}
|
|
|
|
case TypeBoolean:
|
|
if v.Kind() == reflect.Bool {
|
|
return v, nil
|
|
} else if v.Kind() == reflect.String {
|
|
b, err := strconv.ParseBool(v.String())
|
|
if err != nil {
|
|
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
|
|
}
|
|
return reflect.ValueOf(b), nil
|
|
} else {
|
|
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
|
|
}
|
|
|
|
default:
|
|
return reflect.Value{}, errors.New("unknown type")
|
|
}
|
|
}
|
|
|
|
func (b basic) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
|
// if this basic type is not required that means it's a pointer type
|
|
// so we need to create a new value of the type of the pointer
|
|
if !b.required {
|
|
vv := reflect.New(obj.Field(b.index).Type().Elem())
|
|
|
|
// and then set the value of the pointer to the new value
|
|
vv.Elem().Set(val)
|
|
|
|
obj.Field(b.index).Set(vv)
|
|
return
|
|
}
|
|
obj.Field(b.index).Set(val)
|
|
}
|