# Owner(s): ["module: npu"]
"""
Tests for NPU Sanitizer record_stream detection.

record_stream detection is separate from data race detection:
  - Data race: detected at kernel launch time (raises CUDASanitizerErrors)
  - Missing record_stream: detected at deallocation or via flush_record_stream_warnings()

Per PyTorch docs (torch.Tensor.record_stream), record_stream is NOT needed when
creation_stream has synced with usage_stream before tensor deallocation:
  - creation_stream.wait_stream(usage_stream)
  - creation_stream.wait_event(event recorded on usage_stream)
  - torch_npu.npu.synchronize() (device-level sync covers all directions)

Conversely, usage_stream.wait_stream(creation_stream) only resolves data races
but does NOT guarantee memory safety — record_stream is still needed.

Test matrix:
┌───────────────────────────────────┬──────────────┬───────────────────────┐
│ Scenario                          │ Data race?   │ record_stream needed? │
├───────────────────────────────────┼──────────────┼───────────────────────┤
│ No sync at all                    │ Yes          │ Yes                   │
│ usage.wait(creation) only         │ No           │ Yes                   │
│ creation.wait(usage) only         │ Possible     │ No                    │
│ Both directions synced            │ No           │ No                    │
│ device synchronize                │ No           │ No                    │
│ record_stream called              │ Unresolved   │ No (recorded)         │
│ Same stream                       │ No           │ No                    │
└───────────────────────────────────┴──────────────┴───────────────────────┘
"""

import gc
import os

import torch
import torch.cuda._sanitizer as csan
import torch.distributed as dist
   
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests


def setup_sanitizer():
    """Enable sanitizer with record_stream checking."""
    os.environ['TORCH_NPU_SANITIZER'] = '1'
    import torch_npu.npu._sanitizer as sanitizer
    if not sanitizer.npu_sanitizer.enabled:
        sanitizer.npu_sanitizer.enable()


def reset_sanitizer():
    """Reset sanitizer state between tests."""
    import torch_npu.npu._sanitizer as sanitizer
    if sanitizer.npu_sanitizer.dispatch is not None:
        try:
            sanitizer.npu_sanitizer.dispatch.__exit__(None, None, None)
        except Exception:
            pass
        sanitizer.npu_sanitizer.dispatch = None
    sanitizer.npu_sanitizer.event_handler = None
    sanitizer.npu_sanitizer.enabled = False


def get_event_handler():
    import torch_npu.npu._sanitizer as sanitizer
    return sanitizer.npu_sanitizer.event_handler


class SanitizerRecordStreamTestBase(TestCase):
    def setUp(self):
        reset_sanitizer()
        setup_sanitizer()

    def tearDown(self):
        reset_sanitizer()
        if dist.is_available() and dist.is_initialized():
            dist.destroy_process_group()

    @staticmethod
    def _get_tracked_tensor_info(tensor):
        storage_ptr = tensor.untyped_storage().data_ptr()
        info = get_event_handler()._npu_tensors.get(storage_ptr)
        if info is None:
            raise AssertionError(
                f"Tensor storage {storage_ptr} is not tracked by NPU sanitizer."
            )
        return storage_ptr, info


class TestDataRaceDetection(SanitizerRecordStreamTestBase):
    def test_write_read_race(self):
        """Unsynchronized cross-stream read after write should raise data race."""
        x = torch.randn(100, device="npu")
        stream = torch_npu.npu.Stream()

        with self.assertRaises(csan.CUDASanitizerErrors):
            with torch_npu.npu.stream(stream):
                _ = x + 1

    def test_record_stream_does_not_fix_data_race(self):
        """record_stream should not hide a real cross-stream data race."""
        x = torch.randn(100, device="npu")
        stream = torch_npu.npu.Stream()
        x.record_stream(stream)

        with self.assertRaises(csan.CUDASanitizerErrors):
            with torch_npu.npu.stream(stream):
                _ = x + 1


class TestMissingRecordStream(SanitizerRecordStreamTestBase):
    def test_cross_stream_no_record_stream(self):
        """Cross-stream use without record_stream should report missing record_stream."""
        x = torch.randn(100, device="npu")
        stream = torch_npu.npu.Stream()
        default_stream = torch_npu.npu.default_stream()

        stream.wait_stream(default_stream)
        with torch_npu.npu.stream(stream):
            _ = x + 1

        warnings = get_event_handler().flush_record_stream_warnings()
        self.assertGreater(len(warnings), 0)

    def test_multiple_streams_need_record_stream(self):
        """Each non-creation stream needs its own record_stream coverage."""
        x = torch.randn(100, device="npu")
        stream1 = torch_npu.npu.Stream()
        stream2 = torch_npu.npu.Stream()
        default_stream = torch_npu.npu.default_stream()

        stream1.wait_stream(default_stream)
        with torch_npu.npu.stream(stream1):
            _ = x + 1

        stream2.wait_stream(default_stream)
        with torch_npu.npu.stream(stream2):
            _ = x + 2

        warnings = get_event_handler().flush_record_stream_warnings()
        self.assertGreaterEqual(len(warnings), 2)


