import logging
import unittest
import time
from pathlib import Path

import torch
import torch_npu
import torchair
from torchair.configs.compiler_config import CompilerConfig
from torchair.core.utils import logger


logger.setLevel(logging.DEBUG)


class DataDumpTest(unittest.TestCase):

    def test_data_dump_with_scope(self):
        class Network(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, data0, data1):
                add_01 = torch.add(data0, data1)
                with torchair.scope.data_dump():
                    sub_01 = torch.sub(data0, data1)
                return add_01, sub_01

        input0 = torch.randn(2, 2, dtype=torch.float16).npu()
        input1 = torch.randn(2, 2, dtype=torch.float16).npu()
        config = torchair.CompilerConfig()
        config.dump_config.enable_dump = True
        config.dump_config.dump_layer = " Add "
        npu_backend = torchair.get_npu_backend(compiler_config=config)
        npu_mode = Network()
        npu_mode = torch.compile(npu_mode, backend=npu_backend)
        npu_mode(input0, input1)
        torch.npu.synchronize()
        cwd = Path.cwd()
        rank_dirs = sorted(cwd.glob("worldsize*_global_rank*"))
        assert rank_dirs, "No rank directory found for dump validation"
        rank_dir = rank_dirs[0]
        pattern = "*/0/graph_*/1/0"
        candidates = sorted(rank_dir.glob(pattern), key=lambda p: p.stat().st_mtime if p.exists() else 0)
        assert candidates, f"No dump subdirectories found under {rank_dir} with pattern {pattern}"
        target_dir = candidates[-1]

        files = [p.name for p in target_dir.iterdir() if p.is_file()]
        assert files, f"No files found in dump directory {target_dir}"
        files_lower = [n.lower() for n in files]
        has_add = any("add" in n for n in files_lower)
        has_sub = any("sub" in n for n in files_lower)
        assert has_add and has_sub, f"Dump files missing expected ops: add={has_add}, sub={has_sub}; files={files}"


    def test_aclgraph_data_dump(self):
        class Network(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, data0, data1):
                add_01 = torch.add(data0, data1)
                sub_01 = torch.sub(data0, data1)
                return add_01, sub_01

        input0 = torch.randn(2, 2, dtype=torch.float16).npu()
        input1 = torch.randn(2, 2, dtype=torch.float16).npu()
        config = torchair.CompilerConfig()
        config.mode = "reduce-overhead"
        config.dump_config.enable_dump = True
        npu_backend = torchair.get_npu_backend(compiler_config=config)
        npu_mode = Network()
        npu_mode = torch.compile(npu_mode, backend=npu_backend)
        npu_mode(input0, input1)
        torch.npu.synchronize()
        from pathlib import Path
        cwd = Path.cwd()
        rank_dirs = sorted(cwd.glob("worldsize*_global_rank*"))
        assert rank_dirs, "No rank directory found for dump validation"
        rank_dir = rank_dirs[0]

        files = [p.name for p in rank_dir.iterdir() if p.is_file()]
        assert files, f"No files found in dump directory {rank_dir}"
        files_lower = [n.lower() for n in files]
        has_add = any("add" in n for n in files_lower)
        has_sub = any("sub" in n for n in files_lower)
        assert has_add and has_sub, f"Dump files missing expected ops: add={has_add}, sub={has_sub}; files={files}"


    def test_aclgraph_data_dump_has_saved(self):
        class Network(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, data0, data1):
                mm_01 = torch.mm(data0, data1)
                torch_npu.save_npugraph_tensor(mm_01, save_path="./test.pt")
                sq1 = torch.square(mm_01)
                return mm_01, sq1

        input0 = torch.randn(2, 2, dtype=torch.float16).npu()
        input1 = torch.randn(2, 2, dtype=torch.float16).npu()
        config = torchair.CompilerConfig()
        config.mode = "reduce-overhead"
        config.dump_config.enable_dump = True
        npu_backend = torchair.get_npu_backend(compiler_config=config)
        npu_mode = Network()
        npu_mode = torch.compile(npu_mode, backend=npu_backend)
        npu_mode(input0, input1)
        torch.npu.synchronize()
        from pathlib import Path
        cwd = Path.cwd()
        rank_dirs = sorted(cwd.glob("worldsize*_global_rank*"))
        assert rank_dirs, "No rank directory found for dump validation"
        rank_dir = rank_dirs[0]

        files = [p.name for p in rank_dir.iterdir() if p.is_file()]
        assert files, f"No files found in dump directory {rank_dir}"
        files_lower = [n.lower() for n in files]
        has_mm = any("mm" in n for n in files_lower)
        has_pow = any("pow" in n for n in files_lower)
        assert (not has_mm) and has_pow, f"Dump files missing expected ops: mm={has_mm}, sub={has_pow}; files={files}"


if __name__ == '__main__':
    unittest.main()