import os
import atexit

import torch.cuda._sanitizer as csan
import torch_npu
import torch_npu.utils._npu_trace as npu_trace
import torch_npu.npu._stream_check as stream_check
import torch_npu.npu._kernel_check as kernel_check
from torch_npu.utils.utils import _print_warn_log
from torch_npu.npu._stream_check import apply_sanitizer_patch


class SanitizerMode:
    STREAM = 0
    KERNEL = 1


class NPUSanitizer:

    def __init__(self):
        self.event_handler = None
        self.dispatch = None
        self.kernel_path_manager = None
        self.mode = SanitizerMode.STREAM
        self.opp_debug_path = os.path.join(os.getcwd(), "opp_debug_path")
        self.opp_debug_kernel_path = os.getenv('ASCEND_OPP_DEBUG_PATH')
        self.enabled = False

    def enable(self):
        if self.opp_debug_kernel_path:
            success = self.enable_kernel_check()
            self.mode = SanitizerMode.KERNEL
        else:
            success = self.enable_stream_check()
            self.mode = SanitizerMode.STREAM
        if not self.enabled and success:
            torch_npu._C._activate_npu_trace(self.mode)
            self.enabled = True

    def enable_kernel_check(self) -> bool:
        if not self.opp_debug_kernel_path:
            _print_warn_log("ASCEND_OPP_DEBUG_PATH is not set! TORCH_NPU_SANITIZER takes no effect!")
            return False
        self.kernel_path_manager = kernel_check.KernelPathManager()
        if not os.path.exists(self.opp_debug_path):
            return False
        self.event_handler = kernel_check.EventHandler()
        npu_trace.register_callback_for_acl_start_execution(
            self.event_handler._handle_acl_start_execution,
            "handle_acl_start_execution"
        )
        npu_trace.register_callback_for_acl_finish_execution(
            self.event_handler._handle_acl_finish_execution,
            "handle_acl_finish_execution"
        )
        return True

    def enable_stream_check(self) -> bool:
        self.event_handler = csan.EventHandler()
        self.dispatch = stream_check.NPUSanitizerDispatchMode(self.event_handler)
        self.dispatch.__enter__()
        npu_trace.register_callback_for_npu_event_creation(
            self.event_handler._handle_event_creation,
            "handle_event_creation"
        )
        npu_trace.register_callback_for_npu_event_deletion(
            self.event_handler._handle_event_deletion,
            "handle_event_deletion"
        )
        npu_trace.register_callback_for_npu_event_record(
            self.event_handler._handle_event_record,
            "handle_event_record"
        )
        npu_trace.register_callback_for_npu_event_wait(
            self.event_handler._handle_event_wait,
            "handle_event_wait"
        )
        npu_trace.register_callback_for_npu_memory_allocation(
            self.event_handler._handle_memory_allocation,
            "handle_memory_allocation"
        )
        npu_trace.register_callback_for_npu_memory_deallocation(
            self.event_handler._handle_memory_deallocation,
            "handle_memory_deallocation"
        )
        npu_trace.register_callback_for_npu_stream_creation(
            self.event_handler._handle_stream_creation,
            "handle_stream_creation"
        )
        npu_trace.register_callback_for_npu_device_synchronization(
            self.event_handler._handle_device_synchronization,
            "handle_device_synchronization"
        )
        npu_trace.register_callback_for_npu_stream_synchronization(
            self.event_handler._handle_stream_synchronization,
            "handle_stream_synchronization"
        )
        npu_trace.register_callback_for_npu_event_synchronization(
            self.event_handler._handle_event_synchronization,
            "handle_event_synchronization"
        )
        return True

    def __del__(self):
        if self.dispatch:
            self.dispatch.__exit__(None, None, None)

    def clear_debug_env(self):
        if self.kernel_path_manager:
            self.kernel_path_manager.clear_debug_env()


def enable_npu_sanitizer():
    apply_sanitizer_patch()
    npu_sanitizer.enable()


npu_sanitizer = NPUSanitizer()

atexit.register(npu_sanitizer.clear_debug_env)