import torch

import torch_npu



from mindspeed_mm.models.common.chunkloss import (

    chunk_loss,

    calculate_lm_loss

)

from tests.ut.utils import judge_expression





class TestChunkLoss:

    """

    Test ChunkLoss

    """



    device = "npu"

    dtype = torch.bfloat16



    micro_batch_size = 2

    grad_acc = 2

    seq_len = 8192

    chunk_size = 1024

    hidden_dim = 4096

    vocab_size = 151674

    mask_len = 200



    inputs = []

    shift_labels = []

    hidden_states = []

    loss_masks = []

    for _ in range(grad_acc):

        input = torch.rand(micro_batch_size, seq_len, hidden_dim, requires_grad=True, dtype=dtype, device=device)

        label = torch.randint(vocab_size, (micro_batch_size, seq_len), dtype=torch.long, device=device)

        label[:, -200:] = -100

        shift_label = label[:, 1:].contiguous()

        hidden_state = input[:, :-1].contiguous()

        loss_mask = shift_label > -1

        inputs.append(input)

        shift_labels.append(shift_label)

        hidden_states.append(hidden_state)

        loss_masks.append(loss_mask)



    lm_head = torch.nn.Linear(hidden_dim, vocab_size, bias=False, dtype=dtype).to(device)



    @staticmethod

    def _judge_result(no_chunk_forward, chunk_forward, no_chunk_grad, chunk_grad):

        judge_expression(torch.allclose(no_chunk_forward, chunk_forward, rtol=1e-5, atol=1e-6))

        judge_expression(torch.allclose(no_chunk_grad, chunk_grad, rtol=1e-4, atol=1e-5))



    def _loss_forward_backward_per_step(self, hidden_state, shift_label, alpha, reduction):

        no_chunk_forward, _ = calculate_lm_loss(

            hidden_states=hidden_state,

            head_weight=self.lm_head.weight,

            shift_labels=shift_label,

            alpha=alpha,

            ignore_index=-100,

            reduction=reduction

        )

        no_chunk_forward.backward()

        return no_chunk_forward



    def _chunk_loss_forward_backward_per_step(self, hidden_state, shift_label, alpha, reduction):

        chunk_labels = torch.split(shift_label, self.chunk_size, dim=1)

        loss_ctx_kwargs = [

            {

                "shift_labels": chunk_labels[i],

                "ignore_index": -100,

                "reduction": reduction,

                "alpha": alpha

            }

            for i in range(len(chunk_labels))

        ]



        chunk_forward = chunk_loss(

            hidden_states=hidden_state,

            head_weight=self.lm_head.weight,

            head_bias=None,

            loss_forward=calculate_lm_loss,

            loss_kwargs_chunks=loss_ctx_kwargs,

            chunk_size=self.chunk_size

        )



        chunk_forward.backward()

        return chunk_forward



    def _loss_forward_backward(self, alphas, reductions, per_step_func):

        """no chunk"""

        accumulated_forward = 0

        for i in range(self.grad_acc):

            loss_forward = per_step_func(

                self.hidden_states[i],

                self.shift_labels[i],

                alpha=alphas[i],

                reduction=reductions[i]

            )

            accumulated_forward += loss_forward



        grad = self.lm_head.weight.grad

        # reset grad

        self.lm_head.weight.grad = None



        return accumulated_forward, grad



    def test_default_vlm_loss(self):

        alphas = [self.loss_masks[i].sum() for i in range(self.grad_acc)]

        reductions = ["sum"] * self.grad_acc

        no_chunk_forward, no_chunk_grad = self._loss_forward_backward(

            alphas=alphas,

            reductions=reductions,

            per_step_func=self._loss_forward_backward_per_step

        )

        chunk_forward, chunk_grad = self._loss_forward_backward(

            alphas=alphas,

            reductions=reductions,

            per_step_func=self._chunk_loss_forward_backward_per_step

        )

        self._judge_result(no_chunk_forward, chunk_forward, no_chunk_grad, chunk_grad)



    def test_per_sample_vlm_loss(self):

        alphas = [self.loss_masks[i].sum(1) * self.loss_masks[i].shape[0] for i in range(self.grad_acc)]

        reductions = ["none"] * self.grad_acc

        no_chunk_forward, no_chunk_grad = self._loss_forward_backward(

            alphas=alphas,

            reductions=reductions,

            per_step_func=self._loss_forward_backward_per_step

        )

        chunk_forward, chunk_grad = self._loss_forward_backward(

            alphas=alphas,

            reductions=reductions,

            per_step_func=self._chunk_loss_forward_backward_per_step

        )

        self._judge_result(no_chunk_forward, chunk_forward, no_chunk_grad, chunk_grad)



    def test_per_token_vlm_loss(self):

        alphas = [sum([self.loss_masks[i].sum() for i in range(self.grad_acc)])] * self.grad_acc

        reductions = ["none"] * self.grad_acc

        no_chunk_forward, no_chunk_grad = self._loss_forward_backward(

            alphas=alphas,

            reductions=reductions,

            per_step_func=self._loss_forward_backward_per_step

        )

        chunk_forward, chunk_grad = self._loss_forward_backward(

            alphas=alphas,

            reductions=reductions,

            per_step_func=self._chunk_loss_forward_backward_per_step

        )

        self._judge_result(no_chunk_forward, chunk_forward, no_chunk_grad, chunk_grad)