* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "transfer_shape_utils.h"
#include <mutex>
#include "axis_constants.h"
#include "common/ge_common/string_util.h"
#include "graph/ge_error_codes.h"
#include "platform/platform_info.h"
#include "graph/utils/type_utils.h"
#include "graph/types.h"
#include "expand_dimension.h"
namespace transformer {
std::array<uint32_t, static_cast<size_t>(ge::DataType::DT_MAX)> TransferShapeUtils::m0_list_{};
std::array<uint32_t, static_cast<size_t>(ge::DataType::DT_MAX)> TransferShapeUtils::k0_list_{};
std::array<uint32_t, static_cast<size_t>(ge::DataType::DT_MAX)> TransferShapeUtils::n0_list_{};
const std::map<TransferShapeType, GetAlignedShapeFunc> TransferShapeUtils::get_aligned_shape_func_map = {
{TransferShapeType::ND_TO_ND, GetNdToNdAlignedShape},
{TransferShapeType::ND_TO_NZ, GetNdToNzAlignedShape},
{TransferShapeType::FULL_SIZE, GetFullSizeAlignedShape},
{TransferShapeType::NOT_FULL_SIZE, GetNotFullSizeAlignedShape},
};
const std::map<TransferShapeType, TransferDimsFunc> TransferShapeUtils::transfer_dims_func_map = {
{TransferShapeType::ND_TO_ND, GetNdToNdAxisIndexMapping},
{TransferShapeType::ND_TO_NZ, GetNdToNzAxisIndexMapping},
{TransferShapeType::FULL_SIZE, GetFullSizeAxisIndexMapping},
{TransferShapeType::NOT_FULL_SIZE, GetNotFullSizeAxisIndexMapping},
};
namespace {
const int64_t SHAPE_NUMBER_32 = 32;
const int64_t SHAPE_NUMBER_16 = 16;
const int64_t SHAPE_NUMBER_8 = 8;
const int64_t SHAPE_NUMBER_4 = 4;
const int64_t SHAPE_NUMBER_2 = 2;
const int64_t NI = 16;
const int64_t LSTM_NI = 4;
const int64_t GROUPS_DEFAULT_VALUE = 1;
const int64_t UNKNOWN_SHAPE_VALUE = -1;
const int64_t RNN_STATE_SIZE_DEFAULT_VALUE = -1;
const size_t NUMBER_2 = 2;
const size_t MINUS_VALUE_ONE = 1;
const size_t MINUS_VALUE_TWO = 2;
const size_t DIM_INDEX_N = 0;
const size_t DIM_INDEX_C = 1;
const size_t DIM_INDEX_H = 2;
const size_t DIM_INDEX_W = 3;
const size_t DIM_INDEX_D = 4;
const size_t DIM_INDEX_ZERO = 0;
const size_t DIM_INDEX_ONE = 1;
const size_t DIM_INDEX_TWO = 2;
const size_t DIM_INDEX_THREE = 3;
const size_t DIM_INDEX_FOUR = 4;
const size_t kM0Index = 0;
const size_t kK0Index = 1;
const size_t kN0Index = 2;
const size_t kNzMinDimNum = 2;
constexpr size_t MOKOCO_CONFIG_SIZE = 3;
const std::string kPltDtypeMKN = "DtypeMKN";
const std::string kPltDefault = "Default";
const std::map<ge::Format, FormatIndex> kFormatIndexMap = {
{ge::FORMAT_NCHW, {DIM_INDEX_ZERO, DIM_INDEX_ONE, DIM_INDEX_TWO, DIM_INDEX_THREE, DIM_INDEX_FOUR}},
{ge::FORMAT_NHWC, {DIM_INDEX_ZERO, DIM_INDEX_THREE, DIM_INDEX_ONE, DIM_INDEX_TWO, DIM_INDEX_FOUR}},
{ge::FORMAT_HWCN, {DIM_INDEX_THREE, DIM_INDEX_TWO, DIM_INDEX_ZERO, DIM_INDEX_ONE, DIM_INDEX_FOUR}},
{ge::FORMAT_CHWN, {DIM_INDEX_THREE, DIM_INDEX_ZERO, DIM_INDEX_ONE, DIM_INDEX_TWO, DIM_INDEX_FOUR}},
{ge::FORMAT_ND, {DIM_INDEX_ZERO, DIM_INDEX_ONE, DIM_INDEX_TWO, DIM_INDEX_THREE, DIM_INDEX_FOUR}},
{ge::FORMAT_NCDHW, {DIM_INDEX_ZERO, DIM_INDEX_ONE, DIM_INDEX_THREE, DIM_INDEX_FOUR, DIM_INDEX_TWO}},
{ge::FORMAT_NDHWC, {DIM_INDEX_ZERO, DIM_INDEX_FOUR, DIM_INDEX_TWO, DIM_INDEX_THREE, DIM_INDEX_ONE}},
{ge::FORMAT_DHWCN, {DIM_INDEX_FOUR, DIM_INDEX_THREE, DIM_INDEX_ONE, DIM_INDEX_TWO, DIM_INDEX_ZERO}},
{ge::FORMAT_DHWNC, {DIM_INDEX_THREE, DIM_INDEX_FOUR, DIM_INDEX_ONE, DIM_INDEX_TWO, DIM_INDEX_ZERO}}
};
const std::set<ge::Format> kOriginFormatVec = {
ge::FORMAT_NCHW, ge::FORMAT_NHWC, ge::FORMAT_HWCN,
ge::FORMAT_CHWN, ge::FORMAT_NDHWC, ge::FORMAT_NCDHW,
ge::FORMAT_DHWCN, ge::FORMAT_DHWNC, ge::FORMAT_ND
};
inline int64_t GetGreatestCommonDivisor(int64_t x, int64_t y) {
if (y == 0) {
return x;
}
int64_t z = y;
while (x % y != 0) {
z = x % y;
x = y;
y = z;
}
return z;
}
inline int64_t GetLeastCommonMultiple(int64_t x, int64_t y) {
if (x == 0 || y == 0) {
return 0;
}
return (x * y) / GetGreatestCommonDivisor(x, y);
}
inline int64_t GetAsisEnlargeValue(int64_t cin, int64_t cout, int64_t c0, int64_t group) {
if (cin == 0 || cout == 0) {
return 0;
}
return std::min(GetLeastCommonMultiple(c0 / GetGreatestCommonDivisor(cin, c0),
NI / GetGreatestCommonDivisor(cout, NI)), group);
}
}
bool TransferShapeUtils::InitM0K0CO(const fe::PlatFormInfos *platform_infos) {
std::string default_mkn;
fe::PlatFormInfos *temp_platform_infos = const_cast<fe::PlatFormInfos *>(platform_infos);
if (temp_platform_infos->GetPlatformResWithLock(kPltDtypeMKN, kPltDefault, default_mkn)) {
GELOGD("Default MKN value from platform is [%s].", default_mkn.c_str());
std::vector<string> infos = ge::StringUtils::Split(default_mkn, ',');
if (infos.size() != MOKOCO_CONFIG_SIZE) {
return false;
}
m0_list_.fill(static_cast<uint32_t>(std::atoi(infos[kM0Index].c_str())));
k0_list_.fill(static_cast<uint32_t>(std::atoi(infos[kK0Index].c_str())));
n0_list_.fill(static_cast<uint32_t>(std::atoi(infos[kN0Index].c_str())));
}
std::map<std::string, std::string> m0_k0_n0_info;
temp_platform_infos->GetPlatformResWithLock(kPltDtypeMKN, m0_k0_n0_info);
for (auto &item : m0_k0_n0_info) {
if (item.first == kPltDefault) {
continue;
}
ge::DataType dtype = ge::TypeUtils::SerialStringToDataType(item.first);
if (dtype == ge::DT_UNDEFINED) {
continue;
}
std::vector<string> infos = ge::StringUtils::Split(item.second, ',');
if (infos.size() != MOKOCO_CONFIG_SIZE) {
continue;
}
m0_list_[static_cast<size_t>(dtype)] = static_cast<uint32_t>(std::atoi(infos[kM0Index].c_str()));
k0_list_[static_cast<size_t>(dtype)] = static_cast<uint32_t>(std::atoi(infos[kK0Index].c_str()));
n0_list_[static_cast<size_t>(dtype)] = static_cast<uint32_t>(std::atoi(infos[kN0Index].c_str()));
}
return true;
}
bool TransferShapeUtils::InitPlatformInfo(const fe::PlatFormInfos *platform_infos_ptr) {
static std::once_flag flag;
std::call_once(flag, [&platform_infos_ptr]() {
m0_list_.fill(SHAPE_NUMBER_16);
k0_list_.fill(SHAPE_NUMBER_16);
n0_list_.fill(SHAPE_NUMBER_16);
fe::PlatFormInfos platform_infos;
fe::OptionalInfos optional_infos;
if (platform_infos_ptr == nullptr) {
GELOGI("Input platform is null; now retrieving m0k0c0 from platformmanager autonomously.");
fe::PlatformInfoManager::GeInstance().InitializePlatformInfo();
if (fe::PlatformInfoManager::GeInstance().GetPlatformInfoWithOutSocVersion(platform_infos, optional_infos) != 0) {
GELOGW("Failed to get platform info, using default MKN value.");
return false;
}
}
return platform_infos_ptr == nullptr ? TransferShapeUtils::InitM0K0CO(&platform_infos) : TransferShapeUtils::InitM0K0CO(platform_infos_ptr);
});
return true;
}
bool TransferShapeUtils::TransferShape(const ge::Format &origin_format, const ge::Format &format,
const ge::DataType &data_type, const ExtAxisValue &ext_axis,
gert::Shape &shape, const fe::PlatFormInfos *platform_infos_ptr) {
if (!InitPlatformInfo(platform_infos_ptr)) {
GELOGW("Init platform info failed");
}
ge::Format primary_format = static_cast<ge::Format>(GetPrimaryFormat(format));
ge::Format origin_primary_format = static_cast<ge::Format>(GetPrimaryFormat(origin_format));
GELOGD("Original format is %s, new format is %s",
(ge::TypeUtils::FormatToSerialString(origin_primary_format)).c_str(),
(ge::TypeUtils::FormatToSerialString(primary_format)).c_str());
if (!IsNeedTransferShape(origin_primary_format, primary_format, shape)) {
return true;
}
if (!CheckInputParam(origin_primary_format, primary_format, data_type)) {
return false;
}
AxisValue axis_value;
axis_value.fill(1);
int64_t group = static_cast<int64_t>(ge::GetSubFormat(format));
if (group > GROUPS_DEFAULT_VALUE) {
axis_value[AXIS_G] = group;
}
axis_value[AXIS_C0] = GetC0Value(data_type, format);
axis_value[AXIS_M0] = GetM0ByDtype(data_type);
if (primary_format == ge::FORMAT_FRACTAL_ZN_RNN || primary_format == ge::FORMAT_ND_RNN_BIAS) {
axis_value[AXIS_INPUT_SIZE] = ext_axis[EXT_INDEX_INPUT_SIZE];
axis_value[AXIS_HIDDEN_SIZE] = ext_axis[EXT_INDEX_HIDDEN_SIZE];
axis_value[AXIS_STATE_SIZE] = ext_axis[EXT_INDEX_STATE_SIZE];
}
if (!IsNeedAxisValue(primary_format, shape.GetDimNum())) {
return TransferShapeByFormat(primary_format, axis_value, shape);
}
if (!AxisUtil::GetAxisValueByOriginFormat(origin_primary_format, shape, axis_value)) {
return true;
}
return TransferShapeByAxisValue(primary_format, axis_value, shape);
}
bool TransferShapeUtils::TransferShape(const ge::Format &origin_format, const ge::Format &format,
const ge::DataType &data_type, const ExtAxisValue &ext_axis,
const gert::Shape &origin_shape, gert::Shape &shape, const fe::PlatFormInfos *platform_infos_ptr) {
if (!InitPlatformInfo(platform_infos_ptr)) {
GELOGW("Init platform info failed");
}
ge::Format primary_format = static_cast<ge::Format>(GetPrimaryFormat(format));
ge::Format origin_primary_format = static_cast<ge::Format>(GetPrimaryFormat(origin_format));
GELOGD("Transfer shape from original format[%s] to format [%s].",
(ge::TypeUtils::FormatToSerialString(origin_primary_format)).c_str(),
(ge::TypeUtils::FormatToSerialString(primary_format)).c_str());
if (!IsNeedTransferShape(origin_primary_format, primary_format, origin_shape)) {
return true;
}
if (!CheckInputParam(origin_primary_format, primary_format, data_type)) {
return false;
}
int64_t c0 = GetC0Value(data_type, format);
int64_t m0 = GetM0ByDtype(data_type);
if (!IsNeedAxisValue(primary_format, origin_shape.GetDimNum())) {
return TransferShapeByOriginShape(primary_format, c0, m0, ext_axis, origin_shape, shape);
} else {
return TransferShapeByFormatIndex(origin_primary_format, format, c0, origin_shape, shape);
}
}
int64_t TransferShapeUtils::GetC0ByDtype(const ge::DataType &data_type) {
if (static_cast<size_t>(data_type) < k0_list_.size()) {
return static_cast<int64_t>(k0_list_[static_cast<size_t>(data_type)]);
}
return SHAPE_NUMBER_16;
}
int64_t TransferShapeUtils::GetM0ByDtype(const ge::DataType &data_type) {
if (static_cast<size_t>(data_type) < m0_list_.size()) {
return static_cast<int64_t>(m0_list_[static_cast<size_t>(data_type)]);
}
return SHAPE_NUMBER_16;
}
int64_t TransferShapeUtils::GetN0ByDtype(const ge::DataType &data_type) {
if (static_cast<size_t>(data_type) < n0_list_.size()) {
return static_cast<int64_t>(n0_list_[static_cast<size_t>(data_type)]);
}
return SHAPE_NUMBER_16;
}
bool TransferShapeUtils::IsNeedTransferShape(const ge::Format &origin_format, const ge::Format &format,
const gert::Shape &shape) {
if (origin_format == ge::FORMAT_ND && kOriginFormatVec.count(format) > 0) {
GELOGD("No need for shape transformation from ND to the original format.");
return false;
}
if (shape.IsScalar()) {
GELOGD("Do not need to do shape transformation if the shape is scalar.");
return false;
}
return true;
}
bool TransferShapeUtils::CheckInputParam(const ge::Format &origin_format, const ge::Format &primary_format,
const ge::DataType &data_type) {
bool invalid_format = (origin_format == ge::FORMAT_RESERVED || origin_format >= ge::FORMAT_END) ||
(primary_format == ge::FORMAT_RESERVED || primary_format >= ge::FORMAT_END);
if (invalid_format) {
GELOGE(ge::GRAPH_FAILED, "Origin format %u or new format %u is invalid.", origin_format, primary_format);
return false;
}
if (data_type == ge::DT_UNDEFINED || data_type >= ge::DT_MAX) {
GELOGE(ge::GRAPH_FAILED, "DataType %u is invalid.", origin_format);
return false;
}
return true;
}
int64_t TransferShapeUtils::GetC0Value(const ge::DataType &data_type, const ge::Format &format) {
ge::Format primary_format = static_cast<ge::Format>(GetPrimaryFormat(format));
if (primary_format == ge::FORMAT_FRACTAL_NZ_C0_2) {
return SHAPE_NUMBER_2;
}
if (primary_format == ge::FORMAT_NC1HWC0_C04 || primary_format == ge::FORMAT_FRACTAL_NZ_C0_4) {
return SHAPE_NUMBER_4;
}
if (primary_format == ge::FORMAT_FRACTAL_NZ_C0_8) {
return SHAPE_NUMBER_8;
}
if (primary_format == ge::FORMAT_FRACTAL_NZ_C0_16) {
return SHAPE_NUMBER_16;
}
if (primary_format == ge::FORMAT_FRACTAL_NZ_C0_32) {
return SHAPE_NUMBER_32;
}
if (ge::HasC0Format(format)) {
return ge::GetC0Value(format);
}
return GetC0ByDtype(data_type);
}
bool TransferShapeUtils::IsNeedAxisValue(const ge::Format &format, const size_t &origin_dim_size) {
if (kFormatNZSet.count(format) > 0 || format == ge::FORMAT_FRACTAL_ZN_RNN ||
format == ge::FORMAT_ND_RNN_BIAS || format == ge::FORMAT_NYUV_A) {
return false;
}
if (format == ge::FORMAT_FRACTAL_Z && origin_dim_size == DIM_SIZE_TWO) {
return false;
}
return true;
}
bool TransferShapeUtils::TransferShapeByFormat(const ge::Format &primary_format, const AxisValue &axis_value,
gert::Shape &shape) {
switch (primary_format) {
case ge::FORMAT_FRACTAL_Z:
return GetFzShapeByAxisValue(axis_value, shape);
case ge::FORMAT_FRACTAL_NZ:
case ge::FORMAT_FRACTAL_NZ_C0_2:
case ge::FORMAT_FRACTAL_NZ_C0_4:
case ge::FORMAT_FRACTAL_NZ_C0_8:
case ge::FORMAT_FRACTAL_NZ_C0_16:
case ge::FORMAT_FRACTAL_NZ_C0_32:
return GetNzShapeByAxisValue(axis_value, shape);
case ge::FORMAT_FRACTAL_ZN_RNN:
return GetFznRNNShapeByAxisValue(axis_value, shape);
case ge::FORMAT_ND_RNN_BIAS:
return GetNDRNNShapeByAxisValue(axis_value, shape);
case ge::FORMAT_NYUV_A:
return GetNYUVShape(shape);
default:
GELOGD("Cannot obtain new shape with format %d.", primary_format);
return true;
}
}
bool TransferShapeUtils::TransferShapeByAxisValue(const ge::Format &primary_format, const AxisValue &axis_value,
gert::Shape &shape) {
switch (primary_format) {
case ge::FORMAT_NCHW:
return GetNCHWShapeByAxisValue(axis_value, shape);
case ge::FORMAT_NHWC:
return GetNHWCShapeByAxisValue(axis_value, shape);
case ge::FORMAT_HWCN:
return GetHWCNShapeByAxisValue(axis_value, shape);
case ge::FORMAT_CHWN:
return GetCHWNShapeByAxisValue(axis_value, shape);
case ge::FORMAT_NDHWC:
return GetNDHWCShapeByAxisValue(axis_value, shape);
case ge::FORMAT_NCDHW:
return GetNCDHWShapeByAxisValue(axis_value, shape);
case ge::FORMAT_DHWCN:
return GetDHWCNShapeByAxisValue(axis_value, shape);
case ge::FORMAT_DHWNC:
return GetDHWNCShapeByAxisValue(axis_value, shape);
case ge::FORMAT_NC1HWC0:
case ge::FORMAT_NC1HWC0_C04:
return GetNC1HWC0ShapeByAxisValue(axis_value, shape);
case ge::FORMAT_C1HWC0:
return GetC1HWC0ShapeByAxisValue(axis_value, shape);
case ge::FORMAT_NDC1HWC0:
return GetNDC1HWC0ShapeByAxisValue(axis_value, shape);
case ge::FORMAT_C1HWNCoC0:
return GetC1HWNCoC0ShapeByAxisValue(axis_value, shape);
case ge::FORMAT_FRACTAL_Z:
return GetFzShapeByAxisValue(axis_value, shape);
case ge::FORMAT_FRACTAL_Z_WINO:
return GetFzWinoShapeByAxisValue(axis_value, shape);
case ge::FORMAT_FRACTAL_Z_C04:
return GetFzC04ShapeByAxisValue(axis_value, shape);
case ge::FORMAT_FRACTAL_Z_3D:
return GetFz3DShapeByAxisValue(axis_value, shape);
case ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE:
return GetFz3DTransposeShapeByAxisValue(axis_value, shape);
case ge::FORMAT_FRACTAL_ZN_LSTM:
return GetFzLstmShapeByAxisValue(axis_value, shape);
default:
GELOGD("Cannot obtain new shape with format %d.", primary_format);
return true;
}
}
bool TransferShapeUtils::TransferShapeByOriginShape(const ge::Format &primary_format,
const int64_t &c0, const int64_t &m0, const ExtAxisValue &ext_axis,
const gert::Shape &origin_shape, gert::Shape &shape) {
switch (primary_format) {
case ge::FORMAT_FRACTAL_Z:
return GetFractalZShape(c0, origin_shape, shape);
case ge::FORMAT_FRACTAL_NZ:
case ge::FORMAT_FRACTAL_NZ_C0_2:
case ge::FORMAT_FRACTAL_NZ_C0_4:
case ge::FORMAT_FRACTAL_NZ_C0_8:
case ge::FORMAT_FRACTAL_NZ_C0_16:
case ge::FORMAT_FRACTAL_NZ_C0_32:
return GetFractalNzShape(c0, m0, origin_shape, shape);
case ge::FORMAT_FRACTAL_ZN_RNN:
return GetFractalZnRnnShape(ext_axis, c0, origin_shape, shape);
case ge::FORMAT_ND_RNN_BIAS:
return GetNdRnnBiasShape(ext_axis, c0, origin_shape, shape);
default:
GELOGD("Cannot obtain new shape with format %d.", primary_format);
return true;
}
}
bool TransferShapeUtils::TransferShapeByFormatIndex(const ge::Format &origin_format, const ge::Format &format,
const int64_t &c0, const gert::Shape &origin_shape,
gert::Shape &shape) {
std::map<ge::Format, FormatIndex>::const_iterator iter = kFormatIndexMap.find(origin_format);
if (iter == kFormatIndexMap.end()) {
return true;
}
ge::Format primary_format = static_cast<ge::Format>(GetPrimaryFormat(format));
int64_t group = static_cast<int64_t>(ge::GetSubFormat(format));
switch (primary_format) {
case ge::FORMAT_NCHW:
return GetNCHWShape(iter->second, origin_shape, shape);
case ge::FORMAT_NHWC:
return GetNHWCShape(iter->second, origin_shape, shape);
case ge::FORMAT_HWCN:
return GetHWCNShape(iter->second, origin_shape, shape);
case ge::FORMAT_CHWN:
return GetCHWNShape(iter->second, origin_shape, shape);
case ge::FORMAT_NDHWC:
return GetNDHWCShape(iter->second, origin_shape, shape);
case ge::FORMAT_NCDHW:
return GetNCDHWShape(iter->second, origin_shape, shape);
case ge::FORMAT_DHWCN:
return GetDHWCNShape(iter->second, origin_shape, shape);
case ge::FORMAT_DHWNC:
return GetDHWNCShape(iter->second, origin_shape, shape);
case ge::FORMAT_NC1HWC0:
case ge::FORMAT_NC1HWC0_C04:
return GetNC1HWC0Shape(iter->second, c0, origin_shape, shape);
case ge::FORMAT_C1HWC0:
return GetC1HWC0Shape(iter->second, c0, origin_shape, shape);
case ge::FORMAT_NDC1HWC0:
return GetNDC1HWC0Shape(iter->second, c0, origin_shape, shape);
case ge::FORMAT_C1HWNCoC0:
return GetC1HWNCoC0Shape(iter->second, c0, origin_shape, shape);
case ge::FORMAT_FRACTAL_Z:
return GetFractalZShape(iter->second, c0, group, origin_shape, shape);
case ge::FORMAT_FRACTAL_Z_3D:
return GetFractalZ3DShape(iter->second, c0, group, origin_shape, shape);
case ge::FORMAT_FRACTAL_Z_C04:
return GetFractalZC04Shape(iter->second, c0, origin_shape, shape);
case ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE:
return GetFractalZ3DTransposeShape(iter->second, c0, origin_shape, shape);
case ge::FORMAT_FRACTAL_ZN_LSTM:
return GetFractalZLstmShape(iter->second, origin_shape, shape);
case ge::FORMAT_FRACTAL_Z_WINO:
return GetFractalZWinoShape(iter->second, c0, origin_shape, shape);
default: {
GELOGD("Cannot obtain new shape with format %d.", primary_format);
return true;
}
}
}
bool TransferShapeUtils::GetNCHWShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
shape.SetDimNum(0);
shape.AppendDim(axis_value[AXIS_N]);
shape.AppendDim(axis_value[AXIS_C]);
shape.AppendDim(axis_value[AXIS_H]);
shape.AppendDim(axis_value[AXIS_W]);
return true;
}
bool TransferShapeUtils::GetNHWCShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
shape.SetDimNum(0);
shape.AppendDim(axis_value[AXIS_N]);
shape.AppendDim(axis_value[AXIS_H]);
shape.AppendDim(axis_value[AXIS_W]);
shape.AppendDim(axis_value[AXIS_C]);
return true;
}
bool TransferShapeUtils::GetHWCNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
shape.SetDimNum(0);
shape.AppendDim(axis_value[AXIS_H]);
shape.AppendDim(axis_value[AXIS_W]);
shape.AppendDim(axis_value[AXIS_C]);
shape.AppendDim(axis_value[AXIS_N]);
return true;
}
bool TransferShapeUtils::GetCHWNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
shape.SetDimNum(0);
shape.AppendDim(axis_value[AXIS_C]);
shape.AppendDim(axis_value[AXIS_H]);
shape.AppendDim(axis_value[AXIS_W]);
shape.AppendDim(axis_value[AXIS_N]);
return true;
}
bool TransferShapeUtils::GetNDHWCShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
shape.SetDimNum(0);
shape.AppendDim(axis_value[AXIS_N]);
shape.AppendDim(axis_value[AXIS_D]);
shape.AppendDim(axis_value[AXIS_H]);
shape.AppendDim(axis_value[AXIS_W]);
shape.AppendDim(axis_value[AXIS_C]);
return true;
}
bool TransferShapeUtils::GetNCDHWShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
shape.SetDimNum(0);
shape.AppendDim(axis_value[AXIS_N]);
shape.AppendDim(axis_value[AXIS_C]);
shape.AppendDim(axis_value[AXIS_D]);
shape.AppendDim(axis_value[AXIS_H]);
shape.AppendDim(axis_value[AXIS_W]);
return true;
}
bool TransferShapeUtils::GetDHWCNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
shape.SetDimNum(0);
shape.AppendDim(axis_value[AXIS_D]);
shape.AppendDim(axis_value[AXIS_H]);
shape.AppendDim(axis_value[AXIS_W]);
shape.AppendDim(axis_value[AXIS_C]);
shape.AppendDim(axis_value[AXIS_N]);
return true;
}
bool TransferShapeUtils::GetDHWNCShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
shape.SetDimNum(0);
shape.AppendDim(axis_value[AXIS_D]);
shape.AppendDim(axis_value[AXIS_H]);
shape.AppendDim(axis_value[AXIS_W]);
shape.AppendDim(axis_value[AXIS_N]);
shape.AppendDim(axis_value[AXIS_C]);
return true;
}
bool TransferShapeUtils::GetNC1HWC0ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
shape.SetDimNum(0);
shape.AppendDim(axis_value[AXIS_N]);
shape.AppendDim(axis_value[AXIS_C1]);
shape.AppendDim(axis_value[AXIS_H]);
shape.AppendDim(axis_value[AXIS_W]);
shape.AppendDim(axis_value[AXIS_C0]);
return true;
}
bool TransferShapeUtils::GetC1HWC0ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
shape.SetDimNum(0);
shape.AppendDim(DivisionCeiling(axis_value[AXIS_N], axis_value[AXIS_C0]));
shape.AppendDim(axis_value[AXIS_H]);
shape.AppendDim(axis_value[AXIS_W]);
shape.AppendDim(axis_value[AXIS_C0]);
return true;
}
bool TransferShapeUtils::GetNDC1HWC0ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
shape.SetDimNum(0);
shape.AppendDim(axis_value[AXIS_N]);
shape.AppendDim(axis_value[AXIS_D]);
shape.AppendDim(axis_value[AXIS_C1]);
shape.AppendDim(axis_value[AXIS_H]);
shape.AppendDim(axis_value[AXIS_W]);
shape.AppendDim(axis_value[AXIS_C0]);
return true;
}
bool TransferShapeUtils::GetC1HWNCoC0ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
shape.SetDimNum(0);
shape.AppendDim(axis_value[AXIS_C1]);
shape.AppendDim(axis_value[AXIS_H]);
shape.AppendDim(axis_value[AXIS_W]);
shape.AppendDim(DivisionCeiling(axis_value[AXIS_N], NI));
shape.AppendDim(axis_value[AXIS_Co]);
shape.AppendDim(axis_value[AXIS_C0]);
return true;
}
bool TransferShapeUtils::GetFzWinoShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
shape.SetDimNum(0);
shape.AppendDim(axis_value[AXIS_C1]);
shape.AppendDim(DivisionCeiling(axis_value[AXIS_N], SHAPE_NUMBER_16));
shape.AppendDim(NUMBER_2);
int64_t hw = axis_value[AXIS_H];
MUL_OVERFLOW(hw, axis_value[AXIS_W], hw);
shape.AppendDim(hw);
shape.AppendDim(SHAPE_NUMBER_8);
shape.AppendDim(axis_value[AXIS_C0]);
return true;
}
bool TransferShapeUtils::GetNzShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
CHECK(shape.IsScalar(), GELOGD("Origin shape is empty."), return true);
size_t dim_size = shape.GetDimNum();
if (dim_size < DIM_SIZE_TWO) {
GELOGD("nd_value's dim num is less than 2.");
shape.AppendDim(1);
dim_size++;
}
* dim_size - 2 mean the second last value of original vec */
int64_t dim_back_two = shape.GetDim(dim_size - MINUS_VALUE_TWO);
int64_t dim_back_one = shape.GetDim(dim_size - MINUS_VALUE_ONE);
shape.SetDim((dim_size - MINUS_VALUE_ONE), DivisionCeiling(dim_back_two, axis_value[AXIS_M0]));
shape.SetDim((dim_size - MINUS_VALUE_TWO), DivisionCeiling(dim_back_one, axis_value[AXIS_C0]));
shape.AppendDim(axis_value[AXIS_M0]);
shape.AppendDim(axis_value[AXIS_C0]);
return true;
}
bool TransferShapeUtils::GetFzShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
size_t size_of_original_vec = shape.GetDimNum();
if (size_of_original_vec == DIM_SIZE_TWO) {
* size_of_original_vec - 2 mean the second last value of original vec */
shape.SetDim((size_of_original_vec - MINUS_VALUE_ONE),
DivisionCeiling(shape.GetDim(size_of_original_vec - MINUS_VALUE_ONE), SHAPE_NUMBER_16));
shape.SetDim((size_of_original_vec - MINUS_VALUE_TWO),
DivisionCeiling(shape.GetDim(size_of_original_vec - MINUS_VALUE_TWO), axis_value[AXIS_C0]));
shape.AppendDim(SHAPE_NUMBER_16);
shape.AppendDim(axis_value[AXIS_C0]);
return true;
}
return GetFz3DShapeByAxisValue(axis_value, shape);
}
bool TransferShapeUtils::GetFz3DShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
bool has_unknown_shape = axis_value[AXIS_D] == UNKNOWN_SHAPE_VALUE || axis_value[AXIS_H] == UNKNOWN_SHAPE_VALUE ||
axis_value[AXIS_W] == UNKNOWN_SHAPE_VALUE || axis_value[AXIS_C] == UNKNOWN_SHAPE_VALUE;
int64_t gdhwc1 = UNKNOWN_SHAPE_VALUE;
int64_t axis_g_val = GROUPS_DEFAULT_VALUE;
int64_t axis_n_val = axis_value[AXIS_N];
int64_t axis_c_val = axis_value[AXIS_C];
int64_t axis_c1_val = axis_value[AXIS_C1];
if (!has_unknown_shape) {
if (axis_value[AXIS_G] > GROUPS_DEFAULT_VALUE && axis_n_val >= axis_value[AXIS_G]) {
axis_n_val = axis_n_val / axis_value[AXIS_G];
int64_t enlarge_value = GetAsisEnlargeValue(axis_c_val, axis_n_val, axis_value[AXIS_C0], axis_value[AXIS_G]);
axis_g_val = DivisionCeiling(axis_value[AXIS_G], enlarge_value);
MUL_OVERFLOW(axis_c_val, enlarge_value, axis_c_val);
MUL_OVERFLOW(axis_n_val, enlarge_value, axis_n_val);
axis_c1_val = DivisionCeiling(axis_c_val, axis_value[AXIS_C0]);
}
MUL_OVERFLOW(axis_g_val, axis_c1_val, gdhwc1);
MUL_OVERFLOW(gdhwc1, axis_value[AXIS_D], gdhwc1);
MUL_OVERFLOW(gdhwc1, axis_value[AXIS_H], gdhwc1);
MUL_OVERFLOW(gdhwc1, axis_value[AXIS_W], gdhwc1);
}
shape.SetDimNum(0);
shape.AppendDim(gdhwc1);
shape.AppendDim(DivisionCeiling(axis_n_val, NI));
shape.AppendDim(NI);
shape.AppendDim(axis_value[AXIS_C0]);
return true;
}
bool TransferShapeUtils::GetFz3DTransposeShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
int64_t dhwn1 = UNKNOWN_SHAPE_VALUE;
if (axis_value[AXIS_N] != UNKNOWN_SHAPE_VALUE && axis_value[AXIS_H] != UNKNOWN_SHAPE_VALUE &&
axis_value[AXIS_W] != UNKNOWN_SHAPE_VALUE && axis_value[AXIS_D] != UNKNOWN_SHAPE_VALUE) {
dhwn1 = DivisionCeiling(axis_value[AXIS_N], NI);
MUL_OVERFLOW(dhwn1, axis_value[AXIS_D], dhwn1);
MUL_OVERFLOW(dhwn1, axis_value[AXIS_H], dhwn1);
MUL_OVERFLOW(dhwn1, axis_value[AXIS_W], dhwn1);
}
shape.SetDimNum(0);
shape.AppendDim(dhwn1);
if (axis_value[AXIS_C] == UNKNOWN_SHAPE_VALUE) {
shape.AppendDim(UNKNOWN_SHAPE_VALUE);
} else {
shape.AppendDim(axis_value[AXIS_C1]);
}
shape.AppendDim(NI);
shape.AppendDim(axis_value[AXIS_C0]);
return true;
}
bool TransferShapeUtils::GetFzLstmShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
int64_t h = axis_value[AXIS_N] >> NUMBER_2;
int64_t i = axis_value[AXIS_C] - h;
int64_t first_element_of_fz_lstm = DivisionCeiling(i, NI) + DivisionCeiling(h, NI);
int64_t second_element_of_fz_lstm = DivisionCeiling(h, NI);
MUL_OVERFLOW(second_element_of_fz_lstm, LSTM_NI, second_element_of_fz_lstm);
if (axis_value[AXIS_N] == UNKNOWN_SHAPE_VALUE || axis_value[AXIS_C] == UNKNOWN_SHAPE_VALUE) {
first_element_of_fz_lstm = UNKNOWN_SHAPE_VALUE;
second_element_of_fz_lstm = UNKNOWN_SHAPE_VALUE;
}
shape.SetDimNum(0);
shape.AppendDim(first_element_of_fz_lstm);
shape.AppendDim(second_element_of_fz_lstm);
shape.AppendDim(NI);
shape.AppendDim(NI);
return true;
}
bool TransferShapeUtils::GetFzC04ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
shape.SetDimNum(0);
if (axis_value[AXIS_H] == UNKNOWN_SHAPE_VALUE || axis_value[AXIS_W] == UNKNOWN_SHAPE_VALUE) {
shape.AppendDim(UNKNOWN_SHAPE_VALUE);
} else {
int64_t x = SHAPE_NUMBER_4;
MUL_OVERFLOW(x, axis_value[AXIS_H], x);
MUL_OVERFLOW(x, axis_value[AXIS_W], x);
shape.AppendDim(DivisionCeiling(x, axis_value[AXIS_C0]));
}
shape.AppendDim(DivisionCeiling(axis_value[AXIS_N], NI));
shape.AppendDim(NI);
shape.AppendDim(axis_value[AXIS_C0]);
return true;
}
bool TransferShapeUtils::GetFznRNNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
size_t origin_shape_size = shape.GetDimNum();
CHECK(origin_shape_size < DIM_SIZE_TWO, GELOGW("The number of dimensions in ndValue is less than 2!"), return true);
int64_t k_value = shape.GetDim(origin_shape_size - MINUS_VALUE_TWO);
int64_t hidden_or_state_size = axis_value[AXIS_HIDDEN_SIZE];
if (axis_value[AXIS_STATE_SIZE] != RNN_STATE_SIZE_DEFAULT_VALUE) {
hidden_or_state_size = axis_value[AXIS_STATE_SIZE];
}
if (k_value == hidden_or_state_size + axis_value[AXIS_INPUT_SIZE]) {
shape.SetDim(origin_shape_size - MINUS_VALUE_TWO,
DivisionCeiling(axis_value[AXIS_INPUT_SIZE], SHAPE_NUMBER_16) +
DivisionCeiling(hidden_or_state_size, SHAPE_NUMBER_16));
} else if (k_value == hidden_or_state_size || k_value == axis_value[AXIS_INPUT_SIZE]) {
shape.SetDim(origin_shape_size - MINUS_VALUE_TWO, DivisionCeiling(k_value, SHAPE_NUMBER_16));
} else {
return true;
}
int64_t n_value = shape.GetDim(origin_shape_size - MINUS_VALUE_ONE);
INT64_ZEROCHECK(axis_value[AXIS_HIDDEN_SIZE]);
int64_t n_num = n_value / axis_value[AXIS_HIDDEN_SIZE];
MUL_OVERFLOW(n_num, DivisionCeiling(axis_value[AXIS_HIDDEN_SIZE], axis_value[AXIS_C0]), n_num);
shape.SetDim(origin_shape_size - MINUS_VALUE_ONE, n_num);
shape.AppendDim(SHAPE_NUMBER_16);
shape.AppendDim(axis_value[AXIS_C0]);
return true;
}
bool TransferShapeUtils::GetNDRNNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) {
CHECK(axis_value[AXIS_HIDDEN_SIZE] == 0, GELOGD("hidden_size is zero"), return true);
size_t size_of_original_vec = shape.GetDimNum();
int64_t n_num = shape.GetDim(size_of_original_vec - MINUS_VALUE_ONE) / axis_value[AXIS_HIDDEN_SIZE];
MUL_OVERFLOW(n_num, DivisionCeiling(axis_value[AXIS_HIDDEN_SIZE], axis_value[AXIS_C0]), n_num);
MUL_OVERFLOW(n_num, axis_value[AXIS_C0], n_num);
shape.SetDim(size_of_original_vec - MINUS_VALUE_ONE, n_num);
return true;
}
bool TransferShapeUtils::GetNCHWShape(const FormatIndex& format_index, const gert::Shape &origin_shape,
gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true);
shape.SetDimNum(0);
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_C]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W]));
return true;
}
bool TransferShapeUtils::GetNHWCShape(const FormatIndex& format_index, const gert::Shape &origin_shape,
gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true);
shape.SetDimNum(0);
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_C]));
return true;
}
bool TransferShapeUtils::GetHWCNShape(const FormatIndex& format_index, const gert::Shape &origin_shape,
gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true);
shape.SetDimNum(0);
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_C]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N]));
return true;
}
bool TransferShapeUtils::GetCHWNShape(const FormatIndex& format_index, const gert::Shape &origin_shape,
gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true);
shape.SetDimNum(0);
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_C]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N]));
return true;
}
bool TransferShapeUtils::GetNDHWCShape(const FormatIndex& format_index, const gert::Shape &origin_shape,
gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FIVE, GELOGD("Dim size is less than 5."), return true);
shape.SetDimNum(0);
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_D]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_C]));
return true;
}
bool TransferShapeUtils::GetNCDHWShape(const FormatIndex& format_index, const gert::Shape &origin_shape,
gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FIVE, GELOGD("Dim size is less than 5."), return true);
shape.SetDimNum(0);
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_C]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_D]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W]));
return true;
}
bool TransferShapeUtils::GetDHWCNShape(const FormatIndex& format_index, const gert::Shape &origin_shape,
gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FIVE, GELOGD("Dim size is less than 5."), return true);
shape.SetDimNum(0);
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_D]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_C]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N]));
return true;
}
bool TransferShapeUtils::GetDHWNCShape(const FormatIndex& format_index, const gert::Shape &origin_shape,
gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FIVE, GELOGD("Dim size is less than 5."), return true);
shape.SetDimNum(0);
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_D]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_C]));
return true;
}
bool TransferShapeUtils::GetNC1HWC0Shape(const FormatIndex& format_index, const int64_t &c0,
const gert::Shape &origin_shape, gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true);
shape.SetDimNum(0);
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N]));
shape.AppendDim(DivisionCeiling(origin_shape.GetDim(format_index[DIM_INDEX_C]), c0));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W]));
shape.AppendDim(c0);
return true;
}
bool TransferShapeUtils::GetC1HWC0Shape(const FormatIndex &format_index, const int64_t &c0,
const gert::Shape &origin_shape, gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true);
shape.SetDimNum(0);
shape.AppendDim(DivisionCeiling(origin_shape.GetDim(format_index[DIM_INDEX_N]), c0));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W]));
shape.AppendDim(c0);
return true;
}
bool TransferShapeUtils::GetNDC1HWC0Shape(const FormatIndex& format_index, const int64_t &c0,
const gert::Shape &origin_shape, gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true);
shape.SetDimNum(0);
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N]));
if (origin_shape.GetDimNum() == DIM_SIZE_FOUR) {
shape.AppendDim(1);
} else {
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_D]));
}
shape.AppendDim(DivisionCeiling(origin_shape.GetDim(format_index[DIM_INDEX_C]), c0));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W]));
shape.AppendDim(c0);
return true;
}
bool TransferShapeUtils::GetC1HWNCoC0Shape(const FormatIndex& format_index, const int64_t &c0,
const gert::Shape &origin_shape, gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true);
shape.SetDimNum(0);
shape.AppendDim(DivisionCeiling(origin_shape.GetDim(format_index[DIM_INDEX_C]), c0));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H]));
shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W]));
shape.AppendDim(DivisionCeiling(origin_shape.GetDim(format_index[DIM_INDEX_N]), NI));
shape.AppendDim(c0);
shape.AppendDim(c0);
return true;
}
bool TransferShapeUtils::GetFractalNzShape(const int64_t &c0, const int64_t &m0,
const gert::Shape &origin_shape, gert::Shape &shape) {
size_t dim_size = origin_shape.GetDimNum();
shape.SetDimNum(0);
if (dim_size > DIM_SIZE_TWO) {
for (size_t i = 0; i < dim_size - DIM_SIZE_TWO; i++) {
shape.AppendDim(origin_shape.GetDim(i));
}
}
* dim_size - 2 mean the second last value of original vec */
if (dim_size < DIM_SIZE_TWO) {
shape.AppendDim(1);
shape.AppendDim(DivisionCeiling(origin_shape.GetDim(dim_size - MINUS_VALUE_ONE), m0));
} else {
shape.AppendDim(DivisionCeiling(origin_shape.GetDim(dim_size - MINUS_VALUE_ONE), c0));
shape.AppendDim(DivisionCeiling(origin_shape.GetDim(dim_size - MINUS_VALUE_TWO), m0));
}
shape.AppendDim(m0);
shape.AppendDim(c0);
return true;
}
bool TransferShapeUtils::GetFractalZShape(const int64_t &c0, const gert::Shape &origin_shape, gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() != DIM_SIZE_TWO, GELOGD("Dim size is not 2."), return true);
* size_of_original_vec - 2 mean the second last value of original vec */
shape.SetDimNum(0);
shape.AppendDim(DivisionCeiling(origin_shape.GetDim(DIM_SIZE_TWO - MINUS_VALUE_TWO), c0));
shape.AppendDim(DivisionCeiling(origin_shape.GetDim(DIM_SIZE_TWO - MINUS_VALUE_ONE), SHAPE_NUMBER_16));
shape.AppendDim(SHAPE_NUMBER_16);
shape.AppendDim(c0);
return true;
}
bool TransferShapeUtils::GetFractalZShape(const FormatIndex& format_index, const int64_t &c0, const int64_t &group,
const gert::Shape &origin_shape, gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true);
int64_t axis_n_val = origin_shape.GetDim(format_index[DIM_INDEX_N]);
int64_t axis_c_val = origin_shape.GetDim(format_index[DIM_INDEX_C]);
int64_t axis_h_val = origin_shape.GetDim(format_index[DIM_INDEX_H]);
int64_t axis_w_val = origin_shape.GetDim(format_index[DIM_INDEX_W]);
int64_t ghwc1 = UNKNOWN_SHAPE_VALUE;
bool is_unknown_shape = axis_c_val == UNKNOWN_SHAPE_VALUE || axis_h_val == UNKNOWN_SHAPE_VALUE ||
axis_w_val == UNKNOWN_SHAPE_VALUE;
if (!is_unknown_shape) {
int64_t axis_g_val = GROUPS_DEFAULT_VALUE;
int64_t axis_c1_val = 0;
if (group > GROUPS_DEFAULT_VALUE && axis_n_val >= group) {
axis_n_val = axis_n_val / group;
int64_t enlarge_value = GetAsisEnlargeValue(axis_c_val, axis_n_val, c0, group);
axis_g_val = DivisionCeiling(group, enlarge_value);
MUL_OVERFLOW(axis_c_val, enlarge_value, axis_c_val);
MUL_OVERFLOW(axis_n_val, enlarge_value, axis_n_val);
axis_c1_val = DivisionCeiling(axis_c_val, c0);
} else {
axis_c1_val = DivisionCeiling(axis_c_val, c0);
}
MUL_OVERFLOW(axis_g_val, axis_c1_val, ghwc1);
MUL_OVERFLOW(ghwc1, origin_shape.GetDim(format_index[DIM_INDEX_H]), ghwc1);
MUL_OVERFLOW(ghwc1, origin_shape.GetDim(format_index[DIM_INDEX_W]), ghwc1);
}
shape.SetDimNum(0);
shape.AppendDim(ghwc1);
shape.AppendDim(DivisionCeiling(axis_n_val, NI));
shape.AppendDim(NI);
shape.AppendDim(c0);
return true;
}
bool TransferShapeUtils::GetFractalZWinoShape(const FormatIndex& format_index, const int64_t &c0,
const gert::Shape &origin_shape, gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true);
int64_t axis_n_val = origin_shape.GetDim(format_index[DIM_INDEX_N]);
int64_t axis_c_val = origin_shape.GetDim(format_index[DIM_INDEX_C]);
int64_t axis_h_val = origin_shape.GetDim(format_index[DIM_INDEX_H]);
int64_t axis_w_val = origin_shape.GetDim(format_index[DIM_INDEX_W]);
int64_t ghw = axis_h_val;
MUL_OVERFLOW(ghw, axis_w_val, ghw);
shape.SetDimNum(0);
shape.AppendDim(DivisionCeiling(axis_c_val, c0));
shape.AppendDim(DivisionCeiling(axis_n_val, NI));
shape.AppendDim(NUMBER_2);
shape.AppendDim(ghw);
shape.AppendDim(SHAPE_NUMBER_8);
shape.AppendDim(c0);
return true;
}
bool TransferShapeUtils::GetFractalZ3DShape(const FormatIndex& format_index, const int64_t &c0, const int64_t &group,
const gert::Shape &origin_shape, gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FIVE, GELOGD("Dim size is less than 5."), return true);
int64_t axis_n_val = origin_shape.GetDim(format_index[DIM_INDEX_N]);
int64_t axis_c_val = origin_shape.GetDim(format_index[DIM_INDEX_C]);
int64_t axis_h_val = origin_shape.GetDim(format_index[DIM_INDEX_H]);
int64_t axis_w_val = origin_shape.GetDim(format_index[DIM_INDEX_W]);
int64_t axis_d_val = origin_shape.GetDim(format_index[DIM_INDEX_D]);
int64_t gdhwc1 = UNKNOWN_SHAPE_VALUE;
bool is_unknown_shape = axis_c_val == UNKNOWN_SHAPE_VALUE || axis_d_val == UNKNOWN_SHAPE_VALUE ||
axis_h_val == UNKNOWN_SHAPE_VALUE || axis_w_val == UNKNOWN_SHAPE_VALUE;
if (!is_unknown_shape) {
int64_t axis_c1_val = 0;
int64_t axis_g_val = GROUPS_DEFAULT_VALUE;
if (group > GROUPS_DEFAULT_VALUE && axis_n_val >= group) {
axis_n_val = axis_n_val / group;
int64_t enlarge_value = GetAsisEnlargeValue(axis_c_val, axis_n_val, c0, group);
axis_g_val = DivisionCeiling(group, enlarge_value);
MUL_OVERFLOW(axis_c_val, enlarge_value, axis_c_val);
MUL_OVERFLOW(axis_n_val, enlarge_value, axis_n_val);
axis_c1_val = DivisionCeiling(axis_c_val, c0);
} else {
axis_c1_val = DivisionCeiling(axis_c_val, c0);
}
MUL_OVERFLOW(axis_g_val, axis_c1_val, gdhwc1);
MUL_OVERFLOW(gdhwc1, origin_shape.GetDim(format_index[DIM_INDEX_D]), gdhwc1);
MUL_OVERFLOW(gdhwc1, origin_shape.GetDim(format_index[DIM_INDEX_H]), gdhwc1);
MUL_OVERFLOW(gdhwc1, origin_shape.GetDim(format_index[DIM_INDEX_W]), gdhwc1);
}
shape.SetDimNum(0);
shape.AppendDim(gdhwc1);
shape.AppendDim(DivisionCeiling(axis_n_val, NI));
shape.AppendDim(NI);
shape.AppendDim(c0);
return true;
}
bool TransferShapeUtils::GetFractalZ3DTransposeShape(const FormatIndex& format_index, const int64_t &c0,
const gert::Shape &origin_shape, gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FIVE, GELOGD("Dim size is less than 5."), return true);
int64_t dhwn1 = UNKNOWN_SHAPE_VALUE;
if (origin_shape.GetDim(format_index[DIM_INDEX_N]) != UNKNOWN_SHAPE_VALUE &&
origin_shape.GetDim(format_index[DIM_INDEX_H]) != UNKNOWN_SHAPE_VALUE &&
origin_shape.GetDim(format_index[DIM_INDEX_W]) != UNKNOWN_SHAPE_VALUE &&
origin_shape.GetDim(format_index[DIM_INDEX_D]) != UNKNOWN_SHAPE_VALUE) {
dhwn1 = DivisionCeiling(origin_shape.GetDim(format_index[DIM_INDEX_N]), NI);
MUL_OVERFLOW(dhwn1, origin_shape.GetDim(format_index[DIM_INDEX_D]), dhwn1);
MUL_OVERFLOW(dhwn1, origin_shape.GetDim(format_index[DIM_INDEX_H]), dhwn1);
MUL_OVERFLOW(dhwn1, origin_shape.GetDim(format_index[DIM_INDEX_W]), dhwn1);
}
shape.SetDimNum(0);
shape.AppendDim(dhwn1);
if (origin_shape.GetDim(format_index[DIM_INDEX_C]) == UNKNOWN_SHAPE_VALUE) {
shape.AppendDim(UNKNOWN_SHAPE_VALUE);
} else {
shape.AppendDim(DivisionCeiling(origin_shape.GetDim(format_index[DIM_INDEX_C]), c0));
}
shape.AppendDim(NI);
shape.AppendDim(c0);
return true;
}
bool TransferShapeUtils::GetFractalZLstmShape(const FormatIndex& format_index, const gert::Shape &origin_shape,
gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true);
int64_t axis_n_val = origin_shape.GetDim(format_index[DIM_INDEX_N]);
int64_t axis_c_val = origin_shape.GetDim(format_index[DIM_INDEX_C]);
int64_t h = axis_n_val / LSTM_NI;
int64_t i = axis_c_val - h;
int64_t first_element_of_fz_lstm = DivisionCeiling(i, NI) + DivisionCeiling(h, NI);
int64_t second_element_of_fz_lstm = DivisionCeiling(h, NI);
MUL_OVERFLOW(second_element_of_fz_lstm, LSTM_NI, second_element_of_fz_lstm);
if (axis_n_val == UNKNOWN_SHAPE_VALUE || axis_c_val == UNKNOWN_SHAPE_VALUE) {
first_element_of_fz_lstm = UNKNOWN_SHAPE_VALUE;
second_element_of_fz_lstm = UNKNOWN_SHAPE_VALUE;
}
shape.SetDimNum(0);
shape.AppendDim(first_element_of_fz_lstm);
shape.AppendDim(second_element_of_fz_lstm);
shape.AppendDim(NI);
shape.AppendDim(NI);
return true;
}
bool TransferShapeUtils::GetFractalZC04Shape(const FormatIndex& format_index, const int64_t &c0,
const gert::Shape &origin_shape, gert::Shape &shape) {
CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true);
int64_t axis_h_val = origin_shape.GetDim(format_index[DIM_INDEX_H]);
int64_t axis_w_val = origin_shape.GetDim(format_index[DIM_INDEX_W]);
shape.SetDimNum(0);
if (axis_h_val == UNKNOWN_SHAPE_VALUE || axis_w_val == UNKNOWN_SHAPE_VALUE) {
shape.AppendDim(UNKNOWN_SHAPE_VALUE);
} else {
int64_t x = SHAPE_NUMBER_4;
MUL_OVERFLOW(x, origin_shape.GetDim(format_index[DIM_INDEX_H]), x);
MUL_OVERFLOW(x, origin_shape.GetDim(format_index[DIM_INDEX_W]), x);
shape.AppendDim(DivisionCeiling(x, c0));
}
shape.AppendDim(DivisionCeiling(origin_shape.GetDim(format_index[DIM_INDEX_N]), NI));
shape.AppendDim(NI);
shape.AppendDim(c0);
return true;
}
bool TransferShapeUtils::GetFractalZnRnnShape(const ExtAxisValue &ext_axis, const int64_t &c0,
const gert::Shape &origin_shape, gert::Shape &shape) {
size_t origin_shape_size = origin_shape.GetDimNum();
CHECK(origin_shape_size < DIM_SIZE_TWO, GELOGD("Dim size is less than 2."), return true);
shape.SetDimNum(0);
for (size_t i = 0; i < origin_shape_size - DIM_SIZE_TWO; i++) {
shape.AppendDim(origin_shape.GetDim(i));
}
int64_t k_value = origin_shape.GetDim(origin_shape_size - MINUS_VALUE_TWO);
int64_t hidden_or_state_size = ext_axis[EXT_INDEX_HIDDEN_SIZE];
if (ext_axis[EXT_INDEX_STATE_SIZE] != RNN_STATE_SIZE_DEFAULT_VALUE) {
hidden_or_state_size = ext_axis[EXT_INDEX_STATE_SIZE];
}
if (k_value == hidden_or_state_size + ext_axis[EXT_INDEX_INPUT_SIZE]) {
shape.AppendDim(DivisionCeiling(ext_axis[EXT_INDEX_INPUT_SIZE], SHAPE_NUMBER_16) +
DivisionCeiling(hidden_or_state_size, SHAPE_NUMBER_16));
} else if (k_value == hidden_or_state_size || k_value == ext_axis[EXT_INDEX_INPUT_SIZE]) {
shape.AppendDim(DivisionCeiling(k_value, SHAPE_NUMBER_16));
} else {
return true;
}
int64_t n_value = origin_shape.GetDim(origin_shape_size - MINUS_VALUE_ONE);
INT64_ZEROCHECK(ext_axis[EXT_INDEX_HIDDEN_SIZE]);
int64_t n_num = n_value / ext_axis[EXT_INDEX_HIDDEN_SIZE];
MUL_OVERFLOW(n_num, DivisionCeiling(ext_axis[EXT_INDEX_HIDDEN_SIZE], c0), n_num);
shape.AppendDim(n_num);
shape.AppendDim(SHAPE_NUMBER_16);
shape.AppendDim(c0);
return true;
}
bool TransferShapeUtils::GetNdRnnBiasShape(const ExtAxisValue &ext_axis, const int64_t &c0,
const gert::Shape &origin_shape, gert::Shape &shape) {
CHECK(ext_axis[EXT_INDEX_HIDDEN_SIZE] == 0, GELOGD("hidden_size is zero"), return true);
size_t size_of_original_vec = origin_shape.GetDimNum();
shape.SetDimNum(0);
for (size_t i = 0; i < size_of_original_vec - MINUS_VALUE_ONE; i++) {
shape.AppendDim(origin_shape.GetDim(i));
}
int64_t n_num = origin_shape.GetDim(size_of_original_vec - MINUS_VALUE_ONE) / ext_axis[EXT_INDEX_HIDDEN_SIZE];
MUL_OVERFLOW(n_num, DivisionCeiling(ext_axis[EXT_INDEX_HIDDEN_SIZE], c0), n_num);
MUL_OVERFLOW(n_num, c0, n_num);
shape.AppendDim(n_num);
return true;
}
bool TransferShapeUtils::GetNYUVShape(gert::Shape &shape) {
const size_t kSize4 = 4U;
const size_t kSize3 = 3U;
size_t shape_size = shape.GetDimNum();
CHECK(((shape_size != kSize3) && (shape_size != kSize4)),
GELOGD("Dim size is not 3 or 4."), return false);
const size_t kWdimOffset = 2U;
const size_t kHdimOffset = 3U;
const int64_t kAlaign64 = 64;
const int64_t kAlaign16 = 16;
auto width = shape.GetDim(shape_size - kWdimOffset);
auto height = shape.GetDim(shape_size - kHdimOffset);
width = (width + kAlaign64 - 1) / kAlaign64 * kAlaign64;
height = (height + kAlaign16 - 1) / kAlaign16 * kAlaign16;
shape.SetDim(shape_size - kWdimOffset, width);
shape.SetDim(shape_size - kHdimOffset, height);
return true;
}
TransferShapeType TransferShapeUtils::GetTransferShapeType(const ge::Format &src_format, const ge::Format &dst_format,
const gert::Shape &src_shape) {
auto primary_src_format = static_cast<ge::Format>(ge::GetPrimaryFormat(src_format));
auto primary_dst_format = static_cast<ge::Format>(ge::GetPrimaryFormat(dst_format));
if (primary_src_format == primary_dst_format) {
return TransferShapeType::ND_TO_ND;
} else {
if (primary_src_format == ge::FORMAT_ND || kFormatNZSet.count(primary_dst_format) > 0) {
return TransferShapeType::ND_TO_NZ;
}
size_t src_format_full_size = 0;
if (!ExpandDimension::GetFormatFullSize(primary_src_format, src_format_full_size)) {
return TransferShapeType::INVALID;
}
if (src_shape.GetDimNum() == src_format_full_size) {
return TransferShapeType::FULL_SIZE;
}
if (src_shape.GetDimNum() < src_format_full_size) {
return TransferShapeType::NOT_FULL_SIZE;
}
}
return TransferShapeType::INVALID;
}
bool TransferShapeUtils::GetAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape) {
auto func = get_aligned_shape_func_map.find(GetTransferShapeType(align_shape_info.src_format,
align_shape_info.dst_format, align_shape_info.src_shape));
if (func == get_aligned_shape_func_map.end()) {
GELOGW("Do not support src_format %s to src_shape %s.",
ge::TypeUtils::FormatToSerialString(align_shape_info.src_format).c_str(),
ge::TypeUtils::FormatToSerialString(align_shape_info.dst_format).c_str());
return false;
}
return func->second(align_shape_info, aligned_shape);
}
bool TransferShapeUtils::TransferDims(const TransferDimsInfo &transfer_dims_info,
AxisIndexMapping &axis_index_mapping) {
auto func = transfer_dims_func_map.find(GetTransferShapeType(transfer_dims_info.src_format,
transfer_dims_info.dst_format, transfer_dims_info.src_shape));
if (func == transfer_dims_func_map.end()) {
GELOGW("Do not support src_format %s to src_shape %s.",
ge::TypeUtils::FormatToSerialString(transfer_dims_info.src_format).c_str(),
ge::TypeUtils::FormatToSerialString(transfer_dims_info.dst_format).c_str());
return false;
}
return func->second(transfer_dims_info, axis_index_mapping);
}
* nd-to-nd
* The aligned value for each axis is set to 1.
**/
bool TransferShapeUtils::GetNdToNdAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape) {
GELOGD("There are no alignment requirements for the nd-to-nd scenario.");
(void)aligned_shape.SetDimNum(align_shape_info.src_shape.GetDimNum());
for (size_t i = 0; i < align_shape_info.src_shape.GetDimNum(); ++i) {
(void)aligned_shape.SetDim(i, 1);
}
return true;
}
* nd-to-nz
* eg: ND(n, a, b) -> NZ(n, b//c0, a//m0, m0, c0)
* |
* target
* The aligned value for the axis before target is set to 1.
* The aligned value for the target axis is set to m0.
* The aligned value for the axis behind target is set to c0.
**/
bool TransferShapeUtils::GetNdToNzAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape) {
size_t src_shape_dim_num = align_shape_info.src_shape.GetDimNum();
if (src_shape_dim_num < kNzMinDimNum) {
GELOGW("The src_shape_dim_num %zu is invalid for the nd-to-nz scenario.", src_shape_dim_num);
return false;
}
ge::Format format = static_cast<ge::Format>(ge::GetPrimaryFormat(align_shape_info.dst_format));
if (kFormatNZSet.count(format) == 0) {
GELOGW("Unsupported dst_format %s for nd-to-nz scenario.",
ge::TypeUtils::FormatToSerialString(align_shape_info.dst_format).c_str());
return false;
}
(void)aligned_shape.SetDimNum(src_shape_dim_num);
size_t target_index = src_shape_dim_num - 2;
for (size_t i = 0; i < target_index; ++i) {
(void)aligned_shape.SetDim(i, 1);
}
int64_t m0 = GetM0ByDtype(align_shape_info.data_type);
int64_t c0 = GetC0ByDtype(align_shape_info.data_type);
(void)aligned_shape.SetDim(target_index, m0);
(void)aligned_shape.SetDim(target_index + 1, c0);
return true;
}
* full-size: no need to expand dims
* NCHW/HWCN/NHWC/CHWN(a,b,c,d) -> FZ(c//c0*h*w, n//n0, n0, c0)
* NCHW/HWCN/NHWC/CHWN(a,b,c,d) -> NCHW/HWCN/NHWC/CHWN/NC1HWC0
* NDHWC/DHWCN/DHWNC(a,b,c,d,e) -> NDHWC/DHWCN/DHWNC/NDC1HWC0
* Traverse each axis, if current axis is "C" and needs to be split, the aligned value is set to c0.
* Otherwise, set it to 1.
**/
bool TransferShapeUtils::GetFullSizeAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape) {
std::vector<std::string> src_axis_vec = AxisUtil::GetAxisVecByFormat(align_shape_info.src_format);
if (src_axis_vec.empty()) {
GELOGW("Does not support src_format %s for dst_format in full-size scene.",
ge::TypeUtils::FormatToSerialString(align_shape_info.src_format).c_str(),
ge::TypeUtils::FormatToSerialString(align_shape_info.dst_format).c_str());
return false;
}
int64_t n0 = GetN0ByDtype(align_shape_info.data_type);
int64_t c0 = GetC0ByDtype(align_shape_info.data_type);
(void)aligned_shape.SetDimNum(src_axis_vec.size());
for (size_t i = 0; i < src_axis_vec.size(); ++i) {
if (AxisUtil::GetSplitOrConcatAxisByFormat(align_shape_info.dst_format, src_axis_vec.at(i)).size() > 1) {
if (src_axis_vec.at(i) == "N") {
(void)aligned_shape.SetDim(i, n0);
} else if (src_axis_vec.at(i) == "C") {
(void)aligned_shape.SetDim(i, c0);
} else {
(void)aligned_shape.SetDim(i, 1);
}
} else {
(void)aligned_shape.SetDim(i, 1);
}
}
return true;
}
* not-full-size: need to expand dims
* NCHW/HWCN/NHWC/CHWN(a,b) -> FZ(c//c0*h*w, n//n0, n0, c0)
* NCHW/HWCN/NHWC/CHWN(a,b) -> NC1HWC0
* NDHWC/DHWCN/DHWNC(a,b,c) -> NDC1HWC0
* Traverse each reshape type axis, if current axis is "C", the aligned value is set to c0.
* Otherwise, set it to 1.
**/
bool TransferShapeUtils::GetNotFullSizeAlignedShape(const AlignShapeInfo &align_shape_info,
gert::Shape &aligned_shape) {
if (align_shape_info.reshape_type_mask <= 0) {
GELOGW("The reshape type mask %lld is invalid in a non-full-size scenario.", align_shape_info.reshape_type_mask);
return false;
}
std::vector<std::string> reshape_type_axis_vec =
AxisUtil::GetReshapeTypeAxisVec(align_shape_info.src_format, align_shape_info.reshape_type_mask);
if (reshape_type_axis_vec.empty()) {
GELOGW("Does not support conversion from src_format %s to dst_format %s for non-full-size scenarios.",
ge::TypeUtils::FormatToSerialString(align_shape_info.src_format).c_str(),
ge::TypeUtils::FormatToSerialString(align_shape_info.dst_format).c_str());
return false;
}
int64_t n0 = GetN0ByDtype(align_shape_info.data_type);
int64_t c0 = GetC0ByDtype(align_shape_info.data_type);
(void)aligned_shape.SetDimNum(reshape_type_axis_vec.size());
for (size_t i = 0; i < reshape_type_axis_vec.size(); ++i) {
if (AxisUtil::GetSplitOrConcatAxisByFormat(align_shape_info.dst_format, reshape_type_axis_vec.at(i)).size() > 1) {
if (reshape_type_axis_vec.at(i) == "N") {
(void)aligned_shape.SetDim(i, n0);
} else if (reshape_type_axis_vec.at(i) == "C") {
(void)aligned_shape.SetDim(i, c0);
} else {
(void)aligned_shape.SetDim(i, 1);
}
} else {
(void)aligned_shape.SetDim(i, 1);
}
}
return true;
}
* nd-to-nd
* No change.
**/
bool TransferShapeUtils::GetNdToNdAxisIndexMapping(const TransferDimsInfo &transfer_dims_info,
AxisIndexMapping &axis_index_mapping) {
for (int32_t i = 0; i < static_cast<int32_t>(transfer_dims_info.src_shape.GetDimNum()); ++i) {
std::vector<int32_t> cur_transfer_dims;
cur_transfer_dims.emplace_back(i);
axis_index_mapping.src_to_dst_transfer_dims.emplace_back(cur_transfer_dims);
}
axis_index_mapping.dst_to_src_transfer_dims = axis_index_mapping.src_to_dst_transfer_dims;
return true;
}
* nd-to-nz
* eg: ND(n, a, b) -> NZ(n, b//c0, a//m0, m0, c0)
* |
* target
* The aligned value for the axis before target is set to 1.
* The aligned value for the target axis is set to m0.
* The aligned value for the axis behind target is set to c0.
**/
bool TransferShapeUtils::GetNdToNzAxisIndexMapping(const TransferDimsInfo &transfer_dims_info,
AxisIndexMapping &axis_index_mapping) {
size_t src_shape_dim_num = transfer_dims_info.src_shape.GetDimNum();
if (src_shape_dim_num < kNzMinDimNum) {
GELOGW("The src_shape_dim_num %zu is invalid for the nd-to-nz scenario.", src_shape_dim_num);
return false;
}
ge::Format format = static_cast<ge::Format>(GetPrimaryFormat(transfer_dims_info.dst_format));
if (kFormatNZSet.count(format) == 0) {
GELOGW("Unsupported dst_format %s for nd-to-nz scenario.",
ge::TypeUtils::FormatToSerialString(transfer_dims_info.dst_format).c_str());
return false;
}
int32_t target_index = static_cast<int32_t>(src_shape_dim_num - 2);
for (int32_t i = 0; i < target_index; ++i) {
std::vector<int32_t> cur_transfer_dims;
cur_transfer_dims.emplace_back(i);
axis_index_mapping.src_to_dst_transfer_dims.emplace_back(cur_transfer_dims);
}
std::vector<int32_t> cur_transfer_dims;
cur_transfer_dims.emplace_back(target_index + 1);
axis_index_mapping.src_to_dst_transfer_dims.emplace_back(cur_transfer_dims);
cur_transfer_dims.clear();
cur_transfer_dims.emplace_back(target_index);
axis_index_mapping.src_to_dst_transfer_dims.emplace_back(cur_transfer_dims);
axis_index_mapping.dst_to_src_transfer_dims.insert(axis_index_mapping.dst_to_src_transfer_dims.end(),
axis_index_mapping.src_to_dst_transfer_dims.begin(), axis_index_mapping.src_to_dst_transfer_dims.end());
cur_transfer_dims.clear();
cur_transfer_dims.emplace_back(-1);
axis_index_mapping.dst_to_src_transfer_dims.emplace_back(cur_transfer_dims);
axis_index_mapping.dst_to_src_transfer_dims.emplace_back(cur_transfer_dims);
return true;
}
* full-size: no need to expand dims
* NCHW/HWCN/NHWC/CHWN(a,b,c,d) -> FZ(c//c0*h*w, n//n0, n0, c0)
* NCHW/HWCN/NHWC/CHWN(a,b,c,d) -> NCHW/HWCN/NHWC/CHWN/NC1HWC0
* NDHWC/DHWCN/DHWNC(a,b,c,d,e) -> NDHWC/DHWCN/DHWNC/NDC1HWC0
* Traverse each axis in src_format, and the get its mapping-index in dst_format.
* If current axis is "C" and needs to be split, get "C0" and "C1" mapping-index in dst_format.
* Traverse each axis in dst_format, and the get its mapping-index in src_format.
* If current axis is "C0" or "C1", get "C" mapping-index in src_format.
**/
bool TransferShapeUtils::GetFullSizeAxisIndexMapping(const TransferDimsInfo &transfer_dims_info,
AxisIndexMapping &axis_index_mapping) {
std::vector<std::string> src_axis_vec = AxisUtil::GetAxisVecByFormat(transfer_dims_info.src_format);
if (src_axis_vec.empty()) {
GELOGW("Does not support conversion from src_format %s to dst_format %s for full-size scenes.",
ge::TypeUtils::FormatToSerialString(transfer_dims_info.src_format).c_str(),
ge::TypeUtils::FormatToSerialString(transfer_dims_info.dst_format).c_str());
return false;
}
for (auto &cur_axis : src_axis_vec) {
std::vector<int32_t> cur_transfer_dims;
std::vector<std::string> split_or_concat_axis =
AxisUtil::GetSplitOrConcatAxisByFormat(transfer_dims_info.dst_format, cur_axis);
if (!split_or_concat_axis.empty()) {
for (auto &cur_split_or_concat_axis : split_or_concat_axis) {
cur_transfer_dims.emplace_back(AxisUtil::GetAxisIndexByFormat(transfer_dims_info.dst_format,
cur_split_or_concat_axis));
}
} else {
cur_transfer_dims.emplace_back(AxisUtil::GetAxisIndexByFormat(transfer_dims_info.dst_format, cur_axis));
}
axis_index_mapping.src_to_dst_transfer_dims.emplace_back(cur_transfer_dims);
}
std::vector<std::string> dst_axis_vec = AxisUtil::GetAxisVecByFormat(transfer_dims_info.dst_format);
if (dst_axis_vec.empty()) {
GELOGW("Does not support conversion from src_format %s to dst_format %s for full-size scenes.",
ge::TypeUtils::FormatToSerialString(transfer_dims_info.src_format).c_str(),
ge::TypeUtils::FormatToSerialString(transfer_dims_info.dst_format).c_str());
return false;
}
for (auto &cur_axis : dst_axis_vec) {
std::vector<int32_t> cur_transfer_dims;
std::vector<std::string> split_or_concat_axis =
AxisUtil::GetSplitOrConcatAxisByFormat(transfer_dims_info.dst_format, cur_axis);
if (!split_or_concat_axis.empty()) {
for (auto cur_split_or_concat_axis : split_or_concat_axis) {
cur_transfer_dims.emplace_back(AxisUtil::GetAxisIndexByFormat(transfer_dims_info.src_format,
cur_split_or_concat_axis));
}
} else {
cur_transfer_dims.emplace_back(AxisUtil::GetAxisIndexByFormat(transfer_dims_info.src_format, cur_axis));
}
axis_index_mapping.dst_to_src_transfer_dims.emplace_back(cur_transfer_dims);
}
return true;
}
* not-full-size: need to expand dims
* NCHW/HWCN/NHWC/CHWN(a,b) -> NC1HWC0
* NDHWC/DHWCN/DHWNC(a,b,c) -> NDC1HWC0
* Traverse each axis in reshape_type, and the get its mapping-index in dst_format.
* If current axis is "C" and needs to be split, get "C0" and "C1" mapping-index in dst_format.
* Traverse each axis in dst_format, and the get its mapping-index in src_format.
* If current axis is "C0" or "C1", get "C" mapping-index in src_format.
**/
bool TransferShapeUtils::GetNotFullSizeAxisIndexMapping(const TransferDimsInfo &transfer_dims_info,
AxisIndexMapping &axis_index_mapping) {
if (transfer_dims_info.reshape_type_mask <= 0) {
GELOGW("Does not support conversion from src_format %s to dst_format %s for non-full-size scenarios.",
ge::TypeUtils::FormatToSerialString(transfer_dims_info.src_format).c_str(),
ge::TypeUtils::FormatToSerialString(transfer_dims_info.dst_format).c_str());
return false;
}
std::vector<std::string> reshape_type_axis_vec =
AxisUtil::GetReshapeTypeAxisVec(transfer_dims_info.src_format, transfer_dims_info.reshape_type_mask);
if (reshape_type_axis_vec.empty()) {
GELOGW("Does not support conversion from src_format %s to dst_format %s for non-full-size scenarios.",
ge::TypeUtils::FormatToSerialString(transfer_dims_info.src_format).c_str(),
ge::TypeUtils::FormatToSerialString(transfer_dims_info.dst_format).c_str());
return false;
}
for (auto &cur_axis : reshape_type_axis_vec) {
std::vector<int32_t> cur_transfer_dims;
std::vector<std::string> split_or_concat_axis =
AxisUtil::GetSplitOrConcatAxisByFormat(transfer_dims_info.dst_format, cur_axis);
if (!split_or_concat_axis.empty()) {
for (auto &cur_split_or_concat_axis : split_or_concat_axis) {
cur_transfer_dims.emplace_back(
AxisUtil::GetAxisIndexByFormat(transfer_dims_info.dst_format, cur_split_or_concat_axis));
}
} else {
cur_transfer_dims.emplace_back(AxisUtil::GetAxisIndexByFormat(transfer_dims_info.dst_format, cur_axis));
}
axis_index_mapping.src_to_dst_transfer_dims.emplace_back(cur_transfer_dims);
}
std::vector<std::string> dst_axis_vec = AxisUtil::GetAxisVecByFormat(transfer_dims_info.dst_format);
if (dst_axis_vec.empty()) {
return false;
}
std::map<std::string, int32_t> valid_axis_map =
AxisUtil::GetReshapeTypeAxisMap(transfer_dims_info.src_format, transfer_dims_info.reshape_type_mask);
for (auto &cur_axis : dst_axis_vec) {
std::vector<int32_t> cur_transfer_dims;
std::vector<std::string> split_or_concat_axis =
AxisUtil::GetSplitOrConcatAxisByFormat(transfer_dims_info.dst_format, cur_axis);
if (!split_or_concat_axis.empty()) {
for (auto &cur_split_or_concat_axis : split_or_concat_axis) {
cur_transfer_dims.emplace_back(
AxisUtil::GetAxisIndexByFormat(transfer_dims_info.src_format, cur_split_or_concat_axis, valid_axis_map));
}
} else {
cur_transfer_dims.emplace_back(
AxisUtil::GetAxisIndexByFormat(transfer_dims_info.src_format, cur_axis, valid_axis_map));
}
axis_index_mapping.dst_to_src_transfer_dims.emplace_back(cur_transfer_dims);
}
return true;
}
}