// Copyright (c) 2025 Huawei Technologies Co., Ltd
// All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef TORCHNPU_TORCH_NPU_CUSTOM_OPS_OP_API_PTA_COMMON_H_
#define TORCHNPU_TORCH_NPU_CUSTOM_OPS_OP_API_PTA_COMMON_H_
#include "op_plugin/utils/op_api_common.h"
#define EXEC_UPDATE_NPU_NO_FORMAT_CHECK_CMD_V1(aclnn_api, workspace_addr, workspace_size, ...) \
do { \
static const auto getWorkspaceSizeFuncAddr = GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \
static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \
static const auto initMemAddr = GetOpApiFuncAddr("InitHugeMemThreadLocal"); \
static const auto unInitMemAddr = GetOpApiFuncAddr("UnInitHugeMemThreadLocal"); \
static const auto releaseMemAddr = GetOpApiFuncAddr("ReleaseHugeMem"); \
static const auto initPTACacheThreadLocalAddr = GetOpApiFuncAddr("InitPTACacheThreadLocal"); \
static const auto setPTAHashKeyAddr = GetOpApiFuncAddr("SetPTAHashKey"); \
TORCH_CHECK(getWorkspaceSizeFuncAddr != nullptr && opApiFuncAddr != nullptr, #aclnn_api, " or ", \
#aclnn_api "GetWorkspaceSize", " not in ", GetOpApiLibName(), ", or ", GetOpApiLibName(), \
"not found.", OPS_ERROR(ErrCode::PTR)); \
OP_EXEC_LOG_WITH_TASK_QUEUE(#aclnn_api, "EXEC_UPDATE_NPU_NO_FORMAT_CHECK_CMD", "1", __VA_ARGS__); \
auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \
if (c10_npu::check_enqueue_need_use(acl_stream)) { \
c10_npu::UseStreamResInCurrentThread(acl_stream); \
} \
aclOpExecutor *executor = nullptr; \
aclOpExecutor **executor_addr = &executor; \
InitHugeMemThreadLocal initMemFunc = reinterpret_cast<InitHugeMemThreadLocal>(initMemAddr); \
UnInitHugeMemThreadLocal unInitMemFunc = reinterpret_cast<UnInitHugeMemThreadLocal>(unInitMemAddr); \
InitPTACacheThreadLocal initPTACacheThreadLocalFunc = \
reinterpret_cast<InitPTACacheThreadLocal>(initPTACacheThreadLocalAddr); \
SetPTAHashKey setPTAHashKeyFunc = reinterpret_cast<SetPTAHashKey>(setPTAHashKeyAddr); \
if (initPTACacheThreadLocalFunc && setPTAHashKeyFunc) { \
initPTACacheThreadLocalFunc(); \
setPTAHashKeyFunc(0); \
} \
at_npu::native::SetDeterministic(); \
if (initMemFunc) { \
initMemFunc(nullptr, false); \
} \
auto copied_params = CopyTypesV2(__VA_ARGS__); \
uint64_t fake_workspace_size = 0; \
uint64_t *workspace_size_addr = &fake_workspace_size; \
auto converted_params = ConvertTypesV2(copied_params, workspace_size_addr, executor_addr); \
static auto getWorkspaceSizeFunc = ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \
auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \
NPU_CHECK_ERROR(workspace_status, "call " #aclnn_api " failed"); \
auto acl_call = [converted_params, workspace_addr, workspace_size, acl_stream, executor]()->int { \
if (c10_npu::check_dequeue_need_use(acl_stream)) { \
c10_npu::UseStreamResInCurrentThread(acl_stream); \
} \
OpApiFunc opApiFunc = reinterpret_cast<OpApiFunc>(opApiFuncAddr); \
auto api_ret = opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \
NPU_CHECK_ERROR(api_ret, "call " #aclnn_api " failed"); \
ReleaseConvertTypes(converted_params); \
ReleaseHugeMem releaseMemFunc = reinterpret_cast<ReleaseHugeMem>(releaseMemAddr); \
if (releaseMemFunc) { \
releaseMemFunc(nullptr, false); \
} \
return api_ret; \
}; \
at_npu::native::OpCommand::RunOpApiV2(#aclnn_api, acl_call); \
if (unInitMemFunc) { \
unInitMemFunc(nullptr, false); \
} \
UnInitCacheThreadLocal(); \
} while (false)
#define EXEC_UPDATE_NPU_NO_FORMAT_CHECK_CMD_V2(aclnn_api, workspace_addr, workspace_size, ...) \
do { \
static const auto getWorkspaceSizeFuncAddr = GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \
static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \
static const auto initMemAddr = GetOpApiFuncAddr("InitHugeMemThreadLocal"); \
static const auto unInitMemAddr = GetOpApiFuncAddr("UnInitHugeMemThreadLocal"); \
static const auto releaseMemAddr = GetOpApiFuncAddr("ReleaseHugeMem"); \
static const auto initPTACacheThreadLocalAddr = GetOpApiFuncAddr("InitPTACacheThreadLocal"); \
static const auto setPTACacheHashKeyAddr = GetOpApiFuncAddr("SetPTACacheHashKey"); \
TORCH_CHECK(getWorkspaceSizeFuncAddr != nullptr && opApiFuncAddr != nullptr, #aclnn_api, " or ", \
#aclnn_api "GetWorkspaceSize", " not in ", GetOpApiLibName(), ", or ", GetOpApiLibName(), \
"not found.", OPS_ERROR(ErrCode::PTR)); \
OP_EXEC_LOG_WITH_TASK_QUEUE(#aclnn_api, "EXEC_UPDATE_NPU_NO_FORMAT_CHECK_CMD", "2", __VA_ARGS__); \
auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \
if (c10_npu::check_enqueue_need_use(acl_stream)) { \
c10_npu::UseStreamResInCurrentThread(acl_stream); \
} \
auto copied_params = CopyTypesV2(__VA_ARGS__); \
auto acl_call = [workspace_addr, workspace_size, copied_params, acl_stream]()->int { \
if (c10_npu::check_dequeue_need_use(acl_stream)) { \
c10_npu::UseStreamResInCurrentThread(acl_stream); \
} \
uint64_t fake_workspace_size = 0; \
uint64_t *workspace_size_addr = &fake_workspace_size; \
aclOpExecutor *executor = nullptr; \
aclOpExecutor **executor_addr = &executor; \
InitHugeMemThreadLocal initMemFunc = reinterpret_cast<InitHugeMemThreadLocal>(initMemAddr); \
UnInitHugeMemThreadLocal unInitMemFunc = reinterpret_cast<UnInitHugeMemThreadLocal>(unInitMemAddr); \
InitPTACacheThreadLocal initPTACacheThreadLocalFunc = \
reinterpret_cast<InitPTACacheThreadLocal>(initPTACacheThreadLocalAddr); \
SetPTACacheHashKey setPTAHashKeyFunc = reinterpret_cast<SetPTACacheHashKey>(setPTACacheHashKeyAddr); \
if (initPTACacheThreadLocalFunc && setPTAHashKeyFunc) { \
initPTACacheThreadLocalFunc(); \
setPTAHashKeyFunc(nullptr, 0); \
} \
at_npu::native::SetDeterministic(); \
if (initMemFunc) { \
initMemFunc(nullptr, false); \
} \
auto converted_params = ConvertTypesV2(copied_params, workspace_size_addr, executor_addr); \
auto getWorkspaceSizeFunc = ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \
auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \
NPU_CHECK_ERROR(workspace_status, "call " #aclnn_api " failed"); \
OpApiFunc opApiFunc = reinterpret_cast<OpApiFunc>(opApiFuncAddr); \
auto api_ret = opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \
NPU_CHECK_ERROR(api_ret, "call " #aclnn_api " failed"); \
ReleaseConvertTypes(converted_params); \
ReleaseHugeMem releaseMemFunc = reinterpret_cast<ReleaseHugeMem>(releaseMemAddr); \
if (releaseMemFunc) { \
releaseMemFunc(nullptr, false); \
} \
if (unInitMemFunc) { \
unInitMemFunc(nullptr, false); \
} \
UnInitCacheThreadLocal(); \
return api_ret; \
}; \
at_npu::native::OpCommand::RunOpApiV2(#aclnn_api, acl_call); \
} while (false)
#define EXEC_GET_MAX_WORKSPACE_CMD(aclnn_api, ...) \
[](const char *apiName, auto &...args)->auto { \
static const auto getWorkspaceSizeFuncAddr = GetOpApiFuncAddr(#aclnn_api "GetMaxWorkspaceSize"); \
static const auto initMemAddr = GetOpApiFuncAddr("InitHugeMemThreadLocal"); \
static const auto unInitMemAddr = GetOpApiFuncAddr("UnInitHugeMemThreadLocal"); \
static const auto releaseMemAddr = GetOpApiFuncAddr("ReleaseHugeMem"); \
static const auto initPTACacheThreadLocalAddr = GetOpApiFuncAddr("InitPTACacheThreadLocal"); \
static const auto setPTAHashKeyAddr = GetOpApiFuncAddr("SetPTAHashKey"); \
TORCH_CHECK(getWorkspaceSizeFuncAddr != nullptr, #aclnn_api "GetMaxWorkspaceSize", " not in ", \
GetOpApiLibName(), ", or ", GetOpApiLibName(), "not found.", OPS_ERROR(ErrCode::PTR)); \
auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \
if (c10_npu::check_enqueue_need_use(acl_stream)) { \
c10_npu::UseStreamResInCurrentThread(acl_stream); \
} \
uint64_t workspace_size = 0; \
uint64_t *workspace_size_addr = &workspace_size; \
aclOpExecutor *executor = nullptr; \
aclOpExecutor **executor_addr = &executor; \
at_npu::native::SetDeterministic(); \
auto converted_params = ConvertTypes(args..., workspace_size_addr, executor_addr); \
static auto getWorkspaceSizeFunc = ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \
auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \
NPU_CHECK_ERROR(workspace_status, "call " #aclnn_api " failed"); \
ReleaseConvertTypes(converted_params); \
NPU_CHECK_ERROR(at_npu::native::AclDestroyAclOpExecutor(executor)); \
return workspace_size; \
}(#aclnn_api, __VA_ARGS__)
#define EXEC_UPDATE_NPU_NO_FORMAT_CHECK_CMD(aclnn_api, workspace_addr, workspace_size, ...) \
do { \
static const auto task_queue_enable = c10_npu::option::OptionsManager::GetTaskQueueEnable(); \
if (task_queue_enable == 2) { \
EXEC_UPDATE_NPU_NO_FORMAT_CHECK_CMD_V2(aclnn_api, workspace_addr, workspace_size, __VA_ARGS__); \
} else { \
EXEC_UPDATE_NPU_NO_FORMAT_CHECK_CMD_V1(aclnn_api, workspace_addr, workspace_size, __VA_ARGS__); \
} \
} while (false)
#endif // TORCHNPU_TORCH_NPU_CUSTOM_OPS_OP_API_PTA_COMMON_H_