import importlib
import inspect
import logging
import os
import sys

import torch
import torch_npu
from torch import _TorchCompileWrapper
from torch._dynamo import optimize
from torch._dynamo.utils import tensortype_to_dtype
from torch._dynamo.variables.base import VariableTracker
from torch._dynamo.variables.constant import ConstantVariable
from torch._dynamo.variables.ctx_manager import AutocastModeVariable
from torch._dynamo.variables.functions import SkipFunctionVariable
from torch._dynamo.variables.lists import TupleVariable
from torch._dynamo.variables.tensor import TensorVariable
from torch._dynamo.variables.torch import (
    TorchCtxManagerClassVariable,
    TorchInGraphFunctionVariable,
)
from torch._dynamo.variables.user_defined import UserDefinedClassVariable
from torch_npu.dynamo import _get_global_npu_backend


use_jit_script = False
log = logging.getLogger(__name__)


class NPUTorchCtxManagerClassVariable(TorchCtxManagerClassVariable):
    def call_function(self, tx, args, kwargs):
        return NPUAutocastModeVariable.create(self.value, args, kwargs)


class NPUAutocastModeVariable(AutocastModeVariable):
    @staticmethod
    def create(func, args, kwargs):
        bound_args = inspect.signature(func).bind(*args, **kwargs)
        bound_args.apply_defaults()
        target_values = []
        kwargs.clear()

        for key in ["device_type", "dtype", "enabled", "cache_enabled"]:
            if key == "device_type" and func in [
                torch_npu.npu.amp.autocast,
            ]:
                arg = "npu" if func is torch_npu.npu.amp.autocast else "cpu"
            else:
                arg = bound_args.arguments[key]
            if isinstance(arg, VariableTracker):
                target_values.append(arg.as_python_constant())
            else:
                target_values.append(arg)

        var = AutocastModeVariable(target_values, initial_values=None, **kwargs)
        return var


def UserDefinedClassVariable__new__(cls, value, **kwargs):
    if value in [
        torch.npu.amp.autocast,
        torch_npu.npu.amp.autocast,
        torch.npu.amp.autocast_mode.autocast,
        torch_npu.npu.amp.autocast_mode.autocast,
    ]:
        return NPUTorchCtxManagerClassVariable(value, **kwargs)
    elif value in [
        torch.npu.Stream,
        torch_npu.npu.Stream,
        torch.npu.streams.Stream,
        torch_npu.npu.streams.Stream,
        torch_npu.npu.BoolTensor,
        torch_npu.npu.ByteTensor,
        torch_npu.npu.CharTensor,
        torch_npu.npu.DoubleTensor,
        torch_npu.npu.FloatTensor,
        torch_npu.npu.HalfTensor,
        torch_npu.npu.IntTensor,
        torch_npu.npu.LongTensor,
        torch_npu.npu.ShortTensor,
        torch_npu.npu.BFloat16Tensor,
    ]:
        return TorchInGraphFunctionVariable(value, **kwargs)
    return cls.__new__raw(cls)


def SkipFunctionVariable__new__(cls, value, reason=None, **kwargs):
    if value in [
        torch.npu.stream,
        torch_npu.npu.stream,
        torch_npu.npu.utils.stream,
    ]:
        return TorchInGraphFunctionVariable(value, **kwargs)
    return cls.__new__raw(cls)


def TensorVariable_call_method(self, tx, name, args, kwargs):
    if (
        name == "type"
        and self.dtype is not None
        and len(args) == 0
        and isinstance(self.device, torch.device)
        and self.device.type == "npu"
    ):
        tensortype = next(k for k, v in tensortype_to_dtype.items() if self.dtype in v)
        constant_result = ConstantVariable.create(f"torch.npu.{tensortype.__name__}")

        if len(args) == 1:
            return constant_result.getitem_const(args[0])
        elif args:
            return TupleVariable([constant_result.getitem_const(a) for a in args])
        return constant_result
    else:
        return TensorVariable.call_method_raw(self, tx, name, args, kwargs)


class _InductorNpuRegistry:
    _disabled_register = False
    _loaded_backend = None

    @classmethod
    def register_inductor_npu(cls):
        if cls._disabled_register:
            return

        current = os.getenv("TORCHINDUCTOR_NPU_BACKEND", "default")
        if cls._loaded_backend != current:
            if "torch_npu._inductor" not in sys.modules:
                importlib.import_module("torch_npu._inductor")
            else:
                sys.modules["torch_npu._inductor"]._load_backend()
            cls._loaded_backend = current


    @classmethod
    def disable_register(cls):
        cls._disabled_register = True

    @classmethod
    def enable_register(cls):
        cls._disabled_register = False

    @classmethod
    def has_initialized(cls):
        return cls._loaded_backend is not None


