answer/pkg/toolbox/function.go

300 lines
8.2 KiB
Go

package toolbox
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"reflect"
"strings"
"sync"
)
type FuncResponse struct {
Result string
Source string
}
type funcCache struct {
sync.RWMutex
m map[reflect.Type]function
}
func (c *funcCache) get(value reflect.Value) (function, bool) {
c.RLock()
defer c.RUnlock()
fn, ok := c.m[value.Type()]
if ok {
slog.Info("cache hit for function", "function", value.Type().String())
}
return fn, ok
}
func (c *funcCache) set(value reflect.Value, fn function) {
c.Lock()
defer c.Unlock()
c.m[value.Type()] = fn
}
var cache = funcCache{m: map[reflect.Type]function{}}
type arg struct {
Name string
Type reflect.Type
Index int
Values []string
Optional bool
Array bool
Description string
}
func (a arg) Schema() map[string]any {
var res = map[string]any{}
if a.Array {
res["type"] = "array"
res["items"] = map[string]any{"type": a.Type.Kind().String()}
} else {
res["type"] = a.Type.Name()
}
if !a.Optional {
res["required"] = true
}
if len(a.Values) > 0 {
res["enum"] = a.Values
}
if a.Description != "" {
res["description"] = a.Description
}
return res
}
type function struct {
fn reflect.Value
argType reflect.Type
args map[string]arg
}
var ErrInvalidFunction = errors.New("invalid function")
// analyzeFuncFromReflect extracts metadata from a reflect.Value representing a function and returns a structured result.
// It maps the function's parameter names to their corresponding reflect.Type and encapsulates them in a function struct.
// Returns a function struct with extracted information and an error if the operation fails.
// The first parameter to the function must be a *Context.
// The second parameter to the function must be a struct, all the fields of which will be passed as arguments to the
// function to be analyzed.
// Struct tags supported are:
// - `name:"<name>"` to specify the name of the parameter (default is the field name)
// - `description:"<description>"` to specify a description of the parameter (default is "")
// - `values:"<value1>,<value2>,..."` to specify a list of possible values for the parameter (default is "") only for
// string, int, and float types
//
// Allowed types on the struct are:
// - string, *string, []string
// - int, *int, []int
// - float64, *float64, []float64
// - bool, *bool, []bool
//
// Pointer types imply that the parameter is optional.
// The function must have at most 2 parameters.
// The function must return a string and an error.
// The function must be of the form `func(*agent.Context, T) (FuncResponse, error)`.
func analyzeFuncFromReflect(fn reflect.Value) (function, error) {
if f, ok := cache.get(fn); ok {
return f, nil
}
var res function
t := fn.Type()
args := map[string]arg{}
for i := 0; i < t.NumIn(); i++ {
if i == 0 {
if t.In(i).String() != "*agent.Context" {
return res, fmt.Errorf("%w: first parameter must be *agent.Context", ErrInvalidFunction)
}
continue
} else if i == 1 {
if t.In(i).Kind() != reflect.Struct {
return res, fmt.Errorf("%w: second parameter must be a struct", ErrInvalidFunction)
}
res.argType = t.In(i)
for j := 0; j < res.argType.NumField(); j++ {
field := res.argType.Field(j)
a := arg{
Name: field.Name,
Type: field.Type,
Index: j,
Description: "",
}
ft := field.Type
// if it's a pointer, it's optional
if ft.Kind() == reflect.Ptr {
a.Optional = true
ft = ft.Elem()
} else if ft.Kind() == reflect.Slice {
a.Array = true
ft = ft.Elem()
}
if ft.Kind() != reflect.String && ft.Kind() != reflect.Int && ft.Kind() != reflect.Float64 && ft.Kind() != reflect.Bool {
return res, fmt.Errorf("%w: unsupported type %s", ErrInvalidFunction, ft.Kind().String())
}
a.Type = ft
if name, ok := field.Tag.Lookup("name"); ok {
a.Name = name
a.Name = strings.TrimSpace(a.Name)
if a.Name == "" {
return res, fmt.Errorf("%w: name tag cannot be empty", ErrInvalidFunction)
}
}
if description, ok := field.Tag.Lookup("description"); ok {
a.Description = description
}
if values, ok := field.Tag.Lookup("values"); ok {
a.Values = strings.Split(values, ",")
for i, v := range a.Values {
a.Values[i] = strings.TrimSpace(v)
}
if ft.Kind() != reflect.String && ft.Kind() != reflect.Int && ft.Kind() != reflect.Float64 {
return res, fmt.Errorf("%w: values tag only supported for string, int, and float types", ErrInvalidFunction)
}
}
args[field.Name] = a
}
} else {
return res, fmt.Errorf("%w: function must have at most 2 parameters", ErrInvalidFunction)
}
}
// finally ensure that the function returns a FuncResponse and an error
if t.NumOut() != 2 || t.Out(0).String() != "agent.FuncResponse" || t.Out(1).String() != "error" {
return res, fmt.Errorf("%w: function must return a FuncResponse and an error", ErrInvalidFunction)
}
cache.set(fn, res)
return res, nil
}
func analyzeFunction[T any, AgentContext any](fn func(AgentContext, T) (FuncResponse, error)) (function, error) {
return analyzeFuncFromReflect(reflect.ValueOf(fn))
}
// Execute will execute the given function with the given context and arguments.
// Returns the result of the execution and an error if the operation fails.
// The arguments must be a JSON-encoded string that represents the struct to be passed to the function.
func (f function) Execute(ctx context.Context, args string) (FuncResponse, error) {
var m = map[string]any{}
err := json.Unmarshal([]byte(args), &m)
if err != nil {
return FuncResponse{}, fmt.Errorf("failed to unmarshal arguments: %w", err)
}
var obj = reflect.New(f.argType).Elem()
// TODO: ensure that "required" fields are present in the arguments
for name, a := range f.args {
if v, ok := m[name]; ok {
if a.Array {
if v == nil {
continue
}
switch a.Type.Kind() {
case reflect.String:
s := v.([]string)
slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf("")), len(s), len(s))
for i, str := range s {
slice.Index(i).SetString(str)
}
obj.Field(a.Index).Set(slice)
case reflect.Int:
i := v.([]int)
slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(0)), len(i), len(i))
for i, in := range i {
slice.Index(i).SetInt(int64(in))
}
obj.Field(a.Index).Set(slice)
case reflect.Float64:
f := v.([]float64)
slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(0.0)), len(f), len(f))
for i, fl := range f {
slice.Index(i).SetFloat(fl)
}
obj.Field(a.Index).Set(slice)
case reflect.Bool:
b := v.([]bool)
slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(false)), len(b), len(b))
for i, b := range b {
slice.Index(i).SetBool(b)
}
obj.Field(a.Index).Set(slice)
default:
return FuncResponse{}, fmt.Errorf("unsupported type %s for field %s", a.Type.Kind().String(), name)
}
} else if a.Optional {
if v == nil {
continue
}
switch a.Type.Kind() {
case reflect.String:
str := v.(string)
obj.Field(a.Index).Set(reflect.ValueOf(&str))
case reflect.Int:
i := v.(int)
obj.Field(a.Index).Set(reflect.ValueOf(&i))
case reflect.Float64:
f := v.(float64)
obj.Field(a.Index).Set(reflect.ValueOf(&f))
case reflect.Bool:
b := v.(bool)
obj.Field(a.Index).Set(reflect.ValueOf(&b))
default:
return FuncResponse{}, fmt.Errorf("unsupported type %s for field %s", a.Type.Kind().String(), name)
}
} else {
switch a.Type.Kind() {
case reflect.String:
obj.Field(a.Index).SetString(v.(string))
case reflect.Int:
obj.Field(a.Index).SetInt(int64(v.(int)))
case reflect.Float64:
obj.Field(a.Index).SetFloat(v.(float64))
case reflect.Bool:
obj.Field(a.Index).SetBool(v.(bool))
default:
return FuncResponse{}, fmt.Errorf("unsupported type %s for field %s", a.Type.Kind().String(), name)
}
}
}
}
res := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), obj})
if res[1].IsNil() {
return res[0].Interface().(FuncResponse), nil
}
return FuncResponse{}, res[1].Interface().(error)
}