answer/pkg/agents/agent.go
Steve Dudenhoeffer 1be5cc047c Fix browser closure timing in search agent
Move defer statement to ensure browser closure occurs only after assigning the browser to the context. This prevents potential issues of premature resource release.
2025-04-12 02:22:10 -04:00

243 lines
6.6 KiB
Go

package agents
import (
"context"
"fmt"
"sync"
"sync/atomic"
gollm "gitea.stevedudenhoeffer.com/steve/go-llm"
)
// Agent is essentially the bones of a chat agent. It has a model and a toolbox, and can be used to call the model
// with messages and execute the resulting calls.
// The agent will keep track of how many calls it has made to the model, and any agents which inherit from this
// one (e.g.: all made from sharing pointers or from the With... helpers) will share the same call count.
// This package contains a number of helper functions to make it easier to create and use agents.
type Agent struct {
model gollm.ChatCompletion
toolbox gollm.ToolBox
systemPrompt string
contextualInformation []string
systemPromptSuffix string
maxCalls *int32
insertReason bool
calls *atomic.Int32
}
// NewAgent creates a new agent struct with the given model and toolbox.
// Any inherited agents (e.g.: all made from sharing pointers or from the With... helpers) from this one will
// share the same call count.
func NewAgent(model gollm.ChatCompletion, toolbox gollm.ToolBox) Agent {
return Agent{
model: model,
toolbox: toolbox,
calls: &atomic.Int32{},
}
}
func (a Agent) Calls() int32 {
return a.calls.Load()
}
func (a Agent) WithModel(model gollm.ChatCompletion) Agent {
a.model = model
return a
}
func (a Agent) WithToolbox(toolbox gollm.ToolBox) Agent {
a.toolbox = toolbox.WithSyntheticFieldsAddedToAllFunctions(map[string]string{
"reason": "The reason you are calling this function. This will be remembered and presenting to the LLM when it continues after the function call.",
})
return a
}
func (a Agent) WithSystemPrompt(systemPrompt string) Agent {
a.systemPrompt = systemPrompt
return a
}
func (a Agent) WithContextualInformation(contextualInformation []string) Agent {
a.contextualInformation = append(a.contextualInformation, contextualInformation...)
return a
}
func (a Agent) WithSystemPromptSuffix(systemPromptSuffix string) Agent {
a.systemPromptSuffix = systemPromptSuffix
return a
}
func (a Agent) WithMaxCalls(maxCalls int32) Agent {
a.maxCalls = &maxCalls
return a
}
func (a Agent) _readAnyMessages(messages ...any) ([]gollm.Message, error) {
var res []gollm.Message
for _, msg := range messages {
switch v := msg.(type) {
case gollm.Message:
res = append(res, v)
case string:
res = append(res, gollm.Message{
Role: gollm.RoleUser,
Text: v,
})
default:
return res, fmt.Errorf("unknown type %T used as message", msg)
}
}
return res, nil
}
// ToRequest will convert the current agent configuration into a gollm.Request. Any messages passed in will be added
// to the request at the end. messages can be either a gollm.Message or a string. All string entries will be added as
// simple user messages.
func (a Agent) ToRequest(messages ...any) (gollm.Request, error) {
sysPrompt := a.systemPrompt
if len(a.contextualInformation) > 0 {
if len(sysPrompt) > 0 {
sysPrompt += "\n\n"
}
sysPrompt += fmt.Sprintf(" Contextual information you should be aware of: %v", a.contextualInformation)
}
if len(a.systemPromptSuffix) > 0 {
if len(sysPrompt) > 0 {
sysPrompt += "\n\n"
}
sysPrompt += a.systemPromptSuffix
}
req := gollm.Request{
Toolbox: a.toolbox,
}
if len(sysPrompt) > 0 {
req.Messages = append(req.Messages, gollm.Message{
Role: gollm.RoleSystem,
Text: sysPrompt,
})
}
msgs, err := a._readAnyMessages(messages...)
if err != nil {
return req, fmt.Errorf("failed to read messages: %w", err)
}
req.Messages = append(req.Messages, msgs...)
return req, nil
}
// CallModel calls the model with the given messages and returns the raw response.
// note that the msgs can be either a gollm.Message or a string. All string entries will be added as simple
// user messages.
func (a Agent) CallModel(ctx context.Context, msgs ...any) (gollm.Response, error) {
calls := a.calls.Add(1)
if a.maxCalls != nil && calls > *a.maxCalls {
return gollm.Response{}, fmt.Errorf("max model calls exceeded")
}
req, err := a.ToRequest(msgs...)
if err != nil {
return gollm.Response{}, fmt.Errorf("failed to create request: %w", err)
}
return a.model.ChatComplete(ctx, req)
}
type CallResults struct {
ID string
Function string
Arguments string
Result any
Error error
}
type CallAndExecuteResults struct {
Text string
CallResults []CallResults
}
// CallAndExecute calls the model with the given messages and executes the resulting calls in serial order. The results
// are returned in the same order as the calls.
func (a Agent) CallAndExecute(ctx context.Context, msgs ...any) (CallAndExecuteResults, error) {
return a._callAndExecuteParallel(ctx, false, msgs...)
}
// CallAndExecuteParallel will call the model with the given messages and all the tool calls in the response will be
// executed in parallel. The results will be returned in the same order as the calls.
func (a Agent) CallAndExecuteParallel(ctx context.Context, msgs ...any) (CallAndExecuteResults, error) {
return a._callAndExecuteParallel(ctx, true, msgs...)
}
func (a Agent) _callAndExecuteParallel(ctx context.Context, parallel bool, msgs ...any) (CallAndExecuteResults, error) {
calls := a.calls.Add(1)
if a.maxCalls != nil && calls > *a.maxCalls {
return CallAndExecuteResults{}, fmt.Errorf("max model calls exceeded")
}
req, err := a.ToRequest(msgs...)
if err != nil {
return CallAndExecuteResults{}, fmt.Errorf("failed to create request: %w", err)
}
response, err := a.model.ChatComplete(ctx, req)
if err != nil {
return CallAndExecuteResults{}, fmt.Errorf("error calling model: %w", err)
}
if len(response.Choices) == 0 {
return CallAndExecuteResults{}, fmt.Errorf("no choices found")
}
choice := response.Choices[0]
var res = CallAndExecuteResults{
Text: choice.Content,
CallResults: make([]CallResults, len(choice.Calls)),
}
if parallel {
var wg sync.WaitGroup
for i, call := range choice.Calls {
wg.Add(1)
go func() {
var callRes = CallResults{
ID: call.ID,
Function: call.FunctionCall.Name,
Arguments: call.FunctionCall.Arguments,
}
callRes.Result, callRes.Error = req.Toolbox.Execute(gollm.NewContext(ctx, req, &choice, &call), call)
res.CallResults[i] = callRes
wg.Done()
}()
}
wg.Wait()
} else {
for i, call := range choice.Calls {
var callRes = CallResults{
ID: call.ID,
Function: call.FunctionCall.Name,
Arguments: call.FunctionCall.Arguments,
}
callRes.Result, callRes.Error = req.Toolbox.Execute(gollm.NewContext(ctx, req, &choice, &call), call)
res.CallResults[i] = callRes
}
}
return res, nil
}