import torch

import torch.distributed as dist

from torch.autograd.function import Function



import torch_npu

from torch_npu.utils._error_code import ErrCode, ops_error





__all__ = ["SyncBatchNorm"]





class SyncBatchNorm(Function):



    @staticmethod

    def forward(self, input_tensor, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):

        input_tensor = input_tensor.contiguous()

        input_shape = input_tensor.shape

        input_tensor_ = input_tensor.reshape(input_shape[0], input_shape[1], 1, -1)

        # calculate sum/sum_square for input.

        sum_val, sum_square_val = torch_npu.batch_norm_reduce(input_tensor_, eps)



        count = torch.full((1,),

                           input_tensor.numel() // input_tensor.size(1),

                           dtype=sum_val.dtype,

                           device=sum_val.device)



        num_channels = input_tensor.shape[1]

        # C, C, 1 -> (2C + 1)

        combined = torch.cat([sum_val, sum_square_val, count], dim=0)

        # world_size * (2C + 1)

        combined_list = [torch.empty_like(combined) for k in range(world_size)]

        dist.all_gather(combined_list, combined, process_group, async_op=False)

        combined = torch.stack(combined_list, dim=0)

        # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1

        sum_all, square_sum_all, count_all = torch.split(combined, num_channels, dim=1)



        size = count_all.view(-1).sum()

        if size == 1:

            raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size) +

                             ops_error(ErrCode.VALUE))



        # calculate global mean & invstd

        mean, invstd = torch_npu.batch_norm_gather_stats_update(input_tensor,

                                                                sum_all,

                                                                square_sum_all,

                                                                running_mean,

                                                                running_var,

                                                                momentum,

                                                                eps,

                                                                count_all.view(-1))



        self.save_for_backward(input_tensor, weight, mean, invstd, count_all)

        self.process_group = process_group



        # apply element-wise normalization

        out = torch.batch_norm_elemt(input_tensor, weight, bias, mean, invstd, eps)

        return out



    @staticmethod

    def backward(self, grad_output):

        if not grad_output.is_contiguous(memory_format=torch.channels_last):

            grad_output = grad_output.contiguous()

        saved_input, weight, mean, invstd, count_tensor = self.saved_tensors

        grad_input = grad_weight = grad_bias = None

        process_group = self.process_group



        # calculate local stats as well as grad_weight / grad_bias

        sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(grad_output,

                                                                                      saved_input,

                                                                                      mean,

                                                                                      invstd,

                                                                                      weight,

                                                                                      self.needs_input_grad[0],

                                                                                      self.needs_input_grad[1],

                                                                                      self.needs_input_grad[2])



        if self.needs_input_grad[0]:

            # synchronizing stats used to calculate input gradient.

            num_channels = sum_dy.shape[0]

            combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)

            torch.distributed.all_reduce(

                combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)

            sum_dy, sum_dy_xmu = torch.split(combined, num_channels)



            # backward pass for gradient calculation

            grad_input = torch.batch_norm_backward_elemt(grad_output,

                                                         saved_input,

                                                         mean,

                                                         invstd,

                                                         weight,

                                                         sum_dy,

                                                         sum_dy_xmu,

                                                         count_tensor)



        # synchronizing of grad_weight / grad_bias is not needed as distributed

        # training would handle all reduce.

        if weight is None or not self.needs_input_grad[1]:

            grad_weight = None



        if weight is None or not self.needs_input_grad[2]:

            grad_bias = None



        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None