239 lines
6.4 KiB
Go
239 lines
6.4 KiB
Go
package agents
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
gollm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
|
"sync"
|
|
"sync/atomic"
|
|
)
|
|
|
|
// Agent is essentially the bones of a chat agent. It has a model and a toolbox, and can be used to call the model
|
|
// with messages and execute the resulting calls.
|
|
// The agent will keep track of how many calls it has made to the model, and any agents which inherit from this
|
|
// one (e.g.: all made from sharing pointers or from the With... helpers) will share the same call count.
|
|
// This package contains a number of helper functions to make it easier to create and use agents.
|
|
type Agent struct {
|
|
model gollm.ChatCompletion
|
|
toolbox *gollm.ToolBox
|
|
systemPrompt string
|
|
contextualInformation []string
|
|
systemPromptSuffix string
|
|
maxCalls *int32
|
|
|
|
calls *atomic.Int32
|
|
}
|
|
|
|
// NewAgent creates a new agent struct with the given model and toolbox.
|
|
// Any inherited agents (e.g.: all made from sharing pointers or from the With... helpers) from this one will
|
|
// share the same call count.
|
|
func NewAgent(model gollm.ChatCompletion, toolbox *gollm.ToolBox) Agent {
|
|
return Agent{
|
|
model: model,
|
|
toolbox: toolbox,
|
|
calls: &atomic.Int32{},
|
|
}
|
|
}
|
|
|
|
func (a Agent) Calls() int32 {
|
|
return a.calls.Load()
|
|
}
|
|
|
|
func (a Agent) WithModel(model gollm.ChatCompletion) Agent {
|
|
a.model = model
|
|
return a
|
|
}
|
|
|
|
func (a Agent) WithToolbox(toolbox *gollm.ToolBox) Agent {
|
|
a.toolbox = toolbox
|
|
return a
|
|
}
|
|
|
|
func (a Agent) WithSystemPrompt(systemPrompt string) Agent {
|
|
a.systemPrompt = systemPrompt
|
|
return a
|
|
}
|
|
|
|
func (a Agent) WithContextualInformation(contextualInformation []string) Agent {
|
|
a.contextualInformation = append(a.contextualInformation, contextualInformation...)
|
|
return a
|
|
}
|
|
|
|
func (a Agent) WithSystemPromptSuffix(systemPromptSuffix string) Agent {
|
|
a.systemPromptSuffix = systemPromptSuffix
|
|
return a
|
|
}
|
|
|
|
func (a Agent) WithMaxCalls(maxCalls int32) Agent {
|
|
a.maxCalls = &maxCalls
|
|
return a
|
|
}
|
|
|
|
func (a Agent) _readAnyMessages(messages ...any) ([]gollm.Message, error) {
|
|
var res []gollm.Message
|
|
|
|
for _, msg := range messages {
|
|
switch v := msg.(type) {
|
|
case gollm.Message:
|
|
res = append(res, v)
|
|
case string:
|
|
res = append(res, gollm.Message{
|
|
Role: gollm.RoleUser,
|
|
Text: v,
|
|
})
|
|
|
|
default:
|
|
return res, fmt.Errorf("unknown type %T used as message", msg)
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// ToRequest will convert the current agent configuration into a gollm.Request. Any messages passed in will be added
|
|
// to the request at the end. messages can be either a gollm.Message or a string. All string entries will be added as
|
|
// simple user messages.
|
|
func (a Agent) ToRequest(messages ...any) (gollm.Request, error) {
|
|
sysPrompt := a.systemPrompt
|
|
|
|
if len(a.contextualInformation) > 0 {
|
|
if len(sysPrompt) > 0 {
|
|
sysPrompt += "\n\n"
|
|
}
|
|
sysPrompt += fmt.Sprintf(" Contextual information you should be aware of: %v", a.contextualInformation)
|
|
}
|
|
|
|
if len(a.systemPromptSuffix) > 0 {
|
|
if len(sysPrompt) > 0 {
|
|
sysPrompt += "\n\n"
|
|
}
|
|
sysPrompt += a.systemPromptSuffix
|
|
}
|
|
|
|
req := gollm.Request{
|
|
Toolbox: a.toolbox,
|
|
}
|
|
|
|
if len(sysPrompt) > 0 {
|
|
req.Messages = append(req.Messages, gollm.Message{
|
|
Role: gollm.RoleSystem,
|
|
Text: sysPrompt,
|
|
})
|
|
}
|
|
|
|
msgs, err := a._readAnyMessages(messages...)
|
|
if err != nil {
|
|
return req, fmt.Errorf("failed to read messages: %w", err)
|
|
}
|
|
|
|
req.Messages = append(req.Messages, msgs...)
|
|
return req, nil
|
|
}
|
|
|
|
// CallModel calls the model with the given messages and returns the raw response.
|
|
// note that the msgs can be either a gollm.Message or a string. All string entries will be added as simple
|
|
// user messages.
|
|
func (a Agent) CallModel(ctx context.Context, msgs ...any) (gollm.Response, error) {
|
|
calls := a.calls.Add(1)
|
|
if a.maxCalls != nil && calls > *a.maxCalls {
|
|
return gollm.Response{}, fmt.Errorf("max model calls exceeded")
|
|
}
|
|
|
|
req, err := a.ToRequest(msgs...)
|
|
|
|
if err != nil {
|
|
return gollm.Response{}, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
return a.model.ChatComplete(ctx, req)
|
|
}
|
|
|
|
type CallResults struct {
|
|
ID string
|
|
Function string
|
|
Arguments string
|
|
Result any
|
|
Error error
|
|
}
|
|
type CallAndExecuteResults struct {
|
|
Text string
|
|
CallResults []CallResults
|
|
}
|
|
|
|
// CallAndExecute calls the model with the given messages and executes the resulting calls in serial order. The results
|
|
// are returned in the same order as the calls.
|
|
func (a Agent) CallAndExecute(ctx context.Context, msgs ...any) (CallAndExecuteResults, error) {
|
|
return a._callAndExecuteParallel(ctx, false, msgs...)
|
|
}
|
|
|
|
// CallAndExecuteParallel will call the model with the given messages and all the tool calls in the response will be
|
|
// executed in parallel. The results will be returned in the same order as the calls.
|
|
func (a Agent) CallAndExecuteParallel(ctx context.Context, msgs ...any) (CallAndExecuteResults, error) {
|
|
return a._callAndExecuteParallel(ctx, true, msgs...)
|
|
}
|
|
|
|
func (a Agent) _callAndExecuteParallel(ctx context.Context, parallel bool, msgs ...any) (CallAndExecuteResults, error) {
|
|
calls := a.calls.Add(1)
|
|
if a.maxCalls != nil && calls > *a.maxCalls {
|
|
return CallAndExecuteResults{}, fmt.Errorf("max model calls exceeded")
|
|
}
|
|
|
|
req, err := a.ToRequest(msgs...)
|
|
|
|
if err != nil {
|
|
return CallAndExecuteResults{}, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
response, err := a.model.ChatComplete(ctx, req)
|
|
if err != nil {
|
|
return CallAndExecuteResults{}, fmt.Errorf("error calling model: %w", err)
|
|
}
|
|
|
|
if len(response.Choices) == 0 {
|
|
return CallAndExecuteResults{}, fmt.Errorf("no choices found")
|
|
}
|
|
|
|
choice := response.Choices[0]
|
|
|
|
var res = CallAndExecuteResults{
|
|
Text: choice.Content,
|
|
CallResults: make([]CallResults, len(choice.Calls)),
|
|
}
|
|
|
|
if parallel {
|
|
var wg sync.WaitGroup
|
|
|
|
for i, call := range choice.Calls {
|
|
wg.Add(1)
|
|
go func() {
|
|
var callRes = CallResults{
|
|
ID: call.ID,
|
|
Function: call.FunctionCall.Name,
|
|
Arguments: call.FunctionCall.Arguments,
|
|
}
|
|
|
|
callRes.Result, callRes.Error = req.Toolbox.Execute(gollm.NewContext(ctx, req, &choice, &call), call)
|
|
|
|
res.CallResults[i] = callRes
|
|
wg.Done()
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
} else {
|
|
for i, call := range choice.Calls {
|
|
var callRes = CallResults{
|
|
ID: call.ID,
|
|
Function: call.FunctionCall.Name,
|
|
Arguments: call.FunctionCall.Arguments,
|
|
}
|
|
|
|
callRes.Result, callRes.Error = req.Toolbox.Execute(gollm.NewContext(ctx, req, &choice, &call), call)
|
|
|
|
res.CallResults[i] = callRes
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|