import numpy as np
import torch
import torch.nn.functional as F
from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests

import mx_driving.fused


@golden_data_cache(__file__)
def gen_inputs(shape, dtype):
    x = np.random.uniform(-1, 1, shape).astype(dtype)
    x = torch.from_numpy(x)
    y = np.random.uniform(-1, 1, shape).astype(dtype)
    y = torch.from_numpy(y)
    grad = np.random.uniform(-1, 1, shape).astype(dtype)
    grad = torch.from_numpy(grad)
    return x, y, grad


@golden_data_cache(__file__)
def gen_cpu_outputs(x, y, grad):
    x_float = x.float().detach()
    y_float = y.float().detach()
    x_float.requires_grad_()
    y_float.requires_grad_()

    output = F.relu(x_float + y_float)
    output.backward(grad.float())

    return output, x_float.grad, y_float.grad


class TestAddRelu(TestCase):
    def test_npu_add_relu_three_dim(self, device="npu"):
        x, y, grad = gen_inputs([1, 100, 3], np.float32)
        cpu_result, grad_x_cpu, grad_y_cpu = gen_cpu_outputs(x, y, grad)
        result = mx_driving.fused.npu_add_relu(x.npu(), y.npu())
        self.assertRtolEqual(result.detach().cpu().half().numpy(), cpu_result.detach().half().numpy())

        x_npu, y_npu, grad_npu = x.npu(), y.npu(), grad.npu()
        x_npu.requires_grad_()
        y_npu.requires_grad_()
        result = mx_driving.npu_add_relu(x_npu, y_npu)
        result.backward(grad_npu)
        self.assertRtolEqual(result.detach().cpu().half().numpy(), cpu_result.detach().half().numpy())
        self.assertRtolEqual(x_npu.grad.cpu().half().numpy(), grad_x_cpu.half().numpy())
        self.assertRtolEqual(y_npu.grad.cpu().half().numpy(), grad_y_cpu.half().numpy())

    def test_npu_add_relu_large_number(self, device="npu"):
        x, y, grad = gen_inputs([18, 256, 232, 100], np.float32)
        cpu_result, grad_x_cpu, grad_y_cpu = gen_cpu_outputs(x, y, grad)
        result = mx_driving.fused.npu_add_relu(x.npu(), y.npu())
        self.assertRtolEqual(result.detach().cpu().half().numpy(), cpu_result.detach().half().numpy())

        x_npu, y_npu, grad_npu = x.npu(), y.npu(), grad.npu()
        x_npu.requires_grad_()
        y_npu.requires_grad_()
        result = mx_driving.npu_add_relu(x_npu, y_npu)
        result.backward(grad_npu)
        self.assertRtolEqual(result.detach().cpu().half().numpy(), cpu_result.detach().half().numpy())
        self.assertRtolEqual(x_npu.grad.cpu().half().numpy(), grad_x_cpu.half().numpy())
        self.assertRtolEqual(y_npu.grad.cpu().half().numpy(), grad_y_cpu.half().numpy())

    def test_npu_add_relu_fp16_large_number(self, device="npu"):
        x, y, grad = gen_inputs([18, 256, 232, 100], np.float16)
        cpu_result, grad_x_cpu, grad_y_cpu = gen_cpu_outputs(x, y, grad)
        result = mx_driving.fused.npu_add_relu(x.npu(), y.npu())
        self.assertRtolEqual(result.detach().cpu().half().numpy(), cpu_result.detach().half().numpy())

        x_npu, y_npu, grad_npu = x.npu(), y.npu(), grad.npu()
        x_npu.requires_grad_()
        y_npu.requires_grad_()
        result = mx_driving.npu_add_relu(x_npu, y_npu)
        result.backward(grad_npu)
        self.assertRtolEqual(result.detach().cpu().half().numpy(), cpu_result.detach().half().numpy())
        self.assertRtolEqual(x_npu.grad.cpu().half().numpy(), grad_x_cpu.half().numpy())
        self.assertRtolEqual(y_npu.grad.cpu().half().numpy(), grad_y_cpu.half().numpy())

    def test_npu_add_relu_fp16_small_case(self, device="npu"):
        x, y, grad = gen_inputs([18], np.float16)
        cpu_result, grad_x_cpu, grad_y_cpu = gen_cpu_outputs(x, y, grad)
        result = mx_driving.fused.npu_add_relu(x.npu(), y.npu())
        self.assertRtolEqual(result.detach().cpu().half().numpy(), cpu_result.detach().half().numpy())

        x_npu, y_npu, grad_npu = x.npu(), y.npu(), grad.npu()
        x_npu.requires_grad_()
        y_npu.requires_grad_()
        result = mx_driving.npu_add_relu(x_npu, y_npu)
        result.backward(grad_npu)
        self.assertRtolEqual(result.detach().cpu().half().numpy(), cpu_result.detach().half().numpy())
        self.assertRtolEqual(x_npu.grad.cpu().half().numpy(), grad_x_cpu.half().numpy())
        self.assertRtolEqual(y_npu.grad.cpu().half().numpy(), grad_y_cpu.half().numpy())


if __name__ == "__main__":
    run_tests()