* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "expand_dimension.h"
#include <map>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <bitset>
#include "axis_constants.h"
#include "exe_graph/runtime/expand_dims_type.h"
#include "graph/types.h"
#include "framework/common/debug/ge_log.h"
#include "graph/ge_error_codes.h"
#include "graph/utils/type_utils.h"
namespace transformer {
namespace {
const std::string RESHAPE_TYPE_FORBIDDEN = "FORBIDDEN";
const uint32_t kBitsOfByte = 8;
const uint32_t kBitSetDisplaySize = 8;
const uint32_t kMaxReshapeTypeSize = 56;
const std::set<ge::Format> kSupportedTransFormat = {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ_C0_2,
ge::FORMAT_FRACTAL_NZ_C0_4, ge::FORMAT_FRACTAL_NZ_C0_8,
ge::FORMAT_FRACTAL_NZ_C0_16, ge::FORMAT_FRACTAL_NZ_C0_32,
ge::FORMAT_ND_RNN_BIAS, ge::FORMAT_FRACTAL_ZN_RNN};
const std::map<ge::Format, size_t> FULL_SIZE_OF_FORMAT {
{ge::FORMAT_NCHW, DIM_SIZE_FOUR},
{ge::FORMAT_NHWC, DIM_SIZE_FOUR},
{ge::FORMAT_HWCN, DIM_SIZE_FOUR},
{ge::FORMAT_CHWN, DIM_SIZE_FOUR},
{ge::FORMAT_NDHWC, DIM_SIZE_FIVE},
{ge::FORMAT_NCDHW, DIM_SIZE_FIVE},
{ge::FORMAT_DHWCN, DIM_SIZE_FIVE},
{ge::FORMAT_DHWNC, DIM_SIZE_FIVE},
{ge::FORMAT_ND, DIM_SIZE_FOUR}
};
inline uint32_t GenerateFormatKey(ge::Format format) {
return ((static_cast<uint32_t>(format) & 0xff) << kBitsOfByte);
}
inline uint32_t GenerateReshapeTypeKey(ge::Format format, size_t size) {
return ((static_cast<uint32_t>(format) & 0xff) << kBitsOfByte) | (static_cast<uint32_t>(size) & 0xff);
}
inline uint32_t GenerateAxisIndexKey(ge::Format format, char ch) {
return ((static_cast<uint32_t>(format) & 0xff) << kBitsOfByte) | (static_cast<uint32_t>(ch) & 0xff);
}
const std::unordered_map<uint32_t, std::string> DEFAULT_RESHAPE_TYPE {
{GenerateReshapeTypeKey(ge::FORMAT_NCHW, 0), ""},
{GenerateReshapeTypeKey(ge::FORMAT_NHWC, 0), ""},
{GenerateReshapeTypeKey(ge::FORMAT_HWCN, 0), ""},
{GenerateReshapeTypeKey(ge::FORMAT_CHWN, 0), ""},
{GenerateReshapeTypeKey(ge::FORMAT_NDHWC, 0), ""},
{GenerateReshapeTypeKey(ge::FORMAT_NCDHW, 0), ""},
{GenerateReshapeTypeKey(ge::FORMAT_DHWCN, 0), ""},
{GenerateReshapeTypeKey(ge::FORMAT_DHWNC, 0), ""},
{GenerateReshapeTypeKey(ge::FORMAT_NCHW, 1), "C"},
{GenerateReshapeTypeKey(ge::FORMAT_NHWC, 1), "C"},
{GenerateReshapeTypeKey(ge::FORMAT_HWCN, 1), "C"},
{GenerateReshapeTypeKey(ge::FORMAT_CHWN, 1), "C"},
{GenerateReshapeTypeKey(ge::FORMAT_NDHWC, 1), "C"},
{GenerateReshapeTypeKey(ge::FORMAT_NCDHW, 1), "C"},
{GenerateReshapeTypeKey(ge::FORMAT_DHWCN, 1), "C"},
{GenerateReshapeTypeKey(ge::FORMAT_DHWNC, 1), "C"},
{GenerateReshapeTypeKey(ge::FORMAT_NCHW, 2), "CH"},
{GenerateReshapeTypeKey(ge::FORMAT_NHWC, 2), "HW"},
{GenerateReshapeTypeKey(ge::FORMAT_HWCN, 2), "CN"},
{GenerateReshapeTypeKey(ge::FORMAT_CHWN, 2), "WN"},
{GenerateReshapeTypeKey(ge::FORMAT_NDHWC, 2), "WC"},
{GenerateReshapeTypeKey(ge::FORMAT_NCDHW, 2), "HW"},
{GenerateReshapeTypeKey(ge::FORMAT_DHWCN, 2), "CN"},
{GenerateReshapeTypeKey(ge::FORMAT_DHWNC, 2), "NC"},
{GenerateReshapeTypeKey(ge::FORMAT_NCHW, 3), "CHW"},
{GenerateReshapeTypeKey(ge::FORMAT_NHWC, 3), "HWC"},
{GenerateReshapeTypeKey(ge::FORMAT_HWCN, 3), "WCN"},
{GenerateReshapeTypeKey(ge::FORMAT_CHWN, 3), "HWN"},
{GenerateReshapeTypeKey(ge::FORMAT_NDHWC, 3), "HWC"},
{GenerateReshapeTypeKey(ge::FORMAT_NCDHW, 3), "DHW"},
{GenerateReshapeTypeKey(ge::FORMAT_DHWCN, 3), "WCN"},
{GenerateReshapeTypeKey(ge::FORMAT_DHWNC, 3), "WNC"},
{GenerateReshapeTypeKey(ge::FORMAT_NDHWC, 4), "DHWC"},
{GenerateReshapeTypeKey(ge::FORMAT_NCDHW, 4), "CDHW"},
{GenerateReshapeTypeKey(ge::FORMAT_DHWCN, 4), "HWCN"},
{GenerateReshapeTypeKey(ge::FORMAT_DHWNC, 4), "HWNC"}
};
const std::unordered_map<uint32_t, int32_t> AXIS_INDEX_OF_FORMAT {
{GenerateAxisIndexKey(ge::FORMAT_NCHW, 'N'), AXIS_NCHW_DIM_N},
{GenerateAxisIndexKey(ge::FORMAT_NCHW, 'C'), AXIS_NCHW_DIM_C},
{GenerateAxisIndexKey(ge::FORMAT_NCHW, 'H'), AXIS_NCHW_DIM_H},
{GenerateAxisIndexKey(ge::FORMAT_NCHW, 'W'), AXIS_NCHW_DIM_W},
{GenerateAxisIndexKey(ge::FORMAT_HWCN, 'N'), AXIS_HWCN_DIM_N},
{GenerateAxisIndexKey(ge::FORMAT_HWCN, 'C'), AXIS_HWCN_DIM_C},
{GenerateAxisIndexKey(ge::FORMAT_HWCN, 'H'), AXIS_HWCN_DIM_H},
{GenerateAxisIndexKey(ge::FORMAT_HWCN, 'W'), AXIS_HWCN_DIM_W},
{GenerateAxisIndexKey(ge::FORMAT_NHWC, 'N'), AXIS_NHWC_DIM_N},
{GenerateAxisIndexKey(ge::FORMAT_NHWC, 'C'), AXIS_NHWC_DIM_C},
{GenerateAxisIndexKey(ge::FORMAT_NHWC, 'H'), AXIS_NHWC_DIM_H},
{GenerateAxisIndexKey(ge::FORMAT_NHWC, 'W'), AXIS_NHWC_DIM_W},
{GenerateAxisIndexKey(ge::FORMAT_CHWN, 'N'), AXIS_CHWN_DIM_N},
{GenerateAxisIndexKey(ge::FORMAT_CHWN, 'C'), AXIS_CHWN_DIM_C},
{GenerateAxisIndexKey(ge::FORMAT_CHWN, 'H'), AXIS_CHWN_DIM_H},
{GenerateAxisIndexKey(ge::FORMAT_CHWN, 'W'), AXIS_CHWN_DIM_W},
{GenerateAxisIndexKey(ge::FORMAT_NDHWC, 'N'), NDHWC_DIM_N},
{GenerateAxisIndexKey(ge::FORMAT_NDHWC, 'C'), NDHWC_DIM_C},
{GenerateAxisIndexKey(ge::FORMAT_NDHWC, 'H'), NDHWC_DIM_H},
{GenerateAxisIndexKey(ge::FORMAT_NDHWC, 'W'), NDHWC_DIM_W},
{GenerateAxisIndexKey(ge::FORMAT_NDHWC, 'D'), NDHWC_DIM_D},
{GenerateAxisIndexKey(ge::FORMAT_NCDHW, 'N'), NCDHW_DIM_N},
{GenerateAxisIndexKey(ge::FORMAT_NCDHW, 'C'), NCDHW_DIM_C},
{GenerateAxisIndexKey(ge::FORMAT_NCDHW, 'H'), NCDHW_DIM_H},
{GenerateAxisIndexKey(ge::FORMAT_NCDHW, 'W'), NCDHW_DIM_W},
{GenerateAxisIndexKey(ge::FORMAT_NCDHW, 'D'), NCDHW_DIM_D},
{GenerateAxisIndexKey(ge::FORMAT_DHWCN, 'N'), DHWCN_DIM_N},
{GenerateAxisIndexKey(ge::FORMAT_DHWCN, 'C'), DHWCN_DIM_C},
{GenerateAxisIndexKey(ge::FORMAT_DHWCN, 'H'), DHWCN_DIM_H},
{GenerateAxisIndexKey(ge::FORMAT_DHWCN, 'W'), DHWCN_DIM_W},
{GenerateAxisIndexKey(ge::FORMAT_DHWCN, 'D'), DHWCN_DIM_D},
{GenerateAxisIndexKey(ge::FORMAT_DHWNC, 'N'), DHWNC_DIM_N},
{GenerateAxisIndexKey(ge::FORMAT_DHWNC, 'C'), DHWNC_DIM_C},
{GenerateAxisIndexKey(ge::FORMAT_DHWNC, 'H'), DHWNC_DIM_H},
{GenerateAxisIndexKey(ge::FORMAT_DHWNC, 'W'), DHWNC_DIM_W},
{GenerateAxisIndexKey(ge::FORMAT_DHWNC, 'D'), DHWNC_DIM_D}
};
void GeShapeToRtShape(const ge::GeShape &ge_shape, gert::Shape &rt_shape) {
rt_shape.SetDimNum(0);
for (size_t i = 0; i < ge_shape.GetDimNum(); ++i) {
rt_shape.AppendDim(ge_shape.GetDim(i));
}
}
void RtShapeToGeShape(const gert::Shape &rt_shape, ge::GeShape &ge_shape) {
ge_shape.SetDimNum(0);
for (size_t i = 0; i < rt_shape.GetDimNum(); ++i) {
ge_shape.AppendDim(rt_shape.GetDim(i));
}
}
}
bool GetDefaultReshapeType(const ge::Format &original_format, const size_t &old_dims_size, std::string &reshape_type) {
int32_t default_key = GenerateReshapeTypeKey(original_format, old_dims_size);
auto iter = DEFAULT_RESHAPE_TYPE.find(default_key);
if (iter == DEFAULT_RESHAPE_TYPE.end()) {
GELOGW("dim size %zu is invalid.", old_dims_size);
return false;
}
reshape_type = iter->second;
return true;
}
bool IsExpandNecessary(const size_t &old_dims_size, const ge::Format &original_format, const ge::Format &final_format,
const std::string &reshape_type, size_t &full_size) {
auto iter_full_size = FULL_SIZE_OF_FORMAT.find(original_format);
if (iter_full_size == FULL_SIZE_OF_FORMAT.end()) {
GELOGW("Original Format %u is invalid.", original_format);
return false;
} else {
if (old_dims_size >= iter_full_size->second) {
return false;
}
}
bool no_need_reshape_flag = reshape_type == RESHAPE_TYPE_FORBIDDEN || kFormatNZSet.count(final_format) > 0 ||
(original_format == ge::FORMAT_ND && final_format == ge::FORMAT_FRACTAL_Z);
if (no_need_reshape_flag) {
return false;
}
full_size = iter_full_size->second;
return true;
}
bool IsReshapeTypeValid(const ge::Format &original_format, const size_t &old_dims_size,
const std::string &reshape_type) {
if (reshape_type.empty()) {
return old_dims_size == 0;
}
int32_t pos = -1;
uint32_t format_key = GenerateFormatKey(original_format);
uint32_t axis_key = 0;
for (const char &dim : reshape_type) {
axis_key = format_key | (static_cast<uint32_t>(dim) & 0xff);
auto iter = AXIS_INDEX_OF_FORMAT.find(axis_key);
if (iter == AXIS_INDEX_OF_FORMAT.end()) {
return false;
}
if (iter->second > pos) {
pos = iter->second;
} else {
return false;
}
}
return true;
}
void ExpandByReshapeType(ge::GeShape &shape, const ge::Format &original_format,
const size_t &old_dims_size, const size_t &full_size, const std::string &reshape_type) {
GELOGD("Expand tensor through reshape of type %s.", reshape_type.c_str());
for (size_t i = old_dims_size; i < full_size; i++) {
shape.AppendDim(1);
}
if (reshape_type.empty() || old_dims_size == 0) {
return;
}
uint32_t format_key = GenerateFormatKey(original_format);
uint32_t axis_key = 0;
for (int32_t i = static_cast<int32_t>(old_dims_size) - 1; i >= 0; i--) {
axis_key = format_key | (static_cast<uint32_t>(reshape_type.at(i)) & 0xff);
auto iter_axis_index = AXIS_INDEX_OF_FORMAT.find(axis_key);
if (iter_axis_index == AXIS_INDEX_OF_FORMAT.end()) {
continue;
}
if (iter_axis_index->second == i) {
continue;
}
shape.SetDim(iter_axis_index->second, shape.GetDim(i));
shape.SetDim(i, 1);
}
}
bool ExpandDimension(const std::string &op_type, const ge::Format &original_format, const ge::Format &final_format,
const uint32_t &tensor_index, const std::string &reshape_type, ge::GeShape &shape) {
size_t full_size = 0;
size_t old_dims_size = shape.GetDimNum();
auto primary_format = static_cast<ge::Format>(ge::GetPrimaryFormat(final_format));
if (!IsExpandNecessary(old_dims_size, original_format, primary_format, reshape_type, full_size)) {
return true;
}
* If not consistent, just return and report a warning. */
std::string valid_reshape_type = reshape_type;
if (!IsReshapeTypeValid(original_format, old_dims_size, reshape_type)) {
if (!GetDefaultReshapeType(original_format, old_dims_size, valid_reshape_type)) {
return true;
}
}
* the length of reshape type. If the dimension of original shape if larger,
* we cannot find suitable posotion for all axis in original shape and we just return. */
if (old_dims_size > valid_reshape_type.length()) {
GELOGW("Dimension %zu of tensor %u in %s exceeds the length of the reshape type, which is %zu.",
old_dims_size, tensor_index, op_type.c_str(), valid_reshape_type.length());
return true;
}
ExpandByReshapeType(shape, original_format, old_dims_size, full_size, valid_reshape_type);
return true;
}
bool ExpandRangeDimension(const std::string &op_type, const ge::Format &original_format,
const ge::Format &final_format, const uint32_t &tensor_index, const std::string &reshape_type,
std::vector<std::pair<int64_t, int64_t>> &ranges) {
std::vector<int64_t> range_upper;
std::vector<int64_t> range_low;
for (auto &i : ranges) {
range_low.emplace_back(i.first);
range_upper.emplace_back(i.second);
}
ge::GeShape shape_low(range_low);
ge::GeShape shape_upper(range_upper);
auto primary_format = static_cast<ge::Format>(ge::GetPrimaryFormat(final_format));
bool res = ExpandDimension(op_type, original_format, primary_format, tensor_index, reshape_type, shape_low) &&
ExpandDimension(op_type, original_format, primary_format, tensor_index, reshape_type, shape_upper);
if (!res || (shape_low.GetDimNum() != shape_upper.GetDimNum())) {
return false;
}
ranges.clear();
for (size_t idx = 0; idx < shape_low.GetDimNum(); ++idx) {
ranges.emplace_back(std::pair<int64_t, int64_t>(shape_low.GetDim(idx), shape_upper.GetDim(idx)));
}
return res;
}
ExpandDimension::ExpandDimension() {}
ExpandDimension::~ExpandDimension() {}
int64_t ExpandDimension::GenerateReshapeType(const ge::Format &origin_format, const ge::Format &format,
const size_t &origin_dim_size, const std::string &reshape_type) {
auto primary_format = static_cast<ge::Format>(ge::GetPrimaryFormat(format));
GELOGD("Begin to generate integer reshape type, original format[%d], format[%d], dim size[%zu], reshape type[%s].",
origin_format, primary_format, origin_dim_size, reshape_type.c_str());
int64_t ret_reshape_type = 0;
size_t full_size = 0;
if (!GetFormatFullSize(origin_format, full_size)) {
return ret_reshape_type;
}
if (!IsNeedExpand(origin_format, primary_format, origin_dim_size, full_size, reshape_type)) {
return ret_reshape_type;
}
std::string valid_shape_type = reshape_type;
if (!IsReshapeTypeValid(origin_format, origin_dim_size, reshape_type)) {
if (!GetDefaultReshapeType(origin_format, origin_dim_size, valid_shape_type)) {
return ret_reshape_type;
}
GELOGD("Invalid reshape type [%s], using default reshape type [%s]",
reshape_type.c_str(), valid_shape_type.c_str());
}
if (origin_dim_size > valid_shape_type.length()) {
GELOGW("The length of reshape type[%s] is shorter than dim size[%zu]. Can not generate integer reshape type.",
valid_shape_type.c_str(), origin_dim_size);
return ret_reshape_type;
}
uint32_t format_key = GenerateFormatKey(origin_format);
std::unordered_set<int32_t> dim_pos_set;
for (const char &dim : valid_shape_type.substr(0, origin_dim_size)) {
uint32_t axis_key = format_key | (static_cast<uint32_t>(dim) & 0xff);
auto iter_axis_index = AXIS_INDEX_OF_FORMAT.find(axis_key);
if (iter_axis_index != AXIS_INDEX_OF_FORMAT.end()) {
dim_pos_set.emplace(iter_axis_index->second);
}
}
for (size_t i = 0; i < full_size; i++) {
if (dim_pos_set.count(static_cast<int32_t>(i)) == 0) {
ret_reshape_type = ret_reshape_type | (1 << i);
}
}
ret_reshape_type = ret_reshape_type | (static_cast<uint64_t>(full_size) << kMaxReshapeTypeSize);
GELOGD("Integer reshape type [%s] has been generated for the original format [%d], with dim size [%zu] and reshape type [%s].",
std::bitset<kBitSetDisplaySize>(ret_reshape_type).to_string().c_str(), origin_format, origin_dim_size,
valid_shape_type.c_str());
return ret_reshape_type;
}
bool ExpandDimension::GenerateReshapeType(const ge::Format &origin_format, const ge::Format &format,
const size_t &origin_dim_size, const std::string &reshape_type,
int64_t &reshape_type_mask) {
auto primary_format = static_cast<ge::Format>(ge::GetPrimaryFormat(format));
GELOGD("Begin to generate integer reshape type, original format[%d], format[%d], dim size[%zu], reshape type[%s].",
origin_format, primary_format, origin_dim_size, reshape_type.c_str());
size_t full_size = 0;
if (!GetFormatFullSize(origin_format, full_size)) {
return true;
}
if (!IsNeedExpand(origin_format, primary_format, origin_dim_size, full_size, reshape_type)) {
return true;
}
std::string valid_shape_type = reshape_type;
if (!IsReshapeTypeValid(origin_format, origin_dim_size, reshape_type)) {
if (!GetDefaultReshapeType(origin_format, origin_dim_size, valid_shape_type)) {
return true;
}
GELOGD("Invalid reshape type [%s], using default reshape type [%s]",
reshape_type.c_str(), valid_shape_type.c_str());
}
if (origin_dim_size > valid_shape_type.length()) {
GELOGE(ge::GRAPH_FAILED, "The length of reshape type[%s] is longer than dim size[%zu]. Can not generate integer reshape type.",
valid_shape_type.c_str(), origin_dim_size);
return false;
}
uint32_t format_key = GenerateFormatKey(origin_format);
std::unordered_set<int32_t> dim_pos_set;
for (const char &dim : valid_shape_type.substr(0, origin_dim_size)) {
uint32_t axis_key = format_key | (static_cast<uint32_t>(dim) & 0xff);
auto iter_axis_index = AXIS_INDEX_OF_FORMAT.find(axis_key);
if (iter_axis_index != AXIS_INDEX_OF_FORMAT.end()) {
dim_pos_set.emplace(iter_axis_index->second);
}
}
for (size_t i = 0; i < full_size; i++) {
if (dim_pos_set.count(static_cast<int32_t>(i)) == 0) {
reshape_type_mask = reshape_type_mask | (1 << i);
}
}
reshape_type_mask = reshape_type_mask | (static_cast<uint64_t>(full_size) << kMaxReshapeTypeSize);
GELOGD("Integer reshape type [%s] has been generated for the original format [%d], with dim size [%zu] and reshape type [%s].",
std::bitset<kBitSetDisplaySize>(reshape_type_mask).to_string().c_str(), origin_format, origin_dim_size,
valid_shape_type.c_str());
return true;
}
bool ExpandDimension::GenerateReshapeTypeByMask(const ge::Format &origin_format, const size_t &origin_dim_size,
const int64_t &reshape_type_mask, std::string &reshape_type,
std::string &failed_reason) {
if (origin_format == ge::FORMAT_ND) {
if (reshape_type_mask == 0) {
return true;
} else {
failed_reason = "Can not generate reshape type for ND format.";
GELOGI("%s", failed_reason.c_str());
return false;
}
}
std::string origin_format_str = ge::TypeUtils::FormatToSerialString(origin_format);
size_t full_size = 0;
if (!GetFormatFullSize(origin_format, full_size)) {
failed_reason = origin_format_str + " is not supported for expanding dims.";
GELOGI("%s", failed_reason.c_str());
return false;
}
if (reshape_type_mask == 0 && origin_dim_size == full_size) {
reshape_type = origin_format_str;
return true;
}
size_t full_size_mask = static_cast<size_t>(reshape_type_mask >> kMaxReshapeTypeSize);
if (full_size != full_size_mask) {
failed_reason = "Full size[" + std::to_string(full_size_mask) + "] from reshape mask is not correct,";
failed_reason += " it should be[" + std::to_string(full_size) + "].";
GELOGI("%s", failed_reason.c_str());
return false;
}
reshape_type.clear();
size_t dim_count = 0;
for (size_t i = 0; i < full_size; ++i) {
if ((reshape_type_mask & (1 << i)) == 0) {
reshape_type += origin_format_str.at(i);
dim_count++;
}
}
if (dim_count != origin_dim_size) {
std::string bit_str = std::bitset<kBitSetDisplaySize>(reshape_type_mask).to_string();
failed_reason = "[" + bit_str + "] is not correct when dim size is [" + std::to_string(origin_dim_size) + "].";
GELOGI("%s", failed_reason.c_str());
return false;
}
return true;
}
bool ExpandDimension::IsNeedExpand(const ge::Format &origin_format, const ge::Format &format,
const size_t &origin_dim_size, const size_t &full_size,
const std::string &reshape_type) {
if (origin_dim_size >= full_size) {
return false;
}
if (reshape_type == RESHAPE_TYPE_FORBIDDEN) {
return false;
}
if (kSupportedTransFormat.count(format) != 0) {
return false;
}
if (origin_format == ge::FORMAT_ND && format == ge::FORMAT_FRACTAL_Z) {
return false;
}
return true;
}
bool ExpandDimension::IsReshapeTypeValid(const ge::Format &origin_format, const size_t &origin_dim_size,
const std::string &reshape_type) {
if (reshape_type.empty()) {
return origin_dim_size == 0;
}
int32_t pos = -1;
uint32_t format_key = GenerateFormatKey(origin_format);
uint32_t axis_key = 0;
for (const char &dim : reshape_type) {
axis_key = format_key | (static_cast<uint32_t>(dim) & 0xff);
auto iter = AXIS_INDEX_OF_FORMAT.find(axis_key);
if (iter == AXIS_INDEX_OF_FORMAT.end()) {
return false;
}
if (iter->second > pos) {
pos = iter->second;
} else {
return false;
}
}
return true;
}
bool ExpandDimension::GetDefaultReshapeType(const ge::Format &origin_format, const size_t &origin_dim_size,
std::string &reshape_type) {
int32_t default_key = GenerateReshapeTypeKey(origin_format, origin_dim_size);
auto iter = DEFAULT_RESHAPE_TYPE.find(default_key);
if (iter == DEFAULT_RESHAPE_TYPE.end()) {
GELOGW("Dim size %zu is invalid, default reshape type not found.", origin_dim_size);
return false;
}
reshape_type = iter->second;
return true;
}
void ExpandDimension::ExpandDims(const int64_t &reshape_type, ge::GeShape &shape) {
GELOGD("Begin to expand dims, reshape type[%" PRId64 "], shape[%s].", reshape_type, shape.ToString().c_str());
gert::Shape inner_shape;
GeShapeToRtShape(shape, inner_shape);
ExpandDims(reshape_type, inner_shape);
RtShapeToGeShape(inner_shape, shape);
GELOGD("After expanding dims, shape[%s].", shape.ToString().c_str());
}
void ExpandDimension::ExpandDims(const int64_t &reshape_type, const ge::GeShape &origin_shape, ge::GeShape &shape) {
GELOGD("Begin to expand dims, reshape type[%" PRId64 "], origin shape[%s].", reshape_type,
origin_shape.ToString().c_str());
gert::Shape inner_ori_shape;
GeShapeToRtShape(origin_shape, inner_ori_shape);
gert::Shape inner_shape;
GeShapeToRtShape(shape, inner_shape);
ExpandDims(reshape_type, inner_ori_shape, inner_shape);
RtShapeToGeShape(inner_shape, shape);
GELOGD("After expanding dims, shape[%s].", shape.ToString().c_str());
}
void ExpandDimension::ExpandDims(const int64_t &reshape_type, gert::Shape &shape) {
if (reshape_type == 0) {
return;
}
gert::ExpandDimsType expand_dims_type(reshape_type);
expand_dims_type.Expand(shape);
}
void ExpandDimension::ExpandDims(const int64_t &reshape_type, const gert::Shape &origin_shape, gert::Shape &shape) {
if (reshape_type == 0) {
return;
}
gert::ExpandDimsType expand_dims_type(reshape_type);
expand_dims_type.Expand(origin_shape, shape);
}
bool ExpandDimension::GetFormatFullSize(const ge::Format &format, size_t &full_size) {
auto iter = FULL_SIZE_OF_FORMAT.find(format);
if (iter == FULL_SIZE_OF_FORMAT.end()) {
return false;
}
full_size = iter->second;
return true;
}
int32_t ExpandDimension::GetAxisIndexByName(char ch, const ge::Format &format) {
uint32_t format_key = GenerateFormatKey(format);
uint32_t axis_key = 0;
axis_key = format_key | (static_cast<uint32_t>(ch) & 0xff);
auto iter = AXIS_INDEX_OF_FORMAT.find(axis_key);
if (iter == AXIS_INDEX_OF_FORMAT.end()) {
return -1;
}
return iter->second;
}
int64_t ExpandDimension::GetReshapeAxicValue(const int64_t &reshape_type_mask,
const ge::GeShape &shape, int32_t axis_index) {
GELOGD("axis_index = %d.", axis_index);
if (axis_index == -1) {
return -1;
}
gert::ExpandDimsType expand_dims_type(reshape_type_mask);
if (!expand_dims_type.IsExpandIndex(axis_index)) {
GELOGD("axis_index is %d.", axis_index);
}
return shape.GetDim(static_cast<size_t>(axis_index));
}
int64_t ExpandDimension::GetReshapeAxicValueByName(const int64_t &reshape_type_mask, char ch,
const ge::GeShape &shape, const ge::Format &format) {
auto idx = GetAxisIndexByName(ch, format);
return GetReshapeAxicValue(reshape_type_mask, shape, idx);
}
}