import inspect
import os
import sys
import logging
from typing import Any, Optional, TYPE_CHECKING
import importlib
import torch
import torch_npu
from torch import _TorchCompileWrapper
from torch._dynamo import optimize
from torch._dynamo.utils import tensortype_to_dtype
from torch._dynamo.variables.base import VariableTracker
from torch._dynamo.variables.constant import ConstantVariable
from torch._dynamo.variables.ctx_manager import AutocastModeVariable
from torch._dynamo.variables.functions import SkipFunctionVariable
from torch._dynamo.variables.lists import TupleVariable
from torch._dynamo.variables.streams import StreamContextVariable, StreamVariable
from torch._dynamo.variables.tensor import TensorVariable
from torch._dynamo.variables.torch import (
    TorchCtxManagerClassVariable,
    TorchInGraphFunctionVariable,
)
from torch._dynamo.variables.user_defined import UserDefinedClassVariable
from torch_npu.dynamo import _get_global_npu_backend


if TYPE_CHECKING:
    from torch._dynamo.symbolic_convert import InstructionTranslator

use_jit_script = False
log = logging.getLogger(__name__)

class NPUTorchCtxManagerClassVariable(TorchCtxManagerClassVariable):
    def call_function(self, tx, args, kwargs):
        return NPUAutocastModeVariable.create(self.value, args, kwargs)


class NPUAutocastModeVariable(AutocastModeVariable):
    @staticmethod
    def create(func, args, kwargs):
        bound_args = inspect.signature(func).bind(*args, **kwargs)
        bound_args.apply_defaults()
        target_values = []
        kwargs.clear()

        for key in ["device_type", "dtype", "enabled", "cache_enabled"]:
            if key == "device_type" and func in [
                torch_npu.npu.amp.autocast,
            ]:
                arg = "npu" if func is torch_npu.npu.amp.autocast else "cpu"
            else:
                arg = bound_args.arguments[key]
            if isinstance(arg, VariableTracker):
                target_values.append(arg.as_python_constant())
            else:
                target_values.append(arg)

        var = AutocastModeVariable(target_values, initial_values=None, **kwargs)
        return var


def UserDefinedClassVariable__new__(cls, value, **kwargs):
    if value in [
        torch.npu.amp.autocast,
        torch_npu.npu.amp.autocast,
        torch.npu.amp.autocast_mode.autocast,
        torch_npu.npu.amp.autocast_mode.autocast,
    ]:
        return NPUTorchCtxManagerClassVariable(value, **kwargs)
    elif value in [
        torch_npu.npu.BoolTensor,
        torch_npu.npu.ByteTensor,
        torch_npu.npu.CharTensor,
        torch_npu.npu.DoubleTensor,
        torch_npu.npu.FloatTensor,
        torch_npu.npu.HalfTensor,
        torch_npu.npu.IntTensor,
        torch_npu.npu.LongTensor,
        torch_npu.npu.ShortTensor,
        torch_npu.npu.BFloat16Tensor,
    ]:
        return TorchInGraphFunctionVariable(value, **kwargs)
    return cls.__new__raw(cls)


def SkipFunctionVariable__new__(cls, value, reason=None, **kwargs):
    if value in [
        torch.npu.stream,
        torch_npu.npu.stream,
        torch_npu.npu.utils.stream,
    ]:
        return TorchInGraphFunctionVariable(value, **kwargs)
    return cls.__new__raw(cls)


def TensorVariable_call_method(self, tx, name, args, kwargs):
    if (
        name == "type"
        and self.dtype is not None
        and len(args) == 0
        and isinstance(self.device, torch.device)
        and self.device.type == "npu"
    ):
        tensortype = next(k for k, v in tensortype_to_dtype.items() if self.dtype in v)
        constant_result = ConstantVariable.create(f"torch.npu.{tensortype.__name__}")

        if len(args) == 1:
            return constant_result.getitem_const(args[0])
        elif args:
            return TupleVariable([constant_result.getitem_const(a) for a in args])
        return constant_result
    else:
        return TensorVariable.call_method_raw(self, tx, name, args, kwargs)


