restructured answers a bit
This commit is contained in:
17
pkg/toolbox/context.go
Normal file
17
pkg/toolbox/context.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package toolbox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
)
|
||||
|
||||
type Context interface {
|
||||
context.Context
|
||||
|
||||
WithCancel() (Context, func())
|
||||
WithTimeout(time.Duration) (Context, func())
|
||||
WithMessages([]llms.MessageContent) Context
|
||||
GetMessages() []llms.MessageContent
|
||||
}
|
299
pkg/toolbox/function.go
Normal file
299
pkg/toolbox/function.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package toolbox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type FuncResponse struct {
|
||||
Result string
|
||||
Source string
|
||||
}
|
||||
|
||||
type funcCache struct {
|
||||
sync.RWMutex
|
||||
m map[reflect.Type]function
|
||||
}
|
||||
|
||||
func (c *funcCache) get(value reflect.Value) (function, bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
fn, ok := c.m[value.Type()]
|
||||
if ok {
|
||||
slog.Info("cache hit for function", "function", value.Type().String())
|
||||
}
|
||||
return fn, ok
|
||||
}
|
||||
|
||||
func (c *funcCache) set(value reflect.Value, fn function) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
c.m[value.Type()] = fn
|
||||
}
|
||||
|
||||
var cache = funcCache{m: map[reflect.Type]function{}}
|
||||
|
||||
type arg struct {
|
||||
Name string
|
||||
Type reflect.Type
|
||||
Index int
|
||||
Values []string
|
||||
Optional bool
|
||||
Array bool
|
||||
Description string
|
||||
}
|
||||
|
||||
func (a arg) Schema() map[string]any {
|
||||
var res = map[string]any{}
|
||||
|
||||
if a.Array {
|
||||
res["type"] = "array"
|
||||
res["items"] = map[string]any{"type": a.Type.Kind().String()}
|
||||
} else {
|
||||
res["type"] = a.Type.Name()
|
||||
}
|
||||
|
||||
if !a.Optional {
|
||||
res["required"] = true
|
||||
}
|
||||
|
||||
if len(a.Values) > 0 {
|
||||
res["enum"] = a.Values
|
||||
}
|
||||
|
||||
if a.Description != "" {
|
||||
res["description"] = a.Description
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
type function struct {
|
||||
fn reflect.Value
|
||||
argType reflect.Type
|
||||
args map[string]arg
|
||||
}
|
||||
|
||||
var ErrInvalidFunction = errors.New("invalid function")
|
||||
|
||||
// analyzeFuncFromReflect extracts metadata from a reflect.Value representing a function and returns a structured result.
|
||||
// It maps the function's parameter names to their corresponding reflect.Type and encapsulates them in a function struct.
|
||||
// Returns a function struct with extracted information and an error if the operation fails.
|
||||
// The first parameter to the function must be a *Context.
|
||||
// The second parameter to the function must be a struct, all the fields of which will be passed as arguments to the
|
||||
// function to be analyzed.
|
||||
// Struct tags supported are:
|
||||
// - `name:"<name>"` to specify the name of the parameter (default is the field name)
|
||||
// - `description:"<description>"` to specify a description of the parameter (default is "")
|
||||
// - `values:"<value1>,<value2>,..."` to specify a list of possible values for the parameter (default is "") only for
|
||||
// string, int, and float types
|
||||
//
|
||||
// Allowed types on the struct are:
|
||||
// - string, *string, []string
|
||||
// - int, *int, []int
|
||||
// - float64, *float64, []float64
|
||||
// - bool, *bool, []bool
|
||||
//
|
||||
// Pointer types imply that the parameter is optional.
|
||||
// The function must have at most 2 parameters.
|
||||
// The function must return a string and an error.
|
||||
// The function must be of the form `func(*agent.Context, T) (FuncResponse, error)`.
|
||||
func analyzeFuncFromReflect(fn reflect.Value) (function, error) {
|
||||
if f, ok := cache.get(fn); ok {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
var res function
|
||||
t := fn.Type()
|
||||
args := map[string]arg{}
|
||||
|
||||
for i := 0; i < t.NumIn(); i++ {
|
||||
if i == 0 {
|
||||
if t.In(i).String() != "*agent.Context" {
|
||||
return res, fmt.Errorf("%w: first parameter must be *agent.Context", ErrInvalidFunction)
|
||||
}
|
||||
continue
|
||||
} else if i == 1 {
|
||||
if t.In(i).Kind() != reflect.Struct {
|
||||
return res, fmt.Errorf("%w: second parameter must be a struct", ErrInvalidFunction)
|
||||
}
|
||||
res.argType = t.In(i)
|
||||
|
||||
for j := 0; j < res.argType.NumField(); j++ {
|
||||
field := res.argType.Field(j)
|
||||
|
||||
a := arg{
|
||||
Name: field.Name,
|
||||
Type: field.Type,
|
||||
Index: j,
|
||||
Description: "",
|
||||
}
|
||||
|
||||
ft := field.Type
|
||||
// if it's a pointer, it's optional
|
||||
if ft.Kind() == reflect.Ptr {
|
||||
a.Optional = true
|
||||
ft = ft.Elem()
|
||||
} else if ft.Kind() == reflect.Slice {
|
||||
a.Array = true
|
||||
ft = ft.Elem()
|
||||
}
|
||||
|
||||
if ft.Kind() != reflect.String && ft.Kind() != reflect.Int && ft.Kind() != reflect.Float64 && ft.Kind() != reflect.Bool {
|
||||
return res, fmt.Errorf("%w: unsupported type %s", ErrInvalidFunction, ft.Kind().String())
|
||||
}
|
||||
|
||||
a.Type = ft
|
||||
|
||||
if name, ok := field.Tag.Lookup("name"); ok {
|
||||
a.Name = name
|
||||
a.Name = strings.TrimSpace(a.Name)
|
||||
|
||||
if a.Name == "" {
|
||||
return res, fmt.Errorf("%w: name tag cannot be empty", ErrInvalidFunction)
|
||||
}
|
||||
}
|
||||
if description, ok := field.Tag.Lookup("description"); ok {
|
||||
a.Description = description
|
||||
}
|
||||
if values, ok := field.Tag.Lookup("values"); ok {
|
||||
a.Values = strings.Split(values, ",")
|
||||
for i, v := range a.Values {
|
||||
a.Values[i] = strings.TrimSpace(v)
|
||||
}
|
||||
|
||||
if ft.Kind() != reflect.String && ft.Kind() != reflect.Int && ft.Kind() != reflect.Float64 {
|
||||
return res, fmt.Errorf("%w: values tag only supported for string, int, and float types", ErrInvalidFunction)
|
||||
}
|
||||
}
|
||||
|
||||
args[field.Name] = a
|
||||
}
|
||||
} else {
|
||||
return res, fmt.Errorf("%w: function must have at most 2 parameters", ErrInvalidFunction)
|
||||
}
|
||||
}
|
||||
|
||||
// finally ensure that the function returns a FuncResponse and an error
|
||||
if t.NumOut() != 2 || t.Out(0).String() != "agent.FuncResponse" || t.Out(1).String() != "error" {
|
||||
return res, fmt.Errorf("%w: function must return a FuncResponse and an error", ErrInvalidFunction)
|
||||
}
|
||||
|
||||
cache.set(fn, res)
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func analyzeFunction[T any, AgentContext any](fn func(AgentContext, T) (FuncResponse, error)) (function, error) {
|
||||
return analyzeFuncFromReflect(reflect.ValueOf(fn))
|
||||
}
|
||||
|
||||
// Execute will execute the given function with the given context and arguments.
|
||||
// Returns the result of the execution and an error if the operation fails.
|
||||
// The arguments must be a JSON-encoded string that represents the struct to be passed to the function.
|
||||
func (f function) Execute(ctx context.Context, args string) (FuncResponse, error) {
|
||||
var m = map[string]any{}
|
||||
|
||||
err := json.Unmarshal([]byte(args), &m)
|
||||
if err != nil {
|
||||
return FuncResponse{}, fmt.Errorf("failed to unmarshal arguments: %w", err)
|
||||
}
|
||||
|
||||
var obj = reflect.New(f.argType).Elem()
|
||||
|
||||
// TODO: ensure that "required" fields are present in the arguments
|
||||
for name, a := range f.args {
|
||||
if v, ok := m[name]; ok {
|
||||
if a.Array {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch a.Type.Kind() {
|
||||
case reflect.String:
|
||||
s := v.([]string)
|
||||
slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf("")), len(s), len(s))
|
||||
for i, str := range s {
|
||||
slice.Index(i).SetString(str)
|
||||
}
|
||||
obj.Field(a.Index).Set(slice)
|
||||
|
||||
case reflect.Int:
|
||||
i := v.([]int)
|
||||
slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(0)), len(i), len(i))
|
||||
for i, in := range i {
|
||||
slice.Index(i).SetInt(int64(in))
|
||||
}
|
||||
obj.Field(a.Index).Set(slice)
|
||||
|
||||
case reflect.Float64:
|
||||
f := v.([]float64)
|
||||
slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(0.0)), len(f), len(f))
|
||||
for i, fl := range f {
|
||||
slice.Index(i).SetFloat(fl)
|
||||
}
|
||||
obj.Field(a.Index).Set(slice)
|
||||
|
||||
case reflect.Bool:
|
||||
b := v.([]bool)
|
||||
slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(false)), len(b), len(b))
|
||||
for i, b := range b {
|
||||
slice.Index(i).SetBool(b)
|
||||
}
|
||||
obj.Field(a.Index).Set(slice)
|
||||
|
||||
default:
|
||||
return FuncResponse{}, fmt.Errorf("unsupported type %s for field %s", a.Type.Kind().String(), name)
|
||||
}
|
||||
|
||||
} else if a.Optional {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
switch a.Type.Kind() {
|
||||
case reflect.String:
|
||||
str := v.(string)
|
||||
obj.Field(a.Index).Set(reflect.ValueOf(&str))
|
||||
case reflect.Int:
|
||||
i := v.(int)
|
||||
obj.Field(a.Index).Set(reflect.ValueOf(&i))
|
||||
case reflect.Float64:
|
||||
f := v.(float64)
|
||||
obj.Field(a.Index).Set(reflect.ValueOf(&f))
|
||||
case reflect.Bool:
|
||||
b := v.(bool)
|
||||
obj.Field(a.Index).Set(reflect.ValueOf(&b))
|
||||
default:
|
||||
return FuncResponse{}, fmt.Errorf("unsupported type %s for field %s", a.Type.Kind().String(), name)
|
||||
}
|
||||
} else {
|
||||
switch a.Type.Kind() {
|
||||
case reflect.String:
|
||||
obj.Field(a.Index).SetString(v.(string))
|
||||
case reflect.Int:
|
||||
obj.Field(a.Index).SetInt(int64(v.(int)))
|
||||
case reflect.Float64:
|
||||
obj.Field(a.Index).SetFloat(v.(float64))
|
||||
case reflect.Bool:
|
||||
obj.Field(a.Index).SetBool(v.(bool))
|
||||
default:
|
||||
return FuncResponse{}, fmt.Errorf("unsupported type %s for field %s", a.Type.Kind().String(), name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
res := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), obj})
|
||||
if res[1].IsNil() {
|
||||
return res[0].Interface().(FuncResponse), nil
|
||||
}
|
||||
|
||||
return FuncResponse{}, res[1].Interface().(error)
|
||||
}
|
75
pkg/toolbox/tool.go
Normal file
75
pkg/toolbox/tool.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package toolbox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
)
|
||||
|
||||
type Tool struct {
|
||||
Name string
|
||||
Description string
|
||||
Function function
|
||||
}
|
||||
|
||||
func (t *Tool) Tool() llms.Tool {
|
||||
return llms.Tool{
|
||||
Type: "function",
|
||||
Function: t.Definition(),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tool) Definition() *llms.FunctionDefinition {
|
||||
var properties = map[string]any{}
|
||||
for name, arg := range t.Function.args {
|
||||
properties[name] = arg.Schema()
|
||||
}
|
||||
|
||||
var res = llms.FunctionDefinition{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: map[string]any{"type": "object", "properties": properties},
|
||||
}
|
||||
return &res
|
||||
}
|
||||
|
||||
// Execute executes the tool with the given context and arguments.
|
||||
// Returns the result of the execution and an error if the operation fails.
|
||||
// The arguments must be a JSON-encoded string that represents the struct to be passed to the function.
|
||||
func (t *Tool) Execute(ctx context.Context, args string) (FuncResponse, error) {
|
||||
return t.Function.Execute(ctx, args)
|
||||
}
|
||||
|
||||
func FromFunction[T any, AgentContext any](fn func(AgentContext, T) (FuncResponse, error)) *Tool {
|
||||
f, err := analyzeFunction(fn)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &Tool{
|
||||
Name: reflect.TypeOf(fn).Name(),
|
||||
Description: "This is a tool",
|
||||
Function: f,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tool) WithName(name string) *Tool {
|
||||
t.Name = name
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Tool) WithDescription(description string) *Tool {
|
||||
t.Description = description
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Tool) WithFunction(fn any) *Tool {
|
||||
f, err := analyzeFuncFromReflect(reflect.ValueOf(fn))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
t.Function = f
|
||||
return t
|
||||
}
|
208
pkg/toolbox/toolbox.go
Normal file
208
pkg/toolbox/toolbox.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package toolbox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
)
|
||||
|
||||
type ToolBox map[string]*Tool
|
||||
|
||||
var (
|
||||
ErrToolNotFound = errors.New("tool not found")
|
||||
)
|
||||
|
||||
type ToolResults []ToolResult
|
||||
|
||||
func (r ToolResults) ToMessageContent() llms.MessageContent {
|
||||
var res = llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeTool,
|
||||
}
|
||||
|
||||
for _, v := range r {
|
||||
res.Parts = append(res.Parts, v.ToToolCallResponse())
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
type ToolResult struct {
|
||||
ID string
|
||||
Name string
|
||||
Result string
|
||||
Source string
|
||||
Error error
|
||||
}
|
||||
|
||||
func (r ToolResult) ToToolCallResponse() llms.ToolCallResponse {
|
||||
if r.Error != nil {
|
||||
return llms.ToolCallResponse{
|
||||
ToolCallID: r.ID,
|
||||
Name: r.Name,
|
||||
Content: "error executing: " + r.Error.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
return llms.ToolCallResponse{
|
||||
ToolCallID: r.ID,
|
||||
Name: r.Name,
|
||||
Content: r.Result,
|
||||
}
|
||||
}
|
||||
func (tb ToolBox) Execute(ctx context.Context, call llms.ToolCall) (ToolResult, error) {
|
||||
if call.Type != "function" {
|
||||
return ToolResult{}, fmt.Errorf("unsupported tool type: %s", call.Type)
|
||||
}
|
||||
|
||||
if call.FunctionCall == nil {
|
||||
return ToolResult{}, errors.New("function call is nil")
|
||||
}
|
||||
|
||||
tool, ok := tb[call.FunctionCall.Name]
|
||||
if !ok {
|
||||
return ToolResult{}, fmt.Errorf("%w: %s", ErrToolNotFound, call.FunctionCall.Name)
|
||||
}
|
||||
|
||||
res, err := tool.Execute(ctx, call.FunctionCall.Arguments)
|
||||
if err != nil {
|
||||
return ToolResult{
|
||||
ID: call.ID,
|
||||
Name: tool.Name,
|
||||
Error: err,
|
||||
Source: res.Source,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return ToolResult{
|
||||
ID: call.ID,
|
||||
Name: tool.Name,
|
||||
Result: res.Result,
|
||||
Source: res.Source,
|
||||
Error: err,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (tb ToolBox) ExecuteAll(ctx Context, calls []llms.ToolCall) (ToolResults, error) {
|
||||
var results []ToolResult
|
||||
|
||||
for _, call := range calls {
|
||||
res, err := tb.Execute(ctx, call)
|
||||
if err != nil {
|
||||
return results, err
|
||||
}
|
||||
|
||||
results = append(results, res)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (tb ToolBox) ExecuteConcurrent(ctx Context, calls []llms.ToolCall) (ToolResults, error) {
|
||||
var results []ToolResult
|
||||
var ch = make(chan ToolResult, len(calls))
|
||||
var eg = errgroup.Group{}
|
||||
|
||||
for _, call := range calls {
|
||||
eg.Go(func() error {
|
||||
c, cancel := ctx.WithCancel()
|
||||
defer cancel()
|
||||
|
||||
res, err := tb.Execute(c, call)
|
||||
if err != nil {
|
||||
ch <- res
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
err := eg.Wait()
|
||||
if err != nil {
|
||||
return results, err
|
||||
}
|
||||
|
||||
for range calls {
|
||||
results = append(results, <-ch)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
type Answers struct {
|
||||
Response llms.MessageContent
|
||||
Answers []Answer
|
||||
}
|
||||
|
||||
type Answer struct {
|
||||
Answer string
|
||||
Source string
|
||||
ToolCallResponse llms.ToolCallResponse `json:"-"`
|
||||
}
|
||||
|
||||
func (tb ToolBox) Run(ctx Context, model llms.Model, question string) (Answers, error) {
|
||||
ctx = ctx.WithMessages([]llms.MessageContent{{
|
||||
Role: llms.ChatMessageTypeGeneric,
|
||||
Parts: []llms.ContentPart{llms.TextPart(question)},
|
||||
}})
|
||||
|
||||
res, err := model.GenerateContent(ctx, ctx.GetMessages())
|
||||
if err != nil {
|
||||
return Answers{}, err
|
||||
}
|
||||
|
||||
if res == nil {
|
||||
return Answers{}, errors.New("no response from model")
|
||||
}
|
||||
|
||||
if len(res.Choices) == 0 {
|
||||
return Answers{}, errors.New("no response from model")
|
||||
}
|
||||
|
||||
choice := res.Choices[0]
|
||||
|
||||
response := llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeAI,
|
||||
Parts: []llms.ContentPart{llms.TextPart(choice.Content)},
|
||||
}
|
||||
|
||||
for _, c := range choice.ToolCalls {
|
||||
response.Parts = append(response.Parts, c)
|
||||
}
|
||||
|
||||
results, err := tb.ExecuteConcurrent(ctx, choice.ToolCalls)
|
||||
if err != nil {
|
||||
return Answers{}, err
|
||||
}
|
||||
|
||||
var answers []Answer
|
||||
|
||||
for _, r := range results {
|
||||
if r.Error != nil {
|
||||
answers = append(answers, Answer{
|
||||
Answer: "error executing: " + r.Error.Error(),
|
||||
Source: r.Source,
|
||||
ToolCallResponse: r.ToToolCallResponse(),
|
||||
})
|
||||
} else {
|
||||
answers = append(answers, Answer{
|
||||
Answer: r.Result,
|
||||
Source: r.Source,
|
||||
ToolCallResponse: r.ToToolCallResponse(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return Answers{
|
||||
Response: response,
|
||||
Answers: answers,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (tb ToolBox) Register(tool *Tool) {
|
||||
tb[tool.Name] = tool
|
||||
}
|
Reference in New Issue
Block a user