import copy
import unittest

import torch
import torch.nn.functional as F
import torch_npu

from torch_npu.testing.testcase import TestCase, run_tests


class TestLinearFunctions(TestCase):
    @unittest.skip("skip test_linear now")
    def test_linear(self):
        input1 = torch.randn(2, 3, 4)
        weight = torch.randn(3, 4)
        npu_input = copy.deepcopy(input1).npu()
        npu_weight = copy.deepcopy(weight).npu()

        cpu_output = F.linear(input1, weight)
        npu_output = F.linear(npu_input, npu_weight)

        self.assertRtolEqual(cpu_output.numpy(), npu_output.cpu().numpy())

    @unittest.skip("skip test_bilinear now")
    def test_bilinear(self):
        input1 = torch.randn(10, 30)
        input2 = torch.randn(10, 40)
        weight = torch.randn(5, 30, 40)
        bias = torch.randn(5)

        npu_input1 = copy.deepcopy(input1).npu()
        npu_input2 = copy.deepcopy(input2).npu()
        npu_weight = copy.deepcopy(weight).npu()
        npu_bias = copy.deepcopy(bias).npu()

        cpu_output = F.bilinear(input1, input2, weight, bias)
        npu_output = F.bilinear(npu_input1, npu_input2, npu_weight, npu_bias)

        self.assertRtolEqual(cpu_output.numpy(), npu_output.cpu().numpy())


if __name__ == "__main__":
    run_tests()