import logging
import os
from triton.runtime.driver import driver
import torch
from torch._inductor import config
enable_npu_indexing = True
config.triton.unique_kernel_names = True
config.allow_buffer_reuse = False
config.trace.enabled = True
config.fallback_random = True
target = driver.active.get_current_target()
device = driver.active.get_current_device()
prop = driver.active.utils.get_device_properties(device)
num_cube_core = prop["num_aicore"]
num_vector_core = prop["num_aicore"]
npu_block = 32
class aot_inductor:
debug_kernel = os.environ.get("AOTI_ASCEND_DEBUG_KERNEL", False)
debug_kernel_in_run = False
repro_tensor_path = os.environ.get(
"AOTI_ASCEND_REPRO_TENSOR_PATH", "aoti_repro_tensors"
)
dump_path_cpp = os.environ.get("AOTI_ASCEND_DUMP_PATH_CPP", "aoti_dump_cpp")
dump_path_py = os.environ.get("AOTI_DUMP_PATH_PY", "aoti_dump_py")
class _npugraph_trees:
def __init__(self):
self._disable_cpu_input_check = False
@property
def disable_cpu_input_check(self):
return self._disable_cpu_input_check
@disable_cpu_input_check.setter
def disable_cpu_input_check(self, value):
self._disable_cpu_input_check = bool(value)
if value:
torch._inductor.config.triton.slow_path_cudagraph_asserts = False
npugraph_trees = _npugraph_trees()
enable_full_lowering_fallback = os.environ.get("NPU_INDUCTOR_FALLBACK_LIST", "")
traced_fx_graph_cache = os.environ.get("INDUCTOR_ASCEND_FX_GRAPH_CACHE", None)
check_accuracy = os.environ.get("INDUCTOR_ASCEND_CHECK_ACCURACY", False)
auto_fallback = os.environ.get("INDUCTOR_ASCEND_AUTO_FALLBACK", True)
fallback_warning = os.environ.get("INDUCTOR_ASCEND_FALLBACK_WARNING", False)
dump_fx_graph = (
os.environ.get("INDUCTOR_ASCEND_DUMP_FX_GRAPH", False)
or check_accuracy
or aot_inductor.debug_kernel
)
force_fallback_kernel_id = []
skip_specific_stride_asserts = []
acc_comp_tol = {
torch.float32: {"rtol": 1.3e-6, "atol": 1e-5},
torch.float16: {"rtol": 1e-3, "atol": 1e-5},
torch.bfloat16: {"rtol": 1.6e-2, "atol": 1e-5},
"default": {"rtol": 1.3e-6, "atol": 1e-5},
}
if "Ascend910B" in target.arch:
num_vector_core = num_cube_core * 2
log_level_env = os.getenv("INDUCTOR_ASCEND_LOG_LEVEL", "WARNING").upper()
log_level_mapping = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
log_level = log_level_mapping.get(log_level_env.upper(), logging.INFO)
logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s")
log = logging.getLogger(__name__)
aggresive_autotune = os.getenv("INDUCTOR_ASCEND_AGGRESSIVE_AUTOTUNE", "0").lower() in (
"1",
"true",
)
inductor_static_mode = os.environ.get("INDUCTOR_STATIC_MODE", "0").lower() in (
"1",
"yes",
"true",
)
profile_path = "./profile_result/"