# Copyright 2022-2023 XProbe Inc.

from typing import Any, Dict
from PIL import Image

import numpy as np
import torch
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers.generation.streamers import TextStreamer

from mindspeed_mm.models.text_encoder import Tokenizer
from mindspeed_mm.tasks.inference.pipeline.pipeline_mixin.encode_mixin import MMEncoderMixin
from mindspeed_mm.tasks.inference.pipeline.pipeline_mixin.inputs_checks_mixin import InputsCheckMixin
from mindspeed_mm.tasks.inference.pipeline.pipeline_mixin.generation_mixin import GenerationMixin
from mindspeed_mm.data.data_utils.conversation import get_conv_template
from mindspeed_mm.data.data_utils.video_reader import VideoReader
from mindspeed_mm.data.data_utils.multimodal_image_video_preprocess import dynamic_preprocess


def build_infer_transform(input_size):
    IMAGENET_MEAN = (0.485, 0.456, 0.406)
    IMAGENET_STD = (0.229, 0.224, 0.225)
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform


def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
    if bound:
        start, end = bound[0], bound[1]
        start_idx = max(first_idx, round(start * fps))
        end_idx = min(round(end * fps), max_frame)
    else:
        start_idx, end_idx = first_idx, max_frame
    seg_size = float(end_idx - start_idx) / num_segments
    frame_indices = np.array([
        int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
        for idx in range(num_segments)
    ])
    return frame_indices


