package toolbox import ( "context" "errors" "fmt" "golang.org/x/sync/errgroup" "github.com/tmc/langchaingo/llms" ) type ToolBox map[string]*Tool var ( ErrToolNotFound = errors.New("tool not found") ) type ToolResults []ToolResult func (r ToolResults) ToMessageContent() llms.MessageContent { var res = llms.MessageContent{ Role: llms.ChatMessageTypeTool, } for _, v := range r { res.Parts = append(res.Parts, v.ToToolCallResponse()) } return res } type ToolResult struct { ID string Name string Result string Source string Error error } func (r ToolResult) ToToolCallResponse() llms.ToolCallResponse { if r.Error != nil { return llms.ToolCallResponse{ ToolCallID: r.ID, Name: r.Name, Content: "error executing: " + r.Error.Error(), } } return llms.ToolCallResponse{ ToolCallID: r.ID, Name: r.Name, Content: r.Result, } } func (tb ToolBox) Execute(ctx context.Context, call llms.ToolCall) (ToolResult, error) { if call.Type != "function" { return ToolResult{}, fmt.Errorf("unsupported tool type: %s", call.Type) } if call.FunctionCall == nil { return ToolResult{}, errors.New("function call is nil") } tool, ok := tb[call.FunctionCall.Name] if !ok { return ToolResult{}, fmt.Errorf("%w: %s", ErrToolNotFound, call.FunctionCall.Name) } res, err := tool.Execute(ctx, call.FunctionCall.Arguments) if err != nil { return ToolResult{ ID: call.ID, Name: tool.Name, Error: err, Source: res.Source, }, nil } return ToolResult{ ID: call.ID, Name: tool.Name, Result: res.Result, Source: res.Source, Error: err, }, nil } func (tb ToolBox) ExecuteAll(ctx Context, calls []llms.ToolCall) (ToolResults, error) { var results []ToolResult for _, call := range calls { res, err := tb.Execute(ctx, call) if err != nil { return results, err } results = append(results, res) } return results, nil } func (tb ToolBox) ExecuteConcurrent(ctx Context, calls []llms.ToolCall) (ToolResults, error) { var results []ToolResult var ch = make(chan ToolResult, len(calls)) var eg = errgroup.Group{} for _, call := range calls { eg.Go(func() error { c, cancel := ctx.WithCancel() defer cancel() res, err := tb.Execute(c, call) if err != nil { ch <- res return nil } return err }) } err := eg.Wait() if err != nil { return results, err } for range calls { results = append(results, <-ch) } return results, nil } type Answers struct { Response llms.MessageContent Answers []Answer } type Answer struct { Answer string Source string ToolCallResponse llms.ToolCallResponse `json:"-"` } func (tb ToolBox) Run(ctx Context, model llms.Model, question string) (Answers, error) { ctx = ctx.WithMessages([]llms.MessageContent{{ Role: llms.ChatMessageTypeGeneric, Parts: []llms.ContentPart{llms.TextPart(question)}, }}) res, err := model.GenerateContent(ctx, ctx.GetMessages()) if err != nil { return Answers{}, err } if res == nil { return Answers{}, errors.New("no response from model") } if len(res.Choices) == 0 { return Answers{}, errors.New("no response from model") } choice := res.Choices[0] response := llms.MessageContent{ Role: llms.ChatMessageTypeAI, Parts: []llms.ContentPart{llms.TextPart(choice.Content)}, } for _, c := range choice.ToolCalls { response.Parts = append(response.Parts, c) } results, err := tb.ExecuteConcurrent(ctx, choice.ToolCalls) if err != nil { return Answers{}, err } var answers []Answer for _, r := range results { if r.Error != nil { answers = append(answers, Answer{ Answer: "error executing: " + r.Error.Error(), Source: r.Source, ToolCallResponse: r.ToToolCallResponse(), }) } else { answers = append(answers, Answer{ Answer: r.Result, Source: r.Source, ToolCallResponse: r.ToToolCallResponse(), }) } } return Answers{ Response: response, Answers: answers, }, nil } func (tb ToolBox) Register(tool *Tool) { tb[tool.Name] = tool }