__all__ = [
    "is_current_stream_capturing",
    "graph_pool_handle",
    "graph_task_group_begin",
    "graph_task_group_end",
    "graph_task_update_begin",
    "graph_task_update_end",
    "NPUGraph",
    "graph",
    "make_graphed_callables",
    "super_kernel_scope_begin",
    "super_kernel_scope_end",
]

import gc
import os
import re
import threading
import typing
from copy import deepcopy
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, Any, Optional, Tuple

import torch
from torch.fx.node import has_side_effect

import torch_npu._C
from torch_npu._C import _weak_ref_tensor as TensorWeakRef
from torch_npu.npu._npugraph_handlers.npugraph_handler import _NPU_GRAPH_OP_HANDLERS
from torch_npu.utils._error_code import ErrCode, pta_error
from torch_npu._compiler._config import force_npugraph_gc
from .utils import _dummy_type


if not hasattr(torch_npu._C, "_NPUStreamBase"):
    # Define dummy base classes
    torch_npu._C.__dict__["_NPUGraph"] = _dummy_type("_NPUGraph")
    torch_npu._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle")
    torch_npu._C.__dict__["_npu_isCurrentStreamCapturing"] = _dummy_type(
        "_npu_isCurrentStreamCapturing"
    )
    torch_npu._C.__dict__["_graph_task_group_begin"] = _dummy_type("_graph_task_group_begin")
    torch_npu._C.__dict__["_graph_task_group_end"] = _dummy_type("_graph_task_group_end")
    torch_npu._C.__dict__["_graph_task_update_begin"] = _dummy_type("_graph_task_update_begin")
    torch_npu._C.__dict__["_graph_task_update_end"] = _dummy_type("_graph_task_update_end")
    torch_npu._C.__dict__["_super_kernel_scope_begin"] = _dummy_type("_super_kernel_scope_begin")
    torch_npu._C.__dict__["_super_kernel_scope_end"] = _dummy_type("_super_kernel_scope_end")

from torch_npu._C import (  # noqa: F401
    _npu_isCurrentStreamCapturing,
    _NPUGraph,
    _graph_pool_handle,
    _graph_task_group_begin,
    _graph_task_group_end,
    _graph_task_update_begin,
    _graph_task_update_end,
    _super_kernel_scope_begin,
    _super_kernel_scope_end,
)


log = torch._logging.getArtifactLogger("torch_npu.npugraph", "cudagraphs")


def is_current_stream_capturing():
    r"""Return True if NPU graph capture is underway on the current NPU stream, False otherwise.

    If a NPU context does not exist on the current device, returns False without initializing the context.
    """
    return _npu_isCurrentStreamCapturing()


# Python shim helps Sphinx process docstrings more reliably.
def graph_pool_handle():
    r"""Return an opaque token representing the id of a graph memory pool.

    See :ref:`Graph memory management<graph-memory-management>`.

    .. warning::
        This API is in beta and may change in future releases.
    """
    return _graph_pool_handle()


def graph_task_group_begin(stream):
    _graph_task_group_begin(stream)


def graph_task_group_end(stream):
    return _graph_task_group_end(stream)


def graph_task_update_begin(stream, handle):
    _graph_task_update_begin(stream, handle)


def graph_task_update_end(stream):
    _graph_task_update_end(stream)


def _super_kernel_scope_begin_impl(scope_name: Optional[str] = None) -> None:
    if scope_name is not None and not scope_name.strip():
        raise RuntimeError(
            "scope_name should be None or a non-empty string.",
            pta_error(ErrCode.PARAM),
        )
    _super_kernel_scope_begin(scope_name)


def _super_kernel_scope_end_impl(scope_name: Optional[str] = None) -> None:
    if scope_name is not None and not scope_name.strip():
        raise RuntimeError(
            "scope_name should be None or a non-empty string.",
            pta_error(ErrCode.PARAM),
        )
    _super_kernel_scope_end(scope_name)


_save_npugraph_tensor_lock = threading.Lock()
_save_npugraph_tensor_counters = {}
_save_tensor_streams: Dict[int, "torch_npu.npu.Stream"] = {}
_save_tensor_stream_lock = threading.Lock()
_NPUGRAPH_TENSOR_PTR_SPEC_MARKER = "host_tensor_ptr"
_NPUGRAPH_TENSOR_BUFFER_SPEC_MARKER = "host_tensor_buffer"


def _get_save_tensor_stream(device_index: int):
    with _save_tensor_stream_lock:
        if device_index not in _save_tensor_streams:
            _save_tensor_streams[device_index] = torch.npu.Stream(
                device=device_index, priority=-1
            )
        return _save_tensor_streams[device_index]


def _build_save_npugraph_tensor_path(save_path=None, device_index=None, overwrite=False):
    if save_path is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
        save_path = f"tensor_{timestamp}.pt"
    if not isinstance(save_path, str):
        raise TypeError(f"save_path must be str or None, but got {type(save_path).__name__}")
    if not isinstance(overwrite, bool):
        raise TypeError(f"overwrite must be bool, but got {type(overwrite).__name__}")

    directory, filename = os.path.split(save_path)
    stem, suffix = os.path.splitext(filename)
    if not stem:
        raise ValueError("save_path must include a file name")
    if not suffix:
        suffix = ".pt"

    file_stem = stem
    if device_index is not None:
        file_stem = f"{file_stem}_device_{device_index}"

    if overwrite:
        final_name = f"{file_stem}{suffix}"
        return os.path.join(directory, final_name) if directory else final_name

    counter_key = os.path.join(directory, f"{file_stem}{suffix}")
    with _save_npugraph_tensor_lock:
        file_index = _save_npugraph_tensor_counters.get(counter_key, 0)
        _save_npugraph_tensor_counters[counter_key] = file_index + 1

    final_name = f"{file_stem}_{file_index}{suffix}"
    return os.path.join(directory, final_name) if directory else final_name


