* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file fia_tiling_shape.h
* \brief
*/
#ifndef FIA_TILING_SHAPE_H
#define FIA_TILING_SHAPE_H
#include "fia_tiling_info.h"
namespace optiling {
template <typename T> using CompareFunc = bool (*)(const T&, const T&);
enum class FiaCompareType : uint32_t {
EQUAL = 0,
GREATER = 1,
GREATER_EQUAL = 2,
LESS = 3,
LESS_EQUAL = 4,
NOT_EQUAL = 5,
IGNORE_INPUT = 6
};
struct FiaTilingShapeCompareParam {
int64_t B = 1;
int64_t S = 1;
int64_t N = 1;
int64_t D = 1;
int64_t H = 1;
int64_t T = 1;
int64_t Bn = 1;
int64_t Bs = 1;
int64_t D0 = 16;
int64_t S1 = 1;
int64_t S2 = 1;
int64_t CONST = 1;
std::map<FiaAxis, FiaCompareType> compareTypeMap = {};
};
static std::string GetShapeStr(gert::Shape shape)
{
std::ostringstream oss;
oss << "[";
if (shape.GetDimNum() > 0) {
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
oss << shape.GetDim(i) << ", ";
}
oss << shape.GetDim(shape.GetDimNum() - 1);
}
oss << "]";
return oss.str();
}
class FiaTilingShape {
static constexpr int64_t invalidDimValue_ = std::numeric_limits<int64_t>::min();
public:
FiaTilingShape(const gert::Shape &shape, FiaLayout layout, std::string name, std::string opName,
int64_t N = std::numeric_limits<int64_t>::min()) :
shape_(shape), layout_(layout), name_(name), opName_(opName)
{
if (HasShapeH() && N != std::numeric_limits<int64_t>::min()) {
N_ = N;
hasSetN_ = true;
}
};
public:
const gert::Shape &shape_;
FiaLayout layout_;
std::string name_ ;
std::string opName_;
bool hasSetN_ = false;
int64_t N_ = 1;
size_t GetDimNum() const { return shape_.GetDimNum(); }
bool HasShapeB() const
{
return HasAxis(FiaAxis::B);
}
bool HasShapeS() const
{
return HasAxis(FiaAxis::S);
}
bool HasShapeH() const {
return HasAxis(FiaAxis::H);
}
bool HasShapeN() const
{
return HasAxis(FiaAxis::N);
}
bool HasShapeT() const
{
return HasAxis(FiaAxis::T);
}
bool HasShapeD1() const
{
return HasAxis(FiaAxis::D1);
}
bool HasShapeD0() const
{
return HasAxis(FiaAxis::D0);
}
bool HasShapeD() const
{
if (HasAxis(FiaAxis::D)) { return true; }
if (HasShapeH() && hasSetN_ && N_ != 0 && GetShapeH() % N_ == 0) { return true; }
if (HasShapeD1() && HasShapeD0()) { return true; }
return false;
}
int64_t GetShapeB() const { return GetAxisNum(FiaAxis::B); }
int64_t GetShapeS() const { return GetAxisNum(FiaAxis::S); }
int64_t GetShapeN() const { return GetAxisNum(FiaAxis::N); }
int64_t GetShapeH() const { return GetAxisNum(FiaAxis::H); }
int64_t GetShapeT() const { return GetAxisNum(FiaAxis::T); }
int64_t GetShapeD1() const { return GetAxisNum(FiaAxis::D1); }
int64_t GetShapeD0() const { return GetAxisNum(FiaAxis::D0); }
int64_t GetShapeD() const
{
if (HasAxis(FiaAxis::D)) { return shape_.GetDim(GetAxisIdx(FiaAxis::D)); }
if (HasShapeH() && hasSetN_ && N_ != 0 && GetShapeH() % N_ == 0) { return GetShapeH() / N_; }
if (HasShapeD1() && HasShapeD0()) { return GetShapeD1() * GetShapeD0(); }
return invalidDimValue_;
}
ge::graphStatus CheckHasShapeB(const std::string &funcName) const { return CheckHasAxis(FiaAxis::B, funcName); }
ge::graphStatus CheckHasShapeS(const std::string &funcName) const { return CheckHasAxis(FiaAxis::S, funcName); }
ge::graphStatus CheckHasShapeD(const std::string &funcName) const { return CheckHasAxis(FiaAxis::D, funcName); }
ge::graphStatus CheckHasShapeN(const std::string &funcName) const { return CheckHasAxis(FiaAxis::N, funcName); }
ge::graphStatus CheckHasShapeH(const std::string &funcName) const { return CheckHasAxis(FiaAxis::H, funcName); }
ge::graphStatus CheckHasShapeT(const std::string &funcName) const { return CheckHasAxis(FiaAxis::T, funcName); }
private:
bool HasAxis(const FiaAxis &axis) const;
size_t GetAxisIdx(const FiaAxis &axis) const;
int64_t GetAxisNum(const FiaAxis &axis) const;
ge::graphStatus CheckHasAxis(const FiaAxis &axis, const std::string &funcName) const;
};
class FiaTilingShapeCompare {
static const std::map<FiaCompareType, CompareFunc<int64_t>> compareFuncMap_;
public:
FiaTilingShapeCompare(const gert::Shape &shape, FiaLayout layout, std::string name, std::string opName) :
shape_(shape), layout_(layout), name_(name), opName_(opName) {};
public:
const gert::Shape &shape_;
FiaLayout layout_;
std::string name_ ;
std::string opName_;
std::string CompareTypeToSerialString(const FiaCompareType compareType) const;
std::string CompareTypeToSerialSymbolString(const FiaCompareType &compareType) const;
ge::graphStatus GetExpectedShapeSpecial(gert::Shape &shapeExpected,
const FiaTilingShapeCompareParam ¶m, const std::string &funcName) const;
ge::graphStatus GetExpectedShape(gert::Shape &shapeExpected,
const FiaTilingShapeCompareParam ¶m, const std::string &funcName) const;
FiaCompareType GetCompareType(const std::map<FiaAxis, FiaCompareType> &compareTypeMap,
const FiaAxis &axis) const;
ge::graphStatus GetCompareFunc(const FiaCompareType &compareType,
CompareFunc<int64_t> &compareFunc, const std::string &funcName) const;
ge::graphStatus CompareShape(FiaTilingShapeCompareParam ¶m, const std::string &funcName) const;
};
}
#endif