package toolbox import ( "context" "encoding/json" "errors" "fmt" "log/slog" "reflect" "strings" "sync" ) type FuncResponse struct { Result string Source string } type funcCache struct { sync.RWMutex m map[reflect.Type]function } func (c *funcCache) get(value reflect.Value) (function, bool) { c.RLock() defer c.RUnlock() fn, ok := c.m[value.Type()] if ok { slog.Info("cache hit for function", "function", value.Type().String()) } return fn, ok } func (c *funcCache) set(value reflect.Value, fn function) { c.Lock() defer c.Unlock() c.m[value.Type()] = fn } var cache = funcCache{m: map[reflect.Type]function{}} type arg struct { Name string Type reflect.Type Index int Values []string Optional bool Array bool Description string } func (a arg) Schema() map[string]any { var res = map[string]any{} if a.Array { res["type"] = "array" res["items"] = map[string]any{"type": a.Type.Kind().String()} } else { res["type"] = a.Type.Name() } if !a.Optional { res["required"] = true } if len(a.Values) > 0 { res["enum"] = a.Values } if a.Description != "" { res["description"] = a.Description } return res } type function struct { fn reflect.Value argType reflect.Type args map[string]arg } var ErrInvalidFunction = errors.New("invalid function") // analyzeFuncFromReflect extracts metadata from a reflect.Value representing a function and returns a structured result. // It maps the function's parameter names to their corresponding reflect.Type and encapsulates them in a function struct. // Returns a function struct with extracted information and an error if the operation fails. // The first parameter to the function must be a *Context. // The second parameter to the function must be a struct, all the fields of which will be passed as arguments to the // function to be analyzed. // Struct tags supported are: // - `name:""` to specify the name of the parameter (default is the field name) // - `description:""` to specify a description of the parameter (default is "") // - `values:",,..."` to specify a list of possible values for the parameter (default is "") only for // string, int, and float types // // Allowed types on the struct are: // - string, *string, []string // - int, *int, []int // - float64, *float64, []float64 // - bool, *bool, []bool // // Pointer types imply that the parameter is optional. // The function must have at most 2 parameters. // The function must return a string and an error. // The function must be of the form `func(*agent.Context, T) (FuncResponse, error)`. func analyzeFuncFromReflect(fn reflect.Value) (function, error) { if f, ok := cache.get(fn); ok { return f, nil } var res function t := fn.Type() args := map[string]arg{} for i := 0; i < t.NumIn(); i++ { if i == 0 { if t.In(i).String() != "*agent.Context" { return res, fmt.Errorf("%w: first parameter must be *agent.Context", ErrInvalidFunction) } continue } else if i == 1 { if t.In(i).Kind() != reflect.Struct { return res, fmt.Errorf("%w: second parameter must be a struct", ErrInvalidFunction) } res.argType = t.In(i) for j := 0; j < res.argType.NumField(); j++ { field := res.argType.Field(j) a := arg{ Name: field.Name, Type: field.Type, Index: j, Description: "", } ft := field.Type // if it's a pointer, it's optional if ft.Kind() == reflect.Ptr { a.Optional = true ft = ft.Elem() } else if ft.Kind() == reflect.Slice { a.Array = true ft = ft.Elem() } if ft.Kind() != reflect.String && ft.Kind() != reflect.Int && ft.Kind() != reflect.Float64 && ft.Kind() != reflect.Bool { return res, fmt.Errorf("%w: unsupported type %s", ErrInvalidFunction, ft.Kind().String()) } a.Type = ft if name, ok := field.Tag.Lookup("name"); ok { a.Name = name a.Name = strings.TrimSpace(a.Name) if a.Name == "" { return res, fmt.Errorf("%w: name tag cannot be empty", ErrInvalidFunction) } } if description, ok := field.Tag.Lookup("description"); ok { a.Description = description } if values, ok := field.Tag.Lookup("values"); ok { a.Values = strings.Split(values, ",") for i, v := range a.Values { a.Values[i] = strings.TrimSpace(v) } if ft.Kind() != reflect.String && ft.Kind() != reflect.Int && ft.Kind() != reflect.Float64 { return res, fmt.Errorf("%w: values tag only supported for string, int, and float types", ErrInvalidFunction) } } args[field.Name] = a } } else { return res, fmt.Errorf("%w: function must have at most 2 parameters", ErrInvalidFunction) } } // finally ensure that the function returns a FuncResponse and an error if t.NumOut() != 2 || t.Out(0).String() != "agent.FuncResponse" || t.Out(1).String() != "error" { return res, fmt.Errorf("%w: function must return a FuncResponse and an error", ErrInvalidFunction) } cache.set(fn, res) return res, nil } func analyzeFunction[T any, AgentContext any](fn func(AgentContext, T) (FuncResponse, error)) (function, error) { return analyzeFuncFromReflect(reflect.ValueOf(fn)) } // Execute will execute the given function with the given context and arguments. // Returns the result of the execution and an error if the operation fails. // The arguments must be a JSON-encoded string that represents the struct to be passed to the function. func (f function) Execute(ctx context.Context, args string) (FuncResponse, error) { var m = map[string]any{} err := json.Unmarshal([]byte(args), &m) if err != nil { return FuncResponse{}, fmt.Errorf("failed to unmarshal arguments: %w", err) } var obj = reflect.New(f.argType).Elem() // TODO: ensure that "required" fields are present in the arguments for name, a := range f.args { if v, ok := m[name]; ok { if a.Array { if v == nil { continue } switch a.Type.Kind() { case reflect.String: s := v.([]string) slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf("")), len(s), len(s)) for i, str := range s { slice.Index(i).SetString(str) } obj.Field(a.Index).Set(slice) case reflect.Int: i := v.([]int) slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(0)), len(i), len(i)) for i, in := range i { slice.Index(i).SetInt(int64(in)) } obj.Field(a.Index).Set(slice) case reflect.Float64: f := v.([]float64) slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(0.0)), len(f), len(f)) for i, fl := range f { slice.Index(i).SetFloat(fl) } obj.Field(a.Index).Set(slice) case reflect.Bool: b := v.([]bool) slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(false)), len(b), len(b)) for i, b := range b { slice.Index(i).SetBool(b) } obj.Field(a.Index).Set(slice) default: return FuncResponse{}, fmt.Errorf("unsupported type %s for field %s", a.Type.Kind().String(), name) } } else if a.Optional { if v == nil { continue } switch a.Type.Kind() { case reflect.String: str := v.(string) obj.Field(a.Index).Set(reflect.ValueOf(&str)) case reflect.Int: i := v.(int) obj.Field(a.Index).Set(reflect.ValueOf(&i)) case reflect.Float64: f := v.(float64) obj.Field(a.Index).Set(reflect.ValueOf(&f)) case reflect.Bool: b := v.(bool) obj.Field(a.Index).Set(reflect.ValueOf(&b)) default: return FuncResponse{}, fmt.Errorf("unsupported type %s for field %s", a.Type.Kind().String(), name) } } else { switch a.Type.Kind() { case reflect.String: obj.Field(a.Index).SetString(v.(string)) case reflect.Int: obj.Field(a.Index).SetInt(int64(v.(int))) case reflect.Float64: obj.Field(a.Index).SetFloat(v.(float64)) case reflect.Bool: obj.Field(a.Index).SetBool(v.(bool)) default: return FuncResponse{}, fmt.Errorf("unsupported type %s for field %s", a.Type.Kind().String(), name) } } } } res := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), obj}) if res[1].IsNil() { return res[0].Interface().(FuncResponse), nil } return FuncResponse{}, res[1].Interface().(error) }