class TestRecordStreamNotNeeded(SanitizerRecordStreamTestBase):
    def test_record_stream_suppresses_warning(self):
        """record_stream should suppress missing-record_stream warning for that stream."""
        x = torch.randn(100, device="npu")
        stream = torch_npu.npu.Stream()
        default_stream = torch_npu.npu.default_stream()

        x.record_stream(stream)
        stream.wait_stream(default_stream)
        with torch_npu.npu.stream(stream):
            _ = x + 1

        warnings = get_event_handler().flush_record_stream_warnings()
        self.assertEqual(len(warnings), 0)

    def test_creation_waits_usage_via_wait_stream(self):
        """creation_stream.wait_stream(usage_stream) should make record_stream unnecessary."""
        x = torch.randn(100, device="npu")
        stream = torch_npu.npu.Stream()
        default_stream = torch_npu.npu.default_stream()

        stream.wait_stream(default_stream)
        with torch_npu.npu.stream(stream):
            _ = x + 1

        default_stream.wait_stream(stream)
        warnings = get_event_handler().flush_record_stream_warnings()
        self.assertEqual(len(warnings), 0)

    def test_creation_waits_usage_via_event(self):
        """creation_stream.wait_event(event_on_usage) should make record_stream unnecessary."""
        x = torch.randn(100, device="npu")
        stream = torch_npu.npu.Stream()
        default_stream = torch_npu.npu.default_stream()

        stream.wait_stream(default_stream)
        with torch_npu.npu.stream(stream):
            _ = x + 1
            event = torch_npu.npu.Event()
            event.record(stream)

        default_stream.wait_event(event)
        warnings = get_event_handler().flush_record_stream_warnings()
        self.assertEqual(len(warnings), 0)

    def test_device_sync(self):
        """Device synchronize should cover prior cross-stream uses."""
        x = torch.randn(100, device="npu")
        stream = torch_npu.npu.Stream()
        default_stream = torch_npu.npu.default_stream()

        stream.wait_stream(default_stream)
        with torch_npu.npu.stream(stream):
            _ = x + 1

        torch_npu.npu.synchronize()

        warnings = get_event_handler().flush_record_stream_warnings()
        self.assertEqual(len(warnings), 0)


class TestRecordStreamSequenceBoundaries(SanitizerRecordStreamTestBase):
    def test_creation_waits_usage_only_covers_prior_uses(self):
        """A mid-sequence reverse wait should not cover later cross-stream uses."""
        x = torch.randn(100, device="npu")
        stream = torch_npu.npu.Stream()
        default_stream = torch_npu.npu.default_stream()

        stream.wait_stream(default_stream)
        with torch_npu.npu.stream(stream):
            _ = x + 1

        default_stream.wait_stream(stream)
        with torch_npu.npu.stream(stream):
            _ = x + 2

        warnings = get_event_handler().flush_record_stream_warnings()
        self.assertGreater(len(warnings), 0)

    def test_creation_wait_event_only_covers_prior_uses(self):
        """A mid-sequence event wait should not cover later cross-stream uses."""
        x = torch.randn(100, device="npu")
        stream = torch_npu.npu.Stream()
        default_stream = torch_npu.npu.default_stream()

        stream.wait_stream(default_stream)
        with torch_npu.npu.stream(stream):
            _ = x + 1
            event = torch_npu.npu.Event()
            event.record(stream)

        default_stream.wait_event(event)
        with torch_npu.npu.stream(stream):
            _ = x + 2
        warnings = get_event_handler().flush_record_stream_warnings()
        self.assertGreater(len(warnings), 0)

    def test_record_stream_partial_coverage(self):
        """record_stream for one stream should not cover another stream."""
        x = torch.randn(100, device="npu")
        stream1 = torch_npu.npu.Stream()
        stream2 = torch_npu.npu.Stream()
        default_stream = torch_npu.npu.default_stream()

        x.record_stream(stream1)

        stream1.wait_stream(default_stream)
        with torch_npu.npu.stream(stream1):
            _ = x + 1

        stream2.wait_stream(default_stream)
        with torch_npu.npu.stream(stream2):
            _ = x + 2

        warnings = get_event_handler().flush_record_stream_warnings()
        stream1_warnings = [
            w for w in warnings if w.usage_stream == int(stream1.npu_stream)
        ]
        stream2_warnings = [
            w for w in warnings if w.usage_stream == int(stream2.npu_stream)
        ]
        self.assertEqual(len(stream1_warnings), 0)
        self.assertGreater(len(stream2_warnings), 0)


