import os
ORG_AUTOLOAD = os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "1")
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0"
from torch._inductor.async_compile import AsyncCompile
AsyncCompile.warm_pool()
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = ORG_AUTOLOAD
if os.getenv('TORCHINDUCTOR_NPU_BACKEND', 'default') == 'mlir':
try:
import torch_mlir
from torch_mlir import ir
except:
raise ImportError("torch_mlir is not installed, install it first.")
from .ascend_npu_ir.ascend_npu_ir.npu import npu_inductor_plugin
from .ascend_npu_ir.ascend_npu_ir.npu import torch_mlir_patch
elif os.getenv('TORCHINDUCTOR_NPU_BACKEND', 'default') == 'dvm':
from .ascend_npu_ir.ascend_npu_ir.npu import npu_inductor_plugin
from .dvm import mlir_fusion
else:
import os
import torch
from torch._dynamo.device_interface import register_interface_for_device, get_interface_for_device
from torch._inductor import lowering as inductor_lowering
from torch._inductor.choices import InductorChoices
from torch._inductor.codegen.common import register_backend_for_device, register_device_op_overrides
from torch._inductor.runtime import autotune_cache
from torch_npu.npu import device_count
from torch_npu.utils._dynamo_device import NpuInterface, current_device, set_device
from torch_npu.utils._inductor import NPUDeviceOpOverrides
from . import config as npu_config
from . import codegen
from .fx_passes.pattern_match.npu_fusion_attention_graph import register_fa_pass
from .config import (
aggresive_autotune, num_vector_core, set_compile_threads,
disable_comprehensive_padding, max_precompiled_thread_num
)
from .config import log as npulog
from .codegen.triton import patch_gen_common_triton_ext_imports
from .decomposition import _register_npu_inductor_decompositons
from .graph import patch_count_bytes, patch_run_node
from .lowering import make_reduction, npu_make_fallback
from .npu_choices import should_use_persistent_reduction
from .npu_device import NewNPUDeviceOpOverrides
from .npu_triton_heuristics import patch_triton_heuristics_cached_autotune
from .runtime import patch_load_cached_autotuning, patch_create_device_properties
from .utils import (
get_current_raw_stream,
patch_is_gpu,
patch_has_triton,
disable_foreach,
patch_get_first_incompatible_cudagraph_node
)
from .codecache import patch_aot_code_compiler_compile, patch_cache_base_get_system
from .scheduler import patch_scheduler
from .shape_handling import NPUShapeHandling, patch_shape_handling
from .async_compile import patch_async_compile
from .autotune_process import patch_tuning_process, patch_tuning_process_pool
from .select_algorithm import patch_algorithm_selector
from .fx_passes import patch_pattern_mm_plus_mm
from .kernel import (
_register_npu_inductor_mm,
_register_npu_inductor_addmm,
_register_npu_inductor_bmm,
_register_npu_inductor_grouped_mm,
_register_npu_inductor_flex_attention,
_validate_device,
)
from .cpp_builder import patch_get_optimization_cflags
from torch.nn.attention import flex_attention
flex_attention._validate_device = _validate_device
set_compile_threads()
disable_comprehensive_padding()
def _inductor_register_backend_for_device():
from .codegen.npu_combined_scheduling import NPUCombinedScheduling
from .codegen.wrapper import NPUWrapperCodeGen
from .codegen.cpp_wrapper import CppWrapperNpu
register_backend_for_device('npu', NPUCombinedScheduling, NPUWrapperCodeGen, CppWrapperNpu)
_inductor_register_backend_for_device()
def _inductor_register_device_op_overrides():
from torch._inductor.codegen import cpu_device_op_overrides
register_device_op_overrides('npu', NewNPUDeviceOpOverrides())
_inductor_register_device_op_overrides()
device = get_interface_for_device("npu")
inductor_lowering.make_reduction = make_reduction
inductor_lowering.make_fallback = npu_make_fallback
def patch_torch_for_aoti():
from .graph import patch_codegen_with_cpp_wrapper
from .cpp_builder import patch_get_cpp_torch_device_options
from .codegen.cpp_utils import patch_device_to_aten
from .utils import patch_is_same_tensor
from .fx_passes.joint_graph import patch_constant_fold_uniform_value
from .ir import patch_fallback_kernel_codegen
patch_codegen_with_cpp_wrapper()
patch_get_cpp_torch_device_options()
patch_device_to_aten()
patch_is_same_tensor()
patch_constant_fold_uniform_value()
patch_fallback_kernel_codegen()
patch_aot_code_compiler_compile()
from .fx_passes.graph_match_pass import pre_grad_custom_pass_fuc
pre_grad_custom_pass_fuc()
from .fx_passes.graph_match_pass import post_grad_custom_pass_fuc
post_grad_custom_pass_fuc()
if os.environ.get("DISABLE_AOTI_PATCH", "0") != "1":
patch_torch_for_aoti()
if npu_config.dump_fx_graph:
from .codegen.ir_fx import _patch_npu_inductor_ir
_patch_npu_inductor_ir()
from .lowering import _register_npu_inductor_fallbacks
_register_npu_inductor_fallbacks()
_register_npu_inductor_decompositons()
_register_npu_inductor_mm()
_register_npu_inductor_addmm()
_register_npu_inductor_bmm()
_register_npu_inductor_grouped_mm()
_register_npu_inductor_flex_attention()
patch_pattern_mm_plus_mm()
patch_algorithm_selector()
patch_tuning_process()
patch_tuning_process_pool()
patch_async_compile()
patch_scheduler()
patch_gen_common_triton_ext_imports()
patch_load_cached_autotuning()
patch_create_device_properties()
patch_triton_heuristics_cached_autotune()
def _replace_benchmark_all_configs():
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from .npu_triton_heuristics import benchmark_all_configs, _benchmark_all_configs
CachingAutotuner._benchmark_all_configs = _benchmark_all_configs
CachingAutotuner.benchmark_all_configs = benchmark_all_configs
def _replace_precompile():
from .npu_triton_heuristics import precompile_parallel, NPUCachingAutotuner
NPUCachingAutotuner.precompile = precompile_parallel
if (aggresive_autotune):
_replace_benchmark_all_configs()
if (max_precompiled_thread_num > 1):
_replace_precompile()
InductorChoices.should_use_persistent_reduction = should_use_persistent_reduction
def patch_device_override_func():
def get_device_op_overrides_patch(device_name: str):
def register_cpu_backend():
from torch._inductor.codegen import cpu_device_op_overrides
return
def register_mps_backend():
from torch._inductor.codegen import mps_device_op_overrides
return
backend_factory = {"cpu": register_cpu_backend, "mps": register_mps_backend}
if device_name not in torch._inductor.codegen.common.device_op_overrides_dict:
if device_name not in backend_factory:
raise ValueError("backend not found: ", device_name)
backend_factory[device_name]()
return torch._inductor.codegen.common.device_op_overrides_dict[device_name]
torch._inductor.graph.get_device_op_overrides = get_device_op_overrides_patch
register_fa_pass()
patch_cache_base_get_system()
patch_count_bytes()
patch_run_node()
patch_is_gpu()
patch_has_triton()
disable_foreach()
patch_get_first_incompatible_cudagraph_node()
patch_device_override_func()
patch_get_optimization_cflags()
def add_additional_op():
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch._inductor.ops_handler import OpsHandler
from torch._inductor.utils import register_op_dtype_propagation_rules
def cat_insert_slice(self, dst, src, offset, size, output_size):
return self._default("cat_insert_slice", (dst, src, offset, size, output_size), {})
OpsHandler.cat_insert_slice = cat_insert_slice
register_op_dtype_propagation_rules("cat_insert_slice", ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, None)
def cat_store(self, dst, src, size, store_offset_index, output_buffer_index):
return self._default("cat_store", (dst, src, size, store_offset_index, output_buffer_index), {})
OpsHandler.cat_store = cat_store
register_op_dtype_propagation_rules("cat_store", ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, None)
def index_select(self, name, index, indirect_var, set_indirect, bound, index_select_type):
return self._default("index_select", (name, index, indirect_var, set_indirect, bound, index_select_type), {})
OpsHandler.index_select = index_select
register_op_dtype_propagation_rules("index_select", ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, None)
def gather_template(self, name, index, indirect_var, set_indirect, index_boundary):
return self._default("gather_template", (name, index, indirect_var, set_indirect, index_boundary), {})
OpsHandler.gather_template = gather_template
register_op_dtype_propagation_rules("gather_template", ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, None)
def indexput_template(self, name, index, value, indirect_var, boundary):
return self._default("indexput_template", (name, index, value, indirect_var, boundary), {})
OpsHandler.indexput_template = indexput_template
register_op_dtype_propagation_rules("indexput_template", ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, None)
def scatter_template(self, name, index, value, indirect_var, boundary):
return self._default("scatter_template", (name, index, value, indirect_var, boundary), {})
OpsHandler.scatter_template = scatter_template
register_op_dtype_propagation_rules("scatter_template", ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, None)
add_additional_op()