Refactor entire system to be more contextual so that conversation flow can be more easily managed

This commit is contained in:
Steve Dudenhoeffer 2025-03-16 22:38:58 -04:00
parent 0d909edd44
commit 7f5e34e437
9 changed files with 377 additions and 61 deletions

97
context.go Normal file
View File

@ -0,0 +1,97 @@
package go_llm
import (
"context"
"time"
)
type Context struct {
context.Context
request Request
response *ResponseChoice
toolcall *ToolCall
}
func (c *Context) ToNewRequest(toolResults ...ToolCallResponse) Request {
var res Request
res.Toolbox = c.request.Toolbox
res.Temperature = c.request.Temperature
res.Conversation = make([]Input, len(c.request.Conversation))
copy(res.Conversation, c.request.Conversation)
// now for every input message, convert those to an Input to add to the conversation
for _, msg := range c.request.Messages {
res.Conversation = append(res.Conversation, msg)
}
// if there are tool calls, then we need to add those to the conversation
if c.response != nil {
for _, call := range c.response.Calls {
res.Conversation = append(res.Conversation, call)
if c.response.Content != "" || c.response.Refusal != "" {
res.Conversation = append(res.Conversation, Message{
Role: RoleAssistant,
Text: c.response.Content,
})
}
}
}
// if there are tool results, then we need to add those to the conversation
for _, result := range toolResults {
res.Conversation = append(res.Conversation, result)
}
return res
}
func NewContext(ctx context.Context, request Request, response *ResponseChoice, toolcall *ToolCall) *Context {
return &Context{Context: ctx, request: request, response: response, toolcall: toolcall}
}
func (c *Context) Request() Request {
return c.request
}
func (c *Context) WithContext(ctx context.Context) *Context {
return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall}
}
func (c *Context) WithRequest(request Request) *Context {
return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall}
}
func (c *Context) WithResponse(response *ResponseChoice) *Context {
return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall}
}
func (c *Context) WithToolCall(toolcall *ToolCall) *Context {
return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall}
}
func (c *Context) Deadline() (deadline time.Time, ok bool) {
return c.Context.Deadline()
}
func (c *Context) Done() <-chan struct{} {
return c.Context.Done()
}
func (c *Context) Err() error {
return c.Context.Err()
}
func (c *Context) Value(key any) any {
if key == "request" {
return c.request
}
return c.Context.Value(key)
}
func (c *Context) WithTimeout(timeout time.Duration) (*Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(c.Context, timeout)
return c.WithContext(ctx), cancel
}

View File

