import os
import tempfile
import time
import unittest
from unittest.mock import patch

import torch
import torch_npu
import torch_npu.npu.graphs as npu_graphs
from torch_npu.testing.common_utils import SkipIfNotGteCANNVersion
from torch_npu.testing.testcase import run_tests, TestCase


def wait_until(predicate, timeout=5.0, interval=0.01):
    deadline = time.time() + timeout
    while time.time() < deadline:
        if predicate():
            return True
        time.sleep(interval)
    return predicate()


def _resolved_npu_device_index(tensor):
    idx = tensor.device.index
    if idx is None:
        idx = torch.npu.current_device()
    return idx


class TestAclgraphDfx(TestCase):

    def test_npugraph_tensor_ptr_spec_uses_raw_pointer_metadata(self):
        x = torch.arange(6, dtype=torch.float32).reshape(2, 3)

        spec = npu_graphs._make_npugraph_tensor_ptr_spec(x)

        self.assertEqual(spec[0], npu_graphs._NPUGRAPH_TENSOR_PTR_SPEC_MARKER)
        self.assertEqual(spec[1], x.data_ptr())
        self.assertEqual(spec[2], x.numel() * x.element_size())
        self.assertEqual(spec[3], tuple(x.shape))
        self.assertEqual(spec[4], x.dtype)

    def test_materialize_npugraph_tensor_buffer_spec(self):
        expected = torch.arange(6, dtype=torch.float32).reshape(2, 3)
        buffer_spec = (
            npu_graphs._NPUGRAPH_TENSOR_BUFFER_SPEC_MARKER,
            bytearray(expected.numpy().tobytes()),
            tuple(expected.shape),
            expected.dtype,
        )

        actual = npu_graphs._materialize_npugraph_tensor_arg(buffer_spec)

        self.assertEqual(actual, expected)

    def test_materialize_npugraph_tensor_buffer_spec_list(self):
        expected = [
            torch.arange(6, dtype=torch.float32).reshape(2, 3),
            torch.arange(4, dtype=torch.int32).reshape(2, 2),
        ]
        buffer_specs = [
            (
                npu_graphs._NPUGRAPH_TENSOR_BUFFER_SPEC_MARKER,
                bytearray(tensor.numpy().tobytes()),
                tuple(tensor.shape),
                tensor.dtype,
            )
            for tensor in expected
        ]

        actual = npu_graphs._materialize_npugraph_tensor_arg(buffer_specs)

        self.assertEqual(len(actual), len(expected))
        self.assertEqual(actual[0], expected[0])
        self.assertEqual(actual[1], expected[1])

    @SkipIfNotGteCANNVersion("8.5.0")
    def test_print_npugraph_tensor(self):
        torch.npu.set_device(0)
        g = torch.npu.NPUGraph()
        x = torch.arange(6, dtype=torch.float32, device='npu').reshape(2, 3)

        with patch("builtins.print") as mock_print:
            with torch.npu.graph(g):
                torch.ops.npu.print_npugraph_tensor(x, tensor_name="tensor")
            g.replay()
            torch.npu.synchronize()
            self.assertTrue(wait_until(lambda: mock_print.call_count > 0))

        printed_messages = [call.args[0] for call in mock_print.call_args_list if call.args]
        self.assertTrue(any("tensor=tensor(" in msg for msg in printed_messages))
        self.assertTrue(any("shape=(2, 3)" in msg for msg in printed_messages))
        self.assertTrue(any("dtype=torch.float32" in msg for msg in printed_messages))

    @SkipIfNotGteCANNVersion("8.5.0")
    def test_print_npugraph_tensor_with_default_message(self):
        torch.npu.set_device(0)
        g = torch.npu.NPUGraph()
        x = torch.arange(6, dtype=torch.float32, device='npu').reshape(2, 3)

        with patch("builtins.print") as mock_print:
            with torch.npu.graph(g):
                torch.ops.npu.print_npugraph_tensor(x)
            g.replay()
            torch.npu.synchronize()
            self.assertTrue(wait_until(lambda: mock_print.call_count > 0))

        printed_messages = [call.args[0] for call in mock_print.call_args_list if call.args]
        self.assertTrue(any(msg.startswith("tensor(") for msg in printed_messages))

    @SkipIfNotGteCANNVersion("8.5.0")
    def test_print_npugraph_tensor_with_args(self):
        torch.npu.set_device(0)
        g = torch.npu.NPUGraph()
        x = torch.arange(6, dtype=torch.float32, device='npu').reshape(2, 3)

        with patch("builtins.print") as mock_print:
            with torch.npu.graph(g):
                torch_npu.print_npugraph_tensor(x, tensor_name="x")
            g.replay()
            torch.npu.synchronize()
            self.assertTrue(wait_until(lambda: mock_print.call_count > 0))

        printed_messages = [call.args[0] for call in mock_print.call_args_list if call.args]
        self.assertTrue(any("x=tensor(" in msg for msg in printed_messages))
        self.assertTrue(any("shape=(2, 3)" in msg for msg in printed_messages))
        self.assertTrue(any("dtype=torch.float32" in msg for msg in printed_messages))

    @SkipIfNotGteCANNVersion("8.5.0")
    def test_save_npugraph_tensor(self):
        torch.npu.set_device(0)
        first_graph = torch.npu.NPUGraph()
        second_graph = torch.npu.NPUGraph()
        x = torch.arange(6, dtype=torch.float32, device='npu').reshape(2, 3)
        device_index = _resolved_npu_device_index(x)

        with tempfile.TemporaryDirectory() as tmpdir:
            save_path = os.path.join(tmpdir, "tensor.pt")
            expected_counter_path = os.path.join(tmpdir, f"tensor_device_{device_index}_0.pt")
            expected_second_counter_path = os.path.join(tmpdir, f"tensor_device_{device_index}_1.pt")

            with torch.npu.graph(first_graph):
                torch.ops.npu.save_npugraph_tensor(x, save_path=save_path)
            first_graph.replay()
            torch.npu.synchronize()
            self.assertTrue(wait_until(lambda: os.path.exists(expected_counter_path)))
            self.assertEqual(torch.load(expected_counter_path), x.cpu())

            with torch.npu.graph(second_graph):
                torch.ops.npu.save_npugraph_tensor(x, save_path=save_path)
            second_graph.replay()
            torch.npu.synchronize()
            self.assertTrue(wait_until(lambda: os.path.exists(expected_second_counter_path)))
            self.assertEqual(torch.load(expected_second_counter_path), x.cpu())

    @SkipIfNotGteCANNVersion("8.5.0")
    def test_save_npugraph_tensor_overwrite(self):
        torch.npu.set_device(0)
        first_graph = torch.npu.NPUGraph()
        second_graph = torch.npu.NPUGraph()
        x = torch.arange(6, dtype=torch.float32, device='npu').reshape(2, 3)
        y = torch.arange(6, dtype=torch.float32, device='npu').reshape(2, 3) + 1
        device_index = _resolved_npu_device_index(x)

        with tempfile.TemporaryDirectory() as tmpdir:
            save_path = os.path.join(tmpdir, "tensor.pt")
            expected_overwrite_path = os.path.join(tmpdir, f"tensor_device_{device_index}.pt")
            unexpected_counter_path = os.path.join(tmpdir, f"tensor_device_{device_index}_0.pt")

            with torch.npu.graph(first_graph):
                torch.ops.npu.save_npugraph_tensor(x, save_path=save_path, overwrite=True)
            first_graph.replay()
            torch.npu.synchronize()
            self.assertTrue(wait_until(lambda: os.path.exists(expected_overwrite_path)))
            self.assertFalse(os.path.exists(unexpected_counter_path))
            self.assertEqual(torch.load(expected_overwrite_path), x.cpu())

            with torch.npu.graph(second_graph):
                torch.ops.npu.save_npugraph_tensor(y, save_path=save_path, overwrite=True)
            second_graph.replay()
            torch.npu.synchronize()
            self.assertEqual(torch.load(expected_overwrite_path), y.cpu())

    @SkipIfNotGteCANNVersion("8.5.0")
    def test_save_npugraph_tensor_with_default_save_path(self):
        torch.npu.set_device(0)
        g = torch.npu.NPUGraph()
        x = torch.arange(6, dtype=torch.float32, device='npu').reshape(2, 3)
        device_index = _resolved_npu_device_index(x)

        with tempfile.TemporaryDirectory() as tmpdir:
            original_cwd = os.getcwd()
            try:
                os.chdir(tmpdir)
                with torch.npu.graph(g):
                    torch.ops.npu.save_npugraph_tensor(x)
                g.replay()
                torch.npu.synchronize()

                def default_saved_files():
                    return [
                        file_name for file_name in os.listdir(tmpdir)
                        if file_name.startswith("tensor_") and file_name.endswith(".pt")
                    ]

                self.assertTrue(wait_until(lambda: len(default_saved_files()) == 1))
                [saved_file] = default_saved_files()
                self.assertIn(f"_device_{device_index}_0.pt", saved_file)
                self.assertEqual(torch.load(os.path.join(tmpdir, saved_file)), x.cpu())
            finally:
                os.chdir(original_cwd)

    @SkipIfNotGteCANNVersion("8.5.0")
    def test_save_npugraph_tensor_tensor_list(self):
        torch.npu.set_device(0)
        g = torch.npu.NPUGraph()
        x = torch.arange(6, dtype=torch.float32, device='npu').reshape(2, 3)
        y = torch.arange(4, dtype=torch.float16, device='npu').reshape(2, 2)
        device_index = _resolved_npu_device_index(x)

        with tempfile.TemporaryDirectory() as tmpdir:
            save_path = os.path.join(tmpdir, "tensor_list.pt")
            expected_path = os.path.join(tmpdir, f"tensor_list_device_{device_index}_0.pt")

            with torch.npu.graph(g):
                torch.ops.npu.save_npugraph_tensor.tensor_list([x, y], save_path=save_path)
            g.replay()
            torch.npu.synchronize()
            self.assertTrue(wait_until(lambda: os.path.exists(expected_path)))

            saved = torch.load(expected_path)
            self.assertEqual(len(saved), 2)
            self.assertEqual(saved[0], x.cpu())
            self.assertEqual(saved[1], y.cpu())


if __name__ == '__main__':
    run_tests()