package go_llm

import (
	"context"
	"fmt"
	"strings"

	oai "github.com/sashabaranov/go-openai"
)

type openaiImpl struct {
	key   string
	model string
}

var _ LLM = openaiImpl{}

func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
	res := oai.ChatCompletionRequest{
		Model: o.model,
	}

	for _, msg := range request.Messages {
		m := oai.ChatCompletionMessage{
			Content: msg.Text,
			Role:    string(msg.Role),
			Name:    msg.Name,
		}

		for _, img := range msg.Images {
			if img.Base64 != "" {
				m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
					Type: "image_url",
					ImageURL: &oai.ChatMessageImageURL{
						URL: fmt.Sprintf("data:%s;base64,%s", img.ContentType, img.Base64),
					},
				})
			} else if img.Url != "" {
				m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
					Type: "image_url",
					ImageURL: &oai.ChatMessageImageURL{
						URL: img.Url,
					},
				})
			}
		}

		// openai does not allow Content and MultiContent to be set at the same time, so we need to check
		if len(m.MultiContent) > 0 && m.Content != "" {
			m.MultiContent = append([]oai.ChatMessagePart{{
				Type: "text",
				Text: m.Content,
			}}, m.MultiContent...)

			m.Content = ""
		}

		res.Messages = append(res.Messages, m)
	}

	if request.Toolbox != nil {
		for _, tool := range request.Toolbox.funcs {
			res.Tools = append(res.Tools, oai.Tool{
				Type: "function",
				Function: &oai.FunctionDefinition{
					Name:        tool.Name,
					Description: tool.Description,
					Strict:      tool.Strict,
					Parameters:  tool.Parameters.Definition(),
				},
			})

			fmt.Println("tool:", tool.Name, tool.Description, tool.Strict, tool.Parameters.Definition())
		}
	}

	if request.Temperature != nil {
		res.Temperature = *request.Temperature
	}

	// is this an o1-* model?
	isO1 := strings.Split(o.model, "-")[0] == "o1"

	if isO1 {
		// o1 models do not support system messages, so if any messages are system messages, we need to convert them to
		// user messages

		for i, msg := range res.Messages {
			if msg.Role == "system" {
				res.Messages[i].Role = "user"
			}
		}
	}

	return res
}

func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
	res := Response{}

	for _, choice := range response.Choices {
		var toolCalls []ToolCall
		for _, call := range choice.Message.ToolCalls {
			fmt.Println("responseToLLMResponse: call:", call.Function.Arguments)
			toolCall := ToolCall{
				ID: call.ID,
				FunctionCall: FunctionCall{
					Name:      call.Function.Name,
					Arguments: call.Function.Arguments,
				},
			}

			fmt.Println("toolCall.FunctionCall.Arguments:", toolCall.FunctionCall.Arguments)

			toolCalls = append(toolCalls, toolCall)

		}
		res.Choices = append(res.Choices, ResponseChoice{
			Content: choice.Message.Content,
			Role:    Role(choice.Message.Role),
			Name:    choice.Message.Name,
			Refusal: choice.Message.Refusal,
			Calls:   toolCalls,
		})
	}

	return res
}

func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
	cl := oai.NewClient(o.key)

	req := o.requestToOpenAIRequest(request)

	resp, err := cl.CreateChatCompletion(ctx, req)

	fmt.Println("resp:", fmt.Sprintf("%#v", resp))

	if err != nil {
		return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)
	}

	return o.responseToLLMResponse(resp), nil
}

func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
	return openaiImpl{
		key:   o.key,
		model: modelVersion,
	}, nil
}