answer/pkg/agents/agent.go

239 lines
6.4 KiB
Go

package agents
import (
"context"
"fmt"
gollm "gitea.stevedudenhoeffer.com/steve/go-llm"
"sync"
"sync/atomic"
)
// Agent is essentially the bones of a chat agent. It has a model and a toolbox, and can be used to call the model
// with messages and execute the resulting calls.
// The agent will keep track of how many calls it has made to the model, and any agents which inherit from this
// one (e.g.: all made from sharing pointers or from the With... helpers) will share the same call count.
// This package contains a number of helper functions to make it easier to create and use agents.
type Agent struct {
model gollm.ChatCompletion
toolbox *gollm.ToolBox
systemPrompt string
contextualInformation []string
systemPromptSuffix string
maxCalls *int32
calls *atomic.Int32
}
// NewAgent creates a new agent struct with the given model and toolbox.
// Any inherited agents (e.g.: all made from sharing pointers or from the With... helpers) from this one will
// share the same call count.
func NewAgent(model gollm.ChatCompletion, toolbox *gollm.ToolBox) Agent {
return Agent{
model: model,
toolbox: toolbox,
calls: &atomic.Int32{},
}
}
func (a Agent) Calls() int32 {
return a.calls.Load()
}
func (a Agent) WithModel(model gollm.ChatCompletion) Agent {
a.model = model
return a
}
func (a Agent) WithToolbox(toolbox *gollm.ToolBox) Agent {
a.toolbox = toolbox
return a
}
func (a Agent) WithSystemPrompt(systemPrompt string) Agent {
a.systemPrompt = systemPrompt
return a
}
func (a Agent) WithContextualInformation(contextualInformation []string) Agent {
a.contextualInformation = append(a.contextualInformation, contextualInformation...)
return a
}
func (a Agent) WithSystemPromptSuffix(systemPromptSuffix string) Agent {
a.systemPromptSuffix = systemPromptSuffix
return a
}
func (a Agent) WithMaxCalls(maxCalls int32) Agent {
a.maxCalls = &maxCalls
return a
}
func (a Agent) _readAnyMessages(messages ...any) ([]gollm.Message, error) {
var res []gollm.Message
for _, msg := range messages {
switch v := msg.(type) {
case gollm.Message:
res = append(res, v)
case string:
res = append(res, gollm.Message{
Role: gollm.RoleUser,
Text: v,
})
default:
return res, fmt.Errorf("unknown type %T used as message", msg)
}
}
return res, nil
}
// ToRequest will convert the current agent configuration into a gollm.Request. Any messages passed in will be added
// to the request at the end. messages can be either a gollm.Message or a string. All string entries will be added as
// simple user messages.
func (a Agent) ToRequest(messages ...any) (gollm.Request, error) {
sysPrompt := a.systemPrompt
if len(a.contextualInformation) > 0 {
if len(sysPrompt) > 0 {
sysPrompt += "\n\n"
}
sysPrompt += fmt.Sprintf(" Contextual information you should be aware of: %v", a.contextualInformation)
}
if len(a.systemPromptSuffix) > 0 {
if len(sysPrompt) > 0 {
sysPrompt += "\n\n"
}
sysPrompt += a.systemPromptSuffix
}
req := gollm.Request{
Toolbox: a.toolbox,
}
if len(sysPrompt) > 0 {
req.Messages = append(req.Messages, gollm.Message{
Role: gollm.RoleSystem,
Text: sysPrompt,
})
}
msgs, err := a._readAnyMessages(messages...)
if err != nil {
return req, fmt.Errorf("failed to read messages: %w", err)
}
req.Messages = append(req.Messages, msgs...)
return req, nil
}
// CallModel calls the model with the given messages and returns the raw response.
// note that the msgs can be either a gollm.Message or a string. All string entries will be added as simple
// user messages.
func (a Agent) CallModel(ctx context.Context, msgs ...any) (gollm.Response, error) {
calls := a.calls.Add(1)
if a.maxCalls != nil && calls > *a.maxCalls {
return gollm.Response{}, fmt.Errorf("max model calls exceeded")
}
req, err := a.ToRequest(msgs...)
if err != nil {
return gollm.Response{}, fmt.Errorf("failed to create request: %w", err)
}
return a.model.ChatComplete(ctx, req)
}
type CallResults struct {
ID string
Function string
Arguments string
Result any
Error error
}
type CallAndExecuteResults struct {
Text string
CallResults []CallResults
}
// CallAndExecute calls the model with the given messages and executes the resulting calls in serial order. The results
// are returned in the same order as the calls.
func (a Agent) CallAndExecute(ctx context.Context, msgs ...any) (CallAndExecuteResults, error) {
return a._callAndExecuteParallel(ctx, false, msgs...)
}
// CallAndExecuteParallel will call the model with the given messages and all the tool calls in the response will be
// executed in parallel. The results will be returned in the same order as the calls.
func (a Agent) CallAndExecuteParallel(ctx context.Context, msgs ...any) (CallAndExecuteResults, error) {
return a._callAndExecuteParallel(ctx, true, msgs...)
}
func (a Agent) _callAndExecuteParallel(ctx context.Context, parallel bool, msgs ...any) (CallAndExecuteResults, error) {
calls := a.calls.Add(1)
if a.maxCalls != nil && calls > *a.maxCalls {
return CallAndExecuteResults{}, fmt.Errorf("max model calls exceeded")
}
req, err := a.ToRequest(msgs...)
if err != nil {
return CallAndExecuteResults{}, fmt.Errorf("failed to create request: %w", err)
}
response, err := a.model.ChatComplete(ctx, req)
if err != nil {
return CallAndExecuteResults{}, fmt.Errorf("error calling model: %w", err)
}
if len(response.Choices) == 0 {
return CallAndExecuteResults{}, fmt.Errorf("no choices found")
}
choice := response.Choices[0]
var res = CallAndExecuteResults{
Text: choice.Content,
CallResults: make([]CallResults, len(choice.Calls)),
}
if parallel {
var wg sync.WaitGroup
for i, call := range choice.Calls {
wg.Add(1)
go func() {
var callRes = CallResults{
ID: call.ID,
Function: call.FunctionCall.Name,
Arguments: call.FunctionCall.Arguments,
}
callRes.Result, callRes.Error = req.Toolbox.Execute(gollm.NewContext(ctx, req, &choice, &call), call)
res.CallResults[i] = callRes
wg.Done()
}()
}
wg.Wait()
} else {
for i, call := range choice.Calls {
var callRes = CallResults{
ID: call.ID,
Function: call.FunctionCall.Name,
Arguments: call.FunctionCall.Arguments,
}
callRes.Result, callRes.Error = req.Toolbox.Execute(gollm.NewContext(ctx, req, &choice, &call), call)
res.CallResults[i] = callRes
}
}
return res, nil
}