import torch
import torch.nn.functional as F
import torch_npu

from torch_npu.testing.testcase import TestCase, run_tests


class TestDataParallelLayers(TestCase):
    def test_parallel_DistributedDataParallel(self):
        class Net(torch.nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
                self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
                self.dense2 = torch.nn.Linear(128, 10)

            def forward(self, x):
                x = F.max_pool2d(F.relu(self.conv(x)), 2)
                x = x.view(x.size(0), -1)
                x = F.relu(self.dense1(x))
                x = self.dense2(x)
                return x

        model = Net()
        import os
        os.environ["MASTER_ADDR"] = "127.0.0.1"
        os.environ["MASTER_PORT"] = "29688"

        LOCAL_RANK = int(os.getenv('LOCAL_RANK', 0))
        RANK = int(os.getenv('RANK', 0))
        WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
        torch.distributed.init_process_group(backend="hccl", rank=RANK, world_size=WORLD_SIZE)
        model = model.npu()
        net = torch.nn.parallel.DistributedDataParallel(model, device_ids=[0], broadcast_buffers=False)

        self.assertEqual(net is not None, True)


if __name__ == "__main__":
    torch.npu.set_device(0)
    run_tests()