Compare commits

...

19 Commits

Author SHA1 Message Date
39ffb82237 Add image handling for Gemini requests with URL download and base64 support 2025-04-12 15:07:44 -04:00
916f07be18 added anthropic tool support i think 2025-04-12 03:41:48 -04:00
3093b988f8 Refactor toolbox and function handling to support synthetic fields and improve type definitions 2025-04-12 02:20:40 -04:00
2ae583e9f3 Add support for required fields in parameter generation
Previously, required fields were not handled in OpenAI and Google parameter generation. This update adds logic to include a "required" list for both, ensuring mandatory fields are accurately captured in the schema outputs.
2025-04-07 20:33:21 -04:00
5ba0d5df7e instead of having an openai => google translation layer, just add sister functions to the types that construct the google request just like openai's 2025-04-07 01:57:02 -04:00
58552ee226 Fix enum typing 2025-04-06 15:35:05 -04:00
14961bfbc6 Refactor candidate parsing logic in Google adapter, which fixes only one tool call per execution 2025-04-06 14:35:22 -04:00
7c9eb08cb4 Add support for integers and tool configuration in schema handling
This update introduces support for `jsonschema.Integer` types and updates the logic to handle nested items in schemas. Added a new default error log for unknown types using `slog.Error`. Also, integrated tool configuration with a `FunctionCallingConfig` when `dontRequireTool` is false.
2025-04-06 01:23:10 -04:00
ff5e4ca7b0 Add support for integers and tool configuration in schema handling
This update introduces support for `jsonschema.Integer` types and updates the logic to handle nested items in schemas. Added a new default error log for unknown types using `slog.Error`. Also, integrated tool configuration with a `FunctionCallingConfig` when `dontRequireTool` is false.
2025-04-04 20:13:46 -04:00
82feb7d8b4 Change function result type from string to any
Updated the return type of functions and related code from `string` to `any` to improve flexibility and support more diverse outputs. Adjusted function implementations, signatures, and handling of results accordingly.
2025-03-25 23:53:09 -04:00
5ba42056ad Add toolbox features for function removal and callback execution
Introduced `WithFunctionRemoved` and `ExecuteCallbacks` methods to enhance `ToolBox` functionality. This allows dynamic function removal and execution of custom callbacks during tool call processing. Also cleaned up logging and improved handling for required tools in `openai.go`.
2025-03-21 11:09:32 -04:00
52533238d3 Add getter methods for response and toolcall in Context
Introduce `Response()` and `ToolCall()` methods to access the respective fields from the `Context` struct. This enhances encapsulation and provides a standardized way to retrieve these values.
2025-03-18 03:45:38 -04:00
88fbf89a63 Fix handling of OpenAI messages with content and multi-content.
Previously, OpenAI messages containing both `Content` and `MultiContent` could cause inconsistent behavior. This update ensures `Content` is converted into a `MultiContent` entry to maintain compatibility.
2025-03-18 01:01:46 -04:00
e5a046a70b Handle execution errors by appending them to the result.
Previously, execution errors were only returned in the refusal field. This update appends errors to the result field if present, ensuring they are included in the tool's output. This change improves visibility and clarity for error reporting.
2025-03-17 23:41:48 -04:00
2737a5b2be Refactor response handling for clarity and consistency.
Simplified how responses and tool calls are appended to conversations. Adjusted structure in message formatting to better align with tool call requirements, ensuring consistent data representation.
2025-03-17 00:18:32 -04:00
7f5e34e437 Refactor entire system to be more contextual so that conversation flow can be more easily managed 2025-03-16 22:38:58 -04:00
0d909edd44 Refactor Google LLM adapter to support tool schemas.
Enhanced the `requestToChatHistory` method to include OpenAI schema conversion logic and integrate tools with generative AI schemas. This change improves flexibility when working with different schema types and tool definitions. Adjusted response handling to return a modified model alongside chat sessions and parts.
2025-01-22 23:56:20 -05:00
388a44fa79 Refactor Google LLM API to use chat session interface.
Replace message handling with a chat session model, aligning the logic with new API requirements. Adjust functions to properly build chat history and send messages via chat sessions, improving compatibility and extensibility.
2025-01-22 22:07:20 -05:00
e7b7aab62e Refactor Toolbox handling in Google LLM integration.
Implemented a nil check for Toolbox to prevent potential nil pointer dereferences. Cleaned up and reorganized code for better readability and maintainability while keeping placeholder functionality intact.
2025-01-22 19:49:37 -05:00
18 changed files with 1201 additions and 359 deletions

View File

@ -123,7 +123,6 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
// if this has the same role as the previous message, we can append it to the previous message
// as anthropic expects alternating assistant and user roles
if len(msgs) > 0 && msgs[len(msgs)-1].Role == role {
m2 := &msgs[len(msgs)-1]
@ -134,20 +133,19 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
}
}
if req.Toolbox != nil {
for _, tool := range req.Toolbox.funcs {
for _, tool := range req.Toolbox.functions {
res.Tools = append(res.Tools, anth.ToolDefinition{
Name: tool.Name,
Description: tool.Description,
InputSchema: tool.Parameters,
InputSchema: tool.Parameters.AnthropicInputSchema(),
})
}
}
res.Messages = msgs
if req.Temperature != nil {
res.Temperature = req.Temperature
var f = float32(*req.Temperature)
res.Temperature = &f
}
log.Println("llm request to anthropic request", res)
@ -156,15 +154,13 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
}
func (a anthropic) responseToLLMResponse(in anth.MessagesResponse) Response {
res := Response{}
for _, msg := range in.Content {
choice := ResponseChoice{}
for _, msg := range in.Content {
switch msg.Type {
case anth.MessagesContentTypeText:
if msg.Text != nil {
choice.Content = *msg.Text
choice.Content += *msg.Text
}
case anth.MessagesContentTypeToolUse:
@ -183,13 +179,13 @@ func (a anthropic) responseToLLMResponse(in anth.MessagesResponse) Response {
}
}
}
res.Choices = append(res.Choices, choice)
}
log.Println("anthropic response to llm response", res)
log.Println("anthropic response to llm response", choice)
return res
return Response{
Choices: []ResponseChoice{choice},
}
}
func (a anthropic) ChatComplete(ctx context.Context, req Request) (Response, error) {

120
context.go Normal file
View File

@ -0,0 +1,120 @@
package go_llm
import (
"context"
"time"
)
type Context struct {
context.Context
request Request
response *ResponseChoice
toolcall *ToolCall
syntheticFields map[string]string
}
func (c *Context) ToNewRequest(toolResults ...ToolCallResponse) Request {
var res Request
res.Toolbox = c.request.Toolbox
res.Temperature = c.request.Temperature
res.Conversation = make([]Input, len(c.request.Conversation))
copy(res.Conversation, c.request.Conversation)
// now for every input message, convert those to an Input to add to the conversation
for _, msg := range c.request.Messages {
res.Conversation = append(res.Conversation, msg)
}
// if there are tool calls, then we need to add those to the conversation
if c.response != nil {
res.Conversation = append(res.Conversation, *c.response)
}
// if there are tool results, then we need to add those to the conversation
for _, result := range toolResults {
res.Conversation = append(res.Conversation, result)
}
return res
}
func NewContext(ctx context.Context, request Request, response *ResponseChoice, toolcall *ToolCall) *Context {
return &Context{Context: ctx, request: request, response: response, toolcall: toolcall}
}
func (c *Context) Request() Request {
return c.request
}
func (c *Context) Response() *ResponseChoice {
return c.response
}
func (c *Context) ToolCall() *ToolCall {
return c.toolcall
}
func (c *Context) SyntheticFields() map[string]string {
if c.syntheticFields == nil {
c.syntheticFields = map[string]string{}
}
return c.syntheticFields
}
func (c *Context) WithContext(ctx context.Context) *Context {
return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields}
}
func (c *Context) WithRequest(request Request) *Context {
return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields}
}
func (c *Context) WithResponse(response *ResponseChoice) *Context {
return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall, syntheticFields: c.syntheticFields}
}
func (c *Context) WithToolCall(toolcall *ToolCall) *Context {
return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall, syntheticFields: c.syntheticFields}
}
func (c *Context) WithSyntheticFields(syntheticFields map[string]string) *Context {
return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: syntheticFields}
}
func (c *Context) Deadline() (deadline time.Time, ok bool) {
return c.Context.Deadline()
}
func (c *Context) Done() <-chan struct{} {
return c.Context.Done()
}
func (c *Context) Err() error {
return c.Context.Err()
}
func (c *Context) Value(key any) any {
switch key {
case "request":
return c.request
case "response":
return c.response
case "toolcall":
return c.toolcall
case "syntheticFields":
return c.syntheticFields
}
return c.Context.Value(key)
}
func (c *Context) WithTimeout(timeout time.Duration) (*Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(c.Context, timeout)
return c.WithContext(ctx), cancel
}

