Change function result type from string to any

Updated the return type of functions and related code from `string` to `any` to improve flexibility and support more diverse outputs. Adjusted function implementations, signatures, and handling of results accordingly.
This commit is contained in:
Steve Dudenhoeffer 2025-03-25 23:53:09 -04:00
parent 5ba42056ad
commit 82feb7d8b4
4 changed files with 18 additions and 14 deletions

View File

@ -4,11 +4,13 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
"reflect" "reflect"
"time" "time"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
) )
type Function struct { type Function struct {
@ -31,7 +33,7 @@ type Function struct {
definition *jsonschema.Definition definition *jsonschema.Definition
} }
func (f *Function) Execute(ctx *Context, input string) (string, error) { func (f *Function) Execute(ctx *Context, input string) (any, error) {
if !f.fn.IsValid() { if !f.fn.IsValid() {
return "", fmt.Errorf("function %s is not implemented", f.Name) return "", fmt.Errorf("function %s is not implemented", f.Name)
} }
@ -46,7 +48,7 @@ func (f *Function) Execute(ctx *Context, input string) (string, error) {
} }
// now we can call the function // now we can call the function
exec := func(ctx *Context) (string, error) { exec := func(ctx *Context) (any, error) {
out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()}) out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
if len(out) != 2 { if len(out) != 2 {
@ -54,7 +56,7 @@ func (f *Function) Execute(ctx *Context, input string) (string, error) {
} }
if out[1].IsNil() { if out[1].IsNil() {
return out[0].String(), nil return out[0].Interface(), nil
} }
return "", out[1].Interface().(error) return "", out[1].Interface().(error)

View File

@ -1,8 +1,9 @@
package go_llm package go_llm
import ( import (
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
"reflect" "reflect"
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
) )
// Parse takes a function pointer and returns a function object. // Parse takes a function pointer and returns a function object.
@ -12,7 +13,7 @@ import (
// The struct parameters can have the following tags: // The struct parameters can have the following tags:
// - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for // - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for
func NewFunction[T any](name string, description string, fn func(*Context, T) (string, error)) *Function { func NewFunction[T any](name string, description string, fn func(*Context, T) (any, error)) *Function {
var o T var o T
res := Function{ res := Function{

7
llm.go
View File

@ -2,6 +2,7 @@ package go_llm
import ( import (
"context" "context"
"fmt"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
) )
@ -142,7 +143,7 @@ func (t ToolCall) toChatCompletionMessages() []openai.ChatCompletionMessage {
type ToolCallResponse struct { type ToolCallResponse struct {
ID string ID string
Result string Result any
Error error Error error
} }
@ -167,7 +168,7 @@ func (t ToolCallResponse) toChatCompletionMessages() []openai.ChatCompletionMess
if refusal != "" { if refusal != "" {
if t.Result != "" { if t.Result != "" {
t.Result = t.Result + " (error in execution: " + refusal + ")" t.Result = fmt.Sprint(t.Result) + " (error in execution: " + refusal + ")"
} else { } else {
t.Result = "error in execution:" + refusal t.Result = "error in execution:" + refusal
} }
@ -175,7 +176,7 @@ func (t ToolCallResponse) toChatCompletionMessages() []openai.ChatCompletionMess
return []openai.ChatCompletionMessage{{ return []openai.ChatCompletionMessage{{
Role: openai.ChatMessageRoleTool, Role: openai.ChatMessageRoleTool,
Content: t.Result, Content: fmt.Sprint(t.Result),
ToolCallID: t.ID, ToolCallID: t.ID,
}} }}
} }

View File

@ -87,7 +87,7 @@ var (
ErrFunctionNotFound = errors.New("function not found") ErrFunctionNotFound = errors.New("function not found")
) )
func (t *ToolBox) executeFunction(ctx *Context, functionName string, params string) (string, error) { func (t *ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) {
f, ok := t.names[functionName] f, ok := t.names[functionName]
if !ok { if !ok {
@ -97,14 +97,14 @@ func (t *ToolBox) executeFunction(ctx *Context, functionName string, params stri
return f.Execute(ctx, params) return f.Execute(ctx, params)
} }
func (t *ToolBox) Execute(ctx *Context, toolCall ToolCall) (string, error) { func (t *ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) {
return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments) return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
} }
// ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished. // ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished.
// OnNewFunction is called when a new function is created // OnNewFunction is called when a new function is created
// OnFunctionFinished is called when a function is finished // OnFunctionFinished is called when a function is finished
func (t *ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result string, err error, newFunctionResult any) error) ([]ToolCallResponse, error) { func (t *ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) {
var res []ToolCallResponse var res []ToolCallResponse
for _, call := range toolCalls { for _, call := range toolCalls {