import os
import torch

import torch_npu
import torchair
from torchair.configs.compiler_config import CompilerConfig
import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce


class AllReduceSingeGroup(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        x = x + y
        torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM)
        return x


def example(rank, world_size):
    torch.npu.set_device(rank)
    torch.distributed.init_process_group("hccl", rank=rank, world_size=world_size)
    x = torch.ones([2, 2], dtype=torch.int32).to("npu:"+str(rank))
    y = torch.ones([2, 2], dtype=torch.int32).to("npu:"+str(rank))
    config = CompilerConfig()
    npu_backend = torchair.get_npu_backend(compiler_config=config)
    model = torch.compile(AllReduceSingeGroup().to("npu:"+str(rank)), backend=npu_backend, dynamic=False)
    out = torch.ones([2, 2], dtype=torch.int32).npu() * 2 * world_size
    ret = model(x, y)
    assert out.equal(ret)
    torch.distributed.destroy_process_group()


def mp():
    world_size = 2
    torch.multiprocessing.spawn(example, args=(world_size, ), nprocs=world_size, join=True)

if __name__ == '__main__':
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29506"    
    mp()