# Copyright (c) 2020, Huawei Technologies.All rights reserved.
#
# Licensed under the BSD 3-Clause License  (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import numpy as np

import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor


class TestBatchNormGatherStatsUpdate(TestCase):
    # pylint:disable=huawei-too-many-arguments
    def expect_cpu_out(self, cpu_sum, cpu_square_sum, cpu_running_mean, cpu_running_var,
                       momentum, eps, cpu_counts):
        count_sum = torch.sum(cpu_counts)
        cpu_sum = torch.sum(cpu_sum, dim=[0])
        cpu_square_sum = torch.sum(cpu_square_sum, dim=[0])
        try:
            cpu_mean = cpu_sum / count_sum
            cpu_square_mean = cpu_square_sum / count_sum
        except ZeroDivisionError:
            print("count_sum is 0, raise ZeroDivisionError.")
        cpu_var = cpu_square_mean - cpu_mean * cpu_mean

        if cpu_var.dtype != torch.float32:
            try:
                cpu_invstd = 1 / torch.sqrt(cpu_var.float() + eps)
            except ZeroDivisionError:
                print("cpu_var + eps is 0, raise ZeroDivisionError.")
            cpu_invstd = cpu_invstd.half()
        else:
            try:
                cpu_invstd = 1 / torch.sqrt(cpu_var + eps)
            except ZeroDivisionError:
                print("cpu_var + eps is 0, raise ZeroDivisionError.")

        cpu_running_mean = momentum * cpu_mean + (1 - momentum) * cpu_running_mean
        cpu_running_var = momentum * cpu_var + (1 - momentum) * cpu_running_var
        cpu_mean = cpu_mean.cpu().numpy()
        cpu_invstd = cpu_invstd.cpu().numpy()
        cpu_running_mean = cpu_running_mean.cpu().numpy()
        cpu_running_var = cpu_running_var.cpu().numpy()
        # pylint:disable=too-many-return-values
        return cpu_mean, cpu_invstd, cpu_running_mean, cpu_running_var

    def npu_op_exec(self, *args):
        npu_mean, npu_invstd = torch.batch_norm_gather_stats_update(*args)
        out_mean = npu_mean.cpu().numpy()
        out_invstd = npu_invstd.cpu().numpy()
        return out_mean, out_invstd

    def create_counts_tensor(self, item):
        dtype = item[0]
        npu_format = item[1]
        if dtype != torch.float32:
            data = [4, 5, 3, 2]
        else:
            data = [4, 5, 6, 4]
        input1 = np.array(data).astype(dtype)
        cpu_counts = torch.from_numpy(input1)
        npu_counts = cpu_counts.to("npu:0")
        if npu_format != -1:
            npu_counts = torch_npu.npu_format_cast(npu_counts, npu_format)
        return cpu_counts, npu_counts

    def test_batch_norm_gather_stats_update(self):
        np.random.seed(1234)
        shape_format = [
            [[np.float32, -1, [2, 3, 12, 12]], [np.float32, -1, [4, 3]], [np.float32, -1, [4, 3]],
             [np.float32, -1, [3]], [np.float32, -1, [3]], 1e-3, 1e-5, [np.float32, -1, [4]]],
            [[np.float16, -1, [16, 3, 12, 12]], [np.float16, -1, [4, 3]], [np.float16, -1, [4, 3]],
             [np.float16, -1, [3]], [np.float16, -1, [3]], 1e-2, 1e-4, [np.float16, -1, [4]]],
        ]

        for item in shape_format:
            cpu_counts, npu_counts = self.create_counts_tensor(item[7])
            cpu_input1, npu_input1fp16 = create_common_tensor(item[0], 1, 10)
            if item[1][0] == np.float32:
                npu_input1fp32 = npu_input1fp16.float()
            cpu_sum, npu_sum = create_common_tensor(item[1], 0, 1)
            cpu_square_sum, npu_square_sum = create_common_tensor(item[2], 0, 1)
            cpu_running_mean, npu_running_mean = create_common_tensor(item[3], 0, 1)
            cpu_running_var, npu_running_var = create_common_tensor(item[4], 0, 1)

            cpu_output = self.expect_cpu_out(cpu_sum, cpu_square_sum,
                                             cpu_running_mean, cpu_running_var,
                                             item[-3], item[-2], cpu_counts)

            npu_outputfp16 = self.npu_op_exec(npu_input1fp16, npu_sum, npu_square_sum,
                                              npu_running_mean, npu_running_var,
                                              item[-3], item[-2], npu_counts)

            self.assertRtolEqual(npu_outputfp16[0], cpu_output[0])
            self.assertRtolEqual(npu_outputfp16[1], cpu_output[1])
            self.assertRtolEqual(npu_running_mean.cpu().numpy(), cpu_output[2])
            self.assertRtolEqual(npu_running_var.cpu().numpy(), cpu_output[3])

            if item[1][0] == np.float32:
                npu_outputfp32 = self.npu_op_exec(npu_input1fp32, npu_sum, npu_square_sum,
                                                  npu_running_mean, npu_running_var,
                                                  item[-3], item[-2], npu_counts)
                self.assertRtolEqual(npu_outputfp32[0], cpu_output[0])
                self.assertRtolEqual(npu_outputfp32[1], cpu_output[1])


if __name__ == "__main__":
    run_tests()