View File

@ -4,11 +4,11 @@ import (
"context"
"encoding/json"
"fmt"
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
"log/slog"
"reflect"
"time"
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
)
type Function struct {
@ -26,27 +26,71 @@ type Function struct {
fn reflect.Value
paramType reflect.Type
// definition is a cache of the openaiImpl jsonschema definition
definition *jsonschema.Definition
}
func (f *Function) Execute(ctx context.Context, input string) (string, error) {
func (f Function) WithSyntheticField(name string, description string) Function {
if obj, o := f.Parameters.(schema.Object); o {
f.Parameters = obj.WithSyntheticField(name, description)
}
return f
}
func (f Function) WithSyntheticFields(fieldsAndDescriptions map[string]string) Function {
if obj, o := f.Parameters.(schema.Object); o {
for k, v := range fieldsAndDescriptions {
obj = obj.WithSyntheticField(k, v)
}
f.Parameters = obj
}
return f
}
func (f Function) Execute(ctx *Context, input string) (any, error) {
if !f.fn.IsValid() {
return "", fmt.Errorf("function %s is not implemented", f.Name)
}
slog.Info("Function.Execute", "name", f.Name, "input", input, "f", f.paramType)
// first, we need to parse the input into the struct
p := reflect.New(f.paramType)
fmt.Println("Function.Execute", f.Name, "input:", input)
//m := map[string]any{}
err := json.Unmarshal([]byte(input), p.Interface())
var vals map[string]any
err := json.Unmarshal([]byte(input), &vals)
var syntheticFields map[string]string
// first eat up any synthetic fields
if obj, o := f.Parameters.(schema.Object); o {
for k := range obj.SyntheticFields() {
key := schema.SyntheticFieldPrefix + k
if val, ok := vals[key]; ok {
if syntheticFields == nil {
syntheticFields = map[string]string{}
}
syntheticFields[k] = fmt.Sprint(val)
delete(vals, key)
}
}
}
// now for any remaining fields, re-marshal them into json and then unmarshal into the struct
b, err := json.Marshal(vals)
if err != nil {
return "", fmt.Errorf("failed to marshal input: %w (input: %s)", err, input)
}
// now we can unmarshal the input into the struct
err = json.Unmarshal(b, p.Interface())
if err != nil {
return "", fmt.Errorf("failed to unmarshal input: %w (input: %s)", err, input)
}
// now we can call the function
exec := func(ctx context.Context) (string, error) {
exec := func(ctx *Context) (any, error) {
out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
if len(out) != 2 {
@ -54,7 +98,7 @@ func (f *Function) Execute(ctx context.Context, input string) (string, error) {
}
if out[1].IsNil() {
return out[0].String(), nil
return out[0].Interface(), nil
}
return "", out[1].Interface().(error)
@ -62,31 +106,26 @@ func (f *Function) Execute(ctx context.Context, input string) (string, error) {
var cancel context.CancelFunc
if f.Timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, f.Timeout)
ctx, cancel = ctx.WithTimeout(f.Timeout)
defer cancel()
}
return exec(ctx)
}
func (f *Function) toOpenAIFunction() *openai.FunctionDefinition {
return &openai.FunctionDefinition{
Name: f.Name,
Description: f.Description,
Strict: f.Strict,
Parameters: f.Parameters,
}
}
func (f *Function) toOpenAIDefinition() jsonschema.Definition {
if f.definition == nil {
def := f.Parameters.Definition()
f.definition = &def
}
return *f.definition
}
type FunctionCall struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
}
func (fc *FunctionCall) toRaw() map[string]any {
res := map[string]interface{}{
"name": fc.Name,
}
if fc.Arguments != "" {
res["arguments"] = fc.Arguments
}
return res
}

View File

@ -1,9 +1,9 @@
package go_llm
import (
"context"
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
"reflect"
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
)
// Parse takes a function pointer and returns a function object.
@ -13,7 +13,7 @@ import (
// The struct parameters can have the following tags:
// - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for
func NewFunction[T any](name string, description string, fn func(context.Context, T) (string, error)) *Function {
func NewFunction[T any](name string, description string, fn func(*Context, T) (any, error)) Function {
var o T
res := Function{
@ -31,5 +31,5 @@ func NewFunction[T any](name string, description string, fn func(context.Context
panic("function parameter must be a struct")
}
return &res
return res
}

57
go.mod
View File

@ -4,41 +4,44 @@ go 1.23.1
require (
github.com/google/generative-ai-go v0.19.0
github.com/liushuangls/go-anthropic/v2 v2.13.0
github.com/sashabaranov/go-openai v1.36.1
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67
google.golang.org/api v0.214.0
github.com/liushuangls/go-anthropic/v2 v2.15.0
github.com/openai/openai-go v0.1.0-beta.9
google.golang.org/api v0.228.0
)
require (
cloud.google.com/go v0.117.0 // indirect
cloud.google.com/go/ai v0.9.0 // indirect
cloud.google.com/go/auth v0.13.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect
cloud.google.com/go v0.120.0 // indirect
cloud.google.com/go/ai v0.10.1 // indirect
cloud.google.com/go/auth v0.15.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.6.0 // indirect
cloud.google.com/go/longrunning v0.6.3 // indirect
cloud.google.com/go/longrunning v0.6.6 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 // indirect
go.opentelemetry.io/otel v1.33.0 // indirect
go.opentelemetry.io/otel/metric v1.33.0 // indirect
go.opentelemetry.io/otel/trace v1.33.0 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/oauth2 v0.24.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/time v0.8.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 // indirect
google.golang.org/grpc v1.69.2 // indirect
google.golang.org/protobuf v1.36.1 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect
go.opentelemetry.io/otel v1.35.0 // indirect
go.opentelemetry.io/otel/metric v1.35.0 // indirect
go.opentelemetry.io/otel/trace v1.35.0 // indirect
golang.org/x/crypto v0.37.0 // indirect
golang.org/x/net v0.39.0 // indirect
golang.org/x/oauth2 v0.29.0 // indirect
golang.org/x/sync v0.13.0 // indirect
golang.org/x/sys v0.32.0 // indirect
golang.org/x/text v0.24.0 // indirect
golang.org/x/time v0.11.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a // indirect
google.golang.org/grpc v1.71.1 // indirect
google.golang.org/protobuf v1.36.6 // indirect
)

130
go.sum
View File

@ -1,15 +1,15 @@
cloud.google.com/go v0.117.0 h1:Z5TNFfQxj7WG2FgOGX1ekC5RiXrYgms6QscOm32M/4s=
cloud.google.com/go v0.117.0/go.mod h1:ZbwhVTb1DBGt2Iwb3tNO6SEK4q+cplHZmLWH+DelYYc=
cloud.google.com/go/ai v0.9.0 h1:r1Ig8O8+Qr3Ia3WfoO+gokD0fxB2Rk4quppuKjmGMsY=
cloud.google.com/go/ai v0.9.0/go.mod h1:28bKM/oxmRgxmRgI1GLumFv+NSkt+DscAg/gF+54zzY=
cloud.google.com/go/auth v0.13.0 h1:8Fu8TZy167JkW8Tj3q7dIkr2v4cndv41ouecJx0PAHs=
cloud.google.com/go/auth v0.13.0/go.mod h1:COOjD9gwfKNKz+IIduatIhYJQIc0mG3H102r/EMxX6Q=
cloud.google.com/go/auth/oauth2adapt v0.2.6 h1:V6a6XDu2lTwPZWOawrAa9HUK+DB2zfJyTuciBG5hFkU=
cloud.google.com/go/auth/oauth2adapt v0.2.6/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8=
cloud.google.com/go v0.120.0 h1:wc6bgG9DHyKqF5/vQvX1CiZrtHnxJjBlKUyF9nP6meA=
cloud.google.com/go v0.120.0/go.mod h1:/beW32s8/pGRuj4IILWQNd4uuebeT4dkOhKmkfit64Q=
cloud.google.com/go/ai v0.10.1 h1:EU93KqYmMeOKgaBXAz2DshH2C/BzAT1P+iJORksLIic=
cloud.google.com/go/ai v0.10.1/go.mod h1:sWWHZvmJ83BjuxAQtYEiA0SFTpijtbH+SXWFO14ri5A=
cloud.google.com/go/auth v0.15.0 h1:Ly0u4aA5vG/fsSsxu98qCQBemXtAtJf+95z9HK+cxps=
cloud.google.com/go/auth v0.15.0/go.mod h1:WJDGqZ1o9E9wKIL+IwStfyn/+s59zl4Bi+1KQNVXLZ8=
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
cloud.google.com/go/longrunning v0.6.3 h1:A2q2vuyXysRcwzqDpMMLSI6mb6o39miS52UEG/Rd2ng=
cloud.google.com/go/longrunning v0.6.3/go.mod h1:k/vIs83RN4bE3YCswdXC5PFfWVILjm3hpEUlSko4PiI=
cloud.google.com/go/longrunning v0.6.6 h1:XJNDo5MUfMM05xK3ewpbSdmt7R2Zw+aQEMbdQR65Rbw=
cloud.google.com/go/longrunning v0.6.6/go.mod h1:hyeGJUrPHcx0u2Uu1UFSoYZLn4lkMrccJig0t4FI7yw=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
@ -23,67 +23,73 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg=
github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM=
github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw=
github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4=
github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA=
github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q=
github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA=
github.com/liushuangls/go-anthropic/v2 v2.13.0 h1:f7KJ54IHxIpHPPhrCzs3SrdP2PfErXiJcJn7DUVstSA=
github.com/liushuangls/go-anthropic/v2 v2.13.0/go.mod h1:5ZwRLF5TQ+y5s/MC9Z1IJYx9WUFgQCKfqFM2xreIQLk=
github.com/liushuangls/go-anthropic/v2 v2.15.0 h1:zpplg7BRV/9FlMmeMPI0eDwhViB0l9SkNrF8ErYlRoQ=
github.com/liushuangls/go-anthropic/v2 v2.15.0/go.mod h1:kq2yW3JVy1/rph8u5KzX7F3q95CEpCT2RXp/2nfCmb4=
github.com/openai/openai-go v0.1.0-beta.9 h1:ABpubc5yU/3ejee2GgRrbFta81SG/d7bQbB8mIdP0Xo=
github.com/openai/openai-go v0.1.0-beta.9/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sashabaranov/go-openai v1.36.0 h1:fcSrn8uGuorzPWCBp8L0aCR95Zjb/Dd+ZSML0YZy9EI=
github.com/sashabaranov/go-openai v1.36.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.36.1 h1:EVfRXwIlW2rUzpx6vR+aeIKCK/xylSrVYAx1TMTSX3g=
github.com/sashabaranov/go-openai v1.36.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0 h1:PS8wXpbyaDJQ2VDHHncMe9Vct0Zn1fEjpsjrLxGJoSc=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0/go.mod h1:HDBUsEjOuRC0EzKZ1bSaRGZWUBAzo+MhAcUUORSr4D0=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 h1:yd02MEjBdJkG3uabWP9apV+OuWRIXGDuJEUJbOHmCFU=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0/go.mod h1:umTcuxiv1n/s/S6/c2AT/g2CQ7u5C59sHDNmfSwgz7Q=
go.opentelemetry.io/otel v1.33.0 h1:/FerN9bax5LoK51X/sI0SVYrjSE0/yUL7DpxW4K3FWw=
go.opentelemetry.io/otel v1.33.0/go.mod h1:SUUkR6csvUQl+yjReHu5uM3EtVV7MBm5FHKRlNx4I8I=
go.opentelemetry.io/otel/metric v1.33.0 h1:r+JOocAyeRVXD8lZpjdQjzMadVZp2M4WmQ+5WtEnklQ=
go.opentelemetry.io/otel/metric v1.33.0/go.mod h1:L9+Fyctbp6HFTddIxClbQkjtubW6O9QS3Ann/M82u6M=
go.opentelemetry.io/otel/sdk v1.31.0 h1:xLY3abVHYZ5HSfOg3l2E5LUj2Cwva5Y7yGxnSW9H5Gk=
go.opentelemetry.io/otel/sdk v1.31.0/go.mod h1:TfRbMdhvxIIr/B2N2LQW2S5v9m3gOQ/08KsbbO5BPT0=
go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4JjxTeYusH7zMc=
go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8=
go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qqW2d/s=
go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 h1:1UoZQm6f0P/ZO0w1Ri+f+ifG/gXhegadRdwBIXEFWDo=
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg=
golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
google.golang.org/api v0.214.0 h1:h2Gkq07OYi6kusGOaT/9rnNljuXmqPnaig7WGPmKbwA=
google.golang.org/api v0.214.0/go.mod h1:bYPpLG8AyeMWwDU6NXoB00xC0DFkikVvd5MfwoxjLqE=
google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8 h1:st3LcW/BPi75W4q1jJTEor/QWwbNlPlDG0JTn6XhZu0=
google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8/go.mod h1:klhJGKFyG8Tn50enBn7gizg4nXGXJ+jqEREdCWaPcV4=
google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 h1:TqExAhdPaB60Ux47Cn0oLV07rGnxZzIsaRhQaqS666A=
google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8/go.mod h1:lcTa1sDdWEIHMWlITnIczmw5w60CF9ffkb8Z+DVmmjA=
google.golang.org/grpc v1.69.2 h1:U3S9QEtbXC0bYNvRtcoklF3xGtLViumSYxWykJS+7AU=
google.golang.org/grpc v1.69.2/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4=
google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk=
google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 h1:x7wzEgXfnzJcHDwStJT+mxOz4etr2EcexjqhBvmoakw=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0/go.mod h1:rg+RlpR5dKwaS95IyyZqj5Wd4E13lk/msnTS0Xl9lJM=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 h1:sbiXRNDSWJOTobXh5HyQKjq6wUC5tNybqjIqDpAY4CU=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0/go.mod h1:69uWxva0WgAA/4bu2Yy70SLDBwZXuQ6PbBpbsa5iZrQ=
go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y=
go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M=
go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE=
go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY=
go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg=
go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o=
go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w=
go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs=
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98=
golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
google.golang.org/api v0.228.0 h1:X2DJ/uoWGnY5obVjewbp8icSL5U4FzuCfy9OjbLSnLs=
google.golang.org/api v0.228.0/go.mod h1:wNvRS1Pbe8r4+IfBIniV8fwCpGwTrYa+kMUDiC5z5a4=
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a h1:OQ7sHVzkx6L57dQpzUS4ckfWJ51KDH74XHTDe23xWAs=
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a/go.mod h1:2R6XrVC8Oc08GlNh8ujEpc7HkLiEZ16QeY7FxIs20ac=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a h1:GIqLhp/cYUkuGuiT+vJk8vhOP86L4+SP5j8yXgeVpvI=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI=
google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

139
google.go
View File

@ -2,8 +2,12 @@ package go_llm
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/option"
)
@ -19,51 +23,117 @@ func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) {
return g, nil
}
func (g google) requestToGoogleRequest(in Request, model *genai.GenerativeModel) []genai.Part {
func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.GenerativeModel, *genai.ChatSession, []genai.Part) {
res := *model
if in.Temperature != nil {
model.GenerationConfig.Temperature = in.Temperature
}
res := []genai.Part{}
for _, c := range in.Messages {
res = append(res, genai.Text(c.Text))
}
for _, tool := range in.Toolbox.funcs {
panic("google ToolBox is todo" + tool.Name)
/*
t := genai.Tool{}
t.FunctionDeclarations = append(t.FunctionDeclarations, &genai.FunctionDeclaration{
for _, tool := range in.Toolbox.functions {
res.Tools = append(res.Tools, &genai.Tool{
FunctionDeclarations: []*genai.FunctionDeclaration{
{
Name: tool.Name,
Description: tool.Description,
Parameters: nil, //tool.Parameters,
Parameters: tool.Parameters.GoogleParameters(),
},
},
})
*/
}
return res
if !in.Toolbox.RequiresTool() {
res.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{
Mode: genai.FunctionCallingAny,
}}
}
cs := res.StartChat()
for i, c := range in.Messages {
content := genai.NewUserContent(genai.Text(c.Text))
switch c.Role {
case RoleAssistant, RoleSystem:
content.Role = "model"
case RoleUser:
content.Role = "user"
}
for _, img := range c.Images {
if img.Url != "" {
// gemini does not support URLs, so we need to download the image and convert it to a blob
// Download the image from the URL
resp, err := http.Get(img.Url)
if err != nil {
panic(fmt.Sprintf("error downloading image: %v", err))
}
defer resp.Body.Close()
// Check the Content-Length to ensure it's not over 20MB
if resp.ContentLength > 20*1024*1024 {
panic(fmt.Sprintf("image size exceeds 20MB: %d bytes", resp.ContentLength))
}
// Read the content into a byte slice
data, err := io.ReadAll(resp.Body)
if err != nil {
panic(fmt.Sprintf("error reading image data: %v", err))
}
// Ensure the MIME type is appropriate
mimeType := http.DetectContentType(data)
switch mimeType {
case "image/jpeg", "image/png", "image/gif":
// MIME type is valid
default:
panic(fmt.Sprintf("unsupported image MIME type: %s", mimeType))
}
// Create a genai.Blob using the validated image data
content.Parts = append(content.Parts, genai.Blob{
MIMEType: mimeType,
Data: data,
})
} else {
// convert base64 to blob
b, e := base64.StdEncoding.DecodeString(img.Base64)
if e != nil {
panic(fmt.Sprintf("error decoding base64: %v", e))
}
content.Parts = append(content.Parts, genai.Blob{
MIMEType: img.ContentType,
Data: b,
})
}
}
// if this is the last message, we want to add to history, we want it to be the parts
if i == len(in.Messages)-1 {
return &res, cs, content.Parts
}
cs.History = append(cs.History, content)
}
return &res, cs, nil
}
func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) {
res := Response{}
for _, c := range in.Candidates {
var choice ResponseChoice
var set = false
if c.Content != nil {
for _, p := range c.Content.Parts {
switch p.(type) {
case genai.Text:
res.Choices = append(res.Choices, ResponseChoice{
Content: string(p.(genai.Text)),
})
choice.Content = string(p.(genai.Text))
set = true
case genai.FunctionCall:
v := p.(genai.FunctionCall)
choice := ResponseChoice{}
choice.Content = v.Name
b, e := json.Marshal(v.Args)
if e != nil {
@ -79,14 +149,17 @@ func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Respon
}
choice.Calls = append(choice.Calls, call)
res.Choices = append(res.Choices, choice)
set = true
default:
return Response{}, fmt.Errorf("unknown part type: %T", p)
}
}
}
if set {
choice.Role = RoleAssistant
res.Choices = append(res.Choices, choice)
}
}
return res, nil
@ -101,9 +174,13 @@ func (g google) ChatComplete(ctx context.Context, req Request) (Response, error)
model := cl.GenerativeModel(g.model)
parts := g.requestToGoogleRequest(req, model)
_, cs, parts := g.requestToChatHistory(req, model)
resp, err := model.GenerateContent(ctx, parts...)
resp, err := cs.SendMessage(ctx, parts...)
//parts := g.requestToGoogleRequest(req, model)
//resp, err := model.GenerateContent(ctx, parts...)
if err != nil {
return Response{}, fmt.Errorf("error generating content: %w", err)

243
llm.go
View File

@ -2,6 +2,11 @@ package go_llm
import (
"context"
"fmt"
"strings"
"github.com/openai/openai-go"
"github.com/openai/openai-go/packages/param"
)
type Role string
@ -18,6 +23,26 @@ type Image struct {
Url string
}
func (i Image) toRaw() map[string]any {
res := map[string]any{
"base64": i.Base64,
"contenttype": i.ContentType,
"url": i.Url,
}
return res
}
func (i *Image) fromRaw(raw map[string]any) Image {
var res Image
res.Base64 = raw["base64"].(string)
res.ContentType = raw["contenttype"].(string)
res.Url = raw["url"].(string)
return res
}
type Message struct {
Role Role
Name string
@ -25,10 +50,145 @@ type Message struct {
Images []Image
}
type Request struct {
Messages []Message
Toolbox *ToolBox
Temperature *float32
func (m Message) toRaw() map[string]any {
res := map[string]any{
"role": m.Role,
"name": m.Name,
"text": m.Text,
}
images := make([]map[string]any, 0, len(m.Images))
for _, img := range m.Images {
images = append(images, img.toRaw())
}
res["images"] = images
return res
}
func (m *Message) fromRaw(raw map[string]any) Message {
var res Message
res.Role = Role(raw["role"].(string))
res.Name = raw["name"].(string)
res.Text = raw["text"].(string)
images := raw["images"].([]map[string]any)
for _, img := range images {
var i Image
res.Images = append(res.Images, i.fromRaw(img))
}
return res
}
func (m Message) toChatCompletionMessages(model string) []openai.ChatCompletionMessageParamUnion {
var res openai.ChatCompletionMessageParamUnion
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
var textContent param.Opt[string]
for _, img := range m.Images {
if img.Base64 != "" {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfImageURL: &openai.ChatCompletionContentPartImageParam{
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
URL: "data:" + img.ContentType + ";base64," + img.Base64,
},
},
},
)
} else if img.Url != "" {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfImageURL: &openai.ChatCompletionContentPartImageParam{
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
URL: img.Url,
},
},
},
)
}
}
if m.Text != "" {
if len(arrayOfContentParts) > 0 {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfText: &openai.ChatCompletionContentPartTextParam{
Text: "\n",
},
},
)
} else {
textContent = openai.String(m.Text)
}
}
a := strings.Split(model, "-")
useSystemInsteadOfDeveloper := true
if len(a) > 1 && a[0][0] == 'o' {
useSystemInsteadOfDeveloper = false
}
switch m.Role {
case RoleSystem:
if useSystemInsteadOfDeveloper {
res = openai.ChatCompletionMessageParamUnion{
OfSystem: &openai.ChatCompletionSystemMessageParam{
Content: openai.ChatCompletionSystemMessageParamContentUnion{
OfString: textContent,
},
},
}
} else {
res = openai.ChatCompletionMessageParamUnion{
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
OfString: textContent,
},
},
}
}
case RoleUser:
var name param.Opt[string]
if m.Name != "" {
name = openai.String(m.Name)
}
res = openai.ChatCompletionMessageParamUnion{
OfUser: &openai.ChatCompletionUserMessageParam{
Name: name,
Content: openai.ChatCompletionUserMessageParamContentUnion{
OfString: textContent,
OfArrayOfContentParts: arrayOfContentParts,
},
},
}
case RoleAssistant:
var name param.Opt[string]
if m.Name != "" {
name = openai.String(m.Name)
}
res = openai.ChatCompletionMessageParamUnion{
OfAssistant: &openai.ChatCompletionAssistantMessageParam{
Name: name,
Content: openai.ChatCompletionAssistantMessageParamContentUnion{
OfString: textContent,
},
},
}
}
return []openai.ChatCompletionMessageParamUnion{res}
}
type ToolCall struct {
@ -36,16 +196,73 @@ type ToolCall struct {
FunctionCall FunctionCall
}
type ResponseChoice struct {
Index int
Role Role
Content string
Refusal string
Name string
Calls []ToolCall
func (t ToolCall) toRaw() map[string]any {
res := map[string]any{
"id": t.ID,
}
type Response struct {
Choices []ResponseChoice
res["function"] = t.FunctionCall.toRaw()
return res
}
func (t ToolCall) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion {
return []openai.ChatCompletionMessageParamUnion{{
OfAssistant: &openai.ChatCompletionAssistantMessageParam{
ToolCalls: []openai.ChatCompletionMessageToolCallParam{
{
ID: t.ID,
Function: openai.ChatCompletionMessageToolCallFunctionParam{
Name: t.FunctionCall.Name,
Arguments: t.FunctionCall.Arguments,
},
},
},
},
}}
}
type ToolCallResponse struct {
ID string
Result any
Error error
}
func (t ToolCallResponse) toRaw() map[string]any {
res := map[string]any{
"id": t.ID,
"result": t.Result,
}
if t.Error != nil {
res["error"] = t.Error.Error()
}
return res
}
func (t ToolCallResponse) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion {
var refusal string
if t.Error != nil {
refusal = t.Error.Error()
}
if refusal != "" {
if t.Result != "" {
t.Result = fmt.Sprint(t.Result) + " (error in execution: " + refusal + ")"
} else {
t.Result = "error in execution:" + refusal
}
}
return []openai.ChatCompletionMessageParamUnion{{
OfTool: &openai.ChatCompletionToolMessageParam{
ToolCallID: t.ID,
Content: openai.ChatCompletionToolMessageParamContentUnion{
OfString: openai.String(fmt.Sprint(t.Result)),
},
},
}}
}
type ChatCompletion interface {

116
openai.go
View File

@ -5,120 +5,85 @@ import (
"fmt"
"strings"
oai "github.com/sashabaranov/go-openai"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/shared"
)
type openaiImpl struct {
key string
model string
baseUrl string
}
var _ LLM = openaiImpl{}
func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
res := oai.ChatCompletionRequest{
func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatCompletionNewParams {
res := openai.ChatCompletionNewParams{
Model: o.model,
}
for _, i := range request.Conversation {
res.Messages = append(res.Messages, i.toChatCompletionMessages(o.model)...)
}
for _, msg := range request.Messages {
m := oai.ChatCompletionMessage{
Content: msg.Text,
Role: string(msg.Role),
Name: msg.Name,
res.Messages = append(res.Messages, msg.toChatCompletionMessages(o.model)...)
}
for _, img := range msg.Images {
if img.Base64 != "" {
m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
Type: "image_url",
ImageURL: &oai.ChatMessageImageURL{
URL: fmt.Sprintf("data:%s;base64,%s", img.ContentType, img.Base64),
},
})
} else if img.Url != "" {
m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
Type: "image_url",
ImageURL: &oai.ChatMessageImageURL{
URL: img.Url,
},
})
}
}
// openai does not allow Content and MultiContent to be set at the same time, so we need to check
if len(m.MultiContent) > 0 && m.Content != "" {
m.MultiContent = append([]oai.ChatMessagePart{{
Type: "text",
Text: m.Content,
}}, m.MultiContent...)
m.Content = ""
}
res.Messages = append(res.Messages, m)
}
if request.Toolbox != nil {
for _, tool := range request.Toolbox.funcs {
res.Tools = append(res.Tools, oai.Tool{
for _, tool := range request.Toolbox.functions {
res.Tools = append(res.Tools, openai.ChatCompletionToolParam{
Type: "function",
Function: &oai.FunctionDefinition{
Function: shared.FunctionDefinitionParam{
Name: tool.Name,
Description: tool.Description,
Strict: tool.Strict,
Parameters: tool.Parameters.Definition(),
Description: openai.String(tool.Description),
Strict: openai.Bool(tool.Strict),
Parameters: tool.Parameters.OpenAIParameters(),
},
})
}
fmt.Println("tool:", tool.Name, tool.Description, tool.Strict, tool.Parameters.Definition())
if request.Toolbox.RequiresTool() {
res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{
OfAuto: openai.String("required"),
}
}
if request.Temperature != nil {
res.Temperature = *request.Temperature
}
// is this an o1-* model?
isO1 := strings.Split(o.model, "-")[0] == "o1"
if isO1 {
// o1 models do not support system messages, so if any messages are system messages, we need to convert them to
// user messages
for i, msg := range res.Messages {
if msg.Role == "system" {
res.Messages[i].Role = "user"
}
}
res.Temperature = openai.Float(*request.Temperature)
}
return res
}
func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
res := Response{}
func (o openaiImpl) responseToLLMResponse(response *openai.ChatCompletion) Response {
var res Response
if response == nil {
return res
}
if len(response.Choices) == 0 {
return res
}
for _, choice := range response.Choices {
var toolCalls []ToolCall
for _, call := range choice.Message.ToolCalls {
fmt.Println("responseToLLMResponse: call:", call.Function.Arguments)
toolCall := ToolCall{
ID: call.ID,
FunctionCall: FunctionCall{
Name: call.Function.Name,
Arguments: call.Function.Arguments,
Arguments: strings.TrimSpace(call.Function.Arguments),
},
}
fmt.Println("toolCall.FunctionCall.Arguments:", toolCall.FunctionCall.Arguments)
toolCalls = append(toolCalls, toolCall)
}
res.Choices = append(res.Choices, ResponseChoice{
Content: choice.Message.Content,
Role: Role(choice.Message.Role),
Name: choice.Message.Name,
Refusal: choice.Message.Refusal,
Calls: toolCalls,
})
@ -128,13 +93,20 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
}
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
cl := oai.NewClient(o.key)
var opts = []option.RequestOption{
option.WithAPIKey(o.key),
}
req := o.requestToOpenAIRequest(request)
if o.baseUrl != "" {
opts = append(opts, option.WithBaseURL(o.baseUrl))
}
resp, err := cl.CreateChatCompletion(ctx, req)
cl := openai.NewClient(opts...)
fmt.Println("resp:", fmt.Sprintf("%#v", resp))
req := o.newRequestToOpenAIRequest(request)
resp, err := cl.Chat.Completions.New(ctx, req)
//resp, err := cl.CreateChatCompletion(ctx, req)
if err != nil {
return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)

48
request.go Normal file
View File

@ -0,0 +1,48 @@
package go_llm
import (
"github.com/openai/openai-go"
)
type rawAble interface {
toRaw() map[string]any
fromRaw(raw map[string]any) Input
}
type Input interface {
toChatCompletionMessages(model string) []openai.ChatCompletionMessageParamUnion
}
type Request struct {
Conversation []Input
Messages []Message
Toolbox ToolBox
Temperature *float64
}
// NextRequest will take the current request's conversation, messages, the response, and any tool results, and
// return a new request with the conversation updated to include the response and tool results.
func (req Request) NextRequest(resp ResponseChoice, toolResults []ToolCallResponse) Request {
var res Request
res.Toolbox = req.Toolbox
res.Temperature = req.Temperature
res.Conversation = make([]Input, len(req.Conversation))
copy(res.Conversation, req.Conversation)
// now for every input message, convert those to an Input to add to the conversation
for _, msg := range req.Messages {
res.Conversation = append(res.Conversation, msg)
}
if resp.Content != "" || resp.Refusal != "" || len(resp.Calls) > 0 {
res.Conversation = append(res.Conversation, resp)
}
// if there are tool results, then we need to add those to the conversation
for _, result := range toolResults {
res.Conversation = append(res.Conversation, result)
}
return res
}

84
response.go Normal file
View File

@ -0,0 +1,84 @@
package go_llm
import (
"github.com/openai/openai-go"
)
type ResponseChoice struct {
Index int
Role Role
Content string
Refusal string
Name string
Calls []ToolCall
}
func (r ResponseChoice) toRaw() map[string]any {
res := map[string]any{
"index": r.Index,
"role": r.Role,
"content": r.Content,
"refusal": r.Refusal,
"name": r.Name,
}
calls := make([]map[string]any, 0, len(r.Calls))
for _, call := range r.Calls {
calls = append(calls, call.toRaw())
}
res["tool_calls"] = calls
return res
}
func (r ResponseChoice) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion {
var as openai.ChatCompletionAssistantMessageParam
if r.Name != "" {
as.Name = openai.String(r.Name)
}
if r.Refusal != "" {
as.Refusal = openai.String(r.Refusal)
}
if r.Content != "" {
as.Content.OfString = openai.String(r.Content)
}
for _, call := range r.Calls {
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
ID: call.ID,
Function: openai.ChatCompletionMessageToolCallFunctionParam{
Name: call.FunctionCall.Name,
Arguments: call.FunctionCall.Arguments,
},
})
}
return []openai.ChatCompletionMessageParamUnion{
{
OfAssistant: &as,
},
}
}
func (r ResponseChoice) toInput() []Input {
var res []Input
for _, call := range r.Calls {
res = append(res, call)
}
if r.Content != "" || r.Refusal != "" {
res = append(res, Message{
Role: RoleAssistant,
Text: r.Content,
})
}
return res
}
type Response struct {
Choices []ResponseChoice
}

View File

@ -3,8 +3,6 @@ package schema
import (
"reflect"
"strings"
"github.com/sashabaranov/go-openai/jsonschema"
)
// GetType will, given an interface{} that is a struct (NOT a pointer to a struct), return the Type of the struct that
@ -27,23 +25,28 @@ func getFromType(t reflect.Type, b basic) Type {
switch t.Kind() {
case reflect.String:
b.DataType = jsonschema.String
b.DataType = TypeString
b.typeName = "string"
return b
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
b.DataType = jsonschema.Integer
b.DataType = TypeInteger
b.typeName = "integer"
return b
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
b.DataType = jsonschema.Integer
b.DataType = TypeInteger
b.typeName = "integer"
return b
case reflect.Float32, reflect.Float64:
b.DataType = jsonschema.Number
b.DataType = TypeNumber
b.typeName = "number"
return b
case reflect.Bool:
b.DataType = jsonschema.Boolean
b.DataType = TypeBoolean
b.typeName = "boolean"
return b
case reflect.Struct:
@ -89,6 +92,8 @@ func getField(f reflect.StructField, index int) Type {
}
}
b.DataType = TypeString
b.typeName = "string"
return enum{
basic: b,
values: vals,
@ -99,15 +104,26 @@ func getField(f reflect.StructField, index int) Type {
return getFromType(t, b)
}
func getObject(t reflect.Type) object {
func getObject(t reflect.Type) Object {
fields := make(map[string]Type, t.NumField())
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if field.Anonymous {
// if the field is anonymous, we need to get the fields of the anonymous struct
// and add them to the object
anon := getObject(field.Type)
for k, v := range anon.fields {
fields[k] = v
}
continue
} else {
fields[field.Name] = getField(field, i)
}
}
return object{
basic: basic{DataType: jsonschema.Object},
return Object{
basic: basic{DataType: TypeObject, typeName: "object"},
fields: fields,
}
}
@ -115,7 +131,8 @@ func getObject(t reflect.Type) object {
func getArray(t reflect.Type) array {
res := array{
basic: basic{
DataType: jsonschema.Array,
DataType: TypeArray,
typeName: "array",
},
}

View File

@ -4,7 +4,8 @@ import (
"errors"
"reflect"
"github.com/sashabaranov/go-openai/jsonschema"
"github.com/google/generative-ai-go/genai"
"github.com/openai/openai-go"
)
type array struct {
@ -14,17 +15,28 @@ type array struct {
items Type
}
func (a array) SchemaType() jsonschema.DataType {
return jsonschema.Array
func (a array) OpenAIParameters() openai.FunctionParameters {
return openai.FunctionParameters{
"type": "array",
"description": a.Description(),
"items": a.items.OpenAIParameters(),
}
}
func (a array) Definition() jsonschema.Definition {
def := a.basic.Definition()
def.Type = jsonschema.Array
i := a.items.Definition()
def.Items = &i
def.AdditionalProperties = false
return def
func (a array) GoogleParameters() *genai.Schema {
return &genai.Schema{
Type: genai.TypeArray,
Description: a.Description(),
Items: a.items.GoogleParameters(),
}
}
func (a array) AnthropicInputSchema() map[string]any {
return map[string]any{
"type": "array",
"description": a.Description(),
"items": a.items.AnthropicInputSchema(),
}
}
func (a array) FromAny(val any) (reflect.Value, error) {

View File

@ -5,14 +5,27 @@ import (
"reflect"
"strconv"
"github.com/sashabaranov/go-openai/jsonschema"
"github.com/google/generative-ai-go/genai"
"github.com/openai/openai-go"
)
// just enforcing that basic implements Type
var _ Type = basic{}
type DataType string
const (
TypeString DataType = "string"
TypeInteger DataType = "integer"
TypeNumber DataType = "number"
TypeBoolean DataType = "boolean"
TypeObject DataType = "object"
TypeArray DataType = "array"
)
type basic struct {
jsonschema.DataType
DataType
typeName string
// index is the position of the parameter in the StructField of the function's parameter struct
index int
@ -25,17 +38,64 @@ type basic struct {
description string
}
func (b basic) SchemaType() jsonschema.DataType {
return b.DataType
func (b basic) OpenAIParameters() openai.FunctionParameters {
return openai.FunctionParameters{
"type": b.typeName,
"description": b.description,
}
}
func (b basic) Definition() jsonschema.Definition {
return jsonschema.Definition{
Type: b.DataType,
func (b basic) GoogleParameters() *genai.Schema {
var t = genai.TypeUnspecified
switch b.DataType {
case TypeString:
t = genai.TypeString
case TypeInteger:
t = genai.TypeInteger
case TypeNumber:
t = genai.TypeNumber
case TypeBoolean:
t = genai.TypeBoolean
case TypeObject:
t = genai.TypeObject
case TypeArray:
t = genai.TypeArray
default:
t = genai.TypeUnspecified
}
return &genai.Schema{
Type: t,
Description: b.description,
}
}
func (b basic) AnthropicInputSchema() map[string]any {
var t = "string"
switch b.DataType {
case TypeString:
t = "string"
case TypeInteger:
t = "integer"
case TypeNumber:
t = "number"
case TypeBoolean:
t = "boolean"
case TypeObject:
t = "object"
case TypeArray:
t = "array"
default:
t = "unknown"
}
return map[string]any{
"type": t,
"description": b.description,
}
}
func (b basic) Required() bool {
return b.required
}
@ -48,12 +108,12 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
v := reflect.ValueOf(val)
switch b.DataType {
case jsonschema.String:
case TypeString:
var val = v.String()
return reflect.ValueOf(val), nil
case jsonschema.Integer:
case TypeInteger:
if v.Kind() == reflect.Float64 {
return v.Convert(reflect.TypeOf(int(0))), nil
} else if v.Kind() != reflect.Int {
@ -62,7 +122,7 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
return v, nil
}
case jsonschema.Number:
case TypeNumber:
if v.Kind() == reflect.Float64 {
return v.Convert(reflect.TypeOf(float64(0))), nil
} else if v.Kind() != reflect.Float64 {
@ -71,7 +131,7 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
return v, nil
}
case jsonschema.Boolean:
case TypeBoolean:
if v.Kind() == reflect.Bool {
return v, nil
} else if v.Kind() == reflect.String {

View File

@ -3,10 +3,10 @@ package schema
import (
"errors"
"reflect"
"slices"
"golang.org/x/exp/slices"
"github.com/sashabaranov/go-openai/jsonschema"
"github.com/google/generative-ai-go/genai"
"github.com/openai/openai-go"
)
type enum struct {
@ -15,14 +15,28 @@ type enum struct {
values []string
}
func (e enum) SchemaType() jsonschema.DataType {
return jsonschema.String
func (e enum) FunctionParameters() openai.FunctionParameters {
return openai.FunctionParameters{
"type": "string",
"description": e.Description(),
"enum": e.values,
}
}
func (e enum) Definition() jsonschema.Definition {
def := e.basic.Definition()
def.Enum = e.values
return def
func (e enum) GoogleParameters() *genai.Schema {
return &genai.Schema{
Type: genai.TypeString,
Description: e.Description(),
Enum: e.values,
}
}
func (e enum) AnthropicInputSchema() map[string]any {
return map[string]any{
"type": "string",
"description": e.Description(),
"enum": e.values,
}
}
func (e enum) FromAny(val any) (reflect.Value, error) {

View File

@ -4,34 +4,125 @@ import (
"errors"
"reflect"
"github.com/sashabaranov/go-openai/jsonschema"
"github.com/google/generative-ai-go/genai"
"github.com/openai/openai-go"
)
type object struct {
const (
// SyntheticFieldPrefix is any prefix that is added to any synthetic fields that are added to the object, to prevent
// collisions with the fields in the struct.
SyntheticFieldPrefix = "__"
)
type Object struct {
basic
ref reflect.Type
fields map[string]Type
// syntheticFields are fields that are not in the struct but are generated by a system.
synetheticFields map[string]Type
}
func (o object) SchemaType() jsonschema.DataType {
return jsonschema.Object
func (o Object) WithSyntheticField(name string, description string) Object {
if o.synetheticFields == nil {
o.synetheticFields = map[string]Type{}
}
func (o object) Definition() jsonschema.Definition {
def := o.basic.Definition()
def.Type = jsonschema.Object
def.Properties = make(map[string]jsonschema.Definition)
o.synetheticFields[name] = basic{
DataType: TypeString,
typeName: "string",
index: -1,
required: false,
description: description,
}
return o
}
func (o Object) SyntheticFields() map[string]Type {
return o.synetheticFields
}
func (o Object) OpenAIParameters() openai.FunctionParameters {
var properties = map[string]openai.FunctionParameters{}
var required []string
for k, v := range o.fields {
def.Properties[k] = v.Definition()
properties[k] = v.OpenAIParameters()
if v.Required() {
required = append(required, k)
}
}
def.AdditionalProperties = false
return def
for k, v := range o.synetheticFields {
properties[SyntheticFieldPrefix+k] = v.OpenAIParameters()
if v.Required() {
required = append(required, SyntheticFieldPrefix+k)
}
}
func (o object) FromAny(val any) (reflect.Value, error) {
var res = openai.FunctionParameters{
"type": "object",
"description": o.Description(),
"properties": properties,
}
if len(required) > 0 {
res["required"] = required
}
return res
}
func (o Object) GoogleParameters() *genai.Schema {
var properties = map[string]*genai.Schema{}
var required []string
for k, v := range o.fields {
properties[k] = v.GoogleParameters()
if v.Required() {
required = append(required, k)
}
}
var res = &genai.Schema{
Type: genai.TypeObject,
Description: o.Description(),
Properties: properties,
}
if len(required) > 0 {
res.Required = required
}
return res
}
func (o Object) AnthropicInputSchema() map[string]any {
var properties = map[string]any{}
var required []string
for k, v := range o.fields {
properties[k] = v.AnthropicInputSchema()
if v.Required() {
required = append(required, k)
}
}
var res = map[string]any{
"type": "object",
"description": o.Description(),
"properties": properties,
}
if len(required) > 0 {
res["required"] = required
}
return res
}
// FromAny converts the value from any to the correct type, returning the value, and an error if any
func (o Object) FromAny(val any) (reflect.Value, error) {
// if the value is nil, we can't do anything
if val == nil {
return reflect.Value{}, nil
@ -68,7 +159,7 @@ func (o object) FromAny(val any) (reflect.Value, error) {
return obj, nil
}
func (o object) SetValueOnField(obj reflect.Value, val reflect.Value) {
func (o Object) SetValueOnField(obj reflect.Value, val reflect.Value) {
// if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value
if !o.required {
val = val.Addr()

View File

@ -3,12 +3,17 @@ package schema
import (
"reflect"
"github.com/sashabaranov/go-openai/jsonschema"
"github.com/google/generative-ai-go/genai"
"github.com/openai/openai-go"
)
type Type interface {
SchemaType() jsonschema.DataType
Definition() jsonschema.Definition
OpenAIParameters() openai.FunctionParameters
GoogleParameters() *genai.Schema
AnthropicInputSchema() map[string]any
//SchemaType() jsonschema.DataType
//Definition() jsonschema.Definition
Required() bool
Description() string

View File

@ -4,56 +4,82 @@ import (
"context"
"errors"
"fmt"
"github.com/sashabaranov/go-openai"
)
// ToolBox is a collection of tools that OpenAI can use to execute functions.
// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with
// the correct parameters.
type ToolBox struct {
funcs []Function
names map[string]Function
functions map[string]Function
dontRequireTool bool
}
func NewToolBox(fns ...*Function) *ToolBox {
func NewToolBox(fns ...Function) ToolBox {
res := ToolBox{
funcs: []Function{},
names: map[string]Function{},
functions: map[string]Function{},
}
for _, f := range fns {
o := *f
res.names[o.Name] = o
res.funcs = append(res.funcs, o)
}
return &res
}
func (t *ToolBox) WithFunction(f Function) *ToolBox {
t2 := *t
t2.names[f.Name] = f
t2.funcs = append(t2.funcs, f)
return &t2
}
// ToOpenAI will convert the current ToolBox to a slice of openai.Tool, which can be used to send to the OpenAI API.
func (t *ToolBox) toOpenAI() []openai.Tool {
var res []openai.Tool
for _, f := range t.funcs {
res = append(res, openai.Tool{
Type: "function",
Function: f.toOpenAIFunction(),
})
res.functions[f.Name] = f
}
return res
}
func (t *ToolBox) ToToolChoice() any {
if len(t.funcs) == 0 {
func (t ToolBox) Functions() []Function {
var res []Function
for _, f := range t.functions {
res = append(res, f)
}
return res
}
func (t ToolBox) WithFunction(f Function) ToolBox {
t.functions[f.Name] = f
return t
}
func (t ToolBox) WithFunctions(fns ...Function) ToolBox {
for _, f := range fns {
t.functions[f.Name] = f
}
return t
}
func (t ToolBox) WithSyntheticFieldsAddedToAllFunctions(fieldsAndDescriptions map[string]string) ToolBox {
for k, v := range t.functions {
t.functions[k] = v.WithSyntheticFields(fieldsAndDescriptions)
}
return t
}
func (t ToolBox) ForEachFunction(fn func(f Function)) {
for _, f := range t.functions {
fn(f)
}
}
func (t ToolBox) WithFunctionRemoved(name string) ToolBox {
delete(t.functions, name)
return t
}
func (t ToolBox) WithRequireTool(val bool) ToolBox {
t.dontRequireTool = !val
return t
}
func (t ToolBox) RequiresTool() bool {
return !t.dontRequireTool && len(t.functions) > 0
}
func (t ToolBox) ToToolChoice() any {
if len(t.functions) == 0 {
return nil
}
@ -64,8 +90,8 @@ var (
ErrFunctionNotFound = errors.New("function not found")
)
func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, params string) (string, error) {
f, ok := t.names[functionName]
func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) {
f, ok := t.functions[functionName]
if !ok {
return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName))
@ -74,6 +100,61 @@ func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, para
return f.Execute(ctx, params)
}
func (t *ToolBox) Execute(ctx context.Context, toolCall ToolCall) (string, error) {
return t.ExecuteFunction(ctx, toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
func (t ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) {
return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
}
func (t ToolBox) GetSyntheticParametersFromFunctionContext(ctx context.Context) map[string]string {
val := ctx.Value("syntheticParameters")
if val == nil {
return nil
}
syntheticParameters, ok := val.(map[string]string)
if !ok {
return nil
}
return syntheticParameters
}
// ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished.
// OnNewFunction is called when a new function is created
// OnFunctionFinished is called when a function is finished
func (t ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) {
var res []ToolCallResponse
for _, call := range toolCalls {
ctx := ctx.WithToolCall(&call)
if call.FunctionCall.Name == "" {
return nil, newError(ErrFunctionNotFound, errors.New("function name is empty"))
}
var arg any
if OnNewFunction != nil {
var err error
arg, err = OnNewFunction(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments)
if err != nil {
return nil, newError(ErrFunctionNotFound, err)
}
}
out, err := t.Execute(ctx, call)
if OnFunctionFinished != nil {
err := OnFunctionFinished(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments, out, err, arg)
if err != nil {
return nil, newError(ErrFunctionNotFound, err)
}
}
res = append(res, ToolCallResponse{
ID: call.ID,
Result: out,
Error: err,
})
}
return res, nil
}