* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "sk_constant_codegen.h"
#include "sk_log.h"
#include "sk_model_context.h"
#include "sk_options_manager.h"
#include "sk_common.h"
#include <acl/acl.h>
#include <acl/acl_rt_compile.h>
#include <sstream>
#include <iomanip>
#include <cstring>
#include <mutex>
#include <fstream>
#include <sys/stat.h>
#include <unistd.h>
extern "C" aclrtBinHandle AscendGetEntryBinHandle();
const char* SK_COMMON_CODE = R"(
#include <cstdint>
#include <cstddef>
enum class SkKernelType : uint8_t {
DEFAULT = 0xFF,
AIC_ONLY = 1,
AIV_ONLY = 2,
MIX_AIV_1_0 = 3,
MIX_AIC_1_0 = 4,
MIX_AIC_1_1 = 5,
MIX_AIC_1_2 = 6,
};
enum class SkTaskType : uint8_t {
TYPE_FUNC,
TYPE_SYNC,
TYPE_PRELOAD,
TYPE_EVENT_NOTIFY,
TYPE_EVENT_WAIT,
TYPE_EVENT_RESET,
TYPE_MAX,
};
enum class SkCoreSyncType : uint8_t {
ALL_SYNC = 0,
CROSS_SYNC_AIC_TO_AIC,
CROSS_SYNC_AIV_TO_AIV,
INTER_SYNC_SET_AIC_TO_AIV,
INTER_SYNC_SET_AIV_TO_AIC,
INTER_SYNC_WAIT_AIC_TO_AIV,
INTER_SYNC_WAIT_AIV_TO_AIC,
SYNC_NONE,
};
enum class SkMemoryWaitFlag : uint32_t {
GEQ = 0x0,
EQ = 0x1,
AND = 0x2,
NOR = 0x3,
};
enum class SkEarlyStartMask : uint32_t {
NONE = 0U,
AIC_TO_AIC_SET = 1U << 0,
AIC_TO_AIC_WAIT = 1U << 1,
AIC_TO_AIV_SET = 1U << 2,
AIV_TO_AIC_WAIT = 1U << 3,
AIV_TO_AIV_SET = 1U << 4,
AIV_TO_AIV_WAIT = 1U << 5,
AIV_TO_AIC_SET = 1U << 6,
AIC_TO_AIV_WAIT = 1U << 7,
SPLIT_CORE_CTRL = 1U << 15,
};
struct TaskInfo {
uint32_t index;
SkTaskType type;
SkKernelType relatedType;
uint8_t numBlocks;
uint8_t entryCnt;
uint64_t entry[4];
uint64_t debugOptions;
uint64_t reserved;
uint64_t args;
uint32_t argsSize;
uint8_t reservedList[4];
};
)";
const char* SK_IMPL_CODE = R"(
#include "kernel_operator.h"
// 函数指针类型定义(使用 sk::SkSystemArgs,定义在 kernel_operator.h 中)
typedef void (*sk_sub_func)(const __gm__ void *param, const sk::SkSystemArgs *sysArgs);
namespace AscendC {
template<bool aic_flag>
__aicore__ inline void NotifyFunc(uint64_t param, uint64_t value) {
if constexpr(aic_flag) {
if (get_block_idx() == 0) {
__gm__ uint64_t *notifyLock = reinterpret_cast<__gm__ uint64_t *>(param);
*notifyLock = value;
dcci(notifyLock, 0, 2);
}
} else {
if (AscendC::GetBlockIdx() == 0) {
__gm__ uint64_t *notifyLock = reinterpret_cast<__gm__ uint64_t *>(param);
*notifyLock = value;
dcci(notifyLock, 0, 2);
}
}
}
template<bool aic_flag>
__aicore__ inline void WaitFunc(uint64_t param, uint64_t value, uint32_t flag) {
if constexpr(aic_flag) {
if (get_block_idx() == 0) {
__gm__ volatile uint64_t *waitLock = reinterpret_cast<__gm__ uint64_t *>(param);
if (flag == static_cast<uint32_t>(SkMemoryWaitFlag::GEQ)) {
dcci(waitLock, 0, 2);
while (*waitLock < value) {
dcci(waitLock, 0, 2);
}
} else if (flag == static_cast<uint32_t>(SkMemoryWaitFlag::EQ)) {
dcci(waitLock, 0, 2);
while (*waitLock != value) {
dcci(waitLock, 0, 2);
}
} else if (flag == static_cast<uint32_t>(SkMemoryWaitFlag::AND)) {
dcci(waitLock, 0, 2);
while ((*waitLock & value) == 0) {
dcci(waitLock, 0, 2);
}
} else {
dcci(waitLock, 0, 2);
while ((~(*waitLock | value)) == 0) {
dcci(waitLock, 0, 2);
}
}
}
} else {
if (AscendC::GetBlockIdx() == 0) {
__gm__ volatile uint64_t *waitLock = reinterpret_cast<__gm__ uint64_t *>(param);
if (flag == static_cast<uint32_t>(SkMemoryWaitFlag::GEQ)) {
dcci(waitLock, 0, 2);
while (*waitLock < value) {
dcci(waitLock, 0, 2);
}
} else if (flag == static_cast<uint32_t>(SkMemoryWaitFlag::EQ)) {
dcci(waitLock, 0, 2);
while (*waitLock != value) {
dcci(waitLock, 0, 2);
}
} else if (flag == static_cast<uint32_t>(SkMemoryWaitFlag::AND)) {
dcci(waitLock, 0, 2);
while ((*waitLock & value) == 0) {
dcci(waitLock, 0, 2);
}
} else {
dcci(waitLock, 0, 2);
while ((~(*waitLock | value)) == 0) {
dcci(waitLock, 0, 2);
}
}
}
}
}
template<bool aic_flag>
__aicore__ inline void ResetFunc(uint64_t param, uint64_t value) {
if constexpr(aic_flag) {
if (get_block_idx() == 0) {
__gm__ uint64_t *resetLock = reinterpret_cast<__gm__ uint64_t *>(param);
*resetLock = value;
dcci(resetLock, 0, 2);
}
} else {
if (AscendC::GetBlockIdx() == 0) {
__gm__ uint64_t *resetLock = reinterpret_cast<__gm__ uint64_t *>(param);
*resetLock = value;
dcci(resetLock, 0, 2);
}
}
}
template <uint8_t aic, uint8_t aiv>
__aicore__ inline void AutoCoreSyncImpl(SkCoreSyncType sync_type) {
switch (sync_type) {
case SkCoreSyncType::CROSS_SYNC_AIC_TO_AIC:
if ASCEND_IS_AIC {
AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(AscendC::SYNC_AIC_FLAG);
AscendC::CrossCoreWaitFlag(AscendC::SYNC_AIC_FLAG);
}
return;
case SkCoreSyncType::CROSS_SYNC_AIV_TO_AIV:
if ASCEND_IS_AIV {
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(AscendC::SYNC_AIV_ONLY_ALL);
AscendC::CrossCoreWaitFlag(AscendC::SYNC_AIV_ONLY_ALL);
}
return;
case SkCoreSyncType::INTER_SYNC_SET_AIC_TO_AIV:
if ASCEND_IS_AIC {
AscendC::CrossCoreSetFlag<0x02, PIPE_FIX>(AscendC::SYNC_AIC_AIV_FLAG);
}
return;
case SkCoreSyncType::INTER_SYNC_SET_AIV_TO_AIC:
if ASCEND_IS_AIV {
AscendC::CrossCoreSetFlag<0x02, PIPE_MTE3>(AscendC::SYNC_AIV_FLAG);
}
return;
case SkCoreSyncType::INTER_SYNC_WAIT_AIC_TO_AIV:
if ASCEND_IS_AIV {
AscendC::CrossCoreWaitFlag(AscendC::SYNC_AIC_AIV_FLAG);
}
return;
case SkCoreSyncType::INTER_SYNC_WAIT_AIV_TO_AIC:
if ASCEND_IS_AIC {
AscendC::CrossCoreWaitFlag(AscendC::SYNC_AIV_FLAG);
}
return;
default:
if constexpr (aic == 1 && aiv == 0) {
AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(AscendC::SYNC_AIC_FLAG);
AscendC::CrossCoreWaitFlag(AscendC::SYNC_AIC_FLAG);
} else if constexpr (aic == 0 && aiv == 1) {
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(AscendC::SYNC_AIV_ONLY_ALL);
AscendC::CrossCoreWaitFlag(AscendC::SYNC_AIV_ONLY_ALL);
} else {
AscendC::SyncAll<false>();
}
return;
}
}
template <uint8_t aic, uint8_t aiv>
__aicore__ inline void AutoCoreSyncImpl(SkCoreSyncType syncType, uint8_t numBlocks, uint64_t syncConfig)
{
if (syncConfig == 0) {
AutoCoreSyncImpl<aic, aiv>(syncType);
return;
}
if (AscendC::GetBlockIdx() < numBlocks) {
return;
}
if ASCEND_IS_AIC {
if ((syncConfig & static_cast<uint64_t>(SkEarlyStartMask::AIC_TO_AIC_SET)) != 0) {
AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(AscendC::SYNC_AIC_FLAG);
}
if ((syncConfig & static_cast<uint64_t>(SkEarlyStartMask::AIC_TO_AIC_WAIT)) != 0) {
AscendC::CrossCoreWaitFlag(AscendC::SYNC_AIC_FLAG);
}
if ((syncConfig & static_cast<uint64_t>(SkEarlyStartMask::AIC_TO_AIV_SET)) != 0) {
AscendC::CrossCoreSetFlag<0x02, PIPE_FIX>(AscendC::SYNC_AIC_AIV_FLAG);
}
if ((syncConfig & static_cast<uint64_t>(SkEarlyStartMask::AIV_TO_AIC_WAIT)) != 0) {
AscendC::CrossCoreWaitFlag(AscendC::SYNC_AIV_FLAG);
}
}
if ASCEND_IS_AIV {
if ((syncConfig & static_cast<uint64_t>(SkEarlyStartMask::AIV_TO_AIV_SET)) != 0) {
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(AscendC::SYNC_AIV_ONLY_ALL);
}
if ((syncConfig & static_cast<uint64_t>(SkEarlyStartMask::AIV_TO_AIV_WAIT)) != 0) {
AscendC::CrossCoreWaitFlag(AscendC::SYNC_AIV_ONLY_ALL);
}
if ((syncConfig & static_cast<uint64_t>(SkEarlyStartMask::AIV_TO_AIC_SET)) != 0) {
AscendC::CrossCoreSetFlag<0x02, PIPE_MTE3>(AscendC::SYNC_AIV_FLAG);
}
if ((syncConfig & static_cast<uint64_t>(SkEarlyStartMask::AIC_TO_AIV_WAIT)) != 0) {
AscendC::CrossCoreWaitFlag(AscendC::SYNC_AIC_AIV_FLAG);
}
}
}
} // namespace AscendC
)";
const char* TaskTypeToEnumStr(SkTaskType type) {
switch (type) {
case SkTaskType::TYPE_FUNC: return "SkTaskType::TYPE_FUNC";
case SkTaskType::TYPE_SYNC: return "SkTaskType::TYPE_SYNC";
case SkTaskType::TYPE_PRELOAD: return "SkTaskType::TYPE_PRELOAD";
case SkTaskType::TYPE_EVENT_NOTIFY: return "SkTaskType::TYPE_EVENT_NOTIFY";
case SkTaskType::TYPE_EVENT_WAIT: return "SkTaskType::TYPE_EVENT_WAIT";
case SkTaskType::TYPE_EVENT_RESET: return "SkTaskType::TYPE_EVENT_RESET";
default: return "SkTaskType::TYPE_MAX";
}
}
const char* KernelTypeToEnumStr(SkKernelType type) {
switch (type) {
case SkKernelType::AIC_ONLY: return "SkKernelType::AIC_ONLY";
case SkKernelType::AIV_ONLY: return "SkKernelType::AIV_ONLY";
case SkKernelType::MIX_AIV_1_0: return "SkKernelType::MIX_AIV_1_0";
case SkKernelType::MIX_AIC_1_0: return "SkKernelType::MIX_AIC_1_0";
case SkKernelType::MIX_AIC_1_1: return "SkKernelType::MIX_AIC_1_1";
case SkKernelType::MIX_AIC_1_2: return "SkKernelType::MIX_AIC_1_2";
default: return "SkKernelType::DEFAULT";
}
}
std::string Hex64ToStr(uint64_t val) {
std::ostringstream oss;
oss << "0x" << std::hex << std::setw(16) << std::setfill('0') << val << "ULL";
return oss.str();
}
std::string ConstantCodeGenerator::GenerateConstantTaskQue(
const SkTask& task,
const std::string& queueName)
{
std::ostringstream code;
const TaskQue* taskQue = task.GetTaskQue();
if (taskQue == nullptr || taskQue->taskCnt == 0) {
code << "// No " << queueName << " tasks\n";
return code.str();
}
code << "struct TaskQue_" << queueName << " {\n";
code << " static constexpr uint32_t taskCnt = " << taskQue->taskCnt << ";\n";
code << " static constexpr uint32_t cap = " << taskQue->cap << ";\n";
if (taskQue->taskCnt > 0) {
code << " static constexpr TaskInfo taskInfos[" << taskQue->taskCnt << "] = {\n";
for (uint32_t i = 0; i < taskQue->taskCnt; i++) {
const TaskInfo& info = taskQue->taskInfos[i];
code << " {";
code << info.index << ", ";
code << TaskTypeToEnumStr(info.type) << ", ";
code << KernelTypeToEnumStr(info.relatedType) << ", ";
code << static_cast<uint32_t>(info.numBlocks) << ", ";
code << static_cast<uint32_t>(info.entryCnt) << ", ";
code << Hex64ToStr(info.args) << ", ";
code << "{";
for (int j = 0; j < 4; j++) {
code << Hex64ToStr(info.entry[j]);
if (j < 3) code << ", ";
}
code << "}, ";
code << Hex64ToStr(info.debugOptions) << ", ";
code << Hex64ToStr(info.reserved);
code << "}";
if (i < taskQue->taskCnt - 1) code << ",";
code << "\n";
}
code << " };\n";
}
code << "};\n";
return code.str();
}
std::pair<int, int> ConstantCodeGenerator::GetKernelTypeParams(SkKernelType kernelType)
{
switch (kernelType) {
case SkKernelType::AIC_ONLY: return {1, 0};
case SkKernelType::AIV_ONLY: return {0, 1};
case SkKernelType::MIX_AIC_1_1: return {1, 1};
case SkKernelType::MIX_AIC_1_2: return {1, 2};
default: return {1, 1};
}
}
* @brief 生成针对特定 split 的任务执行代码
* @param taskQue 任务队列
* @param taskIdx 任务索引
* @param isAic 是否为 AIC 队列
* @param splitIdx split 索引 (0-3),用于确定 entry 索引
* @return 生成的代码字符串
*
* 与 GenerateTaskExecution 不同,此函数直接使用编译期确定的 entry 索引,
* 消除运行时的 get_coreid() 和取模运算。
*/
std::string ConstantCodeGenerator::GenerateTaskExecutionForSplit(
const TaskQue* taskQue,
size_t taskIdx,
bool isAic,
int splitIdx)
{
std::ostringstream code;
if (taskQue == nullptr || taskIdx >= taskQue->taskCnt) {
return code.str();
}
const TaskInfo& task = taskQue->taskInfos[taskIdx];
code << " // Task[" << taskIdx << "]: " << TaskTypeToEnumStr(task.type)
<< " (split=" << (splitIdx + 1) << ", entryIdx=" << (splitIdx % std::max(1, (int)task.entryCnt)) << ")\n";
switch (task.type) {
case SkTaskType::TYPE_PRELOAD: {
code << " {\n";
code << " auto blockId = AscendC::GetBlockIdx();\n";
code << " if (blockId < " << static_cast<uint32_t>(task.numBlocks) << ") {\n";
int entryIdx = (task.entryCnt > 0) ? (splitIdx % task.entryCnt) : 0;
code << " // [SPLIT] Preload entry[" << entryIdx << "]\n";
code << " constexpr uint64_t PRELOAD_ADDR = 0x"
<< std::hex << task.entry[entryIdx] << std::dec << "ULL;\n";
code << " preload((const void *)(PRELOAD_ADDR), "
<< static_cast<int64_t>(task.args) << "L);\n";
code << " dc_preload(reinterpret_cast<__gm__ uint64_t*>(" << Hex64ToStr(task.reserved) << "), 0);\n";
code << " dc_preload(reinterpret_cast<__gm__ uint64_t*>(" << Hex64ToStr(task.reserved + 8) << "), 0);\n";
code << " }\n";
code << " }\n";
break;
}
case SkTaskType::TYPE_FUNC: {
code << " {\n";
code << " auto blockId = AscendC::GetBlockIdx();\n";
code << " if (blockId < " << static_cast<uint32_t>(task.numBlocks) << ") {\n";
code << " // [SPLIT] sysArgs: numBlocks=" << static_cast<uint32_t>(task.numBlocks) << "\n";
code << " sk::SkSystemArgs sysArgs = {};\n";
code << " sysArgs.skBlockIdx = static_cast<uint16_t>(AscendC::GetBlockIdx());\n";
code << " sysArgs.skNumBlocks = " << static_cast<uint32_t>(task.numBlocks) << ";\n";
code << " sysArgs.skTaskSyncCfg = static_cast<uint16_t>(" << task.reserved << "ULL);\n";
int entryIdx = (task.entryCnt > 0) ? (splitIdx % task.entryCnt) : 0;
code << " // [SPLIT] Func entry[" << entryIdx << "]\n";
code << " constexpr uint64_t FUNC_ADDR = 0x"
<< std::hex << task.entry[entryIdx] << std::dec << "ULL;\n";
code << " ((sk_sub_func)(FUNC_ADDR))"
<< "(reinterpret_cast<const __gm__ void*>(" << Hex64ToStr(task.args) << "), &sysArgs);\n";
code << " }\n";
code << " }\n";
break;
}
case SkTaskType::TYPE_SYNC: {
code << " AscendC::AutoCoreSyncImpl<aic, aiv>(static_cast<SkCoreSyncType>(" << task.args
<< "), static_cast<uint8_t>(" << static_cast<uint32_t>(task.numBlocks) << "), "
<< Hex64ToStr(task.reserved) << ");\n";
break;
}
case SkTaskType::TYPE_EVENT_NOTIFY: {
code << " if ASCEND_IS_AIC { AscendC::NotifyFunc<true>(" << Hex64ToStr(task.args)
<< ", " << Hex64ToStr(task.entry[0]) << "); }\n";
code << " if ASCEND_IS_AIV { AscendC::NotifyFunc<false>(" << Hex64ToStr(task.args)
<< ", " << Hex64ToStr(task.entry[0]) << "); }\n";
break;
}
case SkTaskType::TYPE_EVENT_WAIT: {
code << " if ASCEND_IS_AIC { AscendC::WaitFunc<true>(" << Hex64ToStr(task.args)
<< ", " << Hex64ToStr(task.entry[0]) << ", "
<< static_cast<uint32_t>(task.reserved) << "U); }\n";
code << " if ASCEND_IS_AIV { AscendC::WaitFunc<false>(" << Hex64ToStr(task.args)
<< ", " << Hex64ToStr(task.entry[0]) << ", "
<< static_cast<uint32_t>(task.reserved) << "U); }\n";
break;
}
case SkTaskType::TYPE_EVENT_RESET: {
code << " if ASCEND_IS_AIC { AscendC::ResetFunc<true>(" << Hex64ToStr(task.args)
<< ", " << Hex64ToStr(task.entry[0]) << "); }\n";
code << " if ASCEND_IS_AIV { AscendC::ResetFunc<false>(" << Hex64ToStr(task.args)
<< ", " << Hex64ToStr(task.entry[0]) << "); }\n";
break;
}
default:
break;
}
return code.str();
}
std::string ConstantCodeGenerator::GenerateSpecializedEntry(
const SkTask& aicTask,
const SkTask& aivTask,
SkKernelType kernelType)
{
std::ostringstream code;
auto [aic, aiv] = GetKernelTypeParams(kernelType);
const TaskQue* aicQue = aicTask.GetTaskQue();
const TaskQue* aivQue = aivTask.GetTaskQue();
SK_LOGI("[ConstantCodegen] GenerateSpecializedEntry: kernelType=%s, aic=%d, aiv=%d",
to_string(kernelType), aic, aiv);
SK_LOGI("[ConstantCodegen] AIC queue: taskCnt=%u", aicQue ? aicQue->taskCnt : 0);
SK_LOGI("[ConstantCodegen] AIV queue: taskCnt=%u", aivQue ? aivQue->taskCnt : 0);
code << "\n// ========== Constant TaskQue Definitions ==========\n";
code << "#ifdef __DAV_CUBE__\n";
code << GenerateConstantTaskQue(aicTask, "AIC");
code << "#endif // __DAV_CUBE__\n\n";
code << "#ifdef __DAV_VEC__\n";
code << GenerateConstantTaskQue(aivTask, "AIV");
code << "#endif // __DAV_VEC__\n\n";
for (int splitIdx = 0; splitIdx < 4; splitIdx++) {
code << "// Split entry for coreId % 4 == " << splitIdx << "\n";
code << "template <uint8_t aic, uint8_t aiv>\n";
code << "__aicore__ __attribute__((aligned(512))) void sk_constant_entry_impl_split" << (splitIdx + 1) << "(void) {\n";
if (aicQue != nullptr && aicQue->taskCnt > 0) {
code << "#ifdef __DAV_CUBE__\n";
code << " // AIC Queue Tasks (split " << (splitIdx + 1) << ")\n";
for (uint32_t i = 0; i < aicQue->taskCnt; i++) {
code << GenerateTaskExecutionForSplit(aicQue, i, true, splitIdx);
}
code << "#endif // __DAV_CUBE__\n";
}
if (aivQue != nullptr && aivQue->taskCnt > 0) {
code << "#ifdef __DAV_VEC__\n";
code << " // AIV Queue Tasks (split " << (splitIdx + 1) << ")\n";
for (uint32_t i = 0; i < aivQue->taskCnt; i++) {
code << GenerateTaskExecutionForSplit(aivQue, i, false, splitIdx);
}
code << "#endif // __DAV_VEC__\n";
}
code << " pipe_barrier(PIPE_ALL);\n";
code << "}\n\n";
}
std::string entryFuncName = "sk_constant_entry_" + options_.skId;
code << "// Entry point: dispatches to split functions based on coreId % 4\n";
code << "extern \"C\" __global__ __attribute__((aligned(512))) __mix__(" << aic << ", " << aiv << ")\n";
code << "void " << entryFuncName << "(void) {\n";
code << " uint8_t coreSplitIdx = (uint8_t)get_coreid() & 0x3; // coreId % 4\n";
code << " switch (coreSplitIdx) {\n";
code << " case 0: sk_constant_entry_impl_split1<" << aic << ", " << aiv << ">(); break;\n";
code << " case 1: sk_constant_entry_impl_split2<" << aic << ", " << aiv << ">(); break;\n";
code << " case 2: sk_constant_entry_impl_split3<" << aic << ", " << aiv << ">(); break;\n";
code << " case 3: sk_constant_entry_impl_split4<" << aic << ", " << aiv << ">(); break;\n";
code << " }\n";
code << "}\n";
SK_LOGI("[ConstantCodegen] Generated entry function with 4 splits: %s", entryFuncName.c_str());
return code.str();
}
std::string ConstantCodeGenerator::GenerateCombinedSource(
const SkTask& aicTask,
const SkTask& aivTask,
SkKernelType kernelType)
{
std::ostringstream source;
source << SK_COMMON_CODE << "\n";
source << SK_IMPL_CODE << "\n";
source << GenerateSpecializedEntry(aicTask, aivTask, kernelType);
return source.str();
}
ConstantCodeGenResult ConstantCodeGenerator::Generate(
const SkTask& aicTask,
const SkTask& aivTask,
const SkHeaderInfo& header)
{
ConstantCodeGenResult result;
if (!options_.enableConstantCodeGen) {
SK_LOGI("Constant code generation is disabled");
return result;
}
SkKernelType kernelType = SkKernelType::DEFAULT;
const TaskQue* aicQue = aicTask.GetTaskQue();
const TaskQue* aivQue = aivTask.GetTaskQue();
uint32_t aicFuncCnt = aicTask.funcCnt;
uint32_t aivFuncCnt = aivTask.funcCnt;
if (aicFuncCnt == 0 && aivFuncCnt > 0) {
kernelType = SkKernelType::AIV_ONLY;
} else if (aicFuncCnt > 0 && aivFuncCnt == 0) {
kernelType = SkKernelType::AIC_ONLY;
} else if (aicFuncCnt > 0 && aivFuncCnt > 0) {
if (aicTask.nodeType == SkKernelType::MIX_AIC_1_2 || aivTask.nodeType == SkKernelType::MIX_AIC_1_2) {
kernelType = SkKernelType::MIX_AIC_1_2;
} else {
kernelType = SkKernelType::MIX_AIC_1_1;
}
}
result.combinedSource = GenerateCombinedSource(aicTask, aivTask, kernelType);
SK_LOGI("Generated constant source code, size=%zu bytes", result.combinedSource.size());
return result;
}
ConstantCodeGenResult ConstantCodeGenerator::CompileAndResolve(
const std::string& source,
SkKernelType kernelType)
{
ConstantCodeGenResult result;
result.combinedSource = source;
aclrtcProg prog = nullptr;
aclError ret = aclrtcCreateProg(&prog, source.c_str(), "sk_constant.asc", 0, nullptr, nullptr);
if (ret != ACL_SUCCESS) {
SK_LOGE("Failed to create aclrtc program: ret=%d", ret);
return result;
}
const char* options[] = {
"--npu-arch=dav-2201",
"-O3"
};
int numOptions = sizeof(options) / sizeof(options[0]);
ret = aclrtcCompileProg(prog, numOptions, options);
if (ret != ACL_SUCCESS) {
size_t logSize = 0;
aclrtcGetCompileLogSize(prog, &logSize);
if (logSize > 0) {
std::vector<char> logBuf(logSize);
aclrtcGetCompileLog(prog, logBuf.data());
SK_LOGE("aclrtc compilation failed: %s", logBuf.data());
}
aclrtcDestroyProg(&prog);
return result;
}
size_t binSize = 0;
ret = aclrtcGetBinDataSize(prog, &binSize);
if (ret != ACL_SUCCESS) {
SK_LOGE("Failed to get binary size: ret=%d", ret);
aclrtcDestroyProg(&prog);
return result;
}
result.compiledBinary.resize(binSize);
ret = aclrtcGetBinData(prog, reinterpret_cast<char*>(result.compiledBinary.data()));
if (ret != ACL_SUCCESS) {
SK_LOGE("Failed to get binary data: ret=%d", ret);
aclrtcDestroyProg(&prog);
return result;
}
aclrtBinaryLoadOptions loadOpts;
aclrtBinaryLoadOption opt;
opt.type = ACL_RT_BINARY_LOAD_OPT_LAZY_MAGIC;
opt.value.magic = ACL_RT_BINARY_MAGIC_ELF_AICORE;
loadOpts.numOpt = 1;
loadOpts.options = &opt;
ret = aclrtBinaryLoadFromData(reinterpret_cast<void*>(result.compiledBinary.data()),
binSize, &loadOpts, &result.specializedBinHandle);
if (ret != ACL_SUCCESS) {
SK_LOGE("Failed to load binary: ret=%d", ret);
aclrtcDestroyProg(&prog);
return result;
}
std::string entryFuncName = "sk_constant_entry_" + options_.skId;
ret = aclrtBinaryGetFunction(result.specializedBinHandle, entryFuncName.c_str(), &result.specializedFuncHandle);
if (ret != ACL_SUCCESS) {
SK_LOGE("Failed to get function handle: ret=%d, funcName=%s", ret, entryFuncName.c_str());
aclrtBinaryUnLoad(result.specializedBinHandle);
result.specializedBinHandle = nullptr;
}
SK_LOGI("Successfully compiled and resolved funcHandle: binHandle=%p, funcHandle=%p",
result.specializedBinHandle, result.specializedFuncHandle);
aclrtcDestroyProg(&prog);
return result;
}
ConstantFuncHandleManager& ConstantFuncHandleManager::Instance()
{
static ConstantFuncHandleManager instance;
return instance;
}
void ConstantFuncHandleManager::RegisterFuncHandle(
const std::string& skId,
aclrtFuncHandle funcHandle,
aclrtBinHandle binHandle)
{
std::lock_guard<std::mutex> lock(mutex_);
HandlePair pair;
pair.funcHandle = funcHandle;
pair.binHandle = binHandle;
handleMap_[skId] = pair;
SK_LOGI("Registered constant funcHandle: skId=%s, funcHandle=%p, binHandle=%p",
skId.c_str(), funcHandle, binHandle);
}
aclrtFuncHandle ConstantFuncHandleManager::GetFuncHandle(const std::string& skId) const
{
auto it = handleMap_.find(skId);
if (it != handleMap_.end()) {
return it->second.funcHandle;
}
return nullptr;
}
aclrtBinHandle ConstantFuncHandleManager::GetBinHandle(const std::string& skId) const
{
auto it = handleMap_.find(skId);
if (it != handleMap_.end()) {
return it->second.binHandle;
}
return nullptr;
}
bool ConstantFuncHandleManager::HasFuncHandle(const std::string& skId) const
{
return handleMap_.find(skId) != handleMap_.end();
}
void ConstantFuncHandleManager::Clear()
{
std::lock_guard<std::mutex> lock(mutex_);
for (auto& pair : handleMap_) {
if (pair.second.binHandle != nullptr) {
aclrtBinaryUnLoad(pair.second.binHandle);
}
}
handleMap_.clear();
}
* @brief 将生成的源码和二进制写入 sk_meta 目录
* @param skId Super Kernel ID
* @param sourceCode 生成的源码
* @param binaryData 编译后的二进制数据
* @param kernelType 内核类型
* @param modelLabel 模型标签
* @return 是否写入成功
*/
bool DumpConstantCodegenFiles(
const std::string& skId,
const std::string& sourceCode,
const std::vector<uint8_t>& binaryData,
SkKernelType kernelType,
const std::string& modelLabel)
{
std::string baseDir = CreateSkMetaDirectory(modelLabel);
if (baseDir.empty()) {
SK_LOGE("[ConstantCodegen] Failed to create sk_meta directory");
return false;
}
std::string codegenDir = baseDir + "/constant_codegen";
if (!CreateDirectoryRecursive(codegenDir)) {
SK_LOGE("[ConstantCodegen] Failed to create constant_codegen directory: %s", codegenDir.c_str());
return false;
}
std::string filePrefix = codegenDir + "/sk_" + skId + "_" + to_string(kernelType);
std::string sourceFile = filePrefix + ".asc";
std::ofstream srcOut(sourceFile, std::ios::out);
if (srcOut.is_open()) {
srcOut << sourceCode;
srcOut.close();
SK_LOGI("[ConstantCodegen] Source code written to: %s (size=%zu bytes)",
sourceFile.c_str(), sourceCode.size());
} else {
SK_LOGE("[ConstantCodegen] Failed to write source file: %s", sourceFile.c_str());
return false;
}
std::string binaryFile = filePrefix + ".bin";
std::ofstream binOut(binaryFile, std::ios::binary);
if (binOut.is_open()) {
binOut.write(reinterpret_cast<const char*>(binaryData.data()), binaryData.size());
binOut.close();
SK_LOGI("[ConstantCodegen] Binary written to: %s (size=%zu bytes)",
binaryFile.c_str(), binaryData.size());
} else {
SK_LOGE("[ConstantCodegen] Failed to write binary file: %s", binaryFile.c_str());
return false;
}
return true;
}
* @brief 尝试生成并使用常量化 funcHandle
*
* 此函数是核心集成点,在 SkTaskBuilder::GenEntryInfo 中调用。
* 如果常量化成功,返回新的 funcHandle;否则返回 nullptr,使用原有逻辑。
*
* @param aicTask AIC 任务队列
* @param aivTask AIV 任务队列
* @param opts 选项管理器
* @param modelLabel 模型标签,用于生成 sk_meta 路径
* @return std::pair<aclrtFuncHandle, SkKernelType>
* first: 常量化 funcHandle(失败为 nullptr)
* second: 内核类型
*/
std::pair<aclrtFuncHandle, SkKernelType> TryGenerateConstantFuncHandle(
const SkTask& aicTask,
const SkTask& aivTask,
SuperKernelOptionsManager& opts,
const std::string& modelLabel)
{
SK_LOGI("[ConstantCodegen] Start constant codegen");
bool enableConstant = false;
const char* envConstant = std::getenv("SK_CONSTANT");
if (envConstant != nullptr && std::string(envConstant) == "1") {
enableConstant = true;
SK_LOGI("[ConstantCodegen] Enabled by environment variable SK_CONSTANT=1");
}
auto constantOpt = opts.GetOption(aclskOptionType::CONSTANT_CODEGEN);
if (constantOpt != nullptr) {
if (constantOpt->GetIntValue() == 1) {
enableConstant = true;
SK_LOGI("[ConstantCodegen] Enabled by CONSTANT_CODEGEN option");
} else {
enableConstant = false;
SK_LOGI("[ConstantCodegen] Disabled by CONSTANT_CODEGEN option");
}
}
if (!enableConstant) {
SK_LOGI("[ConstantCodegen] Constant codegen is disabled (default). Set SK_CONSTANT=1 to enable.");
return {nullptr, SkKernelType::DEFAULT};
}
SK_LOGI("[ConstantCodegen] Constant codegen is enabled");
aclrtBinHandle entryBinHandle = AscendGetEntryBinHandle();
if (entryBinHandle == nullptr) {
SK_LOGI("[ConstantCodegen] Entry binHandle is null, skip constant codegen");
return {nullptr, SkKernelType::DEFAULT};
}
const TaskQue* aicQue = aicTask.GetTaskQue();
const TaskQue* aivQue = aivTask.GetTaskQue();
SK_LOGI("[ConstantCodegen] Task analysis:");
SK_LOGI(" - AIC: funcCnt=%u, numBlocks=%u, taskCnt=%u",
aicTask.funcCnt, aicTask.numBlocks, aicQue ? aicQue->taskCnt : 0);
SK_LOGI(" - AIV: funcCnt=%u, numBlocks=%u, taskCnt=%u",
aivTask.funcCnt, aivTask.numBlocks, aivQue ? aivQue->taskCnt : 0);
if ((aicQue == nullptr || aicQue->taskCnt == 0) &&
(aivQue == nullptr || aivQue->taskCnt == 0)) {
SK_LOGE("[ConstantCodegen] Both AIC and AIV task queues are empty");
return {nullptr, SkKernelType::DEFAULT};
}
SkKernelType kernelType = SkKernelType::MIX_AIC_1_1;
if (aicTask.funcCnt == 0 && aivTask.funcCnt > 0) {
kernelType = SkKernelType::AIV_ONLY;
} else if (aicTask.funcCnt > 0 && aivTask.funcCnt == 0) {
kernelType = SkKernelType::AIC_ONLY;
} else if (aicTask.nodeType == SkKernelType::MIX_AIC_1_2 || aivTask.nodeType == SkKernelType::MIX_AIC_1_2) {
kernelType = SkKernelType::MIX_AIC_1_2;
}
SK_LOGI("[ConstantCodegen] Kernel type determined: %s", to_string(kernelType));
static std::atomic<uint64_t> skIdCounter{0};
std::string skId = std::to_string(skIdCounter.fetch_add(1));
SK_LOGI("[ConstantCodegen] Generated skId: %s", skId.c_str());
ConstantCodeGenOptions codegenOpts;
codegenOpts.enableConstantCodeGen = true;
codegenOpts.enableUnrollOptimization = true;
codegenOpts.skId = skId;
ConstantCodeGenerator generator(codegenOpts);
SK_LOGI("[ConstantCodegen] Generating constant source code...");
SkHeaderInfo header;
ConstantCodeGenResult genResult = generator.Generate(aicTask, aivTask, header);
if (genResult.combinedSource.empty()) {
SK_LOGE("[ConstantCodegen] Failed to generate constant source code");
return {nullptr, SkKernelType::DEFAULT};
}
SK_LOGI("[ConstantCodegen] Source code generated, size=%zu bytes", genResult.combinedSource.size());
SK_LOGI("[ConstantCodegen] Starting JIT compilation...");
ConstantCodeGenResult compileResult = generator.CompileAndResolve(genResult.combinedSource, kernelType);
if (compileResult.specializedFuncHandle == nullptr) {
SK_LOGE("[ConstantCodegen] JIT compilation failed");
DumpConstantCodegenFiles(skId, genResult.combinedSource, {}, kernelType, modelLabel);
return {nullptr, SkKernelType::DEFAULT};
}
SK_LOGI("[ConstantCodegen] JIT compilation succeeded, funcHandle=%p, binHandle=%p",
compileResult.specializedFuncHandle, compileResult.specializedBinHandle);
DumpConstantCodegenFiles(skId, genResult.combinedSource, compileResult.compiledBinary, kernelType, modelLabel);
ConstantFuncHandleManager::Instance().RegisterFuncHandle(
skId, compileResult.specializedFuncHandle, compileResult.specializedBinHandle);
SK_LOGI("[ConstantCodegen] Constant codegen SUCCESS");
SK_LOGI("[ConstantCodegen] Result: funcHandle=%p, kernelType=%s",
compileResult.specializedFuncHandle, to_string(kernelType));
return {compileResult.specializedFuncHandle, kernelType};
}