push of current changes
This commit is contained in:
110
openai.go
Normal file
110
openai.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package go_llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
oai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type openai struct {
|
||||
key string
|
||||
model string
|
||||
}
|
||||
|
||||
var _ LLM = openai{}
|
||||
|
||||
func (o openai) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
|
||||
res := oai.ChatCompletionRequest{
|
||||
Model: o.model,
|
||||
}
|
||||
|
||||
for _, msg := range request.Messages {
|
||||
m := oai.ChatCompletionMessage{
|
||||
Content: msg.Text,
|
||||
Role: string(msg.Role),
|
||||
Name: msg.Name,
|
||||
}
|
||||
|
||||
if msg.ImageBase64 != "" {
|
||||
part := oai.ChatMessagePart{
|
||||
Type: "image_url",
|
||||
ImageURL: &oai.ChatMessageImageURL{
|
||||
URL: fmt.Sprintf("data:image/jpeg;base64,%s", msg.ImageBase64),
|
||||
},
|
||||
}
|
||||
|
||||
m.MultiContent = append(m.MultiContent, part)
|
||||
}
|
||||
|
||||
res.Messages = append(res.Messages, m)
|
||||
}
|
||||
|
||||
for _, tool := range request.Toolbox {
|
||||
res.Tools = append(res.Tools, oai.Tool{
|
||||
Type: "function",
|
||||
Function: &oai.FunctionDefinition{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Strict: tool.Strict,
|
||||
Parameters: tool.Parameters,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if request.Temperature != nil {
|
||||
res.Temperature = *request.Temperature
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (o openai) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
|
||||
res := Response{}
|
||||
|
||||
for _, choice := range response.Choices {
|
||||
var tools []ToolCall
|
||||
for _, call := range choice.Message.ToolCalls {
|
||||
toolCall := ToolCall{
|
||||
ID: call.ID,
|
||||
FunctionCall: FunctionCall{
|
||||
Name: call.Function.Name,
|
||||
Arguments: call.Function.Arguments,
|
||||
},
|
||||
}
|
||||
|
||||
tools = append(tools, toolCall)
|
||||
|
||||
}
|
||||
res.Choices = append(res.Choices, ResponseChoice{
|
||||
Content: choice.Message.Content,
|
||||
Role: Role(choice.Message.Role),
|
||||
Name: choice.Message.Name,
|
||||
Refusal: choice.Message.Refusal,
|
||||
})
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (o openai) ChatComplete(ctx context.Context, request Request) (Response, error) {
|
||||
cl := oai.NewClient(o.key)
|
||||
|
||||
req := o.requestToOpenAIRequest(request)
|
||||
|
||||
resp, err := cl.CreateChatCompletion(ctx, req)
|
||||
|
||||
fmt.Println("resp:", fmt.Sprintf("%#v", resp))
|
||||
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("unhandled openai error: %w", err)
|
||||
}
|
||||
|
||||
return o.responseToLLMResponse(resp), nil
|
||||
}
|
||||
|
||||
func (o openai) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
||||
return openai{
|
||||
key: o.key,
|
||||
model: modelVersion,
|
||||
}, nil
|
||||
}
|
Reference in New Issue
Block a user