124 lines
2.7 KiB
Go
124 lines
2.7 KiB
Go
package go_llm
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/openai/openai-go"
|
|
"github.com/openai/openai-go/option"
|
|
"github.com/openai/openai-go/shared"
|
|
)
|
|
|
|
type openaiImpl struct {
|
|
key string
|
|
model string
|
|
baseUrl string
|
|
}
|
|
|
|
var _ LLM = openaiImpl{}
|
|
|
|
func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatCompletionNewParams {
|
|
res := openai.ChatCompletionNewParams{
|
|
Model: o.model,
|
|
}
|
|
|
|
for _, i := range request.Conversation {
|
|
res.Messages = append(res.Messages, i.toChatCompletionMessages(o.model)...)
|
|
}
|
|
|
|
for _, msg := range request.Messages {
|
|
res.Messages = append(res.Messages, msg.toChatCompletionMessages(o.model)...)
|
|
}
|
|
|
|
for _, tool := range request.Toolbox.functions {
|
|
res.Tools = append(res.Tools, openai.ChatCompletionToolParam{
|
|
Type: "function",
|
|
Function: shared.FunctionDefinitionParam{
|
|
Name: tool.Name,
|
|
Description: openai.String(tool.Description),
|
|
Strict: openai.Bool(tool.Strict),
|
|
Parameters: tool.Parameters.OpenAIParameters(),
|
|
},
|
|
})
|
|
}
|
|
|
|
if request.Toolbox.RequiresTool() {
|
|
res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{
|
|
OfAuto: openai.String("required"),
|
|
}
|
|
}
|
|
|
|
if request.Temperature != nil {
|
|
res.Temperature = openai.Float(*request.Temperature)
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
func (o openaiImpl) responseToLLMResponse(response *openai.ChatCompletion) Response {
|
|
var res Response
|
|
|
|
if response == nil {
|
|
return res
|
|
}
|
|
|
|
if len(response.Choices) == 0 {
|
|
return res
|
|
}
|
|
|
|
for _, choice := range response.Choices {
|
|
var toolCalls []ToolCall
|
|
for _, call := range choice.Message.ToolCalls {
|
|
toolCall := ToolCall{
|
|
ID: call.ID,
|
|
FunctionCall: FunctionCall{
|
|
Name: call.Function.Name,
|
|
Arguments: strings.TrimSpace(call.Function.Arguments),
|
|
},
|
|
}
|
|
|
|
toolCalls = append(toolCalls, toolCall)
|
|
|
|
}
|
|
res.Choices = append(res.Choices, ResponseChoice{
|
|
Content: choice.Message.Content,
|
|
Role: Role(choice.Message.Role),
|
|
Refusal: choice.Message.Refusal,
|
|
Calls: toolCalls,
|
|
})
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
|
|
var opts = []option.RequestOption{
|
|
option.WithAPIKey(o.key),
|
|
}
|
|
|
|
if o.baseUrl != "" {
|
|
opts = append(opts, option.WithBaseURL(o.baseUrl))
|
|
}
|
|
|
|
cl := openai.NewClient(opts...)
|
|
|
|
req := o.newRequestToOpenAIRequest(request)
|
|
|
|
resp, err := cl.Chat.Completions.New(ctx, req)
|
|
//resp, err := cl.CreateChatCompletion(ctx, req)
|
|
|
|
if err != nil {
|
|
return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)
|
|
}
|
|
|
|
return o.responseToLLMResponse(resp), nil
|
|
}
|
|
|
|
func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
|
return openaiImpl{
|
|
key: o.key,
|
|
model: modelVersion,
|
|
}, nil
|
|
}
|