* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_remove_redundant_op.cpp
* \brief Unit test for RemoveRedundantOp pass.
*/
#include "gtest/gtest.h"
#include "symbolic_scalar_test_utils.h"
#include "tilefwk/tilefwk_op.h"
#include "interface/function/function.h"
#include "tilefwk/tilefwk.h"
#include "interface/inner/tilefwk.h"
#include "passes/pass_mgr/pass_manager.h"
#include "interface/configs/config_manager.h"
#include "ut_json/ut_json_tool.h"
#include "passes/tile_graph_pass/graph_optimization/remove_redundant_op.h"
#include "computational_graph_builder.h"
#include "interface/operation/attribute.h"
#include <fstream>
#include <vector>
#include <string>
#include "interface/tensor/irbuilder.h"
using namespace npu::tile_fwk;
void PrintGraphInfoRemoveRedundantOp(Function* func)
{
std::cout << "func->Operations().size() = " << func->Operations().size() << std::endl;
for (auto& op : func->Operations()) {
std::cout << "Op:" << op.GetOpMagic() << " " << op.GetOpcodeStr() << std::endl;
std::cout << "input operation:";
for (const std::shared_ptr<LogicalTensor>& input_tensor : op.GetIOperands()) {
for (const auto& item_op : input_tensor->GetProducers()) {
std::cout << "(" << item_op->opmagic << ", " << item_op->GetOpcodeStr() << ") ";
}
}
std::cout << std::endl << "output operation:";
for (const std::shared_ptr<LogicalTensor>& output_tensor : op.GetOOperands()) {
for (const auto& item_op : output_tensor->GetConsumers()) {
std::cout << "(" << item_op->opmagic << ", " << item_op->GetOpcodeStr() << ") ";
}
}
std::cout << std::endl;
}
}
void SetUpPassStrategy()
{
PassManager& passManager = PassManager::Instance();
passManager.RegisterStrategy(
"RemoveRedundantOpTestStrategy", {
{"RemoveRedundantReshape", PassName::REMOVE_REDUNDANT_RESHAPE},
{"InferMemoryConflict", PassName::INFER_MEMORY_CONFLICT},
{"ExpandFunction", PassName::EXPAND_FUNCTION},
{"DuplicateOp", PassName::DUPLICATE_OP},
{"MergeViewAssemble", PassName::MERGE_VIEW_ASSEMBLE},
{"AssignMemoryType", PassName::ASSIGN_MEMORY_TYPE},
{"SplitLargeFanoutTensor", PassName::SPLIT_LARGE_FANOUT_TENSOR},
{"SplitReshape", PassName::SPLIT_RESHAPE},
});
}
class RemoveRedundantOpTest : public testing::Test {
public:
static void SetUpTestCase() {}
static void TearDownTestCase() {}
void SetUp() override
{
Program::GetInstance().Reset();
config::Reset();
config::SetHostOption(COMPILE_STAGE, CS_EXECUTE_GRAPH);
config::SetHostConfig(KEY_STRATEGY, "RemoveRedundantOpTestStrategy");
config::SetPlatformConfig(KEY_ENABLE_COST_MODEL, false);
}
void TearDown() override {}
};
TEST_F(RemoveRedundantOpTest, TestIntermediateOutcast)
{
config::SetHostOption(COMPILE_STAGE, CS_EXECUTE_GRAPH);
int bs = 1;
int n = 32;
int d = 128;
std::vector<int64_t> shape{bs, n, d};
std::vector<int64_t> resShape{bs, n, d};
SetUpPassStrategy();
ConfigManager::Instance();
Tensor input(DataType::DT_FP32, shape, "input");
Tensor output(DataType::DT_FP32, resShape, "res");
Tensor output_add(DataType::DT_FP32, resShape, "res_add");
config::SetBuildStatic(true);
FUNCTION("RemoveRedundantOpFunction", {input, output, output_add})
{
TileShape::Current().SetVecTile(1, 32, 128);
output = Transpose(input, {0, 1});
TileShape::Current().SetVecTile(8, 1, 128);
output_add = Add(output, Element(DataType::DT_FP32, 0.0));
}
Function* func = Program::GetInstance().GetFunctionByRawName("TENSOR_RemoveRedundantOpFunction");
npu::tile_fwk::RemoveRedundantOp removeRedundantOp;
auto oriOpList = func->Operations(true);
EXPECT_EQ(oriOpList.size(), 15) << "Before the Pass, there should be 15 operations";
int ori_view_count = 0;
int ori_assemble_count = 0;
for (auto& op : oriOpList) {
if (op.GetOpcode() == Opcode::OP_VIEW) {
ori_view_count += 1;
} else if (op.GetOpcode() == Opcode::OP_ASSEMBLE) {
ori_assemble_count += 1;
}
}
EXPECT_EQ(ori_view_count, 5) << "There shoule be 5 VIEW op before RemoveRedundantOp";
EXPECT_EQ(ori_assemble_count, 5) << "There shoule be 5 ASSEMBLE op before RemoveRedundantOp";
removeRedundantOp.PreCheck(*func);
removeRedundantOp.RunOnFunction(*func);
removeRedundantOp.PostCheck(*func);
PrintGraphInfoRemoveRedundantOp(func);
auto updated_operations = func->Operations(true);
int opSize = 14;
EXPECT_EQ(updated_operations.size(), opSize) << "After the Pass, there should be 14 operations, no VIEW be deleted";
EXPECT_EQ(updated_operations[0].GetOpcode(), Opcode::OP_VIEW) << "The first operation should be VIEW";
int view_count = 0;
int assemble_count = 0;
for (auto& op : updated_operations) {
if (op.GetOpcode() == Opcode::OP_VIEW) {
view_count += 1;
} else if (op.GetOpcode() == Opcode::OP_ASSEMBLE) {
assemble_count += 1;
}
}
EXPECT_EQ(view_count, 5) << "There shoule be 5 ASSEMBLE op after RemoveRedundantOp";
EXPECT_EQ(assemble_count, 4) << "There shoule be 5 ASSEMBLE op after RemoveRedundantOp";
}
TEST_F(RemoveRedundantOpTest, TestInternalAssembleView)
{
config::SetHostOption(COMPILE_STAGE, CS_EXECUTE_GRAPH);
int bs = 4;
int n = 32;
int d = 128;
std::vector<int64_t> shape{bs, n, d};
std::vector<int64_t> resShape{bs, n, d};
SetUpPassStrategy();
ConfigManager::Instance();
Tensor input(DataType::DT_FP32, shape, "input");
Tensor output(DataType::DT_FP32, resShape, "res");
config::SetBuildStatic(true);
FUNCTION("RemoveRedundantOpFunction", {input, output})
{
TileShape::Current().SetVecTile(1, 32, 128);
auto tmp = Transpose(input, {0, 1});
TileShape::Current().SetVecTile(8, 1, 64);
output = Add(tmp, Element(DataType::DT_FP32, 3.0));
}
Function* func = Program::GetInstance().GetFunctionByRawName("TENSOR_RemoveRedundantOpFunction");
npu::tile_fwk::RemoveRedundantOp removeRedundantOp;
auto oriOpList = func->Operations(true);
int ori_view_count = 0;
int ori_assemble_count = 0;
for (auto& op : oriOpList) {
if (op.GetOpcode() == Opcode::OP_VIEW) {
ori_view_count += 1;
} else if (op.GetOpcode() == Opcode::OP_ASSEMBLE) {
ori_assemble_count += 1;
}
}
removeRedundantOp.PreCheck(*func);
removeRedundantOp.RunOnFunction(*func);
removeRedundantOp.PostCheck(*func);
PrintGraphInfoRemoveRedundantOp(func);
auto updated_operations = func->Operations(true);
int view_count = 0;
int assemble_count = 0;
for (auto& op : updated_operations) {
if (op.GetOpcode() == Opcode::OP_VIEW) {
view_count += 1;
} else if (op.GetOpcode() == Opcode::OP_ASSEMBLE) {
assemble_count += 1;
}
}
EXPECT_EQ(updated_operations.size(), oriOpList.size()) << "No op should be removed in RemoveRedundantOp";
EXPECT_EQ(view_count, ori_view_count) << "No VIEW op should be removed in RemoveRedundantOp";
EXPECT_EQ(assemble_count, ori_assemble_count) << "No ASSEMBLE op should be removed in RemoveRedundantOp";
}
std::shared_ptr<Function> SetUpParallelAssembleWithReshapeGraph()
{
auto func = std::make_shared<Function>(
Program::GetInstance(), "ProcessRedundantOpParallelAssembleWithReshape",
"ProcessRedundantOpParallelAssembleWithReshape", nullptr);
std::vector<int64_t> inputShape = {32, 128};
std::vector<int64_t> outputShape1 = {64, 128};
std::vector<int64_t> outputShape2 = {32, 128};
auto oriInput = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, inputShape, CreateTestConstIntVector(inputShape));
oriInput->SetMemoryTypeBoth(MemoryType::MEM_UB, true);
auto sharedInput = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, inputShape, CreateTestConstIntVector(inputShape));
sharedInput->SetMemoryTypeBoth(MemoryType::MEM_UB, true);
auto anotherInput = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, inputShape, CreateTestConstIntVector(inputShape));
anotherInput->SetMemoryTypeBoth(MemoryType::MEM_UB, true);
auto outputA = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, outputShape1, CreateTestConstIntVector(outputShape1));
outputA->SetMemoryTypeBoth(MemoryType::MEM_UB, true);
auto outputB = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, outputShape2, CreateTestConstIntVector(outputShape2));
outputB->SetMemoryTypeBoth(MemoryType::MEM_UB, true);
IRBuilder().CreateTensorOpStmt(*func, Opcode::OP_ADDS, {oriInput}, {sharedInput});
IRBuilder().CreateTensorOpStmt(*func, Opcode::OP_ADDS, {oriInput}, {anotherInput});
auto& assemble1 = IRBuilder().CreateTensorOpStmt(*func, Opcode::OP_ASSEMBLE, {anotherInput}, {outputA});
assemble1.SetOpAttribute(std::make_shared<AssembleOpAttribute>(std::vector<int64_t>{0, 0}));
auto& assemble2 = IRBuilder().CreateTensorOpStmt(*func, Opcode::OP_ASSEMBLE, {sharedInput}, {outputA});
assemble2.SetOpAttribute(std::make_shared<AssembleOpAttribute>(std::vector<int64_t>{32, 0}));
auto& assemble3 = IRBuilder().CreateTensorOpStmt(*func, Opcode::OP_ASSEMBLE, {sharedInput}, {outputB});
assemble3.SetOpAttribute(std::make_shared<AssembleOpAttribute>(std::vector<int64_t>{0, 0}));
std::vector<int64_t> reshapeShape = {4096};
auto reshapeOut = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, reshapeShape, CreateTestConstIntVector(reshapeShape));
reshapeOut->SetMemoryTypeBoth(MemoryType::MEM_UB, true);
IRBuilder().CreateTensorOpStmt(*func, Opcode::OP_RESHAPE, {outputB}, {reshapeOut});
func->inCasts_.push_back(oriInput);
func->outCasts_.push_back(outputA);
func->outCasts_.push_back(reshapeOut);
return func;
}
TEST_F(RemoveRedundantOpTest, ProcessParallelAssembleWithReshape)
{
auto currFunctionPtr = SetUpParallelAssembleWithReshapeGraph();
EXPECT_TRUE(currFunctionPtr != nullptr);
Function* func = currFunctionPtr.get();
EXPECT_NE(func, nullptr);
auto oriOpList = func->Operations(true);
int oriAssembleCount = 0;
for (auto& op : oriOpList) {
if (op.GetOpcode() == Opcode::OP_ASSEMBLE) {
oriAssembleCount++;
}
}
EXPECT_EQ(oriAssembleCount, 3) << "Should have 3 ASSEMBLE ops before pass";
RemoveRedundantOp removeRedundantOp;
removeRedundantOp.PreCheck(*func);
removeRedundantOp.RunOnFunction(*func);
removeRedundantOp.PostCheck(*func);
auto updatedOps = func->Operations(true);
int newAssembleCount = 0;
for (auto& op : updatedOps) {
if (op.GetOpcode() == Opcode::OP_ASSEMBLE) {
newAssembleCount++;
}
}
EXPECT_EQ(newAssembleCount, oriAssembleCount)
<< "ASSEMBLE ops should NOT be deleted when hasParallelAssemble=true and hasReshapeConsumer=true";
}