# Copyright (c) Microsoft Corporation.

# Licensed under the MIT License.

import torch

import torch.nn as nn

import torch.utils.checkpoint as checkpoint

from timm.models.layers import DropPath, to_2tuple, trunc_normal_

from functools import partial





class GroupNorm(nn.GroupNorm):



    def __init__(self, num_channels, num_groups=1):

        """ We use GroupNorm (group = 1) to approximate LayerNorm

        for [N, C, H, W] layout"""

        super(GroupNorm, self).__init__(num_groups, num_channels)





class Mlp(nn.Module):



    def __init__(self,

                 in_features,

                 hidden_features=None,

                 out_features=None,

                 act_layer=nn.GELU,

                 drop=0.):

        """ MLP network in FFN. By default, the MLP is implemented by

        nn.Linear. However, in our implementation, the data layout is

        in format of [N, C, H, W], therefore we use 1x1 convolution to

        implement fully-connected MLP layers.



        Args:

            in_features (int): input channels

            hidden_features (int): hidden channels, if None, set to in_features

            out_features (int): out channels, if None, set to in_features

            act_layer (callable): activation function class type

            drop (float): drop out probability

        """

        super().__init__()

        out_features = out_features or in_features

        hidden_features = hidden_features or in_features

        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)

        self.act = act_layer()

        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)

        self.drop = nn.Dropout(drop)



    def forward(self, x):

        x = self.fc1(x)

        x = self.act(x)

        x = self.drop(x)

        x = self.fc2(x)

        x = self.drop(x)

        return x





class ShiftViTBlock(nn.Module):



    def __init__(self,

                 dim,

                 n_div=12,

                 mlp_ratio=4.,

                 drop=0.,

                 drop_path=0.,

                 act_layer=nn.GELU,

                 norm_layer=nn.LayerNorm,

                 input_resolution=None):

        """ The building block of Shift-ViT network.



        Args:

            dim (int): feature dimension

            n_div (int): how many divisions are used. Totally, 4/n_div of

                channels will be shifted.

            mlp_ratio (float): expand ratio of MLP network.

            drop (float): drop out prob.

            drop_path (float): drop path prob.

            act_layer (callable): activation function class type.

            norm_layer (callable): normalization layer class type.

            input_resolution (tuple): input resolution. This optional variable

                is used to calculate the flops.



        """

        super(ShiftViTBlock, self).__init__()

        self.dim = dim

        self.input_resolution = input_resolution

        self.mlp_ratio = mlp_ratio



        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)

        mlp_hidden_dim = int(dim * mlp_ratio)

        self.mlp = Mlp(in_features=dim,

                       hidden_features=mlp_hidden_dim,

                       act_layer=act_layer,

                       drop=drop)

        self.n_div = n_div



    def forward(self, x):

        x = self.shift_feat(x, self.n_div)

        shortcut = x

        x = shortcut + self.drop_path(self.mlp(self.norm2(x)))

        return x



    def extra_repr(self) -> str:

        return f"dim={self.dim}," \

               f"input_resolution={self.input_resolution}," \

               f"shift percentage={4.0 / self.n_div * 100}%."



    @staticmethod

    def shift_feat(x, n_div):

        B, C, H, W = x.shape

        g = C // n_div

        out = torch.zeros_like(x)



        out[:, g * 0:g * 1, :, :-1] = x[:, g * 0:g * 1, :, 1:]  # shift left

        out[:, g * 1:g * 2, :, 1:] = x[:, g * 1:g * 2, :, :-1]  # shift right

        out[:, g * 2:g * 3, :-1, :] = x[:, g * 2:g * 3, 1:, :]  # shift up

        out[:, g * 3:g * 4, 1:, :] = x[:, g * 3:g * 4, :-1, :]  # shift down



        out[:, g * 4:, :, :] = x[:, g * 4:, :, :]  # no shift

        return out





class PatchMerging(nn.Module):



    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):

        super().__init__()

        self.input_resolution = input_resolution

        self.dim = dim

        self.reduction = nn.Conv2d(dim, 2 * dim, (2, 2), stride=2, bias=False)

        self.norm = norm_layer(dim)



    def forward(self, x):

        x = self.norm(x)

        x = self.reduction(x)

        return x



    def extra_repr(self) -> str:

        return f"input_resolution={self.input_resolution}, dim={self.dim}"





class BasicLayer(nn.Module):



    def __init__(self,

                 dim,

                 input_resolution,

                 depth,

                 n_div=12,

                 mlp_ratio=4.,

                 drop=0.,

                 drop_path=None,

                 norm_layer=None,

                 downsample=None,

                 use_checkpoint=False,

                 act_layer=nn.GELU):



        super(BasicLayer, self).__init__()

        self.dim = dim

        self.input_resolution = input_resolution

        self.depth = depth

        self.use_checkpoint = use_checkpoint



        # build blocks

        self.blocks = nn.ModuleList([

            ShiftViTBlock(dim=dim,

                          n_div=n_div,

                          mlp_ratio=mlp_ratio,

                          drop=drop,

                          drop_path=drop_path[i],

                          norm_layer=norm_layer,

                          act_layer=act_layer,

                          input_resolution=input_resolution)

            for i in range(depth)

        ])



        # patch merging layer

        if downsample is not None:

            self.downsample = downsample(input_resolution,

                                         dim=dim,

                                         norm_layer=norm_layer)

        else:

            self.downsample = None



    def forward(self, x):

        for blk in self.blocks:

            if self.use_checkpoint:

                x = checkpoint.checkpoint(blk, x)

            else:

                x = blk(x)

        if self.downsample is not None:

            x = self.downsample(x)

        return x



    def extra_repr(self) -> str:

        return f"dim={self.dim}," \

               f"input_resolution={self.input_resolution}," \

               f"depth={self.depth}"





