* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "slice/data_slice_adapter.h"
#include <sstream>
#include <map>
#include <set>
#include "graph/debug/ge_attr_define.h"
#include "framework/common/debug/ge_log.h"
namespace ge {
const std::map<Format, std::vector<std::string>> FORMAT_MAP = {
{Format::FORMAT_NCHW, {"N", "C", "H", "W"}},
{Format::FORMAT_NHWC, {"N", "H", "W", "C"}},
{Format::FORMAT_CHWN, {"C", "H", "W", "N"}},
{Format::FORMAT_HWCN, {"H", "W", "C", "N"}},
{Format::FORMAT_NC1HWC0, {"N", "C1", "H", "W", "C0"}},
{Format::FORMAT_NC1HWC0_C04, {"N", "C1", "H", "W", "C0"}},
{Format::FORMAT_NCDHW, {"N", "C", "D", "H", "W"}},
{Format::FORMAT_NDHWC, {"N", "D", "H", "W", "C"}},
{Format::FORMAT_DHWCN, {"D", "H", "W", "C", "N"}},
{Format::FORMAT_DHWNC, {"D", "H", "W", "N", "C"}},
{Format::FORMAT_NDC1HWC0, {"N", "D", "C1", "H", "W", "C0"}}
};
const std::set<Format> FORMAT_4D_SET = {
Format::FORMAT_NCHW,
Format::FORMAT_NHWC,
Format::FORMAT_CHWN,
Format::FORMAT_HWCN
};
const std::set<Format> FORMAT_5D_SET = {
Format::FORMAT_NCDHW,
Format::FORMAT_NDHWC,
Format::FORMAT_DHWCN,
Format::FORMAT_DHWNC
};
const std::map<Format, std::string> FORMAT_MAP_STR = {
{Format::FORMAT_NCHW, "NCHW"},
{Format::FORMAT_NHWC, "NHWC"},
{Format::FORMAT_CHWN, "CHWN"},
{Format::FORMAT_HWCN, "HWCN"},
{Format::FORMAT_NC1HWC0, "NC1HWC0"},
{Format::FORMAT_NC1HWC0_C04, "NC1HWC0"},
{Format::FORMAT_NCHW, "NCHW"},
{Format::FORMAT_NCDHW, "NCDHW"},
{Format::FORMAT_NDHWC, "NDHWC"},
{Format::FORMAT_DHWCN, "DHWCN"},
{Format::FORMAT_DHWNC, "DHWNC"},
{Format::FORMAT_NDC1HWC0, "NDC1HWC0"},
{Format::FORMAT_FRACTAL_NZ, "NZ"},
{Format::FORMAT_ND, "ND"},
};
constexpr int64_t AXIS_INDEX_1 = 1;
constexpr int64_t AXIS_INDEX_2 = 2;
constexpr int64_t AXIS_INDEX_4 = 4;
constexpr int64_t AXIS_INDEX_5 = 5;
constexpr int64_t DIM_NUM_1 = 1;
constexpr int64_t DIM_NUM_2 = 2;
constexpr int64_t DIM_NUM_3 = 3;
constexpr size_t DIM_NUM_4 = 4;
constexpr size_t DIM_NUM_5 = 5;
constexpr size_t MAX_TYPE_SIZE = 2;
constexpr size_t RANGE_NUM_SIZE = 2;
void DataSliceAdapter::PrintAxisItem(const AxisTypeInfo &axis_type, bool print_ori, std::stringstream &ss)
{
ss << "{type:" << static_cast<int>(axis_type.GetAxisType());
ss << ",relate_inputs:[";
for (const auto &relate_input : axis_type.GetRelateInputs()) {
ss << "{" << relate_input.first << ",{";
for (const auto &axis : relate_input.second) {
ss << axis << ",";
}
ss << "}}";
}
ss << "],relate_outputs:[";
for (const auto &relate_output : axis_type.GetRelateOutputs()) {
ss << "{" << relate_output.first << ",{";
for (const auto &axis : relate_output.second) {
ss << axis << ",";
}
ss << "}}";
}
ss << "],";
if (print_ori) {
ss << "ori_relate_inputs:[";
for (const auto &relate_input : axis_type.GetOriRelateInputs()) {
ss << "{" << relate_input.first << ",{";
for (const auto &axis : relate_input.second) {
ss << axis << ",";
}
ss << "}}";
}
ss << "],ori_relate_outputs:[";
for (const auto &relate_output : axis_type.GetOriRelateOutputs()) {
ss << "{" << relate_output.first << ",{";
for (const auto &axis : relate_output.second) {
ss << axis << ",";
}
ss << "}}";
}
ss << "]";
}
ss << "},";
}
void DataSliceAdapter::PrintAxis(const OpDescPtr &op, const std::vector<AxisTypeInfo> &axis_type_vec,
const std::string &type, bool print_ori)
{
if (!IsLogEnable(GE, DLOG_DEBUG)) {
return;
}
std::stringstream ss;
ss << "OpName:[" << op->GetName() << "] " << type << " op_axis_type info:[";
for (const auto &axis_type : axis_type_vec) {
PrintAxisItem(axis_type, print_ori, ss);
}
ss << "]";
GELOGD("%s", ss.str().c_str());
}
void DataSliceAdapter::PrintSlice(const OpDescPtr &op, const DataSliceType &slice_info,
const std::string &tensor_type, const std::string &tag)
{
if (!IsLogEnable(GE, DLOG_DEBUG)) {
return;
}
std::stringstream ss;
ss << "OpName[" << op->GetName() << "] " << tag << ":";
for (size_t tensor_idx = 0; tensor_idx < slice_info.size(); tensor_idx++) {
ss << tensor_type << "[" << tensor_idx << "]={";
for (size_t axis_idx = 0; axis_idx < slice_info[tensor_idx].size(); axis_idx++) {
ss << "{";
for (size_t range_idx = 0; range_idx < slice_info[tensor_idx][axis_idx].size(); range_idx++) {
ss << slice_info[tensor_idx][axis_idx][range_idx] << ",";
}
ss << "},";
}
ss << "};";
}
GELOGD("%s", ss.str().c_str());
}
void DataSliceAdapter::PrintOp(const OpDescPtr &op)
{
if (!IsLogEnable(GE, DLOG_DEBUG)) {
return;
}
OpDesc &op_desc = *(op.get());
std::string input_str = GetTensorStr(op_desc.GetAllInputsDescPtr());
GELOGD("OpName[%s] input:%s", op->GetName().c_str(), input_str.c_str());
std::string output_str = GetTensorStr(op_desc.GetAllOutputsDescPtr());
GELOGD("OpName[%s] output:%s", op->GetName().c_str(), output_str.c_str());
}
std::string DataSliceAdapter::GetTensorStr(const OpDesc::Vistor<ge::GeTensorDescPtr> all_tensor_desc)
{
std::stringstream ss;
for (const auto &tensor : all_tensor_desc) {
const Format ori_format = tensor->GetOriginFormat();
const GeShape ori_shape = tensor->GetOriginShape();
const Format format = static_cast<Format>(GetPrimaryFormat(tensor->GetFormat()));
const GeShape shape = tensor->GetShape();
auto iter_ori = FORMAT_MAP_STR.find(ori_format);
auto iter = FORMAT_MAP_STR.find(format);
const std::string *reshape_type = AttrUtils::GetStr(tensor, ATTR_NAME_RESHAPE_INFER_TYPE);
if (reshape_type == nullptr) {
continue;
}
if (iter_ori == FORMAT_MAP_STR.cend() || iter == FORMAT_MAP_STR.cend()) {
ss << "ori_fomat:" << ori_format << ",ori_shape:" << ori_shape.ToString();
ss << ",fomat:" << format << ",shape:" << shape.ToString();
ss << ",reshape_type:" << (reshape_type == nullptr ? "" : *reshape_type) << ";";
continue;
}
ss << "ori_fomat:" << iter_ori->second << ",ori_shape:" << ori_shape.ToString();
ss << ",fomat:" << iter->second << ",shape:" << shape.ToString();
ss << ",reshape_type:" << (reshape_type == nullptr ? "" : *reshape_type) << ";";
}
return ss.str();
}
AxisTypeInfo DataSliceAdapter::GetTmpAxisTypeInfo(const AxisTypeInfo &slice_info)
{
AxisTypeInfo axis_type_info = slice_info;
axis_type_info.SetRelateInputs(slice_info.GetOriRelateInputs());
axis_type_info.SetRelateOutputs(slice_info.GetOriRelateOutputs());
return axis_type_info;
}
Status DataSliceAdapter::GetOriOutputSlice(const OpDescPtr &op, const AxisTypeInfo &slice_info,
DataSliceType &ori_output_slice)
{
DataSliceType output_slice;
for (const auto &tensor_slice : slice_info.GetRelateOutputs()) {
GeTensorDesc tensor_desc = op->GetOutputDesc(tensor_slice.first);
std::vector<std::vector<int64_t>> infer_range_vec_res;
(void)AttrUtils::GetListListInt(tensor_desc, ATTR_NAME_DATA_SLICE, infer_range_vec_res);
output_slice.emplace_back(infer_range_vec_res);
}
PrintSlice(op, output_slice, "output", "current");
if (TransSliceInfo(op, slice_info, TransType::CUR_TO_ORI, output_slice, ori_output_slice) != SUCCESS) {
GELOGE(FAILED, "Failed to trans slice info from cur to ori, op_name = %s", op->GetName().c_str());
return FAILED;
}
PrintSlice(op, ori_output_slice, "output", "origin");
return SUCCESS;
}
Status DataSliceAdapter::GetCurInputSlice(const OpDescPtr &op, const AxisTypeInfo &slice_info,
const DataSliceType &ori_input_slice, DataSliceType &cur_input_slice)
{
PrintSlice(op, ori_input_slice, "input", "origin");
if (TransSliceInfo(op, slice_info, TransType::ORI_TO_CUR, ori_input_slice, cur_input_slice) != SUCCESS) {
GELOGE(FAILED, "Failed to trans slice info from cur to ori, op_name = %s", op->GetName().c_str());
return FAILED;
}
PrintSlice(op, cur_input_slice, "input", "current");
return SUCCESS;
}
bool DataSliceAdapter::CheckOriInfo(const OpDescPtr &op)
{
for (size_t idx = 0; idx < op->GetAllInputsDescPtr().size(); idx++) {
auto cur_tensor = op->MutableInputDesc(idx);
if (cur_tensor == nullptr) {
GELOGW("op_name = %s, input_tensor[%zu] is nullptr", op->GetName().c_str(), idx);
continue;
}
auto ori_shape = cur_tensor->GetOriginShape();
auto shape = cur_tensor->GetShape();
if (ori_shape.GetShapeSize() == 0 && shape.GetShapeSize() != 0) {
GELOGW("op_name = %s, input_tensor[%zu] ori_shape is empty", op->GetName().c_str(), idx);
return false;
}
}
for (size_t idx = 0; idx < op->GetAllOutputsDescPtr().size(); idx++) {
auto cur_tensor = op->MutableOutputDesc(idx);
if (cur_tensor == nullptr) {
GELOGW("op_name = %s, output_tensor[%zu] is nullptr", op->GetName().c_str(), idx);
continue;
}
auto ori_shape = cur_tensor->GetOriginShape();
auto shape = cur_tensor->GetShape();
if (ori_shape.GetShapeSize() == 0 && shape.GetShapeSize() != 0) {
GELOGW("op_name = %s, output_tensor[%zu] ori_shape is empty", op->GetName().c_str(), idx);
return false;
}
}
return true;
}
void DataSliceAdapter::SetOriOpInfo(OpDescPtr &op,
std::vector<std::pair<Format, GeShape>> &cache_input_info,
std::vector<std::pair<Format, GeShape>> &cache_output_info)
{
uint32_t input_size = static_cast<uint32_t>(op->GetAllInputsDescPtr().size());
for (uint32_t idx = 0; idx < input_size; idx++) {
auto cur_tensor = op->MutableInputDesc(idx);
if (cur_tensor == nullptr) {
GELOGW("op_name = %s, input_tensor[%u] is nullptr", op->GetName().c_str(), idx);
continue;
}
cache_input_info.emplace_back(static_cast<Format>(cur_tensor->GetFormat()),
cur_tensor->GetShape());
cur_tensor->SetFormat(cur_tensor->GetOriginFormat());
cur_tensor->SetShape(cur_tensor->GetOriginShape());
}
uint32_t output_size = static_cast<uint32_t>(op->GetAllOutputsDescPtr().size());
for (uint32_t idx = 0; idx < output_size; idx++) {
auto cur_tensor = op->MutableOutputDesc(idx);
if (cur_tensor == nullptr) {
GELOGW("op_name = %s, output_tensor[%u] is nullptr", op->GetName().c_str(), idx);
continue;
}
cache_output_info.emplace_back(static_cast<Format>(cur_tensor->GetFormat()),
cur_tensor->GetShape());
cur_tensor->SetFormat(cur_tensor->GetOriginFormat());
cur_tensor->SetShape(cur_tensor->GetOriginShape());
}
}
void DataSliceAdapter::SetCurOpInfo(OpDescPtr &op,
const std::vector<std::pair<Format, GeShape>> &cache_input_info,
const std::vector<std::pair<Format, GeShape>> &cache_output_info)
{
size_t item_idx = 0;
uint32_t input_size = static_cast<uint32_t>(op->GetAllInputsDescPtr().size());
for (uint32_t idx = 0; idx < input_size; idx++) {
auto cur_tensor = op->MutableInputDesc(static_cast<uint32_t>(idx));
if (cur_tensor == nullptr) {
GELOGW("op_name = %s, input_tensor[%u] is nullptr", op->GetName().c_str(), idx);
continue;
}
cur_tensor->SetFormat(cache_input_info[item_idx].first);
cur_tensor->SetShape(cache_input_info[item_idx].second);
item_idx++;
}
item_idx = 0;
uint32_t output_size = static_cast<uint32_t>(op->GetAllOutputsDescPtr().size());
for (uint32_t idx = 0; idx < output_size; idx++) {
auto cur_tensor = op->MutableOutputDesc(static_cast<uint32_t>(idx));
if (cur_tensor == nullptr) {
GELOGW("op_name = %s, output_tensor[%u] is nullptr", op->GetName().c_str(), idx);
continue;
}
cur_tensor->SetFormat(cache_output_info[item_idx].first);
cur_tensor->SetShape(cache_output_info[item_idx].second);
item_idx++;
}
}
std::vector<int64_t> DataSliceAdapter::TransAxisToNZ(const GeTensorDescPtr &tensor, int64_t axis)
{
const auto ori_shape = tensor->GetOriginShape();
const int64_t rank = static_cast<int64_t>(ori_shape.GetDims().size());
std::vector<int64_t> axis_vec;
if (axis <= rank - DIM_NUM_3) {
axis_vec.push_back(axis);
} else if (axis == rank - DIM_NUM_2) {
axis_vec.push_back(rank - DIM_NUM_1);
axis_vec.push_back(rank);
} else if (axis == rank - DIM_NUM_1) {
axis_vec.push_back(rank - DIM_NUM_2);
axis_vec.push_back(rank + DIM_NUM_1);
}
return axis_vec;
}
bool DataSliceAdapter::CheckReshape(const GeTensorDescPtr &tensor, const std::string &reshape_type,
int64_t axis, int64_t &format_match_axis)
{
if (axis >= static_cast<int64_t>(reshape_type.size())) {
GELOGW("The axis [%ld] >= reshape_type size [%zu]", axis, reshape_type.size());
return false;
}
const auto format = tensor->GetOriginFormat();
auto iter = FORMAT_MAP.find(format);
const std::vector<std::string> format_vec = iter->second;
const std::string reshape_char = reshape_type.substr(axis, 1);
format_match_axis = std::find(format_vec.cbegin(), format_vec.cend(), reshape_char) - format_vec.cbegin();
return true;
}
bool DataSliceAdapter::CheckRank(size_t rank, size_t dim_num, const std::string &reshape_type)
{
if (rank > dim_num) {
return false;
} else if (rank < dim_num && rank != reshape_type.size()) {
return false;
}
return true;
}
std::vector<int64_t> DataSliceAdapter::TransAxisForSplit(const GeTensorDescPtr &tensor, const int64_t axis,
size_t dim_num)
{
const auto ori_shape = tensor->GetOriginShape();
const size_t rank = ori_shape.GetDims().size();
std::string reshape_type;
(void)AttrUtils::GetStr(tensor, ATTR_NAME_RESHAPE_INFER_TYPE, reshape_type);
std::vector<int64_t> axis_vec;
if (!(CheckRank(rank, dim_num, reshape_type))) {
GELOGW("Failed to CheckRank rank = %zu, dim_num = %zu, reshape_type = %s", rank, dim_num, reshape_type.c_str());
return axis_vec;
}
int64_t format_match_axis = axis;
if (rank != dim_num && !CheckReshape(tensor, reshape_type, axis, format_match_axis)) {
GELOGW("Failed to CheckReshape");
return axis_vec;
}
const auto ori_format = tensor->GetOriginFormat();
const auto format = static_cast<Format>(GetPrimaryFormat(tensor->GetFormat()));
auto iter = FORMAT_MAP.find(ori_format);
auto iter_dst = FORMAT_MAP.find(format);
if (iter == FORMAT_MAP.cend() || iter_dst == FORMAT_MAP.cend()) {
GELOGW("Cannot find ori_format[%d] or format[%d] in FORMAT_MAP", static_cast<int>(ori_format),
static_cast<int>(format));
return axis_vec;
}
const std::vector<std::string> ori_format_vec = iter->second;
if (format_match_axis >= static_cast<int64_t>(ori_format_vec.size())) {
GELOGW("format_match_axis[%ld] is out of range of format_vec_size[%zu]", format_match_axis, ori_format_vec.size());
return axis_vec;
}
const std::string format_char = ori_format_vec[format_match_axis];
if (format_char == "C") {
const std::vector<int64_t> vec_4d = {AXIS_INDEX_1, AXIS_INDEX_4};
const std::vector<int64_t> vec_5d = {AXIS_INDEX_2, AXIS_INDEX_5};
axis_vec = (dim_num == DIM_NUM_4) ? vec_4d : vec_5d;
return axis_vec;
}
std::vector<std::string> dst_fmt_vec = iter_dst->second;
axis_vec.push_back(std::find(dst_fmt_vec.cbegin(), dst_fmt_vec.cend(), format_char) - dst_fmt_vec.cbegin());
return axis_vec;
}
std::vector<int64_t> DataSliceAdapter::TransAxisForNoSplit(const GeTensorDescPtr &tensor, const int64_t axis,
size_t dim_num)
{
const auto ori_shape = tensor->GetOriginShape();
const size_t rank = ori_shape.GetDims().size();
std::vector<int64_t> axis_vec;
if (rank != dim_num) {
GELOGW("rank[%zu] != to dim_num[%zu] in non_split_axis scene", rank, dim_num);
return axis_vec;
}
const auto ori_format = tensor->GetOriginFormat();
auto iter = FORMAT_MAP.find(ori_format);
if (iter == FORMAT_MAP.cend()) {
GELOGW("Cannot find ori_format[%d] in FORMAT_MAP", static_cast<int>(ori_format));
return axis_vec;
}
const std::vector<std::string> ori_format_vec = iter->second;
if (axis >= static_cast<int64_t>(ori_format_vec.size())) {
GELOGW("axis[%ld] is out of range of format_vec_size[%zu]", axis, ori_format_vec.size());
return axis_vec;
}
const std::string format_char = ori_format_vec[axis];
auto format = static_cast<Format>(GetPrimaryFormat(tensor->GetFormat()));
auto iter_dst = FORMAT_MAP.find(format);
if (iter_dst == FORMAT_MAP.cend()) {
GELOGW("Cannot find format[%d] in FORMAT_MAP", format);
return axis_vec;
}
std::vector<std::string> dst_fmt_vec = iter_dst->second;
axis_vec.push_back(std::find(dst_fmt_vec.cbegin(), dst_fmt_vec.cend(), format_char) - dst_fmt_vec.cbegin());
return axis_vec;
}
bool DataSliceAdapter::IsFormatInSet(const Format format, const std::set<Format> &format_set)
{
auto iter = format_set.find(format);
return iter != format_set.cend();
}
std::vector<int64_t> DataSliceAdapter::TransAxis(const GeTensorDescPtr &tensor, int64_t ori_axis)
{
std::vector<int64_t> axis_vec;
auto ori_format = tensor->GetOriginFormat();
auto format = static_cast<Format>(GetPrimaryFormat(tensor->GetFormat()));
if (format == ori_format) {
axis_vec.push_back(ori_axis);
return axis_vec;
} else if (format == Format::FORMAT_FRACTAL_NZ) {
return TransAxisToNZ(tensor, ori_axis);
} else if (format == Format::FORMAT_NC1HWC0 || format == Format::FORMAT_NC1HWC0_C04) {
return TransAxisForSplit(tensor, ori_axis, DIM_NUM_4);
} else if (format == Format::FORMAT_NDC1HWC0) {
return TransAxisForSplit(tensor, ori_axis, DIM_NUM_5);
} else if (IsFormatInSet(format, FORMAT_4D_SET)) {
return TransAxisForNoSplit(tensor, ori_axis, DIM_NUM_4);
} else if (IsFormatInSet(format, FORMAT_5D_SET)) {
return TransAxisForNoSplit(tensor, ori_axis, DIM_NUM_5);
}
return axis_vec;
}
Status DataSliceAdapter::FixAxisTypeInfoToOne(AxisTypeInfo &axis_type_info)
{
size_t count = 0;
std::vector<CutInfo> input_cut_info_vec = axis_type_info.GetRelateInputs();
for (const auto &item : input_cut_info_vec) {
count = (count == 0) ? item.second.size() : count;
if (count != item.second.size()) {
GELOGW("The split axis size is not same in all input tensors.");
return FAILED;
}
}
std::vector<CutInfo> output_cut_info_vec = axis_type_info.GetRelateOutputs();
for (const auto &item : output_cut_info_vec) {
bool is_reduce = (axis_type_info.GetAxisType() == AxisType::REDUCESUM) ||
(axis_type_info.GetAxisType() == AxisType::REDUCEMAX) ||
(axis_type_info.GetAxisType() == AxisType::REDUCEMIN);
if (item.second.empty() && is_reduce) {
continue;
}
count = (count == 0) ? item.second.size() : count;
if (count != item.second.size()) {
GELOGW("The split axis size is not same in all input output tensors.");
return FAILED;
}
}
for (auto &item : input_cut_info_vec) {
std::vector<int64_t> &axis_vec = item.second;
axis_vec = {axis_vec[0]};
}
for (auto &item : output_cut_info_vec) {
std::vector<int64_t> &axis_vec = item.second;
if (!axis_vec.empty()) {
axis_vec = {axis_vec[0]};
}
}
axis_type_info.SetRelateInputs(input_cut_info_vec);
axis_type_info.SetRelateOutputs(output_cut_info_vec);
return SUCCESS;
}
Status DataSliceAdapter::TransAxisForInputTensor(const OpDescPtr &op, const std::string &axis_type_str,
AxisTypeInfo &axis_type_info)
{
std::vector<CutInfo> tmp_relate_puts;
for (const auto &item : axis_type_info.GetRelateInputs()) {
std::vector<int64_t> trans_axis_list;
for (const int64_t axis : item.second) {
GeTensorDescPtr cur_tensor = nullptr;
cur_tensor = op->MutableInputDesc(static_cast<uint32_t>(item.first));
if (cur_tensor == nullptr) {
GELOGW("op_name = %s, input_tensor[%ld] is nullptr", op->GetName().c_str(), item.first);
return FAILED;
}
const std::vector<int64_t> axis_vec = TransAxis(cur_tensor, axis);
if (axis_type_str == "reduce_type" && axis_vec.size() > 1) {
GELOGW("axis_type is reduce_type and axis_vec_size > 1");
return FAILED;
}
if (axis_vec.empty()) {
GELOGW("TransAxis failed: op_name = %s, input_tensor[%ld], ori_axis[%ld]",
op->GetName().c_str(), item.first, axis);
return FAILED;
}
trans_axis_list.insert(trans_axis_list.cend(), axis_vec.cbegin(), axis_vec.cend());
}
tmp_relate_puts.emplace_back(item.first, trans_axis_list);
}
axis_type_info.SetRelateInputs(tmp_relate_puts);
return SUCCESS;
}
Status DataSliceAdapter::TransAxisForOutputTensor(const OpDescPtr &op, const std::string &axis_type_str,
AxisTypeInfo &axis_type_info)
{
std::vector<CutInfo> tmp_relate_puts;
for (const auto &item : axis_type_info.GetRelateOutputs()) {
std::vector<int64_t> trans_axis_list;
for (const int64_t axis : item.second) {
GeTensorDescPtr cur_tensor = nullptr;
cur_tensor = op->MutableOutputDesc(static_cast<uint32_t>(item.first));
if (cur_tensor == nullptr) {
GELOGW("op_name = %s, output_tensor[%ld] is nullptr", op->GetName().c_str(), item.first);
return FAILED;
}
const std::vector<int64_t> axis_vec = TransAxis(cur_tensor, axis);
if (axis_type_str == "reduce_type" && axis_vec.size() > 1) {
GELOGW("axis_type is reduce_type and axis_vec_size > 1");
return FAILED;
}
if (axis_vec.empty()) {
GELOGW("TransAxis failed: op_name = %s, output_tensor[%ld], ori_axis[%ld]",
op->GetName().c_str(), item.first, axis);
return FAILED;
}
trans_axis_list.insert(trans_axis_list.cend(), axis_vec.cbegin(), axis_vec.cend());
}
tmp_relate_puts.emplace_back(item.first, trans_axis_list);
}
axis_type_info.SetRelateOutputs(tmp_relate_puts);
return SUCCESS;
}
Status DataSliceAdapter::TransByAxisTypeStr(const OpDescPtr &op, const std::string &axis_type_str,
AxisTypeInfo &axis_type_info)
{
if (TransAxisForInputTensor(op, axis_type_str, axis_type_info) != SUCCESS) {
GELOGW("Failed to trans axis type for input tensor.");
return FAILED;
}
if (TransAxisForOutputTensor(op, axis_type_str, axis_type_info) != SUCCESS) {
GELOGW("Failed to trans axis type for output tensor.");
return FAILED;
}
if (axis_type_str == "element_type" && FixAxisTypeInfoToOne(axis_type_info) != SUCCESS) {
GELOGW("Fix axis type info to on for element_type failed");
return FAILED;
}
return SUCCESS;
}
void DataSliceAdapter::BackupOriAxisTypeInfo(AxisTypeInfo &axis_type_info)
{
axis_type_info.SetOriRelateInputs(axis_type_info.GetRelateInputs());
axis_type_info.SetOriRelateOutputs(axis_type_info.GetRelateOutputs());
}
void DataSliceAdapter::ResetOriAxisTypeInfo(AxisTypeInfo &axis_type_info)
{
std::vector<CutInfo> relat_inputs;
std::vector<CutInfo> relat_outputs;
axis_type_info.SetOriRelateInputs(relat_inputs);
axis_type_info.SetOriRelateOutputs(relat_outputs);
}
bool DataSliceAdapter::ValidateRelateInputOutput(const AxisTypeInfo &axis_type_info)
{
if (axis_type_info.GetRelateInputs().size() > 0 && axis_type_info.GetRelateOutputs().size() > 0) {
return true;
}
return false;
}
Status DataSliceAdapter::TransAxisByType(const AxisType axis_type, const OpDescPtr &op,
AxisTypeInfo &axis_type_info)
{
if (!ValidateRelateInputOutput(axis_type_info)) {
GELOGW("ValidateRelateInputOutput failed");
return FAILED;
}
Status ret = SUCCESS;
BackupOriAxisTypeInfo(axis_type_info);
switch (axis_type) {
case AxisType::ELEMENTWISE:
case AxisType::TRANSPOSE:
case AxisType::REDUCESUM:
case AxisType::REDUCEMAX:
case AxisType::REDUCEMIN:
ret = TransByAxisTypeStr(op, "element_type", axis_type_info);
break;
case AxisType::REDUCEMEAN:
case AxisType::REDUCEGATHER:
case AxisType::ELEMENTWITHSHAPEVALUE:
ret = TransByAxisTypeStr(op, "reduce_type", axis_type_info);
break;
case AxisType::SLIDINGWINDOW:
case AxisType::SLIDINGWINDOWGRAD:
ret = TransByAxisTypeStr(op, "other_type", axis_type_info);
break;
default:
ret = FAILED;
GELOGW("Unsupport axis_type = %d", static_cast<int>(axis_type));
break;
}
if (ret != SUCCESS) {
ResetOriAxisTypeInfo(axis_type_info);
}
return ret;
}
AxisType DataSliceAdapter::GetAxisTypeForTransAxis(const AxisTypeInfo &axis_type_info)
{
const std::vector<AxisType> tmp_vec = axis_type_info.GetAxisTypes();
std::set<AxisType> tmp_set(tmp_vec.begin(), tmp_vec.end());
if (tmp_set.size() > MAX_TYPE_SIZE) {
return AxisType::UNSPLIT;
}
if (tmp_set.size() == MAX_TYPE_SIZE &&
std::find(tmp_set.cbegin(), tmp_set.cend(), AxisType::ELEMENTWISE) != tmp_set.cend() &&
std::find(tmp_set.cbegin(), tmp_set.cend(), AxisType::REDUCESUM) != tmp_set.cend()) {
GELOGI("axis_type is ELEMENTWISE+REDUCESUM.");
return AxisType::SLIDINGWINDOW;
}
if (tmp_set.size() == 1) {
return *tmp_set.cbegin();
}
return axis_type_info.GetAxisType();
}
void DataSliceAdapter::TransAxisInfo(const OpDescPtr &op, std::vector<AxisTypeInfo> &axis_type_vec)
{
for (auto iter = axis_type_vec.begin(); iter != axis_type_vec.end();) {
AxisType axis_type = GetAxisTypeForTransAxis(*iter);
if (TransAxisByType(axis_type, op, *iter) == SUCCESS) {
++iter;
} else {
GELOGI("remove one axis type info");
iter = axis_type_vec.erase(iter);
}
}
}
int64_t DataSliceAdapter::SearchOriAxis(const std::vector<CutInfo> &ori_relate, int64_t tensor_idx,
int64_t axis_idx)
{
for (const auto &item : ori_relate) {
if (item.first == tensor_idx) {
if (axis_idx < static_cast<int64_t>(item.second.size())) {
return item.second[axis_idx];
}
}
}
return -1;
}
bool DataSliceAdapter::ValidateAxisIndex(int64_t from_axis,
const std::vector<std::vector<int64_t>> &slice_info,
int64_t to_axis, const std::vector<std::vector<int64_t>> &cur_tensor_range)
{
if (from_axis >= static_cast<int64_t>(slice_info.size()) ||
to_axis >= static_cast<int64_t>(cur_tensor_range.size())) {
GELOGE(FAILED, "from_axis:%ld,slice_info_size:%zu,to_axis:%ld, cur_tensor_range_size:%zu",
from_axis, slice_info.size(), to_axis, cur_tensor_range.size());
return false;
}
if (slice_info[from_axis].size() != RANGE_NUM_SIZE) {
GELOGE(FAILED, "slice_info[%ld].size:%zu != RANGE_NUM_SIZE", from_axis, slice_info[from_axis].size());
return false;
}
return true;
}
Status DataSliceAdapter::TransSliceInfoToOriForElement(const OpDescPtr &op, const AxisTypeInfo &axis_type_info,
const DataSliceType &slice_info_list, DataSliceType &ori_slice_info_list)
{
const std::vector<CutInfo> ori_relate_outputs = axis_type_info.GetOriRelateOutputs();
const std::vector<CutInfo> relate_outputs = axis_type_info.GetRelateOutputs();
if (ori_relate_outputs.size() == 0 || relate_outputs.size() != slice_info_list.size()) {
GELOGW("op_name = %s, ori_relate_outputs_size[%zu], relate_outputs_size[%zu], slice_info_list_size[%zu]",
op->GetName().c_str(), ori_relate_outputs.size(), relate_outputs.size(), slice_info_list.size());
return FAILED;
}
for (size_t index = 0; index < relate_outputs.size(); index++) {
const int64_t tensor_idx = relate_outputs[index].first;
const auto output_tensor = op->MutableOutputDesc(static_cast<uint32_t>(tensor_idx));
if (output_tensor == nullptr) {
GELOGW("op_name = %s, output_tensor[%ld] is nullptr", op->GetName().c_str(), tensor_idx);
return FAILED;
}
const auto ori_shape = output_tensor->GetOriginShape();
const size_t rank = ori_shape.GetDims().size();
std::vector<int64_t> tmp;
std::vector<std::vector<int64_t>> cur_tensor_range(rank, tmp);
std::vector<int64_t> axis_vec = relate_outputs[index].second;
for (size_t idx = 0; idx < axis_vec.size(); idx++) {
const int64_t ori_axis = SearchOriAxis(ori_relate_outputs, tensor_idx, idx);
if (ori_axis < 0) {
GELOGW("op_name = %s, get_ori_axis for output_tensor[%ld] axis[%ld] return ori_axis [-1]",
op->GetName().c_str(), tensor_idx, idx);
return FAILED;
}
const int64_t cur_axis = axis_vec[idx];
if (!ValidateAxisIndex(cur_axis, slice_info_list[index], ori_axis, cur_tensor_range)) {
return FAILED;
}
const std::vector<int64_t> transed_axis_list = TransAxis(output_tensor, ori_axis);
if (transed_axis_list.size() == 0) {
GELOGW("op_name = %s, TransAxis failed for output_tensor[%ld] ori_axis[%ld]",
op->GetName().c_str(), tensor_idx, ori_axis);
return FAILED;
}
GeShape cur_shape = output_tensor->GetShape();
size_t prod_rest_axis = 1;
for (size_t i = 1; i < transed_axis_list.size(); i++) {
prod_rest_axis *= (cur_shape.GetDim(transed_axis_list[i]));
}
cur_tensor_range[ori_axis] = slice_info_list[index][cur_axis];
std::vector<int64_t> &ori_slice_piece = cur_tensor_range[ori_axis];
ori_slice_piece[0] = ori_slice_piece[0] * prod_rest_axis;
ori_slice_piece[1] = ori_slice_piece[1] * prod_rest_axis + prod_rest_axis - 1;
ori_slice_piece[1] = std::min(ori_slice_piece[1], ori_shape.GetDim(ori_axis) -1);
}
ori_slice_info_list.emplace_back(cur_tensor_range);
}
return SUCCESS;
}
Status DataSliceAdapter::TransSliceInfoToCurForElement(const OpDescPtr &op, const AxisTypeInfo &axis_type_info,
const DataSliceType &slice_info_list, DataSliceType &cur_slice_info_list)
{
const std::vector<CutInfo> ori_relate_inputs = axis_type_info.GetOriRelateInputs();
const std::vector<CutInfo> relate_inputs = axis_type_info.GetRelateInputs();
if (ori_relate_inputs.size() == 0 || relate_inputs.size() != slice_info_list.size()) {
GELOGW("op_name = %s, ori_relate_inputs_size[%zu], relate_inputs_size[%zu], slice_info_list_size[%zu]",
op->GetName().c_str(), ori_relate_inputs.size(), relate_inputs.size(), slice_info_list.size());
return FAILED;
}
for (size_t index = 0; index < relate_inputs.size(); index++) {
const int64_t tensor_idx = relate_inputs[index].first;
const auto input_tensor = op->MutableInputDesc(static_cast<uint32_t>(tensor_idx));
if (input_tensor == nullptr) {
GELOGW("op_name = %s, input_tensor[%ld] is nullptr", op->GetName().c_str(), tensor_idx);
return FAILED;
}
const auto cur_shape = input_tensor->GetShape();
const size_t rank = cur_shape.GetDims().size();
std::vector<int64_t> tmp;
std::vector<std::vector<int64_t>> cur_tensor_range(rank, tmp);
std::vector<int64_t> axis_vec = relate_inputs[index].second;
for (size_t idx = 0; idx < axis_vec.size(); idx++) {
const int64_t ori_axis = SearchOriAxis(ori_relate_inputs, tensor_idx, idx);
if (ori_axis < 0) {
GELOGW("op_name = %s, get_ori_axis for input_tensor[%ld] axis[%ld] return ori_axis [-1]",
op->GetName().c_str(), tensor_idx, idx);
return FAILED;
}
const int64_t cur_axis = axis_vec[idx];
if (!ValidateAxisIndex(ori_axis, slice_info_list[index], cur_axis, cur_tensor_range)) {
return FAILED;
}
const std::vector<int64_t> transed_axis_list = TransAxis(input_tensor, ori_axis);
if (transed_axis_list.size() == 0) {
GELOGW("op_name = %s, TransAxis failed for input_tensor[%ld] ori_axis[%ld]",
op->GetName().c_str(), tensor_idx, ori_axis);
return FAILED;
}
size_t prod_rest_axis = 1;
for (size_t i = 1; i < transed_axis_list.size(); i++) {
prod_rest_axis *= (cur_shape.GetDim(transed_axis_list[i]));
}
cur_tensor_range[cur_axis] = slice_info_list[index][ori_axis];
std::vector<int64_t> &cur_slice_piece = cur_tensor_range[cur_axis];
cur_slice_piece[0] /= prod_rest_axis;
cur_slice_piece[1] /= prod_rest_axis;
}
cur_slice_info_list.emplace_back(cur_tensor_range);
}
return SUCCESS;
}
AxisType DataSliceAdapter::GetAxisTypeForTransSlice(const AxisTypeInfo &axis_type_info)
{
const std::vector<AxisType> tmp_vec = axis_type_info.GetAxisTypes();
std::set<AxisType> tmp_set(tmp_vec.begin(), tmp_vec.end());
if (tmp_set.size() >= MAX_TYPE_SIZE) {
return AxisType::UNSPLIT;
}
if (tmp_set.size() == 1) {
return *tmp_set.cbegin();
}
return axis_type_info.GetAxisType();
}
Status DataSliceAdapter::TransSliceInfo(const OpDescPtr &op, const AxisTypeInfo &axis_type_info,
TransType trans_type, const DataSliceType &slice_info_list, DataSliceType &out_slice_info_list)
{
const AxisType axis_type = GetAxisTypeForTransSlice(axis_type_info);
Status ret = SUCCESS;
switch (axis_type) {
case AxisType::ELEMENTWISE:
case AxisType::REDUCESUM:
case AxisType::REDUCEMAX:
case AxisType::REDUCEMIN:
case AxisType::REDUCEMEAN:
if (trans_type == TransType::CUR_TO_ORI) {
ret = TransSliceInfoToOriForElement(op, axis_type_info, slice_info_list, out_slice_info_list);
} else {
ret = TransSliceInfoToCurForElement(op, axis_type_info, slice_info_list, out_slice_info_list);
}
break;
case AxisType::SLIDINGWINDOW:
case AxisType::SLIDINGWINDOWGRAD:
case AxisType::ELEMENTWITHSHAPEVALUE:
out_slice_info_list = slice_info_list;
GELOGI("op_name[%s], axis_type[%d], keep slice info.", op->GetName().c_str(), static_cast<int>(axis_type));
break;
default:
GELOGW("op_name[%s], unsupport axis_type[%d]", op->GetName().c_str(), static_cast<int>(axis_type));
ret = FAILED;
break;
}
return ret;
}
}