From 090b28d95614e0f68b8e3923101820942bbb7fea Mon Sep 17 00:00:00 2001 From: Steve Dudenhoeffer Date: Tue, 25 Feb 2025 22:56:32 -0500 Subject: [PATCH] Update LLM integration and add new agent tools and utilities Refactored LLM handling to use updated langchaingo models and tools, replacing gollm dependencies. Introduced agent-related utilities, tools, and counters for better modular functionality. Added a parser for LLM model configuration and revamped the answering mechanism with enhanced support for tool-based interaction. --- go.mod | 4 + pkg/agent/agent.go | 115 ++++++++++++++++ pkg/agent/answertool.go | 10 ++ pkg/agent/ask.go | 46 +++++++ pkg/agent/context.go | 110 +++++++++++++++ pkg/agent/counter.go | 25 ++++ pkg/agent/function.go | 298 ++++++++++++++++++++++++++++++++++++++++ pkg/agent/search.go | 11 ++ pkg/agent/test/main.go | 34 +++++ pkg/agent/tool.go | 74 ++++++++++ pkg/agent/toolbox.go | 188 +++++++++++++++++++++++++ pkg/agent/wolfram.go | 28 ++++ pkg/answer/answer.go | 280 +++++++++++++++---------------------- 13 files changed, 1051 insertions(+), 172 deletions(-) create mode 100644 pkg/agent/agent.go create mode 100644 pkg/agent/answertool.go create mode 100644 pkg/agent/ask.go create mode 100644 pkg/agent/context.go create mode 100644 pkg/agent/counter.go create mode 100644 pkg/agent/function.go create mode 100644 pkg/agent/search.go create mode 100644 pkg/agent/test/main.go create mode 100644 pkg/agent/tool.go create mode 100644 pkg/agent/toolbox.go create mode 100644 pkg/agent/wolfram.go diff --git a/go.mod b/go.mod index 3e502a6..2bb91e2 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.7 // indirect cloud.google.com/go/compute/metadata v0.6.0 // indirect cloud.google.com/go/longrunning v0.6.4 // indirect + github.com/Edw590/go-wolfram v0.0.0-20241010091529-fb9031908c5d // indirect github.com/PuerkitoBio/goquery v1.10.2 // indirect github.com/andybalholm/cascadia v1.3.3 // indirect github.com/antchfx/htmlquery v1.3.4 // indirect @@ -31,6 +32,7 @@ require ( github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de // indirect github.com/cpuguy83/go-md2man/v2 v2.0.6 // indirect github.com/deckarep/golang-set/v2 v2.7.0 // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect github.com/fatih/set v0.2.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/gigawattio/window v0.0.0-20180317192513-0f5467e35573 // indirect @@ -57,12 +59,14 @@ require ( github.com/mattn/go-runewidth v0.0.16 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pkoukk/tiktoken-go v0.1.6 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d // indirect github.com/sashabaranov/go-openai v1.37.0 // indirect github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect github.com/temoto/robotstxt v1.1.2 // indirect + github.com/tmc/langchaingo v0.1.13 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 // indirect diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go new file mode 100644 index 0000000..915293e --- /dev/null +++ b/pkg/agent/agent.go @@ -0,0 +1,115 @@ +package agent + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/tmc/langchaingo/llms" + + cache2 "gitea.stevedudenhoeffer.com/steve/answer/pkg/cache" + "gitea.stevedudenhoeffer.com/steve/answer/pkg/extractor" + "gitea.stevedudenhoeffer.com/steve/answer/pkg/search" +) + +type Options struct { + MaxSearches int + MaxQuestions int +} +type Agent struct { + RemainingSearches *Counter + RemainingQuestions *Counter + Search search.Search + Cache cache2.Cache + Extractor extractor.Extractor + Model llms.Model +} + +type Answers struct { + Response llms.MessageContent + Answers []Answer +} + +type Answer struct { + Answer string + Source string + ToolCallResponse llms.ToolCallResponse `json:"-"` +} + +type Question struct { + Question string +} + +func New(o Options) *Agent { + searches := &Counter{} + searches.Add(int32(o.MaxSearches)) + + questions := &Counter{} + questions.Add(int32(o.MaxQuestions)) + + return &Agent{ + RemainingSearches: searches, + RemainingQuestions: questions, + } +} + +var ( + ErrOutOfSearches = errors.New("out of searches") + ErrOutOfQuestions = errors.New("out of questions") +) + +func ask(ctx *Context, q Question) (Answers, error) { + var tb = ToolBox{} + + if ctx.Agent.RemainingSearches.Load() > 0 { + tb.Register(SearchTool) + } + tb.Register(WolframTool) + tb.Register(AnswerTool) + + return tb.Run(ctx, q) +} + +var SummarizeAnswers = FromFunction( + func(ctx *Context, args struct { + Summary string `description:"the summary of the answers"` + }) (FuncResponse, error) { + return FuncResponse{Result: args.Summary}, nil + }). + WithName("summarize_answers"). + WithDescription(`You are given previously figured out answers and they are in the format of: [ { "answer": "the answer", "source": "the source of the answer" }, { "answer": "answer 2", "source": "the source for answer2" } ]. You need to summarize the answers into a single string. Be sure to make the summary clear and concise, but include the sources at some point.`) + +// Ask is an incoming call to the agent, it will create an internal Context from an incoming context.Context +func (a *Agent) Ask(ctx context.Context, q Question) (string, error) { + c := From(ctx, a) + if !a.RemainingQuestions.Take() { + return "", ErrOutOfQuestions + } + + answers, err := ask(c, q) + if err != nil { + return "", err + } + + tb := ToolBox{} + + tb.Register(SummarizeAnswers) + + b, err := json.Marshal(answers.Answers) + if err != nil { + return "", fmt.Errorf("failed to marshal answers: %w", err) + } + + answers, err = tb.Run(c, Question{Question: string(b)}) + + if err != nil { + return "", fmt.Errorf("failed to summarize answers: %w", err) + } + + if len(answers.Answers) == 0 { + return "", errors.New("no response from model") + } + return answers.Answers[0].Answer, nil + +} diff --git a/pkg/agent/answertool.go b/pkg/agent/answertool.go new file mode 100644 index 0000000..66f4edb --- /dev/null +++ b/pkg/agent/answertool.go @@ -0,0 +1,10 @@ +package agent + +var AnswerTool = FromFunction( + func(ctx *Context, args struct { + Answer string `description:"the answer to the question"` + }) (FuncResponse, error) { + return FuncResponse{Result: args.Answer}, nil + }). + WithName("answer"). + WithDescription("Answer the question") diff --git a/pkg/agent/ask.go b/pkg/agent/ask.go new file mode 100644 index 0000000..55359b3 --- /dev/null +++ b/pkg/agent/ask.go @@ -0,0 +1,46 @@ +package agent + +import ( + "encoding/json" + "fmt" +) + +var AskTool = FromFunction( + func(ctx *Context, args struct { + Question string `description:"the question to answer"` + }) (FuncResponse, error) { + var q Question + + q.Question = args.Question + ctx = ctx.WithQuestion(q) + + answers, err := ask(ctx, q) + + if err != nil { + return FuncResponse{}, err + } + + tb := ToolBox{} + tb.Register(SummarizeAnswers) + + b, err := json.Marshal(answers.Answers) + if err != nil { + return FuncResponse{}, fmt.Errorf("failed to marshal answers: %w", err) + } + + q = Question{Question: string(b)} + ctx = ctx.WithQuestion(q) + answers, err = tb.Run(ctx, q) + + if err != nil { + return FuncResponse{}, fmt.Errorf("failed to summarize answers: %w", err) + } + + if len(answers.Answers) == 0 { + return FuncResponse{}, fmt.Errorf("no response from model") + } + + return FuncResponse{Result: answers.Answers[0].Answer}, nil + }). + WithName("ask"). + WithDescription("Ask the agent a question, this is useful for splitting a question into multiple parts") diff --git a/pkg/agent/context.go b/pkg/agent/context.go new file mode 100644 index 0000000..2ec781d --- /dev/null +++ b/pkg/agent/context.go @@ -0,0 +1,110 @@ +package agent + +import ( + "context" + "time" + + "github.com/tmc/langchaingo/llms" +) + +type Context struct { + context.Context + + Messages []llms.MessageContent + Agent *Agent +} + +func From(ctx context.Context, a *Agent) *Context { + if ctx == nil { + ctx = context.Background() + } + + return &Context{ + Context: ctx, + Agent: a, + } +} + +func (c *Context) WithMessage(m llms.MessageContent) *Context { + c.Messages = append(c.Messages, m) + return c +} + +func (c *Context) WithMessages(m ...llms.MessageContent) *Context { + c.Messages = append(c.Messages, m...) + return c +} + +func (c *Context) WithToolResults(r ...ToolResult) *Context { + msg := llms.MessageContent{ + Role: llms.ChatMessageTypeTool, + } + + for _, v := range r { + res := v.Result + if v.Error != nil { + res = "error executing: " + v.Error.Error() + } + msg.Parts = append(msg.Parts, llms.ToolCallResponse{ + ToolCallID: v.ID, + Name: v.Name, + Content: res, + }) + } + + return c.WithMessage(msg) +} + +func (c *Context) WithAgent(a *Agent) *Context { + return &Context{ + Context: c.Context, + Agent: a, + } +} + +func (c *Context) WithCancel() (*Context, func()) { + ctx, cancel := context.WithCancel(c.Context) + return &Context{ + Context: ctx, + Agent: c.Agent, + }, cancel +} + +func (c *Context) WithDeadline(deadline time.Time) (*Context, func()) { + ctx, cancel := context.WithDeadline(c.Context, deadline) + return &Context{ + Context: ctx, + Agent: c.Agent, + }, cancel +} + +func (c *Context) WithTimeout(timeout time.Duration) (*Context, func()) { + ctx, cancel := context.WithTimeout(c.Context, timeout) + return &Context{ + Context: ctx, + Agent: c.Agent, + }, cancel +} + +func (c *Context) WithValue(key, value interface{}) *Context { + return &Context{ + Context: context.WithValue(c.Context, key, value), + Agent: c.Agent, + } +} + +func (c *Context) Done() <-chan struct{} { + return c.Context.Done() +} + +func (c *Context) Err() error { + return c.Context.Err() +} + +func (c *Context) Value(key interface{}) interface{} { + return c.Context.Value(key) +} + +func (c *Context) Deadline() (time.Time, bool) { + return c.Context.Deadline() +} diff --git a/pkg/agent/counter.go b/pkg/agent/counter.go new file mode 100644 index 0000000..e1b197b --- /dev/null +++ b/pkg/agent/counter.go @@ -0,0 +1,25 @@ +package agent + +import "sync/atomic" + +type Counter struct { + atomic.Int32 +} + +// Take will attempt to take an item from the counter. If the counter is zero, it will return false. +func (c *Counter) Take() bool { + for { + current := c.Load() + if current <= 0 { + return false + } + if c.CompareAndSwap(current, current-1) { + return true + } + } +} + +// Return will return an item to the counter. +func (c *Counter) Return() { + c.Add(1) +} diff --git a/pkg/agent/function.go b/pkg/agent/function.go new file mode 100644 index 0000000..dc64d40 --- /dev/null +++ b/pkg/agent/function.go @@ -0,0 +1,298 @@ +package agent + +import ( + "encoding/json" + "errors" + "fmt" + "log/slog" + "reflect" + "strings" + "sync" +) + +type FuncResponse struct { + Result string + Source string +} + +type funcCache struct { + sync.RWMutex + m map[reflect.Type]function +} + +func (c *funcCache) get(value reflect.Value) (function, bool) { + c.RLock() + defer c.RUnlock() + + fn, ok := c.m[value.Type()] + if ok { + slog.Info("cache hit for function", "function", value.Type().String()) + } + return fn, ok +} + +func (c *funcCache) set(value reflect.Value, fn function) { + c.Lock() + defer c.Unlock() + + c.m[value.Type()] = fn +} + +var cache = funcCache{m: map[reflect.Type]function{}} + +type arg struct { + Name string + Type reflect.Type + Index int + Values []string + Optional bool + Array bool + Description string +} + +func (a arg) Schema() map[string]any { + var res = map[string]any{} + + if a.Array { + res["type"] = "array" + res["items"] = map[string]any{"type": a.Type.Kind().String()} + } else { + res["type"] = a.Type.Name() + } + + if !a.Optional { + res["required"] = true + } + + if len(a.Values) > 0 { + res["enum"] = a.Values + } + + if a.Description != "" { + res["description"] = a.Description + } + + return res +} + +type function struct { + fn reflect.Value + argType reflect.Type + args map[string]arg +} + +var ErrInvalidFunction = errors.New("invalid function") + +// analyzeFuncFromReflect extracts metadata from a reflect.Value representing a function and returns a structured result. +// It maps the function's parameter names to their corresponding reflect.Type and encapsulates them in a function struct. +// Returns a function struct with extracted information and an error if the operation fails. +// The first parameter to the function must be a *Context. +// The second parameter to the function must be a struct, all the fields of which will be passed as arguments to the +// function to be analyzed. +// Struct tags supported are: +// - `name:""` to specify the name of the parameter (default is the field name) +// - `description:""` to specify a description of the parameter (default is "") +// - `values:",,..."` to specify a list of possible values for the parameter (default is "") only for +// string, int, and float types +// +// Allowed types on the struct are: +// - string, *string, []string +// - int, *int, []int +// - float64, *float64, []float64 +// - bool, *bool, []bool +// +// Pointer types imply that the parameter is optional. +// The function must have at most 2 parameters. +// The function must return a string and an error. +// The function must be of the form `func(*agent.Context, T) (FuncResponse, error)`. +func analyzeFuncFromReflect(fn reflect.Value) (function, error) { + if f, ok := cache.get(fn); ok { + return f, nil + } + + var res function + t := fn.Type() + args := map[string]arg{} + + for i := 0; i < t.NumIn(); i++ { + if i == 0 { + if t.In(i).String() != "*agent.Context" { + return res, fmt.Errorf("%w: first parameter must be *agent.Context", ErrInvalidFunction) + } + continue + } else if i == 1 { + if t.In(i).Kind() != reflect.Struct { + return res, fmt.Errorf("%w: second parameter must be a struct", ErrInvalidFunction) + } + res.argType = t.In(i) + + for j := 0; j < res.argType.NumField(); j++ { + field := res.argType.Field(j) + + a := arg{ + Name: field.Name, + Type: field.Type, + Index: j, + Description: "", + } + + ft := field.Type + // if it's a pointer, it's optional + if ft.Kind() == reflect.Ptr { + a.Optional = true + ft = ft.Elem() + } else if ft.Kind() == reflect.Slice { + a.Array = true + ft = ft.Elem() + } + + if ft.Kind() != reflect.String && ft.Kind() != reflect.Int && ft.Kind() != reflect.Float64 && ft.Kind() != reflect.Bool { + return res, fmt.Errorf("%w: unsupported type %s", ErrInvalidFunction, ft.Kind().String()) + } + + a.Type = ft + + if name, ok := field.Tag.Lookup("name"); ok { + a.Name = name + a.Name = strings.TrimSpace(a.Name) + + if a.Name == "" { + return res, fmt.Errorf("%w: name tag cannot be empty", ErrInvalidFunction) + } + } + if description, ok := field.Tag.Lookup("description"); ok { + a.Description = description + } + if values, ok := field.Tag.Lookup("values"); ok { + a.Values = strings.Split(values, ",") + for i, v := range a.Values { + a.Values[i] = strings.TrimSpace(v) + } + + if ft.Kind() != reflect.String && ft.Kind() != reflect.Int && ft.Kind() != reflect.Float64 { + return res, fmt.Errorf("%w: values tag only supported for string, int, and float types", ErrInvalidFunction) + } + } + + args[field.Name] = a + } + } else { + return res, fmt.Errorf("%w: function must have at most 2 parameters", ErrInvalidFunction) + } + } + + // finally ensure that the function returns a FuncResponse and an error + if t.NumOut() != 2 || t.Out(0).String() != "agent.FuncResponse" || t.Out(1).String() != "error" { + return res, fmt.Errorf("%w: function must return a FuncResponse and an error", ErrInvalidFunction) + } + + cache.set(fn, res) + return res, nil +} + +func analyzeFunction[T any](fn func(*Context, T) (FuncResponse, error)) (function, error) { + return analyzeFuncFromReflect(reflect.ValueOf(fn)) +} + +// Execute will execute the given function with the given context and arguments. +// Returns the result of the execution and an error if the operation fails. +// The arguments must be a JSON-encoded string that represents the struct to be passed to the function. +func (f function) Execute(ctx *Context, args string) (FuncResponse, error) { + var m = map[string]any{} + + err := json.Unmarshal([]byte(args), &m) + if err != nil { + return FuncResponse{}, fmt.Errorf("failed to unmarshal arguments: %w", err) + } + + var obj = reflect.New(f.argType).Elem() + + // TODO: ensure that "required" fields are present in the arguments + for name, a := range f.args { + if v, ok := m[name]; ok { + if a.Array { + if v == nil { + continue + } + + switch a.Type.Kind() { + case reflect.String: + s := v.([]string) + slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf("")), len(s), len(s)) + for i, str := range s { + slice.Index(i).SetString(str) + } + obj.Field(a.Index).Set(slice) + + case reflect.Int: + i := v.([]int) + slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(0)), len(i), len(i)) + for i, in := range i { + slice.Index(i).SetInt(int64(in)) + } + obj.Field(a.Index).Set(slice) + + case reflect.Float64: + f := v.([]float64) + slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(0.0)), len(f), len(f)) + for i, fl := range f { + slice.Index(i).SetFloat(fl) + } + obj.Field(a.Index).Set(slice) + + case reflect.Bool: + b := v.([]bool) + slice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(false)), len(b), len(b)) + for i, b := range b { + slice.Index(i).SetBool(b) + } + obj.Field(a.Index).Set(slice) + + default: + return FuncResponse{}, fmt.Errorf("unsupported type %s for field %s", a.Type.Kind().String(), name) + } + + } else if a.Optional { + if v == nil { + continue + } + switch a.Type.Kind() { + case reflect.String: + str := v.(string) + obj.Field(a.Index).Set(reflect.ValueOf(&str)) + case reflect.Int: + i := v.(int) + obj.Field(a.Index).Set(reflect.ValueOf(&i)) + case reflect.Float64: + f := v.(float64) + obj.Field(a.Index).Set(reflect.ValueOf(&f)) + case reflect.Bool: + b := v.(bool) + obj.Field(a.Index).Set(reflect.ValueOf(&b)) + default: + return FuncResponse{}, fmt.Errorf("unsupported type %s for field %s", a.Type.Kind().String(), name) + } + } else { + switch a.Type.Kind() { + case reflect.String: + obj.Field(a.Index).SetString(v.(string)) + case reflect.Int: + obj.Field(a.Index).SetInt(int64(v.(int))) + case reflect.Float64: + obj.Field(a.Index).SetFloat(v.(float64)) + case reflect.Bool: + obj.Field(a.Index).SetBool(v.(bool)) + default: + return FuncResponse{}, fmt.Errorf("unsupported type %s for field %s", a.Type.Kind().String(), name) + } + } + } + } + + res := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), obj}) + if res[1].IsNil() { + return res[0].Interface().(FuncResponse), nil + } + + return FuncResponse{}, res[1].Interface().(error) +} diff --git a/pkg/agent/search.go b/pkg/agent/search.go new file mode 100644 index 0000000..99dcf83 --- /dev/null +++ b/pkg/agent/search.go @@ -0,0 +1,11 @@ +package agent + +var SearchTool = FromFunction( + func(ctx *Context, args struct { + SearchFor string `description:"what to search for"` + Question string `description:"the question to answer with the search results"` + }) (FuncResponse, error) { + return FuncResponse{}, nil + }). + WithName("search"). + WithDescription("Search the web and read a few articles to find the answer to the question") diff --git a/pkg/agent/test/main.go b/pkg/agent/test/main.go new file mode 100644 index 0000000..6b04c32 --- /dev/null +++ b/pkg/agent/test/main.go @@ -0,0 +1,34 @@ +package main + +import ( + "reflect" + + "github.com/tmc/langchaingo/llms" +) + +func testFunction(args struct{ a, b int }) { + // This is a test function +} + +func main() { + v := reflect.New(reflect.TypeOf(testFunction)) + + t := reflect.TypeOf(testFunction) + + for i := 0; i < t.NumIn(); i++ { + param := t.In(i) + + llms.MessageContent{ + Role: llms.ChatMessageTypeTool, + Parts: []llms.ContentPart{ + llms.ToolCallResponse{ + Name: "testFunction", + }, + }, + } + if param.Type().Kind() == reflect.Struct { + + } + println(param.Name(), param.Kind().String()) + } +} diff --git a/pkg/agent/tool.go b/pkg/agent/tool.go new file mode 100644 index 0000000..5a526ab --- /dev/null +++ b/pkg/agent/tool.go @@ -0,0 +1,74 @@ +package agent + +import ( + "reflect" + + "github.com/tmc/langchaingo/llms" +) + +type Tool struct { + Name string + Description string + Function function +} + +func (t *Tool) Tool() llms.Tool { + return llms.Tool{ + Type: "function", + Function: t.Definition(), + } +} + +func (t *Tool) Definition() *llms.FunctionDefinition { + var properties = map[string]any{} + for name, arg := range t.Function.args { + properties[name] = arg.Schema() + } + + var res = llms.FunctionDefinition{ + Name: t.Name, + Description: t.Description, + Parameters: map[string]any{"type": "object", "properties": properties}, + } + return &res +} + +// Execute executes the tool with the given context and arguments. +// Returns the result of the execution and an error if the operation fails. +// The arguments must be a JSON-encoded string that represents the struct to be passed to the function. +func (t *Tool) Execute(ctx *Context, args string) (FuncResponse, error) { + return t.Function.Execute(ctx, args) +} + +func FromFunction[T any](fn func(*Context, T) (FuncResponse, error)) *Tool { + f, err := analyzeFunction(fn) + if err != nil { + panic(err) + } + + return &Tool{ + Name: reflect.TypeOf(fn).Name(), + Description: "This is a tool", + Function: f, + } +} + +func (t *Tool) WithName(name string) *Tool { + t.Name = name + return t +} + +func (t *Tool) WithDescription(description string) *Tool { + t.Description = description + return t +} + +func (t *Tool) WithFunction(fn any) *Tool { + f, err := analyzeFuncFromReflect(reflect.ValueOf(fn)) + if err != nil { + panic(err) + } + + t.Function = f + return t +} diff --git a/pkg/agent/toolbox.go b/pkg/agent/toolbox.go new file mode 100644 index 0000000..8e5a98e --- /dev/null +++ b/pkg/agent/toolbox.go @@ -0,0 +1,188 @@ +package agent + +import ( + "errors" + "fmt" + + "golang.org/x/sync/errgroup" + + "github.com/tmc/langchaingo/llms" +) + +type ToolBox map[string]*Tool + +var ( + ErrToolNotFound = errors.New("tool not found") +) + +type ToolResults []ToolResult + +func (r ToolResults) ToMessageContent() llms.MessageContent { + var res = llms.MessageContent{ + Role: llms.ChatMessageTypeTool, + } + + for _, v := range r { + res.Parts = append(res.Parts, v.ToToolCallResponse()) + } + + return res +} + +type ToolResult struct { + ID string + Name string + Result string + Source string + Error error +} + +func (r ToolResult) ToToolCallResponse() llms.ToolCallResponse { + if r.Error != nil { + return llms.ToolCallResponse{ + ToolCallID: r.ID, + Name: r.Name, + Content: "error executing: " + r.Error.Error(), + } + } + + return llms.ToolCallResponse{ + ToolCallID: r.ID, + Name: r.Name, + Content: r.Result, + } +} +func (tb ToolBox) Execute(ctx *Context, call llms.ToolCall) (ToolResult, error) { + if call.Type != "function" { + return ToolResult{}, fmt.Errorf("unsupported tool type: %s", call.Type) + } + + if call.FunctionCall == nil { + return ToolResult{}, errors.New("function call is nil") + } + + tool, ok := tb[call.FunctionCall.Name] + if !ok { + return ToolResult{}, fmt.Errorf("%w: %s", ErrToolNotFound, call.FunctionCall.Name) + } + + res, err := tool.Execute(ctx, call.FunctionCall.Arguments) + + return ToolResult{ + ID: call.ID, + Name: tool.Name, + Result: res.Result, + Source: res.Source, + Error: err, + }, nil +} + +func (tb ToolBox) ExecuteAll(ctx *Context, calls []llms.ToolCall) (ToolResults, error) { + var results []ToolResult + + for _, call := range calls { + res, err := tb.Execute(ctx, call) + if err != nil { + return results, err + } + + results = append(results, res) + } + + return results, nil +} + +func (tb ToolBox) ExecuteConcurrent(ctx *Context, calls []llms.ToolCall) (ToolResults, error) { + var results []ToolResult + var ch = make(chan ToolResult, len(calls)) + var eg = errgroup.Group{} + + for _, call := range calls { + eg.Go(func() error { + c, cancel := ctx.WithCancel() + defer cancel() + + res, err := tb.Execute(c, call) + if err != nil { + ch <- res + return nil + } + + return err + }) + } + + err := eg.Wait() + if err != nil { + return results, err + } + + for range calls { + results = append(results, <-ch) + } + + return results, nil +} + +func (tb ToolBox) Run(ctx *Context, q Question) (Answers, error) { + ctx.Messages = append(ctx.Messages, llms.MessageContent{ + Role: llms.ChatMessageTypeGeneric, + Parts: []llms.ContentPart{llms.TextPart(q.Question)}, + }) + + res, err := ctx.Agent.Model.GenerateContent(ctx, ctx.Messages) + if err != nil { + return Answers{}, err + } + + if res == nil { + return Answers{}, errors.New("no response from model") + } + + if len(res.Choices) == 0 { + return Answers{}, errors.New("no response from model") + } + + choice := res.Choices[0] + + response := llms.MessageContent{ + Role: llms.ChatMessageTypeAI, + Parts: []llms.ContentPart{llms.TextPart(choice.Content)}, + } + + for _, c := range choice.ToolCalls { + response.Parts = append(response.Parts, c) + } + + results, err := tb.ExecuteConcurrent(ctx, choice.ToolCalls) + if err != nil { + return Answers{}, err + } + + var answers []Answer + + for _, r := range results { + if r.Error != nil { + answers = append(answers, Answer{ + Answer: "error executing: " + r.Error.Error(), + Source: r.Source, + ToolCallResponse: r.ToToolCallResponse(), + }) + } else { + answers = append(answers, Answer{ + Answer: r.Result, + Source: r.Source, + ToolCallResponse: r.ToToolCallResponse(), + }) + } + } + + return Answers{ + Response: response, + Answers: answers, + }, nil +} + +func (tb ToolBox) Register(tool *Tool) { + tb[tool.Name] = tool +} diff --git a/pkg/agent/wolfram.go b/pkg/agent/wolfram.go new file mode 100644 index 0000000..75c6764 --- /dev/null +++ b/pkg/agent/wolfram.go @@ -0,0 +1,28 @@ +package agent + +import ( + "fmt" + "os" + + "github.com/Edw590/go-wolfram" +) + +var WolframTool = FromFunction( + func(ctx *Context, args struct { + Query string `description:"what to ask wolfram alpha"` + }) (FuncResponse, error) { + var cl = wolfram.Client{ + AppID: os.Getenv("WOLFRAM_APPID"), + } + + unit := wolfram.Imperial + + a, err := cl.GetShortAnswerQuery(args.Query, unit, 10) + if err != nil { + return FuncResponse{}, fmt.Errorf("failed to get short answer from wolfram: %w", err) + } + + return FuncResponse{Result: a, Source: "Wolfram|Alpha"}, nil + }). + WithName("wolfram"). + WithDescription("ask wolfram alpha for the answer") diff --git a/pkg/answer/answer.go b/pkg/answer/answer.go index d6541d1..7dad6d2 100644 --- a/pkg/answer/answer.go +++ b/pkg/answer/answer.go @@ -2,12 +2,17 @@ package answer import ( "context" + "encoding/json" "errors" "fmt" "log/slog" "net/url" "strings" + "github.com/tmc/langchaingo/agents" + + "github.com/tmc/langchaingo/llms" + "gitea.stevedudenhoeffer.com/steve/answer/pkg/cache" "gitea.stevedudenhoeffer.com/steve/answer/pkg/extractor" "gitea.stevedudenhoeffer.com/steve/answer/pkg/search" @@ -22,7 +27,7 @@ type Question struct { // Question is the question to answer Question string - Model gollm.ChatCompletion + Model llms.Model Search search.Search @@ -146,44 +151,82 @@ func extractArticle(ctx context.Context, c cache.Cache, u *url.URL) (res article } func doesTextAnswerQuestion(ctx context.Context, q Question, text string) (string, error) { - fnAnswer := gollm.NewFunction( - "answer", - "The answer from the given text that answers the question.", - func(ctx context.Context, args struct { - Answer string `description:"the answer to the question, the answer should come from the text"` - }) (string, error) { - return args.Answer, nil - }) - - fnNoAnswer := gollm.NewFunction( - "no_answer", - "Indicate that the text does not answer the question.", - func(ctx context.Context, args struct { - Ignored string `description:"ignored, just here to make sure the function is called. Fill with anything."` - }) (string, error) { - return "", nil - }) - - req := gollm.Request{ - Messages: []gollm.Message{ - { - Role: gollm.RoleSystem, - Text: "Evaluate the given text to see if it answers the question from the user. The text is as follows:", - }, - { - Role: gollm.RoleSystem, - Text: text, - }, - { - Role: gollm.RoleUser, - Text: q.Question, + var availableTools = []llms.Tool{ + { + Type: "function", + Function: &llms.FunctionDefinition{ + Name: "answer", + Description: "The answer from the given text that answers the question.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "answer": map[string]any{ + "type": "string", + "description": "the answer to the question, the answer should come from the text", + }, + }, + }, + }, + }, + { + Type: "function", + Function: &llms.FunctionDefinition{ + Name: "no_answer", + Description: "Indicate that the text does not answer the question.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "ignored": map[string]any{ + "type": "string", + "description": "ignored, just here to make sure the function is called. Fill with anything.", + }, + }, + }, }, }, - Toolbox: gollm.NewToolBox(fnAnswer, fnNoAnswer), } - res, err := q.Model.ChatComplete(ctx, req) + answer := func(ctx context.Context, args string) (string, error) { + type answer struct { + Answer string `json:"answer"` + } + var a answer + + err := json.Unmarshal([]byte(args), &a) + if err != nil { + return "", err + } + + return a.Answer, nil + } + + noAnswer := func(ctx context.Context, ignored string) (string, error) { + return "", nil + } + + var req = []llms.MessageContent{ + { + Role: llms.ChatMessageTypeSystem, + Parts: []llms.ContentPart{ + llms.TextPart("Evaluate the given text to see if it answers the question from the user. The text is as follows:"), + }, + }, + { + Role: llms.ChatMessageTypeSystem, + Parts: []llms.ContentPart{ + llms.TextPart(text), + }, + }, + { + Role: llms.ChatMessageTypeHuman, + Parts: []llms.ContentPart{ + llms.TextPart(q.Question), + }, + }, + } + + res, err := q.Model.GenerateContent(ctx, req, llms.WithTools(availableTools)) if err != nil { return "", err } @@ -192,11 +235,24 @@ func doesTextAnswerQuestion(ctx context.Context, q Question, text string) (strin return "", nil } - if len(res.Choices[0].Calls) == 0 { + if len(res.Choices[0].ToolCalls) == 0 { return "", nil } - return req.Toolbox.Execute(ctx, res.Choices[0].Calls[0]) + for _, call := range res.Choices[0].ToolCalls { + switch call.FunctionCall.Name { + case "answer": + return answer(ctx, call.FunctionCall.Arguments) + + case "no_answer": + return noAnswer(ctx, call.FunctionCall.Arguments) + + default: + return "", fmt.Errorf("unknown function %s", call.FunctionCall.Name) + } + } + + return "", nil } func functionSearch(ctx context.Context, q Question, searchTerm string) (string, error) { @@ -292,150 +348,30 @@ func functionThink(ctx context.Context, q Question) (string, error) { return req.Toolbox.Execute(ctx, res.Choices[0].Calls[0]) } -func (o Options) Answer(ctx context.Context, q Question) (Answers, error) { - fnSearch := gollm.NewFunction( - "search", - "Search the web for an answer to a question. You can call this function up to "+fmt.Sprint(o.MaxSearches)+" times.", - func(ctx context.Context, args struct { - SearchQuery string `description:"what to search the web for for this question"` - Question string `description:"what question(s) you are trying to answer with this search"` - }) (string, error) { - q2 := q - q2.Question = args.Question - - return functionSearch(ctx, q2, args.SearchQuery) - }) - - fnThink := gollm.NewFunction( - "think", - "Think about a question. This is useful for breaking down complex questions into smaller parts that are easier to answer.", - func(ctx context.Context, args struct { - Question string `json:"question" description:"the question to think about"` - }) (string, error) { - q2 := q - q2.Question = args.Question - - return functionThink(ctx, q2) - }) - - fnAnswer := gollm.NewFunction( - "answer", - "You definitively answer a question, if you call this it means you know the answer and do not need to search for it or use any other function to find it", - func(ctx context.Context, args struct { - Answer string `json:"answer" description:"the answer to the question"` - }) (string, error) { - return args.Answer, nil - }) - - var funcs = []*gollm.Function{fnAnswer} - - if o.MaxSearches > 0 { - funcs = append(funcs, fnSearch) +func appendResponse(req []llms.MessageContent, response *llms.ContentResponse) ([]llms.MessageContent, error) { + if response == nil { + return req, nil } - if o.MaxThinks > 0 { - funcs = append(funcs, fnThink) + if len(response.Choices) == 0 { + return req, nil } - var temp float32 = 0.8 + choice := response.Choices[0] - req := gollm.Request{ - Messages: []gollm.Message{ - { - Role: gollm.RoleSystem, - Text: "You are being asked to answer a question. You must respond with a function. You can answer it if you know the answer, or if some functions exist you can use those to help you find the answer.", - }, - { - Role: gollm.RoleUser, - Text: q.Question, - }, - }, - Toolbox: gollm.NewToolBox(funcs...), - Temperature: &temp, + assistantResponse := llms.TextParts(llms.ChatMessageTypeAI, choice.Content) + + for _, tc := range choice.ToolCalls { + assistantResponse.Parts = append(assistantResponse.Parts, tc) } - res, err := q.Model.ChatComplete(ctx, req) - - if err != nil { - return nil, err - } - - if len(res.Choices) == 0 { - return nil, nil - } - - if len(res.Choices) > o.MaxSearches { - res.Choices = res.Choices[:o.MaxSearches] - } - - var answers []Result - for _, choice := range res.Choices { - fnChoice := func(choice gollm.ResponseChoice) []Result { - var calls []Result - var callsOutput = make(chan Result, len(choice.Calls)) - fnCall := func(call gollm.ToolCall) Result { - str, err := req.Toolbox.Execute(ctx, call) - - if err != nil { - return Result{ - Error: err, - } - } - - return Result{ - Result: str, - } - } - - for _, call := range choice.Calls { - go func(call gollm.ToolCall) { - if o.OnNewFunction != nil { - err := o.OnNewFunction(ctx, call.FunctionCall.Name, q.Question, call.FunctionCall.Arguments) - if err != nil { - callsOutput <- Result{ - Error: err, - } - return - } - } - callsOutput <- fnCall(call) - }(call) - } - - for i := 0; i < len(choice.Calls); i++ { - result := <-callsOutput - calls = append(calls, result) - } - - close(callsOutput) - - slog.Info("calls", "calls", calls) - return calls - } - - answers = append(answers, fnChoice(choice)...) - } - - var errs []error - var results []string - - for _, answer := range answers { - if answer.Error != nil { - errs = append(errs, answer.Error) - continue - } - - results = append(results, answer.Result) - } - - if len(errs) > 0 { - return nil, errors.Join(errs...) - } - - return results, nil - + return req, nil } -func Answer(ctx context.Context, q Question) (Answers, error) { +func (o Options) Answer(ctx context.Context, q Question) (string, error) { + a := agents.NewConversationalAgent() +} + +func Answer(ctx context.Context, q Question) (string, error) { return DefaultOptions.Answer(ctx, q) }