Compare commits

..

1 Commits

Author SHA1 Message Date
f4e9082ce4 updated anthropic 2024-12-26 22:46:59 -05:00
18 changed files with 343 additions and 1205 deletions

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
"log/slog"
"net/http" "net/http"
anth "github.com/liushuangls/go-anthropic/v2" anth "github.com/liushuangls/go-anthropic/v2"
@ -26,13 +25,6 @@ func (a anthropic) ModelVersion(modelVersion string) (ChatCompletion, error) {
return a, nil return a, nil
} }
func deferClose(c io.Closer) {
err := c.Close()
if err != nil {
slog.Error("error closing", "error", err)
}
}
func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest { func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
res := anth.MessagesRequest{ res := anth.MessagesRequest{
Model: anth.Model(a.model), Model: anth.Model(a.model),
@ -71,11 +63,6 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
} }
for _, img := range msg.Images { for _, img := range msg.Images {
// anthropic doesn't allow the assistant to send images, so we need to say it's from the user
if m.Role == anth.RoleAssistant {
m.Role = anth.RoleUser
}
if img.Base64 != "" { if img.Base64 != "" {
m.Content = append(m.Content, anth.NewImageMessageContent( m.Content = append(m.Content, anth.NewImageMessageContent(
anth.NewMessageContentSource( anth.NewMessageContentSource(
@ -98,7 +85,7 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
continue continue
} }
defer deferClose(resp.Body) defer resp.Body.Close()
img.ContentType = resp.Header.Get("Content-Type") img.ContentType = resp.Header.Get("Content-Type")
@ -123,6 +110,7 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
// if this has the same role as the previous message, we can append it to the previous message // if this has the same role as the previous message, we can append it to the previous message
// as anthropic expects alternating assistant and user roles // as anthropic expects alternating assistant and user roles
if len(msgs) > 0 && msgs[len(msgs)-1].Role == role { if len(msgs) > 0 && msgs[len(msgs)-1].Role == role {
m2 := &msgs[len(msgs)-1] m2 := &msgs[len(msgs)-1]
@ -133,19 +121,18 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
} }
} }
for _, tool := range req.Toolbox.functions { for _, tool := range req.Toolbox.funcs {
res.Tools = append(res.Tools, anth.ToolDefinition{ res.Tools = append(res.Tools, anth.ToolDefinition{
Name: tool.Name, Name: tool.Name,
Description: tool.Description, Description: tool.Description,
InputSchema: tool.Parameters.AnthropicInputSchema(), InputSchema: tool.Parameters,
}) })
} }
res.Messages = msgs res.Messages = msgs
if req.Temperature != nil { if req.Temperature != nil {
var f = float32(*req.Temperature) res.Temperature = req.Temperature
res.Temperature = &f
} }
log.Println("llm request to anthropic request", res) log.Println("llm request to anthropic request", res)
@ -154,13 +141,15 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
} }
func (a anthropic) responseToLLMResponse(in anth.MessagesResponse) Response { func (a anthropic) responseToLLMResponse(in anth.MessagesResponse) Response {
choice := ResponseChoice{} res := Response{}
for _, msg := range in.Content { for _, msg := range in.Content {
choice := ResponseChoice{}
switch msg.Type { switch msg.Type {
case anth.MessagesContentTypeText: case anth.MessagesContentTypeText:
if msg.Text != nil { if msg.Text != nil {
choice.Content += *msg.Text choice.Content = *msg.Text
} }
case anth.MessagesContentTypeToolUse: case anth.MessagesContentTypeToolUse:
@ -179,13 +168,13 @@ func (a anthropic) responseToLLMResponse(in anth.MessagesResponse) Response {
} }
} }
} }
res.Choices = append(res.Choices, choice)
} }
log.Println("anthropic response to llm response", choice) log.Println("anthropic response to llm response", res)
return Response{ return res
Choices: []ResponseChoice{choice},
}
} }
func (a anthropic) ChatComplete(ctx context.Context, req Request) (Response, error) { func (a anthropic) ChatComplete(ctx context.Context, req Request) (Response, error) {

View File

@ -1,120 +0,0 @@
package go_llm
import (
"context"
"time"
)
type Context struct {
context.Context
request Request
response *ResponseChoice
toolcall *ToolCall
syntheticFields map[string]string
}
func (c *Context) ToNewRequest(toolResults ...ToolCallResponse) Request {
var res Request
res.Toolbox = c.request.Toolbox
res.Temperature = c.request.Temperature
res.Conversation = make([]Input, len(c.request.Conversation))
copy(res.Conversation, c.request.Conversation)
// now for every input message, convert those to an Input to add to the conversation
for _, msg := range c.request.Messages {
res.Conversation = append(res.Conversation, msg)
}
// if there are tool calls, then we need to add those to the conversation
if c.response != nil {
res.Conversation = append(res.Conversation, *c.response)
}
// if there are tool results, then we need to add those to the conversation
for _, result := range toolResults {
res.Conversation = append(res.Conversation, result)
}
return res
}
func NewContext(ctx context.Context, request Request, response *ResponseChoice, toolcall *ToolCall) *Context {
return &Context{Context: ctx, request: request, response: response, toolcall: toolcall}
}
func (c *Context) Request() Request {
return c.request
}
func (c *Context) Response() *ResponseChoice {
return c.response
}
func (c *Context) ToolCall() *ToolCall {
return c.toolcall
}
func (c *Context) SyntheticFields() map[string]string {
if c.syntheticFields == nil {
c.syntheticFields = map[string]string{}
}
return c.syntheticFields
}
func (c *Context) WithContext(ctx context.Context) *Context {
return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields}
}
func (c *Context) WithRequest(request Request) *Context {
return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields}
}
func (c *Context) WithResponse(response *ResponseChoice) *Context {
return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall, syntheticFields: c.syntheticFields}
}
func (c *Context) WithToolCall(toolcall *ToolCall) *Context {
return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall, syntheticFields: c.syntheticFields}
}
func (c *Context) WithSyntheticFields(syntheticFields map[string]string) *Context {
return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: syntheticFields}
}
func (c *Context) Deadline() (deadline time.Time, ok bool) {
return c.Context.Deadline()
}
func (c *Context) Done() <-chan struct{} {
return c.Context.Done()
}
func (c *Context) Err() error {
return c.Context.Err()
}
func (c *Context) Value(key any) any {
switch key {
case "request":
return c.request
case "response":
return c.response
case "toolcall":
return c.toolcall
case "syntheticFields":
return c.syntheticFields
}
return c.Context.Value(key)
}
func (c *Context) WithTimeout(timeout time.Duration) (*Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(c.Context, timeout)
return c.WithContext(ctx), cancel
}

View File

