* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/ir_definitions_recover.h"
#include <algorithm>
#include <cinttypes>
#include <ostream>
#include <sstream>
#include "graph/operator_factory.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/normal_graph/op_desc_impl.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/node_utils.h"
#include "graph/debug/ge_op_types.h"
#include "framework/common/debug/ge_log.h"
#include "common/checker.h"
#include "graph/utils/op_type_utils.h"
#include "graph/utils/recover_ir_utils.h"
using IrDefinition = ge::RecoverIrUtils::IrDefinition;
namespace {
std::string IrAttrNamesToString(const std::vector<std::string> &attr_names) {
std::ostringstream oss;
bool first = true;
for (const auto &attr : attr_names) {
if (first) {
first = false;
} else {
oss << ", ";
}
oss << attr;
}
return oss.str();
}
std::string IrInputTypeToString(ge::IrInputType type) {
switch (type) {
case ge::kIrInputRequired:
return "Required";
case ge::kIrInputOptional:
return "Optional";
case ge::kIrInputDynamic:
return "Dynamic";
default:
return "Unknown(" + std::to_string(static_cast<int>(type)) + ")";
}
}
std::string IrOutputTypeToString(ge::IrOutputType type) {
switch (type) {
case ge::kIrOutputRequired:
return "Required";
case ge::kIrOutputDynamic:
return "Dynamic";
default:
return "Unknown(" + std::to_string(static_cast<int>(type)) + ")";
}
}
template <typename IrType>
std::string IrTypeToString(const IrType &ir_type) {
return std::to_string(ir_type);
}
template <>
std::string IrTypeToString<ge::IrInputType>(const ge::IrInputType &ir_type) {
return IrInputTypeToString(ir_type);
}
template <>
std::string IrTypeToString<ge::IrOutputType>(const ge::IrOutputType &ir_type) {
return IrOutputTypeToString(ir_type);
}
template <typename IrDef>
std::string IrDefsToString(const IrDef &ir_defs) {
std::ostringstream oss;
bool first = true;
for (const auto &pair : ir_defs) {
if (first) {
first = false;
} else {
oss << ", ";
}
oss << "[" << pair.first << ", " << IrTypeToString(pair.second) << "]";
}
return oss.str();
}
template <typename IrDef, typename IrType>
ge::graphStatus AppendIrDefs(const ge::OpDescPtr &op_desc, const IrDef &ir_ins, const IrDef &ir_defs,
const ge::RecoverIrUtils::IrDefAppender<IrType> appender,
const std::string &ir_type_name) {
if (ir_defs.size() < ir_ins.size()) {
GELOGE(ge::FAILED,
"In the current running version, the number of operator[%s][%s] %s has been reduced, "
"ir_def.%s size[%zu] is less than ir_%s_in_node size[%zu], ir_def.%s is [%s], "
"ir_%s_in_node is [%s]",
op_desc->GetName().c_str(), op_desc->GetType().c_str(), ir_type_name.c_str(), ir_type_name.c_str(),
ir_defs.size(), ir_type_name.c_str(), ir_ins.size(), ir_type_name.c_str(),
IrDefsToString<IrDef>(ir_defs).c_str(), ir_type_name.c_str(), IrDefsToString<IrDef>(ir_ins).c_str());
return ge::FAILED;
}
for (size_t i = ir_ins.size(); i < ir_defs.size(); ++i) {
appender(op_desc, ir_defs[i].first, ir_defs[i].second);
GELOGD("Append ir %s:[%s, %s] for node[%s(%s)]", ir_type_name.c_str(), ir_defs[i].first.c_str(),
IrTypeToString(ir_defs[i].second).c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str());
}
return ge::GRAPH_SUCCESS;
}
bool IsRequiredAttr(const ge::OpDescPtr &desc, const std::string &attr_name) {
const auto &required_attrs = desc->GetRequiredAttrWithType();
return required_attrs.find(attr_name) != required_attrs.end();
}
bool IsInputConnected(const ge::OpDescPtr &desc, const std::string &ir_input_name) {
return desc->MutableInputDesc(ir_input_name) != nullptr;
}
template <typename IrDef>
bool ValidateIrOrderCompatibility(const IrDef &node_ir_defs, const IrDef &compatible_ir_defs) {
const size_t min_size = std::min(node_ir_defs.size(), compatible_ir_defs.size());
for (size_t i = 0U; i < min_size; ++i) {
if (node_ir_defs[i] != compatible_ir_defs[i]) {
GELOGE(ge::GRAPH_FAILED, "ValidateIrOrderCompatibility failed: index is %zu", i);
return false;
}
}
return true;
}
ge::graphStatus ValidateIrInputOutputOrderCompatibility(const ge::OpDescPtr &desc,
const ge::RecoverIrUtils::InputIrDefs &ir_inputs_in_node,
const ge::RecoverIrUtils::OutputIrDefs &ir_outputs_in_node,
const IrDefinition &ir_def) {
GE_ASSERT_TRUE(ValidateIrOrderCompatibility(ir_inputs_in_node, ir_def.inputs),
"Compatibility failed: operator[%s][%s] input order or type has changed. "
"ir_inputs_in_node is [%s], ir_def.inputs is [%s]",
desc->GetName().c_str(), desc->GetType().c_str(),
IrDefsToString<ge::RecoverIrUtils::InputIrDefs>(ir_inputs_in_node).c_str(),
IrDefsToString<ge::RecoverIrUtils::InputIrDefs>(ir_def.inputs).c_str());
GE_ASSERT_TRUE(ir_outputs_in_node.size() <= ir_def.outputs.size(),
"Compatibility failed: operator[%s][%s] output size has changed. "
"ir_outputs_in_node size is %zu, ir_def.outputs size is %zu",
desc->GetName().c_str(), desc->GetType().c_str(), ir_outputs_in_node.size(), ir_def.outputs.size());
GE_ASSERT_TRUE(ValidateIrOrderCompatibility(ir_outputs_in_node, ir_def.outputs),
"Compatibility failed: operator[%s][%s] output order or type has changed. "
"ir_outputs_in_node is [%s], ir_def.outputs is [%s]",
desc->GetName().c_str(), desc->GetType().c_str(),
IrDefsToString<ge::RecoverIrUtils::OutputIrDefs>(ir_outputs_in_node).c_str(),
IrDefsToString<ge::RecoverIrUtils::OutputIrDefs>(ir_def.outputs).c_str());
return ge::GRAPH_SUCCESS;
}
}
namespace ge {
graphStatus RecoverIrUtils::ProcessForwardCompatInputs(const ge::OpDescPtr &desc,
const InputIrDefs &ir_inputs_in_node,
const IrDefinition &ir_def) {
for (size_t i = ir_def.inputs.size(); i < ir_inputs_in_node.size(); ++i) {
const std::string ir_input_name = ir_inputs_in_node[i].first;
const ge::IrInputType ir_input_type = ir_inputs_in_node[i].second;
GE_ASSERT_TRUE(ir_input_type == ge::kIrInputOptional,
"Forward compatibility failed: operator[%s][%s] has required input[%s] "
"(type: %s) that does not exist in the compatible IR version. "
"This is an incompatible change.",
desc->GetName().c_str(), desc->GetType().c_str(), ir_input_name.c_str(),
IrInputTypeToString(ir_input_type).c_str());
GE_ASSERT_TRUE(!IsInputConnected(desc, ir_input_name),
"Forward compatibility failed: operator[%s][%s] uses optional input[%s] "
"that does not exist in the compatible IR version. "
"The input is connected in the node but not supported by the runtime environment.",
desc->GetName().c_str(), desc->GetType().c_str(), ir_input_name.c_str());
desc->impl_->MutableIRMeta().RemoveIrInput(ir_input_name);
GELOGD("Forward compatibility: removed unused optional input[%s] from node[%s(%s)]",
ir_input_name.c_str(), desc->GetName().c_str(), desc->GetType().c_str());
}
return ge::GRAPH_SUCCESS;
}
graphStatus RecoverIrUtils::RecoverIrAttrNames(const ge::OpDescPtr &desc, IrDefinition &ir_def) {
const auto &ir_attr_names_in_node = desc->GetIrAttrNames();
GE_ASSERT_TRUE(ValidateIrOrderCompatibility(ir_attr_names_in_node, ir_def.attr_names),
"Compatibility failed: operator[%s][%s] attribute order has changed. "
"ir_attr_names_in_node is [%s], ir_def.attr_names is [%s]",
desc->GetName().c_str(), desc->GetType().c_str(), IrAttrNamesToString(ir_attr_names_in_node).c_str(),
IrAttrNamesToString(ir_def.attr_names).c_str());
if (ir_def.strategy == CompatibilityStrategy::kForward) {
for (size_t i = ir_def.attr_names.size(); i < ir_attr_names_in_node.size(); ++i) {
const std::string attr_name = ir_attr_names_in_node[i];
GE_ASSERT_TRUE(!IsRequiredAttr(desc, attr_name),
"Forward compatibility failed: operator[%s][%s] has required attribute[%s] "
"that does not exist in the compatible IR version. This is an incompatible change.",
desc->GetName().c_str(), desc->GetType().c_str(), attr_name.c_str());
GE_ASSERT(!desc->HasAttr(attr_name),
"Forward compatibility failed: operator[%s][%s] has optional attribute[%s] "
"but is configured as non-default value in the node. This is an incompatible change.",
desc->GetName().c_str(), desc->GetType().c_str(), attr_name.c_str());
desc->impl_->MutableIRMeta().RemoveIrAttrName(attr_name);
GELOGD("Forward compatibility: removed optional attribute[%s] from node[%s(%s)]", attr_name.c_str(),
desc->GetName().c_str(), desc->GetType().c_str());
}
return ge::GRAPH_SUCCESS;
}
if (ir_def.strategy == CompatibilityStrategy::kBackward) {
for (size_t i = ir_attr_names_in_node.size(); i < ir_def.attr_names.size(); ++i) {
desc->AppendIrAttrName(ir_def.attr_names[i]);
GELOGD("Append ir attr name:%s for desc[%s(%s), is_required_attr:%d]", ir_def.attr_names[i].c_str(), desc->GetName().c_str(),
desc->GetType().c_str(), ir_def.is_required_attr[i]);
}
return ge::GRAPH_SUCCESS;
}
return ge::GRAPH_SUCCESS;
}
void RecoverIrUtils::InitIrDefinitionsIfNeed(const string &op_type, IrDefinition &ir_def) {
if (!ir_def.inited) {
auto op = ge::OperatorFactory::CreateOperator("temp", op_type.c_str());
op.BreakConnect();
auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op);
if (op_desc == nullptr) {
GELOGW("Failed to construct operator from type %s", op_type.c_str());
ir_def.has_ir_definition = false;
ir_def.inited = true;
return;
}
ir_def.attr_names = op_desc->GetIrAttrNames();
ir_def.is_required_attr.resize(ir_def.attr_names.size(), 0);
for (size_t i = 0; i < ir_def.attr_names.size(); ++i) {
ir_def.is_required_attr[i] = IsRequiredAttr(op_desc, ir_def.attr_names[i]);
}
ir_def.inputs = op_desc->GetIrInputs();
ir_def.outputs = op_desc->GetIrOutputs();
ir_def.attr_value = ge::AttrUtils::GetAllAttrs(op_desc);
ir_def.has_ir_definition = true;
ir_def.inited = true;
ir_def.op_desc = op_desc;
}
}
graphStatus RecoverIrUtils::RecoverIrAttrDefaultValue(const ge::OpDescPtr &desc, const string &op_type,
IrDefinition &ir_def) {
const auto node_all_attrs = ge::AttrUtils::GetAllAttrs(desc);
for (const auto &name : ir_def.attr_names) {
if (node_all_attrs.find(name) != node_all_attrs.cend()) {
continue;
}
const std::map<std::string, ge::AnyValue>::const_iterator iter = ir_def.attr_value.find(name);
if (iter == ir_def.attr_value.cend()) {
GELOGI(
"node[%s(%s)] missing attr name[%s], and cannot find default value for the attr,"
" it may be REQUIRED_ATTR.",
desc->GetName().c_str(), op_type.c_str(), name.c_str());
continue;
}
GELOGD("node[%s(%s)] missing attr name[%s], set default value.", desc->GetName().c_str(), op_type.c_str(),
name.c_str());
GE_ASSERT_GRAPH_SUCCESS(desc->AttrHolder::SetAttr(name, iter->second));
}
return ge::GRAPH_SUCCESS;
}
CompatibilityStrategy RecoverIrUtils::DeriveCompatibilityStrategy(const ge::OpDescPtr &desc,
const IrDefinition &ir_def) {
const int64_t attr_diff =
static_cast<int64_t>(desc->GetIrAttrNames().size()) - static_cast<int64_t>(ir_def.attr_names.size());
const int64_t input_diff =
static_cast<int64_t>(desc->GetIrInputs().size()) - static_cast<int64_t>(ir_def.inputs.size());
if ((attr_diff > 0) && (input_diff > 0)) {
return CompatibilityStrategy::kForward;
}
if ((attr_diff < 0) && (input_diff < 0)) {
return CompatibilityStrategy::kBackward;
}
if (((attr_diff > 0) && (input_diff < 0)) || ((attr_diff < 0) && (input_diff > 0))) {
GELOGE(ge::GRAPH_FAILED,
"Compatibility strategy derivation failed: operator[%s][%s] has inconsistent compatibility direction. "
"Node has %zu attributes and %zu inputs, compatible IR has %zu attributes and %zu inputs. "
"Attributes suggest %s compatibility, but inputs suggest %s compatibility. This is an incompatible change.",
desc->GetName().c_str(), desc->GetType().c_str(), desc->GetIrAttrNames().size(), desc->GetIrInputs().size(),
ir_def.attr_names.size(), ir_def.inputs.size(), (attr_diff > 0 ? "forward" : "backward"),
(input_diff > 0 ? "forward" : "backward"));
return CompatibilityStrategy::kFailed;
}
if ((attr_diff == 0) && (input_diff == 0)) {
return CompatibilityStrategy::kNone;
}
return ((attr_diff > 0) || (input_diff > 0)) ? CompatibilityStrategy::kForward : CompatibilityStrategy::kBackward;
}
graphStatus RecoverIrUtils::RecoverOpDescIrDefinition(const ge::OpDescPtr &desc, const string &op_type,
IrDefinition &ir_def) {
if ((desc->GetType() == ge::NETOUTPUT) || ge::OpTypeUtils::IsDataNode(desc->GetType())) {
return ge::GRAPH_SUCCESS;
}
InitIrDefinitionsIfNeed(op_type, ir_def);
if (!ir_def.has_ir_definition) {
GELOGI("Op type:%s has no registered IR, maybe no need to recover.", op_type.c_str());
return ge::GRAPH_SUCCESS;
}
ir_def.strategy = DeriveCompatibilityStrategy(desc, ir_def);
if (ir_def.strategy == CompatibilityStrategy::kFailed) {
return ge::GRAPH_FAILED;
}
GE_ASSERT_GRAPH_SUCCESS(RecoverIrAttrNames(desc, ir_def), "%s %s recover ir attr names failed.",
desc->GetNamePtr(), desc->GetTypePtr());
GE_ASSERT_GRAPH_SUCCESS(RecoverIrInputAndOutput(desc, ir_def), "%s %s recover ir input and output failed.",
desc->GetNamePtr(), desc->GetTypePtr());
desc->ShareDtypeSymbolsFrom(*ir_def.op_desc);
GE_ASSERT_GRAPH_SUCCESS(RecoverIrAttrDefaultValue(desc, op_type, ir_def),
"%s %s recover ir attr default value failed.", desc->GetNamePtr(), desc->GetTypePtr());
return ge::GRAPH_SUCCESS;
}
graphStatus RecoverIrUtils::RecoverIrInputAndOutput(const OpDescPtr &desc, IrDefinition &ir_def) {
const auto &ir_inputs_in_node = desc->GetIrInputs();
const auto &ir_outputs_in_node = desc->GetIrOutputs();
GE_ASSERT_GRAPH_SUCCESS(ValidateIrInputOutputOrderCompatibility(desc, ir_inputs_in_node, ir_outputs_in_node, ir_def),
"%s %s validate ir input output order compatibility failed.", desc->GetNamePtr(),
desc->GetTypePtr());
if (ir_def.strategy == CompatibilityStrategy::kForward) {
GE_ASSERT_GRAPH_SUCCESS(RecoverIrUtils::ProcessForwardCompatInputs(desc, ir_inputs_in_node, ir_def));
} else if (ir_def.strategy == CompatibilityStrategy::kBackward) {
auto input_appender = [](const ge::OpDescPtr &op_desc, const std::string &ir_name,
const ge::IrInputType ir_type) -> void { op_desc->AppendIrInput(ir_name, ir_type); };
if (AppendIrDefs<InputIrDefs, ge::IrInputType>(desc, desc->GetIrInputs(), ir_def.inputs, input_appender, "input") !=
ge::GRAPH_SUCCESS) {
GELOGE(ge::GRAPH_FAILED, "recover ir inputs failed.");
return ge::GRAPH_FAILED;
}
}
auto output_appender = [](const ge::OpDescPtr &op_desc, const std::string &ir_name,
const ge::IrOutputType ir_type) -> void { op_desc->AppendIrOutput(ir_name, ir_type); };
if (AppendIrDefs<OutputIrDefs, ge::IrOutputType>(desc, desc->GetIrOutputs(), ir_def.outputs, output_appender,
"output") != ge::GRAPH_SUCCESS) {
GELOGE(ge::GRAPH_FAILED, "recover ir outputs failed.");
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
static graphStatus RecoverNodeIrDefinitions(const ge::NodePtr &node, std::string &op_type, IrDefinition &ir_def) {
return RecoverIrUtils::RecoverOpDescIrDefinition(node->GetOpDesc(), op_type, ir_def);
}
graphStatus RecoverIrUtils::RecoverOpDescIrDefinition(const ge::OpDescPtr &desc, const std::string &op_type) {
std::string specified_type = op_type.empty() ? desc->GetType() : op_type;
IrDefinition ir_def;
ir_def.inited = false;
return RecoverIrUtils::RecoverOpDescIrDefinition(desc, specified_type, ir_def);
}
ge::graphStatus RecoverIrUtils::RecoverIrDefinitions(const ge::ComputeGraphPtr &graph,
const vector<std::string> &attr_names) {
GELOGD("Start to recover all ir definitions for graph:%s.", graph->GetName().c_str());
std::map<std::string, IrDefinition> op_type_to_ir_def;
for (const auto &node : graph->GetAllNodes()) {
std::string op_type = ge::NodeUtils::GetNodeType(node);
auto &ir_def = op_type_to_ir_def[op_type];
if (RecoverNodeIrDefinitions(node, op_type, ir_def) != ge::GRAPH_SUCCESS) {
GELOGE(ge::GRAPH_FAILED, "[Recover][NodeIrDefinitions] failed, node[%s], type[%s]",
node->GetName().c_str(), node->GetType().c_str());
return ge::GRAPH_FAILED;
}
for (const auto &attr_name : attr_names) {
ge::ComputeGraphPtr graph_ptr = nullptr;
(void) ge::AttrUtils::GetGraph(node->GetOpDesc(), attr_name, graph_ptr);
if (graph_ptr == nullptr) {
continue;
}
if (RecoverIrDefinitions(graph_ptr) != ge::GRAPH_SUCCESS) {
GELOGE(ge::GRAPH_FAILED, "[Recover][IrDefinitions] failed, graph[%s]", graph_ptr->GetName().c_str());
return ge::GRAPH_FAILED;
}
(void) ge::AttrUtils::SetGraph(node->GetOpDesc(), attr_name, graph_ptr);
GELOGD("Success to recover definitions for graph:%s with node:%s and attr:%s.",
graph->GetName().c_str(), node->GetName().c_str(), attr_name.c_str());
}
}
GELOGD("Success to recover all ir definitions for graph:%s.", graph->GetName().c_str());
return ge::GRAPH_SUCCESS;
}
ge::graphStatus RecoverIrDefinitions(const ge::ComputeGraphPtr &graph, const vector<std::string> &attr_names) {
return RecoverIrUtils::RecoverIrDefinitions(graph, attr_names);
}
ge::graphStatus RecoverOpDescIrDefinition(const ge::OpDescPtr &desc, const std::string &op_type) {
return RecoverIrUtils::RecoverOpDescIrDefinition(desc, op_type);
}
bool CheckIrSpec(const ge::OpDescPtr &desc) {
std::string op_type = desc->GetType();
IrDefinition ir_def;
ir_def.inited = false;
RecoverIrUtils::InitIrDefinitionsIfNeed(op_type, ir_def);
bool ir_input_include_dynamic = false;
bool ir_output_include_dynamic = false;
for (auto &ir_def_input : ir_def.inputs) {
if ((ir_def_input.second == kIrInputDynamic) || (ir_def_input.second == kIrInputOptional)) {
ir_input_include_dynamic = true;
break;
}
}
for (auto &ir_def_output : ir_def.outputs) {
if (ir_def_output.second == kIrOutputDynamic) {
ir_output_include_dynamic = true;
break;
}
}
size_t input_num = desc->GetInputsSize();
size_t output_num = desc->GetOutputsSize();
GELOGD("Node:%s check input num is %d and ir input num is %d, output num is %d and ir output num is %d",
desc->GetName().c_str(), input_num, ir_def.inputs.size(), output_num, ir_def.outputs.size());
if (((input_num != ir_def.inputs.size()) && !ir_input_include_dynamic) ||
((output_num != ir_def.outputs.size()) && !ir_output_include_dynamic)) {
GELOGI("Node:%s inputs/outputs num changed, compatibility check skipped", desc->GetName().c_str());
return false;
}
const auto node_all_attrs = ge::AttrUtils::GetAllAttrs(desc);
for (const auto &name : ir_def.attr_names) {
if (node_all_attrs.find(name) != node_all_attrs.cend()) {
continue;
}
const std::map<std::string, ge::AnyValue>::const_iterator iter = ir_def.attr_value.find(name);
if (iter == ir_def.attr_value.cend()) {
GELOGI("node[%s(%s)] missing attr name[%s], and cannot find default value for the attr,"
" it may be REQUIRED_ATTR.",
desc->GetName().c_str(), op_type.c_str(), name.c_str());
return false;
}
}
return true;
}
}