class PatchEmbed(nn.Module):

    r""" Image to Patch Embedding



    Args:

        img_size (int, tuple): Image size.

        patch_size (int, tuple): Patch token size.

        in_chans (int): Number of input image channels.

        embed_dim (int): Number of linear projection output channels.

        norm_layer (nn.Module, optional): Normalization layer.

    """



    def __init__(self,

                 img_size=224,

                 patch_size=4,

                 in_chans=3,

                 embed_dim=96,

                 norm_layer=None):

        super().__init__()

        img_size = to_2tuple(img_size)

        patch_size = to_2tuple(patch_size)

        patches_resolution = [img_size[0] // patch_size[0],

                              img_size[1] // patch_size[1]]

        self.img_size = img_size

        self.patch_size = patch_size

        self.patches_resolution = patches_resolution

        self.num_patches = patches_resolution[0] * patches_resolution[1]



        self.in_chans = in_chans

        self.embed_dim = embed_dim



        self.proj = nn.Conv2d(in_chans, embed_dim,

                              kernel_size=patch_size, stride=patch_size)

        if norm_layer is not None:

            self.norm = norm_layer(embed_dim)

        else:

            self.norm = None



    def forward(self, x):

        x = self.proj(x)

        if self.norm is not None:

            x = self.norm(x)

        return x





class ShiftViT(nn.Module):



    def __init__(self,

                 n_div=12,

                 img_size=224,

                 patch_size=4,

                 in_chans=3,

                 num_classes=1000,

                 embed_dim=96,

                 depths=(2, 2, 6, 2),

                 mlp_ratio=4.,

                 drop_rate=0.,

                 drop_path_rate=0.1,

                 norm_layer='GN1',

                 act_layer='GELU',

                 patch_norm=True,

                 use_checkpoint=False,

                 **kwargs):

        super().__init__()

        assert norm_layer in ('GN1', 'BN')

        if norm_layer == 'BN':

            norm_layer = nn.BatchNorm2d

        elif norm_layer == 'GN1':

            norm_layer = partial(GroupNorm, num_groups=1)

        else:

            raise NotImplementedError



        if act_layer == 'GELU':

            act_layer = nn.GELU

        elif act_layer == 'RELU':

            act_layer = partial(nn.ReLU, inplace=False)

        else:

            raise NotImplementedError



        self.num_classes = num_classes

        self.num_layers = len(depths)

        self.embed_dim = embed_dim

        self.patch_norm = patch_norm

        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))

        self.mlp_ratio = mlp_ratio



        # split image into non-overlapping patches

        self.patch_embed = PatchEmbed(

            img_size=img_size,

            patch_size=patch_size,

            in_chans=in_chans,

            embed_dim=embed_dim,

            norm_layer=norm_layer if self.patch_norm else None)



        # num_patches = self.patch_embed.num_patches

        patches_resolution = self.patch_embed.patches_resolution

        self.patches_resolution = patches_resolution

        self.pos_drop = nn.Dropout(p=drop_rate)



        # stochastic depth decay rule

        dpr = [x.item()

               for x in torch.linspace(0, drop_path_rate, sum(depths))]



        # build layers

        self.layers = nn.ModuleList()

        for i_layer in range(self.num_layers):

            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),

                               n_div=n_div,

                               input_resolution=(patches_resolution[0] // (2 ** i_layer),

                                                 patches_resolution[1] // (2 ** i_layer)),

                               depth=depths[i_layer],

                               mlp_ratio=self.mlp_ratio,

                               drop=drop_rate,

                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],

                               norm_layer=norm_layer,

                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,

                               use_checkpoint=use_checkpoint,

                               act_layer=act_layer)

            self.layers.append(layer)



        self.norm = norm_layer(self.num_features)

        self.avgpool = nn.AdaptiveAvgPool2d(1)

        self.head = nn.Linear(self.num_features, num_classes) \

            if num_classes > 0 else nn.Identity()



        self.apply(self._init_weights)



    def _init_weights(self, m):

        if isinstance(m, nn.Linear):

            trunc_normal_(m.weight, std=.02)

            if isinstance(m, nn.Linear) and m.bias is not None:

                nn.init.constant_(m.bias, 0)

        elif isinstance(m, (nn.Conv1d, nn.Conv2d)):

            trunc_normal_(m.weight, std=.02)

            if m.bias is not None:

                nn.init.constant_(m.bias, 0)

        elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):

            nn.init.constant_(m.bias, 0)

            nn.init.constant_(m.weight, 1.0)



    def forward_features(self, x):

        x = self.patch_embed(x)

        x = self.pos_drop(x)



        for layer in self.layers:

            x = layer(x)



        x = self.norm(x)  # B L C

        x = self.avgpool(x)  # B C 1

        x = torch.flatten(x, 1)

        return x



    def forward(self, x):

        x = self.forward_features(x)

        x = self.head(x)

        return x