Compare commits
6 Commits
langchain
...
5b338d4129
Author | SHA1 | Date | |
---|---|---|---|
5b338d4129 | |||
bdaa3b7d96 | |||
bf16b4b0cd | |||
568a0e99e4 | |||
4c106d32cb | |||
1d47cf5758 |
80
go.mod
80
go.mod
@@ -4,31 +4,27 @@ go 1.23.2
|
|||||||
|
|
||||||
replace github.com/rocketlaunchr/google-search => github.com/chrisjoyce911/google-search v0.0.0-20230910003754-e501aedf805a
|
replace github.com/rocketlaunchr/google-search => github.com/chrisjoyce911/google-search v0.0.0-20230910003754-e501aedf805a
|
||||||
|
|
||||||
replace gitea.stevedudenhoeffer.com/steve/go-llm => ../go-llm
|
//replace gitea.stevedudenhoeffer.com/steve/go-llm => ../go-llm
|
||||||
|
|
||||||
require (
|
require (
|
||||||
gitea.stevedudenhoeffer.com/steve/go-extractor v0.0.0-20250123020607-964a98a5a884
|
gitea.stevedudenhoeffer.com/steve/go-extractor v0.0.0-20250315044602-7c0e44a22f2c
|
||||||
gitea.stevedudenhoeffer.com/steve/go-llm v0.0.0-20250123045620-0d909edd44d9
|
gitea.stevedudenhoeffer.com/steve/go-llm v0.0.0-20250317023858-7f5e34e437a7
|
||||||
github.com/Edw590/go-wolfram v0.0.0-20241010091529-fb9031908c5d
|
github.com/Edw590/go-wolfram v0.0.0-20241010091529-fb9031908c5d
|
||||||
github.com/advancedlogic/GoOse v0.0.0-20231203033844-ae6b36caf275
|
github.com/advancedlogic/GoOse v0.0.0-20231203033844-ae6b36caf275
|
||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/playwright-community/playwright-go v0.5001.0
|
github.com/playwright-community/playwright-go v0.5001.0
|
||||||
github.com/rocketlaunchr/google-search v1.1.6
|
github.com/rocketlaunchr/google-search v1.1.6
|
||||||
github.com/tmc/langchaingo v0.1.13
|
|
||||||
github.com/urfave/cli v1.22.16
|
github.com/urfave/cli v1.22.16
|
||||||
golang.org/x/sync v0.11.0
|
go.starlark.net v0.0.0-20250225190231-0d3f41d403af
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
cloud.google.com/go v0.118.3 // indirect
|
cloud.google.com/go v0.119.0 // indirect
|
||||||
cloud.google.com/go/ai v0.10.0 // indirect
|
cloud.google.com/go/ai v0.10.1 // indirect
|
||||||
cloud.google.com/go/auth v0.15.0 // indirect
|
cloud.google.com/go/auth v0.15.0 // indirect
|
||||||
cloud.google.com/go/auth/oauth2adapt v0.2.7 // indirect
|
cloud.google.com/go/auth/oauth2adapt v0.2.7 // indirect
|
||||||
cloud.google.com/go/compute/metadata v0.6.0 // indirect
|
cloud.google.com/go/compute/metadata v0.6.0 // indirect
|
||||||
cloud.google.com/go/longrunning v0.6.4 // indirect
|
cloud.google.com/go/longrunning v0.6.6 // indirect
|
||||||
github.com/Masterminds/goutils v1.1.1 // indirect
|
|
||||||
github.com/Masterminds/semver/v3 v3.2.0 // indirect
|
|
||||||
github.com/Masterminds/sprig/v3 v3.2.3 // indirect
|
|
||||||
github.com/PuerkitoBio/goquery v1.10.2 // indirect
|
github.com/PuerkitoBio/goquery v1.10.2 // indirect
|
||||||
github.com/andybalholm/cascadia v1.3.3 // indirect
|
github.com/andybalholm/cascadia v1.3.3 // indirect
|
||||||
github.com/antchfx/htmlquery v1.3.4 // indirect
|
github.com/antchfx/htmlquery v1.3.4 // indirect
|
||||||
@@ -36,13 +32,11 @@ require (
|
|||||||
github.com/antchfx/xpath v1.3.3 // indirect
|
github.com/antchfx/xpath v1.3.3 // indirect
|
||||||
github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de // indirect
|
github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de // indirect
|
||||||
github.com/cpuguy83/go-md2man/v2 v2.0.6 // indirect
|
github.com/cpuguy83/go-md2man/v2 v2.0.6 // indirect
|
||||||
github.com/deckarep/golang-set/v2 v2.7.0 // indirect
|
github.com/deckarep/golang-set/v2 v2.8.0 // indirect
|
||||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
|
||||||
github.com/fatih/set v0.2.1 // indirect
|
github.com/fatih/set v0.2.1 // indirect
|
||||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
github.com/gigawattio/window v0.0.0-20180317192513-0f5467e35573 // indirect
|
github.com/gigawattio/window v0.0.0-20180317192513-0f5467e35573 // indirect
|
||||||
github.com/go-jose/go-jose/v3 v3.0.3 // indirect
|
github.com/go-jose/go-jose/v3 v3.0.4 // indirect
|
||||||
github.com/go-logr/logr v1.4.2 // indirect
|
github.com/go-logr/logr v1.4.2 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
github.com/go-resty/resty/v2 v2.16.5 // indirect
|
github.com/go-resty/resty/v2 v2.16.5 // indirect
|
||||||
@@ -57,54 +51,38 @@ require (
|
|||||||
github.com/google/generative-ai-go v0.19.0 // indirect
|
github.com/google/generative-ai-go v0.19.0 // indirect
|
||||||
github.com/google/s2a-go v0.1.9 // indirect
|
github.com/google/s2a-go v0.1.9 // indirect
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect
|
||||||
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
|
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
|
||||||
github.com/goph/emperror v0.17.2 // indirect
|
|
||||||
github.com/huandu/xstrings v1.3.3 // indirect
|
|
||||||
github.com/imdario/mergo v0.3.13 // indirect
|
|
||||||
github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056 // indirect
|
github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
|
||||||
github.com/kennygrant/sanitize v1.2.4 // indirect
|
github.com/kennygrant/sanitize v1.2.4 // indirect
|
||||||
github.com/liushuangls/go-anthropic/v2 v2.13.1 // indirect
|
github.com/liushuangls/go-anthropic/v2 v2.14.1 // indirect
|
||||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||||
github.com/mitchellh/copystructure v1.0.0 // indirect
|
|
||||||
github.com/mitchellh/reflectwalk v1.0.0 // indirect
|
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
|
||||||
github.com/nikolalohinski/gonja v1.5.3 // indirect
|
|
||||||
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.0.9 // indirect
|
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pkoukk/tiktoken-go v0.1.6 // indirect
|
|
||||||
github.com/rivo/uniseg v0.4.7 // indirect
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||||
github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d // indirect
|
github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d // indirect
|
||||||
github.com/sashabaranov/go-openai v1.37.0 // indirect
|
github.com/sashabaranov/go-openai v1.38.0 // indirect
|
||||||
github.com/shopspring/decimal v1.2.0 // indirect
|
|
||||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
|
||||||
github.com/spf13/cast v1.3.1 // indirect
|
|
||||||
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect
|
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect
|
||||||
github.com/temoto/robotstxt v1.1.2 // indirect
|
github.com/temoto/robotstxt v1.1.2 // indirect
|
||||||
github.com/yargevad/filepathx v1.0.0 // indirect
|
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 // indirect
|
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 // indirect
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 // indirect
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect
|
||||||
go.opentelemetry.io/otel v1.34.0 // indirect
|
go.opentelemetry.io/otel v1.35.0 // indirect
|
||||||
go.opentelemetry.io/otel/metric v1.34.0 // indirect
|
go.opentelemetry.io/otel/metric v1.35.0 // indirect
|
||||||
go.opentelemetry.io/otel/trace v1.34.0 // indirect
|
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
||||||
go.starlark.net v0.0.0-20230302034142-4b1e35fe2254 // indirect
|
golang.org/x/crypto v0.36.0 // indirect
|
||||||
golang.org/x/crypto v0.34.0 // indirect
|
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
|
||||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
golang.org/x/net v0.37.0 // indirect
|
||||||
golang.org/x/net v0.35.0 // indirect
|
golang.org/x/oauth2 v0.28.0 // indirect
|
||||||
golang.org/x/oauth2 v0.26.0 // indirect
|
golang.org/x/sync v0.12.0 // indirect
|
||||||
golang.org/x/sys v0.30.0 // indirect
|
golang.org/x/sys v0.31.0 // indirect
|
||||||
golang.org/x/text v0.22.0 // indirect
|
golang.org/x/text v0.23.0 // indirect
|
||||||
golang.org/x/time v0.10.0 // indirect
|
golang.org/x/time v0.11.0 // indirect
|
||||||
google.golang.org/api v0.222.0 // indirect
|
google.golang.org/api v0.226.0 // indirect
|
||||||
google.golang.org/appengine v1.6.8 // indirect
|
google.golang.org/appengine v1.6.8 // indirect
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250219182151-9fdb1cabc7b2 // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20250313205543-e70fdf4c4cb4 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250219182151-9fdb1cabc7b2 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4 // indirect
|
||||||
google.golang.org/grpc v1.70.0 // indirect
|
google.golang.org/grpc v1.71.0 // indirect
|
||||||
google.golang.org/protobuf v1.36.5 // indirect
|
google.golang.org/protobuf v1.36.5 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
|
||||||
)
|
)
|
||||||
|
@@ -1,117 +0,0 @@
|
|||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/answer/pkg/toolbox"
|
|
||||||
|
|
||||||
"github.com/tmc/langchaingo/llms"
|
|
||||||
|
|
||||||
cache2 "gitea.stevedudenhoeffer.com/steve/answer/pkg/cache"
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/answer/pkg/extractor"
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/answer/pkg/search"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Options struct {
|
|
||||||
MaxSearches int
|
|
||||||
MaxQuestions int
|
|
||||||
}
|
|
||||||
type Agent struct {
|
|
||||||
RemainingSearches *Counter
|
|
||||||
RemainingQuestions *Counter
|
|
||||||
Search search.Search
|
|
||||||
Cache cache2.Cache
|
|
||||||
Extractor extractor.Extractor
|
|
||||||
Model llms.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
type Answers struct {
|
|
||||||
Response llms.MessageContent
|
|
||||||
Answers []Answer
|
|
||||||
}
|
|
||||||
|
|
||||||
type Answer struct {
|
|
||||||
Answer string
|
|
||||||
Source string
|
|
||||||
ToolCallResponse llms.ToolCallResponse `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Question struct {
|
|
||||||
Question string
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(o Options) *Agent {
|
|
||||||
searches := &Counter{}
|
|
||||||
searches.Add(int32(o.MaxSearches))
|
|
||||||
|
|
||||||
questions := &Counter{}
|
|
||||||
questions.Add(int32(o.MaxQuestions))
|
|
||||||
|
|
||||||
return &Agent{
|
|
||||||
RemainingSearches: searches,
|
|
||||||
RemainingQuestions: questions,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrOutOfSearches = errors.New("out of searches")
|
|
||||||
ErrOutOfQuestions = errors.New("out of questions")
|
|
||||||
)
|
|
||||||
|
|
||||||
func ask(ctx *Context, q Question) (Answers, error) {
|
|
||||||
var tb = toolbox.ToolBox{}
|
|
||||||
|
|
||||||
if ctx.Agent.RemainingSearches.Load() > 0 {
|
|
||||||
tb.Register(SearchTool)
|
|
||||||
}
|
|
||||||
tb.Register(WolframTool)
|
|
||||||
tb.Register(AnswerTool)
|
|
||||||
|
|
||||||
return tb.Run(ctx, q)
|
|
||||||
}
|
|
||||||
|
|
||||||
var SummarizeAnswers = toolbox.FromFunction(
|
|
||||||
func(ctx *Context, args struct {
|
|
||||||
Summary string `description:"the summary of the answers"`
|
|
||||||
}) (toolbox.FuncResponse, error) {
|
|
||||||
return toolbox.FuncResponse{Result: args.Summary}, nil
|
|
||||||
}).
|
|
||||||
WithName("summarize_answers").
|
|
||||||
WithDescription(`You are given previously figured out answers and they are in the format of: [ { "answer": "the answer", "source": "the source of the answer" }, { "answer": "answer 2", "source": "the source for answer2" } ]. You need to summarize the answers into a single string. Be sure to make the summary clear and concise, but include the sources at some point.`)
|
|
||||||
|
|
||||||
// Ask is an incoming call to the agent, it will create an internal Context from an incoming context.Context
|
|
||||||
func (a *Agent) Ask(ctx context.Context, q Question) (string, error) {
|
|
||||||
c := From(ctx, a)
|
|
||||||
if !a.RemainingQuestions.Take() {
|
|
||||||
return "", ErrOutOfQuestions
|
|
||||||
}
|
|
||||||
|
|
||||||
answers, err := ask(c, q)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
tb := toolbox.ToolBox{}
|
|
||||||
|
|
||||||
tb.Register(SummarizeAnswers)
|
|
||||||
|
|
||||||
b, err := json.Marshal(answers.Answers)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to marshal answers: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
answers, err = tb.Run(c, Question{Question: string(b)})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to summarize answers: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(answers.Answers) == 0 {
|
|
||||||
return "", errors.New("no response from model")
|
|
||||||
}
|
|
||||||
return answers.Answers[0].Answer, nil
|
|
||||||
|
|
||||||
}
|
|
@@ -1,12 +0,0 @@
|
|||||||
package agent
|
|
||||||
|
|
||||||
import "gitea.stevedudenhoeffer.com/steve/answer/pkg/toolbox"
|
|
||||||
|
|
||||||
var AnswerTool = toolbox.FromFunction(
|
|
||||||
func(ctx *Context, args struct {
|
|
||||||
Answer string `description:"the answer to the question"`
|
|
||||||
}) (toolbox.FuncResponse, error) {
|
|
||||||
return toolbox.FuncResponse{Result: args.Answer}, nil
|
|
||||||
}).
|
|
||||||
WithName("answer").
|
|
||||||
WithDescription("Answer the question")
|
|
@@ -1,48 +0,0 @@
|
|||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/answer/pkg/toolbox"
|
|
||||||
)
|
|
||||||
|
|
||||||
var AskTool = toolbox.FromFunction(
|
|
||||||
func(ctx *Context, args struct {
|
|
||||||
Question string `description:"the question to answer"`
|
|
||||||
}) (toolbox.FuncResponse, error) {
|
|
||||||
var q Question
|
|
||||||
|
|
||||||
q.Question = args.Question
|
|
||||||
ctx = ctx.WithQuestion(q)
|
|
||||||
|
|
||||||
answers, err := ask(ctx, q)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return toolbox.FuncResponse{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tb := toolbox.ToolBox{}
|
|
||||||
tb.Register(SummarizeAnswers)
|
|
||||||
|
|
||||||
b, err := json.Marshal(answers.Answers)
|
|
||||||
if err != nil {
|
|
||||||
return toolbox.FuncResponse{}, fmt.Errorf("failed to marshal answers: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
q = Question{Question: string(b)}
|
|
||||||
ctx = ctx.WithQuestion(q)
|
|
||||||
answers, err = tb.Run(ctx, q)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return toolbox.FuncResponse{}, fmt.Errorf("failed to summarize answers: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(answers.Answers) == 0 {
|
|
||||||
return toolbox.FuncResponse{}, fmt.Errorf("no response from model")
|
|
||||||
}
|
|
||||||
|
|
||||||
return toolbox.FuncResponse{Result: answers.Answers[0].Answer}, nil
|
|
||||||
}).
|
|
||||||
WithName("ask").
|
|
||||||
WithDescription("Ask the agent a question, this is useful for splitting a question into multiple parts")
|
|
@@ -1,112 +0,0 @@
|
|||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/answer/pkg/toolbox"
|
|
||||||
|
|
||||||
"github.com/tmc/langchaingo/llms"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Context struct {
|
|
||||||
context.Context
|
|
||||||
|
|
||||||
Messages []llms.MessageContent
|
|
||||||
Agent *Agent
|
|
||||||
}
|
|
||||||
|
|
||||||
func From(ctx context.Context, a *Agent) *Context {
|
|
||||||
if ctx == nil {
|
|
||||||
ctx = context.Background()
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Context{
|
|
||||||
Context: ctx,
|
|
||||||
Agent: a,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) WithMessage(m llms.MessageContent) *Context {
|
|
||||||
c.Messages = append(c.Messages, m)
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) WithMessages(m ...llms.MessageContent) *Context {
|
|
||||||
c.Messages = append(c.Messages, m...)
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) WithToolResults(r ...toolbox.ToolResult) *Context {
|
|
||||||
msg := llms.MessageContent{
|
|
||||||
Role: llms.ChatMessageTypeTool,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, v := range r {
|
|
||||||
res := v.Result
|
|
||||||
if v.Error != nil {
|
|
||||||
res = "error executing: " + v.Error.Error()
|
|
||||||
}
|
|
||||||
msg.Parts = append(msg.Parts, llms.ToolCallResponse{
|
|
||||||
ToolCallID: v.ID,
|
|
||||||
Name: v.Name,
|
|
||||||
Content: res,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.WithMessage(msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) WithAgent(a *Agent) *Context {
|
|
||||||
return &Context{
|
|
||||||
Context: c.Context,
|
|
||||||
Agent: a,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) WithCancel() (*Context, func()) {
|
|
||||||
ctx, cancel := context.WithCancel(c.Context)
|
|
||||||
return &Context{
|
|
||||||
Context: ctx,
|
|
||||||
Agent: c.Agent,
|
|
||||||
}, cancel
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) WithDeadline(deadline time.Time) (*Context, func()) {
|
|
||||||
ctx, cancel := context.WithDeadline(c.Context, deadline)
|
|
||||||
return &Context{
|
|
||||||
Context: ctx,
|
|
||||||
Agent: c.Agent,
|
|
||||||
}, cancel
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) WithTimeout(timeout time.Duration) (*Context, func()) {
|
|
||||||
ctx, cancel := context.WithTimeout(c.Context, timeout)
|
|
||||||
return &Context{
|
|
||||||
Context: ctx,
|
|
||||||
Agent: c.Agent,
|
|
||||||
}, cancel
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) WithValue(key, value interface{}) *Context {
|
|
||||||
return &Context{
|
|
||||||
Context: context.WithValue(c.Context, key, value),
|
|
||||||
Agent: c.Agent,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) Done() <-chan struct{} {
|
|
||||||
return c.Context.Done()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) Err() error {
|
|
||||||
return c.Context.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) Value(key interface{}) interface{} {
|
|
||||||
return c.Context.Value(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) Deadline() (time.Time, bool) {
|
|
||||||
return c.Context.Deadline()
|
|
||||||
}
|
|
@@ -1,25 +0,0 @@
|
|||||||
package agent
|
|
||||||
|
|
||||||
import "sync/atomic"
|
|
||||||
|
|
||||||
type Counter struct {
|
|
||||||
atomic.Int32
|
|
||||||
}
|
|
||||||
|
|
||||||
// Take will attempt to take an item from the counter. If the counter is zero, it will return false.
|
|
||||||
func (c *Counter) Take() bool {
|
|
||||||
for {
|
|
||||||
current := c.Load()
|
|
||||||
if current <= 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if c.CompareAndSwap(current, current-1) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return will return an item to the counter.
|
|
||||||
func (c *Counter) Return() {
|
|
||||||
c.Add(1)
|
|
||||||
}
|
|
@@ -1,13 +0,0 @@
|
|||||||
package agent
|
|
||||||
|
|
||||||
import "gitea.stevedudenhoeffer.com/steve/answer/pkg/toolbox"
|
|
||||||
|
|
||||||
var SearchTool = toolbox.FromFunction(
|
|
||||||
func(ctx *Context, args struct {
|
|
||||||
SearchFor string `description:"what to search for"`
|
|
||||||
Question string `description:"the question to answer with the search results"`
|
|
||||||
}) (toolbox.FuncResponse, error) {
|
|
||||||
return toolbox.FuncResponse{}, nil
|
|
||||||
}).
|
|
||||||
WithName("search").
|
|
||||||
WithDescription("Search the web and read a few articles to find the answer to the question")
|
|
@@ -1,34 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/tmc/langchaingo/llms"
|
|
||||||
)
|
|
||||||
|
|
||||||
func testFunction(args struct{ a, b int }) {
|
|
||||||
// This is a test function
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
v := reflect.New(reflect.TypeOf(testFunction))
|
|
||||||
|
|
||||||
t := reflect.TypeOf(testFunction)
|
|
||||||
|
|
||||||
for i := 0; i < t.NumIn(); i++ {
|
|
||||||
param := t.In(i)
|
|
||||||
|
|
||||||
llms.MessageContent{
|
|
||||||
Role: llms.ChatMessageTypeTool,
|
|
||||||
Parts: []llms.ContentPart{
|
|
||||||
llms.ToolCallResponse{
|
|
||||||
Name: "testFunction",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if param.Type().Kind() == reflect.Struct {
|
|
||||||
|
|
||||||
}
|
|
||||||
println(param.Name(), param.Kind().String())
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,30 +0,0 @@
|
|||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/answer/pkg/toolbox"
|
|
||||||
|
|
||||||
"github.com/Edw590/go-wolfram"
|
|
||||||
)
|
|
||||||
|
|
||||||
var WolframTool = toolbox.FromFunction(
|
|
||||||
func(ctx *Context, args struct {
|
|
||||||
Query string `description:"what to ask wolfram alpha"`
|
|
||||||
}) (toolbox.FuncResponse, error) {
|
|
||||||
var cl = wolfram.Client{
|
|
||||||
AppID: os.Getenv("WOLFRAM_APPID"),
|
|
||||||
}
|
|
||||||
|
|
||||||
unit := wolfram.Imperial
|
|
||||||
|
|
||||||
a, err := cl.GetShortAnswerQuery(args.Query, unit, 10)
|
|
||||||
if err != nil {
|
|
||||||
return toolbox.FuncResponse{}, fmt.Errorf("failed to get short answer from wolfram: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return toolbox.FuncResponse{Result: a, Source: "Wolfram|Alpha"}, nil
|
|
||||||
}).
|
|
||||||
WithName("wolfram").
|
|
||||||
WithDescription("ask wolfram alpha for the answer")
|
|
@@ -2,17 +2,16 @@ package answer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/Edw590/go-wolfram"
|
||||||
|
"go.starlark.net/lib/math"
|
||||||
|
"go.starlark.net/starlark"
|
||||||
|
"go.starlark.net/syntax"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tmc/langchaingo/agents"
|
|
||||||
|
|
||||||
"github.com/tmc/langchaingo/llms"
|
|
||||||
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/answer/pkg/cache"
|
"gitea.stevedudenhoeffer.com/steve/answer/pkg/cache"
|
||||||
"gitea.stevedudenhoeffer.com/steve/answer/pkg/extractor"
|
"gitea.stevedudenhoeffer.com/steve/answer/pkg/extractor"
|
||||||
"gitea.stevedudenhoeffer.com/steve/answer/pkg/search"
|
"gitea.stevedudenhoeffer.com/steve/answer/pkg/search"
|
||||||
@@ -27,7 +26,7 @@ type Question struct {
|
|||||||
// Question is the question to answer
|
// Question is the question to answer
|
||||||
Question string
|
Question string
|
||||||
|
|
||||||
Model llms.Model
|
Model gollm.ChatCompletion
|
||||||
|
|
||||||
Search search.Search
|
Search search.Search
|
||||||
|
|
||||||
@@ -37,6 +36,8 @@ type Question struct {
|
|||||||
// Answers is a list of answers to a question
|
// Answers is a list of answers to a question
|
||||||
type Answers []string
|
type Answers []string
|
||||||
|
|
||||||
|
const DefaultPrompt = "You are being asked to answer a question. You must respond with a function. You can answer it if you know the answer, or if some functions exist you can use those to help you find the answer."
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
// MaxSearches is the maximum possible number of searches to execute for this question. If this is set to 5, the function could
|
// MaxSearches is the maximum possible number of searches to execute for this question. If this is set to 5, the function could
|
||||||
// search up to 5 possible times to find an answer.
|
// search up to 5 possible times to find an answer.
|
||||||
@@ -55,6 +56,19 @@ type Options struct {
|
|||||||
// The "answer" and "no_answer" functions are not included in this callback.
|
// The "answer" and "no_answer" functions are not included in this callback.
|
||||||
// Return an error to stop the function from being called.
|
// Return an error to stop the function from being called.
|
||||||
OnNewFunction func(ctx context.Context, funcName string, question string, parameter string) error
|
OnNewFunction func(ctx context.Context, funcName string, question string, parameter string) error
|
||||||
|
|
||||||
|
// SystemPrompt is the prompt to use when asking the system to answer a question.
|
||||||
|
// If this is empty, DefaultPrompt will be used.
|
||||||
|
SystemPrompt string
|
||||||
|
|
||||||
|
// ExtraSystemPrompts is a list of extra prompts to use when asking the system to answer a question. Use these for
|
||||||
|
// variety in the prompts, or passing in some useful contextually relevant information.
|
||||||
|
// All of these will be used in addition to the SystemPrompt.
|
||||||
|
ExtraSystemPrompts []string
|
||||||
|
|
||||||
|
// WolframAppID is the Wolfram Alpha App ID to use when searching Wolfram Alpha for answers. If not set, the
|
||||||
|
// wolfram function will not be available.
|
||||||
|
WolframAppID string
|
||||||
}
|
}
|
||||||
|
|
||||||
var DefaultOptions = Options{
|
var DefaultOptions = Options{
|
||||||
@@ -68,39 +82,6 @@ type Result struct {
|
|||||||
Error error
|
Error error
|
||||||
}
|
}
|
||||||
|
|
||||||
func fanExecuteToolCalls(ctx context.Context, toolBox *gollm.ToolBox, calls []gollm.ToolCall) []Result {
|
|
||||||
var results []Result
|
|
||||||
var resultsOutput = make(chan Result, len(calls))
|
|
||||||
|
|
||||||
fnCall := func(call gollm.ToolCall) Result {
|
|
||||||
str, err := toolBox.Execute(ctx, call)
|
|
||||||
if err != nil {
|
|
||||||
return Result{
|
|
||||||
Error: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Result{
|
|
||||||
Result: str,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, call := range calls {
|
|
||||||
go func(call gollm.ToolCall) {
|
|
||||||
resultsOutput <- fnCall(call)
|
|
||||||
}(call)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < len(calls); i++ {
|
|
||||||
result := <-resultsOutput
|
|
||||||
results = append(results, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
close(resultsOutput)
|
|
||||||
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
|
|
||||||
type article struct {
|
type article struct {
|
||||||
URL string
|
URL string
|
||||||
Title string
|
Title string
|
||||||
@@ -150,83 +131,45 @@ func extractArticle(ctx context.Context, c cache.Cache, u *url.URL) (res article
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func doesTextAnswerQuestion(ctx context.Context, q Question, text string) (string, error) {
|
func doesTextAnswerQuestion(ctx *gollm.Context, q Question, text string) (string, error) {
|
||||||
var availableTools = []llms.Tool{
|
fnAnswer := gollm.NewFunction(
|
||||||
{
|
"answer",
|
||||||
Type: "function",
|
"The answer from the given text that answers the question.",
|
||||||
Function: &llms.FunctionDefinition{
|
func(ctx *gollm.Context, args struct {
|
||||||
Name: "answer",
|
Answer string `description:"the answer to the question, the answer should come from the text"`
|
||||||
Description: "The answer from the given text that answers the question.",
|
}) (string, error) {
|
||||||
Parameters: map[string]any{
|
return args.Answer, nil
|
||||||
"type": "object",
|
})
|
||||||
"properties": map[string]any{
|
|
||||||
"answer": map[string]any{
|
fnNoAnswer := gollm.NewFunction(
|
||||||
"type": "string",
|
"no_answer",
|
||||||
"description": "the answer to the question, the answer should come from the text",
|
"Indicate that the text does not answer the question.",
|
||||||
},
|
func(ctx *gollm.Context, args struct {
|
||||||
},
|
Ignored string `description:"ignored, just here to make sure the function is called. Fill with anything."`
|
||||||
},
|
}) (string, error) {
|
||||||
},
|
return "", nil
|
||||||
},
|
})
|
||||||
{
|
|
||||||
Type: "function",
|
req := gollm.Request{
|
||||||
Function: &llms.FunctionDefinition{
|
Messages: []gollm.Message{
|
||||||
Name: "no_answer",
|
{
|
||||||
Description: "Indicate that the text does not answer the question.",
|
Role: gollm.RoleSystem,
|
||||||
Parameters: map[string]any{
|
Text: "Evaluate the given text to see if it answers the question from the user. The text is as follows:",
|
||||||
"type": "object",
|
},
|
||||||
"properties": map[string]any{
|
{
|
||||||
"ignored": map[string]any{
|
Role: gollm.RoleSystem,
|
||||||
"type": "string",
|
Text: text,
|
||||||
"description": "ignored, just here to make sure the function is called. Fill with anything.",
|
},
|
||||||
},
|
{
|
||||||
},
|
Role: gollm.RoleUser,
|
||||||
},
|
Text: q.Question,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Toolbox: gollm.NewToolBox(fnAnswer, fnNoAnswer),
|
||||||
}
|
}
|
||||||
|
|
||||||
answer := func(ctx context.Context, args string) (string, error) {
|
res, err := q.Model.ChatComplete(ctx, req)
|
||||||
type answer struct {
|
|
||||||
Answer string `json:"answer"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var a answer
|
|
||||||
|
|
||||||
err := json.Unmarshal([]byte(args), &a)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return a.Answer, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
noAnswer := func(ctx context.Context, ignored string) (string, error) {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var req = []llms.MessageContent{
|
|
||||||
{
|
|
||||||
Role: llms.ChatMessageTypeSystem,
|
|
||||||
Parts: []llms.ContentPart{
|
|
||||||
llms.TextPart("Evaluate the given text to see if it answers the question from the user. The text is as follows:"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: llms.ChatMessageTypeSystem,
|
|
||||||
Parts: []llms.ContentPart{
|
|
||||||
llms.TextPart(text),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: llms.ChatMessageTypeHuman,
|
|
||||||
Parts: []llms.ContentPart{
|
|
||||||
llms.TextPart(q.Question),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := q.Model.GenerateContent(ctx, req, llms.WithTools(availableTools))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -235,28 +178,14 @@ func doesTextAnswerQuestion(ctx context.Context, q Question, text string) (strin
|
|||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(res.Choices[0].ToolCalls) == 0 {
|
if len(res.Choices[0].Calls) == 0 {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, call := range res.Choices[0].ToolCalls {
|
return req.Toolbox.Execute(ctx, res.Choices[0].Calls[0])
|
||||||
switch call.FunctionCall.Name {
|
|
||||||
case "answer":
|
|
||||||
return answer(ctx, call.FunctionCall.Arguments)
|
|
||||||
|
|
||||||
case "no_answer":
|
|
||||||
return noAnswer(ctx, call.FunctionCall.Arguments)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return "", fmt.Errorf("unknown function %s", call.FunctionCall.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func functionSearch(ctx context.Context, q Question, searchTerm string) (string, error) {
|
func functionSearch(ctx *gollm.Context, q Question, searchTerm string) (string, error) {
|
||||||
|
|
||||||
slog.Info("searching", "search", searchTerm, "question", q)
|
slog.Info("searching", "search", searchTerm, "question", q)
|
||||||
res, err := q.Search.Search(ctx, searchTerm)
|
res, err := q.Search.Search(ctx, searchTerm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -305,11 +234,11 @@ func functionSearch(ctx context.Context, q Question, searchTerm string) (string,
|
|||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func functionThink(ctx context.Context, q Question) (string, error) {
|
func functionThink(ctx *gollm.Context, q Question) (string, error) {
|
||||||
fnAnswer := gollm.NewFunction(
|
fnAnswer := gollm.NewFunction(
|
||||||
"answer",
|
"answer",
|
||||||
"Answer the question.",
|
"Answer the question.",
|
||||||
func(ctx context.Context, args struct {
|
func(ctx *gollm.Context, args struct {
|
||||||
Answer string `description:"the answer to the question"`
|
Answer string `description:"the answer to the question"`
|
||||||
}) (string, error) {
|
}) (string, error) {
|
||||||
return args.Answer, nil
|
return args.Answer, nil
|
||||||
@@ -348,30 +277,206 @@ func functionThink(ctx context.Context, q Question) (string, error) {
|
|||||||
return req.Toolbox.Execute(ctx, res.Choices[0].Calls[0])
|
return req.Toolbox.Execute(ctx, res.Choices[0].Calls[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
func appendResponse(req []llms.MessageContent, response *llms.ContentResponse) ([]llms.MessageContent, error) {
|
func (o Options) Answer(ctx context.Context, q Question) (Answers, error) {
|
||||||
if response == nil {
|
fnSearch := gollm.NewFunction(
|
||||||
return req, nil
|
"search",
|
||||||
|
"Search the web for an answer to a question. You can call this function up to "+fmt.Sprint(o.MaxSearches)+" times.",
|
||||||
|
func(ctx *gollm.Context, args struct {
|
||||||
|
SearchQuery string `description:"what to search the web for for this question"`
|
||||||
|
Question string `description:"what question(s) you are trying to answer with this search"`
|
||||||
|
}) (string, error) {
|
||||||
|
q2 := q
|
||||||
|
q2.Question = args.Question
|
||||||
|
|
||||||
|
return functionSearch(ctx, q2, args.SearchQuery)
|
||||||
|
})
|
||||||
|
|
||||||
|
fnThink := gollm.NewFunction(
|
||||||
|
"think",
|
||||||
|
"Think about a question. This is useful for breaking down complex questions into smaller parts that are easier to answer.",
|
||||||
|
func(ctx *gollm.Context, args struct {
|
||||||
|
Question string `json:"question" description:"the question to think about"`
|
||||||
|
}) (string, error) {
|
||||||
|
q2 := q
|
||||||
|
q2.Question = args.Question
|
||||||
|
|
||||||
|
return functionThink(ctx, q2)
|
||||||
|
})
|
||||||
|
|
||||||
|
fnAnswer := gollm.NewFunction(
|
||||||
|
"answer",
|
||||||
|
"You definitively answer a question, if you call this it means you know the answer and do not need to search for it or use any other function to find it",
|
||||||
|
func(ctx *gollm.Context, args struct {
|
||||||
|
Answer string `json:"answer" description:"the answer to the question"`
|
||||||
|
}) (string, error) {
|
||||||
|
return args.Answer, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
var fnWolfram *gollm.Function
|
||||||
|
|
||||||
|
if o.WolframAppID != "" {
|
||||||
|
fnWolfram = gollm.NewFunction(
|
||||||
|
"wolfram",
|
||||||
|
"Search Wolfram Alpha for an answer to a question.",
|
||||||
|
func(ctx *gollm.Context, args struct {
|
||||||
|
Question string `description:"the question to search for"`
|
||||||
|
}) (string, error) {
|
||||||
|
cl := wolfram.Client{
|
||||||
|
AppID: o.WolframAppID,
|
||||||
|
}
|
||||||
|
unit := wolfram.Imperial
|
||||||
|
|
||||||
|
return cl.GetShortAnswerQuery(args.Question, unit, 10)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(response.Choices) == 0 {
|
fnCalculate := gollm.NewFunction(
|
||||||
return req, nil
|
"calculate",
|
||||||
|
"Calculate a mathematical expression using starlark.",
|
||||||
|
func(ctx *gollm.Context, args struct {
|
||||||
|
Expression string `description:"the mathematical expression to calculate, in starlark format"`
|
||||||
|
}) (string, error) {
|
||||||
|
fileOpts := syntax.FileOptions{}
|
||||||
|
v, err := starlark.EvalOptions(&fileOpts, &starlark.Thread{Name: "main"}, "input", args.Expression, math.Module.Members)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return v.String(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
var funcs = []*gollm.Function{fnAnswer, fnCalculate}
|
||||||
|
|
||||||
|
if fnWolfram != nil {
|
||||||
|
funcs = append(funcs, fnWolfram)
|
||||||
}
|
}
|
||||||
|
|
||||||
choice := response.Choices[0]
|
if o.MaxSearches > 0 {
|
||||||
|
funcs = append(funcs, fnSearch)
|
||||||
assistantResponse := llms.TextParts(llms.ChatMessageTypeAI, choice.Content)
|
|
||||||
|
|
||||||
for _, tc := range choice.ToolCalls {
|
|
||||||
assistantResponse.Parts = append(assistantResponse.Parts, tc)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return req, nil
|
if o.MaxThinks > 0 {
|
||||||
|
funcs = append(funcs, fnThink)
|
||||||
|
}
|
||||||
|
|
||||||
|
var temp float32 = 0.8
|
||||||
|
|
||||||
|
var messages []gollm.Message
|
||||||
|
|
||||||
|
if o.SystemPrompt != "" {
|
||||||
|
messages = append(messages, gollm.Message{
|
||||||
|
Role: gollm.RoleSystem,
|
||||||
|
Text: o.SystemPrompt,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
messages = append(messages, gollm.Message{
|
||||||
|
Role: gollm.RoleSystem,
|
||||||
|
Text: DefaultPrompt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prompt := range o.ExtraSystemPrompts {
|
||||||
|
messages = append(messages, gollm.Message{
|
||||||
|
Role: gollm.RoleSystem,
|
||||||
|
Text: prompt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.Question != "" {
|
||||||
|
messages = append(messages, gollm.Message{
|
||||||
|
Role: gollm.RoleUser,
|
||||||
|
Text: q.Question,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
req := gollm.Request{
|
||||||
|
Messages: messages,
|
||||||
|
Toolbox: gollm.NewToolBox(funcs...),
|
||||||
|
Temperature: &temp,
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := q.Model.ChatComplete(ctx, req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Choices) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Choices) > o.MaxSearches {
|
||||||
|
res.Choices = res.Choices[:o.MaxSearches]
|
||||||
|
}
|
||||||
|
|
||||||
|
var answers []Result
|
||||||
|
for _, choice := range res.Choices {
|
||||||
|
fnChoice := func(choice gollm.ResponseChoice) []Result {
|
||||||
|
var calls []Result
|
||||||
|
var callsOutput = make(chan Result, len(choice.Calls))
|
||||||
|
fnCall := func(call gollm.ToolCall) Result {
|
||||||
|
str, err := req.Toolbox.Execute(gollm.NewContext(ctx, req, &choice, &call), call)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return Result{
|
||||||
|
Error: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Result{
|
||||||
|
Result: str,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, call := range choice.Calls {
|
||||||
|
go func(call gollm.ToolCall) {
|
||||||
|
if o.OnNewFunction != nil {
|
||||||
|
err := o.OnNewFunction(ctx, call.FunctionCall.Name, q.Question, call.FunctionCall.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
callsOutput <- Result{
|
||||||
|
Error: err,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
callsOutput <- fnCall(call)
|
||||||
|
}(call)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(choice.Calls); i++ {
|
||||||
|
result := <-callsOutput
|
||||||
|
calls = append(calls, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
close(callsOutput)
|
||||||
|
|
||||||
|
slog.Info("calls", "calls", calls)
|
||||||
|
return calls
|
||||||
|
}
|
||||||
|
|
||||||
|
answers = append(answers, fnChoice(choice)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
var results []string
|
||||||
|
|
||||||
|
for _, answer := range answers {
|
||||||
|
if answer.Error != nil {
|
||||||
|
errs = append(errs, answer.Error)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
results = append(results, answer.Result)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return nil, errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o Options) Answer(ctx context.Context, q Question) (string, error) {
|
func Answer(ctx context.Context, q Question) (Answers, error) {
|
||||||
a := agents.NewConversationalAgent()
|
|
||||||
}
|
|
||||||
|
|
||||||
func Answer(ctx context.Context, q Question) (string, error) {
|
|
||||||
return DefaultOptions.Answer(ctx, q)
|
return DefaultOptions.Answer(ctx, q)
|
||||||
}
|
}
|
||||||
|
@@ -1,17 +0,0 @@
|
|||||||
package toolbox
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/tmc/langchaingo/llms"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Context interface {
|
|
||||||
context.Context
|
|
||||||
|
|
||||||
WithCancel() (Context, func())
|
|
||||||
WithTimeout(time.Duration) (Context, func())
|
|
||||||
WithMessages([]llms.MessageContent) Context
|
|
||||||
GetMessages() []llms.MessageContent
|
|
||||||
}
|
|
@@ -1,299 +0,0 @@
|
|||||||
package toolbox
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type FuncResponse struct {
|
|
||||||
Result string
|
|
||||||
Source string
|
|
||||||
}
|
|
||||||
|
|
||||||
type funcCache struct {
|
|
||||||
sync.RWMutex
|
|
||||||
m map[reflect.Type]function
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *funcCache) get(value reflect.Value) (function, bool) {
|
|
||||||
c.RLock()
|
|
||||||
defer c.RUnlock()
|
|
||||||
|
|
||||||
fn, ok := c.m[value.Type()]
|
|
||||||
if ok {
|
|
||||||
slog.Info("cache hit for function", "function", value.Type().String())
|
|
||||||
}
|
|
||||||
return fn, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *funcCache) set(value reflect.Value, fn function) {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
c.m[value.Type()] = fn
|
|
||||||
}
|
|
||||||
|
|
||||||
var cache = funcCache{m: map[reflect.Type]function{}}
|
|
||||||
|
|
||||||
type arg struct {
|
|
||||||
Name string
|
|
||||||
Type reflect.Type
|
|
||||||
Index int
|
|
||||||
Values []string
|
|
||||||
Optional bool
|
|
||||||
Array bool
|
|
||||||
Description string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a arg) Schema() map[string]any {
|
|
||||||
var res = map[string]any{}
|
|
||||||
|
|
||||||
if a.Array {
|
|
||||||
res["type"] = "array"
|
|
||||||
res["items"] = map[string]any{"type": a.Type.Kind().String()}
|
|
||||||
} else {
|
|
||||||
res["type"] = a.Type.Name()
|
|
||||||
}
|
|
||||||
|
|
||||||
if !a.Optional {
|
|
||||||
res["required"] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(a.Values) > 0 {
|
|
||||||
res["enum"] = a.Values
|
|
||||||
}
|
|
||||||
|
|
||||||
if a.Description != "" {
|
|
||||||
res["description"] = a.Description
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
type function struct {
|
|
||||||
fn reflect.Value
|
|
||||||
argType reflect.Type
|
|
||||||
args map[string]arg
|
|
||||||
}
|
|
||||||
|
|
||||||
var ErrInvalidFunction = errors.New("invalid function")
|
|
||||||
|
|
||||||
// analyzeFuncFromReflect extracts metadata from a reflect.Value representing a function and returns a structured result.
|
|
||||||
// It maps the function's parameter names to their corresponding reflect.Type and encapsulates them in a function struct.
|
|
||||||
// Returns a function struct with extracted information and an error if the operation fails.
|
|
||||||
// The first parameter to the function must be a *Context.
|
|
||||||
// The second parameter to the function must be a struct, all the fields of which will be passed as arguments to the
|
|
||||||
// function to be analyzed.
|
|
||||||
// Struct tags supported are:
|
|
||||||
// - `name:"<name>"` to specify the name of the parameter (default is the field name)
|
|
||||||
// - `description:"<description>"` to specify a description of the parameter (default is "")
|
|
||||||
// - `values:"<value1>,<value2>,..."` to specify a list of possible values for the parameter (default is "") only for
|
|
||||||
// string, int, and float types
|
|
||||||
//
|
|
||||||
// Allowed types on the struct are:
|
|
||||||
// - string, *string, []string
|
|
||||||
// - int, *int, []int
|
|
||||||
// - float64, *float64, []float64
|
|
||||||
// - bool, *bool, []bool
|
|
||||||
//
|
|
||||||
// Pointer types imply that the parameter is optional.
|
|
||||||
// The function must have at most 2 parameters.
|
|
||||||
// The function must return a string and an error.
|
|
||||||
// The function must be of the form `func(*agent.Context, T) (FuncResponse, error)`.
|
|
||||||
func analyzeFuncFromReflect(fn reflect.Value) (function, error) {
|
|
||||||
if f, ok := cache.get(fn); ok {
|
|
||||||
return f, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var res function
|
|
||||||
t := fn.Type()
|
|
||||||
args := map[string]arg{}
|
|
||||||
|
|
||||||
for i := 0; i < t.NumIn(); i++ {
|
|
||||||
if i == 0 {
|
|
||||||
if t.In(i).String() != "*agent.Context" {
|
|
||||||
return res, fmt.Errorf("%w: first parameter must be *agent.Context", ErrInvalidFunction)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
} else if i == 1 {
|
|
||||||
if t.In(i).Kind() != reflect.Struct {
|
|
||||||
return res, fmt.Errorf("%w: second parameter must be a struct", ErrInvalidFunction)
|
|
||||||
}
|
|
||||||
res.argType = t.In(i)
|
|
||||||
|
|
||||||
for j := 0; j < res.argType.NumField(); j++ {
|
|
||||||
field := res.argType.Field(j)
|
|
||||||
|
|
||||||
a := arg{
|
|
||||||
Name: field.Name,
|
|
||||||
Type: field.Type,
|
|
||||||
Index: j,
|
|
||||||
Description: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
ft := field.Type
|
|
||||||
// if it's a pointer, it's optional
|
|
||||||
if ft.Kind() == reflect.Ptr {
|
|
||||||
a.Optional = true
|
|
||||||
ft = ft.Elem()
|
|
||||||
} else if ft.Kind() == reflect.Slice {
|
|
||||||
a.Array = true
|
|
||||||
ft = ft.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if ft.Kind() != reflect.String && ft.Kind() != reflect.Int && ft.Kind() != reflect.Float64 && ft.Kind() != reflect.Bool {
|
|
||||||
return res, fmt.Errorf("%w: unsupported type %s", ErrInvalidFunction, ft.Kind().String())
|
|
||||||
}
|
|
||||||
|
|
||||||
a.Type = ft
|
|
||||||
|
|
||||||
if name, ok := field.Tag.Lookup("name"); ok {
|
|
||||||
a.Name = name
|
|
||||||
a.Name = strings.TrimSpace(a.Name)
|
|
||||||
|
|
||||||
if a.Name == "" {
|
|
||||||
return res, fmt.Errorf("%w: name tag cannot be empty", ErrInvalidFunction)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if description, ok := field.Tag.Lookup("description"); ok {
|
|
||||||
a.Description = description
|
|
||||||
}
|
|
||||||
if values, ok := field.Tag.Lookup("values"); ok {
|
|
||||||
a.Values = strings.Split(values, ",")
|
|
||||||
for i, v := range a.Values {
|
|
||||||
a.Values[i] = strings.TrimSpace(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ft.Kind() != reflect.String && ft.Kind() != reflect.Int && ft.Kind() != reflect.Float64 {
|
|
||||||
return res, fmt.Errorf("%w: values tag only supported for string, int, and float types", ErrInvalidFunction)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
args[field.Name] = a
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return res, fmt.Errorf("%w: function must have at most 2 parameters", ErrInvalidFunction)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// finally ensure that the function returns a FuncResponse and an error
|
|
||||||
if t.NumOut() != 2 || t.Out(0).String() != "agent.FuncResponse" || t.Out(1).String() != "error" {
|
|
||||||
return res, fmt.Errorf("%w: function must return a FuncResponse and an error", ErrInvalidFunction)
|
|
||||||
}
|
|
||||||
|
|
||||||
cache.set(fn, res)
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func analyzeFunction[T any, AgentContext any](fn func(AgentContext, T) (FuncResponse, error)) (function, error) {
|
|
||||||
return analyzeFuncFromReflect(reflect.ValueOf(fn))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute will execute the given function with the given context and arguments.
|
|
||||||
// Returns the result of the execution and an error if the operation fails.
|
|
||||||
// The arguments must be a JSON-encoded string that represents the struct to be passed to the function.
|
|
||||||
func (f function) Execute(ctx context.Context, args string) (FuncResponse, error) {
|
|
||||||
var m = map[string]any{}
|
|
||||||
|
|
||||||
err := json.Unmarshal([]byte(args), &m)
|
|
||||||
if err != nil {
|
|
||||||
return FuncResponse{}, fmt.Errorf("failed to unmarshal arguments: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var obj = reflect.New(f.argType).Elem()
|
|
||||||
|
|
||||||
// TODO: ensure that "required" fields are present in the arguments
|
|
||||||
for name, a := range f.args {
|
|
||||||
if v, ok := m[name]; ok {
|
|
||||||
if a.Array {
|
|
||||||
if v == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch a.Type.Kind() {
|
|
||||||
case reflect.String:
|
|
||||||
s := v.([]string)
|
|
||||||
slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf("")), len(s), len(s))
|
|
||||||
for i, str := range s {
|
|
||||||
slice.Index(i).SetString(str)
|
|
||||||
}
|
|
||||||
obj.Field(a.Index).Set(slice)
|
|
||||||
|
|
||||||
case reflect.Int:
|
|
||||||
i := v.([]int)
|
|
||||||
slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(0)), len(i), len(i))
|
|
||||||
for i, in := range i {
|
|
||||||
slice.Index(i).SetInt(int64(in))
|
|
||||||
}
|
|
||||||
obj.Field(a.Index).Set(slice)
|
|
||||||
|
|
||||||
case reflect.Float64:
|
|
||||||
f := v.([]float64)
|
|
||||||
slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(0.0)), len(f), len(f))
|
|
||||||
for i, fl := range f {
|
|
||||||
slice.Index(i).SetFloat(fl)
|
|
||||||
}
|
|
||||||
obj.Field(a.Index).Set(slice)
|
|
||||||
|
|
||||||
case reflect.Bool:
|
|
||||||
b := v.([]bool)
|
|
||||||
slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(false)), len(b), len(b))
|
|
||||||
for i, b := range b {
|
|
||||||
slice.Index(i).SetBool(b)
|
|
||||||
}
|
|
||||||
obj.Field(a.Index).Set(slice)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return FuncResponse{}, fmt.Errorf("unsupported type %s for field %s", a.Type.Kind().String(), name)
|
|
||||||
}
|
|
||||||
|
|
||||||
} else if a.Optional {
|
|
||||||
if v == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
switch a.Type.Kind() {
|
|
||||||
case reflect.String:
|
|
||||||
str := v.(string)
|
|
||||||
obj.Field(a.Index).Set(reflect.ValueOf(&str))
|
|
||||||
case reflect.Int:
|
|
||||||
i := v.(int)
|
|
||||||
obj.Field(a.Index).Set(reflect.ValueOf(&i))
|
|
||||||
case reflect.Float64:
|
|
||||||
f := v.(float64)
|
|
||||||
obj.Field(a.Index).Set(reflect.ValueOf(&f))
|
|
||||||
case reflect.Bool:
|
|
||||||
b := v.(bool)
|
|
||||||
obj.Field(a.Index).Set(reflect.ValueOf(&b))
|
|
||||||
default:
|
|
||||||
return FuncResponse{}, fmt.Errorf("unsupported type %s for field %s", a.Type.Kind().String(), name)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
switch a.Type.Kind() {
|
|
||||||
case reflect.String:
|
|
||||||
obj.Field(a.Index).SetString(v.(string))
|
|
||||||
case reflect.Int:
|
|
||||||
obj.Field(a.Index).SetInt(int64(v.(int)))
|
|
||||||
case reflect.Float64:
|
|
||||||
obj.Field(a.Index).SetFloat(v.(float64))
|
|
||||||
case reflect.Bool:
|
|
||||||
obj.Field(a.Index).SetBool(v.(bool))
|
|
||||||
default:
|
|
||||||
return FuncResponse{}, fmt.Errorf("unsupported type %s for field %s", a.Type.Kind().String(), name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
res := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), obj})
|
|
||||||
if res[1].IsNil() {
|
|
||||||
return res[0].Interface().(FuncResponse), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return FuncResponse{}, res[1].Interface().(error)
|
|
||||||
}
|
|
@@ -1,75 +0,0 @@
|
|||||||
package toolbox
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/tmc/langchaingo/llms"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Tool struct {
|
|
||||||
Name string
|
|
||||||
Description string
|
|
||||||
Function function
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tool) Tool() llms.Tool {
|
|
||||||
return llms.Tool{
|
|
||||||
Type: "function",
|
|
||||||
Function: t.Definition(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tool) Definition() *llms.FunctionDefinition {
|
|
||||||
var properties = map[string]any{}
|
|
||||||
for name, arg := range t.Function.args {
|
|
||||||
properties[name] = arg.Schema()
|
|
||||||
}
|
|
||||||
|
|
||||||
var res = llms.FunctionDefinition{
|
|
||||||
Name: t.Name,
|
|
||||||
Description: t.Description,
|
|
||||||
Parameters: map[string]any{"type": "object", "properties": properties},
|
|
||||||
}
|
|
||||||
return &res
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute executes the tool with the given context and arguments.
|
|
||||||
// Returns the result of the execution and an error if the operation fails.
|
|
||||||
// The arguments must be a JSON-encoded string that represents the struct to be passed to the function.
|
|
||||||
func (t *Tool) Execute(ctx context.Context, args string) (FuncResponse, error) {
|
|
||||||
return t.Function.Execute(ctx, args)
|
|
||||||
}
|
|
||||||
|
|
||||||
func FromFunction[T any, AgentContext any](fn func(AgentContext, T) (FuncResponse, error)) *Tool {
|
|
||||||
f, err := analyzeFunction(fn)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Tool{
|
|
||||||
Name: reflect.TypeOf(fn).Name(),
|
|
||||||
Description: "This is a tool",
|
|
||||||
Function: f,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tool) WithName(name string) *Tool {
|
|
||||||
t.Name = name
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tool) WithDescription(description string) *Tool {
|
|
||||||
t.Description = description
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tool) WithFunction(fn any) *Tool {
|
|
||||||
f, err := analyzeFuncFromReflect(reflect.ValueOf(fn))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Function = f
|
|
||||||
return t
|
|
||||||
}
|
|
@@ -1,208 +0,0 @@
|
|||||||
package toolbox
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
|
|
||||||
"github.com/tmc/langchaingo/llms"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ToolBox map[string]*Tool
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrToolNotFound = errors.New("tool not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
type ToolResults []ToolResult
|
|
||||||
|
|
||||||
func (r ToolResults) ToMessageContent() llms.MessageContent {
|
|
||||||
var res = llms.MessageContent{
|
|
||||||
Role: llms.ChatMessageTypeTool,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, v := range r {
|
|
||||||
res.Parts = append(res.Parts, v.ToToolCallResponse())
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolResult struct {
|
|
||||||
ID string
|
|
||||||
Name string
|
|
||||||
Result string
|
|
||||||
Source string
|
|
||||||
Error error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r ToolResult) ToToolCallResponse() llms.ToolCallResponse {
|
|
||||||
if r.Error != nil {
|
|
||||||
return llms.ToolCallResponse{
|
|
||||||
ToolCallID: r.ID,
|
|
||||||
Name: r.Name,
|
|
||||||
Content: "error executing: " + r.Error.Error(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return llms.ToolCallResponse{
|
|
||||||
ToolCallID: r.ID,
|
|
||||||
Name: r.Name,
|
|
||||||
Content: r.Result,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func (tb ToolBox) Execute(ctx context.Context, call llms.ToolCall) (ToolResult, error) {
|
|
||||||
if call.Type != "function" {
|
|
||||||
return ToolResult{}, fmt.Errorf("unsupported tool type: %s", call.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
if call.FunctionCall == nil {
|
|
||||||
return ToolResult{}, errors.New("function call is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
tool, ok := tb[call.FunctionCall.Name]
|
|
||||||
if !ok {
|
|
||||||
return ToolResult{}, fmt.Errorf("%w: %s", ErrToolNotFound, call.FunctionCall.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := tool.Execute(ctx, call.FunctionCall.Arguments)
|
|
||||||
if err != nil {
|
|
||||||
return ToolResult{
|
|
||||||
ID: call.ID,
|
|
||||||
Name: tool.Name,
|
|
||||||
Error: err,
|
|
||||||
Source: res.Source,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return ToolResult{
|
|
||||||
ID: call.ID,
|
|
||||||
Name: tool.Name,
|
|
||||||
Result: res.Result,
|
|
||||||
Source: res.Source,
|
|
||||||
Error: err,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tb ToolBox) ExecuteAll(ctx Context, calls []llms.ToolCall) (ToolResults, error) {
|
|
||||||
var results []ToolResult
|
|
||||||
|
|
||||||
for _, call := range calls {
|
|
||||||
res, err := tb.Execute(ctx, call)
|
|
||||||
if err != nil {
|
|
||||||
return results, err
|
|
||||||
}
|
|
||||||
|
|
||||||
results = append(results, res)
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tb ToolBox) ExecuteConcurrent(ctx Context, calls []llms.ToolCall) (ToolResults, error) {
|
|
||||||
var results []ToolResult
|
|
||||||
var ch = make(chan ToolResult, len(calls))
|
|
||||||
var eg = errgroup.Group{}
|
|
||||||
|
|
||||||
for _, call := range calls {
|
|
||||||
eg.Go(func() error {
|
|
||||||
c, cancel := ctx.WithCancel()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
res, err := tb.Execute(c, call)
|
|
||||||
if err != nil {
|
|
||||||
ch <- res
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
err := eg.Wait()
|
|
||||||
if err != nil {
|
|
||||||
return results, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for range calls {
|
|
||||||
results = append(results, <-ch)
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type Answers struct {
|
|
||||||
Response llms.MessageContent
|
|
||||||
Answers []Answer
|
|
||||||
}
|
|
||||||
|
|
||||||
type Answer struct {
|
|
||||||
Answer string
|
|
||||||
Source string
|
|
||||||
ToolCallResponse llms.ToolCallResponse `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tb ToolBox) Run(ctx Context, model llms.Model, question string) (Answers, error) {
|
|
||||||
ctx = ctx.WithMessages([]llms.MessageContent{{
|
|
||||||
Role: llms.ChatMessageTypeGeneric,
|
|
||||||
Parts: []llms.ContentPart{llms.TextPart(question)},
|
|
||||||
}})
|
|
||||||
|
|
||||||
res, err := model.GenerateContent(ctx, ctx.GetMessages())
|
|
||||||
if err != nil {
|
|
||||||
return Answers{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if res == nil {
|
|
||||||
return Answers{}, errors.New("no response from model")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(res.Choices) == 0 {
|
|
||||||
return Answers{}, errors.New("no response from model")
|
|
||||||
}
|
|
||||||
|
|
||||||
choice := res.Choices[0]
|
|
||||||
|
|
||||||
response := llms.MessageContent{
|
|
||||||
Role: llms.ChatMessageTypeAI,
|
|
||||||
Parts: []llms.ContentPart{llms.TextPart(choice.Content)},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range choice.ToolCalls {
|
|
||||||
response.Parts = append(response.Parts, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
results, err := tb.ExecuteConcurrent(ctx, choice.ToolCalls)
|
|
||||||
if err != nil {
|
|
||||||
return Answers{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var answers []Answer
|
|
||||||
|
|
||||||
for _, r := range results {
|
|
||||||
if r.Error != nil {
|
|
||||||
answers = append(answers, Answer{
|
|
||||||
Answer: "error executing: " + r.Error.Error(),
|
|
||||||
Source: r.Source,
|
|
||||||
ToolCallResponse: r.ToToolCallResponse(),
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
answers = append(answers, Answer{
|
|
||||||
Answer: r.Result,
|
|
||||||
Source: r.Source,
|
|
||||||
ToolCallResponse: r.ToToolCallResponse(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Answers{
|
|
||||||
Response: response,
|
|
||||||
Answers: answers,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tb ToolBox) Register(tool *Tool) {
|
|
||||||
tb[tool.Name] = tool
|
|
||||||
}
|
|
Reference in New Issue
Block a user