* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <algorithm>
#include <cstdlib>
#include <dlfcn.h>
#include <string>
#include <stdexcept>
#include <sys/stat.h>
#include <unordered_map>
#include <vector>
#include "sk_optimizer.h"
#include "sk_scope_split.h"
#include "sk_scope_postprocess.h"
#include "sk_task_builder.h"
#include "sk_log.h"
#include "sk_dump_json.h"
#include "aprof_pub.h"
#include "securec.h"
#include "sk_event_recorder.h"
namespace {
void PrintSKNodesDetail(std::string skFuncName, SuperKernelScopeInfo& scopeInfo)
{
uint16_t scopeId = scopeInfo.GetScopeId();
auto& nodes = scopeInfo.GetExtInfo().filteredNodes;
SK_LOGI(" SK Function: %s, scope id: %u, Node Count: %zu", skFuncName.c_str(), scopeId, nodes.size());
for (size_t i = 0; i < nodes.size(); ++i) {
SK_LOGI(" [%zu] %s", i, nodes[i]->Format().c_str());
}
}
void PrintSKNodes(std::string skFuncName, SuperKernelScopeInfo& scopeInfo)
{
{
SK_LOG_CONTEXT_SIMPLE("sk_fused_nodes.log");
PrintSKNodesDetail(skFuncName, scopeInfo);
}
PrintSKNodesDetail(skFuncName, scopeInfo);
}
void PrintTaskNodesDetail(const std::vector<SuperKernelBaseNode*>& nodes, const char* tag)
{
SK_LOGI("%s: node count=%zu", tag, nodes.size());
for (size_t i = 0; i < nodes.size(); ++i) {
SuperKernelBaseNode* node = nodes[i];
if (node == nullptr) {
SK_LOGI(" [%zu] nullptr", i);
continue;
}
SK_LOGI(" [%zu] %s", i, node->Format().c_str());
}
}
}
std::vector<SuperKernelBaseNode*> SuperKernelOptimizer::ReorderWaitNodesForTaskBuild(
const std::vector<SuperKernelBaseNode*>& taskNodes) const
{
if (!ShouldReorderWaitNodesForTaskBuild()) {
SK_LOGI("task reorder disabled: auto_op_parallel is not CUSTOMIZE_QUEUE");
return taskNodes;
}
struct PendingWaitNode {
SuperKernelBaseNode* node = nullptr;
size_t originalIdx = 0;
size_t targetKernelIdx = 0;
};
const size_t invalidKernelIdx = taskNodes.size();
std::unordered_map<uint32_t, size_t> nextKernelIdxByStream;
std::vector<size_t> waitTargetKernelIdx(taskNodes.size(), invalidKernelIdx);
for (size_t idx = taskNodes.size(); idx > 0; --idx) {
size_t curIdx = idx - 1;
SuperKernelBaseNode* node = taskNodes[curIdx];
uint32_t streamIdx = node->GetStreamIdxInGraph();
if (node->GetNodeType() == SkNodeType::NODE_KERNEL) {
nextKernelIdxByStream[streamIdx] = curIdx;
} else if (node->GetNodeType() == SkNodeType::NODE_WAIT) {
auto kernelIt = nextKernelIdxByStream.find(streamIdx);
if (kernelIt != nextKernelIdxByStream.end()) {
waitTargetKernelIdx[curIdx] = kernelIt->second;
}
}
}
size_t moveCount = 0;
std::vector<PendingWaitNode> pendingWaitNodes;
std::vector<SuperKernelBaseNode*> reorderedTaskNodes;
reorderedTaskNodes.reserve(taskNodes.size());
auto flushPendingWaitNodes = [&](size_t currentIdx) {
std::vector<PendingWaitNode> remainedWaitNodes;
remainedWaitNodes.reserve(pendingWaitNodes.size());
bool hasReadyWaitNode = false;
for (const auto& pendingWaitNode : pendingWaitNodes) {
if (pendingWaitNode.targetKernelIdx > currentIdx) {
remainedWaitNodes.push_back(pendingWaitNode);
continue;
}
hasReadyWaitNode = true;
size_t finalIdx = reorderedTaskNodes.size();
reorderedTaskNodes.push_back(pendingWaitNode.node);
if (finalIdx != pendingWaitNode.originalIdx) {
++moveCount;
}
SK_LOGI("task reorder: place deferred wait node, waitNodeId=%lu, streamIdx=%u, originalIdx=%zu, finalIdx=%zu, targetKernelIdx=%zu",
pendingWaitNode.node->GetNodeId(), pendingWaitNode.node->GetStreamIdxInGraph(),
pendingWaitNode.originalIdx, finalIdx, pendingWaitNode.targetKernelIdx);
}
if (!hasReadyWaitNode) {
return;
}
pendingWaitNodes.swap(remainedWaitNodes);
};
for (size_t idx = 0; idx < taskNodes.size(); ++idx) {
SuperKernelBaseNode* node = taskNodes[idx];
if (node->GetNodeType() == SkNodeType::NODE_WAIT) {
uint32_t streamIdx = node->GetStreamIdxInGraph();
if (waitTargetKernelIdx[idx] == invalidKernelIdx) {
SK_LOGI("task reorder: keep wait node in place, waitNodeId=%lu, "
"streamIdx=%u, originalIdx=%zu, reason=no later kernel in same stream",
node->GetNodeId(), streamIdx, idx);
flushPendingWaitNodes(idx);
reorderedTaskNodes.push_back(node);
continue;
}
SuperKernelBaseNode* targetKernelNode = taskNodes[waitTargetKernelIdx[idx]];
SK_LOGI("task reorder: defer wait node for same-stream kernel, waitNodeId=%lu, "
"streamIdx=%u, originalIdx=%zu, targetKernelIdx=%zu, targetKernelNodeId=%lu",
node->GetNodeId(), streamIdx, idx, waitTargetKernelIdx[idx],
targetKernelNode == nullptr ? INVALID_TASK_ID : targetKernelNode->GetNodeId());
pendingWaitNodes.push_back({node, idx, waitTargetKernelIdx[idx]});
continue;
}
flushPendingWaitNodes(idx);
reorderedTaskNodes.push_back(node);
}
flushPendingWaitNodes(taskNodes.size());
SK_LOGI("task reorder finished: originalCount=%zu, moveCount=%zu", taskNodes.size(), moveCount);
return reorderedTaskNodes;
}
bool SuperKernelOptimizer::ShouldReorderWaitNodesForTaskBuild() const
{
const auto* autoOpParallel = opts.GetOption(aclskOptionType::AUTO_OP_PARALLEL);
if (autoOpParallel == nullptr) {
return false;
}
return static_cast<SkHeapType>(autoOpParallel->GetIntValue()) == SkHeapType::CUSTOMIZE_QUEUE;
}
bool SuperKernelOptimizer::Update(SuperKernelScopeInfo& scopeInfo, SuperKernelGraph& graph,
const SkLaunchInfo& launchInfo)
{
const std::string& scopeName = scopeInfo.GetExtInfo().scopeName;
SK_LOGI("scope update begin: scopeName=%s, streamCount=%zu", scopeName.c_str(), scopeInfo.GetScopeStreamInfos().size());
bool skMainNodeUpdated = false;
size_t updateTotalCount = 0;
auto& scopeStreamInfos = scopeInfo.GetScopeStreamInfos();
auto& extInfo = scopeInfo.MutableExtInfo();
for (size_t streamIdx = 0; streamIdx < scopeStreamInfos.size(); ++streamIdx) {
auto& streamInfo = scopeStreamInfos[streamIdx];
auto& customParams = extInfo.customParamsList[streamIdx];
SK_LOGI("update stream begin: scopeName=%s, streamId=%u, headNodeId=%lu, tailNodeId=%lu, nodeSize=%lu, customParamSize=%zu",
scopeName.c_str(), streamInfo.streamIdx, streamInfo.headNodeIdx, streamInfo.tailNodeIdx, streamInfo.nodeSize,
customParams.size());
size_t customParamSize = customParams.size();
if (streamInfo.nodeSize < customParamSize) {
SK_LOGE("node size is less than custom params size: nodeSize=%lu, customParamSize=%zu", streamInfo.nodeSize,
customParamSize);
return false;
}
uint64_t curNodeId = streamInfo.headNodeIdx;
size_t eventCnt = 0;
while (curNodeId != INVALID_TASK_ID) {
auto* node = graph.GetNodeById(curNodeId);
if (node == nullptr) {
SK_LOGE("node not found during stream-based update: nodeId=%lu, streamId=%u", curNodeId,
streamInfo.streamIdx);
return false;
}
UpdateContext ctx;
if (eventCnt < customParamSize) {
auto& curCustomParams = customParams[eventCnt++];
ctx.customParams = &curCustomParams;
} else if (curNodeId == extInfo.skMainNodeId) {
if (!skMainNodeUpdated) {
skMainNodeUpdated = true;
ctx.launchInfo = const_cast<SkLaunchInfo*>(&launchInfo);
} else {
SK_LOGI("repeat find sk launch node, skip update kernel and set invalid node");
}
}
++updateTotalCount;
if (!node->Update(ctx)) {
SK_LOGE("node update failed: nodeId=%lu, streamId=%u", curNodeId, streamInfo.streamIdx);
return false;
}
if (curNodeId == streamInfo.tailNodeIdx) {
break;
}
curNodeId = node->GetNextNodeId();
}
SK_LOGI("update stream end: scopeName=%s, streamId=%u, visitedNodes=%zu", scopeName.c_str(), streamInfo.streamIdx, eventCnt);
}
if (!skMainNodeUpdated) {
SK_LOGE("not find sk launch node, sk optimize failed");
return false;
}
graph.SetUpdateFlag(true);
SK_LOGI("scope update finished: scopeName=%s, updateTotalNodes=%zu", scopeName.c_str(), updateTotalCount);
return true;
}
bool SuperKernelOptimizer::Schedule(SuperKernelScopeInfo& scopeInfo, SuperKernelGraph& graph,
SkTaskBuilder& builder)
{
const auto& taskNodes = scopeInfo.GetExtInfo().filteredNodes;
if (taskNodes.empty()) {
SK_LOGE("no tasks for super kernel optimization: scope has 0 nodes for optimization");
return false;
}
std::vector<SuperKernelBaseNode*> reorderedTaskNodes = ReorderWaitNodesForTaskBuild(taskNodes);
if (reorderedTaskNodes.size() != taskNodes.size()) {
SK_LOGE("task reorder produced invalid size: originalCount=%zu, reorderedCount=%zu",
taskNodes.size(), reorderedTaskNodes.size());
return false;
}
std::string skFuncName = GetSkFuncName(reorderedTaskNodes, scopeInfo.GetScopeId(), scopeInfo.GetExtInfo().scopeName);
PrintSKNodes(skFuncName, scopeInfo);
PrintTaskNodesDetail(reorderedTaskNodes, "reordered task nodes");
std::vector<SuperKernelBaseNode*> customTasks;
customTasks.reserve(scopeInfo.GetExtInfo().eventNodes.size());
for (const auto& eventNode : scopeInfo.GetExtInfo().eventNodes) {
customTasks.emplace_back(eventNode.get());
}
SK_LOGI("schedule scope: taskCount=%zu, customTaskCount=%zu, updateStreamCount=%zu", reorderedTaskNodes.size(),
customTasks.size(), scopeInfo.GetScopeStreamInfos().size());
SkBuildResult buildResult = builder.Build(skFuncName, reorderedTaskNodes, customTasks, scopeInfo.GetScopeId());
SkLaunchInfo& launchInfo = buildResult.launchInfo;
taskQueueJsons_[std::to_string(scopeInfo.GetScopeId())] = buildResult.taskQueueJson;
if (!SkProfiling(scopeInfo, launchInfo, graph)) {
SK_LOGE("SkProfiling failed");
return false;
}
if (!DumpProfilingDetail(reorderedTaskNodes, launchInfo, scopeInfo, graph)) {
SK_LOGE("Dump sk time profiling detail failed");
return false;
}
if (launchInfo.entryInfo.skEntryFunc == nullptr || launchInfo.devArgs.Get() == nullptr) {
SK_LOGE("schedule failed: build launch info failed");
return false;
}
SK_LOGI("schedule scope: build finished, entryType=%s, entryFuncHandle=%p, skFuncName=%s",
to_string(launchInfo.entryInfo.entryType), launchInfo.entryInfo.skEntryFunc, launchInfo.skFuncName.c_str());
if (!Update(scopeInfo, graph, launchInfo)) {
SK_LOGE("schedule failed: scope update failed");
return false;
}
return true;
}
bool SuperKernelOptimizer::Process(SuperKernelGraph& graph)
{
SuperKernelScopeSplitter splitter(graph, opts);
if (splitter.SplitGraph()) {
SK_LOGI("graph split into %zu scopes", splitter.GetScopeInfos().size());
} else {
SK_LOGE("graph split failed or no scopes found: cannot proceed with super kernel optimization");
return false;
}
processedScopeInfos_ = std::move(splitter.GetScopeInfos());
SkTaskBuilder builder(opts, graph);
SuperKernelScopePostProcessor postProcessor(graph);
for (size_t i = 0; i < processedScopeInfos_.size(); ++i) {
auto& scopeInfo = processedScopeInfos_[i];
SK_LOGI("process scope begin: scopeId=%u", scopeInfo.GetScopeId());
if (!postProcessor.PostProcess(scopeInfo)) {
scopeInfo.MutableExtInfo().failReason = ScopeFailReason::STREAM_SYNC_FAIL;
SK_LOGI("scope unprocessable after post-process, skip schedule/update: scopeId=%u, reason=%s",
scopeInfo.GetScopeId(), ScopeFailReasonToStr(scopeInfo.GetExtInfo().failReason));
ScopeFailReason failReason = scopeInfo.GetExtInfo().failReason;
for (auto* node : scopeInfo.GetNodes()) {
if (node != nullptr) {
node->SetFusionFailReason(FusionFailReason::SCOPE_FUSE_PART, failReason);
node->SetIsFusible(false);
}
}
continue;
}
if (scopeInfo.GetExtInfo().filteredNodes.empty()) {
SK_LOGI("scope has no nodes after post-process, skipping schedule/update: scopeId=%u", scopeInfo.GetScopeId());
continue;
}
scopeInfo.MutableExtInfo().scopeName = ScopeSplitPass::GetScopeNamesFromBitFlags(scopeInfo.GetScopeBitFlags(), graph);
if (!Schedule(scopeInfo, graph, builder)) {
SK_LOGE("process scope failed: scopeId=%u, schedule/update returned false", scopeInfo.GetScopeId());
return false;
}
}
graph.DumpFusionFailReasons(processedScopeInfos_);
if (!taskQueueJsons_.empty()) {
SK_LOGI("Dumping all task queues to JSON, scopeCount=%zu", taskQueueJsons_.size());
if (!DumpAllTaskQueuesToJson(graph, taskQueueJsons_)) {
SK_LOGE("Failed to dump all task queues to JSON, continuing...");
}
}
SK_LOGI("super kernel process finished: scopeCount=%zu", processedScopeInfos_.size());
return true;
}