@ -4,11 +4,11 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog" "gitea.stevedudenhoeffer.com/steve/go-llm/schema"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
"reflect" "reflect"
"time" "time"
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
) )
type Function struct { type Function struct {
@ -26,71 +26,27 @@ type Function struct {
fn reflect.Value fn reflect.Value
paramType reflect.Type paramType reflect.Type
// definition is a cache of the openaiImpl jsonschema definition
definition *jsonschema.Definition
} }
func (f Function) WithSyntheticField(name string, description string) Function { func (f *Function) Execute(ctx context.Context, input string) (string, error) {
if obj, o := f.Parameters.(schema.Object); o {
f.Parameters = obj.WithSyntheticField(name, description)
}
return f
}
func (f Function) WithSyntheticFields(fieldsAndDescriptions map[string]string) Function {
if obj, o := f.Parameters.(schema.Object); o {
for k, v := range fieldsAndDescriptions {
obj = obj.WithSyntheticField(k, v)
}
f.Parameters = obj
}
return f
}
func (f Function) Execute(ctx *Context, input string) (any, error) {
if !f.fn.IsValid() { if !f.fn.IsValid() {
return "", fmt.Errorf("function %s is not implemented", f.Name) return "", fmt.Errorf("function %s is not implemented", f.Name)
} }
slog.Info("Function.Execute", "name", f.Name, "input", input, "f", f.paramType)
// first, we need to parse the input into the struct // first, we need to parse the input into the struct
p := reflect.New(f.paramType) p := reflect.New(f.paramType)
fmt.Println("Function.Execute", f.Name, "input:", input) fmt.Println("Function.Execute", f.Name, "input:", input)
//m := map[string]any{}
var vals map[string]any err := json.Unmarshal([]byte(input), p.Interface())
err := json.Unmarshal([]byte(input), &vals)
var syntheticFields map[string]string
// first eat up any synthetic fields
if obj, o := f.Parameters.(schema.Object); o {
for k := range obj.SyntheticFields() {
key := schema.SyntheticFieldPrefix + k
if val, ok := vals[key]; ok {
if syntheticFields == nil {
syntheticFields = map[string]string{}
}
syntheticFields[k] = fmt.Sprint(val)
delete(vals, key)
}
}
}
// now for any remaining fields, re-marshal them into json and then unmarshal into the struct
b, err := json.Marshal(vals)
if err != nil {
return "", fmt.Errorf("failed to marshal input: %w (input: %s)", err, input)
}
// now we can unmarshal the input into the struct
err = json.Unmarshal(b, p.Interface())
if err != nil { if err != nil {
return "", fmt.Errorf("failed to unmarshal input: %w (input: %s)", err, input) return "", fmt.Errorf("failed to unmarshal input: %w (input: %s)", err, input)
} }
// now we can call the function // now we can call the function
exec := func(ctx *Context) (any, error) { exec := func(ctx context.Context) (string, error) {
out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()}) out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
if len(out) != 2 { if len(out) != 2 {
@ -98,7 +54,7 @@ func (f Function) Execute(ctx *Context, input string) (any, error) {
} }
if out[1].IsNil() { if out[1].IsNil() {
return out[0].Interface(), nil return out[0].String(), nil
} }
return "", out[1].Interface().(error) return "", out[1].Interface().(error)
@ -106,26 +62,31 @@ func (f Function) Execute(ctx *Context, input string) (any, error) {
var cancel context.CancelFunc var cancel context.CancelFunc
if f.Timeout > 0 { if f.Timeout > 0 {
ctx, cancel = ctx.WithTimeout(f.Timeout) ctx, cancel = context.WithTimeout(ctx, f.Timeout)
defer cancel() defer cancel()
} }
return exec(ctx) return exec(ctx)
} }
func (f *Function) toOpenAIFunction() *openai.FunctionDefinition {
return &openai.FunctionDefinition{
Name: f.Name,
Description: f.Description,
Strict: f.Strict,
Parameters: f.Parameters,
}
}
func (f *Function) toOpenAIDefinition() jsonschema.Definition {
if f.definition == nil {
def := f.Parameters.Definition()
f.definition = &def
}
return *f.definition
}
type FunctionCall struct { type FunctionCall struct {
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"` Arguments string `json:"arguments,omitempty"`
} }
func (fc *FunctionCall) toRaw() map[string]any {
res := map[string]interface{}{
"name": fc.Name,
}
if fc.Arguments != "" {
res["arguments"] = fc.Arguments
}
return res
}

View File

@ -1,9 +1,9 @@
package go_llm package go_llm
import ( import (
"reflect" "context"
"gitea.stevedudenhoeffer.com/steve/go-llm/schema" "gitea.stevedudenhoeffer.com/steve/go-llm/schema"
"reflect"
) )
// Parse takes a function pointer and returns a function object. // Parse takes a function pointer and returns a function object.
@ -13,7 +13,7 @@ import (
// The struct parameters can have the following tags: // The struct parameters can have the following tags:
// - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for // - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for
func NewFunction[T any](name string, description string, fn func(*Context, T) (any, error)) Function { func NewFunction[T any](name string, description string, fn func(context.Context, T) (string, error)) *Function {
var o T var o T
res := Function{ res := Function{
@ -31,5 +31,5 @@ func NewFunction[T any](name string, description string, fn func(*Context, T) (a
panic("function parameter must be a struct") panic("function parameter must be a struct")
} }
return res return &res
} }

57
go.mod
View File

@ -4,44 +4,41 @@ go 1.23.1
require ( require (
github.com/google/generative-ai-go v0.19.0 github.com/google/generative-ai-go v0.19.0
github.com/liushuangls/go-anthropic/v2 v2.15.0 github.com/liushuangls/go-anthropic/v2 v2.13.0
github.com/openai/openai-go v0.1.0-beta.9 github.com/sashabaranov/go-openai v1.36.0
google.golang.org/api v0.228.0 golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67
google.golang.org/api v0.214.0
) )
require ( require (
cloud.google.com/go v0.120.0 // indirect cloud.google.com/go v0.117.0 // indirect
cloud.google.com/go/ai v0.10.1 // indirect cloud.google.com/go/ai v0.9.0 // indirect
cloud.google.com/go/auth v0.15.0 // indirect cloud.google.com/go/auth v0.13.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect
cloud.google.com/go/compute/metadata v0.6.0 // indirect cloud.google.com/go/compute/metadata v0.6.0 // indirect
cloud.google.com/go/longrunning v0.6.6 // indirect cloud.google.com/go/longrunning v0.6.3 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/s2a-go v0.1.9 // indirect github.com/google/s2a-go v0.1.8 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
github.com/googleapis/gax-go/v2 v2.14.1 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 // indirect
go.opentelemetry.io/otel v1.35.0 // indirect go.opentelemetry.io/otel v1.33.0 // indirect
go.opentelemetry.io/otel/metric v1.35.0 // indirect go.opentelemetry.io/otel/metric v1.33.0 // indirect
go.opentelemetry.io/otel/trace v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.33.0 // indirect
golang.org/x/crypto v0.37.0 // indirect golang.org/x/crypto v0.31.0 // indirect
golang.org/x/net v0.39.0 // indirect golang.org/x/net v0.33.0 // indirect
golang.org/x/oauth2 v0.29.0 // indirect golang.org/x/oauth2 v0.24.0 // indirect
golang.org/x/sync v0.13.0 // indirect golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.32.0 // indirect golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.24.0 // indirect golang.org/x/text v0.21.0 // indirect
golang.org/x/time v0.11.0 // indirect golang.org/x/time v0.8.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a // indirect google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 // indirect
google.golang.org/grpc v1.71.1 // indirect google.golang.org/grpc v1.69.2 // indirect
google.golang.org/protobuf v1.36.6 // indirect google.golang.org/protobuf v1.36.1 // indirect
) )

128
go.sum
View File

@ -1,15 +1,15 @@
cloud.google.com/go v0.120.0 h1:wc6bgG9DHyKqF5/vQvX1CiZrtHnxJjBlKUyF9nP6meA= cloud.google.com/go v0.117.0 h1:Z5TNFfQxj7WG2FgOGX1ekC5RiXrYgms6QscOm32M/4s=
cloud.google.com/go v0.120.0/go.mod h1:/beW32s8/pGRuj4IILWQNd4uuebeT4dkOhKmkfit64Q= cloud.google.com/go v0.117.0/go.mod h1:ZbwhVTb1DBGt2Iwb3tNO6SEK4q+cplHZmLWH+DelYYc=
cloud.google.com/go/ai v0.10.1 h1:EU93KqYmMeOKgaBXAz2DshH2C/BzAT1P+iJORksLIic= cloud.google.com/go/ai v0.9.0 h1:r1Ig8O8+Qr3Ia3WfoO+gokD0fxB2Rk4quppuKjmGMsY=
cloud.google.com/go/ai v0.10.1/go.mod h1:sWWHZvmJ83BjuxAQtYEiA0SFTpijtbH+SXWFO14ri5A= cloud.google.com/go/ai v0.9.0/go.mod h1:28bKM/oxmRgxmRgI1GLumFv+NSkt+DscAg/gF+54zzY=
cloud.google.com/go/auth v0.15.0 h1:Ly0u4aA5vG/fsSsxu98qCQBemXtAtJf+95z9HK+cxps= cloud.google.com/go/auth v0.13.0 h1:8Fu8TZy167JkW8Tj3q7dIkr2v4cndv41ouecJx0PAHs=
cloud.google.com/go/auth v0.15.0/go.mod h1:WJDGqZ1o9E9wKIL+IwStfyn/+s59zl4Bi+1KQNVXLZ8= cloud.google.com/go/auth v0.13.0/go.mod h1:COOjD9gwfKNKz+IIduatIhYJQIc0mG3H102r/EMxX6Q=
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= cloud.google.com/go/auth/oauth2adapt v0.2.6 h1:V6a6XDu2lTwPZWOawrAa9HUK+DB2zfJyTuciBG5hFkU=
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/auth/oauth2adapt v0.2.6/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8=
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
cloud.google.com/go/longrunning v0.6.6 h1:XJNDo5MUfMM05xK3ewpbSdmt7R2Zw+aQEMbdQR65Rbw= cloud.google.com/go/longrunning v0.6.3 h1:A2q2vuyXysRcwzqDpMMLSI6mb6o39miS52UEG/Rd2ng=
cloud.google.com/go/longrunning v0.6.6/go.mod h1:hyeGJUrPHcx0u2Uu1UFSoYZLn4lkMrccJig0t4FI7yw= cloud.google.com/go/longrunning v0.6.3/go.mod h1:k/vIs83RN4bE3YCswdXC5PFfWVILjm3hpEUlSko4PiI=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
@ -23,73 +23,65 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg= github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg=
github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E= github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM=
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4= github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw=
github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q= github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q=
github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA= github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA=
github.com/liushuangls/go-anthropic/v2 v2.15.0 h1:zpplg7BRV/9FlMmeMPI0eDwhViB0l9SkNrF8ErYlRoQ= github.com/liushuangls/go-anthropic/v2 v2.13.0 h1:f7KJ54IHxIpHPPhrCzs3SrdP2PfErXiJcJn7DUVstSA=
github.com/liushuangls/go-anthropic/v2 v2.15.0/go.mod h1:kq2yW3JVy1/rph8u5KzX7F3q95CEpCT2RXp/2nfCmb4= github.com/liushuangls/go-anthropic/v2 v2.13.0/go.mod h1:5ZwRLF5TQ+y5s/MC9Z1IJYx9WUFgQCKfqFM2xreIQLk=
github.com/openai/openai-go v0.1.0-beta.9 h1:ABpubc5yU/3ejee2GgRrbFta81SG/d7bQbB8mIdP0Xo=
github.com/openai/openai-go v0.1.0-beta.9/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sashabaranov/go-openai v1.36.0 h1:fcSrn8uGuorzPWCBp8L0aCR95Zjb/Dd+ZSML0YZy9EI=
github.com/sashabaranov/go-openai v1.36.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 h1:x7wzEgXfnzJcHDwStJT+mxOz4etr2EcexjqhBvmoakw= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0 h1:PS8wXpbyaDJQ2VDHHncMe9Vct0Zn1fEjpsjrLxGJoSc=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0/go.mod h1:rg+RlpR5dKwaS95IyyZqj5Wd4E13lk/msnTS0Xl9lJM= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0/go.mod h1:HDBUsEjOuRC0EzKZ1bSaRGZWUBAzo+MhAcUUORSr4D0=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 h1:sbiXRNDSWJOTobXh5HyQKjq6wUC5tNybqjIqDpAY4CU= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 h1:yd02MEjBdJkG3uabWP9apV+OuWRIXGDuJEUJbOHmCFU=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0/go.mod h1:69uWxva0WgAA/4bu2Yy70SLDBwZXuQ6PbBpbsa5iZrQ= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0/go.mod h1:umTcuxiv1n/s/S6/c2AT/g2CQ7u5C59sHDNmfSwgz7Q=
go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= go.opentelemetry.io/otel v1.33.0 h1:/FerN9bax5LoK51X/sI0SVYrjSE0/yUL7DpxW4K3FWw=
go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y= go.opentelemetry.io/otel v1.33.0/go.mod h1:SUUkR6csvUQl+yjReHu5uM3EtVV7MBm5FHKRlNx4I8I=
go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M= go.opentelemetry.io/otel/metric v1.33.0 h1:r+JOocAyeRVXD8lZpjdQjzMadVZp2M4WmQ+5WtEnklQ=
go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE= go.opentelemetry.io/otel/metric v1.33.0/go.mod h1:L9+Fyctbp6HFTddIxClbQkjtubW6O9QS3Ann/M82u6M=
go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY= go.opentelemetry.io/otel/sdk v1.31.0 h1:xLY3abVHYZ5HSfOg3l2E5LUj2Cwva5Y7yGxnSW9H5Gk=
go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg= go.opentelemetry.io/otel/sdk v1.31.0/go.mod h1:TfRbMdhvxIIr/B2N2LQW2S5v9m3gOQ/08KsbbO5BPT0=
go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o= go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4JjxTeYusH7zMc=
go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w= go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8=
go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs= go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qqW2d/s=
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 h1:1UoZQm6f0P/ZO0w1Ri+f+ifG/gXhegadRdwBIXEFWDo=
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c=
golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
google.golang.org/api v0.228.0 h1:X2DJ/uoWGnY5obVjewbp8icSL5U4FzuCfy9OjbLSnLs= golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg=
google.golang.org/api v0.228.0/go.mod h1:wNvRS1Pbe8r4+IfBIniV8fwCpGwTrYa+kMUDiC5z5a4= golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a h1:OQ7sHVzkx6L57dQpzUS4ckfWJ51KDH74XHTDe23xWAs= google.golang.org/api v0.214.0 h1:h2Gkq07OYi6kusGOaT/9rnNljuXmqPnaig7WGPmKbwA=
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a/go.mod h1:2R6XrVC8Oc08GlNh8ujEpc7HkLiEZ16QeY7FxIs20ac= google.golang.org/api v0.214.0/go.mod h1:bYPpLG8AyeMWwDU6NXoB00xC0DFkikVvd5MfwoxjLqE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a h1:GIqLhp/cYUkuGuiT+vJk8vhOP86L4+SP5j8yXgeVpvI= google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8 h1:st3LcW/BPi75W4q1jJTEor/QWwbNlPlDG0JTn6XhZu0=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8/go.mod h1:klhJGKFyG8Tn50enBn7gizg4nXGXJ+jqEREdCWaPcV4=
google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI= google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 h1:TqExAhdPaB60Ux47Cn0oLV07rGnxZzIsaRhQaqS666A=
google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8/go.mod h1:lcTa1sDdWEIHMWlITnIczmw5w60CF9ffkb8Z+DVmmjA=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/grpc v1.69.2 h1:U3S9QEtbXC0bYNvRtcoklF3xGtLViumSYxWykJS+7AU=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= google.golang.org/grpc v1.69.2/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4=
google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk=
google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

139
google.go
View File

@ -2,12 +2,8 @@ package go_llm
import ( import (
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http"
"github.com/google/generative-ai-go/genai" "github.com/google/generative-ai-go/genai"
"google.golang.org/api/option" "google.golang.org/api/option"
) )
@ -23,117 +19,51 @@ func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) {
return g, nil return g, nil
} }
func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.GenerativeModel, *genai.ChatSession, []genai.Part) { func (g google) requestToGoogleRequest(in Request, model *genai.GenerativeModel) []genai.Part {
res := *model
for _, tool := range in.Toolbox.functions { if in.Temperature != nil {
res.Tools = append(res.Tools, &genai.Tool{ model.GenerationConfig.Temperature = in.Temperature
FunctionDeclarations: []*genai.FunctionDeclaration{ }
{
res := []genai.Part{}
for _, c := range in.Messages {
res = append(res, genai.Text(c.Text))
}
for _, tool := range in.Toolbox.funcs {
panic("google ToolBox is todo" + tool.Name)
/*
t := genai.Tool{}
t.FunctionDeclarations = append(t.FunctionDeclarations, &genai.FunctionDeclaration{
Name: tool.Name, Name: tool.Name,
Description: tool.Description, Description: tool.Description,
Parameters: tool.Parameters.GoogleParameters(), Parameters: nil, //tool.Parameters,
},
},
}) })
*/
} }
if !in.Toolbox.RequiresTool() { return res
res.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{
Mode: genai.FunctionCallingAny,
}}
}
cs := res.StartChat()
for i, c := range in.Messages {
content := genai.NewUserContent(genai.Text(c.Text))
switch c.Role {
case RoleAssistant, RoleSystem:
content.Role = "model"
case RoleUser:
content.Role = "user"
}
for _, img := range c.Images {
if img.Url != "" {
// gemini does not support URLs, so we need to download the image and convert it to a blob
// Download the image from the URL
resp, err := http.Get(img.Url)
if err != nil {
panic(fmt.Sprintf("error downloading image: %v", err))
}
defer resp.Body.Close()
// Check the Content-Length to ensure it's not over 20MB
if resp.ContentLength > 20*1024*1024 {
panic(fmt.Sprintf("image size exceeds 20MB: %d bytes", resp.ContentLength))
}
// Read the content into a byte slice
data, err := io.ReadAll(resp.Body)
if err != nil {
panic(fmt.Sprintf("error reading image data: %v", err))
}
// Ensure the MIME type is appropriate
mimeType := http.DetectContentType(data)
switch mimeType {
case "image/jpeg", "image/png", "image/gif":
// MIME type is valid
default:
panic(fmt.Sprintf("unsupported image MIME type: %s", mimeType))
}
// Create a genai.Blob using the validated image data
content.Parts = append(content.Parts, genai.Blob{
MIMEType: mimeType,
Data: data,
})
} else {
// convert base64 to blob
b, e := base64.StdEncoding.DecodeString(img.Base64)
if e != nil {
panic(fmt.Sprintf("error decoding base64: %v", e))
}
content.Parts = append(content.Parts, genai.Blob{
MIMEType: img.ContentType,
Data: b,
})
}
}
// if this is the last message, we want to add to history, we want it to be the parts
if i == len(in.Messages)-1 {
return &res, cs, content.Parts
}
cs.History = append(cs.History, content)
}
return &res, cs, nil
} }
func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) { func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) {
res := Response{} res := Response{}
for _, c := range in.Candidates { for _, c := range in.Candidates {
var choice ResponseChoice
var set = false
if c.Content != nil { if c.Content != nil {
for _, p := range c.Content.Parts { for _, p := range c.Content.Parts {
switch p.(type) { switch p.(type) {
case genai.Text: case genai.Text:
choice.Content = string(p.(genai.Text)) res.Choices = append(res.Choices, ResponseChoice{
set = true Content: string(p.(genai.Text)),
})
case genai.FunctionCall: case genai.FunctionCall:
v := p.(genai.FunctionCall) v := p.(genai.FunctionCall)
choice := ResponseChoice{}
choice.Content = v.Name
b, e := json.Marshal(v.Args) b, e := json.Marshal(v.Args)
if e != nil { if e != nil {
@ -149,17 +79,14 @@ func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Respon
} }
choice.Calls = append(choice.Calls, call) choice.Calls = append(choice.Calls, call)
set = true
res.Choices = append(res.Choices, choice)
default: default:
return Response{}, fmt.Errorf("unknown part type: %T", p) return Response{}, fmt.Errorf("unknown part type: %T", p)
} }
} }
} }
if set {
choice.Role = RoleAssistant
res.Choices = append(res.Choices, choice)
}
} }
return res, nil return res, nil
@ -174,13 +101,9 @@ func (g google) ChatComplete(ctx context.Context, req Request) (Response, error)
model := cl.GenerativeModel(g.model) model := cl.GenerativeModel(g.model)
_, cs, parts := g.requestToChatHistory(req, model) parts := g.requestToGoogleRequest(req, model)
resp, err := cs.SendMessage(ctx, parts...) resp, err := model.GenerateContent(ctx, parts...)
//parts := g.requestToGoogleRequest(req, model)
//resp, err := model.GenerateContent(ctx, parts...)
if err != nil { if err != nil {
return Response{}, fmt.Errorf("error generating content: %w", err) return Response{}, fmt.Errorf("error generating content: %w", err)

243
llm.go
View File

@ -2,11 +2,6 @@ package go_llm
import ( import (
"context" "context"
"fmt"
"strings"
"github.com/openai/openai-go"
"github.com/openai/openai-go/packages/param"
) )
type Role string type Role string
@ -23,26 +18,6 @@ type Image struct {
Url string Url string
} }
func (i Image) toRaw() map[string]any {
res := map[string]any{
"base64": i.Base64,
"contenttype": i.ContentType,
"url": i.Url,
}
return res
}
func (i *Image) fromRaw(raw map[string]any) Image {
var res Image
res.Base64 = raw["base64"].(string)
res.ContentType = raw["contenttype"].(string)
res.Url = raw["url"].(string)
return res
}
type Message struct { type Message struct {
Role Role Role Role
Name string Name string
@ -50,145 +25,10 @@ type Message struct {
Images []Image Images []Image
} }
func (m Message) toRaw() map[string]any { type Request struct {
res := map[string]any{ Messages []Message
"role": m.Role, Toolbox *ToolBox
"name": m.Name, Temperature *float32
"text": m.Text,
}
images := make([]map[string]any, 0, len(m.Images))
for _, img := range m.Images {
images = append(images, img.toRaw())
}
res["images"] = images
return res
}
func (m *Message) fromRaw(raw map[string]any) Message {
var res Message
res.Role = Role(raw["role"].(string))
res.Name = raw["name"].(string)
res.Text = raw["text"].(string)
images := raw["images"].([]map[string]any)
for _, img := range images {
var i Image
res.Images = append(res.Images, i.fromRaw(img))
}
return res
}
func (m Message) toChatCompletionMessages(model string) []openai.ChatCompletionMessageParamUnion {
var res openai.ChatCompletionMessageParamUnion
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
var textContent param.Opt[string]
for _, img := range m.Images {
if img.Base64 != "" {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfImageURL: &openai.ChatCompletionContentPartImageParam{
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
URL: "data:" + img.ContentType + ";base64," + img.Base64,
},
},
},
)
} else if img.Url != "" {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfImageURL: &openai.ChatCompletionContentPartImageParam{
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
URL: img.Url,
},
},
},
)
}
}
if m.Text != "" {
if len(arrayOfContentParts) > 0 {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfText: &openai.ChatCompletionContentPartTextParam{
Text: "\n",
},
},
)
} else {
textContent = openai.String(m.Text)
}
}
a := strings.Split(model, "-")
useSystemInsteadOfDeveloper := true
if len(a) > 1 && a[0][0] == 'o' {
useSystemInsteadOfDeveloper = false
}
switch m.Role {
case RoleSystem:
if useSystemInsteadOfDeveloper {
res = openai.ChatCompletionMessageParamUnion{
OfSystem: &openai.ChatCompletionSystemMessageParam{
Content: openai.ChatCompletionSystemMessageParamContentUnion{
OfString: textContent,
},
},
}
} else {
res = openai.ChatCompletionMessageParamUnion{
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
OfString: textContent,
},
},
}
}
case RoleUser:
var name param.Opt[string]
if m.Name != "" {
name = openai.String(m.Name)
}
res = openai.ChatCompletionMessageParamUnion{
OfUser: &openai.ChatCompletionUserMessageParam{
Name: name,
Content: openai.ChatCompletionUserMessageParamContentUnion{
OfString: textContent,
OfArrayOfContentParts: arrayOfContentParts,
},
},
}
case RoleAssistant:
var name param.Opt[string]
if m.Name != "" {
name = openai.String(m.Name)
}
res = openai.ChatCompletionMessageParamUnion{
OfAssistant: &openai.ChatCompletionAssistantMessageParam{
Name: name,
Content: openai.ChatCompletionAssistantMessageParamContentUnion{
OfString: textContent,
},
},
}
}
return []openai.ChatCompletionMessageParamUnion{res}
} }
type ToolCall struct { type ToolCall struct {
@ -196,73 +36,16 @@ type ToolCall struct {
FunctionCall FunctionCall FunctionCall FunctionCall
} }
func (t ToolCall) toRaw() map[string]any { type ResponseChoice struct {
res := map[string]any{ Index int
"id": t.ID, Role Role
} Content string
Refusal string
res["function"] = t.FunctionCall.toRaw() Name string
Calls []ToolCall
return res
} }
type Response struct {
func (t ToolCall) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion { Choices []ResponseChoice
return []openai.ChatCompletionMessageParamUnion{{
OfAssistant: &openai.ChatCompletionAssistantMessageParam{
ToolCalls: []openai.ChatCompletionMessageToolCallParam{
{
ID: t.ID,
Function: openai.ChatCompletionMessageToolCallFunctionParam{
Name: t.FunctionCall.Name,
Arguments: t.FunctionCall.Arguments,
},
},
},
},
}}
}
type ToolCallResponse struct {
ID string
Result any
Error error
}
func (t ToolCallResponse) toRaw() map[string]any {
res := map[string]any{
"id": t.ID,
"result": t.Result,
}
if t.Error != nil {
res["error"] = t.Error.Error()
}
return res
}
func (t ToolCallResponse) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion {
var refusal string
if t.Error != nil {
refusal = t.Error.Error()
}
if refusal != "" {
if t.Result != "" {
t.Result = fmt.Sprint(t.Result) + " (error in execution: " + refusal + ")"
} else {
t.Result = "error in execution:" + refusal
}
}
return []openai.ChatCompletionMessageParamUnion{{
OfTool: &openai.ChatCompletionToolMessageParam{
ToolCallID: t.ID,
Content: openai.ChatCompletionToolMessageParamContentUnion{
OfString: openai.String(fmt.Sprint(t.Result)),
},
},
}}
} }
type ChatCompletion interface { type ChatCompletion interface {

119
openai.go
View File

@ -3,87 +3,119 @@ package go_llm
import ( import (
"context" "context"
"fmt" "fmt"
oai "github.com/sashabaranov/go-openai"
"strings" "strings"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/shared"
) )
type openaiImpl struct { type openaiImpl struct {
key string key string
model string model string
baseUrl string
} }
var _ LLM = openaiImpl{} var _ LLM = openaiImpl{}
func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatCompletionNewParams { func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
res := openai.ChatCompletionNewParams{ res := oai.ChatCompletionRequest{
Model: o.model, Model: o.model,
} }
for _, i := range request.Conversation {
res.Messages = append(res.Messages, i.toChatCompletionMessages(o.model)...)
}
for _, msg := range request.Messages { for _, msg := range request.Messages {
res.Messages = append(res.Messages, msg.toChatCompletionMessages(o.model)...) m := oai.ChatCompletionMessage{
Content: msg.Text,
Role: string(msg.Role),
Name: msg.Name,
} }
for _, tool := range request.Toolbox.functions { for _, img := range msg.Images {
res.Tools = append(res.Tools, openai.ChatCompletionToolParam{ if img.Base64 != "" {
Type: "function", m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
Function: shared.FunctionDefinitionParam{ Type: "image_url",
Name: tool.Name, ImageURL: &oai.ChatMessageImageURL{
Description: openai.String(tool.Description), URL: fmt.Sprintf("data:%s;base64,%s", img.ContentType, img.Base64),
Strict: openai.Bool(tool.Strict), },
Parameters: tool.Parameters.OpenAIParameters(), })
} else if img.Url != "" {
m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
Type: "image_url",
ImageURL: &oai.ChatMessageImageURL{
URL: img.Url,
}, },
}) })
} }
if request.Toolbox.RequiresTool() {
res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{
OfAuto: openai.String("required"),
} }
// openai does not allow Content and MultiContent to be set at the same time, so we need to check
if len(m.MultiContent) > 0 && m.Content != "" {
m.MultiContent = append([]oai.ChatMessagePart{{
Type: "text",
Text: m.Content,
}}, m.MultiContent...)
m.Content = ""
}
res.Messages = append(res.Messages, m)
}
for _, tool := range request.Toolbox.funcs {
res.Tools = append(res.Tools, oai.Tool{
Type: "function",
Function: &oai.FunctionDefinition{
Name: tool.Name,
Description: tool.Description,
Strict: tool.Strict,
Parameters: tool.Parameters.Definition(),
},
})
fmt.Println("tool:", tool.Name, tool.Description, tool.Strict, tool.Parameters.Definition())
} }
if request.Temperature != nil { if request.Temperature != nil {
res.Temperature = openai.Float(*request.Temperature) res.Temperature = *request.Temperature
}
// is this an o1-* model?
isO1 := strings.Split(o.model, "-")[0] == "o1"
if isO1 {
// o1 models do not support system messages, so if any messages are system messages, we need to convert them to
// user messages
for i, msg := range res.Messages {
if msg.Role == "system" {
res.Messages[i].Role = "user"
}
}
} }
return res return res
} }
func (o openaiImpl) responseToLLMResponse(response *openai.ChatCompletion) Response { func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
var res Response res := Response{}
if response == nil {
return res
}
if len(response.Choices) == 0 {
return res
}
for _, choice := range response.Choices { for _, choice := range response.Choices {
var toolCalls []ToolCall var toolCalls []ToolCall
for _, call := range choice.Message.ToolCalls { for _, call := range choice.Message.ToolCalls {
fmt.Println("responseToLLMResponse: call:", call.Function.Arguments)
toolCall := ToolCall{ toolCall := ToolCall{
ID: call.ID, ID: call.ID,
FunctionCall: FunctionCall{ FunctionCall: FunctionCall{
Name: call.Function.Name, Name: call.Function.Name,
Arguments: strings.TrimSpace(call.Function.Arguments), Arguments: call.Function.Arguments,
}, },
} }
fmt.Println("toolCall.FunctionCall.Arguments:", toolCall.FunctionCall.Arguments)
toolCalls = append(toolCalls, toolCall) toolCalls = append(toolCalls, toolCall)
} }
res.Choices = append(res.Choices, ResponseChoice{ res.Choices = append(res.Choices, ResponseChoice{
Content: choice.Message.Content, Content: choice.Message.Content,
Role: Role(choice.Message.Role), Role: Role(choice.Message.Role),
Name: choice.Message.Name,
Refusal: choice.Message.Refusal, Refusal: choice.Message.Refusal,
Calls: toolCalls, Calls: toolCalls,
}) })
@ -93,20 +125,13 @@ func (o openaiImpl) responseToLLMResponse(response *openai.ChatCompletion) Respo
} }
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) { func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
var opts = []option.RequestOption{ cl := oai.NewClient(o.key)
option.WithAPIKey(o.key),
}
if o.baseUrl != "" { req := o.requestToOpenAIRequest(request)
opts = append(opts, option.WithBaseURL(o.baseUrl))
}
cl := openai.NewClient(opts...) resp, err := cl.CreateChatCompletion(ctx, req)
req := o.newRequestToOpenAIRequest(request) fmt.Println("resp:", fmt.Sprintf("%#v", resp))
resp, err := cl.Chat.Completions.New(ctx, req)
//resp, err := cl.CreateChatCompletion(ctx, req)
if err != nil { if err != nil {
return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err) return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)

View File

@ -1,48 +0,0 @@
package go_llm
import (
"github.com/openai/openai-go"
)
type rawAble interface {
toRaw() map[string]any
fromRaw(raw map[string]any) Input
}
type Input interface {
toChatCompletionMessages(model string) []openai.ChatCompletionMessageParamUnion
}
type Request struct {
Conversation []Input
Messages []Message
Toolbox ToolBox
Temperature *float64
}
// NextRequest will take the current request's conversation, messages, the response, and any tool results, and
// return a new request with the conversation updated to include the response and tool results.
func (req Request) NextRequest(resp ResponseChoice, toolResults []ToolCallResponse) Request {
var res Request
res.Toolbox = req.Toolbox
res.Temperature = req.Temperature
res.Conversation = make([]Input, len(req.Conversation))
copy(res.Conversation, req.Conversation)
// now for every input message, convert those to an Input to add to the conversation
for _, msg := range req.Messages {
res.Conversation = append(res.Conversation, msg)
}
if resp.Content != "" || resp.Refusal != "" || len(resp.Calls) > 0 {
res.Conversation = append(res.Conversation, resp)
}
// if there are tool results, then we need to add those to the conversation
for _, result := range toolResults {
res.Conversation = append(res.Conversation, result)
}
return res
}

View File

@ -1,84 +0,0 @@
package go_llm
import (
"github.com/openai/openai-go"
)
type ResponseChoice struct {
Index int
Role Role
Content string
Refusal string
Name string
Calls []ToolCall
}
func (r ResponseChoice) toRaw() map[string]any {
res := map[string]any{
"index": r.Index,
"role": r.Role,
"content": r.Content,
"refusal": r.Refusal,
"name": r.Name,
}
calls := make([]map[string]any, 0, len(r.Calls))
for _, call := range r.Calls {
calls = append(calls, call.toRaw())
}
res["tool_calls"] = calls
return res
}
func (r ResponseChoice) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion {
var as openai.ChatCompletionAssistantMessageParam
if r.Name != "" {
as.Name = openai.String(r.Name)
}
if r.Refusal != "" {
as.Refusal = openai.String(r.Refusal)
}
if r.Content != "" {
as.Content.OfString = openai.String(r.Content)
}
for _, call := range r.Calls {
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
ID: call.ID,
Function: openai.ChatCompletionMessageToolCallFunctionParam{
Name: call.FunctionCall.Name,
Arguments: call.FunctionCall.Arguments,
},
})
}
return []openai.ChatCompletionMessageParamUnion{
{
OfAssistant: &as,
},
}
}
func (r ResponseChoice) toInput() []Input {
var res []Input
for _, call := range r.Calls {
res = append(res, call)
}
if r.Content != "" || r.Refusal != "" {
res = append(res, Message{
Role: RoleAssistant,
Text: r.Content,
})
}
return res
}
type Response struct {
Choices []ResponseChoice
}

