import torch
from torch_npu.testing.testcase import TestCase, run_tests

import torch_npu

device = 'npu:0'
torch.npu.set_device(device)


class TestLDGComputation(TestCase):
    def test_no_grad(self):
        x = torch.tensor([1], dtype=torch.float32, device=device, requires_grad=True)
        with torch.no_grad():
            y = x * 2
        self.assertFalse(y.requires_grad)

        @torch.no_grad()
        def doubler(x):
            return x * 2
        z = doubler(x)
        self.assertFalse(z.requires_grad)

    def test_enable_grad(self):
        x = torch.tensor([1], dtype=torch.float32, device=device, requires_grad=True)
        with torch.no_grad():
            with torch.enable_grad():
                y = x * 2
        self.assertTrue(y.requires_grad)

        @torch.enable_grad()
        def doubler(x):
            return x * 2
        with torch.no_grad():
            z = doubler(x)
        self.assertTrue(z.requires_grad)

    def test_set_grad_enabled(self):
        x = torch.tensor([1.], device=device, requires_grad=True)
        with torch.set_grad_enabled(False):
            y = x * 2
        self.assertFalse(y.requires_grad)
        with torch.set_grad_enabled(True):
            y = x * 2
        self.assertTrue(y.requires_grad)
        with torch.set_grad_enabled(False):
            torch.set_grad_enabled(True)
            y = x * 2
        self.assertTrue(y.requires_grad)


if __name__ == "__main__":
    run_tests()