* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "common/op_tensor_utils.h"
#include <cmath>
#include "common/comm_error_codes.h"
#include "common/fe_log.h"
#include "common/math_util.h"
#include "common/constants_define.h"
#include "common/string_utils.h"
#include "common/platform_utils.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "register/graph_optimizer/fusion_common/unknown_shape_utils.h"
#include "graph/ge_context.h"
namespace fe {
Status OpTensorUtils::VerifyTensor(const std::vector<int64_t> &dims, const ge::DataType &data_type) {
if (data_type < ge::DT_FLOAT || data_type == ge::DT_UNDEFINED || data_type >= ge::DT_MAX) {
REPORT_FE_ERROR("[GraphOpt][OptWholeGph][VerifyTensor] The data type of this tensor is invalid.");
return TENSOR_DATATYPE_INVALID;
}
if (!dims.empty()) {
for (const int64_t &dim : dims) {
if (dim < 0) {
REPORT_FE_ERROR("[GraphOpt][OptWholeGph][VerfyTensor] The dim value[%ld] is invalid.", dim);
return DIM_VALUE_INVALID;
}
}
}
return SUCCESS;
}
Status OpTensorUtils::ArrayMultiplyInt64WithVerify(const std::vector<int64_t> &dims, int64_t &result) {
result = 1;
if (dims.empty()) {
return SUCCESS;
}
for (const int64_t &num : dims) {
if (num == 0) {
FE_LOGD("num = 0; return 1;");
result = 1;
return SUCCESS;
}
FE_MUL_OVERFLOW(result, num, result);
}
return SUCCESS;
}
Status OpTensorUtils::CalibrateTensorSize(int64_t &tensor_size) {
uint32_t padding_size = PlatformUtils::Instance().GetPaddingSize();
if (PlatformUtils::Instance().GetIsaArchVersion() != ISAArchVersion::EN_ISA_ARCH_V350) {
tensor_size = (tensor_size + DATA_MEMORY_ALIGN_SIZE - 1) / DATA_MEMORY_ALIGN_SIZE;
FE_MUL_OVERFLOW(tensor_size, DATA_MEMORY_ALIGN_SIZE, tensor_size);
} else {
tensor_size = (tensor_size + DATA_MEMORY_ALIGN_SIZE_FOR_V350 - 1) / DATA_MEMORY_ALIGN_SIZE_FOR_V350;
FE_MUL_OVERFLOW(tensor_size, DATA_MEMORY_ALIGN_SIZE_FOR_V350, tensor_size);
}
FE_ADD_OVERFLOW(tensor_size, padding_size, tensor_size);
return SUCCESS;
}
Status OpTensorUtils::CalcTensorSize(const std::vector<int64_t> &dims, const ge::DataType &data_type,
int32_t output_real_calc_flag, int64_t &tensor_size) {
if (VerifyTensor(dims, data_type) != SUCCESS) {
REPORT_FE_ERROR("[GraphOpt][OptWholeGph][CalcTensorSize] Failed to verify this tensor.");
return FAILED;
}
int64_t element_cnt;
if (ArrayMultiplyInt64WithVerify(dims, element_cnt) != SUCCESS) {
tensor_size = INT64_MAX;
FE_LOGW("Tensor size is larger than the upper bound of int64, using INT64_MAX as the tensor size.");
return SUCCESS;
}
tensor_size = ge::GetSizeInBytes(element_cnt, data_type);
if (tensor_size < 0) {
REPORT_FE_ERROR("GetSizeInBytes failed!");
return FAILED;
}
if (!output_real_calc_flag) {
return CalibrateTensorSize(tensor_size);
}
return SUCCESS;
}
Status OpTensorUtils::CalcTensorSize(const ge::GeTensorDesc &tensor_desc,
const int32_t output_real_calc_flag, int64_t &tensor_size) {
std::vector<int64_t> dims = tensor_desc.GetShape().GetDims();
ge::DataType data_type = tensor_desc.GetDataType();
return CalcTensorSize(dims, data_type, output_real_calc_flag, tensor_size);
}
bool OpTensorUtils::IsStaticReuseBinaryOp(const ge::OpDescPtr &op_desc) {
bool stc_dyn_soft_sync = false;
(void)ge::AttrUtils::GetBool(op_desc, kStaticToDynamicSoftSyncOp, stc_dyn_soft_sync);
if (stc_dyn_soft_sync) {
return true;
}
if (UnknownShapeUtils::IsUnknownShapeOp(*op_desc)) {
FE_LOGD("Node [%s] has an unknown shape", op_desc->GetName().c_str());
return false;
}
bool is_unknown_graph = false;
(void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_OWNER_GRAPH_IS_UNKNOWN, is_unknown_graph);
if (is_unknown_graph) {
FE_LOGD("Node:%s is unknown graph", op_desc->GetName().c_str());
return false;
}
bool stc_tiling_depend = false;
(void)ge::AttrUtils::GetBool(op_desc, kDynamicTilingDependOp, stc_tiling_depend);
if (stc_tiling_depend) {
return false;
}
return op_desc->HasAttr(COMPILE_INFO_JSON);
}
bool OpTensorUtils::IsFuzzBuildOp(const ge::OpDesc &op_desc) {
std::string build_mode;
bool support_dyn_shape = false;
if ((ge::GetContext().GetOption("ge.shape_generalized_build_mode", build_mode) == ge::GRAPH_SUCCESS) &&
(build_mode == SHAPE_GENERALIZED) &&
ge::AttrUtils::GetBool(op_desc, ATTR_NAME_SUPPORT_DYNAMIC_SHAPE, support_dyn_shape) && support_dyn_shape) {
return true;
} else {
return false;
}
}
}