* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <gtest/gtest.h>
#include <memory>
#include <vector>
#include <algorithm>
#include <bitset>
#define private public
#define protected public
#include "sk_graph.h"
#include "sk_optimizer.h"
#include "sk_options_manager.h"
#include "sk_node.h"
#include "sk_candidate_heap.h"
class SkOptimizerTaskReorderTest : public testing::Test {
protected:
void SetUp() override
{
graph = std::make_unique<SuperKernelGraph>();
opts = std::make_unique<SuperKernelOptionsManager>();
optimizer = std::make_unique<SuperKernelOptimizer>(*opts);
}
SuperKernelBaseNode* CreateKernelNode(uint64_t nodeId, uint32_t streamIdx)
{
auto node = std::make_unique<SuperKernelKernelNode>(
nullptr, ACL_MODEL_RI_TASK_KERNEL, 0, streamIdx, INVALID_STREAM_ID, INVALID_TASK_ID);
node->SetNodeType(SkNodeType::NODE_KERNEL);
node->SetNodeId(nodeId);
node->SetPreNodeId(INVALID_TASK_ID);
node->SetNextNodeId(INVALID_TASK_ID);
node->nodeInfos.kernelInfos.funcName = "k";
auto* ptr = node.get();
graph->graphMap[nodeId] = std::move(node);
return ptr;
}
SuperKernelBaseNode* CreateWaitNode(uint64_t nodeId, uint32_t streamIdx)
{
auto node = std::make_unique<SuperKernelMemoryNode>(
nullptr, ACL_MODEL_RI_TASK_EVENT_WAIT, 0, streamIdx, INVALID_STREAM_ID, INVALID_TASK_ID);
node->SetNodeType(SkNodeType::NODE_WAIT);
node->SetNodeId(nodeId);
node->SetPreNodeId(INVALID_TASK_ID);
node->SetNextNodeId(INVALID_TASK_ID);
auto* ptr = node.get();
graph->graphMap[nodeId] = std::move(node);
return ptr;
}
std::vector<uint64_t> ToNodeIds(const std::vector<SuperKernelBaseNode*>& nodes)
{
std::vector<uint64_t> nodeIds;
nodeIds.reserve(nodes.size());
for (const auto* node : nodes) {
nodeIds.push_back(node == nullptr ? INVALID_TASK_ID : node->GetNodeId());
}
return nodeIds;
}
void EnableCustomizeQueueReorder()
{
auto option = std::make_unique<NumberOptOption>(
"auto_op_parallel", aclskOptionType::AUTO_OP_PARALLEL,
static_cast<uint32_t>(SkHeapType::CUSTOMIZE_QUEUE), 0, 0xFFFFFFFF);
opts->AddOption(std::move(option));
}
std::unique_ptr<SuperKernelGraph> graph;
std::unique_ptr<SuperKernelOptionsManager> opts;
std::unique_ptr<SuperKernelOptimizer> optimizer;
};
TEST_F(SkOptimizerTaskReorderTest, ReorderWaitNodesForTaskBuild_DefaultDisabled_KeepOriginalOrder)
{
std::vector<SuperKernelBaseNode*> taskNodes = {
CreateKernelNode(1, 0),
CreateWaitNode(2, 1),
CreateKernelNode(3, 1),
};
std::vector<SuperKernelBaseNode*> reorderedTaskNodes = optimizer->ReorderWaitNodesForTaskBuild(taskNodes);
EXPECT_EQ(ToNodeIds(reorderedTaskNodes), (std::vector<uint64_t>{1, 2, 3}));
}
TEST_F(SkOptimizerTaskReorderTest, ReorderWaitNodesForTaskBuild_CustomizeQueue_WaitMovesBeforeLaterSameStreamKernel)
{
EnableCustomizeQueueReorder();
std::vector<SuperKernelBaseNode*> taskNodes = {
CreateKernelNode(1, 0),
CreateWaitNode(2, 1),
CreateKernelNode(3, 1),
};
std::vector<SuperKernelBaseNode*> reorderedTaskNodes = optimizer->ReorderWaitNodesForTaskBuild(taskNodes);
EXPECT_EQ(ToNodeIds(reorderedTaskNodes), (std::vector<uint64_t>{1, 2, 3}));
}
TEST_F(SkOptimizerTaskReorderTest, ReorderWaitNodesForTaskBuild_CustomizeQueue_WaitWithoutLaterSameStreamKernelStaysInPlace)
{
EnableCustomizeQueueReorder();
std::vector<SuperKernelBaseNode*> taskNodes = {
CreateKernelNode(1, 0),
CreateWaitNode(2, 1),
CreateKernelNode(3, 0),
};
std::vector<SuperKernelBaseNode*> reorderedTaskNodes = optimizer->ReorderWaitNodesForTaskBuild(taskNodes);
EXPECT_EQ(ToNodeIds(reorderedTaskNodes), (std::vector<uint64_t>{1, 2, 3}));
}
TEST_F(SkOptimizerTaskReorderTest, ReorderWaitNodesForTaskBuild_CustomizeQueue_ConsecutiveWaitsSameStreamKeepRelativeOrder)
{
EnableCustomizeQueueReorder();
std::vector<SuperKernelBaseNode*> taskNodes = {
CreateWaitNode(10, 1),
CreateKernelNode(20, 0),
CreateWaitNode(11, 1),
CreateKernelNode(12, 1),
};
std::vector<SuperKernelBaseNode*> reorderedTaskNodes = optimizer->ReorderWaitNodesForTaskBuild(taskNodes);
EXPECT_EQ(ToNodeIds(reorderedTaskNodes), (std::vector<uint64_t>{20, 10, 11, 12}));
}
TEST_F(SkOptimizerTaskReorderTest, ReorderWaitNodesForTaskBuild_CustomizeQueue_ConsecutiveWaitsDifferentStreamsAttachToOwnKernel)
{
EnableCustomizeQueueReorder();
std::vector<SuperKernelBaseNode*> taskNodes = {
CreateWaitNode(10, 1),
CreateWaitNode(11, 2),
CreateKernelNode(20, 0),
CreateKernelNode(12, 1),
CreateKernelNode(13, 2),
};
std::vector<SuperKernelBaseNode*> reorderedTaskNodes = optimizer->ReorderWaitNodesForTaskBuild(taskNodes);
EXPECT_EQ(ToNodeIds(reorderedTaskNodes), (std::vector<uint64_t>{20, 10, 12, 11, 13}));
}
TEST_F(SkOptimizerTaskReorderTest, ReorderWaitNodesForTaskBuild_CustomizeQueue_OutOfOrderTargetsDoNotFlushPrefix)
{
EnableCustomizeQueueReorder();
std::vector<SuperKernelBaseNode*> taskNodes = {
CreateWaitNode(10, 1),
CreateWaitNode(11, 2),
CreateKernelNode(13, 2),
CreateKernelNode(12, 1),
};
std::vector<SuperKernelBaseNode*> reorderedTaskNodes = optimizer->ReorderWaitNodesForTaskBuild(taskNodes);
EXPECT_EQ(ToNodeIds(reorderedTaskNodes), (std::vector<uint64_t>{11, 13, 10, 12}));
}
TEST_F(SkOptimizerTaskReorderTest, ReorderWaitNodesForTaskBuild_CustomizeQueue_AllNodesSameStream_KeepOriginalOrder)
{
EnableCustomizeQueueReorder();
std::vector<SuperKernelBaseNode*> taskNodes = {
CreateKernelNode(1, 3),
CreateWaitNode(2, 3),
CreateKernelNode(3, 3),
CreateWaitNode(4, 3),
CreateKernelNode(5, 3),
};
std::vector<SuperKernelBaseNode*> reorderedTaskNodes = optimizer->ReorderWaitNodesForTaskBuild(taskNodes);
EXPECT_EQ(ToNodeIds(reorderedTaskNodes), (std::vector<uint64_t>{1, 2, 3, 4, 5}));
}
TEST_F(SkOptimizerTaskReorderTest, ReorderWaitNodesForTaskBuild_CustomizeQueue_MultipleWaitsTargetSameKernelKeepRelativeOrder)
{
EnableCustomizeQueueReorder();
std::vector<SuperKernelBaseNode*> taskNodes = {
CreateKernelNode(1, 0),
CreateWaitNode(2, 2),
CreateKernelNode(3, 1),
CreateWaitNode(4, 2),
CreateKernelNode(5, 2),
};
std::vector<SuperKernelBaseNode*> reorderedTaskNodes = optimizer->ReorderWaitNodesForTaskBuild(taskNodes);
EXPECT_EQ(ToNodeIds(reorderedTaskNodes), (std::vector<uint64_t>{1, 3, 2, 4, 5}));
}
TEST_F(SkOptimizerTaskReorderTest, ReorderWaitNodesForTaskBuild_CustomizeQueue_NonWaitRelativeOrderPreserved)
{
EnableCustomizeQueueReorder();
std::vector<SuperKernelBaseNode*> taskNodes = {
CreateKernelNode(1, 0),
CreateWaitNode(2, 1),
CreateKernelNode(3, 0),
CreateKernelNode(4, 1),
CreateKernelNode(5, 0),
};
std::vector<SuperKernelBaseNode*> reorderedTaskNodes = optimizer->ReorderWaitNodesForTaskBuild(taskNodes);
EXPECT_EQ(ToNodeIds(reorderedTaskNodes), (std::vector<uint64_t>{1, 3, 2, 4, 5}));
}
* @brief Test case: only op4-op9 are fusible, others are not
*
* Graph structure:
* stream1: op1 -> op4 -> op5 -> op10
* stream2: op2 -> op6 -> op7 -> op11
* stream3: op3 -> op8 -> op9 -> op12
*
* op4-op9 are fusible (SetIsFusible(true))
* op1, op2, op3, op10, op11, op12 are NOT fusible (default isFusible=false)
*
* Expected: SplitGraph should split op4-op9 into one scope
* In Process() loop at line 310, this scope should be processed
*/
TEST_F(SkOptimizerTaskReorderTest, ThreeStreamParallel_OnlyOp4ToOp9Fusible)
{
auto* op1 = CreateKernelNode(1, 0);
auto* op4 = CreateKernelNode(4, 0);
auto* op5 = CreateKernelNode(5, 0);
auto* op10 = CreateKernelNode(10, 0);
auto* op2 = CreateKernelNode(2, 1);
auto* op6 = CreateKernelNode(6, 1);
auto* op7 = CreateKernelNode(7, 1);
auto* op11 = CreateKernelNode(11, 1);
auto* op3 = CreateKernelNode(3, 2);
auto* op8 = CreateKernelNode(8, 2);
auto* op9 = CreateKernelNode(9, 2);
auto* op12 = CreateKernelNode(12, 2);
op4->SetIsFusible(true);
op5->SetIsFusible(true);
op6->SetIsFusible(true);
op7->SetIsFusible(true);
op8->SetIsFusible(true);
op9->SetIsFusible(true);
graph->streams.clear();
graph->headNodes.clear();
graph->streams.emplace_back();
graph->headNodes.push_back(1);
op1->SetNextNodeId(4);
op1->SetPreNodeId(INVALID_TASK_ID);
op1->nodeIdxInStream = 0;
op1->streamIdxInGraph = 0;
op4->SetNextNodeId(5);
op4->SetPreNodeId(1);
op4->nodeIdxInStream = 1;
op4->streamIdxInGraph = 0;
op5->SetNextNodeId(10);
op5->SetPreNodeId(4);
op5->nodeIdxInStream = 2;
op5->streamIdxInGraph = 0;
op10->SetNextNodeId(INVALID_TASK_ID);
op10->SetPreNodeId(5);
op10->nodeIdxInStream = 3;
op10->streamIdxInGraph = 0;
graph->streams.emplace_back();
graph->headNodes.push_back(2);
op2->SetNextNodeId(6);
op2->SetPreNodeId(INVALID_TASK_ID);
op2->nodeIdxInStream = 0;
op2->streamIdxInGraph = 1;
op6->SetNextNodeId(7);
op6->SetPreNodeId(2);
op6->nodeIdxInStream = 1;
op6->streamIdxInGraph = 1;
op7->SetNextNodeId(11);
op7->SetPreNodeId(6);
op7->nodeIdxInStream = 2;
op7->streamIdxInGraph = 1;
op11->SetNextNodeId(INVALID_TASK_ID);
op11->SetPreNodeId(7);
op11->nodeIdxInStream = 3;
op11->streamIdxInGraph = 1;
graph->streams.emplace_back();
graph->headNodes.push_back(3);
op3->SetNextNodeId(8);
op3->SetPreNodeId(INVALID_TASK_ID);
op3->nodeIdxInStream = 0;
op3->streamIdxInGraph = 2;
op8->SetNextNodeId(9);
op8->SetPreNodeId(3);
op8->nodeIdxInStream = 1;
op8->streamIdxInGraph = 2;
op9->SetNextNodeId(12);
op9->SetPreNodeId(8);
op9->nodeIdxInStream = 2;
op9->streamIdxInGraph = 2;
op12->SetNextNodeId(INVALID_TASK_ID);
op12->SetPreNodeId(9);
op12->nodeIdxInStream = 3;
op12->streamIdxInGraph = 2;
bool result = optimizer->Process(*graph);
SK_LOGI("Process result: %d", result);
const auto& scopeInfos = optimizer->GetScopeInfos();
SK_LOGI("Number of scopes: %zu", scopeInfos.size());
for (size_t i = 0; i < scopeInfos.size(); ++i) {
const auto& scope = scopeInfos[i];
SK_LOGI("Scope %zu: scopeId=%u, nodeCount=%zu", i, scope.GetScopeId(), scope.GetNodes().size());
for (const auto* node : scope.GetNodes()) {
if (node != nullptr) {
SK_LOGI(" Node: %s", node->Format().c_str());
}
}
}
bool foundScopeWithOp4ToOp9 = false;
for (size_t i = 0; i < scopeInfos.size(); ++i) {
const auto& scope = scopeInfos[i];
const auto& nodes = scope.GetNodes();
bool hasOp4 = false, hasOp5 = false, hasOp6 = false;
bool hasOp7 = false, hasOp8 = false, hasOp9 = false;
for (const auto* node : nodes) {
if (node == nullptr) continue;
uint64_t id = node->GetNodeId();
if (id == 4) hasOp4 = true;
if (id == 5) hasOp5 = true;
if (id == 6) hasOp6 = true;
if (id == 7) hasOp7 = true;
if (id == 8) hasOp8 = true;
if (id == 9) hasOp9 = true;
}
if (hasOp4 && hasOp5 && hasOp6 && hasOp7 && hasOp8 && hasOp9) {
foundScopeWithOp4ToOp9 = true;
SK_LOGI("Found scope with op4-op9 at index %zu", i);
break;
}
}
EXPECT_TRUE(foundScopeWithOp4ToOp9);
}