import os
import torch
ORG_AUTOLOAD = os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "1")
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0"
from torch._inductor.async_compile import AsyncCompile
AsyncCompile.warm_pool()
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = ORG_AUTOLOAD
from .codegen.common import register_device_op_overrides_npu
from .shape_handling import NPUShapeHandling, patch_shape_handling
from .utils import patch_has_triton, patch_has_triton_tma, patch_is_gpu
from .graph import patch_codegen_with_cpp_wrapper
from .codegen.common import patch_cache_base_get_system
register_device_op_overrides_npu()
patch_has_triton()
patch_has_triton_tma()
patch_is_gpu()
patch_codegen_with_cpp_wrapper()
patch_cache_base_get_system()
from .mfusion.safe_inductor_exc import apply_safe_operator_str_patch_if_enabled
apply_safe_operator_str_patch_if_enabled()
def _get_backend() -> str:
return os.getenv("TORCHINDUCTOR_NPU_BACKEND", "default")
def _load_mlir_backend():
import torch
import torch_npu
try:
import torch_mlir
from torch_mlir import ir
except ImportError as e:
raise ImportError("torch_mlir is not installed, install it first.") from e
from .ascend_npu_ir.ascend_npu_ir.npu import npu_inductor_plugin, torch_mlir_patch
def _load_dvm_backend():
from .ascend_npu_ir.ascend_npu_ir.npu import npu_inductor_plugin
from .dvm import mlir_fusion
has_triton = torch.utils._triton.has_triton()
if has_triton:
from .codegen.triton import patch_gen_common_triton_ext_imports, patch_triton_scheduling
from .runtime import patch_triton_heuristics_cached_autotune
patch_gen_common_triton_ext_imports()
patch_triton_scheduling()
patch_triton_heuristics_cached_autotune()
def _load_triton_backend():
import os
import torch
has_triton = torch.utils._triton.has_triton()
if not has_triton:
return
import logging
log = logging.getLogger(__name__)
import torch
from torch._dynamo.device_interface import get_interface_for_device
from torch._inductor import lowering as inductor_lowering
from torch._inductor.codegen.common import (
register_backend_for_device,
register_device_op_overrides,
)
from torch.nn.attention import flex_attention
from . import codegen, config as npu_config
from .async_compile import patch_async_compile
from .codegen._sizevars import patch_simplify
from .codegen.ir import patch_indexing, patch_loop_body
from .config import (
aggresive_autotune,
log as npulog,
max_precompiled_thread_num,
num_vector_core,
)
from .cpp_builder import (
patch_get_cpp_torch_device_options,
patch_get_optimization_cflags,
)
from .decomposition import _register_npu_inductor_decompositions
from .dependencies import patch_extract_read_writes
from .fx_passes import patch_pattern_mm_plus_mm
from .fx_passes.graph_match_pass import (
post_grad_custom_pass_fuc,
pre_grad_custom_pass_fuc,
)
from .fx_passes.pattern_match.npu_fusion_attention_graph import register_fa_pass
from .fx_passes.joint_graph import patch_constant_fold_uniform_value
from .ir import patch_num_splits
from .kernel import (
_register_npu_inductor_addmm,
_register_npu_inductor_bmm,
_register_npu_inductor_flex_attention,
_register_npu_inductor_grouped_mm,
_register_npu_inductor_mm,
_validate_device,
)
from .lowering import make_reduction
from .runtime import (
patch_create_device_properties,
patch_load_cached_autotuning,
)
from .scheduler import patch_scheduler
from .select_algorithm import patch_algorithm_selector
from .shape_handling import NPUShapeHandling, patch_shape_handling
from .utils import patch_get_first_incompatible_cudagraph_node
from .graph import patch_count_bytes, patch_run_node
from .autotune_process import patch_tuning_process, patch_tuning_process_pool
from .codegen.cpp_utils import patch_device_to_aten
from .codegen.triton import patch_gen_common_triton_ext_imports, patch_triton_scheduling
from .runtime import patch_triton_heuristics_cached_autotune
flex_attention._validate_device = _validate_device
def _patch_flex_attention_singleton_sort():
original = getattr(flex_attention, "_dense_to_ordered", None)
if original is None or getattr(original, "_torch_npu_singleton_sort_patch", False):
return
def _dense_to_ordered_npu_safe(dense_mask):
if dense_mask.ndim > 0 and dense_mask.size(-1) == 1:
dense_mask = dense_mask.to(dtype=torch.int32)
num_blocks_in_row = dense_mask.sum(dim=-1)
col_indices = torch.zeros_like(dense_mask, dtype=torch.int32)
return (
num_blocks_in_row.to(torch.int32, memory_format=torch.contiguous_format),
col_indices.to(torch.int32, memory_format=torch.contiguous_format),
)
return original(dense_mask)
_dense_to_ordered_npu_safe._torch_npu_singleton_sort_patch = True
flex_attention._dense_to_ordered = _dense_to_ordered_npu_safe
_patch_flex_attention_singleton_sort()
def _inductor_register_backend_for_device():
from .codegen.cpp_wrapper import CppWrapperNpu
from .codegen.npu_combined_scheduling import NPUCombinedScheduling
from .codegen.wrapper import NPUWrapperCodeGen
register_backend_for_device(
"npu", NPUCombinedScheduling, NPUWrapperCodeGen, CppWrapperNpu
)
_inductor_register_backend_for_device()
device = get_interface_for_device("npu")
inductor_lowering.make_reduction = make_reduction
patch_get_cpp_torch_device_options()
patch_constant_fold_uniform_value()
patch_gen_common_triton_ext_imports()
patch_triton_scheduling()
patch_triton_heuristics_cached_autotune()
if npu_config.dump_fx_graph:
from .codegen.ir_fx import _patch_npu_inductor_ir
_patch_npu_inductor_ir()
from .lowering import (
_enable_full_lowering_fallback,
_register_npu_inductor_fallbacks,
)
_register_npu_inductor_decompositions(backend="triton")
if npu_config.enable_full_lowering_fallback.strip() == "allfallback":
_enable_full_lowering_fallback()
else:
_register_npu_inductor_fallbacks()
_register_npu_inductor_mm()
_register_npu_inductor_addmm()
_register_npu_inductor_bmm()
_register_npu_inductor_grouped_mm()
_register_npu_inductor_flex_attention()
patch_pattern_mm_plus_mm()
patch_algorithm_selector()
patch_async_compile()
patch_scheduler()
patch_simplify()
patch_num_splits()
patch_loop_body()
patch_indexing()
patch_create_device_properties()
patch_load_cached_autotuning()
pre_grad_custom_pass_fuc()
post_grad_custom_pass_fuc()
if os.environ.get("ENABLE_PARALLEL_SCHEDULER", "false").lower() == "true":
from .fx_passes.parallel_scheduler_pass import parallel_scheduler
parallel_scheduler()
def _replace_benchmark_all_configs():
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from .runtime.triton_heuristics import (
_benchmark_all_configs,
benchmark_all_configs,
)
CachingAutotuner._benchmark_all_configs = _benchmark_all_configs
CachingAutotuner.benchmark_all_configs = benchmark_all_configs
def _replace_precompile():
from .runtime.triton_heuristics import NPUCachingAutotuner, precompile_parallel
NPUCachingAutotuner.precompile = precompile_parallel
if aggresive_autotune:
_replace_benchmark_all_configs()
if max_precompiled_thread_num > 1:
_replace_precompile()
register_fa_pass()
patch_get_first_incompatible_cudagraph_node()
patch_get_optimization_cflags()
patch_extract_read_writes()
patch_count_bytes()
patch_run_node()
patch_tuning_process()
patch_tuning_process_pool()
patch_device_to_aten()
def add_additional_op():
from torch._inductor.ops_handler import OpsHandler
from torch._inductor.utils import register_op_dtype_propagation_rules
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
def index_select(
self, name, index, indirect_var, set_indirect, bound, index_select_type
):
return self._default(
"index_select",
(name, index, indirect_var, set_indirect, bound, index_select_type),
{},
)
OpsHandler.index_select = index_select
register_op_dtype_propagation_rules(
"index_select", ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, None
)
def gather_template(
self, name, index, indirect_var, set_indirect, index_boundary
):
return self._default(
"gather_template",
(name, index, indirect_var, set_indirect, index_boundary),
{},
)
OpsHandler.gather_template = gather_template
register_op_dtype_propagation_rules(
"gather_template", ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, None
)
def indexput_template(self, name, index, value, indirect_var, boundary):
return self._default(
"indexput_template", (name, index, value, indirect_var, boundary), {}
)
OpsHandler.indexput_template = indexput_template
register_op_dtype_propagation_rules(
"indexput_template", ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, None
)
def scatter_template(self, name, index, value, indirect_var, boundary):
return self._default(
"scatter_template", (name, index, value, indirect_var, boundary), {}
)
OpsHandler.scatter_template = scatter_template
register_op_dtype_propagation_rules(
"scatter_template", ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, None
)
add_additional_op()
torch._inductor.config.comprehensive_padding = False
compile_threads = int(
os.environ.get("TORCHINDUCTOR_COMPILE_THREADS") or "1"
)
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = str(compile_threads)
torch._inductor.config.compile_threads = compile_threads
_fasta_autotune = os.environ.get("FASTAUTOTUNE", "0") == "1"
_fasta_autotune_method = os.getenv("AUTOTUNE_METHOD", "Expert")
if _fasta_autotune:
if os.environ.get("ENABLE_PRINT_UB_BITS", "0") == "0":
log.warnings(
"Please set ENABLE_PRINT_UB_BITS to 1. Fasta autotune need to know real ub usage."
)
os.environ["ENABLE_PRINT_UB_BITS"] = "1"
if (
_fasta_autotune_method == "SampleStack"
and torch._inductor.config.compile_threads != 1
):
log.warnings(
"fasta SampleStack method is not temporarily compatible with multi-process compile, "
"fasta_autotune set TORCHINDUCTOR_COMPILE_THREADS "
f"from {torch._inductor.config.compile_threads} to 1."
)
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
torch._inductor.config.compile_threads = 1
_BACKEND_LOADERS = {
"mlir": _load_mlir_backend,
"dvm": _load_dvm_backend,
"default": _load_triton_backend,
}
def _load_backend():
backend = _get_backend()
loader = _BACKEND_LOADERS.get(backend, _load_triton_backend)
loader()
from .decomposition import _register_shared_decompositions
_register_shared_decompositions()
from ..utils._dynamo import _InductorNpuRegistry
_InductorNpuRegistry._loaded_backend = backend
_load_backend()
if os.getenv("TORCHINDUCTOR_ENABLE_MFUSION", "0") == "1":
from .mfusion.graph_fusion import MFusionPatch
MFusionPatch.enable()