class InternVLPipeline(GenerationMixin, InputsCheckMixin, MMEncoderMixin):
    def __init__(self, infer_config) -> None:
        self.infer_config = infer_config
        self.prepare_model(infer_config.device, infer_config.dtype)

        self.image_encoder = self.infer_model.image_encoder

        # prepare for generate
        self.device = infer_config.device
        self.dtype = infer_config.dtype
        self.infer_data_type = infer_config.infer_data_type
        self.model_config = infer_config.text_decoder
        self.template = infer_config.template
        self.text_decoder_model_id = getattr(infer_config.text_decoder, "model_id", None)
        self.generation_config = infer_config.generation_config

        self.main_input_name = "input_ids"
        self.image_size = self.infer_config.image_encoder.vision_encoder.image_size
        self.patch_size = self.infer_config.image_encoder.vision_encoder.patch_size
        self.downsample_ratio = self.infer_config.image_encoder.vision_encoder.downsample_ratio
        self.num_image_token = int((self.image_size // self.patch_size) ** 2 * (self.downsample_ratio ** 2))
        self.num_segments = infer_config.num_segments

        self.model = self.infer_model.text_decoder.eval()
        self.vit_embeds = None

    def prepare_model(self, device, dtype):
        self.tokenizer = Tokenizer(self.infer_config.tokenizer).get_tokenizer()
        from pretrain_internvl import model_provider
        self.infer_model = model_provider()
        model_state_dict = torch.load(self.infer_config.from_pretrained, map_location="cpu")
        self.infer_model.load_state_dict(state_dict=model_state_dict["model"])
        self.infer_model.to(dtype=dtype, device=device).eval()

    def _prepare_images(self, image_path, input_size=448, max_num=12, upscale=False):
        image = Image.open(image_path).convert("RGB")
        if upscale:
            image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR)
        transform = build_infer_transform(input_size=input_size)
        images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
        pixel_values = [transform(image) for image in images]
        pixel_values = torch.stack(pixel_values)
        return pixel_values

    def _prepare_video(self, video_path, bound=None, input_size=448, max_num=1, num_segments=32):
        video_reader = VideoReader(video_reader_type="DecordVideo")(video_path, layout="THWC", array_type="numpy")
        max_frame = video_reader.get_len() - 1
        fps = float(video_reader.get_video_fps())

        pixel_values_list, num_patches_list = [], []
        transform = build_infer_transform(input_size=input_size)
        frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
        for frame_index in frame_indices:
            img = Image.fromarray(video_reader.get_batch([frame_index])[0]).convert('RGB')
            img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
            pixel_values = [transform(tile) for tile in img]
            pixel_values = torch.stack(pixel_values)
            num_patches_list.append(pixel_values.shape[0])
            pixel_values_list.append(pixel_values)
        pixel_values = torch.cat(pixel_values_list)
        return pixel_values, num_patches_list

    def _prepare_prompts(self, question, num_patches_list=None):
        if self.infer_data_type == "video":
            video_prefix = ''.join([f'Frame{i+1}: <image>\n' for i in range(len(num_patches_list))])
            question = video_prefix + question
        else:
            question = "<image>\n" + question
        return question

    def prepare_inputs(self, prompt, input_path):
        if input_path:
            if isinstance(input_path, list):
                input_path = input_path[0]
            else:
                input_path = input_path
        else:
            input_path = self.infer_config.file_path

        if not prompt:
            prompt = self.infer_config.prompts
        if self.infer_data_type == "image":
            pixel_values = self._prepare_images(
                image_path=input_path,
                input_size=self.image_size
            )
            num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
            question = self._prepare_prompts(prompt)
        elif self.infer_data_type == "video":
            pixel_values, num_patches_list = self._prepare_video(
                video_path=input_path,
                input_size=self.image_size,
                num_segments=self.num_segments
            )
            question = self._prepare_prompts(prompt, num_patches_list)
        else:
            raise AssertionError(f"Inference data type must be image or video.")

        pixel_values = pixel_values.to(self.device).to(self.dtype)

        return pixel_values, question, num_patches_list

    def prepare_inputs_for_generation(
            self, input_ids, attention_mask=None, **kwargs
    ):
        B, S = input_ids.shape
        attention_mask = torch.ones(B, S).npu()
        attention_mask = self.infer_model._prepare_decoder_attention_mask(attention_mask)
        position_ids = kwargs.get("position_ids", None)

        if attention_mask is not None and position_ids is None:
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)

        cur_input_embeds = self.model.embedding(input_ids, position_ids=position_ids)
        B, N, C = cur_input_embeds.shape
        cur_input_embeds = cur_input_embeds.reshape(B * N, C)

        input_ids = input_ids.reshape(B * N)
        selected = (input_ids == self.img_context_token_id)
        if selected.sum() == 0:
            raise ValueError("image special token must in input_ids")
        cur_input_embeds[selected] = self.vit_embeds.reshape(-1, C).to(cur_input_embeds.device)
        cur_input_embeds = cur_input_embeds.reshape(B, N, C)

        model_inputs = {
            "decoder_input": cur_input_embeds,
            "position_ids": position_ids,
            "attention_mask": attention_mask,
            "input_ids": input_ids,
        }
        return model_inputs

    def _update_model_kwargs_for_generation(self, model_kwargs: Dict[str, Any], model_inputs: Dict[str, Any]):
        # update position_ids
        if "position_ids" in model_kwargs:
            model_kwargs["position_ids"] = model_inputs["position_ids"]

        return model_kwargs

    @torch.no_grad()
    def _inference(
            self,
            input_ids,
            pixel_values=None,
            attention_mask=None,
            visual_features=None,
            return_ids=False,

    ) -> torch.LongTensor:
        if self.img_context_token_id is None:
            raise ValueError("img_context_token_id cannot be None")
        if pixel_values is not None:
            if visual_features is not None:
                vit_embeds = visual_features
            else:
                vit_embeds = self.image_encoder(pixel_values)
        if return_ids:
            streamer = None
        else:
            streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
        self.vit_embeds = vit_embeds
        outputs = self.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            generation_config=self.generation_config,
            streamer=streamer)
        return outputs

    def __call__(self, prompt=None, input_path=None, return_ids=False):
        IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
        IMG_START_TOKEN = "<img>"
        IMG_END_TOKEN = "</img>"

        pixel_values, question, num_patches_list = self.prepare_inputs(prompt=prompt, input_path=input_path)

        self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)

        template = get_conv_template(self.template)
        template.append_message(template.roles[0], question)
        template.append_message(template.roles[1], None)
        query = template.get_prompt()

        for num_patches in num_patches_list:
            image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
            query = query.replace('<image>', image_tokens, 1)

        model_inputs = self.tokenizer(query, return_tensors='pt')
        input_ids = model_inputs['input_ids'].npu()
        attention_mask = model_inputs['attention_mask'].npu()
        generation_output = self._inference(
            input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            return_ids=return_ids

        )

        if return_ids and generation_output is not None:
            answer_len = generation_output.shape[-1] - input_ids.shape[-1]
            response = self.tokenizer.batch_decode(generation_output[:, -answer_len:], skip_special_tokens=True)[0]
            response = response.split(template.sep)[0].strip()

            return response
        else:
            return None

    def evaluate(self, message, dataset=None):
        IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
        IMG_START_TOKEN = "<img>"
        IMG_END_TOKEN = "</img>"

        if dataset in ['chartqa_test', 'mmmu_dev_val', 'mmmu_test']:
            self.max_num = 12
        elif dataset in ['docvqa_val', 'docvqa_test']:
            self.max_num = 18
        else:
            self.max_num = 6

        image_num = len([x for x in message if x['type'] == 'image'])
        if image_num == 1:
            prompt = '<image>\n' + '\n'.join([x['value'] for x in message if x['type'] == 'text'])
        else:
            prompt, image_idx = '', 1
            for x in message:
                if x['type'] == 'text':
                    prompt += x['value']
                elif x['type'] == 'image':
                    prompt += f'<Image-{image_idx}>'
                    image_idx += 1
            prompt = '\n'.join([f'Image-{i + 1}: <image>' for i in range(image_num)]) + '\n' + prompt

        question = prompt
        if image_num > 1:
            image_path = [x['value'] for x in message if x['type'] == 'image']
            num_patches_list = []
            pixel_values_list = []
            for image_idx, file_name in enumerate(image_path):
                upscale_flag = image_idx == 0 and dataset is not None and dataset == 'mmu_dev_val'
                curr_pixel_values = self._prepare_images(
                    file_name, max_num=self.max_num, upscale=upscale_flag).to(self.device).to(torch.bfloat16)
                num_patches_list.append(curr_pixel_values.size(0))
                pixel_values_list.append(curr_pixel_values)
            pixel_values = torch.cat(pixel_values_list, dim=0)
        elif image_num == 1:
            image_path = [x['value'] for x in message if x['type'] == 'image'][0]
            upscale_flag = dataset == 'mmu_dev_val'
            pixel_values = self._prepare_images(
                image_path, max_num=self.max_num, upscale=upscale_flag).to(self.device).to(torch.bfloat16)
            num_patches_list = [pixel_values.size(0)]
        else:
            pixel_values = None
            num_patches_list = []

        self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
        template = get_conv_template(self.template)
        eos_token_id = self.tokenizer.convert_tokens_to_ids(template.sep)

        template.append_message(template.roles[0], question)
        template.append_message(template.roles[1], None)
        query = template.get_prompt()

        for num_patches in num_patches_list:
            image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
            query = query.replace('<image>', image_tokens, 1)

        model_inputs = self.tokenizer(query, return_tensors='pt')
        input_ids = model_inputs['input_ids'].to(self.device)
        self.init_input_ids = input_ids
        attention_mask = model_inputs['attention_mask'].to(self.device)

        self.generation_config.eos_token_id = eos_token_id

        if pixel_values is not None:
            vit_embeds = self.image_encoder(pixel_values)
        self.vit_embeds = vit_embeds
        generation_output = self.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            generation_config=self.generation_config,
        )
        if generation_output is not None:
            answer_len = generation_output.shape[-1] - input_ids.shape[-1]
            response = self.tokenizer.batch_decode(generation_output[:, -answer_len:], skip_special_tokens=True)[0]
            response = response.split(template.sep)[0].strip()

            return response
        else:
            return None