def _print_callback_pending(tensor_name, tensor_arg):
    tensor_arg = _materialize_npugraph_tensor_arg(tensor_arg)
    output = str(tensor_arg)
    if tensor_name is not None:
        output = f"{tensor_name}={output}"
    output = f"{output}, shape={tuple(tensor_arg.shape)}, dtype={tensor_arg.dtype}"
    print(output, flush=True)


def _save_callback_pending(tensor_arg, str_arg):
    tensor_arg = _materialize_npugraph_tensor_arg(tensor_arg)
    torch.save(tensor_arg, str_arg)


def _make_npugraph_tensor_ptr_spec(tensor):
    return (
        _NPUGRAPH_TENSOR_PTR_SPEC_MARKER,
        tensor.data_ptr(),
        tensor.numel() * tensor.element_size(),
        tuple(tensor.shape),
        tensor.dtype,
    )


def _materialize_npugraph_tensor_arg(tensor_arg):
    if (
        isinstance(tensor_arg, tuple)
        and len(tensor_arg) == 4
        and tensor_arg[0] == _NPUGRAPH_TENSOR_BUFFER_SPEC_MARKER
    ):
        _, buffer, shape, dtype = tensor_arg
        if len(buffer) == 0:
            return torch.empty(shape, dtype=dtype)
        return torch.frombuffer(buffer, dtype=dtype).reshape(shape)
    if isinstance(tensor_arg, list):
        return [_materialize_npugraph_tensor_arg(arg) for arg in tensor_arg]
    return tensor_arg


def _validate_tensor_list(inputs):
    if not isinstance(inputs, (list, tuple)):
        raise TypeError(f"input must be Tensor or TensorList, but got {type(inputs).__name__}")
    if len(inputs) == 0:
        raise ValueError("input tensor list must not be empty")
    if not all(isinstance(tensor, torch.Tensor) for tensor in inputs):
        raise TypeError("all elements in input tensor list must be Tensor")
    first_device = inputs[0].device
    for tensor in inputs[1:]:
        if tensor.device != first_device:
            raise ValueError("all tensors in input tensor list must be on the same device")
    return first_device


def _print_npugraph_tensor_impl(input, tensor_name=None):
    if not isinstance(input, torch.Tensor):
        return

    device = input.device
    if device.type == "cpu":
        _print_callback_pending(tensor_name, input)
        return

    if device.type != "npu":
        return

    device_index = device.index
    save_stream = _get_save_tensor_stream(device_index)

    # Record event on the original compute stream before switching
    event1 = torch.npu.Event()
    event2 = torch.npu.Event()
    event1.record()

    with torch.npu.stream(save_stream):
        # Wait for the original stream to complete before D2H
        event1.wait()
        cpu_tensor = input.to("cpu", non_blocking=True)
        cpu_arg = _make_npugraph_tensor_ptr_spec(cpu_tensor)
        # Get current stream inside context (which is save_stream)
        current_stream = torch.npu.current_stream()
        torch_npu.npu._launch_host_func_pending(
            current_stream,
            _print_callback_pending,
            (tensor_name, cpu_arg),
        )
        # Mark save_stream completion
        event2.record()

    # Wait for save_stream to complete (back to original stream now)
    event2.wait()


def _save_npugraph_tensor_impl(input, save_path=None, overwrite=False):
    if not isinstance(input, torch.Tensor):
        return

    device = input.device
    if device.type == "cpu":
        torch.save(input, _build_save_npugraph_tensor_path(save_path, overwrite=overwrite))
        return

    if device.type != "npu":
        return

    device_index = device.index
    save_stream = _get_save_tensor_stream(device_index)
    final_path = _build_save_npugraph_tensor_path(save_path, device_index, overwrite)

    # Record event on the original compute stream before switching
    event1 = torch.npu.Event()
    event2 = torch.npu.Event()
    event1.record()

    with torch.npu.stream(save_stream):
        # Wait for the original stream to complete before D2H
        event1.wait()
        cpu_tensor = input.to("cpu", non_blocking=True)
        cpu_arg = _make_npugraph_tensor_ptr_spec(cpu_tensor)
        # Get current stream inside context (which is save_stream)
        current_stream = torch.npu.current_stream()
        torch_npu.npu._launch_host_func_pending(
            current_stream,
            _save_callback_pending,
            (cpu_arg, final_path),
        )
        # Mark save_stream completion
        event2.record()

    # Wait for save_stream to complete (back to original stream now)
    event2.wait()


