Consolidated a bunch of reused code to agents
This commit is contained in:
238
pkg/agents/agent.go
Normal file
238
pkg/agents/agent.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package agents
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
gollm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Agent is essentially the bones of a chat agent. It has a model and a toolbox, and can be used to call the model
|
||||
// with messages and execute the resulting calls.
|
||||
// The agent will keep track of how many calls it has made to the model, and any agents which inherit from this
|
||||
// one (e.g.: all made from sharing pointers or from the With... helpers) will share the same call count.
|
||||
// This package contains a number of helper functions to make it easier to create and use agents.
|
||||
type Agent struct {
|
||||
model gollm.ChatCompletion
|
||||
toolbox *gollm.ToolBox
|
||||
systemPrompt string
|
||||
contextualInformation []string
|
||||
systemPromptSuffix string
|
||||
maxCalls *int32
|
||||
|
||||
calls *atomic.Int32
|
||||
}
|
||||
|
||||
// NewAgent creates a new agent struct with the given model and toolbox.
|
||||
// Any inherited agents (e.g.: all made from sharing pointers or from the With... helpers) from this one will
|
||||
// share the same call count.
|
||||
func NewAgent(model gollm.ChatCompletion, toolbox *gollm.ToolBox) Agent {
|
||||
return Agent{
|
||||
model: model,
|
||||
toolbox: toolbox,
|
||||
calls: &atomic.Int32{},
|
||||
}
|
||||
}
|
||||
|
||||
func (a Agent) Calls() int32 {
|
||||
return a.calls.Load()
|
||||
}
|
||||
|
||||
func (a Agent) WithModel(model gollm.ChatCompletion) Agent {
|
||||
a.model = model
|
||||
return a
|
||||
}
|
||||
|
||||
func (a Agent) WithToolbox(toolbox *gollm.ToolBox) Agent {
|
||||
a.toolbox = toolbox
|
||||
return a
|
||||
}
|
||||
|
||||
func (a Agent) WithSystemPrompt(systemPrompt string) Agent {
|
||||
a.systemPrompt = systemPrompt
|
||||
return a
|
||||
}
|
||||
|
||||
func (a Agent) WithContextualInformation(contextualInformation []string) Agent {
|
||||
a.contextualInformation = append(a.contextualInformation, contextualInformation...)
|
||||
return a
|
||||
}
|
||||
|
||||
func (a Agent) WithSystemPromptSuffix(systemPromptSuffix string) Agent {
|
||||
a.systemPromptSuffix = systemPromptSuffix
|
||||
return a
|
||||
}
|
||||
|
||||
func (a Agent) WithMaxCalls(maxCalls int32) Agent {
|
||||
a.maxCalls = &maxCalls
|
||||
return a
|
||||
}
|
||||
|
||||
func (a Agent) _readAnyMessages(messages ...any) ([]gollm.Message, error) {
|
||||
var res []gollm.Message
|
||||
|
||||
for _, msg := range messages {
|
||||
switch v := msg.(type) {
|
||||
case gollm.Message:
|
||||
res = append(res, v)
|
||||
case string:
|
||||
res = append(res, gollm.Message{
|
||||
Role: gollm.RoleUser,
|
||||
Text: v,
|
||||
})
|
||||
|
||||
default:
|
||||
return res, fmt.Errorf("unknown type %T used as message", msg)
|
||||
}
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// ToRequest will convert the current agent configuration into a gollm.Request. Any messages passed in will be added
|
||||
// to the request at the end. messages can be either a gollm.Message or a string. All string entries will be added as
|
||||
// simple user messages.
|
||||
func (a Agent) ToRequest(messages ...any) (gollm.Request, error) {
|
||||
sysPrompt := a.systemPrompt
|
||||
|
||||
if len(a.contextualInformation) > 0 {
|
||||
if len(sysPrompt) > 0 {
|
||||
sysPrompt += "\n\n"
|
||||
}
|
||||
sysPrompt += fmt.Sprintf(" Contextual information you should be aware of: %v", a.contextualInformation)
|
||||
}
|
||||
|
||||
if len(a.systemPromptSuffix) > 0 {
|
||||
if len(sysPrompt) > 0 {
|
||||
sysPrompt += "\n\n"
|
||||
}
|
||||
sysPrompt += a.systemPromptSuffix
|
||||
}
|
||||
|
||||
req := gollm.Request{
|
||||
Toolbox: a.toolbox,
|
||||
}
|
||||
|
||||
if len(sysPrompt) > 0 {
|
||||
req.Messages = append(req.Messages, gollm.Message{
|
||||
Role: gollm.RoleSystem,
|
||||
Text: sysPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
msgs, err := a._readAnyMessages(messages...)
|
||||
if err != nil {
|
||||
return req, fmt.Errorf("failed to read messages: %w", err)
|
||||
}
|
||||
|
||||
req.Messages = append(req.Messages, msgs...)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// CallModel calls the model with the given messages and returns the raw response.
|
||||
// note that the msgs can be either a gollm.Message or a string. All string entries will be added as simple
|
||||
// user messages.
|
||||
func (a Agent) CallModel(ctx context.Context, msgs ...any) (gollm.Response, error) {
|
||||
calls := a.calls.Add(1)
|
||||
if a.maxCalls != nil && calls > *a.maxCalls {
|
||||
return gollm.Response{}, fmt.Errorf("max model calls exceeded")
|
||||
}
|
||||
|
||||
req, err := a.ToRequest(msgs...)
|
||||
|
||||
if err != nil {
|
||||
return gollm.Response{}, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
return a.model.ChatComplete(ctx, req)
|
||||
}
|
||||
|
||||
type CallResults struct {
|
||||
ID string
|
||||
Function string
|
||||
Arguments string
|
||||
Result any
|
||||
Error error
|
||||
}
|
||||
type CallAndExecuteResults struct {
|
||||
Text string
|
||||
CallResults []CallResults
|
||||
}
|
||||
|
||||
// CallAndExecute calls the model with the given messages and executes the resulting calls in serial order. The results
|
||||
// are returned in the same order as the calls.
|
||||
func (a Agent) CallAndExecute(ctx context.Context, msgs ...any) (CallAndExecuteResults, error) {
|
||||
return a._callAndExecuteParallel(ctx, false, msgs...)
|
||||
}
|
||||
|
||||
// CallAndExecuteParallel will call the model with the given messages and all the tool calls in the response will be
|
||||
// executed in parallel. The results will be returned in the same order as the calls.
|
||||
func (a Agent) CallAndExecuteParallel(ctx context.Context, msgs ...any) (CallAndExecuteResults, error) {
|
||||
return a._callAndExecuteParallel(ctx, true, msgs...)
|
||||
}
|
||||
|
||||
func (a Agent) _callAndExecuteParallel(ctx context.Context, parallel bool, msgs ...any) (CallAndExecuteResults, error) {
|
||||
calls := a.calls.Add(1)
|
||||
if a.maxCalls != nil && calls > *a.maxCalls {
|
||||
return CallAndExecuteResults{}, fmt.Errorf("max model calls exceeded")
|
||||
}
|
||||
|
||||
req, err := a.ToRequest(msgs...)
|
||||
|
||||
if err != nil {
|
||||
return CallAndExecuteResults{}, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
response, err := a.model.ChatComplete(ctx, req)
|
||||
if err != nil {
|
||||
return CallAndExecuteResults{}, fmt.Errorf("error calling model: %w", err)
|
||||
}
|
||||
|
||||
if len(response.Choices) == 0 {
|
||||
return CallAndExecuteResults{}, fmt.Errorf("no choices found")
|
||||
}
|
||||
|
||||
choice := response.Choices[0]
|
||||
|
||||
var res = CallAndExecuteResults{
|
||||
Text: choice.Content,
|
||||
CallResults: make([]CallResults, len(choice.Calls)),
|
||||
}
|
||||
|
||||
if parallel {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, call := range choice.Calls {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
var callRes = CallResults{
|
||||
ID: call.ID,
|
||||
Function: call.FunctionCall.Name,
|
||||
Arguments: call.FunctionCall.Arguments,
|
||||
}
|
||||
|
||||
callRes.Result, callRes.Error = req.Toolbox.Execute(gollm.NewContext(ctx, req, &choice, &call), call)
|
||||
|
||||
res.CallResults[i] = callRes
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
} else {
|
||||
for i, call := range choice.Calls {
|
||||
var callRes = CallResults{
|
||||
ID: call.ID,
|
||||
Function: call.FunctionCall.Name,
|
||||
Arguments: call.FunctionCall.Arguments,
|
||||
}
|
||||
|
||||
callRes.Result, callRes.Error = req.Toolbox.Execute(gollm.NewContext(ctx, req, &choice, &call), call)
|
||||
|
||||
res.CallResults[i] = callRes
|
||||
}
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
Reference in New Issue
Block a user