answer/pkg/toolbox/toolbox.go

209 lines
4.0 KiB
Go

package toolbox
import (
"context"
"errors"
"fmt"
"golang.org/x/sync/errgroup"
"github.com/tmc/langchaingo/llms"
)
type ToolBox map[string]*Tool
var (
ErrToolNotFound = errors.New("tool not found")
)
type ToolResults []ToolResult
func (r ToolResults) ToMessageContent() llms.MessageContent {
var res = llms.MessageContent{
Role: llms.ChatMessageTypeTool,
}
for _, v := range r {
res.Parts = append(res.Parts, v.ToToolCallResponse())
}
return res
}
type ToolResult struct {
ID string
Name string
Result string
Source string
Error error
}
func (r ToolResult) ToToolCallResponse() llms.ToolCallResponse {
if r.Error != nil {
return llms.ToolCallResponse{
ToolCallID: r.ID,
Name: r.Name,
Content: "error executing: " + r.Error.Error(),
}
}
return llms.ToolCallResponse{
ToolCallID: r.ID,
Name: r.Name,
Content: r.Result,
}
}
func (tb ToolBox) Execute(ctx context.Context, call llms.ToolCall) (ToolResult, error) {
if call.Type != "function" {
return ToolResult{}, fmt.Errorf("unsupported tool type: %s", call.Type)
}
if call.FunctionCall == nil {
return ToolResult{}, errors.New("function call is nil")
}
tool, ok := tb[call.FunctionCall.Name]
if !ok {
return ToolResult{}, fmt.Errorf("%w: %s", ErrToolNotFound, call.FunctionCall.Name)
}
res, err := tool.Execute(ctx, call.FunctionCall.Arguments)
if err != nil {
return ToolResult{
ID: call.ID,
Name: tool.Name,
Error: err,
Source: res.Source,
}, nil
}
return ToolResult{
ID: call.ID,
Name: tool.Name,
Result: res.Result,
Source: res.Source,
Error: err,
}, nil
}
func (tb ToolBox) ExecuteAll(ctx Context, calls []llms.ToolCall) (ToolResults, error) {
var results []ToolResult
for _, call := range calls {
res, err := tb.Execute(ctx, call)
if err != nil {
return results, err
}
results = append(results, res)
}
return results, nil
}
func (tb ToolBox) ExecuteConcurrent(ctx Context, calls []llms.ToolCall) (ToolResults, error) {
var results []ToolResult
var ch = make(chan ToolResult, len(calls))
var eg = errgroup.Group{}
for _, call := range calls {
eg.Go(func() error {
c, cancel := ctx.WithCancel()
defer cancel()
res, err := tb.Execute(c, call)
if err != nil {
ch <- res
return nil
}
return err
})
}
err := eg.Wait()
if err != nil {
return results, err
}
for range calls {
results = append(results, <-ch)
}
return results, nil
}
type Answers struct {
Response llms.MessageContent
Answers []Answer
}
type Answer struct {
Answer string
Source string
ToolCallResponse llms.ToolCallResponse `json:"-"`
}
func (tb ToolBox) Run(ctx Context, model llms.Model, question string) (Answers, error) {
ctx = ctx.WithMessages([]llms.MessageContent{{
Role: llms.ChatMessageTypeGeneric,
Parts: []llms.ContentPart{llms.TextPart(question)},
}})
res, err := model.GenerateContent(ctx, ctx.GetMessages())
if err != nil {
return Answers{}, err
}
if res == nil {
return Answers{}, errors.New("no response from model")
}
if len(res.Choices) == 0 {
return Answers{}, errors.New("no response from model")
}
choice := res.Choices[0]
response := llms.MessageContent{
Role: llms.ChatMessageTypeAI,
Parts: []llms.ContentPart{llms.TextPart(choice.Content)},
}
for _, c := range choice.ToolCalls {
response.Parts = append(response.Parts, c)
}
results, err := tb.ExecuteConcurrent(ctx, choice.ToolCalls)
if err != nil {
return Answers{}, err
}
var answers []Answer
for _, r := range results {
if r.Error != nil {
answers = append(answers, Answer{
Answer: "error executing: " + r.Error.Error(),
Source: r.Source,
ToolCallResponse: r.ToToolCallResponse(),
})
} else {
answers = append(answers, Answer{
Answer: r.Result,
Source: r.Source,
ToolCallResponse: r.ToToolCallResponse(),
})
}
}
return Answers{
Response: response,
Answers: answers,
}, nil
}
func (tb ToolBox) Register(tool *Tool) {
tb[tool.Name] = tool
}