2024-10-06 21:02:26 -04:00
|
|
|
package go_llm
|
2024-10-06 20:01:01 -04:00
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
2024-11-11 00:23:01 -05:00
|
|
|
"encoding/json"
|
2024-10-06 20:01:01 -04:00
|
|
|
"fmt"
|
2024-12-26 22:46:59 -05:00
|
|
|
"io"
|
2024-10-06 21:02:26 -04:00
|
|
|
"log"
|
2024-12-29 19:45:28 -05:00
|
|
|
"log/slog"
|
2024-12-26 22:46:59 -05:00
|
|
|
"net/http"
|
|
|
|
|
|
|
|
anth "github.com/liushuangls/go-anthropic/v2"
|
2024-10-06 20:01:01 -04:00
|
|
|
)
|
|
|
|
|
|
|
|
type anthropic struct {
|
|
|
|
key string
|
|
|
|
model string
|
|
|
|
}
|
|
|
|
|
|
|
|
var _ LLM = anthropic{}
|
|
|
|
|
|
|
|
func (a anthropic) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
|
|
|
a.model = modelVersion
|
|
|
|
|
|
|
|
// TODO: model verification?
|
|
|
|
return a, nil
|
|
|
|
}
|
|
|
|
|
2024-12-29 19:45:28 -05:00
|
|
|
func deferClose(c io.Closer) {
|
|
|
|
err := c.Close()
|
|
|
|
if err != nil {
|
|
|
|
slog.Error("error closing", "error", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-10-06 20:01:01 -04:00
|
|
|
func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
|
|
|
|
res := anth.MessagesRequest{
|
|
|
|
Model: anth.Model(a.model),
|
|
|
|
MaxTokens: 1000,
|
|
|
|
}
|
|
|
|
|
|
|
|
msgs := []anth.Message{}
|
|
|
|
|
|
|
|
// we gotta convert messages into anthropic messages, however
|
|
|
|
// anthropic does not have a "system" message type, so we need to
|
|
|
|
// append it to the res.System field instead
|
|
|
|
|
|
|
|
for _, msg := range req.Messages {
|
|
|
|
if msg.Role == RoleSystem {
|
|
|
|
if len(res.System) > 0 {
|
|
|
|
res.System += "\n"
|
|
|
|
}
|
|
|
|
res.System += msg.Text
|
|
|
|
} else {
|
|
|
|
role := anth.RoleUser
|
|
|
|
|
|
|
|
if msg.Role == RoleAssistant {
|
|
|
|
role = anth.RoleAssistant
|
|
|
|
}
|
|
|
|
|
2024-10-06 21:02:26 -04:00
|
|
|
m := anth.Message{
|
|
|
|
Role: role,
|
|
|
|
Content: []anth.MessageContent{},
|
|
|
|
}
|
|
|
|
|
|
|
|
if msg.Text != "" {
|
|
|
|
m.Content = append(m.Content, anth.MessageContent{
|
|
|
|
Type: anth.MessagesContentTypeText,
|
|
|
|
Text: &msg.Text,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
2024-10-07 16:33:57 -04:00
|
|
|
for _, img := range msg.Images {
|
|
|
|
if img.Base64 != "" {
|
2024-12-26 22:46:59 -05:00
|
|
|
m.Content = append(m.Content, anth.NewImageMessageContent(
|
|
|
|
anth.NewMessageContentSource(
|
|
|
|
anth.MessagesContentSourceTypeBase64,
|
|
|
|
img.ContentType,
|
|
|
|
img.Base64,
|
|
|
|
)))
|
2024-10-07 16:33:57 -04:00
|
|
|
} else if img.Url != "" {
|
2024-12-26 22:46:59 -05:00
|
|
|
|
|
|
|
// download the image
|
|
|
|
cl, err := http.NewRequest(http.MethodGet, img.Url, nil)
|
|
|
|
if err != nil {
|
|
|
|
log.Println("failed to create request", err)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
resp, err := http.DefaultClient.Do(cl)
|
|
|
|
if err != nil {
|
|
|
|
log.Println("failed to download image", err)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
2024-12-29 19:45:28 -05:00
|
|
|
defer deferClose(resp.Body)
|
2024-12-26 22:46:59 -05:00
|
|
|
|
|
|
|
img.ContentType = resp.Header.Get("Content-Type")
|
|
|
|
|
|
|
|
// read the image
|
|
|
|
b, err := io.ReadAll(resp.Body)
|
|
|
|
if err != nil {
|
|
|
|
log.Println("failed to read image", err)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
// base64 encode the image
|
|
|
|
img.Base64 = string(b)
|
|
|
|
|
|
|
|
m.Content = append(m.Content, anth.NewImageMessageContent(
|
|
|
|
anth.NewMessageContentSource(
|
|
|
|
anth.MessagesContentSourceTypeBase64,
|
|
|
|
img.ContentType,
|
|
|
|
img.Base64,
|
|
|
|
)))
|
2024-10-07 16:33:57 -04:00
|
|
|
}
|
2024-10-06 21:02:26 -04:00
|
|
|
}
|
2024-10-07 14:38:23 -04:00
|
|
|
|
|
|
|
// if this has the same role as the previous message, we can append it to the previous message
|
|
|
|
// as anthropic expects alternating assistant and user roles
|
|
|
|
|
|
|
|
if len(msgs) > 0 && msgs[len(msgs)-1].Role == role {
|
|
|
|
m2 := &msgs[len(msgs)-1]
|
|
|
|
|
|
|
|
m2.Content = append(m2.Content, m.Content...)
|
|
|
|
} else {
|
|
|
|
msgs = append(msgs, m)
|
|
|
|
}
|
2024-10-06 20:01:01 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-12-29 19:45:28 -05:00
|
|
|
if req.Toolbox != nil {
|
|
|
|
for _, tool := range req.Toolbox.funcs {
|
|
|
|
res.Tools = append(res.Tools, anth.ToolDefinition{
|
|
|
|
Name: tool.Name,
|
|
|
|
Description: tool.Description,
|
|
|
|
InputSchema: tool.Parameters,
|
|
|
|
})
|
|
|
|
}
|
2024-10-06 20:01:01 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
res.Messages = msgs
|
|
|
|
|
2024-10-06 21:02:26 -04:00
|
|
|
if req.Temperature != nil {
|
|
|
|
res.Temperature = req.Temperature
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Println("llm request to anthropic request", res)
|
|
|
|
|
2024-10-06 20:01:01 -04:00
|
|
|
return res
|
|
|
|
}
|
|
|
|
|
|
|
|
func (a anthropic) responseToLLMResponse(in anth.MessagesResponse) Response {
|
|
|
|
res := Response{}
|
|
|
|
|
|
|
|
for _, msg := range in.Content {
|
|
|
|
choice := ResponseChoice{}
|
|
|
|
|
|
|
|
switch msg.Type {
|
|
|
|
case anth.MessagesContentTypeText:
|
|
|
|
if msg.Text != nil {
|
|
|
|
choice.Content = *msg.Text
|
|
|
|
}
|
|
|
|
|
|
|
|
case anth.MessagesContentTypeToolUse:
|
|
|
|
if msg.MessageContentToolUse != nil {
|
2024-11-11 00:23:01 -05:00
|
|
|
b, e := json.Marshal(msg.MessageContentToolUse.Input)
|
|
|
|
if e != nil {
|
|
|
|
log.Println("failed to marshal input", e)
|
|
|
|
} else {
|
|
|
|
choice.Calls = append(choice.Calls, ToolCall{
|
|
|
|
ID: msg.MessageContentToolUse.ID,
|
|
|
|
FunctionCall: FunctionCall{
|
|
|
|
Name: msg.MessageContentToolUse.Name,
|
|
|
|
Arguments: string(b),
|
|
|
|
},
|
|
|
|
})
|
|
|
|
}
|
2024-10-06 20:01:01 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
res.Choices = append(res.Choices, choice)
|
|
|
|
}
|
|
|
|
|
2024-10-06 21:02:26 -04:00
|
|
|
log.Println("anthropic response to llm response", res)
|
|
|
|
|
2024-10-06 20:01:01 -04:00
|
|
|
return res
|
|
|
|
}
|
|
|
|
|
|
|
|
func (a anthropic) ChatComplete(ctx context.Context, req Request) (Response, error) {
|
|
|
|
cl := anth.NewClient(a.key)
|
|
|
|
|
|
|
|
res, err := cl.CreateMessages(ctx, a.requestToAnthropicRequest(req))
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return Response{}, fmt.Errorf("failed to chat complete: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return a.responseToLLMResponse(res), nil
|
|
|
|
}
|