Refactor entire system to be more contextual so that conversation flow can be more easily managed
This commit is contained in:
parent
0d909edd44
commit
7f5e34e437
97
context.go
Normal file
97
context.go
Normal file
@ -0,0 +1,97 @@
|
||||
package go_llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Context struct {
|
||||
context.Context
|
||||
request Request
|
||||
response *ResponseChoice
|
||||
toolcall *ToolCall
|
||||
}
|
||||
|
||||
func (c *Context) ToNewRequest(toolResults ...ToolCallResponse) Request {
|
||||
var res Request
|
||||
|
||||
res.Toolbox = c.request.Toolbox
|
||||
res.Temperature = c.request.Temperature
|
||||
|
||||
res.Conversation = make([]Input, len(c.request.Conversation))
|
||||
copy(res.Conversation, c.request.Conversation)
|
||||
|
||||
// now for every input message, convert those to an Input to add to the conversation
|
||||
for _, msg := range c.request.Messages {
|
||||
res.Conversation = append(res.Conversation, msg)
|
||||
}
|
||||
|
||||
// if there are tool calls, then we need to add those to the conversation
|
||||
if c.response != nil {
|
||||
for _, call := range c.response.Calls {
|
||||
res.Conversation = append(res.Conversation, call)
|
||||
|
||||
if c.response.Content != "" || c.response.Refusal != "" {
|
||||
res.Conversation = append(res.Conversation, Message{
|
||||
Role: RoleAssistant,
|
||||
Text: c.response.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if there are tool results, then we need to add those to the conversation
|
||||
for _, result := range toolResults {
|
||||
res.Conversation = append(res.Conversation, result)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func NewContext(ctx context.Context, request Request, response *ResponseChoice, toolcall *ToolCall) *Context {
|
||||
return &Context{Context: ctx, request: request, response: response, toolcall: toolcall}
|
||||
}
|
||||
|
||||
func (c *Context) Request() Request {
|
||||
return c.request
|
||||
}
|
||||
|
||||
func (c *Context) WithContext(ctx context.Context) *Context {
|
||||
return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall}
|
||||
}
|
||||
|
||||
func (c *Context) WithRequest(request Request) *Context {
|
||||
return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall}
|
||||
}
|
||||
|
||||
func (c *Context) WithResponse(response *ResponseChoice) *Context {
|
||||
return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall}
|
||||
}
|
||||
|
||||
func (c *Context) WithToolCall(toolcall *ToolCall) *Context {
|
||||
return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall}
|
||||
}
|
||||
|
||||
func (c *Context) Deadline() (deadline time.Time, ok bool) {
|
||||
return c.Context.Deadline()
|
||||
}
|
||||
|
||||
func (c *Context) Done() <-chan struct{} {
|
||||
return c.Context.Done()
|
||||
}
|
||||
|
||||
func (c *Context) Err() error {
|
||||
return c.Context.Err()
|
||||
}
|
||||
|
||||
func (c *Context) Value(key any) any {
|
||||
if key == "request" {
|
||||
return c.request
|
||||
}
|
||||
return c.Context.Value(key)
|
||||
}
|
||||
|
||||
func (c *Context) WithTimeout(timeout time.Duration) (*Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithTimeout(c.Context, timeout)
|
||||
return c.WithContext(ctx), cancel
|
||||
}
|
18
function.go
18
function.go
@ -31,7 +31,7 @@ type Function struct {
|
||||
definition *jsonschema.Definition
|
||||
}
|
||||
|
||||
func (f *Function) Execute(ctx context.Context, input string) (string, error) {
|
||||
func (f *Function) Execute(ctx *Context, input string) (string, error) {
|
||||
if !f.fn.IsValid() {
|
||||
return "", fmt.Errorf("function %s is not implemented", f.Name)
|
||||
}
|
||||
@ -46,7 +46,7 @@ func (f *Function) Execute(ctx context.Context, input string) (string, error) {
|
||||
}
|
||||
|
||||
// now we can call the function
|
||||
exec := func(ctx context.Context) (string, error) {
|
||||
exec := func(ctx *Context) (string, error) {
|
||||
out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
|
||||
|
||||
if len(out) != 2 {
|
||||
@ -62,7 +62,7 @@ func (f *Function) Execute(ctx context.Context, input string) (string, error) {
|
||||
|
||||
var cancel context.CancelFunc
|
||||
if f.Timeout > 0 {
|
||||
ctx, cancel = context.WithTimeout(ctx, f.Timeout)
|
||||
ctx, cancel = ctx.WithTimeout(f.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
@ -90,3 +90,15 @@ type FunctionCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
func (fc *FunctionCall) toRaw() map[string]any {
|
||||
res := map[string]interface{}{
|
||||
"name": fc.Name,
|
||||
}
|
||||
|
||||
if fc.Arguments != "" {
|
||||
res["arguments"] = fc.Arguments
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package go_llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
|
||||
"reflect"
|
||||
)
|
||||
@ -13,7 +12,7 @@ import (
|
||||
// The struct parameters can have the following tags:
|
||||
// - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for
|
||||
|
||||
func NewFunction[T any](name string, description string, fn func(context.Context, T) (string, error)) *Function {
|
||||
func NewFunction[T any](name string, description string, fn func(*Context, T) (string, error)) *Function {
|
||||
var o T
|
||||
|
||||
res := Function{
|
||||
|
2
go.sum
2
go.sum
@ -37,8 +37,6 @@ github.com/liushuangls/go-anthropic/v2 v2.13.0 h1:f7KJ54IHxIpHPPhrCzs3SrdP2PfErX
|
||||
github.com/liushuangls/go-anthropic/v2 v2.13.0/go.mod h1:5ZwRLF5TQ+y5s/MC9Z1IJYx9WUFgQCKfqFM2xreIQLk=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/sashabaranov/go-openai v1.36.0 h1:fcSrn8uGuorzPWCBp8L0aCR95Zjb/Dd+ZSML0YZy9EI=
|
||||
github.com/sashabaranov/go-openai v1.36.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/sashabaranov/go-openai v1.36.1 h1:EVfRXwIlW2rUzpx6vR+aeIKCK/xylSrVYAx1TMTSX3g=
|
||||
github.com/sashabaranov/go-openai v1.36.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
|
141
llm.go
141
llm.go
@ -2,6 +2,7 @@ package go_llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type Role string
|
||||
@ -18,6 +19,26 @@ type Image struct {
|
||||
Url string
|
||||
}
|
||||
|
||||
func (i Image) toRaw() map[string]any {
|
||||
res := map[string]any{
|
||||
"base64": i.Base64,
|
||||
"contenttype": i.ContentType,
|
||||
"url": i.Url,
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (i *Image) fromRaw(raw map[string]any) Image {
|
||||
var res Image
|
||||
|
||||
res.Base64 = raw["base64"].(string)
|
||||
res.ContentType = raw["contenttype"].(string)
|
||||
res.Url = raw["url"].(string)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role Role
|
||||
Name string
|
||||
@ -25,10 +46,66 @@ type Message struct {
|
||||
Images []Image
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Messages []Message
|
||||
Toolbox *ToolBox
|
||||
Temperature *float32
|
||||
func (m Message) toRaw() map[string]any {
|
||||
res := map[string]any{
|
||||
"role": m.Role,
|
||||
"name": m.Name,
|
||||
"text": m.Text,
|
||||
}
|
||||
|
||||
images := make([]map[string]any, 0, len(m.Images))
|
||||
for _, img := range m.Images {
|
||||
images = append(images, img.toRaw())
|
||||
}
|
||||
|
||||
res["images"] = images
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (m *Message) fromRaw(raw map[string]any) Message {
|
||||
var res Message
|
||||
|
||||
res.Role = Role(raw["role"].(string))
|
||||
res.Name = raw["name"].(string)
|
||||
res.Text = raw["text"].(string)
|
||||
|
||||
images := raw["images"].([]map[string]any)
|
||||
for _, img := range images {
|
||||
var i Image
|
||||
|
||||
res.Images = append(res.Images, i.fromRaw(img))
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (m Message) toChatCompletionMessages() []openai.ChatCompletionMessage {
|
||||
var res openai.ChatCompletionMessage
|
||||
|
||||
res.Role = string(m.Role)
|
||||
res.Name = m.Name
|
||||
res.Content = m.Text
|
||||
|
||||
for _, img := range m.Images {
|
||||
if img.Base64 != "" {
|
||||
res.MultiContent = append(res.MultiContent, openai.ChatMessagePart{
|
||||
Type: "image_url",
|
||||
ImageURL: &openai.ChatMessageImageURL{
|
||||
URL: "data:" + img.ContentType + ";base64," + img.Base64,
|
||||
},
|
||||
})
|
||||
} else if img.Url != "" {
|
||||
res.MultiContent = append(res.MultiContent, openai.ChatMessagePart{
|
||||
Type: "image_url",
|
||||
ImageURL: &openai.ChatMessageImageURL{
|
||||
URL: img.Url,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return []openai.ChatCompletionMessage{res}
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
@ -36,16 +113,54 @@ type ToolCall struct {
|
||||
FunctionCall FunctionCall
|
||||
}
|
||||
|
||||
type ResponseChoice struct {
|
||||
Index int
|
||||
Role Role
|
||||
Content string
|
||||
Refusal string
|
||||
Name string
|
||||
Calls []ToolCall
|
||||
func (t ToolCall) toRaw() map[string]any {
|
||||
res := map[string]any{
|
||||
"id": t.ID,
|
||||
}
|
||||
type Response struct {
|
||||
Choices []ResponseChoice
|
||||
|
||||
res["function"] = t.FunctionCall.toRaw()
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (t ToolCall) toChatCompletionMessages() []openai.ChatCompletionMessage {
|
||||
return []openai.ChatCompletionMessage{{
|
||||
Role: openai.ChatMessageRoleTool,
|
||||
ToolCallID: t.ID,
|
||||
}}
|
||||
}
|
||||
|
||||
type ToolCallResponse struct {
|
||||
ID string
|
||||
Result string
|
||||
Error error
|
||||
}
|
||||
|
||||
func (t ToolCallResponse) toRaw() map[string]any {
|
||||
res := map[string]any{
|
||||
"id": t.ID,
|
||||
"result": t.Result,
|
||||
}
|
||||
|
||||
if t.Error != nil {
|
||||
res["error"] = t.Error.Error()
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (t ToolCallResponse) toChatCompletionMessages() []openai.ChatCompletionMessage {
|
||||
var refusal string
|
||||
if t.Error != nil {
|
||||
refusal = t.Error.Error()
|
||||
}
|
||||
|
||||
return []openai.ChatCompletionMessage{{
|
||||
Role: openai.ChatMessageRoleTool,
|
||||
Content: t.Result,
|
||||
Refusal: refusal,
|
||||
ToolCallID: t.ID,
|
||||
}}
|
||||
}
|
||||
|
||||
type ChatCompletion interface {
|
||||
|
46
openai.go
46
openai.go
@ -3,6 +3,7 @@ package go_llm
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
oai "github.com/sashabaranov/go-openai"
|
||||
@ -15,47 +16,17 @@ type openaiImpl struct {
|
||||
|
||||
var _ LLM = openaiImpl{}
|
||||
|
||||
func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
|
||||
func (o openaiImpl) newRequestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
|
||||
res := oai.ChatCompletionRequest{
|
||||
Model: o.model,
|
||||
}
|
||||
|
||||
for _, i := range request.Conversation {
|
||||
res.Messages = append(res.Messages, i.toChatCompletionMessages()...)
|
||||
}
|
||||
|
||||
for _, msg := range request.Messages {
|
||||
m := oai.ChatCompletionMessage{
|
||||
Content: msg.Text,
|
||||
Role: string(msg.Role),
|
||||
Name: msg.Name,
|
||||
}
|
||||
|
||||
for _, img := range msg.Images {
|
||||
if img.Base64 != "" {
|
||||
m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
|
||||
Type: "image_url",
|
||||
ImageURL: &oai.ChatMessageImageURL{
|
||||
URL: fmt.Sprintf("data:%s;base64,%s", img.ContentType, img.Base64),
|
||||
},
|
||||
})
|
||||
} else if img.Url != "" {
|
||||
m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
|
||||
Type: "image_url",
|
||||
ImageURL: &oai.ChatMessageImageURL{
|
||||
URL: img.Url,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// openai does not allow Content and MultiContent to be set at the same time, so we need to check
|
||||
if len(m.MultiContent) > 0 && m.Content != "" {
|
||||
m.MultiContent = append([]oai.ChatMessagePart{{
|
||||
Type: "text",
|
||||
Text: m.Content,
|
||||
}}, m.MultiContent...)
|
||||
|
||||
m.Content = ""
|
||||
}
|
||||
|
||||
res.Messages = append(res.Messages, m)
|
||||
res.Messages = append(res.Messages, msg.toChatCompletionMessages()...)
|
||||
}
|
||||
|
||||
if request.Toolbox != nil {
|
||||
@ -130,8 +101,9 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
|
||||
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
|
||||
cl := oai.NewClient(o.key)
|
||||
|
||||
req := o.requestToOpenAIRequest(request)
|
||||
req := o.newRequestToOpenAIRequest(request)
|
||||
|
||||
slog.Info("openaiImpl.ChatComplete", "req", fmt.Sprintf("%#v", req))
|
||||
resp, err := cl.CreateChatCompletion(ctx, req)
|
||||
|
||||
fmt.Println("resp:", fmt.Sprintf("%#v", resp))
|
||||
|
54
request.go
Normal file
54
request.go
Normal file
@ -0,0 +1,54 @@
|
||||
package go_llm
|
||||
|
||||
import "github.com/sashabaranov/go-openai"
|
||||
|
||||
type rawAble interface {
|
||||
toRaw() map[string]any
|
||||
fromRaw(raw map[string]any) Input
|
||||
}
|
||||
|
||||
type Input interface {
|
||||
toChatCompletionMessages() []openai.ChatCompletionMessage
|
||||
}
|
||||
type Request struct {
|
||||
Conversation []Input
|
||||
Messages []Message
|
||||
Toolbox *ToolBox
|
||||
Temperature *float32
|
||||
}
|
||||
|
||||
// NextRequest will take the current request's conversation, messages, the response, and any tool results, and
|
||||
// return a new request with the conversation updated to include the response and tool results.
|
||||
func (req Request) NextRequest(resp ResponseChoice, toolResults []ToolCallResponse) Request {
|
||||
var res Request
|
||||
|
||||
res.Toolbox = req.Toolbox
|
||||
res.Temperature = req.Temperature
|
||||
|
||||
res.Conversation = make([]Input, len(req.Conversation))
|
||||
copy(res.Conversation, req.Conversation)
|
||||
|
||||
// now for every input message, convert those to an Input to add to the conversation
|
||||
for _, msg := range req.Messages {
|
||||
res.Conversation = append(res.Conversation, msg)
|
||||
}
|
||||
|
||||
// if there are tool calls, then we need to add those to the conversation
|
||||
for _, call := range resp.Calls {
|
||||
res.Conversation = append(res.Conversation, call)
|
||||
}
|
||||
|
||||
if resp.Content != "" || resp.Refusal != "" {
|
||||
res.Conversation = append(res.Conversation, Message{
|
||||
Role: RoleAssistant,
|
||||
Text: resp.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// if there are tool results, then we need to add those to the conversation
|
||||
for _, result := range toolResults {
|
||||
res.Conversation = append(res.Conversation, result)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
70
response.go
Normal file
70
response.go
Normal file
@ -0,0 +1,70 @@
|
||||
package go_llm
|
||||
|
||||
import "github.com/sashabaranov/go-openai"
|
||||
|
||||
type ResponseChoice struct {
|
||||
Index int
|
||||
Role Role
|
||||
Content string
|
||||
Refusal string
|
||||
Name string
|
||||
Calls []ToolCall
|
||||
}
|
||||
|
||||
func (r ResponseChoice) toRaw() map[string]any {
|
||||
res := map[string]any{
|
||||
"index": r.Index,
|
||||
"role": r.Role,
|
||||
"content": r.Content,
|
||||
"refusal": r.Refusal,
|
||||
"name": r.Name,
|
||||
}
|
||||
|
||||
calls := make([]map[string]any, 0, len(r.Calls))
|
||||
for _, call := range r.Calls {
|
||||
calls = append(calls, call.toRaw())
|
||||
}
|
||||
|
||||
res["calls"] = calls
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (r ResponseChoice) toChatCompletionMessages() []openai.ChatCompletionMessage {
|
||||
var res []openai.ChatCompletionMessage
|
||||
|
||||
for _, call := range r.Calls {
|
||||
res = append(res, call.toChatCompletionMessages()...)
|
||||
}
|
||||
|
||||
if r.Refusal != "" || r.Content != "" {
|
||||
res = append(res, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
Content: r.Content,
|
||||
Refusal: r.Refusal,
|
||||
})
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (r ResponseChoice) toInput() []Input {
|
||||
var res []Input
|
||||
|
||||
for _, call := range r.Calls {
|
||||
res = append(res, call)
|
||||
}
|
||||
|
||||
if r.Content != "" || r.Refusal != "" {
|
||||
res = append(res, Message{
|
||||
Role: RoleAssistant,
|
||||
Text: r.Content,
|
||||
})
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Choices []ResponseChoice
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
package go_llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
@ -64,7 +63,7 @@ var (
|
||||
ErrFunctionNotFound = errors.New("function not found")
|
||||
)
|
||||
|
||||
func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, params string) (string, error) {
|
||||
func (t *ToolBox) executeFunction(ctx *Context, functionName string, params string) (string, error) {
|
||||
f, ok := t.names[functionName]
|
||||
|
||||
if !ok {
|
||||
@ -74,6 +73,6 @@ func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, para
|
||||
return f.Execute(ctx, params)
|
||||
}
|
||||
|
||||
func (t *ToolBox) Execute(ctx context.Context, toolCall ToolCall) (string, error) {
|
||||
return t.ExecuteFunction(ctx, toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
|
||||
func (t *ToolBox) Execute(ctx *Context, toolCall ToolCall) (string, error) {
|
||||
return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user