import os
from typing import Callable
from torch_npu.utils.utils import _print_error_log


def print_check_msg(msg: str):
    pid = os.getpid()
    print(f"[sanitizer]({pid}) {msg}")


class CallbackRegistry:
    def __init__(self, name: str):
        self.name = name
        self.callback_list = []

    def add_callback(self, cb: Callable, cb_name: str) -> None:
        self.callback_list.append((cb, cb_name))

    def fire_callbacks(self, *args, **kwargs) -> None:
        for cb, cb_name in self.callback_list:
            try:
                cb(*args, **kwargs)
            except Exception as e:
                _print_error_log(
                    f"Exception in callback {cb_name} for {self.name} registered with NPU trace"
                )


NPUACLStartExecuteCallbacks: "CallbackRegistry" = CallbackRegistry(
    "[kernel check] NPU acl start execution"
)
NPUACLFinishExecuteCallbacks: "CallbackRegistry" = CallbackRegistry(
    "[kernel check] NPU acl finish execution"
)
NPUEventCreationCallbacks: "CallbackRegistry" = CallbackRegistry(
    "[stream check] NPU event creation"
)
NPUEventDeletionCallbacks: "CallbackRegistry" = CallbackRegistry(
    "[stream check] NPU event deletion"
)
NPUEventRecordCallbacks: "CallbackRegistry" = CallbackRegistry(
    "[stream check] NPU event record"
)
NPUEventWaitCallbacks: "CallbackRegistry" = CallbackRegistry(
    "[stream check] NPU event wait"
)
NPUMemoryAllocationCallbacks: "CallbackRegistry" = CallbackRegistry(
    "[stream check] NPU memory allocation"
)
NPUMemoryDeallocationCallbacks: "CallbackRegistry" = CallbackRegistry(
    "[stream check] NPU memory deallocation"
)
NPUStreamCreationCallbacks: "CallbackRegistry" = CallbackRegistry(
    "[stream check] NPU stream creation"
)
NPUDeviceSynchronizationCallbacks: "CallbackRegistry" = CallbackRegistry(
    "[stream check] NPU device synchronization"
)
NPUStreamSynchronizationCallbacks: "CallbackRegistry" = CallbackRegistry(
    "[stream check] NPU stream synchronization"
)
NPUEventSynchronizationCallbacks: "CallbackRegistry" = CallbackRegistry(
    "[stream check] NPU event synchronization"
)
NPURecordStreamCallbacks: "CallbackRegistry" = CallbackRegistry(
    "[stream check] NPU record_stream"
)
NPUEraseStreamCallbacks: "CallbackRegistry" = CallbackRegistry(
        "[stream check] NPU erase_stream"
)


def register_callback_for_acl_start_execution(cb: Callable[[str], None], cb_name: str) -> None:
    NPUACLStartExecuteCallbacks.add_callback(cb, cb_name)


def register_callback_for_acl_finish_execution(cb: Callable[[str], None], cb_name: str) -> None:
    NPUACLFinishExecuteCallbacks.add_callback(cb, cb_name)


def register_callback_for_npu_event_creation(cb: Callable[[int], None], cb_name: str) -> None:
    NPUEventCreationCallbacks.add_callback(cb, cb_name)


def register_callback_for_npu_event_deletion(cb: Callable[[int], None], cb_name: str) -> None:
    NPUEventDeletionCallbacks.add_callback(cb, cb_name)


def register_callback_for_npu_event_record(cb: Callable[[int, int], None], cb_name: str) -> None:
    NPUEventRecordCallbacks.add_callback(cb, cb_name)


def register_callback_for_npu_event_wait(cb: Callable[[int, int], None], cb_name: str) -> None:
    NPUEventWaitCallbacks.add_callback(cb, cb_name)


def register_callback_for_npu_memory_allocation(cb: Callable[[int], None], cb_name: str) -> None:
    NPUMemoryAllocationCallbacks.add_callback(cb, cb_name)


def register_callback_for_npu_memory_deallocation(cb: Callable[[int], None], cb_name: str) -> None:
    NPUMemoryDeallocationCallbacks.add_callback(cb, cb_name)


def register_callback_for_npu_stream_creation(cb: Callable[[int], None], cb_name: str) -> None:
    NPUStreamCreationCallbacks.add_callback(cb, cb_name)


def register_callback_for_npu_device_synchronization(cb: Callable[[], None], cb_name: str) -> None:
    NPUDeviceSynchronizationCallbacks.add_callback(cb, cb_name)


def register_callback_for_npu_stream_synchronization(cb: Callable[[int], None], cb_name: str) -> None:
    NPUStreamSynchronizationCallbacks.add_callback(cb, cb_name)


def register_callback_for_npu_event_synchronization(cb: Callable[[int], None], cb_name: str) -> None:
    NPUEventSynchronizationCallbacks.add_callback(cb, cb_name)


def register_callback_for_npu_record_stream(cb: Callable[[int, int], None], cb_name: str) -> None:
    """Register callback for record_stream calls.

    Args:
        cb: Callback function taking (data_ptr, stream_id) as arguments
        cb_name: Name of the callback for debugging
    """
    NPURecordStreamCallbacks.add_callback(cb, cb_name)


def register_callback_for_npu_erase_stream(
        cb: Callable[[int, int], None], cb_name: str
) -> None:
    """Register callback for eraseStream calls."""
    NPUEraseStreamCallbacks.add_callback(cb, cb_name)