diff --git a/anthropic.go b/anthropic.go index b2982dc..90ebdf5 100644 --- a/anthropic.go +++ b/anthropic.go @@ -147,7 +147,8 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest { res.Messages = msgs if req.Temperature != nil { - res.Temperature = req.Temperature + var f = float32(*req.Temperature) + res.Temperature = &f } log.Println("llm request to anthropic request", res) diff --git a/go.mod b/go.mod index ba7d832..0ea5099 100644 --- a/go.mod +++ b/go.mod @@ -24,17 +24,22 @@ require ( github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect + github.com/openai/openai-go v0.1.0-beta.6 // indirect + github.com/tidwall/gjson v1.14.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 // indirect go.opentelemetry.io/otel v1.33.0 // indirect go.opentelemetry.io/otel/metric v1.33.0 // indirect go.opentelemetry.io/otel/trace v1.33.0 // indirect - golang.org/x/crypto v0.31.0 // indirect - golang.org/x/net v0.33.0 // indirect + golang.org/x/crypto v0.32.0 // indirect + golang.org/x/net v0.34.0 // indirect golang.org/x/oauth2 v0.24.0 // indirect golang.org/x/sync v0.10.0 // indirect - golang.org/x/sys v0.28.0 // indirect + golang.org/x/sys v0.29.0 // indirect golang.org/x/text v0.21.0 // indirect golang.org/x/time v0.8.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8 // indirect diff --git a/go.sum b/go.sum index 46a13e0..9195c05 100644 --- a/go.sum +++ b/go.sum @@ -35,12 +35,24 @@ github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrk github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA= github.com/liushuangls/go-anthropic/v2 v2.13.0 h1:f7KJ54IHxIpHPPhrCzs3SrdP2PfErXiJcJn7DUVstSA= github.com/liushuangls/go-anthropic/v2 v2.13.0/go.mod h1:5ZwRLF5TQ+y5s/MC9Z1IJYx9WUFgQCKfqFM2xreIQLk= +github.com/openai/openai-go v0.1.0-beta.6 h1:JquYDpprfrGnlKvQQg+apy9dQ8R9mIrm+wNvAPp6jCQ= +github.com/openai/openai-go v0.1.0-beta.6/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sashabaranov/go-openai v1.36.1 h1:EVfRXwIlW2rUzpx6vR+aeIKCK/xylSrVYAx1TMTSX3g= github.com/sashabaranov/go-openai v1.36.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0 h1:PS8wXpbyaDJQ2VDHHncMe9Vct0Zn1fEjpsjrLxGJoSc= @@ -59,16 +71,22 @@ go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qq go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= +golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 h1:1UoZQm6f0P/ZO0w1Ri+f+ifG/gXhegadRdwBIXEFWDo= golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= diff --git a/llm.go b/llm.go index 3b19f24..c892d50 100644 --- a/llm.go +++ b/llm.go @@ -3,8 +3,10 @@ package go_llm import ( "context" "fmt" + "strings" - "github.com/sashabaranov/go-openai" + "github.com/openai/openai-go" + "github.com/openai/openai-go/packages/param" ) type Role string @@ -82,41 +84,111 @@ func (m *Message) fromRaw(raw map[string]any) Message { return res } -func (m Message) toChatCompletionMessages() []openai.ChatCompletionMessage { - var res openai.ChatCompletionMessage +func (m Message) toChatCompletionMessages(model string) []openai.ChatCompletionMessageParamUnion { + var res openai.ChatCompletionMessageParamUnion - res.Role = string(m.Role) - res.Name = m.Name - res.Content = m.Text + var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam + var textContent param.Opt[string] for _, img := range m.Images { if img.Base64 != "" { - res.MultiContent = append(res.MultiContent, openai.ChatMessagePart{ - Type: "image_url", - ImageURL: &openai.ChatMessageImageURL{ - URL: "data:" + img.ContentType + ";base64," + img.Base64, + arrayOfContentParts = append(arrayOfContentParts, + openai.ChatCompletionContentPartUnionParam{ + OfImageURL: &openai.ChatCompletionContentPartImageParam{ + ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ + URL: "data:" + img.ContentType + ";base64," + img.Base64, + }, + }, }, - }) + ) } else if img.Url != "" { - res.MultiContent = append(res.MultiContent, openai.ChatMessagePart{ - Type: "image_url", - ImageURL: &openai.ChatMessageImageURL{ - URL: img.Url, + arrayOfContentParts = append(arrayOfContentParts, + openai.ChatCompletionContentPartUnionParam{ + OfImageURL: &openai.ChatCompletionContentPartImageParam{ + ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ + URL: img.Url, + }, + }, }, - }) + ) } } - // openai does not support messages with both content and multi-content - if len(res.MultiContent) > 0 && res.Content != "" { - res.MultiContent = append([]openai.ChatMessagePart{{ - Type: "text", - Text: res.Content, - }}, res.MultiContent...) - res.Content = "" + if m.Text != "" { + if len(arrayOfContentParts) > 0 { + arrayOfContentParts = append(arrayOfContentParts, + openai.ChatCompletionContentPartUnionParam{ + OfText: &openai.ChatCompletionContentPartTextParam{ + Text: "\n", + }, + }, + ) + } else { + textContent = openai.String(m.Text) + } } - return []openai.ChatCompletionMessage{res} + a := strings.Split(model, "-") + + useSystemInsteadOfDeveloper := true + if len(a) > 1 && a[0][0] == 'o' { + useSystemInsteadOfDeveloper = false + } + + switch m.Role { + case RoleSystem: + if useSystemInsteadOfDeveloper { + res = openai.ChatCompletionMessageParamUnion{ + OfSystem: &openai.ChatCompletionSystemMessageParam{ + Content: openai.ChatCompletionSystemMessageParamContentUnion{ + OfString: textContent, + }, + }, + } + } else { + res = openai.ChatCompletionMessageParamUnion{ + OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{ + Content: openai.ChatCompletionDeveloperMessageParamContentUnion{ + OfString: textContent, + }, + }, + } + } + + case RoleUser: + var name param.Opt[string] + if m.Name != "" { + name = openai.String(m.Name) + } + + res = openai.ChatCompletionMessageParamUnion{ + OfUser: &openai.ChatCompletionUserMessageParam{ + Name: name, + Content: openai.ChatCompletionUserMessageParamContentUnion{ + OfString: textContent, + OfArrayOfContentParts: arrayOfContentParts, + }, + }, + } + + case RoleAssistant: + var name param.Opt[string] + if m.Name != "" { + name = openai.String(m.Name) + } + + res = openai.ChatCompletionMessageParamUnion{ + OfAssistant: &openai.ChatCompletionAssistantMessageParam{ + Name: name, + Content: openai.ChatCompletionAssistantMessageParamContentUnion{ + OfString: textContent, + }, + }, + } + + } + + return []openai.ChatCompletionMessageParamUnion{res} } type ToolCall struct { @@ -134,10 +206,19 @@ func (t ToolCall) toRaw() map[string]any { return res } -func (t ToolCall) toChatCompletionMessages() []openai.ChatCompletionMessage { - return []openai.ChatCompletionMessage{{ - Role: openai.ChatMessageRoleTool, - ToolCallID: t.ID, +func (t ToolCall) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion { + return []openai.ChatCompletionMessageParamUnion{{ + OfAssistant: &openai.ChatCompletionAssistantMessageParam{ + ToolCalls: []openai.ChatCompletionMessageToolCallParam{ + { + ID: t.ID, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: t.FunctionCall.Name, + Arguments: t.FunctionCall.Arguments, + }, + }, + }, + }, }} } @@ -160,7 +241,7 @@ func (t ToolCallResponse) toRaw() map[string]any { return res } -func (t ToolCallResponse) toChatCompletionMessages() []openai.ChatCompletionMessage { +func (t ToolCallResponse) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion { var refusal string if t.Error != nil { refusal = t.Error.Error() @@ -174,10 +255,13 @@ func (t ToolCallResponse) toChatCompletionMessages() []openai.ChatCompletionMess } } - return []openai.ChatCompletionMessage{{ - Role: openai.ChatMessageRoleTool, - Content: fmt.Sprint(t.Result), - ToolCallID: t.ID, + return []openai.ChatCompletionMessageParamUnion{{ + OfTool: &openai.ChatCompletionToolMessageParam{ + ToolCallID: t.ID, + Content: openai.ChatCompletionToolMessageParamContentUnion{ + OfString: openai.String(fmt.Sprint(t.Result)), + }, + }, }} } diff --git a/openai.go b/openai.go index decbf55..38f3748 100644 --- a/openai.go +++ b/openai.go @@ -5,70 +5,69 @@ import ( "fmt" "strings" - oai "github.com/sashabaranov/go-openai" + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/shared" ) type openaiImpl struct { - key string - model string + key string + model string + baseUrl string } var _ LLM = openaiImpl{} -func (o openaiImpl) newRequestToOpenAIRequest(request Request) oai.ChatCompletionRequest { - res := oai.ChatCompletionRequest{ +func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatCompletionNewParams { + res := openai.ChatCompletionNewParams{ Model: o.model, } for _, i := range request.Conversation { - res.Messages = append(res.Messages, i.toChatCompletionMessages()...) + res.Messages = append(res.Messages, i.toChatCompletionMessages(o.model)...) } for _, msg := range request.Messages { - res.Messages = append(res.Messages, msg.toChatCompletionMessages()...) + res.Messages = append(res.Messages, msg.toChatCompletionMessages(o.model)...) } if request.Toolbox != nil { for _, tool := range request.Toolbox.funcs { - res.Tools = append(res.Tools, oai.Tool{ + res.Tools = append(res.Tools, openai.ChatCompletionToolParam{ Type: "function", - Function: &oai.FunctionDefinition{ + Function: shared.FunctionDefinitionParam{ Name: tool.Name, - Description: tool.Description, - Strict: tool.Strict, - Parameters: tool.Parameters.Definition(), + Description: openai.String(tool.Description), + Strict: openai.Bool(tool.Strict), + Parameters: tool.Parameters.FunctionParameters(), }, }) } if !request.Toolbox.dontRequireTool { - res.ToolChoice = "required" + res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{ + OfAuto: openai.String("required"), + } } } if request.Temperature != nil { - res.Temperature = *request.Temperature - } - - // is this an o1-* model? - isO1 := strings.Split(o.model, "-")[0] == "o1" - - if isO1 { - // o1 models do not support system messages, so if any messages are system messages, we need to convert them to - // user messages - - for i, msg := range res.Messages { - if msg.Role == "system" { - res.Messages[i].Role = "user" - } - } + res.Temperature = openai.Float(*request.Temperature) } return res } -func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) Response { - res := Response{} +func (o openaiImpl) responseToLLMResponse(response *openai.ChatCompletion) Response { + var res Response + + if response == nil { + return res + } + + if len(response.Choices) == 0 { + return res + } for _, choice := range response.Choices { var toolCalls []ToolCall @@ -77,7 +76,7 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R ID: call.ID, FunctionCall: FunctionCall{ Name: call.Function.Name, - Arguments: call.Function.Arguments, + Arguments: strings.TrimSpace(call.Function.Arguments), }, } @@ -87,7 +86,6 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R res.Choices = append(res.Choices, ResponseChoice{ Content: choice.Message.Content, Role: Role(choice.Message.Role), - Name: choice.Message.Name, Refusal: choice.Message.Refusal, Calls: toolCalls, }) @@ -97,11 +95,20 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R } func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) { - cl := oai.NewClient(o.key) + var opts = []option.RequestOption{ + option.WithAPIKey(o.key), + } + + if o.baseUrl != "" { + opts = append(opts, option.WithBaseURL(o.baseUrl)) + } + + cl := openai.NewClient(opts...) req := o.newRequestToOpenAIRequest(request) - resp, err := cl.CreateChatCompletion(ctx, req) + resp, err := cl.Chat.Completions.New(ctx, req) + //resp, err := cl.CreateChatCompletion(ctx, req) if err != nil { return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err) diff --git a/request.go b/request.go index e3831c2..961fae7 100644 --- a/request.go +++ b/request.go @@ -1,6 +1,8 @@ package go_llm -import "github.com/sashabaranov/go-openai" +import ( + "github.com/openai/openai-go" +) type rawAble interface { toRaw() map[string]any @@ -8,13 +10,13 @@ type rawAble interface { } type Input interface { - toChatCompletionMessages() []openai.ChatCompletionMessage + toChatCompletionMessages(model string) []openai.ChatCompletionMessageParamUnion } type Request struct { Conversation []Input Messages []Message Toolbox *ToolBox - Temperature *float32 + Temperature *float64 } // NextRequest will take the current request's conversation, messages, the response, and any tool results, and @@ -33,16 +35,8 @@ func (req Request) NextRequest(resp ResponseChoice, toolResults []ToolCallRespon res.Conversation = append(res.Conversation, msg) } - // if there are tool calls, then we need to add those to the conversation - for _, call := range resp.Calls { - res.Conversation = append(res.Conversation, call) - } - - if resp.Content != "" || resp.Refusal != "" { - res.Conversation = append(res.Conversation, Message{ - Role: RoleAssistant, - Text: resp.Content, - }) + if resp.Content != "" || resp.Refusal != "" || len(resp.Calls) > 0 { + res.Conversation = append(res.Conversation, resp) } // if there are tool results, then we need to add those to the conversation diff --git a/response.go b/response.go index 85160c3..dd193ef 100644 --- a/response.go +++ b/response.go @@ -1,6 +1,8 @@ package go_llm -import "github.com/sashabaranov/go-openai" +import ( + "github.com/openai/openai-go" +) type ResponseChoice struct { Index int @@ -30,24 +32,34 @@ func (r ResponseChoice) toRaw() map[string]any { return res } -func (r ResponseChoice) toChatCompletionMessages() []openai.ChatCompletionMessage { - var res = openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleAssistant, - Content: r.Content, - Refusal: r.Refusal, +func (r ResponseChoice) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion { + var as openai.ChatCompletionAssistantMessageParam + + if r.Name != "" { + as.Name = openai.String(r.Name) + } + if r.Refusal != "" { + as.Refusal = openai.String(r.Refusal) + } + + if r.Content != "" { + as.Content.OfString = openai.String(r.Content) } for _, call := range r.Calls { - res.ToolCalls = append(res.ToolCalls, openai.ToolCall{ - ID: call.ID, - Type: openai.ToolTypeFunction, - Function: openai.FunctionCall{ + as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{ + ID: call.ID, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ Name: call.FunctionCall.Name, Arguments: call.FunctionCall.Arguments, }, }) } - return []openai.ChatCompletionMessage{res} + return []openai.ChatCompletionMessageParamUnion{ + { + OfAssistant: &as, + }, + } } func (r ResponseChoice) toInput() []Input { diff --git a/schema/GetType.go b/schema/GetType.go index f2f7e16..a9f5950 100644 --- a/schema/GetType.go +++ b/schema/GetType.go @@ -28,22 +28,27 @@ func getFromType(t reflect.Type, b basic) Type { switch t.Kind() { case reflect.String: b.DataType = jsonschema.String + b.typeName = "string" return b case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: b.DataType = jsonschema.Integer + b.typeName = "integer" return b case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: b.DataType = jsonschema.Integer + b.typeName = "integer" return b case reflect.Float32, reflect.Float64: b.DataType = jsonschema.Number + b.typeName = "number" return b case reflect.Bool: b.DataType = jsonschema.Boolean + b.typeName = "boolean" return b case reflect.Struct: @@ -107,7 +112,7 @@ func getObject(t reflect.Type) object { } return object{ - basic: basic{DataType: jsonschema.Object}, + basic: basic{DataType: jsonschema.Object, typeName: "object"}, fields: fields, } } @@ -116,6 +121,7 @@ func getArray(t reflect.Type) array { res := array{ basic: basic{ DataType: jsonschema.Array, + typeName: "array", }, } diff --git a/schema/array.go b/schema/array.go index 6773da8..6417e82 100644 --- a/schema/array.go +++ b/schema/array.go @@ -2,6 +2,7 @@ package schema import ( "errors" + "github.com/openai/openai-go" "reflect" "github.com/sashabaranov/go-openai/jsonschema" @@ -18,6 +19,14 @@ func (a array) SchemaType() jsonschema.DataType { return jsonschema.Array } +func (a array) FunctionParameters() openai.FunctionParameters { + return openai.FunctionParameters{ + "type": "array", + "description": a.Description(), + "items": a.items.FunctionParameters(), + } +} + func (a array) Definition() jsonschema.Definition { def := a.basic.Definition() def.Type = jsonschema.Array diff --git a/schema/basic.go b/schema/basic.go index 0175c8f..b9a3d42 100644 --- a/schema/basic.go +++ b/schema/basic.go @@ -2,6 +2,7 @@ package schema import ( "errors" + "github.com/openai/openai-go" "reflect" "strconv" @@ -13,6 +14,7 @@ var _ Type = basic{} type basic struct { jsonschema.DataType + typeName string // index is the position of the parameter in the StructField of the function's parameter struct index int @@ -29,6 +31,13 @@ func (b basic) SchemaType() jsonschema.DataType { return b.DataType } +func (b basic) FunctionParameters() openai.FunctionParameters { + return openai.FunctionParameters{ + "type": b.typeName, + "description": b.description, + } +} + func (b basic) Definition() jsonschema.Definition { return jsonschema.Definition{ Type: b.DataType, diff --git a/schema/enum.go b/schema/enum.go index 272866f..3ec6b00 100644 --- a/schema/enum.go +++ b/schema/enum.go @@ -2,6 +2,7 @@ package schema import ( "errors" + "github.com/openai/openai-go" "reflect" "golang.org/x/exp/slices" @@ -19,6 +20,14 @@ func (e enum) SchemaType() jsonschema.DataType { return jsonschema.String } +func (e enum) FunctionParameters() openai.FunctionParameters { + return openai.FunctionParameters{ + "type": "string", + "description": e.Description(), + "enum": e.values, + } +} + func (e enum) Definition() jsonschema.Definition { def := e.basic.Definition() def.Enum = e.values diff --git a/schema/object.go b/schema/object.go index b0a0c1b..6930b4c 100644 --- a/schema/object.go +++ b/schema/object.go @@ -2,6 +2,7 @@ package schema import ( "errors" + "github.com/openai/openai-go" "reflect" "github.com/sashabaranov/go-openai/jsonschema" @@ -19,6 +20,19 @@ func (o object) SchemaType() jsonschema.DataType { return jsonschema.Object } +func (o object) FunctionParameters() openai.FunctionParameters { + var properties = map[string]openai.FunctionParameters{} + for k, v := range o.fields { + properties[k] = v.FunctionParameters() + } + + return openai.FunctionParameters{ + "type": "object", + "description": o.Description(), + "properties": properties, + } +} + func (o object) Definition() jsonschema.Definition { def := o.basic.Definition() def.Type = jsonschema.Object diff --git a/schema/type.go b/schema/type.go index 7f84216..5260b99 100644 --- a/schema/type.go +++ b/schema/type.go @@ -1,12 +1,15 @@ package schema import ( + "github.com/openai/openai-go" "reflect" "github.com/sashabaranov/go-openai/jsonschema" ) type Type interface { + FunctionParameters() openai.FunctionParameters + SchemaType() jsonschema.DataType Definition() jsonschema.Definition