* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "symbolic_cond_remove_pass.h"
#include "common/checker.h"
#include "graph/utils/node_utils.h"
#include "common/plugin/ge_make_unique_util.h"
#include "graph/utils/graph_utils_ex.h"
#include "graph/debug/ge_util.h"
#include "attribute_group/attr_group_symbolic_desc.h"
#include "attribute_group/attr_group_shape_env.h"
#include "framework/common/framework_types_internal.h"
#include "graph/utils/node_utils_ex.h"
#include "graph/utils/op_type_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/optimize/symbolic/symbol_compute_context.h"
#include "graph/optimize/symbolic/infer_symbolic_shape/symbolic_shape_inference.h"
#include "base/registry/op_impl_space_registry_v2.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/passes/control_flow_and_stream/cond_remove_pass.h"
#include "graph/expression/const_values.h"
#include "graph/optimize/symbolic/infer_symbolic_shape/symbolic_infer_util.h"
namespace ge {
namespace {
const int kCondInputIndex = 0;
void GetIfCondIndexValue(const Expression &cond_index_sym, uint32_t &cond_index_val) {
if (EXPECT_SYMBOL_GT(cond_index_sym, sym::kSymbolZero)) {
cond_index_val = 1U;
} else {
cond_index_val = 0U;
}
}
Status GetCaseCondIndexValue(const NodePtr &node, const Expression &cond_index_sym, uint32_t &cond_index_val) {
uint32_t subgraph_size = node->GetOpDesc()->GetSubgraphInstanceNames().size();
GE_ASSERT_TRUE(subgraph_size > 0);
if (EXPECT_SYMBOL_GE(cond_index_sym, Symbol(subgraph_size))) {
cond_index_val = subgraph_size - 1;
return SUCCESS;
}
for (uint32_t subgraph_idx = 0; subgraph_idx < subgraph_size; subgraph_idx++) {
if (EXPECT_SYMBOL_EQ(cond_index_sym, Symbol(subgraph_idx))) {
cond_index_val = subgraph_idx;
return SUCCESS;
}
}
return FAILED;
}
}
std::string InputValueForCondSource::GetSourceStr() const {
return R"([&]() -> int64_t {
const auto *tensor = context->GetGraphInputTensor()" + std::to_string(input_data_idx_) + R"();
if (tensor == nullptr) {
return -1;
}
const uint8_t *data_ptr = tensor->GetData<uint8_t>();
// 字节数
size_t tensor_size = tensor->GetSize();
const auto type = tensor->GetDataType();
switch (type) {
case ge::DT_STRING:
return (tensor_size > (sizeof(ge::StringHead) + 1)) ? 1 : 0;
case ge::DT_BOOL:
return (tensor_size >= sizeof(bool)) ? static_cast<int64_t>(*reinterpret_cast<const bool *>(data_ptr)) : -1;
case ge::DT_FLOAT:
return (tensor_size >= sizeof(float)) ? static_cast<int64_t>(*reinterpret_cast<const float *>(data_ptr)) : -1;
case ge::DT_DOUBLE:
return (tensor_size >= sizeof(double)) ?
static_cast<int64_t>(*reinterpret_cast<const double *>(data_ptr)) : -1;
case ge::DT_INT8:
case ge::DT_UINT8:
case ge::DT_HIFLOAT8:
case ge::DT_FLOAT8_E5M2:
case ge::DT_FLOAT8_E4M3FN:
return (tensor_size >= sizeof(uint8_t)) ? static_cast<int64_t>(*data_ptr) : -1;
case ge::DT_FLOAT16:
case ge::DT_INT16:
case ge::DT_UINT16:
return (tensor_size >= sizeof(int16_t)) ?
static_cast<int64_t>(*reinterpret_cast<const int16_t *>(data_ptr)) : -1;
case ge::DT_INT32:
case ge::DT_UINT32:
return (tensor_size >= sizeof(int32_t)) ?
static_cast<int64_t>(*reinterpret_cast<const int32_t *>(data_ptr)) : -1;
case ge::DT_INT64:
case ge::DT_UINT64:
return (tensor_size >= sizeof(int64_t)) ? *reinterpret_cast<const int64_t *>(data_ptr) : -1;
default:
return (tensor_size >= sizeof(uint8_t)) ? static_cast<int64_t>(*data_ptr) : -1;
}
}())";
}
Status SymbolicCondRemovePass::GetCondIndexSymbol(const NodePtr &cond_input, Expression &cond_index_sym, const std::string &node_type) {
auto op_desc = cond_input->GetOpDesc();
GE_ASSERT_NOTNULL(op_desc);
int32_t data_index = -1;
GE_ASSERT_TRUE(AttrUtils::GetInt(op_desc, "index", data_index));
GE_ASSERT_TRUE(data_index < static_cast<int32_t>(graph_inputs_.size()));
const auto &tensor = graph_inputs_.at(data_index);
const auto &ge_tensor_desc = tensor.GetTensorDesc();
auto ori_shape = ge_tensor_desc.IsOriginShapeInitialized() ? tensor.GetTensorDesc().GetOriginShape() :
ge_tensor_desc.GetShape();
if (ori_shape.IsScalar()) {
if (created_sym.find(data_index) != created_sym.end()) {
cond_index_sym = created_sym[data_index];
} else {
auto cond_value = CondRemovePass::GetCondIndex(&tensor);
if (cond_value < 0) {
if (kCaseOpTypes.find(node_type) != kCaseOpTypes.end()) {
GELOGI("Get cond index[%d] from node[%s][%s] success, cond_index is negative.", cond_value, cond_input->GetNamePtr(), cond_input->GetTypePtr());
return UNSUPPORTED;
}
}
GELOGD("Get cond index[%d] from node[%s][%s] success, ", cond_value, cond_input->GetNamePtr(), cond_input->GetTypePtr());
auto value_source = MakeShared<InputValueForCondSource>(data_index, cond_value);
auto shape_env = GetCurShapeEnvContext();
GE_ASSERT_NOTNULL(shape_env);
GELOGD("Get shape env success.");
cond_index_sym = GetCurShapeEnvContext()->CreateSymbol(cond_value, value_source);
created_sym[data_index] = cond_index_sym;
}
} else {
auto output_desc = op_desc->GetOutputDesc(0U);
auto sym_attr = output_desc.GetAttrsGroup<SymbolicDescAttr>();
GE_ASSERT_NOTNULL(sym_attr);
cond_index_sym = sym_attr->symbolic_tensor.GetOriginSymbolShape().GetSymbolShapeSize();
}
return SUCCESS;
}
Status SymbolicCondRemovePass::Run(NodePtr &node) {
GE_ASSERT_NOTNULL(node);
if (!SymbolicInferUtil::IsSupportCondNode(node)) {
return SUCCESS;
}
auto cond_input_node = SymbolicInferUtil::GetCondInput(node);
GE_ASSERT_NOTNULL(cond_input_node);
GELOGD("Get cond[%s] input[%s] success.", node->GetNamePtr(), cond_input_node->GetNamePtr());
if (!OpTypeUtils::IsDataNode(cond_input_node->GetType())) {
GELOGI("CondNode[%s] input[%s] is not Data.", node->GetNamePtr(), cond_input_node->GetNamePtr());
return SUCCESS;
}
const auto &node_type = node->GetType();
Expression cond_index_sym;
auto ret = GetCondIndexSymbol(cond_input_node, cond_index_sym, node_type);
if (ret == UNSUPPORTED) {
GELOGI("condIndex of cond[%s] is negative, skip.", node->GetNamePtr());
return SUCCESS;
} else if (ret != SUCCESS) {
GELOGI("Get condIndex of cond[%s] unavailable.", node->GetNamePtr());
return ret;
}
uint32_t cond_index_val = 0;
ComputeGraphPtr chosen_graph = nullptr;
if (kCaseOpTypes.find(node_type) != kCaseOpTypes.end()) {
GE_ASSERT_SUCCESS(GetCaseCondIndexValue(node, cond_index_sym, cond_index_val));
GE_ASSERT_SUCCESS(CondRemovePass::GetCaseChosenBranch(node, cond_index_val, chosen_graph));
} else {
GetIfCondIndexValue(cond_index_sym, cond_index_val);
GE_ASSERT_SUCCESS(CondRemovePass::GetIfChosenBranch(node, cond_index_val, chosen_graph));
}
GE_ASSERT_SUCCESS(CondRemovePass::RemoveDeadCondLink(kCondInputIndex, node));
GE_ASSERT_SUCCESS(CondRemovePass::ReplaceIfCaseNodeWithPartitioncall(node, chosen_graph));
GE_ASSERT_SUCCESS(IsolateAndDeleteNode(node, std::vector<int32_t>()));
return SUCCESS;
}
REG_PASS_OPTION("SymbolicCondRemovePass").LEVELS(OoLevel::kO1).SWITCH_OPT(ge::OO_DEAD_CODE_ELIMINATION);
}