* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "bg_event.h"
#include "common/checker.h"
#include "exe_graph/lowering/frame_selector.h"
#include "graph/debug/ge_attr_define.h"
#include "kernel/common_kernel_impl/event.h"
namespace gert {
namespace bg {
namespace {
constexpr const ge::char_t *kGlobalDataGertEvents = "GertEvents";
ge::Status CollectFirstSyncEventsInfo(const ge::ComputeGraphPtr &compute_graph, std::vector<EventInfo> &event_infos) {
for (const auto node : compute_graph->GetDirectNodePtr()) {
const auto op_desc = node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
if (op_desc->GetStreamId() <= kDefaultMainStreamId) {
continue;
}
if (node->GetInAllNodes().empty()) {
std::vector<int64_t> recive_id_list;
(void)ge::AttrUtils::GetListInt(op_desc, ge::ATTR_NAME_RECV_EVENT_IDS, recive_id_list);
if (recive_id_list.empty()) {
EventInfo event_info(kDefaultMainStreamId, op_desc->GetStreamId());
event_infos.emplace_back(event_info);
}
}
}
return ge::SUCCESS;
}
ge::Status CollectLastSyncEventsInfo(const ge::ComputeGraphPtr &compute_graph, std::vector<EventInfo> &event_infos) {
for (const auto node : compute_graph->GetDirectNodePtr()) {
const auto op_desc = node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
if (op_desc->GetStreamId() <= kDefaultMainStreamId) {
continue;
}
if (node->GetOutAllNodes().empty()) {
std::vector<int64_t> send_id_list;
(void)ge::AttrUtils::GetListInt(op_desc, ge::ATTR_NAME_SEND_EVENT_IDS, send_id_list);
if (send_id_list.empty()) {
EventInfo event_info(op_desc->GetStreamId(), kDefaultMainStreamId);
event_infos.emplace_back(event_info);
}
}
}
return ge::SUCCESS;
}
void CollectLastResourceCleanEventsInfo(int64_t stream_num, std::vector<EventInfo> &event_infos) {
if (stream_num <= 1) {
return;
}
for (int64_t i = 1; i < stream_num; ++i) {
EventInfo event_info(i, kDefaultMainStreamId);
event_infos.emplace_back(event_info);
}
}
ge::graphStatus CollectAllEventInfos(const ge::ComputeGraphPtr &compute_graph, int64_t event_num,
std::vector<EventInfo> &event_infos) {
event_infos.resize(event_num);
std::unordered_set<int64_t> event_ids;
std::unordered_set<int64_t> send_event_ids;
std::unordered_set<int64_t> recive_event_ids;
for (const auto node : compute_graph->GetAllNodesPtr()) {
if (!node->GetOwnerComputeGraphBarePtr()->GetGraphUnknownFlag()) {
continue;
}
const auto op_desc = node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
std::vector<int64_t> send_id_list;
std::vector<int64_t> recive_id_list;
(void)ge::AttrUtils::GetListInt(op_desc, ge::ATTR_NAME_SEND_EVENT_IDS, send_id_list);
(void)ge::AttrUtils::GetListInt(op_desc, ge::ATTR_NAME_RECV_EVENT_IDS, recive_id_list);
const int64_t stream_id = op_desc->GetStreamId();
for (const auto send_event_id : send_id_list) {
event_ids.insert(send_event_id);
GE_ASSERT_TRUE(send_event_ids.insert(send_event_id).second, "Found duplicated send event id %ld on node %s.",
send_event_id, node->GetNamePtr());
GE_ASSERT_TRUE(send_event_id < event_num, "Found send event id %ld out of range %ld.", send_event_id, event_num);
event_infos[send_event_id].src_logic_stream_id = stream_id;
}
for (const auto recive_event_id : recive_id_list) {
event_ids.insert(recive_event_id);
GE_ASSERT_TRUE(recive_event_ids.insert(recive_event_id).second, "Found duplicated send event id %ld on node %s.",
recive_event_id, node->GetNamePtr());
GE_ASSERT_TRUE(recive_event_id < event_num, "Found recive event id %ld out of range %ld.", recive_event_id,
event_num);
event_infos[recive_event_id].dst_logic_stream_id = stream_id;
}
}
GE_ASSERT_TRUE(event_ids.size() == static_cast<size_t>(event_num),
"EvenId in nodes is not equal with event num %lld on model.", event_num);
for (size_t i = 0U; i < event_infos.size(); ++i) {
GE_ASSERT_TRUE(event_infos[i].HasInitialized(), "Found event %zu no in pair.", i);
}
return ge::GRAPH_SUCCESS;
}
void AppendEventInfos(EventInfo *event_info_addrs, const std::vector<EventInfo> event_infos, size_t start_index) {
for (size_t i = 0U; i < event_infos.size(); ++i) {
event_info_addrs[start_index + i] = event_infos[i];
}
}
ValueHolderPtr CreateConstOfAllEventInfos(const std::vector<EventInfo> &model_event_info,
const std::vector<std::vector<EventInfo>> &stage_2_sync_events) {
GE_ASSERT_EQ(stage_2_sync_events.size(), static_cast<size_t>(SyncEventStage::kStageEnd));
auto &first_sync_event_info = stage_2_sync_events[static_cast<size_t>(SyncEventStage::kFirstSyncStage)];
auto &last_sync_event_info = stage_2_sync_events[static_cast<size_t>(SyncEventStage::kLastSyncStage)];
auto &last_resource_clean_events_info =
stage_2_sync_events[static_cast<size_t>(SyncEventStage::kLastResourceCleanStage)];
const auto all_event_num = model_event_info.size() + first_sync_event_info.size() + last_sync_event_info.size() +
last_resource_clean_events_info.size();
size_t total_size = 0U;
auto all_event_infos = ContinuousVector::Create<EventInfo>(all_event_num, total_size);
auto all_event_infos_vec = reinterpret_cast<ContinuousVector *>(all_event_infos.get());
GE_ASSERT_NOTNULL(all_event_infos_vec);
all_event_infos_vec->SetSize(all_event_num);
auto event_info_addrs = reinterpret_cast<EventInfo *>(all_event_infos_vec->MutableData());
AppendEventInfos(event_info_addrs, model_event_info, 0U);
AppendEventInfos(event_info_addrs, first_sync_event_info, model_event_info.size());
AppendEventInfos(event_info_addrs, last_sync_event_info, model_event_info.size() + first_sync_event_info.size());
AppendEventInfos(event_info_addrs, last_resource_clean_events_info,
model_event_info.size() + first_sync_event_info.size() + last_sync_event_info.size());
return ValueHolder::CreateConst(all_event_infos.get(), total_size);
}
}
ValueHolderPtr CollectAndCreateGertEvents(const ge::ComputeGraphPtr &compute_graph, const ModelDesc &model_desc,
LoweringGlobalData &global_data,
std::vector<std::vector<EventInfo>> &stage_2_sync_events) {
std::vector<EventInfo> model_event_info;
const int64_t event_num = model_desc.GetReusableEventNum();
GE_ASSERT_GRAPH_SUCCESS(CollectAllEventInfos(compute_graph, event_num, model_event_info));
GE_ASSERT_EQ(stage_2_sync_events.size(), static_cast<size_t>(SyncEventStage::kStageEnd));
auto &first_sync_events = stage_2_sync_events[static_cast<size_t>(SyncEventStage::kFirstSyncStage)];
auto &last_sync_events = stage_2_sync_events[static_cast<size_t>(SyncEventStage::kLastSyncStage)];
auto &last_resource_clean_events = stage_2_sync_events[static_cast<size_t>(SyncEventStage::kLastResourceCleanStage)];
GE_ASSERT_SUCCESS(CollectFirstSyncEventsInfo(compute_graph, first_sync_events));
GE_ASSERT_SUCCESS(CollectLastSyncEventsInfo(compute_graph, last_sync_events));
CollectLastResourceCleanEventsInfo(model_desc.GetReusableStreamNum(), last_resource_clean_events);
if (model_event_info.empty() && first_sync_events.empty() && last_sync_events.empty() &&
last_resource_clean_events.empty()) {
return nullptr;
}
auto init_out = FrameSelector::OnInitRoot(
[&model_event_info, &stage_2_sync_events]() -> std::vector<ValueHolderPtr> {
auto all_events_info = CreateConstOfAllEventInfos(model_event_info, stage_2_sync_events);
GE_ASSERT_NOTNULL(all_events_info);
return {ValueHolder::CreateSingleDataOutput("CreateGertEvents", {all_events_info})};
});
GE_ASSERT_TRUE(!init_out.empty());
auto builder = [&init_out]() -> ValueHolderPtr {
return init_out[0];
};
return global_data.GetOrCreateUniqueValueHolder(bg::kGlobalDataGertEvents, builder);
}
void LoweringFirstSyncEvents(const std::vector<EventInfo> &first_sync_events, int64_t current_logic_event_num,
LoweringGlobalData &global_data) {
if (first_sync_events.empty()) {
return;
}
std::vector<int64_t> logic_event_ids(first_sync_events.size());
for (size_t i = 0U; i < first_sync_events.size(); ++i) {
logic_event_ids[i] = current_logic_event_num + i;
}
auto first_event_sync_builder = [&]() -> std::vector<bg::ValueHolderPtr> {
auto send_holder = bg::SendEvents(kDefaultMainStreamId, logic_event_ids, global_data);
std::vector<bg::ValueHolderPtr> wait_holders;
for (size_t i = 0U; i < logic_event_ids.size(); ++i) {
auto wait_holder = bg::WaitEvents(first_sync_events[i].dst_logic_stream_id, {logic_event_ids[i]}, global_data);
bg::ValueHolder::AddDependency(send_holder, wait_holder);
wait_holders.emplace_back(wait_holder);
}
return wait_holders;
};
bg::FrameSelector::OnMainRootFirst(first_event_sync_builder);
return;
}
void LoweringLastSyncEvents(const std::vector<EventInfo> &last_sync_events, int64_t current_logic_event_num,
LoweringGlobalData &global_data) {
if (last_sync_events.empty()) {
return;
}
std::vector<int64_t> wait_logic_event_ids(last_sync_events.size());
for (size_t i = 0U; i < last_sync_events.size(); ++i) {
wait_logic_event_ids[i] = current_logic_event_num + i;
}
auto last_event_sync_builder = [&]() -> std::vector<bg::ValueHolderPtr> {
auto wait_holder = bg::WaitEvents(kDefaultMainStreamId, wait_logic_event_ids, global_data);
std::vector<bg::ValueHolderPtr> send_holders;
for (size_t i = 0U; i < last_sync_events.size(); ++i) {
auto event_info = last_sync_events[i];
std::vector<int64_t> send_logic_event_ids = {current_logic_event_num + static_cast<int64_t>(i)};
auto send_holder = bg::SendEvents(event_info.src_logic_stream_id, send_logic_event_ids, global_data);
bg::ValueHolder::AddDependency(send_holder, wait_holder);
send_holders.emplace_back(send_holder);
}
return {wait_holder};
};
bg::FrameSelector::OnMainRootLastEventSync(last_event_sync_builder);
}
std::vector<ValueHolderPtr> PrepareEventInputs(int64_t logic_stream_id, const std::vector<int64_t> &logic_event_ids,
LoweringGlobalData &global_data) {
std::vector<ValueHolderPtr> inputs;
auto src_stream_holder = global_data.GetStreamById(logic_stream_id);
inputs.emplace_back(src_stream_holder);
size_t total_size = 0U;
auto event_id_vec_holder = ContinuousVector::Create<int64_t>(logic_event_ids.size(), total_size);
auto evend_id_vec = reinterpret_cast<ContinuousVector *>(event_id_vec_holder.get());
GE_ASSERT_NOTNULL(evend_id_vec);
evend_id_vec->SetSize(logic_event_ids.size());
for (size_t i = 0U; i < logic_event_ids.size(); ++i) {
*(reinterpret_cast<int64_t *>(evend_id_vec->MutableData()) + i) = logic_event_ids[i];
}
auto logic_event_ids_holder = ValueHolder::CreateConst(event_id_vec_holder.get(), total_size);
inputs.emplace_back(logic_event_ids_holder);
auto all_gert_events_holder = global_data.GetUniqueValueHolder(bg::kGlobalDataGertEvents);
GE_ASSERT_NOTNULL(all_gert_events_holder);
inputs.emplace_back(all_gert_events_holder);
auto all_rt_events_holder = global_data.GetUniqueValueHolder(bg::kGlobalDataRtEvents);
GE_ASSERT_NOTNULL(all_rt_events_holder);
inputs.emplace_back(all_rt_events_holder);
auto l2_allocator =
global_data.GetOrCreateL2Allocator(logic_stream_id, {kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput});
GE_ASSERT_NOTNULL(l2_allocator);
inputs.emplace_back(l2_allocator);
return inputs;
}
void LoweringLastResourceCleanEvents(const std::vector<EventInfo> &last_sync_events, int64_t current_logic_event_num,
LoweringGlobalData &global_data) {
if (last_sync_events.empty()) {
return;
}
std::vector<int64_t> logic_event_ids(last_sync_events.size());
for (size_t i = 0U; i < last_sync_events.size(); ++i) {
logic_event_ids[i] = current_logic_event_num + i;
}
auto last_event_sync_builder = [&]() -> std::vector<bg::ValueHolderPtr> {
auto inputs = PrepareEventInputs(kDefaultMainStreamId, logic_event_ids, global_data);
inputs[static_cast<size_t>(SendEventsInput::kAllocator)] = global_data.GetOrCreateAllL2Allocators();
auto wait_holder = ValueHolder::CreateVoid<bg::ValueHolder>(
"LastWaitEvents", inputs);
std::vector<bg::ValueHolderPtr> send_holders;
for (size_t i = 0U; i < last_sync_events.size(); ++i) {
auto event_info = last_sync_events[i];
std::vector<int64_t> logical_event_ids = {current_logic_event_num + static_cast<int64_t>(i)};
auto send_holder = ValueHolder::CreateVoid<bg::ValueHolder>(
"LastSendEvents", PrepareEventInputs(event_info.src_logic_stream_id, logical_event_ids, global_data));
bg::ValueHolder::AddDependency(send_holder, wait_holder);
send_holders.emplace_back(send_holder);
}
return {wait_holder};
};
bg::FrameSelector::OnMainRootLastResourceClean(last_event_sync_builder);
}
ValueHolderPtr SendEvents(int64_t logic_stream_id, const std::vector<int64_t> &logic_event_ids,
LoweringGlobalData &global_data) {
return ValueHolder::CreateVoid<bg::ValueHolder>("SendEvents",
PrepareEventInputs(logic_stream_id, logic_event_ids, global_data));
}
ValueHolderPtr WaitEvents(int64_t logic_stream_id, const std::vector<int64_t> &logic_event_ids,
LoweringGlobalData &global_data) {
return ValueHolder::CreateVoid<bg::ValueHolder>("WaitEvents",
PrepareEventInputs(logic_stream_id, logic_event_ids, global_data));
}
HyperStatus LoweringAccessMemCrossStream(const ge::NodePtr &curr_compute_node,
std::vector<bg::DevMemValueHolderPtr> &input_addrs) {
const auto curr_stream_id = curr_compute_node->GetOpDescBarePtr()->GetStreamId();
bg::ValueHolderPtr curr_stream_id_holder = nullptr;
for (size_t i = 0U; i < input_addrs.size(); ++i) {
auto addr_holder = input_addrs[i];
if (addr_holder->GetLogicStream() == curr_stream_id) {
continue;
}
if (curr_stream_id_holder == nullptr) {
curr_stream_id_holder = bg::ValueHolder::CreateConst(&curr_stream_id, sizeof(curr_stream_id));
}
GELOGD("LoweringAccessMemCrossStream, node name: %s, input index: %u, curr_stream_id: %lld, logic stream id: %lld",
curr_compute_node->GetNamePtr(), i, curr_stream_id, addr_holder->GetLogicStream());
auto access_holder = bg::DevMemValueHolder::CreateSingleDataOutput(
"AccessMemCrossStream", {addr_holder, curr_stream_id_holder}, curr_stream_id);
access_holder->SetPlacement(addr_holder->GetPlacement());
bg::ValueHolder::CreateVoidGuarder("FreeMemory", access_holder, {});
input_addrs[i] = access_holder;
}
return HyperStatus::Success();
}
}
}