def _save_npugraph_tensor_tensor_list_impl(input, save_path=None, overwrite=False):
    device = _validate_tensor_list(input)
    if device.type == "cpu":
        torch.save(list(input), _build_save_npugraph_tensor_path(save_path, overwrite=overwrite))
        return

    if device.type != "npu":
        return

    device_index = device.index
    save_stream = _get_save_tensor_stream(device_index)
    final_path = _build_save_npugraph_tensor_path(save_path, device_index, overwrite)

    # Record event on the original compute stream before switching
    event1 = torch.npu.Event()
    event2 = torch.npu.Event()
    event1.record()

    with torch.npu.stream(save_stream):
        # Wait for the original stream to complete before D2H
        event1.wait()
        cpu_tensors = [tensor.to("cpu", non_blocking=True) for tensor in input]
        cpu_args = [_make_npugraph_tensor_ptr_spec(tensor) for tensor in cpu_tensors]
        # Get current stream inside context (which is save_stream)
        current_stream = torch.npu.current_stream()
        torch_npu.npu._launch_host_func_pending(
            current_stream,
            _save_callback_pending,
            (cpu_args, final_path),
        )
        # Mark save_stream completion
        event2.record()

    # Wait for save_stream to complete (back to original stream now)
    event2.wait()


_npu_lib = torch.library.Library("npu", "FRAGMENT")
if not hasattr(torch.ops.npu, "super_kernel_scope_begin"):
    _npu_lib.define("super_kernel_scope_begin(str? scope_name=None) -> ()")

    _npu_lib.impl("super_kernel_scope_begin", _super_kernel_scope_begin_impl, "PrivateUse1")
    _npu_lib.impl("super_kernel_scope_begin", _super_kernel_scope_begin_impl, "BackendSelect")
    _npu_lib.impl("super_kernel_scope_begin", _super_kernel_scope_begin_impl, "CPU")

    has_side_effect(torch.ops.npu.super_kernel_scope_begin.default)

    @torch.library.register_fake("npu::super_kernel_scope_begin")
    def _super_kernel_scope_begin_meta(scope_name: Optional[str] = None):
        pass


if not hasattr(torch.ops.npu, "super_kernel_scope_end"):
    _npu_lib.define("super_kernel_scope_end(str? scope_name=None) -> ()")

    _npu_lib.impl("super_kernel_scope_end", _super_kernel_scope_end_impl, "PrivateUse1")
    _npu_lib.impl("super_kernel_scope_end", _super_kernel_scope_end_impl, "BackendSelect")
    _npu_lib.impl("super_kernel_scope_end", _super_kernel_scope_end_impl, "CPU")

    has_side_effect(torch.ops.npu.super_kernel_scope_end.default)

    @torch.library.register_fake("npu::super_kernel_scope_end")
    def _super_kernel_scope_end_meta(scope_name: Optional[str] = None):
        pass


if not hasattr(torch.ops.npu, "print_npugraph_tensor"):
    _npu_lib.define("print_npugraph_tensor(Tensor input, *, str? tensor_name=None) -> ()")

    _npu_lib.impl("print_npugraph_tensor", _print_npugraph_tensor_impl, "PrivateUse1")
    _npu_lib.impl("print_npugraph_tensor", _print_npugraph_tensor_impl, "CPU")

    has_side_effect(torch.ops.npu.print_npugraph_tensor.default)

    @torch.library.register_fake("npu::print_npugraph_tensor")
    def _print_npugraph_tensor_meta(input, *, tensor_name=None):
        pass


if not hasattr(torch.ops.npu, "save_npugraph_tensor"):
    _npu_lib.define("save_npugraph_tensor(Tensor input, *, str? save_path=None, bool overwrite=False) -> ()")
    _npu_lib.define("save_npugraph_tensor.tensor_list(Tensor[] input, *, str? save_path=None, bool overwrite=False) -> ()")

    _npu_lib.impl("save_npugraph_tensor", _save_npugraph_tensor_impl, "PrivateUse1")
    _npu_lib.impl("save_npugraph_tensor", _save_npugraph_tensor_impl, "CPU")
    _npu_lib.impl("save_npugraph_tensor.tensor_list", _save_npugraph_tensor_tensor_list_impl, "PrivateUse1")
    _npu_lib.impl("save_npugraph_tensor.tensor_list", _save_npugraph_tensor_tensor_list_impl, "CPU")

    has_side_effect(torch.ops.npu.save_npugraph_tensor.default)
    has_side_effect(torch.ops.npu.save_npugraph_tensor.tensor_list)

    @torch.library.register_fake("npu::save_npugraph_tensor")
    def _save_npugraph_tensor_meta(input, *, save_path=None, overwrite=False):
        pass

    @torch.library.register_fake("npu::save_npugraph_tensor.tensor_list")
    def _save_npugraph_tensor_tensor_list_meta(input, *, save_path=None, overwrite=False):
        pass


def super_kernel_scope_begin(scope_name: Optional[str] = None):
    return torch.ops.npu.super_kernel_scope_begin(scope_name)


def super_kernel_scope_end(scope_name: Optional[str] = None):
    return torch.ops.npu.super_kernel_scope_end(scope_name)


@dataclass
class _GraphDispatchRecord:
    """Record of a single dispatched operator call during graph capture."""
    event: Any = None
    handle: Any = None
    kwargs: Dict[str, Any] = field(default_factory=dict)
    args: Tuple[Any, ...] = field(default_factory=tuple)
    op_cache_entry: Any = None