@ -31,7 +31,7 @@ type Function struct {
definition *jsonschema.Definition
}
func (f *Function) Execute(ctx context.Context, input string) (string, error) {
func (f *Function) Execute(ctx *Context, input string) (string, error) {
if !f.fn.IsValid() {
return "", fmt.Errorf("function %s is not implemented", f.Name)
}
@ -46,7 +46,7 @@ func (f *Function) Execute(ctx context.Context, input string) (string, error) {
}
// now we can call the function
exec := func(ctx context.Context) (string, error) {
exec := func(ctx *Context) (string, error) {
out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
if len(out) != 2 {
@ -62,7 +62,7 @@ func (f *Function) Execute(ctx context.Context, input string) (string, error) {
var cancel context.CancelFunc
if f.Timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, f.Timeout)
ctx, cancel = ctx.WithTimeout(f.Timeout)
defer cancel()
}
@ -90,3 +90,15 @@ type FunctionCall struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
}
func (fc *FunctionCall) toRaw() map[string]any {
res := map[string]interface{}{
"name": fc.Name,
}
if fc.Arguments != "" {
res["arguments"] = fc.Arguments
}
return res
}

View File

@ -1,7 +1,6 @@
package go_llm
import (
"context"
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
"reflect"
)
@ -13,7 +12,7 @@ import (
// The struct parameters can have the following tags:
// - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for
func NewFunction[T any](name string, description string, fn func(context.Context, T) (string, error)) *Function {
func NewFunction[T any](name string, description string, fn func(*Context, T) (string, error)) *Function {
var o T
res := Function{

2
go.sum
View File

@ -37,8 +37,6 @@ github.com/liushuangls/go-anthropic/v2 v2.13.0 h1:f7KJ54IHxIpHPPhrCzs3SrdP2PfErX
github.com/liushuangls/go-anthropic/v2 v2.13.0/go.mod h1:5ZwRLF5TQ+y5s/MC9Z1IJYx9WUFgQCKfqFM2xreIQLk=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sashabaranov/go-openai v1.36.0 h1:fcSrn8uGuorzPWCBp8L0aCR95Zjb/Dd+ZSML0YZy9EI=
github.com/sashabaranov/go-openai v1.36.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.36.1 h1:EVfRXwIlW2rUzpx6vR+aeIKCK/xylSrVYAx1TMTSX3g=
github.com/sashabaranov/go-openai v1.36.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=

141
llm.go
View File

@ -2,6 +2,7 @@ package go_llm
import (
"context"
"github.com/sashabaranov/go-openai"
)
type Role string
@ -18,6 +19,26 @@ type Image struct {
Url string
}
func (i Image) toRaw() map[string]any {
res := map[string]any{
"base64": i.Base64,
"contenttype": i.ContentType,
"url": i.Url,
}
return res
}
func (i *Image) fromRaw(raw map[string]any) Image {
var res Image
res.Base64 = raw["base64"].(string)
res.ContentType = raw["contenttype"].(string)
res.Url = raw["url"].(string)
return res
}
type Message struct {
Role Role
Name string
@ -25,10 +46,66 @@ type Message struct {
Images []Image
}
type Request struct {
Messages []Message
Toolbox *ToolBox
Temperature *float32
func (m Message) toRaw() map[string]any {
res := map[string]any{
"role": m.Role,
"name": m.Name,
"text": m.Text,
}
images := make([]map[string]any, 0, len(m.Images))
for _, img := range m.Images {
images = append(images, img.toRaw())
}
res["images"] = images
return res
}
func (m *Message) fromRaw(raw map[string]any) Message {
var res Message
res.Role = Role(raw["role"].(string))
res.Name = raw["name"].(string)
res.Text = raw["text"].(string)
images := raw["images"].([]map[string]any)
for _, img := range images {
var i Image
res.Images = append(res.Images, i.fromRaw(img))
}
return res
}
func (m Message) toChatCompletionMessages() []openai.ChatCompletionMessage {
var res openai.ChatCompletionMessage
res.Role = string(m.Role)
res.Name = m.Name
res.Content = m.Text
for _, img := range m.Images {
if img.Base64 != "" {
res.MultiContent = append(res.MultiContent, openai.ChatMessagePart{
Type: "image_url",
ImageURL: &openai.ChatMessageImageURL{
URL: "data:" + img.ContentType + ";base64," + img.Base64,
},
})
} else if img.Url != "" {
res.MultiContent = append(res.MultiContent, openai.ChatMessagePart{
Type: "image_url",
ImageURL: &openai.ChatMessageImageURL{
URL: img.Url,
},
})
}
}
return []openai.ChatCompletionMessage{res}
}
type ToolCall struct {
@ -36,16 +113,54 @@ type ToolCall struct {
FunctionCall FunctionCall
}
type ResponseChoice struct {
Index int
Role Role
Content string
Refusal string
Name string
Calls []ToolCall
func (t ToolCall) toRaw() map[string]any {
res := map[string]any{
"id": t.ID,
}
res["function"] = t.FunctionCall.toRaw()
return res
}
type Response struct {
Choices []ResponseChoice
func (t ToolCall) toChatCompletionMessages() []openai.ChatCompletionMessage {
return []openai.ChatCompletionMessage{{
Role: openai.ChatMessageRoleTool,
ToolCallID: t.ID,
}}
}
type ToolCallResponse struct {
ID string
Result string
Error error
}
func (t ToolCallResponse) toRaw() map[string]any {
res := map[string]any{
"id": t.ID,
"result": t.Result,
}
if t.Error != nil {
res["error"] = t.Error.Error()
}
return res
}
func (t ToolCallResponse) toChatCompletionMessages() []openai.ChatCompletionMessage {
var refusal string
if t.Error != nil {
refusal = t.Error.Error()
}
return []openai.ChatCompletionMessage{{
Role: openai.ChatMessageRoleTool,
Content: t.Result,
Refusal: refusal,
ToolCallID: t.ID,
}}
}
type ChatCompletion interface {

View File

@ -3,6 +3,7 @@ package go_llm
import (
"context"
"fmt"
"log/slog"
"strings"
oai "github.com/sashabaranov/go-openai"
@ -15,47 +16,17 @@ type openaiImpl struct {
var _ LLM = openaiImpl{}
func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
func (o openaiImpl) newRequestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
res := oai.ChatCompletionRequest{
Model: o.model,
}
for _, i := range request.Conversation {
res.Messages = append(res.Messages, i.toChatCompletionMessages()...)
}
for _, msg := range request.Messages {
m := oai.ChatCompletionMessage{
Content: msg.Text,
Role: string(msg.Role),
Name: msg.Name,
}
for _, img := range msg.Images {
if img.Base64 != "" {
m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
Type: "image_url",
ImageURL: &oai.ChatMessageImageURL{
URL: fmt.Sprintf("data:%s;base64,%s", img.ContentType, img.Base64),
},
})
} else if img.Url != "" {
m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
Type: "image_url",
ImageURL: &oai.ChatMessageImageURL{
URL: img.Url,
},
})
}
}
// openai does not allow Content and MultiContent to be set at the same time, so we need to check
if len(m.MultiContent) > 0 && m.Content != "" {
m.MultiContent = append([]oai.ChatMessagePart{{
Type: "text",
Text: m.Content,
}}, m.MultiContent...)
m.Content = ""
}
res.Messages = append(res.Messages, m)
res.Messages = append(res.Messages, msg.toChatCompletionMessages()...)
}
if request.Toolbox != nil {
@ -130,8 +101,9 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
cl := oai.NewClient(o.key)
req := o.requestToOpenAIRequest(request)
req := o.newRequestToOpenAIRequest(request)
slog.Info("openaiImpl.ChatComplete", "req", fmt.Sprintf("%#v", req))
resp, err := cl.CreateChatCompletion(ctx, req)
fmt.Println("resp:", fmt.Sprintf("%#v", resp))

54
request.go Normal file
View File

@ -0,0 +1,54 @@
package go_llm
import "github.com/sashabaranov/go-openai"
type rawAble interface {
toRaw() map[string]any
fromRaw(raw map[string]any) Input
}
type Input interface {
toChatCompletionMessages() []openai.ChatCompletionMessage
}
type Request struct {
Conversation []Input
Messages []Message
Toolbox *ToolBox
Temperature *float32
}
// NextRequest will take the current request's conversation, messages, the response, and any tool results, and
// return a new request with the conversation updated to include the response and tool results.
func (req Request) NextRequest(resp ResponseChoice, toolResults []ToolCallResponse) Request {
var res Request
res.Toolbox = req.Toolbox
res.Temperature = req.Temperature
res.Conversation = make([]Input, len(req.Conversation))
copy(res.Conversation, req.Conversation)
// now for every input message, convert those to an Input to add to the conversation
for _, msg := range req.Messages {
res.Conversation = append(res.Conversation, msg)
}
// if there are tool calls, then we need to add those to the conversation
for _, call := range resp.Calls {
res.Conversation = append(res.Conversation, call)
}
if resp.Content != "" || resp.Refusal != "" {
res.Conversation = append(res.Conversation, Message{
Role: RoleAssistant,
Text: resp.Content,
})
}
// if there are tool results, then we need to add those to the conversation
for _, result := range toolResults {
res.Conversation = append(res.Conversation, result)
}
return res
}

70
response.go Normal file
View File

@ -0,0 +1,70 @@
package go_llm
import "github.com/sashabaranov/go-openai"
type ResponseChoice struct {
Index int
Role Role
Content string
Refusal string
Name string
Calls []ToolCall
}
func (r ResponseChoice) toRaw() map[string]any {
res := map[string]any{
"index": r.Index,
"role": r.Role,
"content": r.Content,
"refusal": r.Refusal,
"name": r.Name,
}
calls := make([]map[string]any, 0, len(r.Calls))
for _, call := range r.Calls {
calls = append(calls, call.toRaw())
}
res["calls"] = calls
return res
}
func (r ResponseChoice) toChatCompletionMessages() []openai.ChatCompletionMessage {
var res []openai.ChatCompletionMessage
for _, call := range r.Calls {
res = append(res, call.toChatCompletionMessages()...)
}
if r.Refusal != "" || r.Content != "" {
res = append(res, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: r.Content,
Refusal: r.Refusal,
})
}
return res
}
func (r ResponseChoice) toInput() []Input {
var res []Input
for _, call := range r.Calls {
res = append(res, call)
}
if r.Content != "" || r.Refusal != "" {
res = append(res, Message{
Role: RoleAssistant,
Text: r.Content,
})
}
return res
}
type Response struct {
Choices []ResponseChoice
}

View File

@ -1,7 +1,6 @@
package go_llm
import (
"context"
"errors"
"fmt"
"github.com/sashabaranov/go-openai"
@ -64,7 +63,7 @@ var (
ErrFunctionNotFound = errors.New("function not found")
)
func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, params string) (string, error) {
func (t *ToolBox) executeFunction(ctx *Context, functionName string, params string) (string, error) {
f, ok := t.names[functionName]
if !ok {
@ -74,6 +73,6 @@ func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, para
return f.Execute(ctx, params)
}
func (t *ToolBox) Execute(ctx context.Context, toolCall ToolCall) (string, error) {
return t.ExecuteFunction(ctx, toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
func (t *ToolBox) Execute(ctx *Context, toolCall ToolCall) (string, error) {
return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
}