* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file fia_tiling_shape.cpp
* \brief
*/
#include <vector>
#include <algorithm>
#include "fia_tiling_shape.h"
namespace optiling {
static const std::map<FiaLayout, std::vector<FiaAxis>> FIA_LAYOUT_AXIS_MAP = {
{FiaLayout::BSH, {FiaAxis::B, FiaAxis::S, FiaAxis::H}},
{FiaLayout::BSND, {FiaAxis::B, FiaAxis::S, FiaAxis::N, FiaAxis::D}},
{FiaLayout::BNSD, {FiaAxis::B, FiaAxis::N, FiaAxis::S, FiaAxis::D}},
{FiaLayout::NZ, {FiaAxis::Bn, FiaAxis::N, FiaAxis::D1, FiaAxis::Bs, FiaAxis::D0}},
{FiaLayout::TND, {FiaAxis::T, FiaAxis::N, FiaAxis::D}},
{FiaLayout::NBSD, {FiaAxis::N, FiaAxis::B, FiaAxis::S, FiaAxis::D}},
{FiaLayout::NTD, {FiaAxis::N, FiaAxis::T, FiaAxis::D}},
{FiaLayout::BS2, {FiaAxis::B, FiaAxis::S2}},
{FiaLayout::S1S2, {FiaAxis::S1, FiaAxis::S2}},
{FiaLayout::BnBsH, {FiaAxis::Bn, FiaAxis::Bs, FiaAxis::H}},
{FiaLayout::BnNBsD, {FiaAxis::Bn, FiaAxis::N, FiaAxis::Bs, FiaAxis::D}},
{FiaLayout::BNS1S2, {FiaAxis::B, FiaAxis::N, FiaAxis::S1, FiaAxis::S2}},
{FiaLayout::INS1S2, {FiaAxis::CONST, FiaAxis::N, FiaAxis::S1, FiaAxis::S2}},
{FiaLayout::BNS11, {FiaAxis::B, FiaAxis::N, FiaAxis::S1, FiaAxis::CONST}},
{FiaLayout::TN1, {FiaAxis::T, FiaAxis::N, FiaAxis::CONST}},
{FiaLayout::BS1S2, {FiaAxis::B, FiaAxis::S1, FiaAxis::S2}},
{FiaLayout::B1S1S2, {FiaAxis::B, FiaAxis::CONST, FiaAxis::S1, FiaAxis::S2}},
{FiaLayout::IS1S2, {FiaAxis::CONST, FiaAxis::S1, FiaAxis::S2}},
{FiaLayout::I1S1S2, {FiaAxis::CONST, FiaAxis::CONST, FiaAxis::S1, FiaAxis::S2}},
{FiaLayout::S1S1, {FiaAxis::S1}},
};
static bool equal_to(const int64_t& a, const int64_t& b)
{
return (a == b);
}
static bool greater(const int64_t& a, const int64_t& b)
{
return (a > b);
}
static bool greater_equal(const int64_t& a, const int64_t& b)
{
return (a >= b);
}
static bool less(const int64_t& a, const int64_t& b)
{
return (a < b);
}
static bool less_equal(const int64_t& a, const int64_t& b)
{
return (a <= b);
}
static bool not_equal_to(const int64_t& a, const int64_t& b)
{
return (a != b);
}
static bool ignore_input(const int64_t& a, const int64_t& b)
{
(void)a;
(void)b;
return true;
}
static ge::graphStatus GetLayoutAxes(std::vector<FiaAxis> &layoutAxes, const FiaLayout &layout,
const std::string &opName, const std::string &funcName)
{
auto it = FIA_LAYOUT_AXIS_MAP.find(layout);
if (it == FIA_LAYOUT_AXIS_MAP.end()) {
OP_LOGE(opName, "[%s] compare layout %s is unsupported.",
funcName.c_str(), LayoutToSerialString(layout).c_str());
return ge::GRAPH_FAILED;
}
layoutAxes = it->second;
return ge::GRAPH_SUCCESS;
}
const std::map<FiaCompareType, CompareFunc<int64_t>> FiaTilingShapeCompare::compareFuncMap_ = {
{FiaCompareType::EQUAL, equal_to},
{FiaCompareType::GREATER, greater},
{FiaCompareType::GREATER_EQUAL, greater_equal},
{FiaCompareType::LESS, less},
{FiaCompareType::LESS_EQUAL, less_equal},
{FiaCompareType::NOT_EQUAL, not_equal_to},
{FiaCompareType::IGNORE_INPUT, ignore_input}
};
bool FiaTilingShape::HasAxis(const FiaAxis &axis) const
{
const auto& layoutIt = FIA_LAYOUT_AXIS_MAP.find(layout_);
if (layoutIt == FIA_LAYOUT_AXIS_MAP.end()) {
return false;
}
const std::vector<FiaAxis>& axes = layoutIt->second;
const auto& axisIt = std::find(axes.begin(), axes.end(), axis);
if (axisIt == axes.end()) {
return false;
}
return true;
}
size_t FiaTilingShape::GetAxisIdx(const FiaAxis &axis) const
{
if (HasAxis(axis)) {
const std::vector<FiaAxis>& axes = FIA_LAYOUT_AXIS_MAP.find(layout_)->second;
const auto& axisIt = std::find(axes.begin(), axes.end(), axis);
return std::distance(axes.begin(), axisIt);
}
return 0;
}
int64_t FiaTilingShape::GetAxisNum(const FiaAxis &axis) const
{
return HasAxis(axis) ? shape_.GetDim(GetAxisIdx(axis)) : invalidDimValue_;
}
ge::graphStatus FiaTilingShape::CheckHasAxis(const FiaAxis &axis, const std::string &funcName) const
{
if (shape_.GetDimNum() == 0) {
OP_LOGE(opName_, "[%s] the dim number of %s is 0.", funcName.c_str(), name_.c_str());
return ge::GRAPH_FAILED;
}
std::vector<FiaAxis> layoutAxes;
if (GetLayoutAxes(layoutAxes, layout_, opName_, funcName) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
if (shape_.GetDimNum() != layoutAxes.size()) {
OP_LOGE(opName_,
"[%s] %s shape dimension is %zu, expected shape dimension is %zu, layout(%s) axes size is %zu, they should be equal.",
funcName.c_str(), name_.c_str(), shape_.GetDimNum(), layoutAxes.size(),
LayoutToSerialString(layout_).c_str(), layoutAxes.size());
return ge::GRAPH_FAILED;
}
if ((axis == FiaAxis::D)) {
if (HasShapeD()) {
return ge::GRAPH_SUCCESS;
} else if (!HasShapeH()) {
OP_LOGE(opName_, "[%s] %s's layout is %s, do not have D and H.",
funcName.c_str(), name_.c_str(), LayoutToSerialString(layout_).c_str());
return ge::GRAPH_FAILED;
} else if (!hasSetN_) {
OP_LOGE(opName_, "[%s] %s's N is not specified, cannot caculate D by H.", funcName.c_str(), name_.c_str());
return ge::GRAPH_FAILED;
} else if (N_ == 0) {
OP_LOGE(opName_, "[%s] %s's N is 0.", funcName.c_str(), name_.c_str());
return ge::GRAPH_FAILED;
} else if (GetShapeH() % N_ != 0) {
OP_LOGE(opName_, "[%s] %s's H(%ld) should be an integer multiple of N(%ld).",
funcName.c_str(), name_.c_str(), GetShapeH(), N_);
return ge::GRAPH_FAILED;
}
} else if (HasAxis(axis)) {
return ge::GRAPH_SUCCESS;
}
OP_LOGE(opName_, "[%s] %s's layout is %s, %s is not exists.",
funcName.c_str(), name_.c_str(), LayoutToSerialString(layout_).c_str(),
AxisToSerialString(axis).c_str());
return ge::GRAPH_FAILED;
}
std::string FiaTilingShapeCompare::CompareTypeToSerialString(const FiaCompareType compareType) const
{
switch (compareType) {
case FiaCompareType::EQUAL:
return "EQUAL";
case FiaCompareType::GREATER:
return "GREATER";
case FiaCompareType::GREATER_EQUAL:
return "GREATER_EQUAL";
case FiaCompareType::LESS:
return "LESS";
case FiaCompareType::LESS_EQUAL:
return "LESS_EQUAL";
case FiaCompareType::NOT_EQUAL:
return "NOT_EQUAL";
default:
return "UNKNOWN";
}
}
std::string FiaTilingShapeCompare::CompareTypeToSerialSymbolString(const FiaCompareType &compareType) const
{
switch (compareType) {
case FiaCompareType::EQUAL:
return "==";
case FiaCompareType::GREATER:
return ">";
case FiaCompareType::GREATER_EQUAL:
return ">=";
case FiaCompareType::LESS:
return "<";
case FiaCompareType::LESS_EQUAL:
return "<=";
case FiaCompareType::NOT_EQUAL:
return "!=";
default:
return "UNKNOWN";
}
}
ge::graphStatus FiaTilingShapeCompare::GetExpectedShapeSpecial(gert::Shape &shapeExpected,
const FiaTilingShapeCompareParam ¶m, const std::string &funcName) const
{
if (layout_ == FiaLayout::BNS1S2) {
shapeExpected = gert::Shape({param.B, param.N, param.S1, param.S2});
} else if (layout_ == FiaLayout::INS1S2) {
shapeExpected = gert::Shape({param.CONST, param.N, param.S1, param.S2});
} else if (layout_ == FiaLayout::BNS11) {
shapeExpected = gert::Shape({param.B, param.N, param.S1, param.CONST});
} else if (layout_ == FiaLayout::TN1) {
shapeExpected = gert::Shape({param.T, param.N, param.CONST});
} else if (layout_ == FiaLayout::BS2) {
shapeExpected = gert::Shape({param.B, param.S2});
} else if (layout_ == FiaLayout::S1S2) {
shapeExpected = gert::Shape({param.S1, param.S2});
} else if (layout_ == FiaLayout::BS1S2) {
shapeExpected = gert::Shape({param.B, param.S1, param.S2});
} else if (layout_ == FiaLayout::B1S1S2) {
shapeExpected = gert::Shape({param.B, param.CONST, param.S1, param.S2});
} else if (layout_ == FiaLayout::IS1S2) {
shapeExpected = gert::Shape({param.CONST, param.S1, param.S2});
} else if (layout_ == FiaLayout::I1S1S2) {
shapeExpected = gert::Shape({param.CONST, param.CONST, param.S1, param.S2});
} else if (layout_ == FiaLayout::S1S1) {
shapeExpected = gert::Shape({param.S1});
} else {
OP_LOGE(opName_, "[%s] layout %s is unsupported", funcName.c_str(), LayoutToSerialString(layout_).c_str());
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus FiaTilingShapeCompare::GetExpectedShape(gert::Shape &shapeExpected,
const FiaTilingShapeCompareParam ¶m, const std::string &funcName) const
{
if (layout_ == FiaLayout::BSH) {
shapeExpected = gert::Shape({param.B, param.S, param.H});
} else if (layout_ == FiaLayout::BSND) {
shapeExpected = gert::Shape({param.B, param.S, param.N, param.D});
} else if (layout_ == FiaLayout::BNSD) {
shapeExpected = gert::Shape({param.B, param.N, param.S, param.D});
} else if (layout_ == FiaLayout::BnBsH) {
shapeExpected = gert::Shape({param.Bn, param.Bs, param.H});
} else if (layout_ == FiaLayout::BnNBsD) {
shapeExpected = gert::Shape({param.Bn, param.N, param.Bs, param.D});
} else if (layout_ == FiaLayout::NZ) {
shapeExpected = gert::Shape({param.Bn, param.N, param.D / param.D0, param.Bs, param.D0});
} else if (layout_ == FiaLayout::TND) {
shapeExpected = gert::Shape({param.T, param.N, param.D});
} else if (layout_ == FiaLayout::NBSD) {
shapeExpected = gert::Shape({param.N, param.B, param.S, param.D});
} else if (layout_ == FiaLayout::NTD) {
shapeExpected = gert::Shape({param.N, param.T, param.D});
} else {
return GetExpectedShapeSpecial(shapeExpected, param, funcName);
}
return ge::GRAPH_SUCCESS;
}
FiaCompareType FiaTilingShapeCompare::GetCompareType(const std::map<FiaAxis, FiaCompareType> &compareTypeMap,
const FiaAxis &axis) const
{
auto it = compareTypeMap.find(axis);
auto compareType = FiaCompareType::EQUAL;
if (it != compareTypeMap.end()) {
compareType = it->second;
}
return compareType;
}
ge::graphStatus FiaTilingShapeCompare::GetCompareFunc(const FiaCompareType &compareType,
CompareFunc<int64_t> &compareFunc, const std::string &funcName) const
{
auto it = compareFuncMap_.find(compareType);
if (it == compareFuncMap_.end()) {
OP_LOGE(opName_, "[%s] compare type %s is unsupported.", funcName.c_str(),
CompareTypeToSerialString(compareType).c_str());
return ge::GRAPH_FAILED;
}
compareFunc = it->second;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus FiaTilingShapeCompare::CompareShape(FiaTilingShapeCompareParam ¶m, const std::string &funcName) const
{
param.H = param.N * param.D;
gert::Shape shapeExpected;
if (GetExpectedShape(shapeExpected, param, funcName) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
std::vector<FiaAxis> layoutAxes;
if (GetLayoutAxes(layoutAxes, layout_, opName_, funcName) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
if ((shape_.GetDimNum() != shapeExpected.GetDimNum()) || (shape_.GetDimNum() != layoutAxes.size())) {
OP_LOGE(opName_,
"[%s] %s shape dimension is %zu, expected shape dimension is %zu, layout(%s) axes size is %zu, they should be equal.",
funcName.c_str(), name_.c_str(), shape_.GetDimNum(), shapeExpected.GetDimNum(),
LayoutToSerialString(layout_).c_str(), layoutAxes.size());
return ge::GRAPH_FAILED;
}
for (size_t i = 0; i < shape_.GetDimNum(); i++) {
auto axis = layoutAxes[i];
auto compareType = GetCompareType(param.compareTypeMap, axis);
CompareFunc<int64_t> compareFunc;
if (GetCompareFunc(compareType, compareFunc, funcName) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
if (!compareFunc(shape_.GetDim(i), shapeExpected.GetDim(i))) {
if (param.compareTypeMap.empty()) {
OP_LOGE(opName_, "[%s] %s layout is %s, shape %s should be equal to %s.",
funcName.c_str(), name_.c_str(), LayoutToSerialString(layout_).c_str(),
GetShapeStr(shape_).c_str(), GetShapeStr(shapeExpected).c_str());
} else {
OP_LOGE(opName_, "[%s] %s layout is %s, shape is %s, expected shape is %s, axis %s(%ld) should be %s expected %ld.",
funcName.c_str(), name_.c_str(), LayoutToSerialString(layout_).c_str(),
GetShapeStr(shape_).c_str(), GetShapeStr(shapeExpected).c_str(), AxisToSerialString(axis).c_str(),
shape_.GetDim(i), CompareTypeToSerialSymbolString(compareType).c_str(), shapeExpected.GetDim(i));
}
return ge::GRAPH_FAILED;
}
}
return ge::GRAPH_SUCCESS;
}
}