class _GraphDispatchMode(torch.utils._python_dispatch.TorchDispatchMode):
    """Template-method skeleton for NPU Graph capture and update.

    The skeleton keeps the common stream / event / task-group orchestration
    and delegates operator-specific logic to
    :class:`~torch_npu.npu.NpuGraphOpHandler` classes
    looked up in ``_NPU_GRAPH_OP_HANDLERS``.
    """

    tensor_schema_name = {}
    update_stream = None

    def __new__(cls):
        if cls.update_stream is None:
            cls.update_stream = torch_npu.npu.Stream()
        return super().__new__(cls)

    def __init__(self):
        self.graph_dispatch_records = []

    @classmethod
    def is_infra_mode(cls):
        return True

    @classmethod
    def update_schema(cls, name, schema):
        if name in cls.tensor_schema_name:
            return
        # match: Tensor q_nope, Tensor? mask=None, Tensor(a!) output
        pattern = r'Tensor(?:\(a!\)|\?)?\s+(\w+)'
        cls.tensor_schema_name[name] = re.findall(pattern, schema)

    # -----------------------------------------------------------------
    #  Capture skeleton
    # -----------------------------------------------------------------

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        # Registry stores class objects; call via cls.method()
        handler_cls = _NPU_GRAPH_OP_HANDLERS.get(func.__name__)

        if handler_cls:
            # 1) Common: obtain stream and event
            if hasattr(handler_cls, 'should_handle'):
                if not handler_cls.should_handle(func, args, kwargs):
                    return func(*args, **kwargs)

            stream = torch_npu.npu.current_stream()
            event = torch.npu.ExternalEvent()
            event.wait(stream)
            event.reset(stream)

            # 2) Operator-specific: preprocessing (workspace, output pre-alloc, func swap)
            actual_func, args, kwargs = handler_cls.prepare_capture(
                func, args, kwargs
            )

            # 3) Common: parse operator schema
            self.update_schema(
                str(actual_func.__name__), str(actual_func._schema)
            )

            # 4) Common: record graph task group
            graph_task_group_begin(stream)
            result = actual_func(*args, **kwargs)
            handle = graph_task_group_end(stream)

            # 5) Common: create dispatch record (delegate kwarg conversion to handler)
            self.graph_dispatch_records.append(
                self._append_dispatch_record(
                    event, handle, args, kwargs, actual_func, handler_cls
                )
            )

            # 6) Operator-specific: post-process return value
            return handler_cls.postprocess_result(result, kwargs)

        return func(*args, **kwargs)

    def _append_dispatch_record(
        self, event, handle, args, kwargs, func, handler_cls
    ):
        """Create a dispatch record, converting args / kwargs to weak-refs or deep copies.

        ``handler_cls`` is a required parameter (class object) guaranteed by
        the capture skeleton.
        """
        args_ref = []
        for element in args:
            if torch.is_tensor(element) and "npu" in str(element.device):
                args_ref.append(TensorWeakRef(element))
            else:
                args_ref.append(deepcopy(element))

        tensor_param_names = self.tensor_schema_name.get(
            str(func.__name__), []
        )
        kwargs_ref = {}
        for key, value in kwargs.items():
            kwargs_ref[key] = handler_cls.record_wrap_kwarg(
                key, value, tensor_param_names
            )

        return _GraphDispatchRecord(
            event=event,
            handle=handle,
            kwargs=kwargs_ref,
            args=list(args_ref),
            op_cache_entry=func,
        )

    # -----------------------------------------------------------------
    #  Update skeleton
    # -----------------------------------------------------------------

    def update_capture_record(self, cpu_update_input):
        if len(cpu_update_input) == 1:
            new_list = [
                cpu_update_input[0].copy()
                for _ in range(len(self.graph_dispatch_records))
            ]
            cpu_update_input = new_list

        # BUG FIX: original code compared len(self.graph_dispatch_records)
        # with itself -- always True.  Corrected to compare with
        # cpu_update_input.
        if len(cpu_update_input) != len(self.graph_dispatch_records):
            raise RuntimeError(
                f"Currently, there are {len(self.graph_dispatch_records)} "
                f"operators that need to be updated by capture, and there "
                f"are only {len(cpu_update_input)} elements in the incoming "
                f"cpu_update_input list",
                pta_error(ErrCode.PARAM),
            )

        with torch.npu.stream(self.update_stream):
            for record, update_input in zip(
                self.graph_dispatch_records, cpu_update_input
            ):
                graph_task_update_begin(self.update_stream, record.handle)

                # 7) Common: update matching kwargs (direct assignment,
                #    consistent with original implementation).
                #    Capture-phase uses TensorWeakRef/deepcopy to avoid
                #    strong references; update is a short "assign -> replay"
                #    flow where direct assignment suffices.
                for key in update_input:
                    if key in record.kwargs:
                        record.kwargs[key] = update_input[key]

                # 8) Operator-specific: update args by index.
                #    Defensive assert -- unregistered ops go through
                #    passthrough in capture and never produce a dispatch
                #    record, so handler_cls must not be None here.
                handler_cls = _NPU_GRAPH_OP_HANDLERS.get(
                    record.op_cache_entry.__name__
                )
                if handler_cls is None:
                    raise RuntimeError(
                        f"No handler for recorded op: {record.op_cache_entry.__name__}. "
                        f"This indicates the handler was unregistered between capture and update.",
                        pta_error(ErrCode.PARAM),
                    )
                handler_cls.update_args(record, update_input)

                # 9) Common: replay the operator
                record.op_cache_entry(*record.args, **record.kwargs)
                graph_task_update_end(self.update_stream)
                record.event.record(self.update_stream)

