from pathlib import Path
import tempfile
import os
import os.path
import re
import subprocess
import sysconfig
from typing import Optional
import functools
import hashlib
from triton.runtime.cache import get_cache_manager, get_dump_manager
from triton.backends.driver import DriverBase
from triton.backends.compiler import GPUTarget
from triton.backends.ascend.utils import (
_precompile_npu_hash,
_precompile_npu_ext,
_build_npu_ext,
_check_cxx11_abi,
convert_sigtype_to_int,
_is_auto_map_parallel_blocks_enabled,
get_ascend_arch_from_env,
is_ffts_supported,
force_disable_ffts,
get_backend_func
)
class NPUUtils(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(NPUUtils, cls).__new__(cls)
return cls.instance
def __init__(self):
dirname = os.path.dirname(os.path.realpath(__file__))
src_path = os.path.join(dirname, "npu_utils.cpp")
src = Path(src_path).read_text()
key = hashlib.md5(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
fname = "npu_utils.so"
cache_path = cache.get_file(fname)
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
tmp_src_path = os.path.join(tmpdir, "npu_utils.cpp")
with open(tmp_src_path, "w") as f:
f.write(src)
so = _build_npu_ext("npu_utils", None, tmp_src_path)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("npu_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.npu_utils_mod = mod
env_arch = get_ascend_arch_from_env()
def load_binary(self, name, kernel, shared, device):
fnname, mix_mode = name.split()
return self.npu_utils_mod.load_kernel_binary(fnname, kernel, shared, device, mix_mode)
@functools.lru_cache()
def get_device_properties(self, device):
num_aic = self.get_aicore_num()
num_aiv = num_aic * 2
return {"max_shared_mem": 1, "num_aicore": num_aic, "num_vectorcore": num_aiv}
@functools.lru_cache()
def get_arch(self):
return self.npu_utils_mod.get_arch()
@functools.lru_cache()
def get_aicore_num(self):
return self.npu_utils_mod.get_aicore_num()
@functools.lru_cache()
def get_aivector_core_num(self):
return self.get_device_properties("npu")["num_vectorcore"]
@functools.lru_cache()
def set_device_limit(self, device, ty, val):
"""
Set npu device limit
Args:
device: Device id
ty: The type of the limit, valid types include:
"LOW_POWER_TIMEOUT", "WARP_STACK_SIZE", "DVG_WARP_STACK_SIZE", "STACK_SIZE"
val: The specific meaning of the value depends on the type of limit.
"""
self.npu_utils_mod.set_device_limit(device, ty, val)
class NPULauncher(object):
def __init__(self, src, metadata):
self.compile_only = os.getenv("TRITON_COMPILE_ONLY", 'false').lower() in ('true', '1')
self.enable_msprof_register_tensor = os.getenv("TRITON_REGISTER_TENSOR_MSPROF", 'false').lower() in ('true', '1')
debug_mode = metadata.debug
header_src = generate_npu_header_src()
constants = src.constants if hasattr(src, "constants") else dict()
cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i
constants = {cst_key(key): value for key, value in constants.items()}
signature = {cst_key(key): value for key, value in src.signature.items()}
wrapper_src = generate_npu_wrapper_src(constants, signature, metadata)
so_launcher_path = make_npu_launcher_stub(header_src, wrapper_src, metadata.debug)
self.mix_mode = metadata.mix_mode
self.shared = metadata.shared
import importlib.util
spec = importlib.util.spec_from_file_location("__triton_launcher", so_launcher_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.launch = getattr(mod, "launch")
def __call__(self, *args, **kwargs):
if self.compile_only:
cache_manager = get_cache_manager(args[5]['hash'])
print("[INFO]: skip running kernel")
print(f"[INFO]: The compiled kernel cache is in {cache_manager.cache_dir}")
if self.enable_msprof_register_tensor:
tensor_params_shape = get_backend_func("get_tensor_params_shape", *args)
args[5]['tensor_params_shape'] = tensor_params_shape
else:
if self.compile_only:
return
profiler_registered = self.launch(*args, **kwargs)
import triton
triton.backends.ascend.utils.TRITON_PROFILER_REGISTERED = True if profiler_registered == 1 else False
class NPUDriver(DriverBase):
def __init__(self):
self.utils = NPUUtils()
self.launcher_cls = NPULauncher
super().__init__()
@classmethod
def is_active(cls):
def test_npucompiler():
from triton.backends.ascend.utils import _get_bisheng_path
npucompiler = _get_bisheng_path()
targets = subprocess.check_output([npucompiler, "-print-targets"]).decode().strip().split()
return "hiipu64" in targets
try:
return test_npucompiler()
except Exception as e_npucompiler:
import warnings
red = "\x1b[31;20m"
reset = "\x1b[0m"
warnings.warn(red + str(e_npucompiler) + reset)
return False
def get_current_target(self):
backend = "npu"
env_target = get_ascend_arch_from_env()
if env_target:
arch = env_target
else:
arch = self.utils.get_arch()
warp_size = 0
return GPUTarget(backend, arch, warp_size)
def get_current_device(self):
"""
Get current device
"""
return get_backend_func("get_current_device")
def set_current_device(self, device):
"""
Set current device as the given device
"""
return get_backend_func("set_current_device", device)
def get_current_stream(self, device: Optional[int] = None) -> int:
"""
Get stream for current device
"""
return get_backend_func("get_current_stream", device)
def get_benchmarker(self):
from triton.testing import do_bench
return do_bench
def get_device_interface(self):
return get_backend_func("get_device_interface")
def get_empty_cache_for_benchmark(self):
cache_size = 192 * 1024 * 1024
return get_backend_func("get_empty_tensor", cache_size // 4)
def make_npu_launcher_stub(header_src, wrapper_src, debug=False):
"""
Generate the launcher stub to launch the kernel
"""
precompile_hash = _precompile_npu_hash(header_src)
cache = get_cache_manager(precompile_hash)
header_path = cache.get_file("precompiled.h")
gch_path = cache.get_file("precompiled.h.gch")
if header_path is None and gch_path is None:
header_path = cache.put(header_src, "precompiled.h", binary=False)
_precompile_npu_ext(header_path)
so_cache_key = hashlib.sha256(wrapper_src.encode("utf-8")).hexdigest()
so_cache_manager = get_cache_manager(so_cache_key)
use_cxx11_abi = _check_cxx11_abi()
name = f"launcher_cxx11abi{use_cxx11_abi}"
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so_name = f"{name}{suffix}"
if debug:
dump_manager = get_dump_manager(so_cache_key)
if header_path is not None:
print(f"Dumping precompiled.h to {dump_manager.cache_dir}")
dump_manager.put(header_src, "precompiled.h", binary=False)
print(f"Dumping {name}.cxx to {dump_manager.cache_dir}")
dump_manager.put(wrapper_src, f"{name}.cxx", binary = False)
cache_path = so_cache_manager.get_file(so_name)
if cache_path is not None:
return cache_path
kernel_launcher_type = "torch"
enable_taskqueue = os.getenv("TRITON_ENABLE_TASKQUEUE", 'true').lower() in ('true', '1')
if not enable_taskqueue:
kernel_launcher_type = None
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, f"{name}.cxx")
with open(src_path, "w") as f:
f.write(wrapper_src)
so_path = _build_npu_ext(name, header_path, src_path, kernel_launcher=kernel_launcher_type, precompile=True)
if debug:
with open(so_path, "rb") as f:
dump_manager.put(f.read(), so_name, binary=True)
with open(so_path, "rb") as f:
so_cache_path = so_cache_manager.put(f.read(), so_name, binary=True)
return so_cache_path
def extract_device_print_code_from_cann():
from triton.backends.ascend.utils import _get_bisheng_path
ccec_compiler_bin_folder, _ = os.path.split(os.path.realpath(_get_bisheng_path()))
ccec_compiler_folder, _ = os.path.split(ccec_compiler_bin_folder)
clang_version = os.listdir(os.path.join(ccec_compiler_folder, "lib/clang/"))[0]
ccelib_path = os.path.join(ccec_compiler_folder, f"lib/clang/{clang_version}/include/ccelib")
def read_header(header_path):
with open(os.path.join(ccelib_path, header_path), 'r') as f:
code = f.read()
lines = code.splitlines()
purged_lines = []
for line in lines:
normalized_line = ' '.join(line.split())
if not normalized_line.startswith('#include "'):
purged_lines.append(line)
code = '\n'.join(purged_lines)
aicore_positions = []
for m in re.finditer('\[aicore\]', code):
aicore_positions.append(m.start())
def find_aicore_function_span(src, pos):
for i in range(pos - 1, -1, -1):
if src[i] == '}':
left = i + 1
break
n = len(src)
brace_nest = 0
for j in range(pos, n, 1):
if src[j] == '{':
brace_nest += 1
elif src[j] == '}':
brace_nest -= 1
if brace_nest == 0:
right = j
break
return left, right
new_code = ''
segment_start = 0
for pos in aicore_positions:
left, right = find_aicore_function_span(code, pos)
new_code += code[segment_start:left]
segment_start = right + 1
new_code += code[segment_start:]
new_code = new_code.replace('__gm__', ' ')
new_code = new_code.replace('__CCELIB_RT_ERROR_NONE', 'RT_ERROR_NONE')
new_code = new_code.replace('__CCELIB_RT_MEMORY_HBM', 'RT_MEMORY_HBM')
new_code = new_code.replace('__CCELIB_RT_MEMCPY_HOST_TO_DEVICE', 'RT_MEMCPY_HOST_TO_DEVICE')
new_code = new_code.replace('__CCELIB_RT_MEMCPY_DEVICE_TO_HOST', 'RT_MEMCPY_DEVICE_TO_HOST')
return new_code
return '\n'.join([
read_header('common/common_impl.h'),
read_header('internal/debug_tunnel/payload.h'),
read_header('internal/debug_tunnel/payload_impl.h'),
read_header('internal/debug_tunnel/tunnel.h'),
read_header('internal/debug_tunnel/tunnel_impl.h')
])
def generate_npu_header_src():
enable_taskqueue = os.getenv(
"TRITON_ENABLE_TASKQUEUE", 'true').lower() in ('true', '1')
return f"""
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* Copyright 2018-2020 Philippe Tillet
* Copyright 2020-2022 OpenAI
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef TRITON_NPU_HEADERS
#define TRITON_NPU_HEADERS
#include <assert.h>
#include <stdbool.h>
#include <string>
#include <sys/syscall.h>
#include <vector>
#include <Python.h>
#include "runtime/runtime/rt.h"
#include <acl/acl.h>
{get_backend_func("header_file", enable_taskqueue)}
#endif
"""
def generate_npu_wrapper_src(constants, signature, metadata):
import os
workspace_size = int(metadata.workspace_size) \
if hasattr(metadata, 'workspace_size') else -1
lock_init_value = int(metadata.lock_init_value) \
if hasattr(metadata, 'lock_init_value') else 0
lock_num = int(metadata.lock_num) \
if hasattr(metadata, 'lock_num') else -1
mix_mode = metadata.mix_mode
compile_on_910_95 = metadata.compile_on_910_95
parallel_mode = metadata.parallel_mode
enable_simt = ("simt" in parallel_mode) or metadata.force_simt_only
def _ty_to_cpp(ty):
if ty[0] == '*':
return "void*"
return {
"i1": "int32_t",
"i8": "int8_t",
"i16": "int16_t",
"i32": "int32_t",
"i64": "int64_t",
"u32": "uint32_t",
"u64": "uint64_t",
"fp16": "float",
"bf16": "float",
"fp32": "float",
"f32": "float",
"fp64": "double",
}[ty]
def _extracted_ty(ty):
if ty[0] == '*':
return "PyObject*"
return {
'i1': 'int32_t',
'i32': 'int32_t',
'i64': 'int64_t',
'u32': 'uint32_t',
'u64': 'uint64_t',
'fp16': 'float',
'bf16': 'float',
'fp32': 'float',
'f32': 'float',
'fp64': 'double',
}[ty]
def _format_of(ty):
return {
"PyObject*": "O",
"float": "f",
"double": "d",
"long": "l",
"uint32_t": "I",
"int32_t": "i",
"uint64_t": "K",
"int64_t": "L",
}[ty]
arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
"""
args:
int gridX, gridY, gridZ;
rtStream_t stream;
const void *functon;
PyObject* packed_metadata, *launch_metadata;
PyObject* launch_enter_hook, *launch_exit_hook;
*args_expand
"""
format = "iiiKKOOOO" + ''.join([_format_of(_extracted_ty(ty)) for ty in signature.values()])
grid_info = {'X': 'i32', 'Y': 'i32', 'Z': 'i32'}
arch = get_ascend_arch_from_env()
target_support_ffts = is_ffts_supported(arch) and (not force_disable_ffts())
enable_device_print = os.getenv(
"TRITON_DEVICE_PRINT", 'false').lower() in ('true', '1')
enable_taskqueue = os.getenv(
"TRITON_ENABLE_TASKQUEUE", 'true').lower() in ('true', '1')
enable_grid_warn_print = os.getenv(
"TRITON_GRID_WARN_PRINT", 'false').lower() in ('true', '1')
enable_auto_map_parallel_blocks = _is_auto_map_parallel_blocks_enabled()
npu_utils = NPUUtils()
num_physical_blocks = npu_utils.get_aivector_core_num(
) if mix_mode == "aiv" else npu_utils.get_aicore_num()
task_type = "MSPROF_GE_TASK_TYPE_AIV" if mix_mode == "aiv" else "MSPROF_GE_TASK_TYPE_AI_CORE"
LINE_CHANGE_CHAR = chr(10)
alloc_success_code = 'return 1;'
sync_lock_fail_code = 'fprintf(stderr, "Error: syncBlockLock allocation failed\\n"); return;'
workspace_fail_code = 'fprintf(stderr, "Error: workspace allocation failed\\n"); return;'
cpp_device_pointer = """
typedef struct _DevicePtrInfo {
void *dev_ptr;
bool valid;
} DevicePtrInfo;
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {
DevicePtrInfo ptr_info;
ptr_info.dev_ptr = 0;
ptr_info.valid = true;
if (PyLong_Check(obj)) {
ptr_info.dev_ptr = reinterpret_cast<void *>(PyLong_AsUnsignedLongLong(obj));
return ptr_info;
}
if (obj == Py_None) {
// valid nullptr
return ptr_info;
}
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if(ptr){
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
ptr_info.valid = false;
return ptr_info;
}
ptr_info.dev_ptr = reinterpret_cast<void *>(PyLong_AsUnsignedLongLong(ret));
if(!ptr_info.dev_ptr)
return ptr_info;
Py_DECREF(ret); // Thanks ChatGPT!
return ptr_info;
}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
return ptr_info;
}
"""
cpp_msprof_extern = """
extern "C" {
typedef int (* callback)(unsigned int type, void* data, unsigned int len);
extern int MsprofReportApi(unsigned int agingFlag, const MsprofApi *api);
extern unsigned long int MsprofSysCycleTime();
extern int MsprofRegisterCallback(unsigned int moduleId, callback handle);
static unsigned int __MsprofFlagL0 = 0;
static unsigned int __MsprofFlagL1 = 0;
int ProfCtrlHandle(unsigned int CtrlType, void* CtrlData, unsigned int DataLen) {
if ((CtrlData == nullptr) || (DataLen == 0U)) {
return 1;
}
if (CtrlType == 1) {
MsprofCommandHandle* handle = (MsprofCommandHandle *)(CtrlData);
if (handle->type >= 6) // 6 is not used here
return 1;
if (handle->type == 1) { // init - 0 , start - 1
__MsprofFlagL0 = ((0x00000800ULL & handle->profSwitch) == 0x00000800ULL) ? 1 : 0;
__MsprofFlagL1 = ((0x00000002ULL & handle->profSwitch) == 0x00000002ULL) ? 1 : 0;
}
}
return 0;
}
}
"""
cpp_msprof_callback = """
MsprofRegisterCallback(8, ProfCtrlHandle); // 8 - CCE defined in msprof headerfile slog.h
"""
cpp_msprof_call_before_launch = """
unsigned long int beginTime = 0;
unsigned long int endTime = 0;
unsigned long int opNameHashID = 0;
unsigned int threadId = 0;
char* _kernelName = const_cast<char*>(name.c_str());
size_t length = name.length();
if (__MsprofFlagL0 || __MsprofFlagL1)
{
beginTime = MsprofSysCycleTime();
}
"""
cpp_msprof_call_after_launch = f"""
if (__MsprofFlagL0 || __MsprofFlagL1)
{{
endTime = MsprofSysCycleTime();
opNameHashID = MsprofGetHashId(_kernelName, length);
threadId = (unsigned int)(syscall(SYS_gettid));
MsprofApi info;
info.level = MSPROF_REPORT_NODE_LEVEL;
info.magicNumber = 0x5a5a; //MSPROF_REPORT_DATA_MAGIC_NUM
info.type = MSPROF_REPORT_NODE_LAUNCH_TYPE;
info.threadId = threadId;
info.reserve = 0;
info.beginTime = beginTime;
info.endTime = endTime;
info.itemId = opNameHashID;
MsprofReportApi(false, &info);
}}
if (__MsprofFlagL1)
{{
MsprofCompactInfo nodeBasicInfo;
nodeBasicInfo.level = MSPROF_REPORT_NODE_LEVEL;
nodeBasicInfo.magicNumber = 0x5a5a; //MSPROF_REPORT_DATA_MAGIC_NUM
nodeBasicInfo.type = MSPROF_REPORT_NODE_BASIC_INFO_TYPE;
nodeBasicInfo.threadId = threadId;
nodeBasicInfo.timeStamp = endTime;
nodeBasicInfo.data.nodeBasicInfo.opName = opNameHashID;
nodeBasicInfo.data.nodeBasicInfo.opType = opNameHashID;
nodeBasicInfo.data.nodeBasicInfo.taskType = {task_type};
nodeBasicInfo.data.nodeBasicInfo.blockDim = blockNum;
MsprofReportCompactInfo(0, static_cast<void *>(&nodeBasicInfo), sizeof(MsprofCompactInfo));
// workspace > 0 indicates a 'mix' kernel, which requires reporting the ctxID
if ({workspace_size} > 0) {{
MsprofAdditionalInfo info;
info.level = MSPROF_REPORT_NODE_LEVEL;
info.type = MSPROF_REPORT_NODE_CONTEXT_ID_INFO_TYPE;
info.threadId = threadId;
info.timeStamp = endTime;
MsprofContextIdInfo ctxId;
ctxId.opName = opNameHashID;
ctxId.ctxIdNum = 1;
for (uint32_t i = 0; i < ctxId.ctxIdNum; i++) {{
ctxId.ctxIds[i] = i;
}}
size_t copyLen = sizeof(MsprofContextIdInfo);
if (copyLen > MSPROF_ADDTIONAL_INFO_DATA_LENGTH) {{
copyLen = MSPROF_ADDTIONAL_INFO_DATA_LENGTH;
}}
std::memcpy(info.data, &ctxId, copyLen);
MsprofReportAdditionalInfo(false, static_cast<void *>(&info), sizeof(MsprofAdditionalInfo));
}}
// Report tensor info
int max_tensors_num = tensorShapes.size() < MSPROF_GE_TENSOR_DATA_NUM ? tensorShapes.size() : MSPROF_GE_TENSOR_DATA_NUM;
MsprofAdditionalInfo tensorInfo;
tensorInfo.level = MSPROF_REPORT_NODE_LEVEL;
tensorInfo.type = MSPROF_REPORT_NODE_TENSOR_INFO_TYPE;
tensorInfo.threadId = threadId;
tensorInfo.timeStamp = endTime;
auto profTensorData = reinterpret_cast<MsprofTensorInfo *>(tensorInfo.data);
profTensorData->opName = opNameHashID;
int tensorCount = 0;
int dataTypes[MSPROF_GE_TENSOR_DATA_NUM];
if (tensorShapes.size() > 0) {{
{LINE_CHANGE_CHAR.join(
f'dataTypes[{i}] = {convert_sigtype_to_int(ty[1:])};'
for i, ty in signature.items()
if ty.startswith("*") and i < 5
)}
}}
for (int i = 0; i < tensorShapes.size() && tensorCount < MSPROF_GE_TENSOR_DATA_NUM; i++) {{
auto fillTensorData = [&](int index, int tensorType) {{
profTensorData->tensorData[index].tensorType = tensorType;
profTensorData->tensorData[index].format = 2; // GeDataFormat: ND = 2
profTensorData->tensorData[index].dataType = dataTypes[i];
int nDim = tensorShapes[i].size();
nDim = nDim < MSPROF_GE_TENSOR_DATA_SHAPE_LEN ? nDim : MSPROF_GE_TENSOR_DATA_SHAPE_LEN;
for (int j = 0; j < nDim; j++) {{
profTensorData->tensorData[index].shape[j] = tensorShapes[i][j];
}}
for (int j = nDim; j < MSPROF_GE_TENSOR_DATA_SHAPE_LEN; j++) {{
profTensorData->tensorData[index].shape[j] = 0;
}}
}};
int tensorType = (i < tensorKinds.size()) ? tensorKinds[i] : 0; // DeFault tensor type is input
if (tensorType == TENSOR_KIND_INPUT || tensorType == TENSOR_KIND_INPUT_OUTPUT) {{
fillTensorData(tensorCount, MSPROF_GE_TENSOR_TYPE_INPUT);
tensorCount++;
}}
if ((tensorType == TENSOR_KIND_OUTPUT || tensorType == TENSOR_KIND_INPUT_OUTPUT) && tensorCount < MSPROF_GE_TENSOR_DATA_NUM){{
fillTensorData(tensorCount, MSPROF_GE_TENSOR_TYPE_OUTPUT);
tensorCount++;
}}
}}
profTensorData->tensorNum = tensorCount;
MsprofReportAdditionalInfo(false, static_cast<void *>(&tensorInfo), sizeof(MsprofAdditionalInfo));
}}
"""
cpp_kernel_launch = f"""
ret = rtKernelLaunch(func, blockNum, static_cast<void*>(&args), sizeof(args), NULL, stream);
"""
if compile_on_910_95 and enable_simt:
cpp_kernel_launch = """
rtArgsEx_t argsInfo = {};
argsInfo.args = static_cast<void*>(&args);
argsInfo.argsSize = sizeof(args);
rtTaskCfgInfo_t cfgInfo = {};
cfgInfo.localMemorySize = 216 * 1024;
ret = rtKernelLaunchWithFlagV2(func, blockNum, &argsInfo, NULL, stream, 0, &cfgInfo);
"""
precompile_headers = f"""
#include "precompiled.h"
"""
return f"""
{precompile_headers}
{'#define __CCE_ENABLE_PRINT__' if enable_device_print else ''}
{extract_device_print_code_from_cann() if enable_device_print else ''}
#define PY_SSIZE_T_CLEAN
{'#define ENABLE_GRID_WARN_PRINT' if enable_grid_warn_print else ''}
#define TENSOR_KIND_INPUT 0
#define TENSOR_KIND_OUTPUT 1
#define TENSOR_KIND_INPUT_OUTPUT 2
{cpp_msprof_extern}
{cpp_device_pointer}
static void _launch(const char* kernelName, const void* func, rtStream_t stream, int gridX, int gridY, int gridZ, std::vector<std::vector<int64_t>> &tensorShapes, std::vector<int> &tensorKinds{', ' + arg_decls if len(signature) > 0 else ''}) {{
// only 1D parallelization is supported for NPU
// Pointer type becomes flattend 1-D Memref tuple: base_ptr, data_ptr, offset, shape, stride
// base_ptr offset shape and stride are not used, arbitrarily set for now
std::string name = "";
name.append(kernelName);
{'auto launch_call = [=]() -> rtError_t' if enable_taskqueue else ''} {{
uint32_t blockNum = gridX * gridY * gridZ;
#ifdef ENABLE_GRID_WARN_PRINT
static bool warned = false;
if (!warned && blockNum > (uint32_t){num_physical_blocks}) {{
printf("WARNING: Grid %u > physical limit {num_physical_blocks}, performance maybe reduced.\\n",blockNum);
warned = true;
}}
#endif
{get_backend_func("pre_launch")}
{'blockNum = std::min(blockNum, (uint32_t)' + str(num_physical_blocks) + ');' if enable_auto_map_parallel_blocks else ''}
{'cce::internal::DebugTunnelData *DTData = cce::internal::DebugTunnel::Open(blockNum);' if enable_device_print else ''}
rtError_t ret;
{'void *ffts_addr = NULL; uint32_t ffts_len; ret = rtGetC2cCtrlAddr((uint64_t*)&ffts_addr, &ffts_len);' if target_support_ffts else ''}
{'if (ret != RT_ERROR_NONE) return ret;' if (target_support_ffts and enable_taskqueue) else 'if (ret != RT_ERROR_NONE) return;' if (target_support_ffts and (not enable_taskqueue)) else ''}
// stub argument for workspace
void *syncBlockLock_ptr = NULL;
void *workspace_addr_ptr = NULL;
uint16_t ModuleId = 0;
{f'''
uint64_t syncBlockLockSize = {lock_num} * sizeof(int64_t);
syncBlockLock_ptr = {get_backend_func("allocate_memory", "syncBlockLockSize", "stream")}
if (!syncBlockLock_ptr) {{
{alloc_success_code if enable_taskqueue else sync_lock_fail_code}
}}
std::vector<int64_t> lockInitData({lock_num}, {lock_init_value});
ret = rtMemcpy(
syncBlockLock_ptr, syncBlockLockSize,
reinterpret_cast<void *>(lockInitData.data()), syncBlockLockSize,
RT_MEMCPY_HOST_TO_DEVICE
);
if (ret != RT_ERROR_NONE) {{
return {'ret' if enable_taskqueue else ''};
}}
''' if lock_num > 0 else ''}
{f'''
uint64_t totalWorkSpaceSize = {workspace_size} * blockNum;
workspace_addr_ptr = {get_backend_func("allocate_memory", "totalWorkSpaceSize", "stream")}
if (!workspace_addr_ptr) {{
{alloc_success_code if enable_taskqueue else workspace_fail_code}
}}
''' if workspace_size > 0 else ''}
{'if (ret != RT_ERROR_NONE) return ret;' if (workspace_size > 0 and enable_taskqueue) else 'if (ret != RT_ERROR_NONE) return;' if (workspace_size > 0 and not enable_taskqueue) else ''}
struct __attribute__((packed)) {{
{'void* ffts_addr __attribute__((aligned(8)));' if target_support_ffts else ''}
{'void* syncBlockLock __attribute__((aligned(8)));' if not metadata.force_simt_only else ''}
{'void* workspace_addr __attribute__((aligned(8)));' if not metadata.force_simt_only else ''}
{' '.join(f'{_ty_to_cpp(ty)} arg{i} __attribute__((aligned({4 if ty[0] != "*" and ty[-2:] != "64" else 8})));' for i, ty in signature.items() if i not in constants)}
{' '.join(f'{_ty_to_cpp(ty)} grid{mark} __attribute__((aligned(4)));' for mark, ty in grid_info.items())}
{'void* DTData __attribute__((aligned(8)));' if enable_device_print else ''}
}} args = {{
{'static_cast<void*>(ffts_addr),' if target_support_ffts else ''}
{('static_cast<void*>(syncBlockLock_ptr),' if lock_num > 0 else 'nullptr,') if not metadata.force_simt_only else ''}
{('static_cast<void*>(workspace_addr_ptr),' if workspace_size > 0 else 'nullptr,') if not metadata.force_simt_only else ''}
{(', '.join(f'static_cast<{_ty_to_cpp(ty)}>(arg{i})' for i, ty in signature.items() if i not in constants) + ',') if len(signature) > 0 else ''}
{', '.join(f'static_cast<{_ty_to_cpp(ty)}>(grid{mark})' for mark, ty in grid_info.items())}
{', static_cast<void*>(DTData)' if enable_device_print else ''}
}};
{cpp_msprof_call_before_launch}
{cpp_kernel_launch}
{'void *&stream_ref = const_cast<void*&>(stream);' if enable_device_print else ''}
{'cce::internal::DebugTunnel::Close(DTData, stream_ref);' if enable_device_print else ''}
{cpp_msprof_call_after_launch}
{'return ret;' if enable_taskqueue else 'ret = rtStreamSynchronize(stream);'}
}};
{f'''{get_backend_func("async_launch", "launch_call") if enable_taskqueue else ''}'''}
return;
}}
// Extract tensor shape from PyObject
static std::vector<int64_t> _get_tensor_shape(PyObject *tensor) {{
std::vector<int64_t> shape;
// Early return if tensor is None or null
if (!tensor || tensor == Py_None) {{
return shape;
}}
// Calling tensor.size()
PyObject* size_result = PyObject_CallMethod(tensor, "size", NULL);
if (!size_result) {{
return shape;
}}
// Using PySequence_Fast to improve access efficiency
PyObject* seq = PySequence_Fast(size_result, "Expected a sequence from tensor.size()");
if (seq) {{
Py_ssize_t len = PySequence_Fast_GET_SIZE(seq);
PyObject** items = PySequence_Fast_ITEMS(seq);
for (Py_ssize_t i = 0; i < len; ++i) {{
PyObject* dim = items[i];
if (PyLong_Check(dim)) {{
shape.push_back(PyLong_AsLong(dim));
}}
}}
}}
Py_DECREF(seq);
Py_DECREF(size_result);
return shape;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
int gridX, gridY, gridZ;
rtStream_t stream;
const void *function;
PyObject *packedMetadata = NULL;
PyObject *launch_metadata = NULL;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
std::vector<std::vector<int64_t>> tensorShapes;
{' '.join([f"{_extracted_ty(ty)} _arg{i}; " for i, ty in signature.items()])}
if(!PyArg_ParseTuple(
args, \"{format}\",
&gridX, &gridY, &gridZ, &stream, &function,
&packedMetadata, &launch_metadata,
&launch_enter_hook, &launch_exit_hook
{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}
)
) {{
return NULL;
}}
if (__MsprofFlagL1)
{{
{
LINE_CHANGE_CHAR.join(
f"{{ auto tmp = _get_tensor_shape(_arg{i}); if (!tmp.empty()) tensorShapes.push_back(tmp); }}"
for i, ty in signature.items() if ty[0] == "*"
)
}
}}
if (launch_enter_hook != Py_None && !PyObject_CallObject(launch_enter_hook, args)) {{
return NULL;
}}
// get kernel_name
PyObject *kernelNameObj = PyDict_GetItemString(packedMetadata, "kernel_name");
const char *kernelName = PyUnicode_AsUTF8(kernelNameObj);
// get tensor_kinds
std::vector<int> tensorKinds;
PyObject *tensorKindList = PyDict_GetItemString(packedMetadata, "tensor_kinds");
if (tensorKindList) {{
int size = PyObject_Size(tensorKindList);
for (int i = 0; i < size; i++) {{
PyObject *kind = PySequence_GetItem(tensorKindList, i);
tensorKinds.push_back(PyLong_AsLong(kind));
}}
}}
// raise exception asap
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0]=="*" else "" for i, ty in signature.items()])};
_launch(kernelName, function, stream, gridX, gridY, gridZ, tensorShapes, tensorKinds{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''});
if (PyErr_Occurred()) {{
return NULL;
}}
if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{
return NULL;
}}
Py_RETURN_NONE;
}}
static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
{{NULL, NULL, 0, NULL}} // sentinel
}};
static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"__triton_launcher\",
NULL, //documentation
-1, //size
ModuleMethods
}};
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
return NULL;
}}
PyModule_AddFunctions(m, ModuleMethods);
{cpp_msprof_callback}
return m;
}}
"""