class _InductorNpuRegistry:
    _disabled_register = False
    _loaded_backend = None

    @classmethod
    def register_inductor_npu(cls):
        if cls._disabled_register:
            return

        current = os.getenv("TORCHINDUCTOR_NPU_BACKEND", "default")
        if cls._loaded_backend != current:
            if "torch_npu._inductor" not in sys.modules:
                importlib.import_module("torch_npu._inductor")
            else:
                sys.modules["torch_npu._inductor"]._load_backend()
            cls._loaded_backend = current


    @classmethod
    def disable_register(cls):
        cls._disabled_register = True

    @classmethod
    def enable_register(cls):
        cls._disabled_register = False

    @classmethod
    def has_initialized(cls):
        return cls._loaded_backend is not None


def is_inductor_npu_initialized():
    return _InductorNpuRegistry.has_initialized()


def disable_register_inductor_npu():
    _InductorNpuRegistry.disable_register()


def enable_register_inductor_npu():
    _InductorNpuRegistry.enable_register()


def register_inductor_npu():
    _InductorNpuRegistry.register_inductor_npu()


def _resolve_npu_backend_from_wrapper(wrapper) -> str:
    """Resolve npu backend with priority: wrapper options > global config > env."""
    wrapper_backend = wrapper.config.get("npu_backend")
    if wrapper_backend not in (None, "", "default"):
        return wrapper_backend

    global_backend = getattr(torch._inductor.config, "npu_backend", None)
    if global_backend not in (None, "", "default"):
        return global_backend

    return os.getenv("TORCHINDUCTOR_NPU_BACKEND", "default")


class _NpuBackendScope:
    """Apply resolved npu backend for one compile invocation and restore env."""

    def __init__(self, backend: str):
        self.backend = backend
        self._old_env = None

    def __enter__(self):
        self._old_env = os.environ.get("TORCHINDUCTOR_NPU_BACKEND")
        os.environ["TORCHINDUCTOR_NPU_BACKEND"] = self.backend
        register_inductor_npu()
        return self

    def __exit__(self, exc_type, exc, tb):
        if self._old_env is None:
            os.environ.pop("TORCHINDUCTOR_NPU_BACKEND", None)
        else:
            os.environ["TORCHINDUCTOR_NPU_BACKEND"] = self._old_env
        return False


def patch_inductor_wrapper():
    from typing import Any

    from torch import _TorchCompileInductorWrapper
    from torch.utils._config_module import Config, ConfigModule

    from torch_npu._compat.utils import make_config_entry


    src_init = _TorchCompileInductorWrapper.__init__
    src_get_config_copy = ConfigModule.get_config_copy
    src_call = _TorchCompileInductorWrapper.__call__

    def new_get_config_copy(self) -> dict[str, Any]:
        ori_dict = src_get_config_copy(self)
        if self is not torch._inductor.config:
            return ori_dict
        if "npu_backend" not in ori_dict:
            ori_dict["npu_backend"] = "default"
            self._config["npu_backend"] = make_config_entry(
                Config(default="default", value_type=str),
                name="npu_backend",
            )
        return ori_dict

    def new_init(self, mode, options, dynamic, name=None):
        if name is not None:
            src_init(self, mode, options, dynamic, name)
        else:
            src_init(self, mode, options, dynamic)
        backend = _resolve_npu_backend_from_wrapper(self)
        if backend == "mlir":
            with _NpuBackendScope(backend):
                log.info("Running MLIR backend")
                device_id = torch_npu.npu.current_device()
                torch_npu._C._recovery_all_npu_stream(device_id)
        if backend == "dvm":
            with _NpuBackendScope(backend):
                log.info("Running dvm backend")

    def new_call(self, model_, inputs_):
        backend = _resolve_npu_backend_from_wrapper(self)
        with _NpuBackendScope(backend):
            return src_call(self, model_, inputs_)

    _TorchCompileInductorWrapper.__call__ = new_call
    _TorchCompileInductorWrapper.__init__ = new_init
    ConfigModule.get_config_copy = new_get_config_copy
    torch._inductor.config.get_config_copy()


def patch_dynamo_optimize():
    src_optimize = optimize

    def npu_optimize(*args, **kwargs):
        backend = None
        if "backend" in kwargs:
            backend = kwargs["backend"]
        elif len(args) == 1:
            backend = args[0]

        backend_name = None
        if isinstance(backend, str):
            backend_name = backend
        elif isinstance(backend, _TorchCompileWrapper):
            backend_name = backend.compiler_name

        if backend_name == "npu":
            # Init torchair ahead of running model.
            _get_global_npu_backend(backend_name)
        return src_optimize(*args, **kwargs)

    torch._dynamo.optimize = npu_optimize