# Python shim helps Sphinx process docstrings more reliably.
class NPUGraph(torch_npu._C._NPUGraph):
    r"""Wrapper around a NPU graph.

    .. warning::
        This API is in beta and may change in future releases.
    """

    def __new__(cls):
        return super().__new__(cls)

    def __init__(self):
        self.graph_dispatch_mode = _GraphDispatchMode()
        self.auto_dispatch_capture = False
        log.debug("NPUGRAPH Lifecycle NPUGraph created, graph_id=%s, auto_dispatch=%s",
                  id(self), self.auto_dispatch_capture)
        super().__init__()

    def capture_begin(self, pool=None, capture_error_mode="global"):
        r"""Begin capturing NPU work on the current stream.

        Typically, you shouldn't call ``capture_begin`` yourself.
        Use :class:`~torch.npu.graph` or :func:`~torch.npu.make_graphed_callables`,
        which call ``capture_begin`` internally.

        Arguments:
            pool (optional): Token (returned by :func:`~torch.npu.graph_pool_handle` or
                :meth:`other_Graph_instance.pool()<torch.npu.NPUGraph.pool>`) that hints this graph may share memory
                with the indicated pool.  See :ref:`Graph memory management<graph-memory-management>`.
            capture_error_mode (str, optional): specifies the aclmdlRICaptureMode for the graph capture stream.
                Can be "global", "thread_local" or "relaxed". During npu graph capture, some actions, such as npuMalloc,
                may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
                actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting
                unless you're familiar with `aclmdlRICaptureMode`_
        """  # noqa: B950
        super().capture_begin(pool=pool, capture_error_mode=capture_error_mode, report_shape=True)

    def capture_end(self):
        r"""End NPU graph capture on the current stream.

        After ``capture_end``, ``replay`` may be called on this instance.

        Typically, you shouldn't call ``capture_end`` yourself.
        Use :class:`~torch.npu.graph` or :func:`~torch.npu.make_graphed_callables`,
        which call ``capture_end`` internally.
        """
        super().capture_end()

    def replay(self):
        r"""Replay the NPU work captured by this graph."""
        log.debug("NPUGRAPH Replay graph_id=%s", id(self))
        super().replay()

    def reset(self):
        r"""Delete the graph currently held by this instance."""
        log.debug("NPUGRAPH Lifecycle NPUGraph reset, graph_id=%s", id(self))
        super().reset()

    def pool(self):
        r"""Return an opaque token representing the id of this graph's memory pool.

        This id can optionally be passed to another graph's ``capture_begin``,
        which hints the other graph may share the same memory pool.
        """
        return super().pool()

    def update(self, cpu_update_input):
        log.debug("NPUGraph: updating graph (%s inputs)...", len(cpu_update_input))
        if not self.auto_dispatch_capture:
            raise RuntimeError(
                "The current graph configuration does not support update,"
                "Try to capture by setting auto_dispatch_capture=True during capture",
                pta_error(ErrCode.PARAM),
            )
        self.graph_dispatch_mode.update_capture_record(cpu_update_input)

    def debug_dump(self, debug_path):
        r"""Calls a function to dump the graph in JSON format.

        Arguments:
            debug_path (required): Path to dump the graph to.
        """
        return super().debug_dump(debug_path)

    def super_kernel_optimize(self, optimize_options=None, debug_options=None):
        r"""Calls a function to optimize graph by super kernel.

        Arguments:
            optimize_options (optional):
                preload_code (int) - Controls code preloading strategy.
                split_mode (int) - Controls kernel splitting mode for better performance.
                stream_fusion (int) - Enables/disables stream fusion optimization.
                constant_codegen (int) - Enables/disables constant code generation.
                auto_op_parallel (int) - Enables/disables auto op parallel optimization.
                opt_extend (str) - Extended optimization option string.
                dcci_before_kernel_start (list[str]) - List of kernel names to insert DCCI before kernel start.
                dcci_after_kernel_end (list[str]) - List of kernel names to insert DCCI after kernel end.
                dcci_disable_on_kernel (list[str]) - List of kernel names to disable DCCI.
                aggressive_opt_strategies (dict) - Aggressive optimization strategies:
                    event_breaker_bypass (int) - Event wait继续融合(默认死锁检测通过融合)
                    value_breaker_bypass (int) - Value wait融合策略
                    task_breaker_bypass (int) - 不支持算子继续融合(默认断开)

            debug_options (optional):
                debug_sync_all (int) - Enables debug synchronization for all operations.
                debug_op_exec_trace (int) - Enables/disables op execution trace.
                debug_cross_core_sync_check (int) - Enables/disables cross core sync check.
                debug_extend (str) - Extended debug option string.
        """
        self._validate_options("optimize_options", optimize_options)
        self._validate_options("debug_options", debug_options)
        return super().super_kernel_optimize(optimize_options, debug_options)

    def _validate_options(self, option_name, options=None):
        if options is None:
            return

        if not isinstance(options, dict):
            raise RuntimeError(f"{option_name} param must be dict or None.", pta_error(ErrCode.PARAM))

        valid_options = {
            'optimize_options': {
                'preload_code': {
                    'value_type': int
                },
                'split_mode': {
                    'value_type': int
                },
                'stream_fusion': {
                    'value_type': int
                },
                'constant_codegen': {
                    'value_type': int
                },
                'auto_op_parallel': {
                    'value_type': int
                },
                'opt_extend': {
                    'value_type': str
                },
                'dcci_before_kernel_start': {
                    'value_type': list,
                    'element_type': str
                },
                'dcci_after_kernel_end': {
                    'value_type': list,
                    'element_type': str
                },
                'dcci_disable_on_kernel': {
                    'value_type': list,
                    'element_type': str
                },
                'aggressive_opt_strategies': {
                    'value_type': dict,
                    'sub_options': {
                        'event_breaker_bypass': {'value_type': int},
                        'value_breaker_bypass': {'value_type': int},
                        'task_breaker_bypass': {'value_type': int}
                    }
                },
                'ubuf_lock_ignore_kernel': {
                    'value_type': list,
                    'element_type': str
                },
                'early_start': {
                    'value_type': int
                }
            },
            'debug_options': {
                'debug_sync_all': {
                    'value_type': int
                },
                'debug_op_exec_trace': {
                    'value_type': int
                },
                'debug_cross_core_sync_check': {
                    'value_type': int
                },
                'debug_extend': {
                    'value_type': str
                },
                'debug_per_op_max_core_num': {
                    'value_type': int
                }
            }
        }

        for key, value in options.items():
            if option_name not in valid_options:
                raise RuntimeError(f"Invalid {option_name} param.", pta_error(ErrCode.PARAM))

            if key not in valid_options[option_name]:
                raise RuntimeError(f"Invalid {option_name} param key: '{key}'.", pta_error(ErrCode.PARAM))

            expected_type = valid_options[option_name][key]['value_type']
            if not isinstance(value, expected_type):
                raise RuntimeError(f"{option_name} param['{key}'] must be {expected_type.__name__}, "
                                   f"got {type(value).__name__}", pta_error(ErrCode.PARAM))

            if expected_type == dict and 'sub_options' in valid_options[option_name][key]:
                for sub_key, sub_value in value.items():
                    if sub_key not in valid_options[option_name][key]['sub_options']:
                        raise RuntimeError(f"Invalid sub_key '{sub_key}' in {option_name}['{key}'].",
                                           pta_error(ErrCode.PARAM))
                    sub_expected_type = valid_options[option_name][key]['sub_options'][sub_key]['value_type']
                    if not isinstance(sub_value, sub_expected_type):
                        raise RuntimeError(f"{option_name}['{key}']['{sub_key}'] must be {sub_expected_type.__name__}, "
                                           f"got {type(sub_value).__name__}", pta_error(ErrCode.PARAM))

            if expected_type == list and 'element_type' in valid_options[option_name][key]:
                element_type = valid_options[option_name][key]['element_type']
                for i, elem in enumerate(value):
                    if not isinstance(elem, element_type):
                        raise RuntimeError(f"{option_name}['{key}'][{i}] must be {element_type.__name__}, "
                                           f"got {type(elem).__name__}", pta_error(ErrCode.PARAM))


