import os
import torch
ORG_AUTOLOAD = os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "1")
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0"
from torch._inductor.async_compile import AsyncCompile


AsyncCompile.warm_pool()
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = ORG_AUTOLOAD
# All backends need npu/cpu/mps device_op_overrides.
from .codegen.common import register_device_op_overrides_npu
from .shape_handling import NPUShapeHandling, patch_shape_handling
from .utils import patch_has_triton, patch_has_triton_tma, patch_is_gpu
from .graph import patch_codegen_with_cpp_wrapper
from .codegen.common import patch_cache_base_get_system

register_device_op_overrides_npu()
patch_has_triton()
patch_has_triton_tma()
patch_is_gpu()
patch_codegen_with_cpp_wrapper()
patch_cache_base_get_system()

# Prevent RecursionError when formatting LoweringException for huge output tuples (e.g. many permute nodes).
from .mfusion.safe_inductor_exc import apply_safe_operator_str_patch_if_enabled


apply_safe_operator_str_patch_if_enabled()

def _get_backend() -> str: 
     return os.getenv("TORCHINDUCTOR_NPU_BACKEND", "default")

def _load_mlir_backend():
    import torch
    import torch_npu
    try:
        import torch_mlir
        from torch_mlir import ir
    except ImportError as e:
        raise ImportError("torch_mlir is not installed, install it first.") from e
    from .ascend_npu_ir.ascend_npu_ir.npu import npu_inductor_plugin, torch_mlir_patch


def _load_dvm_backend():
    from .ascend_npu_ir.ascend_npu_ir.npu import npu_inductor_plugin
    from .dvm import mlir_fusion
    has_triton = torch.utils._triton.has_triton()
    if has_triton:
        from .codegen.triton import patch_gen_common_triton_ext_imports, patch_triton_scheduling
        from .runtime import patch_triton_heuristics_cached_autotune
        patch_gen_common_triton_ext_imports()
        patch_triton_scheduling()
        patch_triton_heuristics_cached_autotune()


