Introduced `WithFunctionRemoved` and `ExecuteCallbacks` methods to enhance `ToolBox` functionality. This allows dynamic function removal and execution of custom callbacks during tool call processing. Also cleaned up logging and improved handling for required tools in `openai.go`.
		
			
				
	
	
		
			119 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			119 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package go_llm
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 	"strings"
 | |
| 
 | |
| 	oai "github.com/sashabaranov/go-openai"
 | |
| )
 | |
| 
 | |
| type openaiImpl struct {
 | |
| 	key   string
 | |
| 	model string
 | |
| }
 | |
| 
 | |
| var _ LLM = openaiImpl{}
 | |
| 
 | |
| func (o openaiImpl) newRequestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
 | |
| 	res := oai.ChatCompletionRequest{
 | |
| 		Model: o.model,
 | |
| 	}
 | |
| 
 | |
| 	for _, i := range request.Conversation {
 | |
| 		res.Messages = append(res.Messages, i.toChatCompletionMessages()...)
 | |
| 	}
 | |
| 
 | |
| 	for _, msg := range request.Messages {
 | |
| 		res.Messages = append(res.Messages, msg.toChatCompletionMessages()...)
 | |
| 	}
 | |
| 
 | |
| 	if request.Toolbox != nil {
 | |
| 		for _, tool := range request.Toolbox.funcs {
 | |
| 			res.Tools = append(res.Tools, oai.Tool{
 | |
| 				Type: "function",
 | |
| 				Function: &oai.FunctionDefinition{
 | |
| 					Name:        tool.Name,
 | |
| 					Description: tool.Description,
 | |
| 					Strict:      tool.Strict,
 | |
| 					Parameters:  tool.Parameters.Definition(),
 | |
| 				},
 | |
| 			})
 | |
| 		}
 | |
| 
 | |
| 		if !request.Toolbox.dontRequireTool {
 | |
| 			res.ToolChoice = "required"
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if request.Temperature != nil {
 | |
| 		res.Temperature = *request.Temperature
 | |
| 	}
 | |
| 
 | |
| 	// is this an o1-* model?
 | |
| 	isO1 := strings.Split(o.model, "-")[0] == "o1"
 | |
| 
 | |
| 	if isO1 {
 | |
| 		// o1 models do not support system messages, so if any messages are system messages, we need to convert them to
 | |
| 		// user messages
 | |
| 
 | |
| 		for i, msg := range res.Messages {
 | |
| 			if msg.Role == "system" {
 | |
| 				res.Messages[i].Role = "user"
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return res
 | |
| }
 | |
| 
 | |
| func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
 | |
| 	res := Response{}
 | |
| 
 | |
| 	for _, choice := range response.Choices {
 | |
| 		var toolCalls []ToolCall
 | |
| 		for _, call := range choice.Message.ToolCalls {
 | |
| 			toolCall := ToolCall{
 | |
| 				ID: call.ID,
 | |
| 				FunctionCall: FunctionCall{
 | |
| 					Name:      call.Function.Name,
 | |
| 					Arguments: call.Function.Arguments,
 | |
| 				},
 | |
| 			}
 | |
| 
 | |
| 			toolCalls = append(toolCalls, toolCall)
 | |
| 
 | |
| 		}
 | |
| 		res.Choices = append(res.Choices, ResponseChoice{
 | |
| 			Content: choice.Message.Content,
 | |
| 			Role:    Role(choice.Message.Role),
 | |
| 			Name:    choice.Message.Name,
 | |
| 			Refusal: choice.Message.Refusal,
 | |
| 			Calls:   toolCalls,
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	return res
 | |
| }
 | |
| 
 | |
| func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
 | |
| 	cl := oai.NewClient(o.key)
 | |
| 
 | |
| 	req := o.newRequestToOpenAIRequest(request)
 | |
| 
 | |
| 	resp, err := cl.CreateChatCompletion(ctx, req)
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	return o.responseToLLMResponse(resp), nil
 | |
| }
 | |
| 
 | |
| func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
 | |
| 	return openaiImpl{
 | |
| 		key:   o.key,
 | |
| 		model: modelVersion,
 | |
| 	}, nil
 | |
| }
 |