class graph:
    r"""Context-manager that captures NPU work into a :class:`torch.npu.NPUGraph` object for later replay.

    See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction,
    detailed use, and constraints.

    Arguments:
        npu_graph (torch.npu.NPUGraph): Graph object used for capture.
        pool (optional): Opaque token (returned by a call to :func:`~torch.npu.graph_pool_handle()` or
            :meth:`other_Graph_instance.pool()<torch.npu.NPUGraph.pool>`) hinting this graph's capture
            may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
        stream (torch.npu.Stream, optional): If supplied, will be set as the current stream in the context.
            If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
        capture_error_mode (str, optional): specifies the aclmdlRICaptureMode for the graph capture stream.
            Can be "global", "thread_local" or "relaxed". During npu graph capture, some actions, such as npuMalloc,
            may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
            actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
            unless you're familiar with `aclmdlRICaptureMode`_

    .. note::
        For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
        used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.

    .. warning::
        This API is in beta and may change in future releases.
    """  # noqa: B950

    default_capture_stream: typing.Optional["torch.npu.Stream"] = None

    def __init__(
        self,
        npu_graph,
        pool=None,
        stream=None,
        auto_dispatch_capture=False,
        capture_error_mode: str = "global",
    ):
        # Lazy-init of default_capture_stream helps avoid circular-import errors.
        # Not thread safe, but graphs already have the general (explicitly documented)
        # restriction that only one capture may be underway at a time in the process.
        if self.__class__.default_capture_stream is None:
            self.__class__.default_capture_stream = torch.npu.Stream()

        self.pool = () if pool is None else (pool,)
        self.capture_stream = (
            stream if stream is not None else self.__class__.default_capture_stream
        )
        if self.capture_stream is None:
            raise RuntimeError("capture stream is None")
        self.stream_ctx = torch.npu.stream(self.capture_stream)
        self.npu_graph = npu_graph
        self.capture_error_mode = capture_error_mode
        self.npu_graph.auto_dispatch_capture = auto_dispatch_capture

    def __enter__(self):
        log.debug("NPUGRAPH Capture device=%s, pool=%s, mode=%s, auto_dispatch=%s",
                  torch.npu.current_device(), self.pool, self.capture_error_mode,
                  self.npu_graph.auto_dispatch_capture)
        torch.npu.synchronize()
        if force_npugraph_gc:
            gc.collect()
        torch.npu.empty_cache()
        torch_npu.npu.host_empty_cache()

        # Stackoverflow seems comfortable with this pattern
        self.stream_ctx.__enter__()
        if self.npu_graph.auto_dispatch_capture:
            self.npu_graph.graph_dispatch_mode.__enter__()
        self.npu_graph.capture_begin(
            *self.pool, capture_error_mode=self.capture_error_mode
        )

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_type is not None:
            log.error("NPUGraph: ERROR — capture failed: %s: %s, "
                      "graph_id=%s, stream=%s, pool=%s, device=%s",
                      exc_type.__name__, exc_value,
                      id(self.npu_graph), self.capture_stream,
                      self.pool, torch.npu.current_device())
            log.warning("NPUGRAPH DFX common causes: default stream, CPU ops, multi-device tensors, "
                        "dynamic control flow on tensor values")
        else:
            log.debug("NPUGraph: capture completed, graph_id=%s, device=%s",
                      id(self.npu_graph), torch.npu.current_device())
        self.npu_graph.capture_end()
        if self.npu_graph.auto_dispatch_capture:
            self.npu_graph.graph_dispatch_mode.__exit__(exc_type, exc_value, traceback)
        self.stream_ctx.__exit__(exc_type, exc_value, traceback)
        # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()


