Add support for integers and tool configuration in schema handling
This update introduces support for `jsonschema.Integer` types and updates the logic to handle nested items in schemas. Added a new default error log for unknown types using `slog.Error`. Also, integrated tool configuration with a `FunctionCallingConfig` when `dontRequireTool` is false.
This commit is contained in:
parent
ff5e4ca7b0
commit
7c9eb08cb4
@ -147,7 +147,8 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
|
|||||||
res.Messages = msgs
|
res.Messages = msgs
|
||||||
|
|
||||||
if req.Temperature != nil {
|
if req.Temperature != nil {
|
||||||
res.Temperature = req.Temperature
|
var f = float32(*req.Temperature)
|
||||||
|
res.Temperature = &f
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("llm request to anthropic request", res)
|
log.Println("llm request to anthropic request", res)
|
||||||
|
11
go.mod
11
go.mod
@ -24,17 +24,22 @@ require (
|
|||||||
github.com/google/uuid v1.6.0 // indirect
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
||||||
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
|
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
|
||||||
|
github.com/openai/openai-go v0.1.0-beta.6 // indirect
|
||||||
|
github.com/tidwall/gjson v1.14.4 // indirect
|
||||||
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
|
github.com/tidwall/sjson v1.2.5 // indirect
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0 // indirect
|
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0 // indirect
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 // indirect
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 // indirect
|
||||||
go.opentelemetry.io/otel v1.33.0 // indirect
|
go.opentelemetry.io/otel v1.33.0 // indirect
|
||||||
go.opentelemetry.io/otel/metric v1.33.0 // indirect
|
go.opentelemetry.io/otel/metric v1.33.0 // indirect
|
||||||
go.opentelemetry.io/otel/trace v1.33.0 // indirect
|
go.opentelemetry.io/otel/trace v1.33.0 // indirect
|
||||||
golang.org/x/crypto v0.31.0 // indirect
|
golang.org/x/crypto v0.32.0 // indirect
|
||||||
golang.org/x/net v0.33.0 // indirect
|
golang.org/x/net v0.34.0 // indirect
|
||||||
golang.org/x/oauth2 v0.24.0 // indirect
|
golang.org/x/oauth2 v0.24.0 // indirect
|
||||||
golang.org/x/sync v0.10.0 // indirect
|
golang.org/x/sync v0.10.0 // indirect
|
||||||
golang.org/x/sys v0.28.0 // indirect
|
golang.org/x/sys v0.29.0 // indirect
|
||||||
golang.org/x/text v0.21.0 // indirect
|
golang.org/x/text v0.21.0 // indirect
|
||||||
golang.org/x/time v0.8.0 // indirect
|
golang.org/x/time v0.8.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8 // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8 // indirect
|
||||||
|
18
go.sum
18
go.sum
@ -35,12 +35,24 @@ github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrk
|
|||||||
github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA=
|
github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA=
|
||||||
github.com/liushuangls/go-anthropic/v2 v2.13.0 h1:f7KJ54IHxIpHPPhrCzs3SrdP2PfErXiJcJn7DUVstSA=
|
github.com/liushuangls/go-anthropic/v2 v2.13.0 h1:f7KJ54IHxIpHPPhrCzs3SrdP2PfErXiJcJn7DUVstSA=
|
||||||
github.com/liushuangls/go-anthropic/v2 v2.13.0/go.mod h1:5ZwRLF5TQ+y5s/MC9Z1IJYx9WUFgQCKfqFM2xreIQLk=
|
github.com/liushuangls/go-anthropic/v2 v2.13.0/go.mod h1:5ZwRLF5TQ+y5s/MC9Z1IJYx9WUFgQCKfqFM2xreIQLk=
|
||||||
|
github.com/openai/openai-go v0.1.0-beta.6 h1:JquYDpprfrGnlKvQQg+apy9dQ8R9mIrm+wNvAPp6jCQ=
|
||||||
|
github.com/openai/openai-go v0.1.0-beta.6/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/sashabaranov/go-openai v1.36.1 h1:EVfRXwIlW2rUzpx6vR+aeIKCK/xylSrVYAx1TMTSX3g=
|
github.com/sashabaranov/go-openai v1.36.1 h1:EVfRXwIlW2rUzpx6vR+aeIKCK/xylSrVYAx1TMTSX3g=
|
||||||
github.com/sashabaranov/go-openai v1.36.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
github.com/sashabaranov/go-openai v1.36.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
|
||||||
|
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||||
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||||
|
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0 h1:PS8wXpbyaDJQ2VDHHncMe9Vct0Zn1fEjpsjrLxGJoSc=
|
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0 h1:PS8wXpbyaDJQ2VDHHncMe9Vct0Zn1fEjpsjrLxGJoSc=
|
||||||
@ -59,16 +71,22 @@ go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qq
|
|||||||
go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck=
|
go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck=
|
||||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||||
|
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
|
||||||
|
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
|
||||||
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 h1:1UoZQm6f0P/ZO0w1Ri+f+ifG/gXhegadRdwBIXEFWDo=
|
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 h1:1UoZQm6f0P/ZO0w1Ri+f+ifG/gXhegadRdwBIXEFWDo=
|
||||||
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c=
|
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c=
|
||||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||||
|
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
|
||||||
|
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
|
||||||
golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
|
golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
|
||||||
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
|
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||||
|
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||||
golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg=
|
golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg=
|
||||||
|
144
llm.go
144
llm.go
@ -3,8 +3,10 @@ package go_llm
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"github.com/openai/openai-go"
|
||||||
|
"github.com/openai/openai-go/packages/param"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Role string
|
type Role string
|
||||||
@ -82,41 +84,111 @@ func (m *Message) fromRaw(raw map[string]any) Message {
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Message) toChatCompletionMessages() []openai.ChatCompletionMessage {
|
func (m Message) toChatCompletionMessages(model string) []openai.ChatCompletionMessageParamUnion {
|
||||||
var res openai.ChatCompletionMessage
|
var res openai.ChatCompletionMessageParamUnion
|
||||||
|
|
||||||
res.Role = string(m.Role)
|
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
|
||||||
res.Name = m.Name
|
var textContent param.Opt[string]
|
||||||
res.Content = m.Text
|
|
||||||
|
|
||||||
for _, img := range m.Images {
|
for _, img := range m.Images {
|
||||||
if img.Base64 != "" {
|
if img.Base64 != "" {
|
||||||
res.MultiContent = append(res.MultiContent, openai.ChatMessagePart{
|
arrayOfContentParts = append(arrayOfContentParts,
|
||||||
Type: "image_url",
|
openai.ChatCompletionContentPartUnionParam{
|
||||||
ImageURL: &openai.ChatMessageImageURL{
|
OfImageURL: &openai.ChatCompletionContentPartImageParam{
|
||||||
|
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
|
||||||
URL: "data:" + img.ContentType + ";base64," + img.Base64,
|
URL: "data:" + img.ContentType + ";base64," + img.Base64,
|
||||||
},
|
},
|
||||||
})
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
} else if img.Url != "" {
|
} else if img.Url != "" {
|
||||||
res.MultiContent = append(res.MultiContent, openai.ChatMessagePart{
|
arrayOfContentParts = append(arrayOfContentParts,
|
||||||
Type: "image_url",
|
openai.ChatCompletionContentPartUnionParam{
|
||||||
ImageURL: &openai.ChatMessageImageURL{
|
OfImageURL: &openai.ChatCompletionContentPartImageParam{
|
||||||
|
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
|
||||||
URL: img.Url,
|
URL: img.Url,
|
||||||
},
|
},
|
||||||
})
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// openai does not support messages with both content and multi-content
|
if m.Text != "" {
|
||||||
if len(res.MultiContent) > 0 && res.Content != "" {
|
if len(arrayOfContentParts) > 0 {
|
||||||
res.MultiContent = append([]openai.ChatMessagePart{{
|
arrayOfContentParts = append(arrayOfContentParts,
|
||||||
Type: "text",
|
openai.ChatCompletionContentPartUnionParam{
|
||||||
Text: res.Content,
|
OfText: &openai.ChatCompletionContentPartTextParam{
|
||||||
}}, res.MultiContent...)
|
Text: "\n",
|
||||||
res.Content = ""
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
textContent = openai.String(m.Text)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return []openai.ChatCompletionMessage{res}
|
a := strings.Split(model, "-")
|
||||||
|
|
||||||
|
useSystemInsteadOfDeveloper := true
|
||||||
|
if len(a) > 1 && a[0][0] == 'o' {
|
||||||
|
useSystemInsteadOfDeveloper = false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch m.Role {
|
||||||
|
case RoleSystem:
|
||||||
|
if useSystemInsteadOfDeveloper {
|
||||||
|
res = openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfSystem: &openai.ChatCompletionSystemMessageParam{
|
||||||
|
Content: openai.ChatCompletionSystemMessageParamContentUnion{
|
||||||
|
OfString: textContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
res = openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
|
||||||
|
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
|
||||||
|
OfString: textContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case RoleUser:
|
||||||
|
var name param.Opt[string]
|
||||||
|
if m.Name != "" {
|
||||||
|
name = openai.String(m.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
res = openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfUser: &openai.ChatCompletionUserMessageParam{
|
||||||
|
Name: name,
|
||||||
|
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
||||||
|
OfString: textContent,
|
||||||
|
OfArrayOfContentParts: arrayOfContentParts,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
case RoleAssistant:
|
||||||
|
var name param.Opt[string]
|
||||||
|
if m.Name != "" {
|
||||||
|
name = openai.String(m.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
res = openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfAssistant: &openai.ChatCompletionAssistantMessageParam{
|
||||||
|
Name: name,
|
||||||
|
Content: openai.ChatCompletionAssistantMessageParamContentUnion{
|
||||||
|
OfString: textContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return []openai.ChatCompletionMessageParamUnion{res}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCall struct {
|
type ToolCall struct {
|
||||||
@ -134,10 +206,19 @@ func (t ToolCall) toRaw() map[string]any {
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t ToolCall) toChatCompletionMessages() []openai.ChatCompletionMessage {
|
func (t ToolCall) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion {
|
||||||
return []openai.ChatCompletionMessage{{
|
return []openai.ChatCompletionMessageParamUnion{{
|
||||||
Role: openai.ChatMessageRoleTool,
|
OfAssistant: &openai.ChatCompletionAssistantMessageParam{
|
||||||
ToolCallID: t.ID,
|
ToolCalls: []openai.ChatCompletionMessageToolCallParam{
|
||||||
|
{
|
||||||
|
ID: t.ID,
|
||||||
|
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
||||||
|
Name: t.FunctionCall.Name,
|
||||||
|
Arguments: t.FunctionCall.Arguments,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -160,7 +241,7 @@ func (t ToolCallResponse) toRaw() map[string]any {
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t ToolCallResponse) toChatCompletionMessages() []openai.ChatCompletionMessage {
|
func (t ToolCallResponse) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion {
|
||||||
var refusal string
|
var refusal string
|
||||||
if t.Error != nil {
|
if t.Error != nil {
|
||||||
refusal = t.Error.Error()
|
refusal = t.Error.Error()
|
||||||
@ -174,10 +255,13 @@ func (t ToolCallResponse) toChatCompletionMessages() []openai.ChatCompletionMess
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return []openai.ChatCompletionMessage{{
|
return []openai.ChatCompletionMessageParamUnion{{
|
||||||
Role: openai.ChatMessageRoleTool,
|
OfTool: &openai.ChatCompletionToolMessageParam{
|
||||||
Content: fmt.Sprint(t.Result),
|
|
||||||
ToolCallID: t.ID,
|
ToolCallID: t.ID,
|
||||||
|
Content: openai.ChatCompletionToolMessageParamContentUnion{
|
||||||
|
OfString: openai.String(fmt.Sprint(t.Result)),
|
||||||
|
},
|
||||||
|
},
|
||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
71
openai.go
71
openai.go
@ -5,70 +5,69 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
oai "github.com/sashabaranov/go-openai"
|
"github.com/openai/openai-go"
|
||||||
|
"github.com/openai/openai-go/option"
|
||||||
|
"github.com/openai/openai-go/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
type openaiImpl struct {
|
type openaiImpl struct {
|
||||||
key string
|
key string
|
||||||
model string
|
model string
|
||||||
|
baseUrl string
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ LLM = openaiImpl{}
|
var _ LLM = openaiImpl{}
|
||||||
|
|
||||||
func (o openaiImpl) newRequestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
|
func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatCompletionNewParams {
|
||||||
res := oai.ChatCompletionRequest{
|
res := openai.ChatCompletionNewParams{
|
||||||
Model: o.model,
|
Model: o.model,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, i := range request.Conversation {
|
for _, i := range request.Conversation {
|
||||||
res.Messages = append(res.Messages, i.toChatCompletionMessages()...)
|
res.Messages = append(res.Messages, i.toChatCompletionMessages(o.model)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, msg := range request.Messages {
|
for _, msg := range request.Messages {
|
||||||
res.Messages = append(res.Messages, msg.toChatCompletionMessages()...)
|
res.Messages = append(res.Messages, msg.toChatCompletionMessages(o.model)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.Toolbox != nil {
|
if request.Toolbox != nil {
|
||||||
for _, tool := range request.Toolbox.funcs {
|
for _, tool := range request.Toolbox.funcs {
|
||||||
res.Tools = append(res.Tools, oai.Tool{
|
res.Tools = append(res.Tools, openai.ChatCompletionToolParam{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
Function: &oai.FunctionDefinition{
|
Function: shared.FunctionDefinitionParam{
|
||||||
Name: tool.Name,
|
Name: tool.Name,
|
||||||
Description: tool.Description,
|
Description: openai.String(tool.Description),
|
||||||
Strict: tool.Strict,
|
Strict: openai.Bool(tool.Strict),
|
||||||
Parameters: tool.Parameters.Definition(),
|
Parameters: tool.Parameters.FunctionParameters(),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if !request.Toolbox.dontRequireTool {
|
if !request.Toolbox.dontRequireTool {
|
||||||
res.ToolChoice = "required"
|
res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{
|
||||||
|
OfAuto: openai.String("required"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.Temperature != nil {
|
if request.Temperature != nil {
|
||||||
res.Temperature = *request.Temperature
|
res.Temperature = openai.Float(*request.Temperature)
|
||||||
}
|
|
||||||
|
|
||||||
// is this an o1-* model?
|
|
||||||
isO1 := strings.Split(o.model, "-")[0] == "o1"
|
|
||||||
|
|
||||||
if isO1 {
|
|
||||||
// o1 models do not support system messages, so if any messages are system messages, we need to convert them to
|
|
||||||
// user messages
|
|
||||||
|
|
||||||
for i, msg := range res.Messages {
|
|
||||||
if msg.Role == "system" {
|
|
||||||
res.Messages[i].Role = "user"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
|
func (o openaiImpl) responseToLLMResponse(response *openai.ChatCompletion) Response {
|
||||||
res := Response{}
|
var res Response
|
||||||
|
|
||||||
|
if response == nil {
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(response.Choices) == 0 {
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
for _, choice := range response.Choices {
|
for _, choice := range response.Choices {
|
||||||
var toolCalls []ToolCall
|
var toolCalls []ToolCall
|
||||||
@ -77,7 +76,7 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
|
|||||||
ID: call.ID,
|
ID: call.ID,
|
||||||
FunctionCall: FunctionCall{
|
FunctionCall: FunctionCall{
|
||||||
Name: call.Function.Name,
|
Name: call.Function.Name,
|
||||||
Arguments: call.Function.Arguments,
|
Arguments: strings.TrimSpace(call.Function.Arguments),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,7 +86,6 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
|
|||||||
res.Choices = append(res.Choices, ResponseChoice{
|
res.Choices = append(res.Choices, ResponseChoice{
|
||||||
Content: choice.Message.Content,
|
Content: choice.Message.Content,
|
||||||
Role: Role(choice.Message.Role),
|
Role: Role(choice.Message.Role),
|
||||||
Name: choice.Message.Name,
|
|
||||||
Refusal: choice.Message.Refusal,
|
Refusal: choice.Message.Refusal,
|
||||||
Calls: toolCalls,
|
Calls: toolCalls,
|
||||||
})
|
})
|
||||||
@ -97,11 +95,20 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
|
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
|
||||||
cl := oai.NewClient(o.key)
|
var opts = []option.RequestOption{
|
||||||
|
option.WithAPIKey(o.key),
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.baseUrl != "" {
|
||||||
|
opts = append(opts, option.WithBaseURL(o.baseUrl))
|
||||||
|
}
|
||||||
|
|
||||||
|
cl := openai.NewClient(opts...)
|
||||||
|
|
||||||
req := o.newRequestToOpenAIRequest(request)
|
req := o.newRequestToOpenAIRequest(request)
|
||||||
|
|
||||||
resp, err := cl.CreateChatCompletion(ctx, req)
|
resp, err := cl.Chat.Completions.New(ctx, req)
|
||||||
|
//resp, err := cl.CreateChatCompletion(ctx, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)
|
return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)
|
||||||
|
20
request.go
20
request.go
@ -1,6 +1,8 @@
|
|||||||
package go_llm
|
package go_llm
|
||||||
|
|
||||||
import "github.com/sashabaranov/go-openai"
|
import (
|
||||||
|
"github.com/openai/openai-go"
|
||||||
|
)
|
||||||
|
|
||||||
type rawAble interface {
|
type rawAble interface {
|
||||||
toRaw() map[string]any
|
toRaw() map[string]any
|
||||||
@ -8,13 +10,13 @@ type rawAble interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Input interface {
|
type Input interface {
|
||||||
toChatCompletionMessages() []openai.ChatCompletionMessage
|
toChatCompletionMessages(model string) []openai.ChatCompletionMessageParamUnion
|
||||||
}
|
}
|
||||||
type Request struct {
|
type Request struct {
|
||||||
Conversation []Input
|
Conversation []Input
|
||||||
Messages []Message
|
Messages []Message
|
||||||
Toolbox *ToolBox
|
Toolbox *ToolBox
|
||||||
Temperature *float32
|
Temperature *float64
|
||||||
}
|
}
|
||||||
|
|
||||||
// NextRequest will take the current request's conversation, messages, the response, and any tool results, and
|
// NextRequest will take the current request's conversation, messages, the response, and any tool results, and
|
||||||
@ -33,16 +35,8 @@ func (req Request) NextRequest(resp ResponseChoice, toolResults []ToolCallRespon
|
|||||||
res.Conversation = append(res.Conversation, msg)
|
res.Conversation = append(res.Conversation, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if there are tool calls, then we need to add those to the conversation
|
if resp.Content != "" || resp.Refusal != "" || len(resp.Calls) > 0 {
|
||||||
for _, call := range resp.Calls {
|
res.Conversation = append(res.Conversation, resp)
|
||||||
res.Conversation = append(res.Conversation, call)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Content != "" || resp.Refusal != "" {
|
|
||||||
res.Conversation = append(res.Conversation, Message{
|
|
||||||
Role: RoleAssistant,
|
|
||||||
Text: resp.Content,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// if there are tool results, then we need to add those to the conversation
|
// if there are tool results, then we need to add those to the conversation
|
||||||
|
32
response.go
32
response.go
@ -1,6 +1,8 @@
|
|||||||
package go_llm
|
package go_llm
|
||||||
|
|
||||||
import "github.com/sashabaranov/go-openai"
|
import (
|
||||||
|
"github.com/openai/openai-go"
|
||||||
|
)
|
||||||
|
|
||||||
type ResponseChoice struct {
|
type ResponseChoice struct {
|
||||||
Index int
|
Index int
|
||||||
@ -30,24 +32,34 @@ func (r ResponseChoice) toRaw() map[string]any {
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r ResponseChoice) toChatCompletionMessages() []openai.ChatCompletionMessage {
|
func (r ResponseChoice) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion {
|
||||||
var res = openai.ChatCompletionMessage{
|
var as openai.ChatCompletionAssistantMessageParam
|
||||||
Role: openai.ChatMessageRoleAssistant,
|
|
||||||
Content: r.Content,
|
if r.Name != "" {
|
||||||
Refusal: r.Refusal,
|
as.Name = openai.String(r.Name)
|
||||||
|
}
|
||||||
|
if r.Refusal != "" {
|
||||||
|
as.Refusal = openai.String(r.Refusal)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Content != "" {
|
||||||
|
as.Content.OfString = openai.String(r.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, call := range r.Calls {
|
for _, call := range r.Calls {
|
||||||
res.ToolCalls = append(res.ToolCalls, openai.ToolCall{
|
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
|
||||||
ID: call.ID,
|
ID: call.ID,
|
||||||
Type: openai.ToolTypeFunction,
|
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
||||||
Function: openai.FunctionCall{
|
|
||||||
Name: call.FunctionCall.Name,
|
Name: call.FunctionCall.Name,
|
||||||
Arguments: call.FunctionCall.Arguments,
|
Arguments: call.FunctionCall.Arguments,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return []openai.ChatCompletionMessage{res}
|
return []openai.ChatCompletionMessageParamUnion{
|
||||||
|
{
|
||||||
|
OfAssistant: &as,
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r ResponseChoice) toInput() []Input {
|
func (r ResponseChoice) toInput() []Input {
|
||||||
|
@ -28,22 +28,27 @@ func getFromType(t reflect.Type, b basic) Type {
|
|||||||
switch t.Kind() {
|
switch t.Kind() {
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
b.DataType = jsonschema.String
|
b.DataType = jsonschema.String
|
||||||
|
b.typeName = "string"
|
||||||
return b
|
return b
|
||||||
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
b.DataType = jsonschema.Integer
|
b.DataType = jsonschema.Integer
|
||||||
|
b.typeName = "integer"
|
||||||
return b
|
return b
|
||||||
|
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
b.DataType = jsonschema.Integer
|
b.DataType = jsonschema.Integer
|
||||||
|
b.typeName = "integer"
|
||||||
return b
|
return b
|
||||||
|
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
b.DataType = jsonschema.Number
|
b.DataType = jsonschema.Number
|
||||||
|
b.typeName = "number"
|
||||||
return b
|
return b
|
||||||
|
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
b.DataType = jsonschema.Boolean
|
b.DataType = jsonschema.Boolean
|
||||||
|
b.typeName = "boolean"
|
||||||
return b
|
return b
|
||||||
|
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
@ -107,7 +112,7 @@ func getObject(t reflect.Type) object {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return object{
|
return object{
|
||||||
basic: basic{DataType: jsonschema.Object},
|
basic: basic{DataType: jsonschema.Object, typeName: "object"},
|
||||||
fields: fields,
|
fields: fields,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -116,6 +121,7 @@ func getArray(t reflect.Type) array {
|
|||||||
res := array{
|
res := array{
|
||||||
basic: basic{
|
basic: basic{
|
||||||
DataType: jsonschema.Array,
|
DataType: jsonschema.Array,
|
||||||
|
typeName: "array",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package schema
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"github.com/openai/openai-go"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||||||
@ -18,6 +19,14 @@ func (a array) SchemaType() jsonschema.DataType {
|
|||||||
return jsonschema.Array
|
return jsonschema.Array
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a array) FunctionParameters() openai.FunctionParameters {
|
||||||
|
return openai.FunctionParameters{
|
||||||
|
"type": "array",
|
||||||
|
"description": a.Description(),
|
||||||
|
"items": a.items.FunctionParameters(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (a array) Definition() jsonschema.Definition {
|
func (a array) Definition() jsonschema.Definition {
|
||||||
def := a.basic.Definition()
|
def := a.basic.Definition()
|
||||||
def.Type = jsonschema.Array
|
def.Type = jsonschema.Array
|
||||||
|
@ -2,6 +2,7 @@ package schema
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"github.com/openai/openai-go"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
@ -13,6 +14,7 @@ var _ Type = basic{}
|
|||||||
|
|
||||||
type basic struct {
|
type basic struct {
|
||||||
jsonschema.DataType
|
jsonschema.DataType
|
||||||
|
typeName string
|
||||||
|
|
||||||
// index is the position of the parameter in the StructField of the function's parameter struct
|
// index is the position of the parameter in the StructField of the function's parameter struct
|
||||||
index int
|
index int
|
||||||
@ -29,6 +31,13 @@ func (b basic) SchemaType() jsonschema.DataType {
|
|||||||
return b.DataType
|
return b.DataType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b basic) FunctionParameters() openai.FunctionParameters {
|
||||||
|
return openai.FunctionParameters{
|
||||||
|
"type": b.typeName,
|
||||||
|
"description": b.description,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (b basic) Definition() jsonschema.Definition {
|
func (b basic) Definition() jsonschema.Definition {
|
||||||
return jsonschema.Definition{
|
return jsonschema.Definition{
|
||||||
Type: b.DataType,
|
Type: b.DataType,
|
||||||
|
@ -2,6 +2,7 @@ package schema
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"github.com/openai/openai-go"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
@ -19,6 +20,14 @@ func (e enum) SchemaType() jsonschema.DataType {
|
|||||||
return jsonschema.String
|
return jsonschema.String
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e enum) FunctionParameters() openai.FunctionParameters {
|
||||||
|
return openai.FunctionParameters{
|
||||||
|
"type": "string",
|
||||||
|
"description": e.Description(),
|
||||||
|
"enum": e.values,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (e enum) Definition() jsonschema.Definition {
|
func (e enum) Definition() jsonschema.Definition {
|
||||||
def := e.basic.Definition()
|
def := e.basic.Definition()
|
||||||
def.Enum = e.values
|
def.Enum = e.values
|
||||||
|
@ -2,6 +2,7 @@ package schema
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"github.com/openai/openai-go"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||||||
@ -19,6 +20,19 @@ func (o object) SchemaType() jsonschema.DataType {
|
|||||||
return jsonschema.Object
|
return jsonschema.Object
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o object) FunctionParameters() openai.FunctionParameters {
|
||||||
|
var properties = map[string]openai.FunctionParameters{}
|
||||||
|
for k, v := range o.fields {
|
||||||
|
properties[k] = v.FunctionParameters()
|
||||||
|
}
|
||||||
|
|
||||||
|
return openai.FunctionParameters{
|
||||||
|
"type": "object",
|
||||||
|
"description": o.Description(),
|
||||||
|
"properties": properties,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (o object) Definition() jsonschema.Definition {
|
func (o object) Definition() jsonschema.Definition {
|
||||||
def := o.basic.Definition()
|
def := o.basic.Definition()
|
||||||
def.Type = jsonschema.Object
|
def.Type = jsonschema.Object
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
package schema
|
package schema
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/openai/openai-go"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Type interface {
|
type Type interface {
|
||||||
|
FunctionParameters() openai.FunctionParameters
|
||||||
|
|
||||||
SchemaType() jsonschema.DataType
|
SchemaType() jsonschema.DataType
|
||||||
Definition() jsonschema.Definition
|
Definition() jsonschema.Definition
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user