* Copyright (c) Huawei Technologies Co., Ltd. 2026-2026. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
* \file fia_tiling_shape.cpp
* \brief
*/
#include <vector>
#include <algorithm>
#include "fia_tiling_shape.h"
namespace optiling {
static const std::map<FiaLayout, std::vector<FiaAxis>> FIA_LAYOUT_AXIS_MAP = {
{FiaLayout::BSH, {FiaAxis::B, FiaAxis::S, FiaAxis::H}},
{FiaLayout::BSND, {FiaAxis::B, FiaAxis::S, FiaAxis::N, FiaAxis::D}},
{FiaLayout::BNSD, {FiaAxis::B, FiaAxis::N, FiaAxis::S, FiaAxis::D}},
{FiaLayout::NZ, {FiaAxis::Bn, FiaAxis::N, FiaAxis::D1, FiaAxis::Bs, FiaAxis::D0}},
{FiaLayout::TND, {FiaAxis::T, FiaAxis::N, FiaAxis::D}},
{FiaLayout::NBSD, {FiaAxis::N, FiaAxis::B, FiaAxis::S, FiaAxis::D}},
{FiaLayout::NTD, {FiaAxis::N, FiaAxis::T, FiaAxis::D}},
{FiaLayout::BS2, {FiaAxis::B, FiaAxis::S2}},
{FiaLayout::S1S2, {FiaAxis::S1, FiaAxis::S2}},
{FiaLayout::BnBsH, {FiaAxis::Bn, FiaAxis::Bs, FiaAxis::H}},
{FiaLayout::BnNBsD, {FiaAxis::Bn, FiaAxis::N, FiaAxis::Bs, FiaAxis::D}},
{FiaLayout::BNS1S2, {FiaAxis::B, FiaAxis::N, FiaAxis::S1, FiaAxis::S2}},
{FiaLayout::INS1S2, {FiaAxis::CONST, FiaAxis::N, FiaAxis::S1, FiaAxis::S2}},
{FiaLayout::BNS11, {FiaAxis::B, FiaAxis::N, FiaAxis::S1, FiaAxis::CONST}},
{FiaLayout::TN1, {FiaAxis::T, FiaAxis::N, FiaAxis::CONST}},
{FiaLayout::BS1S2, {FiaAxis::B, FiaAxis::S1, FiaAxis::S2}},
{FiaLayout::B1S1S2, {FiaAxis::B, FiaAxis::CONST, FiaAxis::S1, FiaAxis::S2}},
{FiaLayout::IS1S2, {FiaAxis::CONST, FiaAxis::S1, FiaAxis::S2}},
{FiaLayout::I1S1S2, {FiaAxis::CONST, FiaAxis::CONST, FiaAxis::S1, FiaAxis::S2}},
{FiaLayout::S1S1, {FiaAxis::S1}},
};
static bool equal_to(const int64_t &a, const int64_t &b) { return (a == b); }
static bool greater(const int64_t &a, const int64_t &b) { return (a > b); }
static bool greater_equal(const int64_t &a, const int64_t &b) { return (a >= b); }
static bool less(const int64_t &a, const int64_t &b) { return (a < b); }
static bool less_equal(const int64_t &a, const int64_t &b) { return (a <= b); }
static bool not_equal_to(const int64_t &a, const int64_t &b) { return (a != b); }
static bool ignore_input(const int64_t &a, const int64_t &b) {
(void)a;
(void)b;
return true;
}
static ge::graphStatus GetLayoutAxes(
std::vector<FiaAxis> &layoutAxes, const FiaLayout &layout, const std::string &opName, const std::string &funcName) {
auto it = FIA_LAYOUT_AXIS_MAP.find(layout);
if (it == FIA_LAYOUT_AXIS_MAP.end()) {
OP_LOGE(
opName, "[%s] compare layout %s is unsupported.", funcName.c_str(), LayoutToSerialString(layout).c_str());
return ge::GRAPH_FAILED;
}
layoutAxes = it->second;
return ge::GRAPH_SUCCESS;
}
const std::map<FiaCompareType, CompareFunc<int64_t>> FiaTilingShapeCompare::compareFuncMap_ = {
{FiaCompareType::EQUAL, equal_to}, {FiaCompareType::GREATER, greater},
{FiaCompareType::GREATER_EQUAL, greater_equal}, {FiaCompareType::LESS, less},
{FiaCompareType::LESS_EQUAL, less_equal}, {FiaCompareType::NOT_EQUAL, not_equal_to},
{FiaCompareType::IGNORE_INPUT, ignore_input}
};
std::string LayoutToSerialString(FiaLayout layout) {
const std::map<FiaLayout, std::string> layout2Str = {{FiaLayout::BSH, "BSH"}, {FiaLayout::BSND, "BSND"},
{FiaLayout::BNSD, "BNSD"}, {FiaLayout::NZ, "NZ"}, {FiaLayout::TND, "TND"}, {FiaLayout::NBSD, "NBSD"},
{FiaLayout::NTD, "NTD"}, {FiaLayout::S1S2, "S1S2"}, {FiaLayout::BS2, "BS2"}, {FiaLayout::BnBsH, "BnBsH"},
{FiaLayout::BnNBsD, "BnNBsD"}, {FiaLayout::BNS1S2, "BNS1S2"}, {FiaLayout::INS1S2, "1NS1S2"},
{FiaLayout::BNS11, "BNS11"}, {FiaLayout::TN1, "TN1"}, {FiaLayout::BS1S2, "BS1S2"},
{FiaLayout::B1S1S2, "B1S1S2"}, {FiaLayout::IS1S2, "1S1S2"}, {FiaLayout::I1S1S2, "11S1S2"}};
if (layout2Str.find(layout) != layout2Str.end()) {
return layout2Str.at(layout);
}
return "UNKNOWN";
}
std::string AxisToSerialString(FiaAxis axis) {
switch (axis) {
case FiaAxis::B:
return "B";
case FiaAxis::S:
return "S";
case FiaAxis::N:
return "N";
case FiaAxis::D:
return "D";
case FiaAxis::H:
return "H";
case FiaAxis::T:
return "T";
case FiaAxis::D1:
return "D1";
case FiaAxis::D0:
return "D0";
case FiaAxis::S1:
return "S1";
case FiaAxis::S2:
return "S2";
case FiaAxis::Bn:
return "Bn";
case FiaAxis::Bs:
return "Bs";
case FiaAxis::CONST:
return "CONST";
default:
return "UNKNOWN";
}
}
bool FiaTilingShape::HasAxis(const FiaAxis &axis) const {
const auto &layoutIt = FIA_LAYOUT_AXIS_MAP.find(layout_);
if (layoutIt == FIA_LAYOUT_AXIS_MAP.end()) {
return false;
}
const std::vector<FiaAxis> &axes = layoutIt->second;
const auto &axisIt = std::find(axes.begin(), axes.end(), axis);
if (axisIt == axes.end()) {
return false;
}
return true;
}
size_t FiaTilingShape::GetAxisIdx(const FiaAxis &axis) const {
if (HasAxis(axis)) {
const std::vector<FiaAxis> &axes = FIA_LAYOUT_AXIS_MAP.find(layout_)->second;
const auto &axisIt = std::find(axes.begin(), axes.end(), axis);
return std::distance(axes.begin(), axisIt);
}
return 0;
}
int64_t FiaTilingShape::GetAxisNum(const FiaAxis &axis) const {
return HasAxis(axis) ? shape_.GetDim(GetAxisIdx(axis)) : invalidDimValue_;
}
ge::graphStatus FiaTilingShape::CheckHasAxis(const FiaAxis &axis, const std::string &funcName) const {
if (shape_.GetDimNum() == 0) {
OP_LOGE(opName_, "[%s] the dim number of %s is 0.", funcName.c_str(), name_.c_str());
return ge::GRAPH_FAILED;
}
std::vector<FiaAxis> layoutAxes;
if (GetLayoutAxes(layoutAxes, layout_, opName_, funcName) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
if (shape_.GetDimNum() != layoutAxes.size()) {
OP_LOGE(opName_,
"[%s] %s shape dimension is %zu, expected shape dimension is %zu, layout(%s) axes size is %zu, they should "
"be equal.",
funcName.c_str(), name_.c_str(), shape_.GetDimNum(), layoutAxes.size(),
LayoutToSerialString(layout_).c_str(), layoutAxes.size());
return ge::GRAPH_FAILED;
}
if ((axis == FiaAxis::D)) {
if (HasShapeD()) {
return ge::GRAPH_SUCCESS;
} else if (!HasShapeH()) {
OP_LOGE(opName_, "[%s] %s's layout is %s, do not have D and H.", funcName.c_str(), name_.c_str(),
LayoutToSerialString(layout_).c_str());
return ge::GRAPH_FAILED;
} else if (!hasSetN_) {
OP_LOGE(opName_, "[%s] %s's N is not specified, cannot caculate D by H.", funcName.c_str(), name_.c_str());
return ge::GRAPH_FAILED;
} else if (N_ == 0) {
OP_LOGE(opName_, "[%s] %s's N is 0.", funcName.c_str(), name_.c_str());
return ge::GRAPH_FAILED;
} else if (GetShapeH() % N_ != 0) {
OP_LOGE(opName_, "[%s] %s's H(%ld) should be an integer multiple of N(%ld).", funcName.c_str(),
name_.c_str(), GetShapeH(), N_);
return ge::GRAPH_FAILED;
}
} else if (HasAxis(axis)) {
return ge::GRAPH_SUCCESS;
}
OP_LOGE(opName_, "[%s] %s's layout is %s, %s is not exists.", funcName.c_str(), name_.c_str(),
LayoutToSerialString(layout_).c_str(), AxisToSerialString(axis).c_str());
return ge::GRAPH_FAILED;
}
std::string FiaTilingShapeCompare::CompareTypeToSerialString(const FiaCompareType compareType) const {
switch (compareType) {
case FiaCompareType::EQUAL:
return "EQUAL";
case FiaCompareType::GREATER:
return "GREATER";
case FiaCompareType::GREATER_EQUAL:
return "GREATER_EQUAL";
case FiaCompareType::LESS:
return "LESS";
case FiaCompareType::LESS_EQUAL:
return "LESS_EQUAL";
case FiaCompareType::NOT_EQUAL:
return "NOT_EQUAL";
default:
return "UNKNOWN";
}
}
std::string FiaTilingShapeCompare::CompareTypeToSerialSymbolString(const FiaCompareType &compareType) const {
switch (compareType) {
case FiaCompareType::EQUAL:
return "==";
case FiaCompareType::GREATER:
return ">";
case FiaCompareType::GREATER_EQUAL:
return ">=";
case FiaCompareType::LESS:
return "<";
case FiaCompareType::LESS_EQUAL:
return "<=";
case FiaCompareType::NOT_EQUAL:
return "!=";
default:
return "UNKNOWN";
}
}
ge::graphStatus FiaTilingShapeCompare::GetExpectedShapeSpecial(
gert::Shape &shapeExpected, const FiaTilingShapeCompareParam ¶m, const std::string &funcName) const {
if (layout_ == FiaLayout::BNS1S2) {
shapeExpected = gert::Shape({param.B, param.N, param.S1, param.S2});
} else if (layout_ == FiaLayout::INS1S2) {
shapeExpected = gert::Shape({param.CONST, param.N, param.S1, param.S2});
} else if (layout_ == FiaLayout::BNS11) {
shapeExpected = gert::Shape({param.B, param.N, param.S1, param.CONST});
} else if (layout_ == FiaLayout::TN1) {
shapeExpected = gert::Shape({param.T, param.N, param.CONST});
} else if (layout_ == FiaLayout::BS2) {
shapeExpected = gert::Shape({param.B, param.S2});
} else if (layout_ == FiaLayout::S1S2) {
shapeExpected = gert::Shape({param.S1, param.S2});
} else if (layout_ == FiaLayout::BS1S2) {
shapeExpected = gert::Shape({param.B, param.S1, param.S2});
} else if (layout_ == FiaLayout::B1S1S2) {
shapeExpected = gert::Shape({param.B, param.CONST, param.S1, param.S2});
} else if (layout_ == FiaLayout::IS1S2) {
shapeExpected = gert::Shape({param.CONST, param.S1, param.S2});
} else if (layout_ == FiaLayout::I1S1S2) {
shapeExpected = gert::Shape({param.CONST, param.CONST, param.S1, param.S2});
} else if (layout_ == FiaLayout::S1S1) {
shapeExpected = gert::Shape({param.S1});
} else {
OP_LOGE(opName_, "[%s] layout %s is unsupported", funcName.c_str(), LayoutToSerialString(layout_).c_str());
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus FiaTilingShapeCompare::GetExpectedShape(
gert::Shape &shapeExpected, const FiaTilingShapeCompareParam ¶m, const std::string &funcName) const {
if (layout_ == FiaLayout::BSH) {
shapeExpected = gert::Shape({param.B, param.S, param.H});
} else if (layout_ == FiaLayout::BSND) {
shapeExpected = gert::Shape({param.B, param.S, param.N, param.D});
} else if (layout_ == FiaLayout::BNSD) {
shapeExpected = gert::Shape({param.B, param.N, param.S, param.D});
} else if (layout_ == FiaLayout::BnBsH) {
shapeExpected = gert::Shape({param.Bn, param.Bs, param.H});
} else if (layout_ == FiaLayout::BnNBsD) {
shapeExpected = gert::Shape({param.Bn, param.N, param.Bs, param.D});
} else if (layout_ == FiaLayout::NZ) {
shapeExpected = gert::Shape({param.Bn, param.N, param.D / param.D0, param.Bs, param.D0});
} else if (layout_ == FiaLayout::TND) {
shapeExpected = gert::Shape({param.T, param.N, param.D});
} else if (layout_ == FiaLayout::NBSD) {
shapeExpected = gert::Shape({param.N, param.B, param.S, param.D});
} else if (layout_ == FiaLayout::NTD) {
shapeExpected = gert::Shape({param.N, param.T, param.D});
} else {
return GetExpectedShapeSpecial(shapeExpected, param, funcName);
}
return ge::GRAPH_SUCCESS;
}
FiaCompareType FiaTilingShapeCompare::GetCompareType(
const std::map<FiaAxis, FiaCompareType> &compareTypeMap, const FiaAxis &axis) const {
auto it = compareTypeMap.find(axis);
auto compareType = FiaCompareType::EQUAL;
if (it != compareTypeMap.end()) {
compareType = it->second;
}
return compareType;
}
ge::graphStatus FiaTilingShapeCompare::GetCompareFunc(
const FiaCompareType &compareType, CompareFunc<int64_t> &compareFunc, const std::string &funcName) const {
auto it = compareFuncMap_.find(compareType);
if (it == compareFuncMap_.end()) {
OP_LOGE(opName_, "[%s] compare type %s is unsupported.", funcName.c_str(),
CompareTypeToSerialString(compareType).c_str());
return ge::GRAPH_FAILED;
}
compareFunc = it->second;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus FiaTilingShapeCompare::CompareShape(
FiaTilingShapeCompareParam ¶m, const std::string &funcName) const {
param.H = param.N * param.D;
gert::Shape shapeExpected;
if (GetExpectedShape(shapeExpected, param, funcName) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
std::vector<FiaAxis> layoutAxes;
if (GetLayoutAxes(layoutAxes, layout_, opName_, funcName) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
if ((shape_.GetDimNum() != shapeExpected.GetDimNum()) || (shape_.GetDimNum() != layoutAxes.size())) {
OP_LOGE(opName_,
"[%s] %s shape dimension is %zu, expected shape dimension is %zu, layout(%s) axes size is %zu, they should "
"be equal.",
funcName.c_str(), name_.c_str(), shape_.GetDimNum(), shapeExpected.GetDimNum(),
LayoutToSerialString(layout_).c_str(), layoutAxes.size());
return ge::GRAPH_FAILED;
}
for (size_t i = 0; i < shape_.GetDimNum(); i++) {
auto axis = layoutAxes[i];
auto compareType = GetCompareType(param.compareTypeMap, axis);
CompareFunc<int64_t> compareFunc;
if (GetCompareFunc(compareType, compareFunc, funcName) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
if (!compareFunc(shape_.GetDim(i), shapeExpected.GetDim(i))) {
if (param.compareTypeMap.empty()) {
OP_LOGE(opName_, "[%s] %s layout is %s, shape %s should be equal to %s.", funcName.c_str(),
name_.c_str(), LayoutToSerialString(layout_).c_str(), GetShapeStr(shape_).c_str(),
GetShapeStr(shapeExpected).c_str());
} else {
OP_LOGE(opName_,
"[%s] %s layout is %s, shape is %s, expected shape is %s, axis %s(%ld) should be %s expected %ld.",
funcName.c_str(), name_.c_str(), LayoutToSerialString(layout_).c_str(), GetShapeStr(shape_).c_str(),
GetShapeStr(shapeExpected).c_str(), AxisToSerialString(axis).c_str(), shape_.GetDim(i),
CompareTypeToSerialSymbolString(compareType).c_str(), shapeExpected.GetDim(i));
}
return ge::GRAPH_FAILED;
}
}
return ge::GRAPH_SUCCESS;
}
}