def is_inductor_npu_initialized():
    return _InductorNpuRegistry.has_initialized()


def disable_register_inductor_npu():
    _InductorNpuRegistry.disable_register()


def enable_register_inductor_npu():
    _InductorNpuRegistry.enable_register()


def register_inductor_npu():
    _InductorNpuRegistry.register_inductor_npu()


def _resolve_npu_backend_from_wrapper(wrapper) -> str:
    """Resolve npu backend with priority: wrapper options > global config > env."""
    wrapper_backend = wrapper.config.get("npu_backend")
    if wrapper_backend not in (None, "", "default"):
        return wrapper_backend

    global_backend = getattr(torch._inductor.config, "npu_backend", None)
    if global_backend not in (None, "", "default"):
        return global_backend

    return os.getenv("TORCHINDUCTOR_NPU_BACKEND", "default")


class _NpuBackendScope:
    """Apply resolved npu backend for one compile invocation and restore env."""

    def __init__(self, backend: str):
        self.backend = backend
        self._old_env = None

    def __enter__(self):
        self._old_env = os.environ.get("TORCHINDUCTOR_NPU_BACKEND") 
        os.environ["TORCHINDUCTOR_NPU_BACKEND"] = self.backend
        register_inductor_npu()
        return self

    def __exit__(self, exc_type, exc, tb):
        if self._old_env is None:
            os.environ.pop("TORCHINDUCTOR_NPU_BACKEND", None)
        else:
            os.environ["TORCHINDUCTOR_NPU_BACKEND"] = self._old_env
        return False


def patch_inductor_wrapper():
    from typing import Any, Literal, Optional

    from torch import _TorchCompileInductorWrapper
    from torch.utils._config_module import _ConfigEntry, Config, ConfigModule

    src_apply_options = _TorchCompileInductorWrapper.apply_options
    src_init = _TorchCompileInductorWrapper.__init__
    src_get_config_copy = ConfigModule.get_config_copy
    src_call = _TorchCompileInductorWrapper.__call__

    def new_apply_options(self, options: Optional[dict[str, Any]]):
        if options is not None and options.get("enable_shape_handling", False):
            if not is_inductor_npu_initialized():
                register_inductor_npu()
            torch_npu._inductor.patch_shape_handling()
        src_apply_options(self, options)

    def new_get_config_copy(self) -> dict[str, Any]:
        ori_dict = src_get_config_copy(self)
        if self is not torch._inductor.config:
            return ori_dict
        NpuBackendType = Literal["default", "mlir", "dvm"]
        if "npu_backend" not in ori_dict:
            ori_dict["npu_backend"] = "default"
            self._config["npu_backend"] = _ConfigEntry(
                Config(default="default", value_type=NpuBackendType)
            )

        if "enable_shape_handling" not in ori_dict:
            ori_dict["enable_shape_handling"] = False
            self._config["enable_shape_handling"] = _ConfigEntry(
                Config(default=False, value_type=bool)
            )

        if "shape_handling_configs" not in ori_dict:
            ori_dict["shape_handling_configs"] = None
            self._config["shape_handling_configs"] = _ConfigEntry(
                Config(default=None, value_type=list)
            )

        if "shape_handling_dict" not in ori_dict:
            ori_dict["shape_handling_dict"] = None
            self._config["shape_handling_dict"] = _ConfigEntry(
                Config(default=None, value_type=dict)
            )
        return ori_dict

    def new_init(self, mode, options, dynamic):
        src_init(self, mode, options, dynamic)
        backend = _resolve_npu_backend_from_wrapper(self)
        if backend=="mlir" or backend=="dvm":
            with _NpuBackendScope(backend):
                device_id = torch_npu.npu.current_device()
                torch_npu._C._recovery_all_npu_stream(device_id)

            
    def new_call(self, model_, inputs_):
        backend = _resolve_npu_backend_from_wrapper(self)
        with _NpuBackendScope(backend):
            return src_call(self, model_, inputs_)
            
        
    _TorchCompileInductorWrapper.__call__ = new_call
    _TorchCompileInductorWrapper.apply_options = new_apply_options
    _TorchCompileInductorWrapper.__init__ = new_init
    ConfigModule.get_config_copy = new_get_config_copy
    torch._inductor.config.get_config_copy()


def patch_dynamo_optimize():
    src_optimize = optimize

    def npu_optimize(*args, **kwargs):
        backend = None
        if "backend" in kwargs:
            backend = kwargs["backend"]
        elif len(args) == 1:
            backend = args[0]

        backend_name = None
        if isinstance(backend, str):
            backend_name = backend
        elif isinstance(backend, _TorchCompileWrapper):
            backend_name = backend.compiler_name

        if backend_name == "npu":
            # Init torchair ahead of running model.
            _get_global_npu_backend(backend_name)
        return src_optimize(*args, **kwargs)

    torch._dynamo.optimize = npu_optimize


