import unittest
import torch
import torch_npu

from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import freeze_rng_state

TEST_NPU_SUPPORT = False
TEST_NPU = torch.npu.is_available()
TEST_MULTINPU = TEST_NPU and torch.npu.device_count() >= 2

device = 'npu:0'
torch.npu.set_device(device)


class TestRandomSampling(TestCase):
    def test_seed(self):
        torch.seed()

    def test_manual_seed(self):
        with freeze_rng_state():
            x = torch.zeros(4, 4).float().npu()
            torch.npu.manual_seed(2)
            self.assertEqual(torch.npu.initial_seed(), 2)
            x.uniform_()
            a = torch.bernoulli(torch.full_like(x, 0.5))
            torch.npu.manual_seed(2)
            y = x.clone().uniform_()
            b = torch.bernoulli(torch.full_like(x, 0.5))
            self.assertRtolEqual(x.cpu().numpy(), y.cpu().numpy())
            self.assertRtolEqual(a.cpu().numpy(), b.cpu().numpy())
            self.assertEqual(torch.npu.initial_seed(), 2)

    @unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
    @unittest.skipIf(not TEST_NPU_SUPPORT, "NPU not support")
    def test_get_set_rng_state_all(self):
        states = torch.npu.get_rng_state_all()
        before0 = torch.npu.FloatTensor(100, device=0).normal_()
        before1 = torch.npu.FloatTensor(100, device=1).normal_()
        torch.npu.set_rng_state_all(states)
        after0 = torch.npu.FloatTensor(100, device=0).normal_()
        after1 = torch.npu.FloatTensor(100, device=1).normal_()
        self.assertEqual(before0, after0, 0)
        self.assertEqual(before1, after1, 0)

    def test_rand(self):
        out = torch.rand(2, 3, device=device)

    def test_rand_like(self):
        input1 = torch.randn((2, 3), device=device)
        out = torch.rand_like(input1, device=device)

    def test_randint(self):
        npu_output1 = torch.randint(3, 5, (3,), device=device)
        npu_output2 = torch.randint(10, (2, 2), device=device)
        npu_output3 = torch.randint(3, 10, (2, 2), device=device)

    def test_randint_like(self):
        input1 = torch.randn((2, 3), device=device)
        output = torch.randint_like(input1, high=8, device=device)

    def test_randn(self):
        output = torch.randn((2, 3), device=device)

    def test_randn_like(self):
        input1 = torch.randn((2, 3), device=device)
        output = torch.randn_like(input1, device=device)


class TestQuasiRandomSampling(TestCase):
    def test_quasirandom_sobolEngine(self):
        soboleng = torch.quasirandom.SobolEngine(dimension=5)
        output = soboleng.draw(3)


if __name__ == "__main__":
    run_tests()