* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_remove_redundant_reshape.cpp
* \brief Unit test for RemoveRedundantReshape pass.
*/
#include <gtest/gtest.h>
#include "interface/function/function.h"
#include "tilefwk/tilefwk.h"
#include "interface/inner/tilefwk.h"
#include "passes/pass_mgr/pass_manager.h"
#include "interface/configs/config_manager.h"
#include <fstream>
#include <vector>
#include <string>
using namespace npu::tile_fwk;
class RemoveRedundantReshapeTest : public ::testing::Test {
public:
static void SetUpTestCase() {}
static void TearDownTestCase() {}
void SetUp() override
{
Program::GetInstance().Reset();
config::Reset();
config::SetHostOption(COMPILE_STAGE, CS_EXECUTE_GRAPH);
config::SetHostConfig(KEY_STRATEGY, "ReshapeTestStrategy");
config::SetPlatformConfig(KEY_ENABLE_COST_MODEL, false);
}
void TearDown() override {}
};
TEST_F(RemoveRedundantReshapeTest, TestReshapeChain)
{
std::vector<int64_t> shape1{1, 256, 512};
std::vector<int64_t> shape2{1, 512, 256};
std::vector<int64_t> shape3{1, 128, 1024};
Tensor in_tensor(DT_FP32, shape1, "in_tensor");
Tensor out_tensor_B(DT_FP32, shape3, "out_tensor_B");
PassManager& passManager = PassManager::Instance();
passManager.RegisterStrategy(
"ReshapeTestStrategy", {
{"RemoveRedundantReshape", PassName::REMOVE_REDUNDANT_RESHAPE},
});
ConfigManager::Instance();
Function* originFunction = nullptr;
std::vector<int64_t> originOpmagic;
FUNCTION("ReshapeChainFunction")
{
Tensor out_tensor_A = Reshape(in_tensor, shape2);
out_tensor_B = Reshape(out_tensor_A, shape3);
originFunction = Program::GetInstance().GetCurrentFunction();
ASSERT_NE(originFunction, nullptr) << "Current function pointer is null";
auto operations = originFunction->Operations();
for (const auto& op : operations) {
originOpmagic.emplace_back(op.opmagic);
}
}
Function* currentFunction = Program::GetInstance().GetFunctionByRawName("TENSOR_ReshapeChainFunction");
auto updated_operations = currentFunction->Operations();
EXPECT_EQ(updated_operations.size(), 3)
<< "After the Pass, there should be 3 operations (View + Reshape + Assemble)";
EXPECT_EQ(updated_operations[0].GetOpcode(), Opcode::OP_VIEW) << "View operation should be kept";
EXPECT_EQ(updated_operations[1].GetOpcode(), Opcode::OP_RESHAPE) << "The second Reshape (valid) should be kept";
EXPECT_EQ(updated_operations[1].opmagic, originOpmagic.back()) << "The kept Reshape is the second one";
EXPECT_EQ(updated_operations[2].GetOpcode(), Opcode::OP_ASSEMBLE) << "Assemble operation should be kept";
auto& view_op = updated_operations[0];
auto& remaining_reshape_op = updated_operations[1];
EXPECT_EQ(remaining_reshape_op.GetIOperands()[0], view_op.GetOOperands()[0])
<< "The output of View should connect to the input of the second Reshape";
auto& reshape_input = remaining_reshape_op.GetIOperands()[0];
EXPECT_EQ(reshape_input->shape, shape1)
<< "The input shape of the remaining Reshape operation should be the same as shape1";
auto& reshape_output = remaining_reshape_op.GetOOperands()[0];
EXPECT_EQ(reshape_output->shape, shape3)
<< "The output shape of the remaining Reshape operation should be the same as shape3";
}
TEST_F(RemoveRedundantReshapeTest, TestReplaceInput)
{
std::vector<int64_t> shape1{1, 256, 512};
std::vector<int64_t> shape2{1, 512, 256};
std::vector<int64_t> shape3{1, 128, 1024};
Tensor in_tensor(DT_FP32, shape1, "in_tensor");
Tensor out_tensor_B(DT_FP32, shape2, "out_tensor_B");
Tensor out_tensor_C(DT_FP32, shape3, "out_tensor_C");
PassManager& passManager = PassManager::Instance();
passManager.RegisterStrategy(
"ReshapeTestStrategy", {{"RemoveRedundantReshape", PassName::REMOVE_REDUNDANT_RESHAPE}});
TileShape::Current().SetVecTile({1, 64, 64});
Function* originFunction = nullptr;
std::vector<int64_t> reshape_opmagics;
FUNCTION("ReplaceInputFunction")
{
Tensor out_tensor_A = Reshape(in_tensor, shape2);
out_tensor_B = Reshape(out_tensor_A, shape3);
Tensor add_tensor(DT_FP32, shape2, "add_tensor");
out_tensor_C = Add(out_tensor_A, add_tensor);
originFunction = Program::GetInstance().GetCurrentFunction();
ASSERT_NE(originFunction, nullptr) << "Current function pointer is null";
auto operations = originFunction->Operations();
for (const auto& op : operations) {
if (op.GetOpcodeStr() == "RESHAPE") {
reshape_opmagics.push_back(op.opmagic);
}
}
}
Function* currentFunction = Program::GetInstance().GetFunctionByRawName("TENSOR_ReplaceInputFunction");
auto updated_operations = currentFunction->Operations();
Operation* first_reshape_op = nullptr;
for (size_t i = 0; i < updated_operations.size(); i++) {
auto& op = updated_operations[i];
if (op.opmagic == reshape_opmagics[0]) {
first_reshape_op = &op;
}
if (op.opmagic == reshape_opmagics[1]) {
auto& second_reshape_op = op;
if (first_reshape_op) {
EXPECT_EQ(second_reshape_op.GetIOperands()[0], first_reshape_op->GetIOperands()[0])
<< "The input of the second Reshape operation should be replaced by the input of the first Reshape "
"operation";
auto& reshape_input = second_reshape_op.GetIOperands()[0];
EXPECT_EQ(reshape_input->shape, shape1)
<< "The input shape of the second Reshape operation should be the same as shape1";
} else {
std::cerr << "Error: first_reshape_op not found!" << std::endl;
}
}
}
EXPECT_EQ(updated_operations.size(), 6)
<< "After the Pass, there should be 5 operations (View + Reshape + Reshape + Add + Assemble)";
}