* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "common/ge_inner_attrs.h"
#include "engine/node_converter_utils.h"
#include "exe_graph/lowering/lowering_definitions.h"
#include "exe_graph/lowering/value_holder_utils.h"
#include "framework/common/framework_types_internal.h"
#include "graph_builder/bg_identity.h"
#include "graph_builder/converter_checker.h"
#include "graph_builder/multi_stream/bg_event.h"
#include "graph/utils/graph_dump_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "lowering/graph_converter.h"
#include "lowering/placement/placed_lowering_result.h"
#include "register/node_converter_registry.h"
namespace gert {
namespace {
constexpr const char *kLowerWhileBodyOutput = "_lower_while_body_output";
constexpr const char *kLowerWhileCondOutput = "_lower_while_cond_output";
LowerResult GetNodeLowerInputPlacedOn(const ge::NodePtr &node, TensorPlacement placement,
LoweringGlobalData &global_data) {
LowerResult ret;
for (const auto &node_and_anchor : node->GetInDataNodesAndAnchors()) {
auto &in_node = node_and_anchor.first;
auto output_index = node_and_anchor.second->GetIdx();
auto placed_result = in_node->GetOpDescBarePtr()->GetExtAttr<PlacedLoweringResult>(kLoweringResult);
LOWER_REQUIRE_NOTNULL(placed_result, "The in node %s of net-output %s have no placed lowering result",
in_node->GetNamePtr(), node->GetNamePtr());
const auto *result = placed_result->GetOutputResult(global_data, output_index, {placement, bg::kMainStream}, false);
LOWER_REQUIRE_NOTNULL(result);
LOWER_REQUIRE(result->has_init);
ret.out_addrs.emplace_back(result->address);
ret.out_shapes.emplace_back(result->shape);
}
return ret;
}
template <typename T>
bool IsNotAliasOrOrdered(const std::vector<T> &targets, const std::vector<T> &ordered, size_t index) {
for (const auto &target : targets) {
for (size_t i = 0U; i < ordered.size(); i++) {
if ((target != nullptr) && (target == ordered[i]) && (index != i)) {
GELOGI("While node input %zu flow to mismatch-ordered %zu", index, i);
return false;
}
}
}
return true;
}
bool IsWhileBodyCanInplace(const ge::NodePtr &net_output, const LowerInput &lower_input) {
auto body = net_output->GetOwnerComputeGraphBarePtr();
GE_ASSERT_NOTNULL(body);
for (const auto &node : body->GetDirectNode()) {
GE_ASSERT_NOTNULL(node);
if (node->GetType() != ge::DATA) {
continue;
}
auto data_node = node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(data_node);
size_t index = 0U;
GE_ASSERT(ge::AttrUtils::GetInt(data_node, ge::ATTR_NAME_PARENT_NODE_INDEX, index));
auto placed_lower_result = data_node->GetExtAttr<PlacedLoweringResult>(kLoweringResult);
GE_ASSERT_NOTNULL(placed_lower_result);
auto lower_result = placed_lower_result->GetResult();
GE_ASSERT_NOTNULL(lower_result);
if (!IsNotAliasOrOrdered(lower_result->out_addrs, lower_input.input_addrs, index)) {
return false;
}
if (!IsNotAliasOrOrdered(lower_result->out_shapes, lower_input.input_shapes, index)) {
return false;
}
}
return true;
}
LowerResult LoweringWhileBodyOutput(const ge::NodePtr &node, const LowerInput &lower_input) {
GELOGI("Start lowering while body output %s", node->GetNamePtr());
auto device_inputs = GetNodeLowerInputPlacedOn(node, kOnDeviceHbm, *lower_input.global_data);
LOWER_RETURN_IF_ERROR(device_inputs);
LOWER_REQUIRE_HYPER_SUCCESS(LoweringAccessMemCrossStream(node, device_inputs.out_addrs));
auto &out_shapes = device_inputs.out_shapes;
auto &out_addrs = device_inputs.out_addrs;
if (!IsWhileBodyCanInplace(node, lower_input)) {
GELOGI("Add identity as barrier for update inputs for next loop as %s body can not inplace", node->GetNamePtr());
out_shapes = bg::IdentityShape(out_shapes);
out_addrs = bg::IdentityAddr(out_addrs, node->GetOpDescBarePtr()->GetStreamId());
for (auto &out_addr : out_addrs) {
bg::ValueHolder::CreateVoidGuarder("FreeMemory", out_addr, {});
}
}
auto shape_and_addrs = bg::IdentityShapeAndAddr(out_shapes, out_addrs, node->GetOpDescBarePtr()->GetStreamId());
LOWER_REQUIRE_VALID_HOLDER(shape_and_addrs);
size_t num_shapes = shape_and_addrs.size() >> 1U;
LowerResult ret;
ret.out_shapes.insert(ret.out_shapes.end(), shape_and_addrs.begin(), shape_and_addrs.begin() + num_shapes);
ret.out_addrs.insert(ret.out_addrs.end(), shape_and_addrs.begin() + num_shapes, shape_and_addrs.end());
ret.order_holders.push_back(shape_and_addrs.front());
return ret;
}
LowerResult LoweringWhileCondOutput(const ge::NodePtr &node, const LowerInput &lower_input) {
GELOGI("Start lowering while cond output %s", node->GetNamePtr());
return GetNodeLowerInputPlacedOn(node, kOnHost, *lower_input.global_data);
}
REGISTER_NODE_CONVERTER(kLowerWhileBodyOutput, LoweringWhileBodyOutput);
REGISTER_NODE_CONVERTER(kLowerWhileCondOutput, LoweringWhileCondOutput);
std::string GetGraphTypeAttrInstanceGraphName(const ge::NodePtr &node, const char *attr) {
const auto op_desc = node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
const auto &graph_attr_index = op_desc->GetSubgraphNameIndexes();
const std::map<std::string, uint32_t>::const_iterator iter = graph_attr_index.find(attr);
if (iter != graph_attr_index.cend()) {
return op_desc->GetSubgraphInstanceName(iter->second);
}
GELOGW("Node %s has no graph attr named %s", node->GetNamePtr(), attr);
return op_desc->GetSubgraphInstanceName((std::string(attr) == "cond") ? 0U : 1U);
}
ge::ComputeGraphPtr GetAttrSubgraph(const ge::NodePtr &node, const char *attr) {
auto attr_graph_name = GetGraphTypeAttrInstanceGraphName(node, attr);
GE_ASSERT(!attr_graph_name.empty(), "Empty instance graph name for attr %s", attr);
auto root_compute_graph = ge::GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
GE_ASSERT_NOTNULL(root_compute_graph);
auto attr_graph = root_compute_graph->GetSubgraph(attr_graph_name);
GE_ASSERT_NOTNULL(attr_graph, "Root graph %s has no subgraph named %s", root_compute_graph->GetName().c_str(),
attr_graph_name.c_str());
return attr_graph;
}
LowerResult CreateLowerWhileNode(const ge::NodePtr &node, const LowerInput &lower_input) {
std::vector<bg::ValueHolderPtr> inputs;
size_t num_tensors = lower_input.input_addrs.size();
const auto op_desc = node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
if (!lower_input.input_addrs.empty()) {
auto identified_shapes = bg::IdentityShape(lower_input.input_shapes);
LOWER_REQUIRE_VALID_HOLDER(identified_shapes);
const auto &identified_addrs = bg::IdentityAddr(lower_input.input_addrs, op_desc->GetStreamId());
LOWER_REQUIRE_VALID_HOLDER(identified_addrs);
inputs.insert(inputs.cend(), identified_shapes.cbegin(), identified_shapes.cend());
inputs.insert(inputs.cend(), identified_addrs.cbegin(), identified_addrs.cend());
}
LowerResult ret;
if (node->GetAllOutDataAnchorsSize() == 0U) {
ret.order_holders.push_back(bg::ValueHolder::CreateVoid<bg::ValueHolder>(node->GetTypePtr(), inputs));
LOWER_REQUIRE_VALID_HOLDER(ret.order_holders);
return ret;
}
size_t num_outputs = node->GetAllOutDataAnchorsSize() << 1U;
auto holders = bg::DevMemValueHolder::CreateDataOutput(node->GetTypePtr(), inputs, num_outputs,
op_desc->GetStreamId());
LOWER_REQUIRE_VALID_HOLDER(holders);
ret.order_holders.push_back(holders.front());
ret.out_shapes.insert(ret.out_shapes.cend(), holders.cbegin(), holders.cbegin() + node->GetAllOutDataAnchorsSize());
ret.out_addrs.insert(ret.out_addrs.cend(), holders.cbegin() + node->GetAllOutDataAnchorsSize(), holders.cend());
for (size_t i = 0U; i < ret.out_shapes.size(); i++) {
LOWER_REQUIRE(i + num_tensors < inputs.size());
ret.out_shapes[i]->RefFrom(inputs[i]);
ret.out_addrs[i]->RefFrom(inputs[i + num_tensors]);
bg::ValueHolder::CreateVoidGuarder("FreeMemory", ret.out_addrs[i], {});
}
return ret;
}
bool BuildGraphCallControlFrame(const bg::ValueHolderPtr &subgraph_call) {
GE_ASSERT_NOTNULL(bg::ValueHolder::PushGraphFrame(subgraph_call, "control_frame"));
auto pivot = bg::ValueHolder::CreateVoid<bg::ValueHolder>("BranchPivot", {});
GE_ASSERT(IsValidHolder(pivot));
auto done = bg::ValueHolder::CreateVoid<bg::ValueHolder>("BranchDone", {});
GE_ASSERT(IsValidHolder(done));
bg::ValueHolder::AddDependency(pivot, done);
auto frame = bg::ValueHolder::PopGraphFrame();
GE_ASSERT_NOTNULL(frame);
return frame->GetExecuteGraph() != nullptr;
}
ge::graphStatus BuildGraphCall(const ge::ComputeGraphPtr &graph, const std::vector<bg::ValueHolderPtr> &inputs,
LoweringGlobalData &global_data, std::vector<bg::ValueHolderPtr> &outputs) {
auto net_output = graph->FindFirstNodeMatchType(ge::NETOUTPUT);
GE_ASSERT_NOTNULL(net_output, "Failed build graph call for %s as has no net-output node", graph->GetName().c_str());
size_t num_compute_tensors = net_output->GetAllInDataAnchorsSize();
outputs = bg::ValueHolder::CreateDataOutput("SubgraphCall", inputs, num_compute_tensors << 1U);
GE_ASSERT(IsValidHolder(outputs));
GE_ASSERT_TRUE(!outputs.empty());
GE_ASSERT_TRUE(BuildGraphCallControlFrame(outputs.front()));
GE_ASSERT_NOTNULL(bg::ValueHolder::PushGraphFrame(outputs.front(), "f"));
auto lower_result = ConvertComputeSubgraphToExecuteGraph(graph, global_data, 0);
GE_ASSERT_NOTNULL(lower_result, "Failed lower cond graph %s", graph->GetName().c_str());
GE_ASSERT(lower_result->result.IsSuccess());
GE_ASSERT_EQ(lower_result->out_addrs.size(), 1U);
GE_ASSERT_NOTNULL(net_output->GetOpDescBarePtr());
const auto &input_desc = net_output->GetOpDescBarePtr()->GetInputDescPtr(0U);
GE_ASSERT_NOTNULL(input_desc);
auto dt = input_desc->GetDataType();
auto data_type = bg::ValueHolder::CreateConst(&dt, sizeof(dt));
GE_ASSERT(IsValidHolder(data_type));
auto cond_output = bg::ValueHolder::CreateSingleDataOutput(
"GenCondForWhile", {lower_result->out_addrs[0U], lower_result->out_shapes[0U], data_type});
auto graph_frame = bg::ValueHolder::PopGraphFrame({cond_output}, {});
GE_ASSERT_NOTNULL(graph_frame);
ge::DumpGraph(graph_frame->GetExecuteGraph().get(), "LowerWhile_CondGraph");
return ge::GRAPH_SUCCESS;
}
ge::graphStatus BuildCondGraphCall(const ge::NodePtr &while_node, const std::vector<bg::ValueHolderPtr> &inputs,
LoweringGlobalData &global_data, std::vector<bg::ValueHolderPtr> &outputs) {
auto graph = GetAttrSubgraph(while_node, "cond");
GE_ASSERT_NOTNULL(graph);
auto net_output = graph->FindFirstNodeMatchType(ge::NETOUTPUT);
GE_ASSERT_NOTNULL(net_output, "Cond graph %s has no net-output node", graph->GetName().c_str());
ge::AttrUtils::SetStr(net_output->GetOpDescBarePtr(), "_ge_attr_lowering_func", kLowerWhileCondOutput);
return BuildGraphCall(graph, inputs, global_data, outputs);
}
ge::graphStatus LowerBodyGraph(const ge::NodePtr &while_node, const bg::ValueHolderPtr &lower_while,
LoweringGlobalData &global_data) {
GE_ASSERT_NOTNULL(bg::ValueHolder::PushGraphFrame(lower_while, "body"));
auto body_graph = GetAttrSubgraph(while_node, "body");
GE_ASSERT_NOTNULL(body_graph);
auto net_output = body_graph->FindFirstNodeMatchType(ge::NETOUTPUT);
GE_ASSERT_NOTNULL(net_output, "Body graph %s has no net-output node", body_graph->GetName().c_str());
(void)ge::AttrUtils::SetStr(net_output->GetOpDescBarePtr(), "_ge_attr_lowering_func", kLowerWhileBodyOutput);
auto lower_body_result = ConvertComputeSubgraphToExecuteGraph(body_graph, global_data, 0);
GE_ASSERT_NOTNULL(lower_body_result, "Failed lower body graph %s", body_graph->GetName().c_str());
GE_ASSERT(lower_body_result->result.IsSuccess());
std::vector<bg::ValueHolderPtr> body_outputs = lower_body_result->out_shapes;
body_outputs.insert(body_outputs.cend(), lower_body_result->out_addrs.cbegin(), lower_body_result->out_addrs.cend());
auto body_graph_frame = bg::ValueHolder::PopGraphFrame(body_outputs, {});
GE_ASSERT_NOTNULL(body_graph_frame);
ge::DumpGraph(body_graph_frame->GetExecuteGraph().get(), "LowerWhile_BodyGraph");
return ge::GRAPH_SUCCESS;
}
ge::graphStatus BuildControlFrame(const ge::NodePtr &while_node, const bg::ValueHolderPtr &lower_while,
LoweringGlobalData &global_data) {
GE_ASSERT_NOTNULL(bg::ValueHolder::PushGraphFrame(lower_while, "control_frame"));
size_t num_inputs = (while_node->GetAllInDataAnchorsSize() << 1U);
std::vector<bg::ValueHolderPtr> inputs;
inputs.reserve(num_inputs);
for (size_t i = 0U; i < num_inputs; i++) {
inputs.emplace_back(bg::ValueHolder::CreateSingleDataOutput("InnerData", {}));
GE_ASSERT(IsValidHolder(inputs.back()));
GE_ASSERT(ge::AttrUtils::SetInt(bg::ValueHolderUtils::GetNodeOpDescBarePtr(inputs.back()), "index", i));
}
auto enter = bg::ValueHolder::CreateVoid<bg::ValueHolder>("Enter", {});
auto wait_anyone = bg::ValueHolder::CreateVoid<bg::ValueHolder>("WaitAnyone", {});
bg::ValueHolder::AddDependency(enter, wait_anyone);
std::vector<bg::ValueHolderPtr> cond_outputs;
GE_ASSERT_GRAPH_SUCCESS(BuildCondGraphCall(while_node, inputs, global_data, cond_outputs));
bg::ValueHolder::AddDependency(wait_anyone, cond_outputs.front());
const size_t kNotifyOutputNum = 3U;
auto switch_notifies = bg::ValueHolder::CreateDataOutput("CondSwitchNotify", cond_outputs, kNotifyOutputNum);
GE_ASSERT(IsValidHolder(switch_notifies));
GE_ASSERT(ge::AttrUtils::SetBool(bg::ValueHolderUtils::GetNodeOpDescBarePtr(switch_notifies[0]), ge::kRequestWatcher,
true));
auto watcher_placeholder = bg::ValueHolder::CreateVoid<bg::ValueHolder>("WatcherPlaceholder", {switch_notifies[0]});
GE_ASSERT(IsValidHolder(watcher_placeholder));
auto body_pivot = bg::ValueHolder::CreateVoid<bg::ValueHolder>("BranchPivot", {switch_notifies[1]});
auto body_done = bg::ValueHolder::CreateVoid<bg::ValueHolder>("BranchDone", {});
GE_ASSERT_HYPER_SUCCESS(bg::ValueHolder::AddDependency(body_pivot, body_done));
auto exit = bg::ValueHolder::CreateVoid<bg::ValueHolder>("Exit", {switch_notifies[2]});
GE_ASSERT(IsValidHolder(exit));
bg::ValueHolder::PopGraphFrame();
return ge::GRAPH_SUCCESS;
}
}
LowerResult LoweringWhile(const ge::NodePtr &node, const LowerInput &lower_input) {
LowerResult ret = CreateLowerWhileNode(node, lower_input);
LOWER_RETURN_IF_ERROR(ret);
auto &lower_while = ret.order_holders.empty() ? ret.out_shapes.front() : ret.order_holders.front();
LOWER_REQUIRE_GRAPH_SUCCESS(BuildControlFrame(node, lower_while, *lower_input.global_data));
LOWER_REQUIRE_GRAPH_SUCCESS(LowerBodyGraph(node, lower_while, *lower_input.global_data));
return ret;
}
REGISTER_NODE_CONVERTER_PLACEMENT("While", 0, LoweringWhile);
REGISTER_NODE_CONVERTER_PLACEMENT("StatelessWhile", 0, LoweringWhile);
}