209 lines
4.0 KiB
Go
209 lines
4.0 KiB
Go
package toolbox
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"golang.org/x/sync/errgroup"
|
|
|
|
"github.com/tmc/langchaingo/llms"
|
|
)
|
|
|
|
type ToolBox map[string]*Tool
|
|
|
|
var (
|
|
ErrToolNotFound = errors.New("tool not found")
|
|
)
|
|
|
|
type ToolResults []ToolResult
|
|
|
|
func (r ToolResults) ToMessageContent() llms.MessageContent {
|
|
var res = llms.MessageContent{
|
|
Role: llms.ChatMessageTypeTool,
|
|
}
|
|
|
|
for _, v := range r {
|
|
res.Parts = append(res.Parts, v.ToToolCallResponse())
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
type ToolResult struct {
|
|
ID string
|
|
Name string
|
|
Result string
|
|
Source string
|
|
Error error
|
|
}
|
|
|
|
func (r ToolResult) ToToolCallResponse() llms.ToolCallResponse {
|
|
if r.Error != nil {
|
|
return llms.ToolCallResponse{
|
|
ToolCallID: r.ID,
|
|
Name: r.Name,
|
|
Content: "error executing: " + r.Error.Error(),
|
|
}
|
|
}
|
|
|
|
return llms.ToolCallResponse{
|
|
ToolCallID: r.ID,
|
|
Name: r.Name,
|
|
Content: r.Result,
|
|
}
|
|
}
|
|
func (tb ToolBox) Execute(ctx context.Context, call llms.ToolCall) (ToolResult, error) {
|
|
if call.Type != "function" {
|
|
return ToolResult{}, fmt.Errorf("unsupported tool type: %s", call.Type)
|
|
}
|
|
|
|
if call.FunctionCall == nil {
|
|
return ToolResult{}, errors.New("function call is nil")
|
|
}
|
|
|
|
tool, ok := tb[call.FunctionCall.Name]
|
|
if !ok {
|
|
return ToolResult{}, fmt.Errorf("%w: %s", ErrToolNotFound, call.FunctionCall.Name)
|
|
}
|
|
|
|
res, err := tool.Execute(ctx, call.FunctionCall.Arguments)
|
|
if err != nil {
|
|
return ToolResult{
|
|
ID: call.ID,
|
|
Name: tool.Name,
|
|
Error: err,
|
|
Source: res.Source,
|
|
}, nil
|
|
}
|
|
|
|
return ToolResult{
|
|
ID: call.ID,
|
|
Name: tool.Name,
|
|
Result: res.Result,
|
|
Source: res.Source,
|
|
Error: err,
|
|
}, nil
|
|
}
|
|
|
|
func (tb ToolBox) ExecuteAll(ctx Context, calls []llms.ToolCall) (ToolResults, error) {
|
|
var results []ToolResult
|
|
|
|
for _, call := range calls {
|
|
res, err := tb.Execute(ctx, call)
|
|
if err != nil {
|
|
return results, err
|
|
}
|
|
|
|
results = append(results, res)
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func (tb ToolBox) ExecuteConcurrent(ctx Context, calls []llms.ToolCall) (ToolResults, error) {
|
|
var results []ToolResult
|
|
var ch = make(chan ToolResult, len(calls))
|
|
var eg = errgroup.Group{}
|
|
|
|
for _, call := range calls {
|
|
eg.Go(func() error {
|
|
c, cancel := ctx.WithCancel()
|
|
defer cancel()
|
|
|
|
res, err := tb.Execute(c, call)
|
|
if err != nil {
|
|
ch <- res
|
|
return nil
|
|
}
|
|
|
|
return err
|
|
})
|
|
}
|
|
|
|
err := eg.Wait()
|
|
if err != nil {
|
|
return results, err
|
|
}
|
|
|
|
for range calls {
|
|
results = append(results, <-ch)
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
type Answers struct {
|
|
Response llms.MessageContent
|
|
Answers []Answer
|
|
}
|
|
|
|
type Answer struct {
|
|
Answer string
|
|
Source string
|
|
ToolCallResponse llms.ToolCallResponse `json:"-"`
|
|
}
|
|
|
|
func (tb ToolBox) Run(ctx Context, model llms.Model, question string) (Answers, error) {
|
|
ctx = ctx.WithMessages([]llms.MessageContent{{
|
|
Role: llms.ChatMessageTypeGeneric,
|
|
Parts: []llms.ContentPart{llms.TextPart(question)},
|
|
}})
|
|
|
|
res, err := model.GenerateContent(ctx, ctx.GetMessages())
|
|
if err != nil {
|
|
return Answers{}, err
|
|
}
|
|
|
|
if res == nil {
|
|
return Answers{}, errors.New("no response from model")
|
|
}
|
|
|
|
if len(res.Choices) == 0 {
|
|
return Answers{}, errors.New("no response from model")
|
|
}
|
|
|
|
choice := res.Choices[0]
|
|
|
|
response := llms.MessageContent{
|
|
Role: llms.ChatMessageTypeAI,
|
|
Parts: []llms.ContentPart{llms.TextPart(choice.Content)},
|
|
}
|
|
|
|
for _, c := range choice.ToolCalls {
|
|
response.Parts = append(response.Parts, c)
|
|
}
|
|
|
|
results, err := tb.ExecuteConcurrent(ctx, choice.ToolCalls)
|
|
if err != nil {
|
|
return Answers{}, err
|
|
}
|
|
|
|
var answers []Answer
|
|
|
|
for _, r := range results {
|
|
if r.Error != nil {
|
|
answers = append(answers, Answer{
|
|
Answer: "error executing: " + r.Error.Error(),
|
|
Source: r.Source,
|
|
ToolCallResponse: r.ToToolCallResponse(),
|
|
})
|
|
} else {
|
|
answers = append(answers, Answer{
|
|
Answer: r.Result,
|
|
Source: r.Source,
|
|
ToolCallResponse: r.ToToolCallResponse(),
|
|
})
|
|
}
|
|
}
|
|
|
|
return Answers{
|
|
Response: response,
|
|
Answers: answers,
|
|
}, nil
|
|
}
|
|
|
|
func (tb ToolBox) Register(tool *Tool) {
|
|
tb[tool.Name] = tool
|
|
}
|