def _load_triton_backend():
    import os
    import torch
    has_triton = torch.utils._triton.has_triton()
    if not has_triton:
        return
    import logging
    log = logging.getLogger(__name__)

    import torch
    from torch._dynamo.device_interface import get_interface_for_device
    from torch._inductor import lowering as inductor_lowering
    from torch._inductor.codegen.common import (
        register_backend_for_device,
        register_device_op_overrides,
    )
    from torch.nn.attention import flex_attention

    from . import codegen, config as npu_config
    from .async_compile import patch_async_compile
    from .codegen._sizevars import patch_simplify
    from .codegen.ir import patch_indexing, patch_loop_body

    from .config import (
        aggresive_autotune,
        log as npulog,
        max_precompiled_thread_num,
        num_vector_core,
    )
    from .cpp_builder import (
        patch_get_cpp_torch_device_options,
        patch_get_optimization_cflags,
    )
    from .decomposition import _register_npu_inductor_decompositions
    from .dependencies import patch_extract_read_writes
    from .fx_passes import patch_pattern_mm_plus_mm
    from .fx_passes.graph_match_pass import (
        post_grad_custom_pass_fuc,
        pre_grad_custom_pass_fuc,
    )
    from .fx_passes.pattern_match.npu_fusion_attention_graph import register_fa_pass
    from .fx_passes.joint_graph import patch_constant_fold_uniform_value
    from .ir import patch_num_splits
    from .kernel import (
        _register_npu_inductor_addmm,
        _register_npu_inductor_bmm,
        _register_npu_inductor_flex_attention,
        _register_npu_inductor_grouped_mm,
        _register_npu_inductor_mm,
        _validate_device,
    )
    from .lowering import make_reduction
    from .runtime import (
        patch_create_device_properties,
        patch_load_cached_autotuning,
    )
    from .scheduler import patch_scheduler
    from .select_algorithm import patch_algorithm_selector
    from .shape_handling import NPUShapeHandling, patch_shape_handling
    from .utils import patch_get_first_incompatible_cudagraph_node

    from .graph import patch_count_bytes, patch_run_node
    
    from .autotune_process import patch_tuning_process, patch_tuning_process_pool
    from .codegen.cpp_utils import patch_device_to_aten
    from .codegen.triton import patch_gen_common_triton_ext_imports, patch_triton_scheduling
    from .runtime import patch_triton_heuristics_cached_autotune
    flex_attention._validate_device = _validate_device

    def _patch_flex_attention_singleton_sort():
        original = getattr(flex_attention, "_dense_to_ordered", None)
        if original is None or getattr(original, "_torch_npu_singleton_sort_patch", False):
            return

        def _dense_to_ordered_npu_safe(dense_mask):
            if dense_mask.ndim > 0 and dense_mask.size(-1) == 1:
                dense_mask = dense_mask.to(dtype=torch.int32)
                num_blocks_in_row = dense_mask.sum(dim=-1)
                col_indices = torch.zeros_like(dense_mask, dtype=torch.int32)
                return (
                    num_blocks_in_row.to(torch.int32, memory_format=torch.contiguous_format),
                    col_indices.to(torch.int32, memory_format=torch.contiguous_format),
                )
            return original(dense_mask)

        _dense_to_ordered_npu_safe._torch_npu_singleton_sort_patch = True
        flex_attention._dense_to_ordered = _dense_to_ordered_npu_safe

    _patch_flex_attention_singleton_sort()

    def _inductor_register_backend_for_device():
        from .codegen.cpp_wrapper import CppWrapperNpu
        from .codegen.npu_combined_scheduling import NPUCombinedScheduling
        from .codegen.wrapper import NPUWrapperCodeGen

        register_backend_for_device(
            "npu", NPUCombinedScheduling, NPUWrapperCodeGen, CppWrapperNpu
        )

    _inductor_register_backend_for_device()

    device = get_interface_for_device("npu")

    inductor_lowering.make_reduction = make_reduction

    patch_get_cpp_torch_device_options()
    patch_constant_fold_uniform_value()
    patch_gen_common_triton_ext_imports()
    patch_triton_scheduling()
    patch_triton_heuristics_cached_autotune()
    if npu_config.dump_fx_graph:
        from .codegen.ir_fx import _patch_npu_inductor_ir

        _patch_npu_inductor_ir()

    from .lowering import (
        _enable_full_lowering_fallback,
        _register_npu_inductor_fallbacks,
    )

    _register_npu_inductor_decompositions(backend="triton")

    if npu_config.enable_full_lowering_fallback.strip() == "allfallback":
        _enable_full_lowering_fallback()
    else:
        _register_npu_inductor_fallbacks()
        _register_npu_inductor_mm()
        _register_npu_inductor_addmm()
        _register_npu_inductor_bmm()
        _register_npu_inductor_grouped_mm()

    _register_npu_inductor_flex_attention()

    patch_pattern_mm_plus_mm()
    patch_algorithm_selector()
    patch_async_compile()
    patch_scheduler()
    patch_simplify()
    patch_num_splits()
    patch_loop_body()
    patch_indexing()

    patch_create_device_properties()
    patch_load_cached_autotuning()
    pre_grad_custom_pass_fuc()
    post_grad_custom_pass_fuc()
    if os.environ.get("ENABLE_PARALLEL_SCHEDULER", "false").lower() == "true":
        from .fx_passes.parallel_scheduler_pass import parallel_scheduler

        parallel_scheduler()

    # register fx_pass should be put behind of _register_npu_inductor_decompositions
    def _replace_benchmark_all_configs():
        from torch._inductor.runtime.triton_heuristics import CachingAutotuner

        from .runtime.triton_heuristics import (
            _benchmark_all_configs,
            benchmark_all_configs,
        )

        CachingAutotuner._benchmark_all_configs = _benchmark_all_configs
        CachingAutotuner.benchmark_all_configs = benchmark_all_configs

    def _replace_precompile():
        from .runtime.triton_heuristics import NPUCachingAutotuner, precompile_parallel

        NPUCachingAutotuner.precompile = precompile_parallel

    if aggresive_autotune:
        _replace_benchmark_all_configs()

    if max_precompiled_thread_num > 1:
        _replace_precompile()

    register_fa_pass()
    patch_get_first_incompatible_cudagraph_node()
    patch_get_optimization_cflags()
    patch_extract_read_writes()
    patch_count_bytes()
    patch_run_node()
    patch_tuning_process()
    patch_tuning_process_pool()
    patch_device_to_aten()

    def add_additional_op():
        from torch._inductor.ops_handler import OpsHandler
        from torch._inductor.utils import register_op_dtype_propagation_rules
        from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND

        def index_select(
            self, name, index, indirect_var, set_indirect, bound, index_select_type
        ):
            return self._default(
                "index_select",
                (name, index, indirect_var, set_indirect, bound, index_select_type),
                {},
            )

        OpsHandler.index_select = index_select
        register_op_dtype_propagation_rules(
            "index_select", ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, None
        )

        def gather_template(
            self, name, index, indirect_var, set_indirect, index_boundary
        ):
            return self._default(
                "gather_template",
                (name, index, indirect_var, set_indirect, index_boundary),
                {},
            )

        OpsHandler.gather_template = gather_template
        register_op_dtype_propagation_rules(
            "gather_template", ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, None
        )

        def indexput_template(self, name, index, value, indirect_var, boundary):
            return self._default(
                "indexput_template", (name, index, value, indirect_var, boundary), {}
            )

        OpsHandler.indexput_template = indexput_template
        register_op_dtype_propagation_rules(
            "indexput_template", ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, None
        )

        def scatter_template(self, name, index, value, indirect_var, boundary):
            return self._default(
                "scatter_template", (name, index, value, indirect_var, boundary), {}
            )

        OpsHandler.scatter_template = scatter_template
        register_op_dtype_propagation_rules(
            "scatter_template", ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, None
        )

    add_additional_op()
    torch._inductor.config.comprehensive_padding = False

    compile_threads = int(
        os.environ.get("TORCHINDUCTOR_COMPILE_THREADS") or "1"
    )
    os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = str(compile_threads)
    torch._inductor.config.compile_threads = compile_threads

    _fasta_autotune = os.environ.get("FASTAUTOTUNE", "0") == "1"
    _fasta_autotune_method = os.getenv("AUTOTUNE_METHOD", "Expert")
    if _fasta_autotune:
        if os.environ.get("ENABLE_PRINT_UB_BITS", "0") == "0":
            log.warnings(
                "Please set ENABLE_PRINT_UB_BITS to 1. Fasta autotune need to know real ub usage."
            )
            os.environ["ENABLE_PRINT_UB_BITS"] = "1"

        if (
            _fasta_autotune_method == "SampleStack"
            and torch._inductor.config.compile_threads != 1
        ):
            log.warnings(
                "fasta SampleStack method is not temporarily compatible with multi-process compile, "
                "fasta_autotune set TORCHINDUCTOR_COMPILE_THREADS "
                f"from {torch._inductor.config.compile_threads} to 1."
            )
            os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
            torch._inductor.config.compile_threads = 1

_BACKEND_LOADERS = {
    "mlir": _load_mlir_backend,
    "dvm": _load_dvm_backend,
    "default": _load_triton_backend,
}


def _load_backend():
    backend = _get_backend()
    loader = _BACKEND_LOADERS.get(backend, _load_triton_backend)
    loader()
    from .decomposition import _register_shared_decompositions
    _register_shared_decompositions()
    from ..utils._dynamo import _InductorNpuRegistry
    _InductorNpuRegistry._loaded_backend = backend


_load_backend()

# Optional MFusion integration: patch Inductor fallback / post-grad when explicitly enabled.
if os.getenv("TORCHINDUCTOR_ENABLE_MFUSION", "0") == "1":
    from .mfusion.graph_fusion import MFusionPatch

    MFusionPatch.enable()