def patch_builtin_variable():
    origin_call_id = torch._dynamo.variables.builtin.BuiltinVariable.call_id

    def _wrap_call_id(self, tx, *args):
        if torch._dynamo.variables.builtin.istype(
            args[0], torch._dynamo.variables.streams.EventVariable
        ):
            return torch._dynamo.variables.ConstantVariable.create(id(args[0].value))
        return origin_call_id(self, tx, *args)

    torch._dynamo.variables.builtin.BuiltinVariable.call_id = _wrap_call_id


def patch_stream_event_variable_python_type():
    """
    Preserve backend-specific stream/event Python types in Dynamo.

    PyTorch's generic StreamVariable/EventVariable report torch.Stream and
    torch.Event. NPU subclasses have Python methods that use super(NpuType,
    self), so Dynamo must use the real runtime subclass when tracing those
    methods, especially when profiler wrappers cause Dynamo to inline them.
    """

    def python_type(self):
        return type(self.value)

    streams = torch._dynamo.variables.streams
    streams.StreamVariable.python_type = python_type
    streams.EventVariable.python_type = python_type


class NpuStreamContextVariable(StreamContextVariable):
    """This represents NPU stream context with FX graph set_stream node creation."""

    @staticmethod
    def create(
        tx: "InstructionTranslator",
        stream_to_enter: "StreamVariable",
        **kwargs: dict[str, Any],
    ) -> "NpuStreamContextVariable":
        from torch._dynamo.device_interface import get_interface_for_device
        from torch._dynamo.variables.builder import wrap_fx_proxy_cls

        device_interface = get_interface_for_device(stream_to_enter.device)
        current_stream_var = wrap_fx_proxy_cls(
            StreamVariable,
            tx,
            tx.output.create_proxy(
                "call_function",
                device_interface.current_stream,
                (None,),
                {},
            ),
        )

        return NpuStreamContextVariable(
            stream_to_enter,
            current_stream=current_stream_var,
            device_interface=device_interface,
            **kwargs,
        )

    def __init__(
        self,
        stream: Optional["StreamVariable"],
        current_stream: Optional["StreamVariable"] = None,
        device_interface: Any | None = None,
        **kwargs: Any,
    ) -> None:
        self.current_stream = current_stream
        self.device_interface = device_interface
        super().__init__(stream, **kwargs)

    def enter(
        self, tx: "InstructionTranslator", *args: VariableTracker
    ) -> VariableTracker:
        # Create set_stream node to switch to self.stream
        if self.get_stream():
            tx.output.create_proxy(
                "call_function",
                self.device_interface.set_stream,
                (self.get_stream().as_proxy(),),
                {},
            )
        return super().enter(tx)

    def exit(
        self, tx: "InstructionTranslator", *args: VariableTracker
    ) -> VariableTracker:
        # First exit the symbolic stream state
        # Create set_stream node to restore current_stream
        if self.get_stream():
            tx.output.create_proxy(
                "call_function",
                self.device_interface.set_stream,
                (self.current_stream.as_proxy(),),
                {},
            )
        return super().exit(tx, *args)


def patch_npu_stream_context():
    from torch._dynamo.device_interface import get_interface_for_device

    def _handle_npu_device_interface_stream(self, tx, stream):
        return NpuStreamContextVariable.create(tx, stream)

    TorchInGraphFunctionVariable._get_handlers()[
        get_interface_for_device("npu").stream
    ] = _handle_npu_device_interface_stream


def patch_user_defined_class_variable():
    import functools

    original_method = UserDefinedClassVariable._in_graph_classes

    @staticmethod
    @functools.lru_cache(None)
    def patched_in_graph_classes():
        result = original_method()
        result.add(torch.npu.Event)
        result.add(torch.npu.Stream)
        return result

    UserDefinedClassVariable._in_graph_classes = patched_in_graph_classes


def add_dynamo_methods():
    UserDefinedClassVariable.__new__raw = UserDefinedClassVariable.__new__
    UserDefinedClassVariable.__new__ = UserDefinedClassVariable__new__
    SkipFunctionVariable.__new__raw = SkipFunctionVariable.__new__
    SkipFunctionVariable.__new__ = SkipFunctionVariable__new__
    TensorVariable.call_method_raw = TensorVariable.call_method
    TensorVariable.call_method = TensorVariable_call_method
    patch_dynamo_optimize()
    patch_inductor_wrapper()
    patch_user_defined_class_variable()
    patch_stream_event_variable_python_type()
    patch_builtin_variable()
    patch_npu_stream_context()
    patch_user_defined_class_variable()