* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "deploy/flowrm/flow_route_planner.h"
#include "common/checker.h"
#include "graph/ge_context.h"
#include "deploy/flowrm/network_manager.h"
#include "deploy/resource/resource_manager.h"
#include "dflow/flow_graph/data_flow_attr_define.h"
namespace ge {
namespace {
bool HasNMappingNode(const FlowModelPtr &flow_model) {
GE_ASSERT_NOTNULL(flow_model);
bool contains_n_mapping_node = false;
constexpr const char *kAttrNameDataFlowContainNMappingNode = "_contains_n-mapping_node";
(void)AttrUtils::GetBool(flow_model->GetRootGraph(), kAttrNameDataFlowContainNMappingNode, contains_n_mapping_node);
GELOGI("flow model[%s] attr[%s] value is %d.", flow_model->GetModelName().c_str(),
kAttrNameDataFlowContainNMappingNode, static_cast<int32_t>(contains_n_mapping_node));
return contains_n_mapping_node;
}
}
Status FlowRoutePlanner::ResolveFlowRoutePlans(DeployState &deploy_state) {
if (deploy_state.GetLocalFlowRoutePlan().first > 0) {
GELOGI("flow route plan already resolved");
return SUCCESS;
}
const auto &deploy_plan = deploy_state.GetDeployPlan();
std::set<int32_t> unique_node_ids;
DeployPlan::DeviceInfo head_device = deploy_plan.GetRootModelQueueDeviceInfo();
unique_node_ids.emplace(head_device.GetNodeId());
for (const auto &it : deploy_plan.GetSubmodels()) {
const auto &target_device = it.second.queue_device_info;
(void) unique_node_ids.emplace(target_device.GetNodeId());
}
PlanAttrs plan_attrs;
plan_attrs.root_model_id = deploy_state.GetRootModelId();
plan_attrs.is_dynamic_sched = deploy_state.GetIsDynamicSched();
plan_attrs.keep_out_of_order = HasNMappingNode(deploy_state.GetFlowModel()) || deploy_state.IsEnableExceptionCatch();
for (const auto &node_id : unique_node_ids) {
deployer::FlowRoutePlan route_plan;
GE_CHK_STATUS_RET_NOLOG(ResolveFlowRoutePlan(deploy_plan, node_id, route_plan, plan_attrs));
deploy_state.AddFlowRoutePlanToDeploy(node_id, route_plan);
}
return SUCCESS;
}
Status FlowRoutePlanner::ResolveFlowRoutePlan(const DeployPlan &deploy_plan,
int32_t node_id,
deployer::FlowRoutePlan &flow_route_plan,
const PlanAttrs &plan_attrs) {
GELOGD("ResolveFlowRoutePlan start, target node id = %d", node_id);
GE_CHK_STATUS_RET_NOLOG(ResolveEndpoints(deploy_plan, node_id, flow_route_plan, plan_attrs));
GE_CHK_STATUS_RET_NOLOG(
ResolveBindings(deploy_plan, node_id, flow_route_plan, plan_attrs.root_model_id));
PrintFlowRoutePlan(flow_route_plan);
return SUCCESS;
}
void FlowRoutePlanner::SetEndpointType(const DeployPlan::QueueInfo &endpoint_info,
deployer::EndpointDesc &endpoint_desc) {
if (endpoint_info.ref_index >= 0) {
endpoint_desc.set_type(static_cast<int32_t>(ExchangeEndpointType::kEndpointTypeRefQueue));
endpoint_desc.mutable_queue_desc()->set_ref_index(endpoint_info.ref_index);
} else if (endpoint_info.is_dummy) {
endpoint_desc.set_type(static_cast<int32_t>(ExchangeEndpointType::kEndpointTypeDummyQueue));
} else {
endpoint_desc.set_type(endpoint_info.owned ?
static_cast<int32_t>(ExchangeEndpointType::kEndpointTypeQueue) :
static_cast<int32_t>(ExchangeEndpointType::kEndpointTypeExternalQueue));
}
}
Status FlowRoutePlanner::ResolveEndpoints(const DeployPlan &deploy_plan,
int32_t node_id,
deployer::FlowRoutePlan &flow_route_plan,
const PlanAttrs &plan_attrs) {
for (size_t i = 0U; i < deploy_plan.GetQueueInfoList().size(); ++i) {
auto endpoint_index = static_cast<int32_t>(i);
auto &endpoint_info = deploy_plan.GetQueueInfoList()[i];
const auto &device_info = endpoint_info.device_info;
auto endpoint_desc = flow_route_plan.add_endpoints();
endpoint_desc->set_name(endpoint_info.name);
endpoint_desc->set_node_id(device_info.GetNodeId());
endpoint_desc->set_device_id(device_info.GetDeviceId());
endpoint_desc->set_device_type(device_info.GetType());
endpoint_desc->set_model_id(endpoint_info.model_id);
endpoint_desc->set_is_dynamic_sched(plan_attrs.is_dynamic_sched);
endpoint_desc->set_root_model_id(plan_attrs.root_model_id);
if (device_info.GetNodeId() == node_id) {
if (deploy_plan.IsGroupEndpoint(endpoint_index)) {
GE_CHK_STATUS_RET_NOLOG(
FillGroupEndpoint(deploy_plan, endpoint_index, plan_attrs.keep_out_of_order, *endpoint_desc));
const auto &group_desc = endpoint_desc->group_desc();
GELOGD("[Group] added ,index = %d, name = %s, elements = %s",
endpoint_index,
endpoint_desc->name().c_str(),
ToString(std::vector<int32_t>(group_desc.endpoint_indices().begin(),
group_desc.endpoint_indices().end())).c_str());
} else {
SetEndpointType(endpoint_info, *endpoint_desc);
endpoint_desc->mutable_queue_desc()->set_depth(endpoint_info.depth);
endpoint_desc->mutable_queue_desc()->set_enqueue_policy(endpoint_info.enqueue_policy);
endpoint_desc->mutable_queue_desc()->set_name(endpoint_info.name);
endpoint_desc->mutable_queue_desc()->set_fusion_offset(endpoint_info.fusion_offset);
GELOGD("[Queue] added, index = %d, name = %s, device = %s, type = %d, ref index = %d, fusion offset = %d",
endpoint_index, endpoint_desc->name().c_str(), device_info.GetDesc().c_str(),
endpoint_desc->type(), endpoint_info.ref_index, endpoint_info.fusion_offset);
}
} else {
endpoint_desc->set_type(static_cast<int32_t>(ExchangeEndpointType::kEndpointTypeNone));
}
}
return SUCCESS;
}
Status FlowRoutePlanner::FillGroupEndpoint(const DeployPlan &deploy_plan,
int32_t group_index, bool keep_out_of_order,
deployer::EndpointDesc &endpoint_desc) {
endpoint_desc.set_type(static_cast<int32_t>(ExchangeEndpointType::kEndpointTypeGroup));
const DeployPlan::QueueInfo *group_queue_info = nullptr;
GE_CHK_STATUS_RET_NOLOG(deploy_plan.GetQueueInfo(group_index, group_queue_info));
endpoint_desc.set_instance_num(group_queue_info->instance_num);
endpoint_desc.set_instance_idx(group_queue_info->instance_idx);
auto queue_list_iter = deploy_plan.GetGroups().find(group_index);
GE_CHK_BOOL_RET_STATUS(queue_list_iter != deploy_plan.GetGroups().end(),
FAILED, "Get group[%d] info failed.", group_index);
auto group_entries_index_start = static_cast<int32_t>(deploy_plan.GetQueueInfoList().size());
for (auto entry_index : queue_list_iter->second) {
auto &entry_info = deploy_plan.GetGroupEntryInfoList()[entry_index];
if (entry_info.ref_index >= 0) {
endpoint_desc.mutable_group_desc()->add_endpoint_indices(entry_info.ref_index);
} else {
endpoint_desc.mutable_group_desc()->add_endpoint_indices(group_entries_index_start + entry_index);
}
}
endpoint_desc.mutable_group_desc()->set_keep_out_of_order(keep_out_of_order);
GELOGD("group[%s] keep out of order value=%d", endpoint_desc.name().c_str(), static_cast<int32_t>(keep_out_of_order));
return SUCCESS;
}
Status FlowRoutePlanner::ResolveBindings(const DeployPlan &deploy_plan,
int32_t node_id,
deployer::FlowRoutePlan &flow_route_plan,
const uint32_t root_model_id) {
std::map<std::string, std::set<size_t>> device_2_tag_indices;
std::vector<DeployPlan::DeviceInfo> devices;
GE_CHK_STATUS_RET(AddFlowRoutePlanBindings(deploy_plan, node_id, devices, flow_route_plan, device_2_tag_indices),
"Failed to add flow route bindings.");
for (size_t i = 0U; i < deploy_plan.GetGroupEntryInfoList().size(); ++i) {
const auto &entry_info = deploy_plan.GetGroupEntryInfoList()[i];
const auto &device_info = entry_info.device_info;
auto endpoint_desc = flow_route_plan.add_endpoints();
bool is_valid_tag = false;
for (const auto &target_device_info : devices) {
const auto &relative_tag_indices = device_2_tag_indices[target_device_info.GetKey()];
if ((entry_info.ref_index < 0) && (device_info.GetKey() != target_device_info.GetKey()) &&
(relative_tag_indices.count(i) > 0U)) {
uint32_t target_device_rank_id = UINT32_MAX;
GE_CHK_STATUS_RET(GetDeviceRankId(target_device_info, target_device_rank_id),
"Failed to get target device rank id, entry_index[%zu], name[%s].",
i, entry_info.name.c_str());
endpoint_desc->set_name(entry_info.name);
endpoint_desc->set_node_id(target_device_info.GetNodeId());
endpoint_desc->set_device_id(target_device_info.GetHcomDeviceId());
endpoint_desc->set_device_type(target_device_info.GetType());
endpoint_desc->set_type(static_cast<int32_t>(ExchangeEndpointType::kEndpointTypeTag));
endpoint_desc->set_root_model_id(root_model_id);
auto tag_desc = endpoint_desc->mutable_tag_desc();
GE_CHK_STATUS_RET(AssignTagInfo(device_info, *tag_desc, entry_info.name, i));
tag_desc->set_rank_id(target_device_rank_id);
tag_desc->set_name(entry_info.name);
tag_desc->set_depth(entry_info.depth);
tag_desc->set_peer_node_id(device_info.GetNodeId());
tag_desc->set_peer_device_id(device_info.GetHcomDeviceId());
tag_desc->set_peer_device_type(device_info.GetType());
GELOGI("[Tag] added, entry_index[%zu], name[%s], device[%s], rank_id[%u], peer_rank_id[%u], "
"tag_id[%u], hcom_device_id[%d], target device_type[%d]",
i, tag_desc->name().c_str(), device_info.GetDesc().c_str(), tag_desc->rank_id(),
tag_desc->peer_rank_id(), tag_desc->tag_id(),
target_device_info.GetHcomDeviceId(), target_device_info.GetType());
is_valid_tag = true;
}
}
if (!is_valid_tag) {
endpoint_desc->set_type(static_cast<int32_t>(ExchangeEndpointType::kEndpointTypeNone));
}
}
return SUCCESS;
}
Status FlowRoutePlanner::AddFlowRoutePlanBindings(const DeployPlan &deploy_plan,
int32_t node_id,
std::vector<DeployPlan::DeviceInfo> &devices,
deployer::FlowRoutePlan &flow_route_plan,
std::map<std::string, std::set<size_t>> &relative_tag_indices) {
for (auto &binding : deploy_plan.GetQueueBindings()) {
auto src_q_idx = binding.first;
auto dst_q_idx = binding.second;
const DeployPlan::QueueInfo *src_queue_info = nullptr;
const DeployPlan::QueueInfo *dst_queue_info = nullptr;
GE_CHK_STATUS_RET_NOLOG(deploy_plan.GetQueueInfo(src_q_idx, src_queue_info));
GE_CHK_STATUS_RET_NOLOG(deploy_plan.GetQueueInfo(dst_q_idx, dst_queue_info));
if (src_queue_info->device_info.GetNodeId() == node_id || dst_queue_info->device_info.GetNodeId() == node_id) {
GE_CHK_STATUS_RET(AddGroupEntries(deploy_plan,
src_q_idx,
devices,
relative_tag_indices),
"Failed to add group entries.");
GE_CHK_STATUS_RET(AddGroupEntries(deploy_plan,
dst_q_idx,
devices,
relative_tag_indices),
"Failed to add group entries.");
GELOGD("src index = %d, dst index = %d, src = %s@%s, dst = %s@%s, src is group = %d, dst is group = %d",
src_q_idx,
dst_q_idx,
src_queue_info->name.c_str(),
src_queue_info->device_info.GetDesc().c_str(),
dst_queue_info->name.c_str(),
dst_queue_info->device_info.GetDesc().c_str(),
deploy_plan.IsGroupEndpoint(src_q_idx),
deploy_plan.IsGroupEndpoint(dst_q_idx));
auto new_binding = src_queue_info->owned ?
flow_route_plan.add_bindings_before_load() :
flow_route_plan.add_bindings();
GE_CHECK_NOTNULL(new_binding);
new_binding->set_src_index(src_q_idx);
new_binding->set_dst_index(dst_q_idx);
}
}
return SUCCESS;
}
Status FlowRoutePlanner::AddGroupEntries(const DeployPlan &deploy_plan,
int32_t index,
std::vector<DeployPlan::DeviceInfo> &devices,
std::map<std::string, std::set<size_t>> &relative_tag_indices) {
if (deploy_plan.IsGroupEndpoint(index)) {
const DeployPlan::QueueInfo *queue_info = nullptr;
GE_CHK_STATUS_RET_NOLOG(deploy_plan.GetQueueInfo(index, queue_info));
const auto &device_info = queue_info->device_info;
const auto &entries = deploy_plan.GetGroups().at(index);
GELOGI("Add group entries, group index = %d, entries = %s", index, ToString(entries).c_str());
const auto &it = relative_tag_indices.find(device_info.GetKey());
if (it == relative_tag_indices.cend()) {
devices.emplace_back(device_info);
}
relative_tag_indices[device_info.GetKey()].insert(entries.begin(), entries.end());
}
return SUCCESS;
}
Status FlowRoutePlanner::GetDeviceHcomInfo(const DeployPlan::DeviceInfo &device_info, uint32_t &rank_id) {
auto dev_info = ResourceManager::GetInstance().GetDeviceInfo(device_info.GetNodeId(),
device_info.GetDeviceId(), device_info.GetType());
GE_CHECK_NOTNULL(dev_info);
const auto &ip = dev_info->GetDeviceIp();
auto device_id = dev_info->GetHcomDeviceId();
GELOGI("Get remote device info success, ip = %s, device_id = %d", ip.c_str(), device_id);
int32_t tmp_rank_id = 0;
auto ret = DeployContext::LocalContext().GetRankTableBuilder().GetRankIdByDeviceId(ip, device_id, tmp_rank_id);
GE_CHK_BOOL_RET_STATUS(ret, FAILED, "Get rank id info failed, ip is %s, devid = %d.",
ip.c_str(), device_info.GetHcomDeviceId());
rank_id = static_cast<uint32_t>(tmp_rank_id);
return SUCCESS;
}
Status FlowRoutePlanner::GetDeviceRankId(const DeployPlan::DeviceInfo &device_info, uint32_t &rank_id) {
return GetDeviceHcomInfo(device_info, rank_id);
}
uint32_t FlowRoutePlanner::GetTagIdByName(const std::string &tag_name) {
static std::mutex mu;
std::lock_guard<std::mutex> lk(mu);
static uint32_t index = 0;
static std::unordered_map<std::string, uint32_t> tag_name_map;
if (tag_name_map.find(tag_name) != tag_name_map.end()) {
return tag_name_map[tag_name];
}
index++;
tag_name_map[tag_name] = index;
GELOGI("GetTagIdByName tag_name is %s, tag_id is %d.", tag_name.c_str(), index);
return index;
}
void FlowRoutePlanner::PrintFlowRoutePlan(const deployer::FlowRoutePlan &flow_route_plan) {
if (!IsLogEnable(GE_MODULE_NAME, DLOG_DEBUG)) {
return;
}
std::set<int32_t> printed;
for (const auto &binding : flow_route_plan.bindings()) {
printed.emplace(binding.src_index());
printed.emplace(binding.dst_index());
auto &src_endpoint = flow_route_plan.endpoints(binding.src_index());
auto &dst_endpoint = flow_route_plan.endpoints(binding.dst_index());
GELOGD("[Binding] src endpoint index = %d, type = %d, dst endpoint index = %d, type = %d",
binding.src_index(), src_endpoint.type(),
binding.dst_index(), dst_endpoint.type());
GELOGD("Src endpoint: ");
PrintEndpointDesc(flow_route_plan, src_endpoint, printed);
GELOGD("Dst endpoint: ");
PrintEndpointDesc(flow_route_plan, dst_endpoint, printed);
}
GELOGD("Unbound endpoints: ");
for (int32_t i = 0; i < flow_route_plan.endpoints_size(); ++i) {
if (printed.count(i) > 0) {
continue;
}
PrintEndpointDesc(flow_route_plan, flow_route_plan.endpoints(i), printed);
}
}
void FlowRoutePlanner::PrintEndpointDesc(const deployer::FlowRoutePlan &flow_route_plan,
const deployer::EndpointDesc &endpoint_desc,
std::set<int32_t> &printed) {
if (endpoint_desc.type() == static_cast<int32_t>(ExchangeEndpointType::kEndpointTypeExternalQueue)) {
GELOGD(" [ExternalQueue] name = %s", endpoint_desc.name().c_str());
} else if (endpoint_desc.type() == static_cast<int32_t>(ExchangeEndpointType::kEndpointTypeQueue)) {
GELOGD(" [Queue] name = %s", endpoint_desc.name().c_str());
} else if (endpoint_desc.type() == static_cast<int32_t>(ExchangeEndpointType::kEndpointTypeRefQueue)) {
GELOGD(" [RefQueue] name = %s, ref index = %d",
endpoint_desc.name().c_str(), endpoint_desc.queue_desc().ref_index());
} else if (endpoint_desc.type() == static_cast<int32_t>(ExchangeEndpointType::kEndpointTypeGroup)) {
GELOGD(" [Group] name = %s, element number = %d",
endpoint_desc.name().c_str(),
endpoint_desc.group_desc().endpoint_indices_size());
for (auto idx : endpoint_desc.group_desc().endpoint_indices()) {
printed.emplace(idx);
auto &group_entry = flow_route_plan.endpoints(idx);
if (group_entry.type() == static_cast<int32_t>(ExchangeEndpointType::kEndpointTypeTag)) {
auto &tag_desc = group_entry.tag_desc();
GELOGD(" [GroupEntry] index = %d, tag name = %s, tag_id = %u, rank_id = %u, peer_rank_id = %u",
idx,
group_entry.name().c_str(),
tag_desc.tag_id(),
tag_desc.rank_id(),
tag_desc.peer_rank_id());
} else if (group_entry.type() == static_cast<int32_t>(ExchangeEndpointType::kEndpointTypeQueue)) {
GELOGD(" [GroupEntry] index = %d, queue name = %s", idx, endpoint_desc.name().c_str());
} else {
GELOGW(" [GroupEntry] Unexpected group entry, index = %d, name = %s, type = %d",
idx,
endpoint_desc.name().c_str(),
group_entry.type());
}
}
}
}
Status FlowRoutePlanner::AssignTagInfo(const DeployPlan::DeviceInfo &device_info, deployer::TagDesc &tag_desc,
const std::string &tag_name, uint32_t index) {
uint32_t rank_id = UINT32_MAX;
GE_CHK_STATUS_RET(GetDeviceRankId(device_info, rank_id),
"Failed to get device rank id, entry_index[%zu], name[%s].",
index, tag_name.c_str());
tag_desc.set_tag_id(GetTagIdByName(tag_name));
tag_desc.set_peer_rank_id(rank_id);
return SUCCESS;
}
}