#include <deque>
#include <cstdint>
#include <cstring>
#include <thread>
#include <utility>
#include <vector>
#include <ATen/ATen.h>
#include "torch_npu/csrc/npu/Graph.h"
#include "op_plugin/OpApiInterface.h"
#include "torch_npu/csrc/core/npu/NPUGraph.h"
#include "torch_npu/csrc/core/npu/NPUGraphsUtils.h"
#include "torch_npu/csrc/npu/Event.h"
#include "torch_npu/csrc/npu/Stream.h"
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
static std::map<c10_npu::NPUStream, std::vector<PyFuncStruct *>> callbacks = {};
constexpr int processReportTimeout = 100;
constexpr int pendingCallRetryCount = 16;
constexpr auto pendingCallRetryInterval = std::chrono::milliseconds(1);
static ThreadArgs* threadArgs = nullptr;
static uint64_t threadId = -1;
static std::mutex pendingCallbacksMutex;
using PendingCallbackEntry = std::pair<PendingCallPayload*, std::vector<at::Tensor>>;
static std::deque<PendingCallbackEntry> pendingCallbacksQueue;
static std::atomic<bool> pendingCallScheduled{false};
void *process_callback(void *arg)
{
ThreadArgs* args = static_cast<ThreadArgs *>(arg);
auto ret = aclrtSetCurrentContext(args->context);
while (!args->exitFlag) {
(void)aclrtProcessReport(processReportTimeout);
}
delete args;
args = nullptr;
return nullptr;
}
void LaunchCallFunc(void *userData)
{
PyGILState_STATE state = PyGILState_Ensure();
if (userData == nullptr) {
return;
}
auto data = (PyFuncStruct *)(userData);
PyObject *argslist = Py_BuildValue("(O)", data->pyFuncArgs);
if (argslist == nullptr) {
return;
}
PyObject *result = PyObject_CallObject(data->pyFunc, argslist);
if (result == nullptr) {
return;
}
if (argslist != nullptr) {
Py_XDECREF(argslist);
}
if (result != nullptr) {
Py_XDECREF(result);
}
PyGILState_Release(state);
}
namespace {
constexpr const char* kNpuGraphTensorPtrSpecMarker = "host_tensor_ptr";
constexpr const char* kNpuGraphTensorBufferSpecMarker = "host_tensor_buffer";
int PendingCallHandler(void *arg);
bool TrySchedulePendingCall()
{
for (int retry = 0; retry < pendingCallRetryCount; ++retry) {
if (Py_AddPendingCall(PendingCallHandler, nullptr) == 0) {
return true;
}
std::this_thread::sleep_for(pendingCallRetryInterval);
}
return false;
}
void StartPendingCallRetryThread()
{
std::thread([]() {
while (!TrySchedulePendingCall()) {
std::this_thread::sleep_for(pendingCallRetryInterval);
}
}).detach();
}
bool IsNpuGraphTensorPtrSpec(PyObject* obj)
{
if (!PyTuple_Check(obj) || PyTuple_GET_SIZE(obj) != 5) {
return false;
}
PyObject* marker = PyTuple_GET_ITEM(obj, 0);
return PyUnicode_Check(marker) &&
PyUnicode_CompareWithASCIIString(marker, kNpuGraphTensorPtrSpecMarker) == 0;
}
bool CollectNpuGraphTensorPtrSpec(PyObject* obj, PendingCallPayload* payload)
{
PyObject* ptrObj = PyTuple_GET_ITEM(obj, 1);
PyObject* nbytesObj = PyTuple_GET_ITEM(obj, 2);
PyObject* shapeObj = PyTuple_GET_ITEM(obj, 3);
PyObject* dtypeObj = PyTuple_GET_ITEM(obj, 4);
auto dataPtr = static_cast<uintptr_t>(PyLong_AsUnsignedLongLong(ptrObj));
if (PyErr_Occurred()) {
return false;
}
auto nbytes = static_cast<Py_ssize_t>(PyLong_AsSsize_t(nbytesObj));
if (PyErr_Occurred()) {
return false;
}
if (nbytes < 0) {
PyErr_SetString(PyExc_ValueError, "npugraph tensor buffer size must be non-negative");
return false;
}
payload->pendingTensorData.emplace_back(dataPtr, nbytes, shapeObj, dtypeObj);
return true;
}
bool CollectPendingTensorData(PyObject* obj, PendingCallPayload* payload)
{
if (IsNpuGraphTensorPtrSpec(obj)) {
return CollectNpuGraphTensorPtrSpec(obj, payload);
}
if (PyTuple_Check(obj)) {
Py_ssize_t size = PyTuple_GET_SIZE(obj);
for (Py_ssize_t i = 0; i < size; ++i) {
if (!CollectPendingTensorData(PyTuple_GET_ITEM(obj, i), payload)) {
return false;
}
}
return true;
}
if (PyList_Check(obj)) {
Py_ssize_t size = PyList_GET_SIZE(obj);
for (Py_ssize_t i = 0; i < size; ++i) {
if (!CollectPendingTensorData(PyList_GET_ITEM(obj, i), payload)) {
return false;
}
}
return true;
}
return true;
}
void CopyPendingTensorData(PendingCallPayload* payload, std::vector<at::Tensor>& copiedTensors)
{
try {
copiedTensors.clear();
copiedTensors.reserve(payload->pendingTensorData.size());
for (const auto& tensorData : payload->pendingTensorData) {
at::Tensor copiedTensor = at::empty(
{static_cast<int64_t>(tensorData.nbytes)},
at::TensorOptions().dtype(at::kByte).device(at::kCPU));
if (tensorData.nbytes > 0) {
std::memcpy(
copiedTensor.data_ptr(),
reinterpret_cast<void*>(tensorData.dataPtr),
static_cast<size_t>(tensorData.nbytes));
}
copiedTensors.emplace_back(std::move(copiedTensor));
}
} catch (...) {
copiedTensors.clear();
}
}
PyObject* MaterializeNpuGraphTensorBufferSpec(
PendingCallPayload* payload,
const std::vector<at::Tensor>& copiedTensors,
size_t& tensorDataIndex)
{
if (tensorDataIndex >= payload->pendingTensorData.size() ||
tensorDataIndex >= copiedTensors.size()) {
PyErr_SetString(PyExc_RuntimeError, "npugraph tensor buffer index is out of range");
return nullptr;
}
auto& tensorData = payload->pendingTensorData[tensorDataIndex];
const auto& copiedTensor = copiedTensors[tensorDataIndex++];
if (!copiedTensor.defined()) {
PyErr_NoMemory();
return nullptr;
}
PyObject* buffer = PyMemoryView_FromMemory(
static_cast<char*>(copiedTensor.data_ptr()),
tensorData.nbytes,
PyBUF_WRITE);
if (buffer == nullptr) {
return nullptr;
}
PyObject* marker = PyUnicode_FromString(kNpuGraphTensorBufferSpecMarker);
if (marker == nullptr) {
Py_DECREF(buffer);
return nullptr;
}
PyObject* bufferSpec = PyTuple_Pack(4, marker, buffer, tensorData.shape, tensorData.dtype);
Py_DECREF(marker);
Py_DECREF(buffer);
return bufferSpec;
}
PyObject* MaterializePendingArg(
PyObject* obj,
PendingCallPayload* payload,
const std::vector<at::Tensor>& copiedTensors,
size_t& tensorDataIndex)
{
if (IsNpuGraphTensorPtrSpec(obj)) {
return MaterializeNpuGraphTensorBufferSpec(payload, copiedTensors, tensorDataIndex);
}
if (PyTuple_Check(obj)) {
Py_ssize_t size = PyTuple_GET_SIZE(obj);
PyObject* tuple = PyTuple_New(size);
if (tuple == nullptr) {
return nullptr;
}
for (Py_ssize_t i = 0; i < size; ++i) {
PyObject* item = MaterializePendingArg(PyTuple_GET_ITEM(obj, i), payload, copiedTensors, tensorDataIndex);
if (item == nullptr) {
Py_DECREF(tuple);
return nullptr;
}
PyTuple_SET_ITEM(tuple, i, item);
}
return tuple;
}
if (PyList_Check(obj)) {
Py_ssize_t size = PyList_GET_SIZE(obj);
PyObject* list = PyList_New(size);
if (list == nullptr) {
return nullptr;
}
for (Py_ssize_t i = 0; i < size; ++i) {
PyObject* item = MaterializePendingArg(PyList_GET_ITEM(obj, i), payload, copiedTensors, tensorDataIndex);
if (item == nullptr) {
Py_DECREF(list);
return nullptr;
}
PyList_SET_ITEM(list, i, item);
}
return list;
}
Py_INCREF(obj);
return obj;
}
PyObject* MaterializePendingArgs(PendingCallPayload* payload, const std::vector<at::Tensor>& copiedTensors)
{
size_t tensorDataIndex = 0;
PyObject* materializedArgs = MaterializePendingArg(
payload->pyFuncData.pyFuncArgs, payload, copiedTensors, tensorDataIndex);
if (materializedArgs == nullptr) {
PyErr_WriteUnraisable(payload->pyFuncData.pyFunc);
PyErr_Clear();
return nullptr;
}
return materializedArgs;
}
int PendingCallHandler(void *arg)
{
std::deque<PendingCallbackEntry> readyCallbacks;
{
std::lock_guard<std::mutex> lock(pendingCallbacksMutex);
pendingCallScheduled.store(false, std::memory_order_release);
readyCallbacks.swap(pendingCallbacksQueue);
}
for (auto& callback : readyCallbacks) {
auto* payload = callback.first;
auto& copiedTensors = callback.second;
if (payload == nullptr) {
continue;
}
PyObject* materializedArgs = MaterializePendingArgs(payload, copiedTensors);
if (materializedArgs == nullptr) {
copiedTensors.clear();
continue;
}
PyObject* result = PyObject_CallObject(payload->pyFuncData.pyFunc, materializedArgs);
if (result != nullptr) {
Py_XDECREF(result);
} else {
PyErr_WriteUnraisable(payload->pyFuncData.pyFunc);
}
Py_DECREF(materializedArgs);
copiedTensors.clear();
}
bool shouldScheduleAgain = false;
{
std::lock_guard<std::mutex> lock(pendingCallbacksMutex);
if (!pendingCallbacksQueue.empty() &&
!pendingCallScheduled.exchange(true, std::memory_order_acq_rel)) {
shouldScheduleAgain = true;
}
}
if (shouldScheduleAgain && !TrySchedulePendingCall()) {
StartPendingCallRetryThread();
}
return 0;
}
void LaunchCallbackViaPendingCall(void *userData)
{
auto* payload = static_cast<PendingCallPayload*>(userData);
if (payload == nullptr) {
return;
}
std::vector<at::Tensor> copiedTensors;
CopyPendingTensorData(payload, copiedTensors);
bool shouldSchedule = false;
{
std::lock_guard<std::mutex> lock(pendingCallbacksMutex);
pendingCallbacksQueue.emplace_back(payload, std::move(copiedTensors));
if (!pendingCallScheduled.exchange(true, std::memory_order_acq_rel)) {
shouldSchedule = true;
}
}
if (shouldSchedule && !TrySchedulePendingCall()) {
StartPendingCallRetryThread();
}
}
}
class AclSkOptionHelper {
public:
std::vector<aclskOption> optionsVec;
std::deque<std::string> stringPool;
std::deque<std::vector<char*>> ptrArrayPool;
void processInitOption(const std::string& key, int value)
{
static const std::unordered_map<std::string, std::function<void(aclskOption&, int)>> optionHandlers = {
{"preload_code", [](aclskOption& opt, int val) {
opt.optionType = aclskOptionType::PRELOAD_CODE;
opt.preload.preloadMode = static_cast<uint32_t>(val);
}},
{"split_mode", [](aclskOption& opt, int val) {
opt.optionType = aclskOptionType::SPLIT_MODE;
opt.splitMode.splitCnt = static_cast<uint32_t>(val);
}},
{"stream_fusion", [](aclskOption& opt, int val) {
opt.optionType = aclskOptionType::STREAM_FUSION;
opt.streamFusion.streamFusion = static_cast<uint32_t>(val);
}},
{"debug_sync_all", [](aclskOption& opt, int val) {
opt.optionType = aclskOptionType::DEBUG_SYNC_ALL;
opt.debugSync.debugSyncAll = static_cast<uint32_t>(val);
}},
{"constant_codegen", [](aclskOption& opt, int val) {
opt.optionType = aclskOptionType::CONSTANT_CODEGEN;
opt.constantCodegen.enableConstant = static_cast<uint32_t>(val);
}},
{"auto_op_parallel", [](aclskOption& opt, int val) {
opt.optionType = aclskOptionType::AUTO_OP_PARALLEL;
opt.autoOpParallel.enableAutoOpParallel = static_cast<uint32_t>(val);
}},
{"debug_op_exec_trace", [](aclskOption& opt, int val) {
opt.optionType = aclskOptionType::DEBUG_OP_EXEC_TRACE;
opt.debugOpExecTrace.enableOpExecTrace = static_cast<uint32_t>(val);
}},
{"debug_cross_core_sync_check", [](aclskOption& opt, int val) {
opt.optionType = aclskOptionType::DEBUG_CROSS_CORE_SYNC_CHECK;
opt.debugCrossCoreSyncCheck.enableCrossCoreSyncCheck = static_cast<uint32_t>(val);
}},
{"early_start", [](aclskOption& opt, int val) {
opt.optionType = aclskOptionType::EARLY_START;
opt.earlyStart.enableEarlyStart = static_cast<uint32_t>(val);
}},
{"debug_per_op_max_core_num", [](aclskOption& opt, int val) {
opt.optionType = aclskOptionType::DEBUG_PER_OP_MAX_CORE_NUM;
opt.debugPerOpMaxCoreNum.enableDebugPerOpMaxCoreNum = static_cast<uint32_t>(val);
}}
};
auto it = optionHandlers.find(key);
if (it != optionHandlers.end()) {
aclskOption opt = {};
it->second(opt, value);
optionsVec.push_back(opt);
}
}
void processStringArrayOption(const std::string& key, const std::vector<std::string>& values)
{
if (key == "dcci_disable_on_kernel") {
processDcciDisableOnKernel(values);
} else if (key == "dcci_before_kernel_start") {
processDcciBeforeKernelStart(values);
} else if (key == "dcci_after_kernel_end") {
processDcciAfterKernelEnd(values);
} else if (key == "ubuf_lock_ignore_kernel") {
processUbufLockIgnoreKernel(values);
}
}
uint32_t getDictIntDefault(const py::dict& d, const std::string& key, uint32_t default_val)
{
if (d.contains(key)) {
py::object item = d.attr("__getitem__")(key);
if (py::isinstance<py::int_>(item)) {
return item.cast<uint32_t>();
}
}
return default_val;
}
void processDictOption(const std::string& key, const py::dict& dict_value)
{
if (key == "aggressive_opt_strategies") {
aclskOption opt = {};
opt.optionType = aclskOptionType::AGGRESSIVE_OPT_STRATEGIES;
opt.aggressiveOpts.eventBreakerBypass = getDictIntDefault(dict_value, "event_breaker_bypass", 0);
opt.aggressiveOpts.valueBreakerBypass = getDictIntDefault(dict_value, "value_breaker_bypass", 0);
opt.aggressiveOpts.taskBreakerBypass = getDictIntDefault(dict_value, "task_breaker_bypass", 0);
optionsVec.push_back(opt);
}
}
void processStringOption(const std::string& key, const std::string& value)
{
aclskOption opt = {};
if (key == "opt_extend") {
opt.optionType = aclskOptionType::OPT_EXTEND_OPTION;
stringPool.push_back(value);
opt.optExtend.value = const_cast<char*>(stringPool.back().c_str());
} else if (key == "debug_extend") {
opt.optionType = aclskOptionType::DEBUG_EXTEND_OPTION;
stringPool.push_back(value);
opt.debugExtend.value = const_cast<char*>(stringPool.back().c_str());
} else {
return;
}
optionsVec.push_back(opt);
}
aclskOptions getStruct()
{
aclskOptions finalOpt = {};
finalOpt.options = optionsVec.data();
finalOpt.numOptions = optionsVec.size();
return finalOpt;
}
private:
void processDcciDisableOnKernel(const std::vector<std::string>& values)
{
aclskOption opt = {};
opt.optionType = aclskOptionType::DCCI_DISABLE_ON_KERNEL;
opt.disableKernelDcci.kernelCnt = static_cast<int>(values.size());
opt.disableKernelDcci.kernelNames = convertStringArray(values);
optionsVec.push_back(opt);
}
void processDcciBeforeKernelStart(const std::vector<std::string>& values)
{
aclskOption opt = {};
opt.optionType = aclskOptionType::DCCI_BEFORE_KERNEL_START;
opt.dcciBeforeKernelStart.kernelCnt = static_cast<int>(values.size());
opt.dcciBeforeKernelStart.kernelNames = convertStringArray(values);
optionsVec.push_back(opt);
}
void processDcciAfterKernelEnd(const std::vector<std::string>& values)
{
aclskOption opt = {};
opt.optionType = aclskOptionType::DCCI_AFTER_KERNEL_END;
opt.dcciAfterKernelEnd.kernelCnt = static_cast<int>(values.size());
opt.dcciAfterKernelEnd.kernelNames = convertStringArray(values);
optionsVec.push_back(opt);
}
void processUbufLockIgnoreKernel(const std::vector<std::string>& values)
{
aclskOption opt = {};
opt.optionType = aclskOptionType::UBUF_LOCK_IGNORE_KERNEL;
opt.ubufLockIgnoreKernel.ubufLockIgnoreKernelCnt = static_cast<int>(values.size());
opt.ubufLockIgnoreKernel.ubufLockIgnoreKernel = convertStringArray(values);
optionsVec.push_back(opt);
}
char** convertStringArray(const std::vector<std::string>& values)
{
std::vector<char*> charPtrs;
charPtrs.reserve(values.size());
for (const auto& token : values) {
stringPool.push_back(token);
charPtrs.push_back(const_cast<char*>(stringPool.back().c_str()));
}
ptrArrayPool.push_back(std::move(charPtrs));
return ptrArrayPool.back().data();
}
};
void TORCH_NPU_API THNPGraph_init(PyObject* module) {
auto torch_N_m = py::handle(module).cast<py::module>();
py::class_<c10_npu::NPUTaskGroupHandle>(torch_N_m, "_NPUTaskGroupHandle")
.def_readonly("task_group", &c10_npu::NPUTaskGroupHandle::task_group);
torch_N_m.def("_graph_pool_handle", &c10_npu::graph_pool_handle)
.def("_graph_task_group_begin", [](py::object py_stream) {
auto stream = (*py_stream).ptr();
c10_npu::graph_task_group_begin(THNPUtils_PyObject_to_NPUStream(stream));
NPUGRAPH_LOGD("NPUGRAPH TaskGroup begin, stream=%p", static_cast<void*>(stream));
})
.def("_graph_task_group_end", [](py::object py_stream) {
auto stream = (*py_stream).ptr();
auto handle = c10_npu::graph_task_group_end(THNPUtils_PyObject_to_NPUStream(stream));
NPUGRAPH_LOGD("NPUGRAPH TaskGroup end, handle=%p", static_cast<void*>(handle.task_group));
return handle;
})
.def("_graph_task_update_begin", [](py::object py_stream, c10_npu::NPUTaskGroupHandle handle) {
auto stream = (*py_stream).ptr();
c10_npu::graph_task_update_begin(THNPUtils_PyObject_to_NPUStream(stream), handle);
NPUGRAPH_LOGD("NPUGRAPH TaskGroup update begin, handle=%p", static_cast<void*>(handle.task_group));
})
.def("_graph_task_update_end", [](py::object py_stream) {
auto stream = (*py_stream).ptr();
c10_npu::graph_task_update_end(THNPUtils_PyObject_to_NPUStream(stream));
NPUGRAPH_LOGD("NPUGRAPH TaskGroup update end");
})
.def("_super_kernel_scope_begin", [](const char* scope_name) {
NPUGRAPH_LOGD("NPUGRAPH SuperKernel scope begin, name=%s",
scope_name ? scope_name : "(null)");
c10_npu::super_kernel_scope_begin(scope_name);
})
.def("_super_kernel_scope_end", [](const char* scope_name) {
NPUGRAPH_LOGD("NPUGRAPH SuperKernel scope end, name=%s",
scope_name ? scope_name : "(null)");
c10_npu::super_kernel_scope_end(scope_name);
})
.def("_launch_host_func", [](py::object py_stream, py::object py_func, py::object py_data) {
auto func = (*py_func).ptr();
auto userDataList = (*py_data).ptr();
auto stream = THNPUtils_PyObject_to_NPUStream((*py_stream).ptr());
PyFuncStruct *data = new(std::nothrow) PyFuncStruct(func, userDataList);
c10_npu::launch_callback(stream, LaunchCallFunc, data);
callbacks[stream].emplace_back(data);
})
.def("_launch_host_func_pending", [](py::object py_stream, py::object py_func, py::object py_data) {
auto func = (*py_func).ptr();
auto userDataList = (*py_data).ptr();
auto stream = THNPUtils_PyObject_to_NPUStream((*py_stream).ptr());
auto payload = std::make_unique<PendingCallPayload>(func, userDataList);
if (!CollectPendingTensorData(userDataList, payload.get())) {
throw py::error_already_set();
}
c10_npu::launch_host_func(stream, LaunchCallbackViaPendingCall, payload.get());
(void)payload.release();
})
.def("_subscribe_report", [](py::object py_stream) {
auto stream = (*py_stream).ptr();
aclrtContext context = aclrtContext();
NPU_CHECK_ERROR(aclrtGetCurrentContext(&context));
if ((threadArgs == nullptr) || (threadId == -1)) {
threadArgs = new ThreadArgs(context, false);
pthread_create(&threadId, nullptr, process_callback, threadArgs);
}
c10_npu::subscribe_report(threadId, THNPUtils_PyObject_to_NPUStream(stream));
})
.def("_unsubscribe_report", [](py::object py_stream) {
auto stream = THNPUtils_PyObject_to_NPUStream((*py_stream).ptr());
c10_npu::unsubscribe_report(threadId, stream);
auto it = callbacks.find(stream);
if (it != callbacks.end()) {
std::vector<PyFuncStruct *>& funcs = it->second;
for (PyFuncStruct* func : funcs) {
delete func;
func = nullptr;
}
funcs.clear();
callbacks.erase(it);
}
if (callbacks.empty()) {
threadArgs->exitFlag = true;
threadId = -1;
}
})
.def("_npu_fused_infer_attention_score_out_graph", [](
py::object py_stream,
c10_npu::NPUTaskGroupHandle handle,
py::object py_event,
py::args args,
py::kwargs kwargs
) -> std::tuple<at::Tensor &, at::Tensor &> {
static torch::PythonArgParser parser({
"npu_fused_infer_attention_score("
"Tensor query, Tensor key, Tensor value, *, "
"Tensor? pse_shift=None, "
"Tensor? atten_mask=None, "
"SymIntArrayRef actual_seq_lengths=None, "
"SymIntArrayRef actual_seq_lengths_kv=None, "
"Tensor? dequant_scale1=None, "
"Tensor? quant_scale1=None, "
"Tensor? dequant_scale2=None, "
"Tensor? quant_scale2=None, "
"Tensor? quant_offset2=None, "
"Tensor? antiquant_scale=None, "
"Tensor? antiquant_offset=None, "
"Tensor? key_antiquant_scale=None, "
"Tensor? key_antiquant_offset=None, "
"Tensor? value_antiquant_scale=None, "
"Tensor? value_antiquant_offset=None, "
"Tensor? block_table=None, "
"Tensor? query_padding_size=None, "
"Tensor? kv_padding_size=None, "
"Tensor? key_shared_prefix=None, "
"Tensor? value_shared_prefix=None, "
"SymIntArrayRef actual_shared_prefix_len=None, "
"Tensor? query_rope=None, "
"Tensor? key_rope=None, "
"Tensor? key_rope_antiquant_scale=None, "
"int64_t num_heads=1, "
"double scale=1.0, "
"int64_t pre_tokens=2147483647, "
"int64_t next_tokens=2147483647, "
"std::string input_layout=\"BSH\", "
"int64_t num_key_value_heads=0, "
"int64_t sparse_mode=0, "
"int64_t inner_precise=0, "
"int64_t block_size=0, "
"int64_t antiquant_mode=0, "
"int64_t key_antiquant_mode=0, "
"int64_t value_antiquant_mode=0, "
"bool softmax_lse_flag=False, "
"Tensor? workspace=None, "
"TensorList out)"
});
PyObject* args_ptr = args.ptr();
PyObject* kwargs_ptr = kwargs.ptr();
torch::ParsedArgs<42> parsed;
torch::PythonArgs py_args = parser.parse(args_ptr, kwargs_ptr, parsed);
at::Tensor query = py_args.tensor(0);
at::Tensor key = py_args.tensor(1);
at::Tensor value = py_args.tensor(2);
c10::optional<at::Tensor> pse_shift = py_args.optionalTensor(3);
c10::optional<at::Tensor> atten_mask = py_args.optionalTensor(4);
c10::OptionalArray<c10::SymInt> actual_seq_lengths = py_args.symintlistOptional(5);
c10::OptionalArray<c10::SymInt> actual_seq_lengths_kv = py_args.symintlistOptional(6);
c10::optional<at::Tensor> dequant_scale1 = py_args.optionalTensor(7);
c10::optional<at::Tensor> quant_scale1 = py_args.optionalTensor(8);
c10::optional<at::Tensor> dequant_scale2 = py_args.optionalTensor(9);
c10::optional<at::Tensor> quant_scale2 = py_args.optionalTensor(10);
c10::optional<at::Tensor> quant_offset2 = py_args.optionalTensor(11);
c10::optional<at::Tensor> antiquant_scale = py_args.optionalTensor(12);
c10::optional<at::Tensor> antiquant_offset = py_args.optionalTensor(13);
c10::optional<at::Tensor> key_antiquant_scale = py_args.optionalTensor(14);
c10::optional<at::Tensor> key_antiquant_offset = py_args.optionalTensor(15);
c10::optional<at::Tensor> value_antiquant_scale = py_args.optionalTensor(16);
c10::optional<at::Tensor> value_antiquant_offset = py_args.optionalTensor(17);
c10::optional<at::Tensor> block_table = py_args.optionalTensor(18);
c10::optional<at::Tensor> query_padding_size = py_args.optionalTensor(19);
c10::optional<at::Tensor> kv_padding_size = py_args.optionalTensor(20);
c10::optional<at::Tensor> key_shared_prefix = py_args.optionalTensor(21);
c10::optional<at::Tensor> value_shared_prefix = py_args.optionalTensor(22);
c10::OptionalArray<c10::SymInt> actual_shared_prefix_len = py_args.symintlistOptional(23);
c10::optional<at::Tensor> query_rope = py_args.optionalTensor(24);
c10::optional<at::Tensor> key_rope = py_args.optionalTensor(25);
c10::optional<at::Tensor> key_rope_antiquant_scale = py_args.optionalTensor(26);
int64_t num_heads = py_args.toInt64(27);
double scale = py_args.toDouble(28);
int64_t pre_tokens = py_args.toInt64(29);
int64_t next_tokens = py_args.toInt64(30);
std::string input_layout = py_args.string(31);
int64_t num_key_value_heads = py_args.toInt64(32);
int64_t sparse_mode = py_args.toInt64(33);
int64_t inner_precise = py_args.toInt64(34);
int64_t block_size = py_args.toInt64(35);
int64_t antiquant_mode = py_args.toInt64(36);
int64_t key_antiquant_mode = py_args.toInt64(37);
int64_t value_antiquant_mode = py_args.toInt64(38);
bool softmax_lse_flag = py_args.toBool(39);
c10::optional<at::Tensor> workspace = py_args.optionalTensor(40);
std::vector<at::Tensor> out = py_args.tensorlist(41);
TORCH_CHECK(out.size() == 2,
"out must have 2 tensors (attention_out, softmax_lse), but got ",
out.size(), PTA_ERROR(ErrCode::PARAM));
at::Tensor attention_out = out[0];
at::Tensor softmax_lse = out[1];
auto stream = THNPUtils_PyObject_to_NPUStream((*py_stream).ptr());
auto event_ptr = THNPUtils_PyObject_to_NPUEvent((*py_event).ptr());
pybind11::gil_scoped_release no_gil;
c10_npu::graph_task_update_begin(stream, handle);
auto fia_result = op_api::npu_fused_infer_attention_score_out_symint(
query, key, value,
pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv,
dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2,
antiquant_scale, antiquant_offset, key_antiquant_scale, key_antiquant_offset,
value_antiquant_scale, value_antiquant_offset, block_table,
query_padding_size, kv_padding_size, key_shared_prefix, value_shared_prefix,
actual_shared_prefix_len, query_rope, key_rope, key_rope_antiquant_scale,
num_heads, scale, pre_tokens, next_tokens, input_layout,
num_key_value_heads, sparse_mode, inner_precise, block_size,
antiquant_mode, key_antiquant_mode, value_antiquant_mode,
softmax_lse_flag, workspace, attention_out, softmax_lse
);
c10_npu::graph_task_update_end(stream);
event_ptr->record(stream);
return fia_result;
});
shared_ptr_class_<c10_npu::NPUGraph>(torch_N_m, "_NPUGraph")
.def(py::init<>())
.def(
"capture_begin",
[](c10_npu::NPUGraph& self,
std::optional<c10_npu::MempoolId_t> pool_opt,
std::string capture_error_mode,
bool report_shape) {
aclmdlRICaptureMode capture_mode;
c10_npu::MempoolId_t pool = pool_opt.has_value()
? pool_opt.value() : c10_npu::MempoolId_t{0, 0};
if (capture_error_mode == "global") {
capture_mode = aclmdlRICaptureMode::ACL_MODEL_RI_CAPTURE_MODE_GLOBAL;
} else if (capture_error_mode == "thread_local") {
capture_mode = aclmdlRICaptureMode::ACL_MODEL_RI_CAPTURE_MODE_THREAD_LOCAL;
} else if (capture_error_mode == "relaxed") {
capture_mode = aclmdlRICaptureMode::ACL_MODEL_RI_CAPTURE_MODE_RELAXED;
} else {
TORCH_CHECK(
false,
"Unknown capture error mode. Expected `global`, `thread_local`, or `relaxed`, got ",
capture_error_mode);
}
return self.capture_begin(pool, capture_mode, report_shape);
},
py::arg("pool"),
py::arg("capture_error_mode"),
py::arg("report_shape"),
py::call_guard<py::gil_scoped_release>())
.def(
"capture_end",
torch::wrap_pybind_function_no_gil(&c10_npu::NPUGraph::capture_end))
.def(
"register_generator_state",
[](c10_npu::NPUGraph& self, py::handle raw_generator) {
auto generator = THPGenerator_Unwrap(raw_generator.ptr());
py::gil_scoped_release release;
return self.register_generator_state(generator);
},
py::arg("generator"))
.def(
"replay",
torch::wrap_pybind_function_no_gil(&c10_npu::NPUGraph::replay))
.def(
"reset",
torch::wrap_pybind_function_no_gil(&c10_npu::NPUGraph::reset))
.def(
"pool",
torch::wrap_pybind_function_no_gil(&c10_npu::NPUGraph::pool))
.def(
"debug_dump",
torch::wrap_pybind_function_no_gil(&c10_npu::NPUGraph::debug_dump),
py::arg("debug_path"))
.def(
"enable_debug_mode",
torch::wrap_pybind_function_no_gil(&c10_npu::NPUGraph::enable_debug_mode))
.def(
"super_kernel_optimize",
[](c10_npu::NPUGraph& self,
py::object optimize_options,
py::object debug_options) {
AclSkOptionHelper helper;
if (!optimize_options.is_none()) {
auto opts = optimize_options.cast<py::dict>();
for (auto item : opts) {
std::string key = py::str(item.first);
if (py::isinstance<py::dict>(item.second)) {
helper.processDictOption(key, item.second.cast<py::dict>());
} else if (py::isinstance<py::str>(item.second)) {
helper.processStringOption(key, item.second.cast<std::string>());
} else if (py::isinstance<py::list>(item.second)) {
helper.processStringArrayOption(key, item.second.cast<std::vector<std::string>>());
} else if (py::isinstance<py::int_>(item.second)) {
helper.processInitOption(key, item.second.cast<int>());
}
}
}
if (!debug_options.is_none()) {
auto opts = debug_options.cast<py::dict>();
for (auto item : opts) {
std::string key = py::str(item.first);
if (py::isinstance<py::int_>(item.second)) {
helper.processInitOption(key, item.second.cast<int>());
} else if (py::isinstance<py::str>(item.second)) {
helper.processStringOption(key, item.second.cast<std::string>());
}
}
}
aclskOptions options = helper.getStruct();
{
py::gil_scoped_release release;
return self.super_kernel_optimize(&options);
}
},
py::arg("optimize_options"),
py::arg("debug_options"));
}