import torch
import torch.nn as nn

import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests


class TestInit(TestCase):
    def test_xavier_uniform_(self):
        m = nn.Conv2d(3, 3, 1)
        n = m.npu()
        nn.init.xavier_uniform_(m.weight, gain=True)
        nn.init.xavier_uniform_(n.weight, gain=True)
        self.assertEqual(m.weight.requires_grad, n.weight.requires_grad)

        m = nn.Conv2d(3, 3, 1)
        n = m.npu()
        nn.init.xavier_uniform_(m.weight, gain=False)
        nn.init.xavier_uniform_(n.weight, gain=False)
        self.assertEqual(m.weight.requires_grad, n.weight.requires_grad)

        m = nn.Conv2d(3, 3, 1)
        n = m.npu()
        nn.init.xavier_normal_(m.weight, gain=True)
        nn.init.xavier_normal_(n.weight, gain=True)
        self.assertEqual(m.weight.requires_grad, n.weight.requires_grad)


if __name__ == "__main__":
    run_tests()