def make_graphed_callables(
        callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None
):
    r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.

    Each graphed callable's forward pass runs its source callable's
    forward CUDA work as a CUDA graph inside a single autograd node.

    The graphed callable's forward pass also appends
    a backward node to the autograd graph. During backward, this node runs the
    callable's backward work as a CUDA graph.

    Therefore, each graphed callable should be a drop-in replacement for its source callable
    in an autograd-enabled training loop.

    See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.

    If you pass a tuple of several callables, their captures will use the same memory pool.
    See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate.

    Arguments:
        callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
            See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables
            is appropriate.  If you pass a tuple of callables, their order in the tuple must be the same order
            they'll run in the live workload.
        sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
            If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
            If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
        num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs
            11 iterations for warm up. Default: ``3``.
        allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs
            (and therefore their grad is always zero) is an error. Defaults to False.
        pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
            :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
            with the indicated pool.  See :ref:`Graph memory management<graph-memory-management>`.
    .. note::
        The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
        that's expected for the corresponding real input in the training loop.

    .. warning::
        This API is in beta and may change in future releases.

    .. warning::
        ``sample_args`` for each callable must contain only Tensors. Other types are not allowed.

    .. warning::
        Returned callables do not support higher order differentiation (e.g., double backward).

    .. warning::
        In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
        may be trainable. Buffers must have ``requires_grad=False``.

    .. warning::
        After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
        you may not add or remove any of that Module's parameters or buffers.

    .. warning::
        :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
        registered on them at the time they are passed. However, registering hooks on modules *after* passing them
        through :func:`~torch.cuda.make_graphed_callables` is allowed.

    .. warning::
        When running a graphed callable, you must pass its arguments in the same order and format
        they appeared in that callable's ``sample_args``.

    .. warning::
        The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled
        caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`.
    """

    if torch_npu.npu.is_autocast_enabled() and torch.is_autocast_cache_enabled():
        log.error("NPUGraph: ERROR — autocast caching not supported")
        log.warning("NPUGRAPH DFX Fix: set cache_enabled=False in torch.autocast()")
        raise RuntimeError(
            "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
        )

    just_one_callable = False

    if not isinstance(callables, tuple):
        just_one_callable = True
        callables = (callables,)
        sample_args = (sample_args,)

    flatten_sample_args = []

    for c, args in zip(callables, sample_args):
        if isinstance(c, torch.nn.Module):
            if len(c._backward_hooks) > 0 or len(c._forward_hooks) > 0 or len(c._forward_pre_hooks) > 0:
                log.error("NPUGraph: ERROR — hooks registered on module %s", type(c).__name__)
                log.warning("NPUGRAPH DFX Fix: remove hooks before make_graphed_callables, re-register after")
                raise RuntimeError(
                    "Modules must not have hooks registered at the time they are passed. However, "
                    + "registering hooks on modules after passing them through make_graphed_callables is allowed."
                )
            if any(b.requires_grad for b in c.buffers()):
                raise RuntimeError(
                    "In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`,"
                    + " only parameters may be trainable. All buffers must have ``requires_grad=False``."
                )
        flatten_arg = torch.utils._pytree.arg_tree_leaves(*args)
        flatten_sample_args.append(tuple(flatten_arg))
        if not all(isinstance(arg, torch.Tensor) for arg in flatten_arg):
            raise RuntimeError(
                "In the beta API, sample_args "
                + "for each callable must contain only Tensors. Other types are not allowed."
            )

    # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
    # passes to forward (ie, its sample_args) AND the module's parameter attributes.
    per_callable_len_user_args = [len(args) for args in flatten_sample_args]
    per_callable_module_params = [
        tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
        for c in callables
    ]
    per_callable_static_input_surfaces = [
        flatten_sample_args[i] + per_callable_module_params[i]
        for i in range(len(callables))
    ]

    fwd_graphs = [torch_npu.npu.NPUGraph() for _ in range(len(callables))]
    bwd_graphs = [torch_npu.npu.NPUGraph() for _ in range(len(callables))]

    mempool = graph_pool_handle() if pool is None else pool

    # Warmup
    # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
    # from ending up in any captures.
    torch_npu.npu.synchronize()
    with torch_npu.npu.stream(torch_npu.npu.Stream()):
        for func, args, static_input_surface in zip(
            callables, sample_args, per_callable_static_input_surfaces
        ):
            grad_inputs, outputs, outputs_grad = None, None, None
            for _ in range(num_warmup_iters):
                outputs = torch.utils._pytree.tree_leaves(func(*args))
                outputs_grad = tuple(o for o in outputs if o.requires_grad)
                if len(outputs_grad) > 0:
                    grad_inputs = torch.autograd.grad(
                        outputs=outputs_grad,
                        inputs=tuple(
                            i for i in static_input_surface if i.requires_grad
                        ),
                        grad_outputs=tuple(
                            torch.empty_like(o) for o in outputs if o.requires_grad
                        ),
                        only_inputs=True,
                        allow_unused=allow_unused_input,
                    )
            for v in [outputs, outputs_grad, grad_inputs]:
                del v

    torch_npu.npu.synchronize()

    # All captures here share a mempool. To avoid replays corrupting each other's memory,
    # the safest approach is to capture all passes in the same order they'll run:
    # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.

    # Capture forward graphs
    per_callable_static_outputs = []
    per_callable_output_unflatten_spec = []
    for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
        with torch_npu.npu.graph(fwd_graph, pool=mempool):
            outputs = func(*args)

        flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs)
        per_callable_static_outputs.append(tuple(flatten_outputs))
        per_callable_output_unflatten_spec.append(spec)

    # Capture backward graphs in reverse order
    per_callable_static_grad_outputs = []
    per_callable_static_grad_inputs = []
    for static_input_surface, static_outputs, bwd_graph, module_params in zip(
        reversed(per_callable_static_input_surfaces),
        reversed(per_callable_static_outputs),
        reversed(bwd_graphs),
        reversed(per_callable_module_params),
    ):
        # For now, assumes all static_outputs require grad
        static_grad_outputs = tuple(
            torch.empty_like(o) if o.requires_grad else None for o in static_outputs
        )

        outputs_grad = tuple(o for o in static_outputs if o.requires_grad)
        grad_inputs = None
        if len(outputs_grad) > 0:
            with torch_npu.npu.graph(bwd_graph, pool=mempool):
                grad_inputs = torch.autograd.grad(
                    outputs=outputs_grad,
                    inputs=tuple(i for i in static_input_surface if i.requires_grad),
                    grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
                    only_inputs=True,
                    allow_unused=allow_unused_input,
                )

        # Constructs a tuple suitable for returning from Graphed.backward:
        # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
        # I couldn't think of a slick one-liner for this pattern.
        static_grad_inputs = []
        grad_idx = 0
        for arg in static_input_surface:
            if arg.requires_grad and grad_inputs is not None:
                static_grad_inputs.append(grad_inputs[grad_idx])
                grad_idx += 1
            else:
                static_grad_inputs.append(None)  # type: ignore[arg-type]
        static_grad_inputs = tuple(static_grad_inputs)  # type: ignore[assignment]

        per_callable_static_grad_outputs.append(static_grad_outputs)
        per_callable_static_grad_inputs.append(static_grad_inputs)

    # Reverses the most recent two lists
    per_callable_static_grad_outputs.reverse()
    per_callable_static_grad_inputs.reverse()
    # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.

    def make_graphed_autograd_function(
        fwd_graph,
        bwd_graph,
        module_params,
        len_user_args,
        output_unflatten_spec,
        static_input_surface,
        static_outputs,
        static_grad_outputs,
        static_grad_inputs,
    ):
        class Graphed(torch.autograd.Function):
            @staticmethod
            def forward(ctx, *inputs):
                # At this stage, only the user args may (potentially) be new tensors.
                for i in range(len_user_args):
                    if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
                        static_input_surface[i].copy_(inputs[i])
                fwd_graph.replay()
                if not isinstance(static_outputs, tuple):
                    raise RuntimeError("static_outputs is not tuple.")
                return tuple(o.detach() for o in static_outputs)

            @staticmethod
            @torch.autograd.function.once_differentiable
            def backward(ctx, *grads):
                if (len(grads) != len(static_grad_outputs)):
                    raise RuntimeError(
                        "The length of grads"
                        + " is not equal with the length of static_grad_outputs."
                    )
                for g, grad in zip(static_grad_outputs, grads):
                    if g is not None:
                        # don't copy if autograd gods have been kind and the
                        # incoming grad is already in the right place
                        if g.data_ptr() != grad.data_ptr():
                            g.copy_(grad)
                bwd_graph.replay()

                # Input args that didn't require grad expect a None gradient.
                if not isinstance(static_grad_inputs, tuple):
                    raise RuntimeError("static_grad_inputs is not tuple.")
                return tuple(
                    b.detach() if b is not None else b for b in static_grad_inputs
                )

        def functionalized(*user_args):
            # Runs the autograd function with inputs == all inputs to the graph that might require grad
            # (explicit user args + module parameters)
            # Assumes module params didn't change since capture.
            flatten_user_args = torch.utils._pytree.arg_tree_leaves(*user_args)
            out = Graphed.apply(*(tuple(flatten_user_args) + module_params))
            return torch.utils._pytree.tree_unflatten(out, output_unflatten_spec)

        return functionalized

    # Put together the final graphed callables
    ret = []
    for i, func in enumerate(callables):
        graphed = make_graphed_autograd_function(
            fwd_graphs[i],
            bwd_graphs[i],
            per_callable_module_params[i],
            per_callable_len_user_args[i],
            per_callable_output_unflatten_spec[i],
            per_callable_static_input_surfaces[i],
            per_callable_static_outputs[i],
            per_callable_static_grad_outputs[i],
            per_callable_static_grad_inputs[i],
        )

        if isinstance(func, torch.nn.Module):

            def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
                def new_fwd(*user_args):
                    # If the module's training-or-eval state matches what we graphed,
                    # run the graph, otherwise run the original forward method
                    if func.training == graph_training_state:
                        return graphed(*user_args)
                    else:
                        return orig_fwd(*user_args)

                return new_fwd

            func.forward = make_graphed_forward(func, func.training, graphed, func.forward)  # type: ignore[assignment]
            ret.append(func)
        else:
            ret.append(graphed)

    log.debug("NPUGraph: graphed callables ready")
    if just_one_callable:
        return ret[0]

    return tuple(ret)