* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef SLICE_DATASLICE_DATA_SLICE_ADAPTER_H_
#define SLICE_DATASLICE_DATA_SLICE_ADAPTER_H_
#include <vector>
#include "ge/ge_api_types.h"
#include "framework/common/debug/ge_log.h"
#include "graph/axis_type_info.h"
#include "graph/utils/op_desc_utils.h"
#include "slice/data_slice_infer_base.h"
namespace ge {
enum class TransType {
CUR_TO_ORI = 0,
ORI_TO_CUR,
};
class DataSliceAdapter {
public:
using DataSliceType = std::vector<std::vector<std::vector<int64_t>>>;
static void PrintOp(const OpDescPtr &op);
static void PrintAxis(const OpDescPtr &op, const std::vector<AxisTypeInfo> &axis_type_vec,
const std::string &type, bool print_ori);
static void PrintSlice(const OpDescPtr &op, const DataSliceType &slice_info, const std::string &tensor_type,
const std::string &tag);
static void TransAxisInfo(const OpDescPtr &op, std::vector<AxisTypeInfo> &axis_type_vec);
static Status GetOriOutputSlice(const OpDescPtr &op, const AxisTypeInfo &slice_info, DataSliceType &ori_output_slice);
static Status GetCurInputSlice(const OpDescPtr &op, const AxisTypeInfo &slice_info,
const DataSliceType &ori_input_slice, DataSliceType &cur_input_slice);
static AxisTypeInfo GetTmpAxisTypeInfo(const AxisTypeInfo &slice_info);
static bool CheckOriInfo(const OpDescPtr &op);
static void SetOriOpInfo(OpDescPtr &op,
std::vector<std::pair<Format, GeShape>> &cache_input_info,
std::vector<std::pair<Format, GeShape>> &cache_output_info);
static void SetCurOpInfo(OpDescPtr &op,
const std::vector<std::pair<Format, GeShape>> &cache_input_info,
const std::vector<std::pair<Format, GeShape>> &cache_output_info);
private:
static inline bool IsLogEnable(const int32_t module_name, const int32_t log_level) {
const int32_t enable = CheckLogLevel(module_name, log_level);
return (enable == 1);
}
static void PrintAxisItem(const AxisTypeInfo &axis_type, bool print_ori, std::stringstream &ss);
static std::string GetTensorStr(const OpDesc::Vistor<ge::GeTensorDescPtr> all_tensor_desc);
static std::vector<int64_t> TransAxisToNZ(const GeTensorDescPtr &tensor, int64_t axis);
static bool CheckReshape(const GeTensorDescPtr &tensor, const std::string &reshape_type, int64_t axis,
int64_t &format_match_axis);
static bool CheckRank(size_t rank, size_t dim_num, const std::string &reshape_type);
static std::vector<int64_t> TransAxisForSplit(const GeTensorDescPtr &tensor, const int64_t axis,
size_t dim_num);
static std::vector<int64_t> TransAxisForNoSplit(const GeTensorDescPtr &tensor, const int64_t axis,
size_t dim_num);
static bool IsFormatInSet(const Format format, const std::set<Format> &format_set);
static std::vector<int64_t> TransAxis(const GeTensorDescPtr &tensor, int64_t ori_axis);
static Status FixAxisTypeInfoToOne(AxisTypeInfo &axis_type_info);
static bool ValidateRelateInputOutput(const AxisTypeInfo &axis_type_info);
static Status TransAxisForInputTensor(const OpDescPtr &op, const std::string &axis_type_str,
AxisTypeInfo &axis_type_info);
static Status TransAxisForOutputTensor(const OpDescPtr &op, const std::string &axis_type_str,
AxisTypeInfo &axis_type_info);
static Status TransByAxisTypeStr(const OpDescPtr &op, const std::string &axis_type_str,
AxisTypeInfo &axis_type_info);
static Status TransReduceMean(const OpDescPtr &op, AxisTypeInfo &axis_type_info);
static Status TransOther(const OpDescPtr &op, AxisTypeInfo &axis_type_info);
static void BackupOriAxisTypeInfo(AxisTypeInfo &axis_type_info);
static void ResetOriAxisTypeInfo(AxisTypeInfo &axis_type_info);
static Status TransAxisByType(const AxisType axis_type, const OpDescPtr &op, AxisTypeInfo &axis_type_info);
static int64_t SearchOriAxis(const std::vector<CutInfo> &ori_relate, int64_t tensor_idx, int64_t axis_idx);
static bool ValidateAxisIndex(int64_t from_axis, const std::vector<std::vector<int64_t>> &slice_info,
int64_t to_axis, const std::vector<std::vector<int64_t>> &cur_tensor_range);
static Status TransSliceInfoToOriForElement(const OpDescPtr &op, const AxisTypeInfo &axis_type_info,
const DataSliceType &slice_info_list, DataSliceType &ori_slice_info_list);
static Status TransSliceInfoToCurForElement(const OpDescPtr &op, const AxisTypeInfo &axis_type_info,
const DataSliceType &slice_info_list, DataSliceType &cur_slice_info_list);
static AxisType GetAxisTypeForTransAxis(const AxisTypeInfo &axis_type_info);
static AxisType GetAxisTypeForTransSlice(const AxisTypeInfo &axis_type_info);
static Status TransSliceInfo(const OpDescPtr &op, const AxisTypeInfo &axis_type_info, TransType trans_type,
const DataSliceType &slice_info_list, DataSliceType &out_slice_info_list);
};
}
#endif