* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Source Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "sk_candidate_heap.h"
#include "sk_log.h"
SkCandidateHeap::SkCandidateHeap(SuperKernelGraph& inputGraph, SkHeapType heapModeIn)
: graph_(inputGraph)
, heapMode(heapModeIn)
, defaultNodes_(CompareByNodeId)
, kernelNodes_(CompareByNodeId)
, nonKernelNodes_(CompareByNodeId)
, prevKernelTypeClass_(KernelTypeClass::OTHER)
, prevKernelStreamIdx_(0)
, prevStreamIdx_(0)
, isFirstSelection_(true)
{}
void SkCandidateHeap::push(uint64_t nodeId)
{
if (heapMode == SkHeapType::PRIORITY_QUEUE) {
nodeHeap_.push(nodeId);
return;
}
auto node = graph_.GetNodeById(nodeId);
if (node == nullptr) {
SK_LOGW("SkCandidateHeap::push: attempting to push null node, ignored");
return;
}
SkNodeType nodeType = node->GetNodeType();
if (nodeType == SkNodeType::NODE_DEFAULT) {
defaultNodes_.insert(node);
SK_LOGD("SkCandidateHeap::push: added default node %s, defaultCount=%zu",
node->Format().c_str(), defaultNodes_.size());
} else if (nodeType == SkNodeType::NODE_KERNEL) {
if (node->IsScopeNode()) {
nonKernelNodes_.insert(node);
SK_LOGD("SkCandidateHeap::push: added non-kernel node %s (type=%d), nonKernelCount=%zu",
node->Format().c_str(), static_cast<int>(nodeType), nonKernelNodes_.size());
} else {
kernelNodes_.insert(node);
SK_LOGD("SkCandidateHeap::push: added kernel node %s, kernelCount=%zu",
node->Format().c_str(), kernelNodes_.size());
}
} else if (nodeType == SkNodeType::NODE_NOTIFY ||
nodeType == SkNodeType::NODE_WAIT ||
nodeType == SkNodeType::NODE_RESET) {
nonKernelNodes_.insert(node);
SK_LOGD("SkCandidateHeap::push: added non-kernel node %s (type=%d), nonKernelCount=%zu",
node->Format().c_str(), static_cast<int>(nodeType), nonKernelNodes_.size());
} else {
SK_LOGW("SkCandidateHeap::push: unsupported node type %d for node %s, ignored",
static_cast<int>(nodeType), node->Format().c_str());
}
}
uint64_t SkCandidateHeap::pop()
{
if (heapMode == SkHeapType::PRIORITY_QUEUE) {
uint64_t nodeId = nodeHeap_.top();
nodeHeap_.pop();
return nodeId;
}
if (empty()) {
SK_LOGE("SkCandidateHeap::pop: heap size is zero");
return 0;
}
SuperKernelBaseNode* selectedNode = nullptr;
if (!defaultNodes_.empty()) {
selectedNode = *defaultNodes_.begin();
defaultNodes_.erase(selectedNode);
SK_LOGD("SkCandidateHeap::pop: selected default node %s, remaining defaultCount=%zu",
selectedNode->Format().c_str(), defaultNodes_.size());
return selectedNode->GetNodeId();
}
if (!nonKernelNodes_.empty()) {
selectedNode = SelectNextNonKernelNode();
if (selectedNode != nullptr) {
nonKernelNodes_.erase(selectedNode);
prevStreamIdx_ = selectedNode->GetStreamIdxInGraph();
SK_LOGD("SkCandidateHeap::pop: selected non-kernel node %s, remaining nonKernelCount=%zu",
selectedNode->Format().c_str(), nonKernelNodes_.size());
return selectedNode->GetNodeId();
}
}
if (!kernelNodes_.empty()) {
selectedNode = SelectNextKernelNode();
if (selectedNode != nullptr) {
kernelNodes_.erase(selectedNode);
prevKernelTypeClass_ = GetKernelTypeClass(selectedNode->GetKernelType());
prevKernelStreamIdx_ = selectedNode->GetStreamIdxInGraph();
prevStreamIdx_ = selectedNode->GetStreamIdxInGraph();
isFirstSelection_ = false;
SK_LOGD("SkCandidateHeap::pop: selected kernel node %s, remaining kernelCount=%zu",
selectedNode->Format().c_str(), kernelNodes_.size());
return selectedNode->GetNodeId();
}
}
SK_LOGE("SkCandidateHeap::pop: cannot select node, both heaps are empty");
return 0;
}
bool SkCandidateHeap::empty() const
{
if (heapMode == SkHeapType::PRIORITY_QUEUE) {
return nodeHeap_.empty();
}
return defaultNodes_.empty() && kernelNodes_.empty() && nonKernelNodes_.empty();
}
size_t SkCandidateHeap::size() const
{
if (heapMode == SkHeapType::PRIORITY_QUEUE) {
return nodeHeap_.size();
}
return defaultNodes_.size() + kernelNodes_.size() + nonKernelNodes_.size();
}
bool SkCandidateHeap::HasKernelNodes() const
{
return !kernelNodes_.empty();
}
void SkCandidateHeap::clear()
{
while (!nodeHeap_.empty()) {
nodeHeap_.pop();
}
defaultNodes_.clear();
kernelNodes_.clear();
nonKernelNodes_.clear();
ResetSelectionState();
}
void SkCandidateHeap::reset()
{
clear();
}
void SkCandidateHeap::ResetSelectionState()
{
prevKernelTypeClass_ = KernelTypeClass::OTHER;
prevKernelStreamIdx_ = 0;
prevStreamIdx_ = 0;
isFirstSelection_ = true;
}
SuperKernelBaseNode* SkCandidateHeap::SelectNextKernelNode()
{
if (kernelNodes_.empty()) {
return nullptr;
}
SuperKernelBaseNode* selectedNode = nullptr;
if (isFirstSelection_) {
for (auto* node : kernelNodes_) {
if (IsMixKernelType(node->GetKernelType())) {
selectedNode = node;
SK_LOGD("SkCandidateHeap::SelectNextKernelNode: first selection, chose MIX node %s",
selectedNode->Format().c_str());
return selectedNode;
}
}
selectedNode = *kernelNodes_.begin();
SK_LOGD("SkCandidateHeap::SelectNextKernelNode: first selection, no MIX found, chose node %s",
selectedNode->Format().c_str());
return selectedNode;
}
bool foundByRule = false;
if (prevKernelTypeClass_ == KernelTypeClass::MIX) {
for (auto* node : kernelNodes_) {
if (IsMixKernelType(node->GetKernelType()) &&
node->GetStreamIdxInGraph() != prevKernelStreamIdx_) {
selectedNode = node;
foundByRule = true;
break;
}
}
} else if (prevKernelTypeClass_ == KernelTypeClass::VEC) {
for (auto* node : kernelNodes_) {
if (IsCubeKernelType(node->GetKernelType()) &&
node->GetStreamIdxInGraph() != prevKernelStreamIdx_) {
selectedNode = node;
foundByRule = true;
break;
}
}
} else if (prevKernelTypeClass_ == KernelTypeClass::CUBE) {
for (auto* node : kernelNodes_) {
if (IsVecKernelType(node->GetKernelType()) &&
node->GetStreamIdxInGraph() != prevKernelStreamIdx_) {
selectedNode = node;
foundByRule = true;
break;
}
}
}
if (!foundByRule) {
for (auto* node : kernelNodes_) {
if (node->GetStreamIdxInGraph() != prevKernelStreamIdx_) {
selectedNode = node;
foundByRule = true;
break;
}
}
if (!foundByRule) {
selectedNode = *kernelNodes_.begin();
}
}
return selectedNode;
}
SuperKernelBaseNode* SkCandidateHeap::SelectNextNonKernelNode()
{
if (nonKernelNodes_.empty()) {
return nullptr;
}
return *nonKernelNodes_.begin();
}
bool SkCandidateHeap::IsMixKernelType(SkKernelType type)
{
return type == SkKernelType::MIX_AIC_1_1 || type == SkKernelType::MIX_AIC_1_2;
}
bool SkCandidateHeap::IsVecKernelType(SkKernelType type)
{
return type == SkKernelType::AIV_ONLY || type == SkKernelType::MIX_AIV_1_0;
}
bool SkCandidateHeap::IsCubeKernelType(SkKernelType type)
{
return type == SkKernelType::AIC_ONLY || type == SkKernelType::MIX_AIC_1_0;
}
SkCandidateHeap::KernelTypeClass SkCandidateHeap::GetKernelTypeClass(SkKernelType type)
{
if (IsMixKernelType(type)) {
return KernelTypeClass::MIX;
}
if (IsVecKernelType(type)) {
return KernelTypeClass::VEC;
}
if (IsCubeKernelType(type)) {
return KernelTypeClass::CUBE;
}
return KernelTypeClass::OTHER;
}
bool SkCandidateHeap::CompareByNodeId(SuperKernelBaseNode* a, SuperKernelBaseNode* b)
{
return a->GetNodeId() < b->GetNodeId();
}