#ifndef BUILD_LIBTORCH
#include "torch_npu/csrc/inductor/mlir/cpp_common.h"
#include <dlfcn.h>
#include <memory>
#include <string.h>
#include <sys/syscall.h>
#include <torch_npu/csrc/framework/OpCommand.h>
#include <torch_npu/csrc/profiler/profiler_mgr.h>
#include "third_party/acl/inc/experiment/msprof/toolchain/prof_api.h"
#include "third_party/acl/inc/experiment/msprof/toolchain/prof_common.h"
struct TilingMem {
std::unique_ptr<void, decltype(&aclrtFreeHost)> arg_tiling_host;
std::unique_ptr<void, decltype(&aclrtFree)> arg_tiling_device;
TilingMem() : arg_tiling_host(nullptr, aclrtFreeHost), arg_tiling_device(nullptr, aclrtFree) {}
};
using TilingMemInfo = TilingMem;
TilingMemInfo MEM_CACHE;
struct WorkspaceMem {
std::unique_ptr<void, decltype(&aclrtFreeHost)> arg_workspace_host;
std::unique_ptr<void, decltype(&aclrtFree)> arg_workspace_device;
WorkspaceMem() : arg_workspace_host(nullptr, aclrtFreeHost), arg_workspace_device(nullptr, aclrtFree) {}
};
using WorkspaceMemInfo = WorkspaceMem;
WorkspaceMemInfo MEM_WORK_CACHE;
rtError_t TORCH_NPU_API common_launch(char* kernelName, const void* func, uint32_t gridX, void* args, uint32_t argsSize,
rtStream_t stream)
{
unsigned long int beginTime = 0;
unsigned long int endTime = 0;
unsigned long int opName = 0;
unsigned int threadId = 0;
size_t length = strlen(kernelName);
if (torch_npu::profiler::GetTraceLevel() != -1) {
beginTime = MsprofSysCycleTime();
}
rtError_t ret = rtKernelLaunch(func, gridX, args, argsSize, NULL, stream);
if (torch_npu::profiler::GetTraceLevel() != -1) {
endTime = MsprofSysCycleTime();
opName = MsprofGetHashId(kernelName, length);
threadId = (unsigned int)(syscall(SYS_gettid));
MsprofApi info;
info.magicNumber = 0x5a5a;
info.level = 10000;
info.type = 5;
info.threadId = threadId;
info.reserve = 0;
info.beginTime = beginTime;
info.endTime = endTime;
info.itemId = opName;
MsprofReportApi(0, &info);
}
if (torch_npu::profiler::GetTraceLevel() >= 1) {
MsprofCompactInfo nodeBasicInfo;
nodeBasicInfo.magicNumber = 0x5a5a;
nodeBasicInfo.level = 10000;
nodeBasicInfo.type = 0;
nodeBasicInfo.threadId = threadId;
nodeBasicInfo.timeStamp = endTime;
nodeBasicInfo.data.nodeBasicInfo.opName = opName;
nodeBasicInfo.data.nodeBasicInfo.taskType = 0;
nodeBasicInfo.data.nodeBasicInfo.opType = opName;
nodeBasicInfo.data.nodeBasicInfo.blockDim = gridX;
MsprofReportCompactInfo(0, &nodeBasicInfo, sizeof(MsprofCompactInfo));
}
return ret;
}
static void prepare_tiling(void* args, void* tiling_func, int64_t tilingSize, void* arg_tiling_host,
void* arg_tiling_device, uint32_t gridX, rtStream_t stream, uint32_t argsSize)
{
uint32_t args_num = argsSize / sizeof(void*);
void** args_cast = static_cast<void**>(args);
args_cast[args_num - 5] = arg_tiling_host;
args_cast[args_num - 4] = arg_tiling_host;
typedef int64_t (*mlir_tiling_func)(void*);
mlir_tiling_func func_tiling_pre = reinterpret_cast<mlir_tiling_func>(tiling_func);
func_tiling_pre(args);
aclError err = aclrtMemcpy(arg_tiling_device, tilingSize, arg_tiling_host, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
if (err != ACL_ERROR_NONE) {
printf("aclrtMemcpy Failed, err: %d \n", err);
return;
}
args_cast[args_num - 5] = arg_tiling_device;
args_cast[args_num - 4] = arg_tiling_device;
}
rtError_t TORCH_NPU_API common_launch_dyn(char* kernelName, void* func, void* tiling_func, int64_t tilingSize,
void* arg_tiling_host, void* arg_tiling_device, uint32_t gridX, void* args,
uint32_t argsSize, rtStream_t stream)
{
unsigned long int beginTime = 0;
unsigned long int endTime = 0;
unsigned long int opName = 0;
unsigned int threadId = 0;
size_t length = strlen(kernelName);
if (tilingSize != 0) {
void** args_cast = static_cast<void**>(args);
prepare_tiling(args_cast, tiling_func, tilingSize, arg_tiling_host, arg_tiling_device, gridX, stream, argsSize);
typedef void (*mlir_func)(uint32_t, void*, void*, void*);
mlir_func func_cast = (mlir_func)func;
if (torch_npu::profiler::GetTraceLevel() != -1) {
beginTime = MsprofSysCycleTime();
}
func_cast(gridX, nullptr, stream, args);
} else {
typedef void (*mlir_func)(uint32_t, void*, void*, void*);
mlir_func func_cast = (mlir_func)func;
if (torch_npu::profiler::GetTraceLevel() != -1) {
beginTime = MsprofSysCycleTime();
}
func_cast(gridX, nullptr, stream, args);
}
if (torch_npu::profiler::GetTraceLevel() != -1) {
endTime = MsprofSysCycleTime();
opName = MsprofGetHashId(kernelName, length);
threadId = (unsigned int)(syscall(SYS_gettid));
MsprofApi info;
info.magicNumber = 0x5a5a;
info.level = 10000;
info.type = 5;
info.threadId = threadId;
info.reserve = 0;
info.beginTime = beginTime;
info.endTime = endTime;
info.itemId = opName;
MsprofReportApi(0, &info);
}
if (torch_npu::profiler::GetTraceLevel() >= 1) {
MsprofCompactInfo nodeBasicInfo;
nodeBasicInfo.magicNumber = 0x5a5a;
nodeBasicInfo.level = 10000;
nodeBasicInfo.type = 0;
nodeBasicInfo.threadId = threadId;
nodeBasicInfo.timeStamp = endTime;
nodeBasicInfo.data.nodeBasicInfo.opName = opName;
nodeBasicInfo.data.nodeBasicInfo.taskType = 0;
nodeBasicInfo.data.nodeBasicInfo.opType = opName;
nodeBasicInfo.data.nodeBasicInfo.blockDim = gridX;
MsprofReportCompactInfo(0, &nodeBasicInfo, sizeof(MsprofCompactInfo));
}
return RT_ERROR_NONE;
}
void TORCH_NPU_API opcommand_call(const char* name, std::function<int()> launch_call)
{
at_npu::native::OpCommand cmd;
cmd.Name(name).SetCustomHandler(launch_call).Run();
}
#endif