package agents import ( "context" "fmt" gollm "gitea.stevedudenhoeffer.com/steve/go-llm" "sync" "sync/atomic" ) // Agent is essentially the bones of a chat agent. It has a model and a toolbox, and can be used to call the model // with messages and execute the resulting calls. // The agent will keep track of how many calls it has made to the model, and any agents which inherit from this // one (e.g.: all made from sharing pointers or from the With... helpers) will share the same call count. // This package contains a number of helper functions to make it easier to create and use agents. type Agent struct { model gollm.ChatCompletion toolbox *gollm.ToolBox systemPrompt string contextualInformation []string systemPromptSuffix string maxCalls *int32 calls *atomic.Int32 } // NewAgent creates a new agent struct with the given model and toolbox. // Any inherited agents (e.g.: all made from sharing pointers or from the With... helpers) from this one will // share the same call count. func NewAgent(model gollm.ChatCompletion, toolbox *gollm.ToolBox) Agent { return Agent{ model: model, toolbox: toolbox, calls: &atomic.Int32{}, } } func (a Agent) Calls() int32 { return a.calls.Load() } func (a Agent) WithModel(model gollm.ChatCompletion) Agent { a.model = model return a } func (a Agent) WithToolbox(toolbox *gollm.ToolBox) Agent { a.toolbox = toolbox return a } func (a Agent) WithSystemPrompt(systemPrompt string) Agent { a.systemPrompt = systemPrompt return a } func (a Agent) WithContextualInformation(contextualInformation []string) Agent { a.contextualInformation = append(a.contextualInformation, contextualInformation...) return a } func (a Agent) WithSystemPromptSuffix(systemPromptSuffix string) Agent { a.systemPromptSuffix = systemPromptSuffix return a } func (a Agent) WithMaxCalls(maxCalls int32) Agent { a.maxCalls = &maxCalls return a } func (a Agent) _readAnyMessages(messages ...any) ([]gollm.Message, error) { var res []gollm.Message for _, msg := range messages { switch v := msg.(type) { case gollm.Message: res = append(res, v) case string: res = append(res, gollm.Message{ Role: gollm.RoleUser, Text: v, }) default: return res, fmt.Errorf("unknown type %T used as message", msg) } } return res, nil } // ToRequest will convert the current agent configuration into a gollm.Request. Any messages passed in will be added // to the request at the end. messages can be either a gollm.Message or a string. All string entries will be added as // simple user messages. func (a Agent) ToRequest(messages ...any) (gollm.Request, error) { sysPrompt := a.systemPrompt if len(a.contextualInformation) > 0 { if len(sysPrompt) > 0 { sysPrompt += "\n\n" } sysPrompt += fmt.Sprintf(" Contextual information you should be aware of: %v", a.contextualInformation) } if len(a.systemPromptSuffix) > 0 { if len(sysPrompt) > 0 { sysPrompt += "\n\n" } sysPrompt += a.systemPromptSuffix } req := gollm.Request{ Toolbox: a.toolbox, } if len(sysPrompt) > 0 { req.Messages = append(req.Messages, gollm.Message{ Role: gollm.RoleSystem, Text: sysPrompt, }) } msgs, err := a._readAnyMessages(messages...) if err != nil { return req, fmt.Errorf("failed to read messages: %w", err) } req.Messages = append(req.Messages, msgs...) return req, nil } // CallModel calls the model with the given messages and returns the raw response. // note that the msgs can be either a gollm.Message or a string. All string entries will be added as simple // user messages. func (a Agent) CallModel(ctx context.Context, msgs ...any) (gollm.Response, error) { calls := a.calls.Add(1) if a.maxCalls != nil && calls > *a.maxCalls { return gollm.Response{}, fmt.Errorf("max model calls exceeded") } req, err := a.ToRequest(msgs...) if err != nil { return gollm.Response{}, fmt.Errorf("failed to create request: %w", err) } return a.model.ChatComplete(ctx, req) } type CallResults struct { ID string Function string Arguments string Result any Error error } type CallAndExecuteResults struct { Text string CallResults []CallResults } // CallAndExecute calls the model with the given messages and executes the resulting calls in serial order. The results // are returned in the same order as the calls. func (a Agent) CallAndExecute(ctx context.Context, msgs ...any) (CallAndExecuteResults, error) { return a._callAndExecuteParallel(ctx, false, msgs...) } // CallAndExecuteParallel will call the model with the given messages and all the tool calls in the response will be // executed in parallel. The results will be returned in the same order as the calls. func (a Agent) CallAndExecuteParallel(ctx context.Context, msgs ...any) (CallAndExecuteResults, error) { return a._callAndExecuteParallel(ctx, true, msgs...) } func (a Agent) _callAndExecuteParallel(ctx context.Context, parallel bool, msgs ...any) (CallAndExecuteResults, error) { calls := a.calls.Add(1) if a.maxCalls != nil && calls > *a.maxCalls { return CallAndExecuteResults{}, fmt.Errorf("max model calls exceeded") } req, err := a.ToRequest(msgs...) if err != nil { return CallAndExecuteResults{}, fmt.Errorf("failed to create request: %w", err) } response, err := a.model.ChatComplete(ctx, req) if err != nil { return CallAndExecuteResults{}, fmt.Errorf("error calling model: %w", err) } if len(response.Choices) == 0 { return CallAndExecuteResults{}, fmt.Errorf("no choices found") } choice := response.Choices[0] var res = CallAndExecuteResults{ Text: choice.Content, CallResults: make([]CallResults, len(choice.Calls)), } if parallel { var wg sync.WaitGroup for i, call := range choice.Calls { wg.Add(1) go func() { var callRes = CallResults{ ID: call.ID, Function: call.FunctionCall.Name, Arguments: call.FunctionCall.Arguments, } callRes.Result, callRes.Error = req.Toolbox.Execute(gollm.NewContext(ctx, req, &choice, &call), call) res.CallResults[i] = callRes wg.Done() }() } wg.Wait() } else { for i, call := range choice.Calls { var callRes = CallResults{ ID: call.ID, Function: call.FunctionCall.Name, Arguments: call.FunctionCall.Arguments, } callRes.Result, callRes.Error = req.Toolbox.Execute(gollm.NewContext(ctx, req, &choice, &call), call) res.CallResults[i] = callRes } } return res, nil }