View File

@ -3,6 +3,8 @@ package schema
import ( import (
"reflect" "reflect"
"strings" "strings"
"github.com/sashabaranov/go-openai/jsonschema"
) )
// GetType will, given an interface{} that is a struct (NOT a pointer to a struct), return the Type of the struct that // GetType will, given an interface{} that is a struct (NOT a pointer to a struct), return the Type of the struct that
@ -25,28 +27,23 @@ func getFromType(t reflect.Type, b basic) Type {
switch t.Kind() { switch t.Kind() {
case reflect.String: case reflect.String:
b.DataType = TypeString b.DataType = jsonschema.String
b.typeName = "string"
return b return b
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
b.DataType = TypeInteger b.DataType = jsonschema.Integer
b.typeName = "integer"
return b return b
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
b.DataType = TypeInteger b.DataType = jsonschema.Integer
b.typeName = "integer"
return b return b
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
b.DataType = TypeNumber b.DataType = jsonschema.Number
b.typeName = "number"
return b return b
case reflect.Bool: case reflect.Bool:
b.DataType = TypeBoolean b.DataType = jsonschema.Boolean
b.typeName = "boolean"
return b return b
case reflect.Struct: case reflect.Struct:
@ -92,8 +89,6 @@ func getField(f reflect.StructField, index int) Type {
} }
} }
b.DataType = TypeString
b.typeName = "string"
return enum{ return enum{
basic: b, basic: b,
values: vals, values: vals,
@ -104,26 +99,15 @@ func getField(f reflect.StructField, index int) Type {
return getFromType(t, b) return getFromType(t, b)
} }
func getObject(t reflect.Type) Object { func getObject(t reflect.Type) object {
fields := make(map[string]Type, t.NumField()) fields := make(map[string]Type, t.NumField())
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {
field := t.Field(i) field := t.Field(i)
if field.Anonymous {
// if the field is anonymous, we need to get the fields of the anonymous struct
// and add them to the object
anon := getObject(field.Type)
for k, v := range anon.fields {
fields[k] = v
}
continue
} else {
fields[field.Name] = getField(field, i) fields[field.Name] = getField(field, i)
} }
}
return Object{ return object{
basic: basic{DataType: TypeObject, typeName: "object"}, basic: basic{DataType: jsonschema.Object},
fields: fields, fields: fields,
} }
} }
@ -131,8 +115,7 @@ func getObject(t reflect.Type) Object {
func getArray(t reflect.Type) array { func getArray(t reflect.Type) array {
res := array{ res := array{
basic: basic{ basic: basic{
DataType: TypeArray, DataType: jsonschema.Array,
typeName: "array",
}, },
} }

View File

@ -4,8 +4,7 @@ import (
"errors" "errors"
"reflect" "reflect"
"github.com/google/generative-ai-go/genai" "github.com/sashabaranov/go-openai/jsonschema"
"github.com/openai/openai-go"
) )
type array struct { type array struct {
@ -15,28 +14,17 @@ type array struct {
items Type items Type
} }
func (a array) OpenAIParameters() openai.FunctionParameters { func (a array) SchemaType() jsonschema.DataType {
return openai.FunctionParameters{ return jsonschema.Array
"type": "array",
"description": a.Description(),
"items": a.items.OpenAIParameters(),
}
} }
func (a array) GoogleParameters() *genai.Schema { func (a array) Definition() jsonschema.Definition {
return &genai.Schema{ def := a.basic.Definition()
Type: genai.TypeArray, def.Type = jsonschema.Array
Description: a.Description(), i := a.items.Definition()
Items: a.items.GoogleParameters(), def.Items = &i
} def.AdditionalProperties = false
} return def
func (a array) AnthropicInputSchema() map[string]any {
return map[string]any{
"type": "array",
"description": a.Description(),
"items": a.items.AnthropicInputSchema(),
}
} }
func (a array) FromAny(val any) (reflect.Value, error) { func (a array) FromAny(val any) (reflect.Value, error) {

View File

@ -5,27 +5,14 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"github.com/google/generative-ai-go/genai" "github.com/sashabaranov/go-openai/jsonschema"
"github.com/openai/openai-go"
) )
// just enforcing that basic implements Type // just enforcing that basic implements Type
var _ Type = basic{} var _ Type = basic{}
type DataType string
const (
TypeString DataType = "string"
TypeInteger DataType = "integer"
TypeNumber DataType = "number"
TypeBoolean DataType = "boolean"
TypeObject DataType = "object"
TypeArray DataType = "array"
)
type basic struct { type basic struct {
DataType jsonschema.DataType
typeName string
// index is the position of the parameter in the StructField of the function's parameter struct // index is the position of the parameter in the StructField of the function's parameter struct
index int index int
@ -38,64 +25,17 @@ type basic struct {
description string description string
} }
func (b basic) OpenAIParameters() openai.FunctionParameters { func (b basic) SchemaType() jsonschema.DataType {
return openai.FunctionParameters{ return b.DataType
"type": b.typeName,
"description": b.description,
}
} }
func (b basic) GoogleParameters() *genai.Schema { func (b basic) Definition() jsonschema.Definition {
var t = genai.TypeUnspecified return jsonschema.Definition{
Type: b.DataType,
switch b.DataType {
case TypeString:
t = genai.TypeString
case TypeInteger:
t = genai.TypeInteger
case TypeNumber:
t = genai.TypeNumber
case TypeBoolean:
t = genai.TypeBoolean
case TypeObject:
t = genai.TypeObject
case TypeArray:
t = genai.TypeArray
default:
t = genai.TypeUnspecified
}
return &genai.Schema{
Type: t,
Description: b.description, Description: b.description,
} }
} }
func (b basic) AnthropicInputSchema() map[string]any {
var t = "string"
switch b.DataType {
case TypeString:
t = "string"
case TypeInteger:
t = "integer"
case TypeNumber:
t = "number"
case TypeBoolean:
t = "boolean"
case TypeObject:
t = "object"
case TypeArray:
t = "array"
default:
t = "unknown"
}
return map[string]any{
"type": t,
"description": b.description,
}
}
func (b basic) Required() bool { func (b basic) Required() bool {
return b.required return b.required
} }
@ -108,12 +48,12 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
v := reflect.ValueOf(val) v := reflect.ValueOf(val)
switch b.DataType { switch b.DataType {
case TypeString: case jsonschema.String:
var val = v.String() var val = v.String()
return reflect.ValueOf(val), nil return reflect.ValueOf(val), nil
case TypeInteger: case jsonschema.Integer:
if v.Kind() == reflect.Float64 { if v.Kind() == reflect.Float64 {
return v.Convert(reflect.TypeOf(int(0))), nil return v.Convert(reflect.TypeOf(int(0))), nil
} else if v.Kind() != reflect.Int { } else if v.Kind() != reflect.Int {
@ -122,7 +62,7 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
return v, nil return v, nil
} }
case TypeNumber: case jsonschema.Number:
if v.Kind() == reflect.Float64 { if v.Kind() == reflect.Float64 {
return v.Convert(reflect.TypeOf(float64(0))), nil return v.Convert(reflect.TypeOf(float64(0))), nil
} else if v.Kind() != reflect.Float64 { } else if v.Kind() != reflect.Float64 {
@ -131,7 +71,7 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
return v, nil return v, nil
} }
case TypeBoolean: case jsonschema.Boolean:
if v.Kind() == reflect.Bool { if v.Kind() == reflect.Bool {
return v, nil return v, nil
} else if v.Kind() == reflect.String { } else if v.Kind() == reflect.String {

View File

@ -3,10 +3,10 @@ package schema
import ( import (
"errors" "errors"
"reflect" "reflect"
"slices"
"github.com/google/generative-ai-go/genai" "golang.org/x/exp/slices"
"github.com/openai/openai-go"
"github.com/sashabaranov/go-openai/jsonschema"
) )
type enum struct { type enum struct {
@ -15,28 +15,14 @@ type enum struct {
values []string values []string
} }
func (e enum) FunctionParameters() openai.FunctionParameters { func (e enum) SchemaType() jsonschema.DataType {
return openai.FunctionParameters{ return jsonschema.String
"type": "string",
"description": e.Description(),
"enum": e.values,
}
} }
func (e enum) GoogleParameters() *genai.Schema { func (e enum) Definition() jsonschema.Definition {
return &genai.Schema{ def := e.basic.Definition()
Type: genai.TypeString, def.Enum = e.values
Description: e.Description(), return def
Enum: e.values,
}
}
func (e enum) AnthropicInputSchema() map[string]any {
return map[string]any{
"type": "string",
"description": e.Description(),
"enum": e.values,
}
} }
func (e enum) FromAny(val any) (reflect.Value, error) { func (e enum) FromAny(val any) (reflect.Value, error) {

View File

@ -4,125 +4,34 @@ import (
"errors" "errors"
"reflect" "reflect"
"github.com/google/generative-ai-go/genai" "github.com/sashabaranov/go-openai/jsonschema"
"github.com/openai/openai-go"
) )
const ( type object struct {
// SyntheticFieldPrefix is any prefix that is added to any synthetic fields that are added to the object, to prevent
// collisions with the fields in the struct.
SyntheticFieldPrefix = "__"
)
type Object struct {
basic basic
ref reflect.Type ref reflect.Type
fields map[string]Type fields map[string]Type
// syntheticFields are fields that are not in the struct but are generated by a system.
synetheticFields map[string]Type
} }
func (o Object) WithSyntheticField(name string, description string) Object { func (o object) SchemaType() jsonschema.DataType {
if o.synetheticFields == nil { return jsonschema.Object
o.synetheticFields = map[string]Type{}
}
o.synetheticFields[name] = basic{
DataType: TypeString,
typeName: "string",
index: -1,
required: false,
description: description,
}
return o
} }
func (o Object) SyntheticFields() map[string]Type { func (o object) Definition() jsonschema.Definition {
return o.synetheticFields def := o.basic.Definition()
} def.Type = jsonschema.Object
def.Properties = make(map[string]jsonschema.Definition)
func (o Object) OpenAIParameters() openai.FunctionParameters {
var properties = map[string]openai.FunctionParameters{}
var required []string
for k, v := range o.fields { for k, v := range o.fields {
properties[k] = v.OpenAIParameters() def.Properties[k] = v.Definition()
if v.Required() {
required = append(required, k)
}
} }
for k, v := range o.synetheticFields { def.AdditionalProperties = false
properties[SyntheticFieldPrefix+k] = v.OpenAIParameters() return def
if v.Required() {
required = append(required, SyntheticFieldPrefix+k)
}
}
var res = openai.FunctionParameters{
"type": "object",
"description": o.Description(),
"properties": properties,
}
if len(required) > 0 {
res["required"] = required
}
return res
} }
func (o Object) GoogleParameters() *genai.Schema { func (o object) FromAny(val any) (reflect.Value, error) {
var properties = map[string]*genai.Schema{}
var required []string
for k, v := range o.fields {
properties[k] = v.GoogleParameters()
if v.Required() {
required = append(required, k)
}
}
var res = &genai.Schema{
Type: genai.TypeObject,
Description: o.Description(),
Properties: properties,
}
if len(required) > 0 {
res.Required = required
}
return res
}
func (o Object) AnthropicInputSchema() map[string]any {
var properties = map[string]any{}
var required []string
for k, v := range o.fields {
properties[k] = v.AnthropicInputSchema()
if v.Required() {
required = append(required, k)
}
}
var res = map[string]any{
"type": "object",
"description": o.Description(),
"properties": properties,
}
if len(required) > 0 {
res["required"] = required
}
return res
}
// FromAny converts the value from any to the correct type, returning the value, and an error if any
func (o Object) FromAny(val any) (reflect.Value, error) {
// if the value is nil, we can't do anything // if the value is nil, we can't do anything
if val == nil { if val == nil {
return reflect.Value{}, nil return reflect.Value{}, nil
@ -159,7 +68,7 @@ func (o Object) FromAny(val any) (reflect.Value, error) {
return obj, nil return obj, nil
} }
func (o Object) SetValueOnField(obj reflect.Value, val reflect.Value) { func (o object) SetValueOnField(obj reflect.Value, val reflect.Value) {
// if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value // if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value
if !o.required { if !o.required {
val = val.Addr() val = val.Addr()

View File

@ -3,17 +3,12 @@ package schema
import ( import (
"reflect" "reflect"
"github.com/google/generative-ai-go/genai" "github.com/sashabaranov/go-openai/jsonschema"
"github.com/openai/openai-go"
) )
type Type interface { type Type interface {
OpenAIParameters() openai.FunctionParameters SchemaType() jsonschema.DataType
GoogleParameters() *genai.Schema Definition() jsonschema.Definition
AnthropicInputSchema() map[string]any
//SchemaType() jsonschema.DataType
//Definition() jsonschema.Definition
Required() bool Required() bool
Description() string Description() string

View File

@ -4,82 +4,56 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/sashabaranov/go-openai"
) )
// ToolBox is a collection of tools that OpenAI can use to execute functions. // ToolBox is a collection of tools that OpenAI can use to execute functions.
// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with // It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with
// the correct parameters. // the correct parameters.
type ToolBox struct { type ToolBox struct {
functions map[string]Function funcs []Function
dontRequireTool bool names map[string]Function
} }
func NewToolBox(fns ...Function) ToolBox { func NewToolBox(fns ...*Function) *ToolBox {
res := ToolBox{ res := ToolBox{
functions: map[string]Function{}, funcs: []Function{},
names: map[string]Function{},
} }
for _, f := range fns { for _, f := range fns {
res.functions[f.Name] = f o := *f
res.names[o.Name] = o
res.funcs = append(res.funcs, o)
}
return &res
}
func (t *ToolBox) WithFunction(f Function) *ToolBox {
t2 := *t
t2.names[f.Name] = f
t2.funcs = append(t2.funcs, f)
return &t2
}
// ToOpenAI will convert the current ToolBox to a slice of openai.Tool, which can be used to send to the OpenAI API.
func (t *ToolBox) toOpenAI() []openai.Tool {
var res []openai.Tool
for _, f := range t.funcs {
res = append(res, openai.Tool{
Type: "function",
Function: f.toOpenAIFunction(),
})
} }
return res return res
} }
func (t ToolBox) Functions() []Function { func (t *ToolBox) ToToolChoice() any {
var res []Function if len(t.funcs) == 0 {
for _, f := range t.functions {
res = append(res, f)
}
return res
}
func (t ToolBox) WithFunction(f Function) ToolBox {
t.functions[f.Name] = f
return t
}
func (t ToolBox) WithFunctions(fns ...Function) ToolBox {
for _, f := range fns {
t.functions[f.Name] = f
}
return t
}
func (t ToolBox) WithSyntheticFieldsAddedToAllFunctions(fieldsAndDescriptions map[string]string) ToolBox {
for k, v := range t.functions {
t.functions[k] = v.WithSyntheticFields(fieldsAndDescriptions)
}
return t
}
func (t ToolBox) ForEachFunction(fn func(f Function)) {
for _, f := range t.functions {
fn(f)
}
}
func (t ToolBox) WithFunctionRemoved(name string) ToolBox {
delete(t.functions, name)
return t
}
func (t ToolBox) WithRequireTool(val bool) ToolBox {
t.dontRequireTool = !val
return t
}
func (t ToolBox) RequiresTool() bool {
return !t.dontRequireTool && len(t.functions) > 0
}
func (t ToolBox) ToToolChoice() any {
if len(t.functions) == 0 {
return nil return nil
} }
@ -90,8 +64,8 @@ var (
ErrFunctionNotFound = errors.New("function not found") ErrFunctionNotFound = errors.New("function not found")
) )
func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) { func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, params string) (string, error) {
f, ok := t.functions[functionName] f, ok := t.names[functionName]
if !ok { if !ok {
return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName)) return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName))
@ -100,61 +74,6 @@ func (t ToolBox) executeFunction(ctx *Context, functionName string, params strin
return f.Execute(ctx, params) return f.Execute(ctx, params)
} }
func (t ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) { func (t *ToolBox) Execute(ctx context.Context, toolCall ToolCall) (string, error) {
return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments) return t.ExecuteFunction(ctx, toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
}
func (t ToolBox) GetSyntheticParametersFromFunctionContext(ctx context.Context) map[string]string {
val := ctx.Value("syntheticParameters")
if val == nil {
return nil
}
syntheticParameters, ok := val.(map[string]string)
if !ok {
return nil
}
return syntheticParameters
}
// ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished.
// OnNewFunction is called when a new function is created
// OnFunctionFinished is called when a function is finished
func (t ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) {
var res []ToolCallResponse
for _, call := range toolCalls {
ctx := ctx.WithToolCall(&call)
if call.FunctionCall.Name == "" {
return nil, newError(ErrFunctionNotFound, errors.New("function name is empty"))
}
var arg any
if OnNewFunction != nil {
var err error
arg, err = OnNewFunction(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments)
if err != nil {
return nil, newError(ErrFunctionNotFound, err)
}
}
out, err := t.Execute(ctx, call)
if OnFunctionFinished != nil {
err := OnFunctionFinished(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments, out, err, arg)
if err != nil {
return nil, newError(ErrFunctionNotFound, err)
}
}
res = append(res, ToolCallResponse{
ID: call.ID,
Result: out,
Error: err,
})
}
return res, nil
} }