import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_npu
from einops import rearrange

from megatron.training import get_args
from megatron.training.arguments import core_transformer_config_from_args
from mindspeed.core.fusions.fused_rms_norm import RMSNorm
from mindspeed_mm.models.common.communications import _conv_split, _conv_gather, all_to_all
from mindspeed_mm.models.common.conv import ContextParallelCausalConv3d
from mindspeed_mm.utils.utils import (get_context_parallel_rank, get_context_parallel_world_size,
                                      get_context_parallel_group)


class LayerNorm(nn.Module):
    def __init__(self, num_channels, eps=1e-6, elementsize_affine=True, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.norm = torch.nn.LayerNorm(num_channels, eps=eps, elementwise_affine=elementsize_affine)

    def forward(self, x):
        if x.dim() == 5:
            x = rearrange(x, "b c t h w -> b t h w c")
            x = self.norm(x)
            x = rearrange(x, "b t h w c -> b c t h w")
        else:
            x = rearrange(x, "b c h w -> b h w c")
            x = self.norm(x)
            x = rearrange(x, "b h w c -> b c h w")
        return x


class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        try:
            return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
        except TypeError:
            input_dtype = hidden_states.dtype
            hidden_states = hidden_states.to(torch.float32)
            variance = hidden_states.pow(2).mean(-1, keepdim=True)
            hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            return self.weight * hidden_states.to(input_dtype)


class FP32LayerNorm(nn.Module):
    def __init__(
        self,
        dim: int,
        eps: float = 1e-6,
        sequence_parallel: bool = False,
        **kwargs
    ):
        super().__init__()
        self.dim = (dim,)
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(self.dim))
        self.bias = nn.Parameter(torch.zeros(self.dim))

        setattr(self.weight, 'sequence_parallel', sequence_parallel)
        setattr(self.bias, 'sequence_parallel', sequence_parallel)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        origin_dtype = inputs.dtype
        return F.layer_norm(
            inputs.float(),
            self.dim,
            self.weight.float() if self.weight is not None else None,
            self.bias.float() if self.bias is not None else None,
            self.eps,
        ).to(origin_dtype)


def normalize(in_channels, num_groups=32, eps=1e-6, affine=True, norm_type="groupnorm", gather=False, **kwargs):
    if not gather:
        if norm_type == "groupnorm":
            return torch.nn.GroupNorm(
                num_groups=num_groups, num_channels=in_channels, eps=eps, affine=affine
            )
        elif norm_type == "aelayernorm":
            return LayerNorm(num_channels=in_channels, eps=eps)
        elif norm_type == "layernorm":
            return nn.LayerNorm(in_channels, eps=eps, elementwise_affine=affine)
        elif norm_type == "rmsnorm":
            args = get_args()
            config = core_transformer_config_from_args(args)
            return RMSNorm(dim=in_channels, eps=eps, config=config, **kwargs)
        else:
            raise ValueError(f"unsupported norm type: {norm_type}. ")
    else:
        return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=affine)


class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_, dim, kernel_size):
        ctx.dim = dim
        ctx.kernel_size = kernel_size
        return _conv_split(input_, dim, kernel_size)

    @staticmethod
    def backward(ctx, grad_output):
        return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None


class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_, dim, kernel_size):
        ctx.dim = dim
        ctx.kernel_size = kernel_size
        return _conv_gather(input_, dim, kernel_size)

    @staticmethod
    def backward(ctx, grad_output):
        return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None


def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
    return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)


def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
    return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)


class ContextParallelGroupNorm(torch.nn.GroupNorm):
    def forward(self, input_, enable_cp=True):
        if not enable_cp:
            return super().forward(input_)

        gather_flag = input_.shape[2] > 1
        if gather_flag:
            cp_world_size = get_context_parallel_world_size()
            if cp_world_size == 1:
                return super().forward(input_)
            group = get_context_parallel_group()
            cp_rank = get_context_parallel_rank()
            _, ch, t, _, _ = input_.shape
            group_size = ch // self.num_groups
            scatter_sizes = torch.tensor_split(torch.ones(self.num_groups) * group_size, cp_world_size)
            scatter_sizes = [int(s.sum().item()) for s in scatter_sizes]
            if cp_rank == 0:
                t -= 1
            gather_sizes = [t] * cp_world_size
            gather_sizes[0] += 1
            input_ = all_to_all(input_, group, 1, 2, scatter_sizes, gather_sizes)
            begin = sum(scatter_sizes[:cp_rank])
            end = begin + scatter_sizes[cp_rank]
            output = torch.nn.functional.group_norm(
                input_, scatter_sizes[cp_rank] // group_size, self.weight[begin: end], self.bias[begin: end], self.eps)
            output = all_to_all(output, group, 2, 1, gather_sizes, scatter_sizes)
        else:
            output = super().forward(input_)
        return output


def Normalize3D(
        in_channels,
        zq_ch=None,
        add_conv=False,
        gather=False,
):
    if gather:
        return SpatialNorm3D(
            in_channels,
            zq_ch,
            gather=gather,
            freeze_norm_layer=False,
            add_conv=add_conv,
            num_groups=32,
            eps=1e-6,
            affine=True,
        )
    else:
        return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)


class SpatialNorm3D(nn.Module):
    def __init__(
        self,
        f_channels,
        zq_channels,
        freeze_norm_layer=False,
        add_conv=False,
        pad_mode="constant",
        gather=False,
        **norm_layer_params,
    ):
        super().__init__()

        self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params)
        if freeze_norm_layer:
            for p in self.norm_layer.parameters():
                p.requires_grad = False

        self.add_conv = add_conv
        if add_conv:
            self.conv = ContextParallelCausalConv3d(
                chan_in=zq_channels,
                chan_out=zq_channels,
                kernel_size=3,
            )

        self.conv_y = ContextParallelCausalConv3d(
            chan_in=zq_channels,
            chan_out=f_channels,
            kernel_size=1,
        )
        self.conv_b = ContextParallelCausalConv3d(
            chan_in=zq_channels,
            chan_out=f_channels,
            kernel_size=1,
        )

    def forward(self, f, zq, clear_fake_cp_cache=True, enable_cp=True):
        if f.shape[2] > 1 and get_context_parallel_rank() == 0 and enable_cp:
            f_first, f_rest = f[:, :, :1], f[:, :, 1:]
            f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
            zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
            zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")

            zq_rest_splits = torch.split(zq_rest, 32, dim=1)
            interpolated_splits = [
                torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest")
                for split in zq_rest_splits
            ]

            zq_rest = torch.cat(interpolated_splits, dim=1)

            zq = torch.cat([zq_first, zq_rest], dim=2)
        else:
            f_size = f.shape[-3:]

            zq_splits = torch.split(zq, 32, dim=1)
            interpolated_splits = [
                torch.nn.functional.interpolate(split, size=f_size, mode="nearest")
                for split in zq_splits
            ]
            zq = torch.cat(interpolated_splits, dim=1)

        if self.add_conv:
            zq, _ = self.conv(zq, clear_cache=clear_fake_cp_cache, enable_cp=enable_cp)

        norm_f = self.norm_layer(f)

        conv_y_out, _ = self.conv_y(zq, enable_cp=enable_cp)
        conv_b_out, _ = self.conv_b(zq, enable_cp=enable_cp)
        new_f = norm_f * conv_y_out + conv_b_out
        return new_f