class TestViewAndSlice(SanitizerRecordStreamTestBase):
    def test_view_cross_stream_no_record_stream_warns(self):
        """Cross-stream use of a view should be tracked at storage level."""
        x = torch.randn(100, device="npu")
        view = x[10:50]
        stream = torch_npu.npu.Stream()
        default_stream = torch_npu.npu.default_stream()

        stream.wait_stream(default_stream)
        with torch_npu.npu.stream(stream):
            _ = view + 1

        warnings = get_event_handler().flush_record_stream_warnings()
        self.assertGreater(len(warnings), 0)

    def test_full_record_stream_covers_view_use(self):
        """record_stream on base tensor should cover view usage."""
        x = torch.randn(100, device="npu")
        stream = torch_npu.npu.Stream()
        default_stream = torch_npu.npu.default_stream()

        x.record_stream(stream)
        view = x[10:50]

        stream.wait_stream(default_stream)
        with torch_npu.npu.stream(stream):
            _ = view + 1

        warnings = get_event_handler().flush_record_stream_warnings()
        self.assertEqual(len(warnings), 0)

    def test_view_record_stream_covers_full_use(self):
        """record_stream on a view should cover base tensor usage."""
        x = torch.randn(100, device="npu")
        view = x[10:50]
        stream = torch_npu.npu.Stream()
        default_stream = torch_npu.npu.default_stream()

        view.record_stream(stream)
        stream.wait_stream(default_stream)
        with torch_npu.npu.stream(stream):
            _ = x + 1

        warnings = get_event_handler().flush_record_stream_warnings()
        self.assertEqual(len(warnings), 0)


class TestMemoryReuse(SanitizerRecordStreamTestBase):
    def test_no_stale_recorded_streams_after_realloc(self):
        """A new allocation should not inherit old recorded-stream state."""
        stream = torch_npu.npu.Stream()
        default_stream = torch_npu.npu.default_stream()

        x = torch.randn(100, device="npu")
        x.record_stream(stream)
        del x

        y = torch.randn(100, device="npu")
        stream.wait_stream(default_stream)

        with torch_npu.npu.stream(stream):
            _ = y + 1

        warnings = get_event_handler().flush_record_stream_warnings()
        self.assertGreater(len(warnings), 0)

    def test_realloc_with_explicit_record_stream(self):
        """A reallocated tensor with explicit record_stream should not warn."""
        stream = torch_npu.npu.Stream()
        default_stream = torch_npu.npu.default_stream()

        x = torch.randn(100, device="npu")
        del x

        y = torch.randn(100, device="npu")
        y.record_stream(stream)

        stream.wait_stream(default_stream)
        with torch_npu.npu.stream(stream):
            _ = y + 1

        warnings = get_event_handler().flush_record_stream_warnings()
        self.assertEqual(len(warnings), 0)


class TestSanitizerDisabled(TestCase):
    """Behavior when sanitizer is disabled."""
    def setUp(self):
        reset_sanitizer()
        os.environ.pop("TORCH_NPU_SANITIZER", None)

    def test_no_errors_when_disabled(self):
        """Disabled sanitizer should not report cross-stream issues."""
        x = torch.randn(100, device="npu")
        stream = torch_npu.npu.Stream()
        error_raised = False
        try:
            with torch_npu.npu.stream(stream):
                _ = x + 1
            torch_npu.npu.synchronize()
        except Exception:
            error_raised = True

        self.assertFalse(error_raised)


class TestFlushBehavior(SanitizerRecordStreamTestBase):
    def test_dealloc_records_into_error_log(self):
        """Deallocation-time missing-record_stream warnings should be retained."""
        stream = torch_npu.npu.Stream()
        default_stream = torch_npu.npu.default_stream()

        x = torch.randn(100, device="npu")
        stream.wait_stream(default_stream)
        with torch_npu.npu.stream(stream):
            _ = x + 1
        del x
        gc.collect()
        self.assertGreater(len(get_event_handler().record_stream_errors), 0)


if __name__ == "__main__":
    run_tests()