118 lines
2.8 KiB
Go
118 lines
2.8 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/answer/pkg/toolbox"
|
|
|
|
"github.com/tmc/langchaingo/llms"
|
|
|
|
cache2 "gitea.stevedudenhoeffer.com/steve/answer/pkg/cache"
|
|
"gitea.stevedudenhoeffer.com/steve/answer/pkg/extractor"
|
|
"gitea.stevedudenhoeffer.com/steve/answer/pkg/search"
|
|
)
|
|
|
|
type Options struct {
|
|
MaxSearches int
|
|
MaxQuestions int
|
|
}
|
|
type Agent struct {
|
|
RemainingSearches *Counter
|
|
RemainingQuestions *Counter
|
|
Search search.Search
|
|
Cache cache2.Cache
|
|
Extractor extractor.Extractor
|
|
Model llms.Model
|
|
}
|
|
|
|
type Answers struct {
|
|
Response llms.MessageContent
|
|
Answers []Answer
|
|
}
|
|
|
|
type Answer struct {
|
|
Answer string
|
|
Source string
|
|
ToolCallResponse llms.ToolCallResponse `json:"-"`
|
|
}
|
|
|
|
type Question struct {
|
|
Question string
|
|
}
|
|
|
|
func New(o Options) *Agent {
|
|
searches := &Counter{}
|
|
searches.Add(int32(o.MaxSearches))
|
|
|
|
questions := &Counter{}
|
|
questions.Add(int32(o.MaxQuestions))
|
|
|
|
return &Agent{
|
|
RemainingSearches: searches,
|
|
RemainingQuestions: questions,
|
|
}
|
|
}
|
|
|
|
var (
|
|
ErrOutOfSearches = errors.New("out of searches")
|
|
ErrOutOfQuestions = errors.New("out of questions")
|
|
)
|
|
|
|
func ask(ctx *Context, q Question) (Answers, error) {
|
|
var tb = toolbox.ToolBox{}
|
|
|
|
if ctx.Agent.RemainingSearches.Load() > 0 {
|
|
tb.Register(SearchTool)
|
|
}
|
|
tb.Register(WolframTool)
|
|
tb.Register(AnswerTool)
|
|
|
|
return tb.Run(ctx, q)
|
|
}
|
|
|
|
var SummarizeAnswers = toolbox.FromFunction(
|
|
func(ctx *Context, args struct {
|
|
Summary string `description:"the summary of the answers"`
|
|
}) (toolbox.FuncResponse, error) {
|
|
return toolbox.FuncResponse{Result: args.Summary}, nil
|
|
}).
|
|
WithName("summarize_answers").
|
|
WithDescription(`You are given previously figured out answers and they are in the format of: [ { "answer": "the answer", "source": "the source of the answer" }, { "answer": "answer 2", "source": "the source for answer2" } ]. You need to summarize the answers into a single string. Be sure to make the summary clear and concise, but include the sources at some point.`)
|
|
|
|
// Ask is an incoming call to the agent, it will create an internal Context from an incoming context.Context
|
|
func (a *Agent) Ask(ctx context.Context, q Question) (string, error) {
|
|
c := From(ctx, a)
|
|
if !a.RemainingQuestions.Take() {
|
|
return "", ErrOutOfQuestions
|
|
}
|
|
|
|
answers, err := ask(c, q)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
tb := toolbox.ToolBox{}
|
|
|
|
tb.Register(SummarizeAnswers)
|
|
|
|
b, err := json.Marshal(answers.Answers)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to marshal answers: %w", err)
|
|
}
|
|
|
|
answers, err = tb.Run(c, Question{Question: string(b)})
|
|
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to summarize answers: %w", err)
|
|
}
|
|
|
|
if len(answers.Answers) == 0 {
|
|
return "", errors.New("no response from model")
|
|
}
|
|
return answers.Answers[0].Answer, nil
|
|
|
|
}
|