* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/utils/node_utils_ex.h"
#include "graph_metadef/common/ge_common/util.h"
#include "common/util/trace_manager/trace_manager.h"
#include "graph/refiner/format_refiner.h"
#include "graph/shape_refiner.h"
#include "graph/normal_graph/operator_impl.h"
#include "graph/operator_factory_impl.h"
#include "graph/common_error_codes.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/debug/ge_op_types.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
#include "common/util/mem_utils.h"
#include "graph/utils/op_type_utils.h"
#include "graph/utils/op_desc_utils_ex.h"
#include "base/err_msg.h"
namespace af {
namespace {
bool NeedUpdateIOName(const OpDescPtr &op_desc) {
const auto &input_name_2_idx = op_desc->GetAllInputName();
const bool is_input_names_empty = (op_desc->GetInputsSize() > 0U) && input_name_2_idx.empty();
const bool is_default_input_name = !input_name_2_idx.empty() &&
StringUtils::StartWith(input_name_2_idx.cbegin()->first, "__input");
if (is_input_names_empty || is_default_input_name) {
return true;
}
const auto &output_name_2_idx = op_desc->GetAllOutputName();
const bool is_output_names_empty = (op_desc->GetOutputsSize() > 0U) && output_name_2_idx.empty();
const bool is_default_output_name = !output_name_2_idx.empty() &&
StringUtils::StartWith(output_name_2_idx.cbegin()->first, "__output");
if (is_output_names_empty || is_default_output_name) {
return true;
}
return false;
}
std::string IoNameToString(const std::string &prefix, const std::map<std::string, uint32_t> &io_names) {
std::stringstream ss;
ss << prefix << ":";
if (io_names.empty()) {
ss << "empty";
return ss.str();
}
for (const auto &pair : io_names) {
ss << "[" << pair.second << "," << pair.first << "]";
}
return ss.str();
}
}
graphStatus NodeUtilsEx::InferShapeAndType(const NodePtr &node) {
GE_CHECK_NOTNULL(node, ", Node is null for Infer Shape.");
Operator op = OpDescUtils::CreateOperatorFromNode(node);
return ShapeRefiner::InferShapeAndType(node, op);
}
graphStatus NodeUtilsEx::InferOriginFormat(const NodePtr &node) {
GE_CHECK_NOTNULL(node, ", Node is null for Infer Format.");
const auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc, ", Op is null for Infer Format.");
Operator op = OpDescUtils::CreateOperatorFromNode(node);
return OpDescUtilsEx::CallInferFormatFunc(op_desc, op);
}
graphStatus NodeUtilsEx::IsInputsValid(const NodePtr &node) {
const auto &op_desc = node->GetOpDesc();
for (const auto &in_anchor : node->GetAllInDataAnchorsPtr()) {
if (in_anchor == nullptr) {
GELOGW("[Verify][CheckParam] In data anchor is null");
continue;
}
const bool valid_anchor = OpTypeUtils::IsDataNode(node->GetType()) ||
(node->GetType() == CONSTANT) || (node->GetType() == VARIABLE) ||
(node->GetType() == CONSTANTOP) ||
(op_desc->MutableInputDesc(static_cast<uint32_t>(in_anchor->GetIdx())) == nullptr) ||
(in_anchor->GetPeerAnchorsSize() > 0UL);
if (!valid_anchor) {
REPORT_PREDEFINED_ERR_MSG(
"E11019", std::vector<const char *>({"opname", "index"}),
std::vector<const char *>({node->GetName().c_str(), std::to_string(in_anchor->GetIdx()).c_str()}));
GELOGE(GRAPH_FAILED, "[Check][Param] operator %s's input %d is not linked.",
node->GetName().c_str(), in_anchor->GetIdx());
return GRAPH_FAILED;
}
}
return GRAPH_SUCCESS;
}
graphStatus NodeUtilsEx::Verify(const NodePtr &node) {
GE_CHECK_NOTNULL(node, ", Node is null for Infer Verify.");
const bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
if (is_unknown_graph) {
return GRAPH_SUCCESS;
}
GE_CHK_STATUS_RET_NOLOG(IsInputsValid(node));
临时方案:
如下代码使用原型库注册的creator构造临时op_desc,获取其input_names设置到当前op desc上有缺陷。
1.只能恢复靠前的必选输入
2.不能恢复dynamic input
3.不能区分传入了哪几个可选输入,全部恢复
且该行为归属parser, 不应该由infershape干预。但因为tf parser等前端没有正确设置input names。直接去掉会导致部分算子infershape失败。
因此判断若input names以'__input'打头才需要刷新,作为临时方案。
正式方案:
tf、caffee、onnx parser要将op desc的必备字段设置完整
*/
const auto op_desc = node->GetOpDesc();
const bool need_update_name = (node->GetType() != FRAMEWORKOP) && NeedUpdateIOName(op_desc);
GELOGD("Before update %s(%s) io name, input size %zu, %s, output size %zu, %s", op_desc->GetNamePtr(),
op_desc->GetTypePtr(), op_desc->GetInputsSize(),
IoNameToString("Input names", op_desc->GetAllInputName()).c_str(),
op_desc->GetOutputsSize(), IoNameToString("Output names", op_desc->GetAllOutputName()).c_str());
if (need_update_name) {
const auto node_op = af::OperatorFactoryImpl::CreateOperator("node_op", node->GetType());
if (node_op.IsEmpty()) {
GELOGW("[Verify][CheckParam] Get op from OperatorFactory failed, type: %s", node->GetType().c_str());
} else {
GELOGD("get op from OperatorFactory success. opType: %s", node->GetType().c_str());
const auto temp_op_desc = af::OpDescUtils::GetOpDescFromOperator(node_op);
if (temp_op_desc == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "GetOpDescFromOperator failed, as return nullptr, type:%s",
node->GetType().c_str());
GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null, type:%s", node->GetType().c_str());
return GRAPH_FAILED;
}
if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) {
GELOGW("[Verify][Update] Update input name failed");
}
if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) {
GELOGW("[Verify][Update] Update output name failed");
}
GELOGD("After update %s(%s) io name, input size %zu, %s, output size %zu, %s", op_desc->GetNamePtr(),
op_desc->GetTypePtr(), op_desc->GetInputsSize(),
IoNameToString("Input names", op_desc->GetAllInputName()).c_str(),
op_desc->GetOutputsSize(), IoNameToString("Output names", op_desc->GetAllOutputName()).c_str());
}
node_op.BreakConnect();
}
if (op_desc->CommonVerify() == GRAPH_SUCCESS) {
Operator op = OpDescUtils::CreateOperatorFromNode(node);
auto verify_func = op_desc->GetVerifyFunc();
if (verify_func == nullptr) {
verify_func = OperatorFactoryImpl::GetVerifyFunc(node->GetType());
}
if (verify_func != nullptr) {
return static_cast<graphStatus>(verify_func(op));
}
return GRAPH_SUCCESS;
} else {
REPORT_INNER_ERR_MSG("E18888", "%s(%s) Verify failed.", node->GetName().c_str(), node->GetType().c_str());
GELOGE(GRAPH_FAILED, "[Call][CommonVerify] %s(%s) failed.", node->GetName().c_str(), node->GetType().c_str());
return GRAPH_FAILED;
}
}
}