Move defer statement to ensure browser closure occurs only after assigning the browser to the context. This prevents potential issues of premature resource release.
243 lines
6.6 KiB
Go
243 lines
6.6 KiB
Go
package agents
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
gollm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
|
)
|
|
|
|
// Agent is essentially the bones of a chat agent. It has a model and a toolbox, and can be used to call the model
|
|
// with messages and execute the resulting calls.
|
|
// The agent will keep track of how many calls it has made to the model, and any agents which inherit from this
|
|
// one (e.g.: all made from sharing pointers or from the With... helpers) will share the same call count.
|
|
// This package contains a number of helper functions to make it easier to create and use agents.
|
|
type Agent struct {
|
|
model gollm.ChatCompletion
|
|
toolbox gollm.ToolBox
|
|
systemPrompt string
|
|
contextualInformation []string
|
|
systemPromptSuffix string
|
|
maxCalls *int32
|
|
insertReason bool
|
|
|
|
calls *atomic.Int32
|
|
}
|
|
|
|
// NewAgent creates a new agent struct with the given model and toolbox.
|
|
// Any inherited agents (e.g.: all made from sharing pointers or from the With... helpers) from this one will
|
|
// share the same call count.
|
|
func NewAgent(model gollm.ChatCompletion, toolbox gollm.ToolBox) Agent {
|
|
return Agent{
|
|
model: model,
|
|
toolbox: toolbox,
|
|
calls: &atomic.Int32{},
|
|
}
|
|
}
|
|
|
|
func (a Agent) Calls() int32 {
|
|
return a.calls.Load()
|
|
}
|
|
|
|
func (a Agent) WithModel(model gollm.ChatCompletion) Agent {
|
|
a.model = model
|
|
return a
|
|
}
|
|
|
|
func (a Agent) WithToolbox(toolbox gollm.ToolBox) Agent {
|
|
a.toolbox = toolbox.WithSyntheticFieldsAddedToAllFunctions(map[string]string{
|
|
"reason": "The reason you are calling this function. This will be remembered and presenting to the LLM when it continues after the function call.",
|
|
})
|
|
return a
|
|
}
|
|
|
|
func (a Agent) WithSystemPrompt(systemPrompt string) Agent {
|
|
a.systemPrompt = systemPrompt
|
|
return a
|
|
}
|
|
|
|
func (a Agent) WithContextualInformation(contextualInformation []string) Agent {
|
|
a.contextualInformation = append(a.contextualInformation, contextualInformation...)
|
|
return a
|
|
}
|
|
|
|
func (a Agent) WithSystemPromptSuffix(systemPromptSuffix string) Agent {
|
|
a.systemPromptSuffix = systemPromptSuffix
|
|
return a
|
|
}
|
|
|
|
func (a Agent) WithMaxCalls(maxCalls int32) Agent {
|
|
a.maxCalls = &maxCalls
|
|
return a
|
|
}
|
|
|
|
func (a Agent) _readAnyMessages(messages ...any) ([]gollm.Message, error) {
|
|
var res []gollm.Message
|
|
|
|
for _, msg := range messages {
|
|
switch v := msg.(type) {
|
|
case gollm.Message:
|
|
res = append(res, v)
|
|
case string:
|
|
res = append(res, gollm.Message{
|
|
Role: gollm.RoleUser,
|
|
Text: v,
|
|
})
|
|
|
|
default:
|
|
return res, fmt.Errorf("unknown type %T used as message", msg)
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// ToRequest will convert the current agent configuration into a gollm.Request. Any messages passed in will be added
|
|
// to the request at the end. messages can be either a gollm.Message or a string. All string entries will be added as
|
|
// simple user messages.
|
|
func (a Agent) ToRequest(messages ...any) (gollm.Request, error) {
|
|
sysPrompt := a.systemPrompt
|
|
|
|
if len(a.contextualInformation) > 0 {
|
|
if len(sysPrompt) > 0 {
|
|
sysPrompt += "\n\n"
|
|
}
|
|
sysPrompt += fmt.Sprintf(" Contextual information you should be aware of: %v", a.contextualInformation)
|
|
}
|
|
|
|
if len(a.systemPromptSuffix) > 0 {
|
|
if len(sysPrompt) > 0 {
|
|
sysPrompt += "\n\n"
|
|
}
|
|
sysPrompt += a.systemPromptSuffix
|
|
}
|
|
|
|
req := gollm.Request{
|
|
Toolbox: a.toolbox,
|
|
}
|
|
|
|
if len(sysPrompt) > 0 {
|
|
req.Messages = append(req.Messages, gollm.Message{
|
|
Role: gollm.RoleSystem,
|
|
Text: sysPrompt,
|
|
})
|
|
}
|
|
|
|
msgs, err := a._readAnyMessages(messages...)
|
|
if err != nil {
|
|
return req, fmt.Errorf("failed to read messages: %w", err)
|
|
}
|
|
|
|
req.Messages = append(req.Messages, msgs...)
|
|
return req, nil
|
|
}
|
|
|
|
// CallModel calls the model with the given messages and returns the raw response.
|
|
// note that the msgs can be either a gollm.Message or a string. All string entries will be added as simple
|
|
// user messages.
|
|
func (a Agent) CallModel(ctx context.Context, msgs ...any) (gollm.Response, error) {
|
|
calls := a.calls.Add(1)
|
|
if a.maxCalls != nil && calls > *a.maxCalls {
|
|
return gollm.Response{}, fmt.Errorf("max model calls exceeded")
|
|
}
|
|
|
|
req, err := a.ToRequest(msgs...)
|
|
|
|
if err != nil {
|
|
return gollm.Response{}, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
return a.model.ChatComplete(ctx, req)
|
|
}
|
|
|
|
type CallResults struct {
|
|
ID string
|
|
Function string
|
|
Arguments string
|
|
Result any
|
|
Error error
|
|
}
|
|
type CallAndExecuteResults struct {
|
|
Text string
|
|
CallResults []CallResults
|
|
}
|
|
|
|
// CallAndExecute calls the model with the given messages and executes the resulting calls in serial order. The results
|
|
// are returned in the same order as the calls.
|
|
func (a Agent) CallAndExecute(ctx context.Context, msgs ...any) (CallAndExecuteResults, error) {
|
|
return a._callAndExecuteParallel(ctx, false, msgs...)
|
|
}
|
|
|
|
// CallAndExecuteParallel will call the model with the given messages and all the tool calls in the response will be
|
|
// executed in parallel. The results will be returned in the same order as the calls.
|
|
func (a Agent) CallAndExecuteParallel(ctx context.Context, msgs ...any) (CallAndExecuteResults, error) {
|
|
return a._callAndExecuteParallel(ctx, true, msgs...)
|
|
}
|
|
|
|
func (a Agent) _callAndExecuteParallel(ctx context.Context, parallel bool, msgs ...any) (CallAndExecuteResults, error) {
|
|
calls := a.calls.Add(1)
|
|
if a.maxCalls != nil && calls > *a.maxCalls {
|
|
return CallAndExecuteResults{}, fmt.Errorf("max model calls exceeded")
|
|
}
|
|
|
|
req, err := a.ToRequest(msgs...)
|
|
|
|
if err != nil {
|
|
return CallAndExecuteResults{}, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
response, err := a.model.ChatComplete(ctx, req)
|
|
if err != nil {
|
|
return CallAndExecuteResults{}, fmt.Errorf("error calling model: %w", err)
|
|
}
|
|
|
|
if len(response.Choices) == 0 {
|
|
return CallAndExecuteResults{}, fmt.Errorf("no choices found")
|
|
}
|
|
|
|
choice := response.Choices[0]
|
|
|
|
var res = CallAndExecuteResults{
|
|
Text: choice.Content,
|
|
CallResults: make([]CallResults, len(choice.Calls)),
|
|
}
|
|
|
|
if parallel {
|
|
var wg sync.WaitGroup
|
|
|
|
for i, call := range choice.Calls {
|
|
wg.Add(1)
|
|
go func() {
|
|
var callRes = CallResults{
|
|
ID: call.ID,
|
|
Function: call.FunctionCall.Name,
|
|
Arguments: call.FunctionCall.Arguments,
|
|
}
|
|
|
|
callRes.Result, callRes.Error = req.Toolbox.Execute(gollm.NewContext(ctx, req, &choice, &call), call)
|
|
|
|
res.CallResults[i] = callRes
|
|
wg.Done()
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
} else {
|
|
for i, call := range choice.Calls {
|
|
var callRes = CallResults{
|
|
ID: call.ID,
|
|
Function: call.FunctionCall.Name,
|
|
Arguments: call.FunctionCall.Arguments,
|
|
}
|
|
|
|
callRes.Result, callRes.Error = req.Toolbox.Execute(gollm.NewContext(ctx, req, &choice, &call), call)
|
|
|
|
res.CallResults[i] = callRes
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|