From 7f5e34e437a7a4c293372975ee39bd819851fb79 Mon Sep 17 00:00:00 2001 From: Steve Dudenhoeffer Date: Sun, 16 Mar 2025 22:38:58 -0400 Subject: [PATCH] Refactor entire system to be more contextual so that conversation flow can be more easily managed --- context.go | 97 +++++++++++++++++++++++++++++++++++ function.go | 18 +++++-- functions.go | 3 +- go.sum | 2 - llm.go | 141 ++++++++++++++++++++++++++++++++++++++++++++++----- openai.go | 46 ++++------------- request.go | 54 ++++++++++++++++++++ response.go | 70 +++++++++++++++++++++++++ toolbox.go | 7 ++- 9 files changed, 377 insertions(+), 61 deletions(-) create mode 100644 context.go create mode 100644 request.go create mode 100644 response.go diff --git a/context.go b/context.go new file mode 100644 index 0000000..04a2b27 --- /dev/null +++ b/context.go @@ -0,0 +1,97 @@ +package go_llm + +import ( + "context" + "time" +) + +type Context struct { + context.Context + request Request + response *ResponseChoice + toolcall *ToolCall +} + +func (c *Context) ToNewRequest(toolResults ...ToolCallResponse) Request { + var res Request + + res.Toolbox = c.request.Toolbox + res.Temperature = c.request.Temperature + + res.Conversation = make([]Input, len(c.request.Conversation)) + copy(res.Conversation, c.request.Conversation) + + // now for every input message, convert those to an Input to add to the conversation + for _, msg := range c.request.Messages { + res.Conversation = append(res.Conversation, msg) + } + + // if there are tool calls, then we need to add those to the conversation + if c.response != nil { + for _, call := range c.response.Calls { + res.Conversation = append(res.Conversation, call) + + if c.response.Content != "" || c.response.Refusal != "" { + res.Conversation = append(res.Conversation, Message{ + Role: RoleAssistant, + Text: c.response.Content, + }) + } + } + } + + // if there are tool results, then we need to add those to the conversation + for _, result := range toolResults { + res.Conversation = append(res.Conversation, result) + } + + return res +} + +func NewContext(ctx context.Context, request Request, response *ResponseChoice, toolcall *ToolCall) *Context { + return &Context{Context: ctx, request: request, response: response, toolcall: toolcall} +} + +func (c *Context) Request() Request { + return c.request +} + +func (c *Context) WithContext(ctx context.Context) *Context { + return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall} +} + +func (c *Context) WithRequest(request Request) *Context { + return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall} +} + +func (c *Context) WithResponse(response *ResponseChoice) *Context { + return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall} +} + +func (c *Context) WithToolCall(toolcall *ToolCall) *Context { + return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall} +} + +func (c *Context) Deadline() (deadline time.Time, ok bool) { + return c.Context.Deadline() +} + +func (c *Context) Done() <-chan struct{} { + return c.Context.Done() +} + +func (c *Context) Err() error { + return c.Context.Err() +} + +func (c *Context) Value(key any) any { + if key == "request" { + return c.request + } + return c.Context.Value(key) +} + +func (c *Context) WithTimeout(timeout time.Duration) (*Context, context.CancelFunc) { + ctx, cancel := context.WithTimeout(c.Context, timeout) + return c.WithContext(ctx), cancel +} diff --git a/function.go b/function.go index a6f7134..18c28cb 100644 --- a/function.go +++ b/function.go @@ -31,7 +31,7 @@ type Function struct { definition *jsonschema.Definition } -func (f *Function) Execute(ctx context.Context, input string) (string, error) { +func (f *Function) Execute(ctx *Context, input string) (string, error) { if !f.fn.IsValid() { return "", fmt.Errorf("function %s is not implemented", f.Name) } @@ -46,7 +46,7 @@ func (f *Function) Execute(ctx context.Context, input string) (string, error) { } // now we can call the function - exec := func(ctx context.Context) (string, error) { + exec := func(ctx *Context) (string, error) { out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()}) if len(out) != 2 { @@ -62,7 +62,7 @@ func (f *Function) Execute(ctx context.Context, input string) (string, error) { var cancel context.CancelFunc if f.Timeout > 0 { - ctx, cancel = context.WithTimeout(ctx, f.Timeout) + ctx, cancel = ctx.WithTimeout(f.Timeout) defer cancel() } @@ -90,3 +90,15 @@ type FunctionCall struct { Name string `json:"name,omitempty"` Arguments string `json:"arguments,omitempty"` } + +func (fc *FunctionCall) toRaw() map[string]any { + res := map[string]interface{}{ + "name": fc.Name, + } + + if fc.Arguments != "" { + res["arguments"] = fc.Arguments + } + + return res +} diff --git a/functions.go b/functions.go index 6b93bb6..a28b4da 100644 --- a/functions.go +++ b/functions.go @@ -1,7 +1,6 @@ package go_llm import ( - "context" "gitea.stevedudenhoeffer.com/steve/go-llm/schema" "reflect" ) @@ -13,7 +12,7 @@ import ( // The struct parameters can have the following tags: // - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for -func NewFunction[T any](name string, description string, fn func(context.Context, T) (string, error)) *Function { +func NewFunction[T any](name string, description string, fn func(*Context, T) (string, error)) *Function { var o T res := Function{ diff --git a/go.sum b/go.sum index 78d3a78..46a13e0 100644 --- a/go.sum +++ b/go.sum @@ -37,8 +37,6 @@ github.com/liushuangls/go-anthropic/v2 v2.13.0 h1:f7KJ54IHxIpHPPhrCzs3SrdP2PfErX github.com/liushuangls/go-anthropic/v2 v2.13.0/go.mod h1:5ZwRLF5TQ+y5s/MC9Z1IJYx9WUFgQCKfqFM2xreIQLk= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/sashabaranov/go-openai v1.36.0 h1:fcSrn8uGuorzPWCBp8L0aCR95Zjb/Dd+ZSML0YZy9EI= -github.com/sashabaranov/go-openai v1.36.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sashabaranov/go-openai v1.36.1 h1:EVfRXwIlW2rUzpx6vR+aeIKCK/xylSrVYAx1TMTSX3g= github.com/sashabaranov/go-openai v1.36.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= diff --git a/llm.go b/llm.go index fee880f..5e0d53a 100644 --- a/llm.go +++ b/llm.go @@ -2,6 +2,7 @@ package go_llm import ( "context" + "github.com/sashabaranov/go-openai" ) type Role string @@ -18,6 +19,26 @@ type Image struct { Url string } +func (i Image) toRaw() map[string]any { + res := map[string]any{ + "base64": i.Base64, + "contenttype": i.ContentType, + "url": i.Url, + } + + return res +} + +func (i *Image) fromRaw(raw map[string]any) Image { + var res Image + + res.Base64 = raw["base64"].(string) + res.ContentType = raw["contenttype"].(string) + res.Url = raw["url"].(string) + + return res +} + type Message struct { Role Role Name string @@ -25,10 +46,66 @@ type Message struct { Images []Image } -type Request struct { - Messages []Message - Toolbox *ToolBox - Temperature *float32 +func (m Message) toRaw() map[string]any { + res := map[string]any{ + "role": m.Role, + "name": m.Name, + "text": m.Text, + } + + images := make([]map[string]any, 0, len(m.Images)) + for _, img := range m.Images { + images = append(images, img.toRaw()) + } + + res["images"] = images + + return res +} + +func (m *Message) fromRaw(raw map[string]any) Message { + var res Message + + res.Role = Role(raw["role"].(string)) + res.Name = raw["name"].(string) + res.Text = raw["text"].(string) + + images := raw["images"].([]map[string]any) + for _, img := range images { + var i Image + + res.Images = append(res.Images, i.fromRaw(img)) + } + + return res +} + +func (m Message) toChatCompletionMessages() []openai.ChatCompletionMessage { + var res openai.ChatCompletionMessage + + res.Role = string(m.Role) + res.Name = m.Name + res.Content = m.Text + + for _, img := range m.Images { + if img.Base64 != "" { + res.MultiContent = append(res.MultiContent, openai.ChatMessagePart{ + Type: "image_url", + ImageURL: &openai.ChatMessageImageURL{ + URL: "data:" + img.ContentType + ";base64," + img.Base64, + }, + }) + } else if img.Url != "" { + res.MultiContent = append(res.MultiContent, openai.ChatMessagePart{ + Type: "image_url", + ImageURL: &openai.ChatMessageImageURL{ + URL: img.Url, + }, + }) + } + } + + return []openai.ChatCompletionMessage{res} } type ToolCall struct { @@ -36,16 +113,54 @@ type ToolCall struct { FunctionCall FunctionCall } -type ResponseChoice struct { - Index int - Role Role - Content string - Refusal string - Name string - Calls []ToolCall +func (t ToolCall) toRaw() map[string]any { + res := map[string]any{ + "id": t.ID, + } + + res["function"] = t.FunctionCall.toRaw() + + return res } -type Response struct { - Choices []ResponseChoice + +func (t ToolCall) toChatCompletionMessages() []openai.ChatCompletionMessage { + return []openai.ChatCompletionMessage{{ + Role: openai.ChatMessageRoleTool, + ToolCallID: t.ID, + }} +} + +type ToolCallResponse struct { + ID string + Result string + Error error +} + +func (t ToolCallResponse) toRaw() map[string]any { + res := map[string]any{ + "id": t.ID, + "result": t.Result, + } + + if t.Error != nil { + res["error"] = t.Error.Error() + } + + return res +} + +func (t ToolCallResponse) toChatCompletionMessages() []openai.ChatCompletionMessage { + var refusal string + if t.Error != nil { + refusal = t.Error.Error() + } + + return []openai.ChatCompletionMessage{{ + Role: openai.ChatMessageRoleTool, + Content: t.Result, + Refusal: refusal, + ToolCallID: t.ID, + }} } type ChatCompletion interface { diff --git a/openai.go b/openai.go index 1514634..812f97e 100644 --- a/openai.go +++ b/openai.go @@ -3,6 +3,7 @@ package go_llm import ( "context" "fmt" + "log/slog" "strings" oai "github.com/sashabaranov/go-openai" @@ -15,47 +16,17 @@ type openaiImpl struct { var _ LLM = openaiImpl{} -func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest { +func (o openaiImpl) newRequestToOpenAIRequest(request Request) oai.ChatCompletionRequest { res := oai.ChatCompletionRequest{ Model: o.model, } + for _, i := range request.Conversation { + res.Messages = append(res.Messages, i.toChatCompletionMessages()...) + } + for _, msg := range request.Messages { - m := oai.ChatCompletionMessage{ - Content: msg.Text, - Role: string(msg.Role), - Name: msg.Name, - } - - for _, img := range msg.Images { - if img.Base64 != "" { - m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{ - Type: "image_url", - ImageURL: &oai.ChatMessageImageURL{ - URL: fmt.Sprintf("data:%s;base64,%s", img.ContentType, img.Base64), - }, - }) - } else if img.Url != "" { - m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{ - Type: "image_url", - ImageURL: &oai.ChatMessageImageURL{ - URL: img.Url, - }, - }) - } - } - - // openai does not allow Content and MultiContent to be set at the same time, so we need to check - if len(m.MultiContent) > 0 && m.Content != "" { - m.MultiContent = append([]oai.ChatMessagePart{{ - Type: "text", - Text: m.Content, - }}, m.MultiContent...) - - m.Content = "" - } - - res.Messages = append(res.Messages, m) + res.Messages = append(res.Messages, msg.toChatCompletionMessages()...) } if request.Toolbox != nil { @@ -130,8 +101,9 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) { cl := oai.NewClient(o.key) - req := o.requestToOpenAIRequest(request) + req := o.newRequestToOpenAIRequest(request) + slog.Info("openaiImpl.ChatComplete", "req", fmt.Sprintf("%#v", req)) resp, err := cl.CreateChatCompletion(ctx, req) fmt.Println("resp:", fmt.Sprintf("%#v", resp)) diff --git a/request.go b/request.go new file mode 100644 index 0000000..e3831c2 --- /dev/null +++ b/request.go @@ -0,0 +1,54 @@ +package go_llm + +import "github.com/sashabaranov/go-openai" + +type rawAble interface { + toRaw() map[string]any + fromRaw(raw map[string]any) Input +} + +type Input interface { + toChatCompletionMessages() []openai.ChatCompletionMessage +} +type Request struct { + Conversation []Input + Messages []Message + Toolbox *ToolBox + Temperature *float32 +} + +// NextRequest will take the current request's conversation, messages, the response, and any tool results, and +// return a new request with the conversation updated to include the response and tool results. +func (req Request) NextRequest(resp ResponseChoice, toolResults []ToolCallResponse) Request { + var res Request + + res.Toolbox = req.Toolbox + res.Temperature = req.Temperature + + res.Conversation = make([]Input, len(req.Conversation)) + copy(res.Conversation, req.Conversation) + + // now for every input message, convert those to an Input to add to the conversation + for _, msg := range req.Messages { + res.Conversation = append(res.Conversation, msg) + } + + // if there are tool calls, then we need to add those to the conversation + for _, call := range resp.Calls { + res.Conversation = append(res.Conversation, call) + } + + if resp.Content != "" || resp.Refusal != "" { + res.Conversation = append(res.Conversation, Message{ + Role: RoleAssistant, + Text: resp.Content, + }) + } + + // if there are tool results, then we need to add those to the conversation + for _, result := range toolResults { + res.Conversation = append(res.Conversation, result) + } + + return res +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..ef88bfa --- /dev/null +++ b/response.go @@ -0,0 +1,70 @@ +package go_llm + +import "github.com/sashabaranov/go-openai" + +type ResponseChoice struct { + Index int + Role Role + Content string + Refusal string + Name string + Calls []ToolCall +} + +func (r ResponseChoice) toRaw() map[string]any { + res := map[string]any{ + "index": r.Index, + "role": r.Role, + "content": r.Content, + "refusal": r.Refusal, + "name": r.Name, + } + + calls := make([]map[string]any, 0, len(r.Calls)) + for _, call := range r.Calls { + calls = append(calls, call.toRaw()) + } + + res["calls"] = calls + + return res +} + +func (r ResponseChoice) toChatCompletionMessages() []openai.ChatCompletionMessage { + var res []openai.ChatCompletionMessage + + for _, call := range r.Calls { + res = append(res, call.toChatCompletionMessages()...) + } + + if r.Refusal != "" || r.Content != "" { + res = append(res, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: r.Content, + Refusal: r.Refusal, + }) + } + + return res +} + +func (r ResponseChoice) toInput() []Input { + var res []Input + + for _, call := range r.Calls { + res = append(res, call) + } + + if r.Content != "" || r.Refusal != "" { + res = append(res, Message{ + Role: RoleAssistant, + Text: r.Content, + }) + } + + return res +} + +type Response struct { + Choices []ResponseChoice +} diff --git a/toolbox.go b/toolbox.go index 0cb34a5..208219f 100644 --- a/toolbox.go +++ b/toolbox.go @@ -1,7 +1,6 @@ package go_llm import ( - "context" "errors" "fmt" "github.com/sashabaranov/go-openai" @@ -64,7 +63,7 @@ var ( ErrFunctionNotFound = errors.New("function not found") ) -func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, params string) (string, error) { +func (t *ToolBox) executeFunction(ctx *Context, functionName string, params string) (string, error) { f, ok := t.names[functionName] if !ok { @@ -74,6 +73,6 @@ func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, para return f.Execute(ctx, params) } -func (t *ToolBox) Execute(ctx context.Context, toolCall ToolCall) (string, error) { - return t.ExecuteFunction(ctx, toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments) +func (t *ToolBox) Execute(ctx *Context, toolCall ToolCall) (string, error) { + return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments) }