initial commit

This commit is contained in:
2024-10-06 20:01:01 -04:00
commit a67ad54bcc
6 changed files with 350 additions and 0 deletions

113
pkg/anthropic.go Normal file
View File

@@ -0,0 +1,113 @@
package llm
import (
"context"
"fmt"
anth "github.com/liushuangls/go-anthropic/v2"
)
type anthropic struct {
key string
model string
}
var _ LLM = anthropic{}
func (a anthropic) ModelVersion(modelVersion string) (ChatCompletion, error) {
a.model = modelVersion
// TODO: model verification?
return a, nil
}
func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
res := anth.MessagesRequest{
Model: anth.Model(a.model),
MaxTokens: 1000,
}
msgs := []anth.Message{}
// we gotta convert messages into anthropic messages, however
// anthropic does not have a "system" message type, so we need to
// append it to the res.System field instead
for _, msg := range req.Messages {
if msg.Role == RoleSystem {
if len(res.System) > 0 {
res.System += "\n"
}
res.System += msg.Text
} else {
role := anth.RoleUser
if msg.Role == RoleAssistant {
role = anth.RoleAssistant
}
msgs = append(msgs, anth.Message{
Role: role,
Content: []anth.MessageContent{
{
Type: anth.MessagesContentTypeText,
Text: &msg.Text,
},
},
})
}
}
for _, tool := range req.Toolbox {
res.Tools = append(res.Tools, anth.ToolDefinition{
Name: tool.Name,
Description: tool.Description,
InputSchema: tool.Parameters,
})
}
res.Messages = msgs
return res
}
func (a anthropic) responseToLLMResponse(in anth.MessagesResponse) Response {
res := Response{}
for _, msg := range in.Content {
choice := ResponseChoice{}
switch msg.Type {
case anth.MessagesContentTypeText:
if msg.Text != nil {
choice.Content = *msg.Text
}
case anth.MessagesContentTypeToolUse:
if msg.MessageContentToolUse != nil {
choice.Calls = append(choice.Calls, ToolCall{
ID: msg.MessageContentToolUse.ID,
FunctionCall: FunctionCall{
Name: msg.MessageContentToolUse.Name,
Arguments: msg.MessageContentToolUse.Input,
},
})
}
}
res.Choices = append(res.Choices, choice)
}
return res
}
func (a anthropic) ChatComplete(ctx context.Context, req Request) (Response, error) {
cl := anth.NewClient(a.key)
res, err := cl.CreateMessages(ctx, a.requestToAnthropicRequest(req))
if err != nil {
return Response{}, fmt.Errorf("failed to chat complete: %w", err)
}
return a.responseToLLMResponse(res), nil
}

13
pkg/function.go Normal file
View File

@@ -0,0 +1,13 @@
package llm
type Function struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Strict bool `json:"strict,omitempty"`
Parameters any `json:"parameters"`
}
type FunctionCall struct {
Name string `json:"name,omitempty"`
Arguments any `json:"arguments,omitempty"`
}

57
pkg/llm.go Normal file
View File

@@ -0,0 +1,57 @@
package llm
import (
"context"
)
type Role string
const (
RoleSystem Role = "system"
RoleUser Role = "user"
RoleAssistant Role = "assistant"
)
type Message struct {
Role Role
Name string
Text string
}
type Request struct {
Messages []Message
Toolbox []Function
}
type ToolCall struct {
ID string
FunctionCall FunctionCall
}
type ResponseChoice struct {
Index int
Role Role
Content string
Refusal string
Name string
Calls []ToolCall
}
type Response struct {
Choices []ResponseChoice
}
type ChatCompletion interface {
ChatComplete(ctx context.Context, req Request) (Response, error)
}
type LLM interface {
ModelVersion(modelVersion string) (ChatCompletion, error)
}
func OpenAI(key string) LLM {
return openai{key: key}
}
func Anthropic(key string) LLM {
return anthropic{key: key}
}

93
pkg/openai.go Normal file
View File

@@ -0,0 +1,93 @@
package llm
import (
"context"
"fmt"
oai "github.com/sashabaranov/go-openai"
)
type openai struct {
key string
model string
}
var _ LLM = openai{}
func (o openai) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
res := oai.ChatCompletionRequest{
Model: o.model,
}
for _, msg := range request.Messages {
res.Messages = append(res.Messages, oai.ChatCompletionMessage{
Content: msg.Text,
Role: string(msg.Role),
Name: msg.Name,
})
}
for _, tool := range request.Toolbox {
res.Tools = append(res.Tools, oai.Tool{
Type: "function",
Function: &oai.FunctionDefinition{
Name: tool.Name,
Description: tool.Description,
Strict: tool.Strict,
Parameters: tool.Parameters,
},
})
}
return res
}
func (o openai) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
res := Response{}
for _, choice := range response.Choices {
var tools []ToolCall
for _, call := range choice.Message.ToolCalls {
toolCall := ToolCall{
ID: call.ID,
FunctionCall: FunctionCall{
Name: call.Function.Name,
Arguments: call.Function.Arguments,
},
}
tools = append(tools, toolCall)
}
res.Choices = append(res.Choices, ResponseChoice{
Content: choice.Message.Content,
Role: Role(choice.Message.Role),
Name: choice.Message.Name,
Refusal: choice.Message.Refusal,
})
}
return res
}
func (o openai) ChatComplete(ctx context.Context, request Request) (Response, error) {
cl := oai.NewClient(o.key)
req := o.requestToOpenAIRequest(request)
resp, err := cl.CreateChatCompletion(ctx, req)
fmt.Println("resp:", fmt.Sprintf("%#v", resp))
if err != nil {
return Response{}, fmt.Errorf("unhandled openai error: %w", err)
}
return o.responseToLLMResponse(resp), nil
}
func (o openai) ModelVersion(modelVersion string) (ChatCompletion, error) {
return openai{
key: o.key,
model: modelVersion,
}, nil
}