* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_derivation_tile_shape.cpp
* \brief Unit test for DerivationTileShape pass.
*/
#include <gtest/gtest.h>
#include <vector>
#include <string>
#include "interface/function/function.h"
#include "tilefwk/tilefwk.h"
#include "ut_json/ut_json_tool.h"
#include "passes/pass_mgr/pass_manager.h"
#include "interface/configs/config_manager.h"
#include "interface/operation/operation.h"
#include "passes/tensor_graph_pass/derivation_tile_shape.h"
#include "computational_graph_builder.h"
using namespace npu::tile_fwk;
namespace npu {
namespace tile_fwk {
class DerivationTileShapeTest : public ::testing::Test {
public:
static void SetUpTestCase() {}
static void TearDownTestCase() {}
void SetUp() override
{
Program::GetInstance().Reset();
config::Reset();
config::SetHostOption(COMPILE_STAGE, CS_EXECUTE_GRAPH);
config::SetHostConfig(KEY_STRATEGY, "ReshapeTestStrategy");
config::SetPlatformConfig(KEY_ENABLE_COST_MODEL, false);
TileShape::Current().SetVecTile({64, 64});
}
void TearDown() override {}
};
static void BuildGraphAndCheck(ComputationalGraphBuilder& G)
{
std::vector<std::string> tensorNames{"t1"};
std::vector<std::string> tensorNames1{"t2"};
std::vector<Opcode> opCodes{Opcode::OP_RESHAPE};
std::vector<std::vector<std::string>> ioperands{{"t1"}};
std::vector<std::vector<std::string>> ooperands{{"t2"}};
std::vector<std::string> opNames{"RESHAPE"};
EXPECT_EQ(G.AddTensors(DataType::DT_FP32, {8, 6}, tensorNames), true);
EXPECT_EQ(G.AddTensors(DataType::DT_FP32, {2, 4, 6}, tensorNames1), true);
EXPECT_EQ(G.AddOps(opCodes, ioperands, ooperands, opNames, true), true);
EXPECT_EQ(G.SetInCast({"t1"}), true);
EXPECT_EQ(G.SetOutCast({"t2"}), true);
Function* function = G.GetFunction();
EXPECT_NE(function, nullptr);
}
static void BuildShapeAndCheckSucc(
Function* function, const Shape& inShape, const Shape& outShape, const std::vector<int64_t>& inTileShape,
const std::vector<int64_t>& resultTileShape)
{
std::vector<int64_t> outTileShape;
DerivationTileShape derivationTileShapePass;
auto opList = function->Operations().DuplicatedOpList();
for (size_t i = 0; i < opList.size(); ++i) {
if (opList[i]->GetOpcode() != Opcode::OP_RESHAPE) {
continue;
}
auto op = opList[i];
auto status =
derivationTileShapePass.DerivationReshapeTileShape(op, inShape, outShape, inTileShape, outTileShape);
EXPECT_EQ(status, SUCCESS);
EXPECT_EQ(outTileShape, resultTileShape);
}
}
static void BuildShapeAndCheckFail(
Function* function, const Shape& inShape, const Shape& outShape, const std::vector<int64_t>& inTileShape)
{
std::vector<int64_t> outTileShape;
DerivationTileShape derivationTileShapePass;
auto opList = function->Operations().DuplicatedOpList();
for (size_t i = 0; i < opList.size(); ++i) {
if (opList[i]->GetOpcode() != Opcode::OP_RESHAPE) {
continue;
}
auto op = opList[i];
auto status =
derivationTileShapePass.DerivationReshapeTileShape(op, inShape, outShape, inTileShape, outTileShape);
EXPECT_EQ(status, WARNING);
}
}
TEST_F(DerivationTileShapeTest, DerivationSplitSuccess)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {8, 6};
Shape outShape = {2, 4, 6};
std::vector<int64_t> inTileShape = {2, 3};
std::vector<int64_t> resultTileShape = {1, 2, 3};
BuildShapeAndCheckSucc(G.GetFunction(), inShape, outShape, inTileShape, resultTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationSplitAndMergeSuccess)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {30, 6};
Shape outShape = {2, 45, 2};
std::vector<int64_t> inTileShape = {5, 6};
std::vector<int64_t> resultTileShape = {1, 15, 2};
BuildShapeAndCheckSucc(G.GetFunction(), inShape, outShape, inTileShape, resultTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationSplitAndMergeInputShape1Success)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {30, 6, 1, 1};
Shape outShape = {2, 45, 2};
std::vector<int64_t> inTileShape = {5, 6, 1, 1};
std::vector<int64_t> resultTileShape = {1, 15, 2};
BuildShapeAndCheckSucc(G.GetFunction(), inShape, outShape, inTileShape, resultTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationSplitAndMergeOutputShape1Success)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {30, 6};
Shape outShape = {2, 45, 2, 1, 1};
std::vector<int64_t> inTileShape = {5, 6};
std::vector<int64_t> resultTileShape = {1, 15, 2, 1, 1};
BuildShapeAndCheckSucc(G.GetFunction(), inShape, outShape, inTileShape, resultTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationLargeSizeSuccess)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {30000, 60000};
Shape outShape = {300, 100, 60000};
std::vector<int64_t> inTileShape = {100, 100};
std::vector<int64_t> resultTileShape = {1, 100, 100};
BuildShapeAndCheckSucc(G.GetFunction(), inShape, outShape, inTileShape, resultTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationAlignShapeFailed)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {2, 3};
Shape outShape = {3, 2};
std::vector<int64_t> inTileShape = {2, 1};
BuildShapeAndCheckFail(G.GetFunction(), inShape, outShape, inTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationAlignTileFail)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {30, 6};
Shape outShape = {2, 45, 2};
std::vector<int64_t> inTileShape = {4, 2};
BuildShapeAndCheckFail(G.GetFunction(), inShape, outShape, inTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationOutTileFail)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {30, 6};
Shape outShape = {2, 45, 2};
std::vector<int64_t> inTileShape = {5, 2};
BuildShapeAndCheckFail(G.GetFunction(), inShape, outShape, inTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationAlignShapeInputProductFailed)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {4, 3};
Shape outShape = {2, 3, 2};
std::vector<int64_t> inTileShape = {2, 1};
BuildShapeAndCheckFail(G.GetFunction(), inShape, outShape, inTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationAlignShapeOutputProductFailed)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {2, 3, 2};
Shape outShape = {4, 3};
std::vector<int64_t> inTileShape = {2, 1, 1};
BuildShapeAndCheckFail(G.GetFunction(), inShape, outShape, inTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationAlignShapeZeroFailed)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {2, 3, 2};
Shape outShape = {4, 3};
std::vector<int64_t> inTileShape = {2, 0, 1};
BuildShapeAndCheckFail(G.GetFunction(), inShape, outShape, inTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationBiggerTileSuccess)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {2, 2};
Shape outShape = {1, 2, 2};
std::vector<int64_t> inTileShape = {2, 32};
std::vector<int64_t> resultTileShape = {1, 2, 32};
BuildShapeAndCheckSucc(G.GetFunction(), inShape, outShape, inTileShape, resultTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationSplitBiggerTileSuccess)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {2, 4};
Shape outShape = {1, 2, 2, 2};
std::vector<int64_t> inTileShape = {2, 32};
std::vector<int64_t> resultTileShape = {1, 2, 2, 16};
BuildShapeAndCheckSucc(G.GetFunction(), inShape, outShape, inTileShape, resultTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationMergeBiggerTileSuccess)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {2, 2, 2};
Shape outShape = {2, 4};
std::vector<int64_t> inTileShape = {2, 2, 32};
std::vector<int64_t> resultTileShape = {2, 64};
BuildShapeAndCheckSucc(G.GetFunction(), inShape, outShape, inTileShape, resultTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationSplitAndMergeBiggerTileSuccess)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {30, 6};
Shape outShape = {2, 45, 2};
std::vector<int64_t> inTileShape = {5, 36};
std::vector<int64_t> resultTileShape = {1, 15, 12};
BuildShapeAndCheckSucc(G.GetFunction(), inShape, outShape, inTileShape, resultTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationMergeBiggerTileFail)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {2, 4};
Shape outShape = {1, 2, 2, 2};
std::vector<int64_t> inTileShape = {2, 33};
BuildShapeAndCheckFail(G.GetFunction(), inShape, outShape, inTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationNonComplianceTileFail)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {2, 27};
Shape outShape = {2, 3, 3, 3};
std::vector<int64_t> inTileShape = {2, 6};
BuildShapeAndCheckFail(G.GetFunction(), inShape, outShape, inTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationNonComplianceTileSuccess)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {2, 27};
Shape outShape = {2, 3, 3, 3};
std::vector<int64_t> inTileShape = {2, 18};
std::vector<int64_t> resultTileShape = {2, 2, 3, 3};
BuildShapeAndCheckSucc(G.GetFunction(), inShape, outShape, inTileShape, resultTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationDiffSizeFail)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {2, 3};
Shape outShape = {2, 3, 3, 3};
std::vector<int64_t> inTileShape = {2, 2, 2};
BuildShapeAndCheckFail(G.GetFunction(), inShape, outShape, inTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationDiffSizeFail1)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {1, 1, 1};
Shape outShape = {1, 1, 1, 1, 2};
std::vector<int64_t> inTileShape = {1, 1, 2};
BuildShapeAndCheckFail(G.GetFunction(), inShape, outShape, inTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationCheckTileShapeSizeMismatch)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {2, 4};
Shape outShape = {1, 2, 2, 2};
std::vector<int64_t> inTileShape = {2, 33};
BuildShapeAndCheckFail(G.GetFunction(), inShape, outShape, inTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationCheckTileShapeDistanceMismatch)
{
ComputationalGraphBuilder G;
BuildGraphAndCheck(G);
Shape inShape = {2, 2, 6};
Shape outShape = {2, 4, 3};
std::vector<int64_t> inTileShape = {2, 2, 2};
BuildShapeAndCheckFail(G.GetFunction(), inShape, outShape, inTileShape);
}
TEST_F(DerivationTileShapeTest, DerivationReshapeTileShapeNonReshapeOp)
{
ComputationalGraphBuilder G;
G.AddTensor(DataType::DT_FP32, {8, 8}, "a");
G.AddTensor(DataType::DT_FP32, {8, 8}, "b");
G.AddOp(Opcode::OP_ADD, {"a"}, {"b"}, "add1");
G.SetInCast({"a"});
G.SetOutCast({"b"});
Function* function = G.GetFunction();
DerivationTileShape pass;
std::vector<int64_t> outTileShape;
auto addOp = function->Operations().DuplicatedOpList()[0];
EXPECT_EQ(pass.DerivationReshapeTileShape(addOp, {8, 8}, {8, 8}, {4, 4}, outTileShape), WARNING);
}
}
}