def patch_base_schedulernode():
    from torch._inductor.scheduler import BaseSchedulerNode, ExternKernelSchedulerNode

    original_get_read_write_buffer_accesses = (
        BaseSchedulerNode.get_read_write_buffer_accesses
    )

    def new_get_read_write_buffer_accesses(
        self_instance, include_reads: bool, include_writes: bool
    ) -> dict[str, int]:
        if isinstance(self_instance, ExternKernelSchedulerNode):
            return {}
        return original_get_read_write_buffer_accesses(
            self_instance, include_reads, include_writes
        )

    BaseSchedulerNode.get_read_write_buffer_accesses = (
        new_get_read_write_buffer_accesses
    )


def patch_user_defined_class_variable():
    import functools

    original_method = UserDefinedClassVariable._in_graph_classes

    @staticmethod
    @functools.lru_cache(None)
    def patched_in_graph_classes():
        result = original_method()
        result.add(torch.npu.Event)
        result.add(torch.npu.Stream)
        return result

    UserDefinedClassVariable._in_graph_classes = patched_in_graph_classes


def fake_record_stream(self, s):
    """
    let dynamo trace Tensor.record_stream as this empty function,
    and you can replace it later in your compile backend to an actual function
    """
    if isinstance(self, torch._subclasses.fake_tensor.FakeTensor):
        return
    raise RuntimeError(
        "tensor.record_stream is not supported on torch.compile! "
        "You should write a pass to replace torch.npu.fake_record_stream to an actual function in FX graph "
        "before aot_autograd."
    )


def patch_record_stream():
    torch.npu.fake_record_stream = fake_record_stream

    def method_record_stream(self, s):
        tx = torch._dynamo.symbolic_convert.InstructionTranslator.current_tx()
        return torch._dynamo.variables.TorchInGraphFunctionVariable(
            torch.npu.fake_record_stream
        ).call_function(tx, [self, s], {})

    torch._dynamo.variables.tensor.TensorVariable.method_record_stream = (
        method_record_stream
    )


def patch_variable_builder():
    original_warp = torch._dynamo.variables.builder.VariableBuilder._wrap

    def _patch_wrapper(self, value):
        if isinstance(value, torch.npu.Event):
            self.install_guards(torch._dynamo.guards.GuardBuilder.ID_MATCH)
            torch._dynamo.utils.store_user_object_weakref(value)
            event_proxy = self.tx.output.create_proxy(
                "call_function",
                torch._dynamo.utils.get_user_object_from_id,
                (id(value),),
                {},
            )
            torch._dynamo.utils.set_example_value(event_proxy.node, value)
            out = torch._dynamo.variables.ctx_manager.EventVariable(
                event_proxy,
                value,
                source=self.source,
            )
            return out
        return original_warp(self, value)

    torch._dynamo.variables.builder.VariableBuilder._wrap = _patch_wrapper


def patch_builtin_variable():
    origin_call_id = torch._dynamo.variables.builtin.BuiltinVariable.call_id

    def _wrap_call_id(self, tx, *args):
        if torch._dynamo.variables.builtin.istype(
            args[0], torch._dynamo.variables.ctx_manager.EventVariable
        ):
            return torch._dynamo.variables.ConstantVariable.create(id(args[0].value))
        return origin_call_id(self, tx, *args)

    torch._dynamo.variables.builtin.BuiltinVariable.call_id = _wrap_call_id


def patch_event_variable_python_type():
    """
    Add the 'python_type' method to the EventVariable class.
    """

    def python_type(self):
        return type(self.value)

    if "python_type" not in torch._dynamo.variables.ctx_manager.EventVariable.__dict__:
        torch._dynamo.variables.ctx_manager.EventVariable.python_type = python_type


def add_dynamo_methods():
    UserDefinedClassVariable.__new__raw = UserDefinedClassVariable.__new__
    UserDefinedClassVariable.__new__ = UserDefinedClassVariable__new__
    SkipFunctionVariable.__new__raw = SkipFunctionVariable.__new__
    SkipFunctionVariable.__new__ = SkipFunctionVariable__new__
    TensorVariable.call_method_raw = TensorVariable.call_method
    TensorVariable.call_method = TensorVariable_call_method
    patch_dynamo_optimize()
    patch_inductor_wrapper()
    patch_base_schedulernode()
    patch_user_defined_class_variable()
    patch_record_stream()
    patch_event_variable_python_type()
    patch_variable_builder()
    patch_builtin_variable()