* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "stream_info.h"
#include <sstream>
#include "external/ge/ge_api_types.h"
#include "graph_metadef/graph/debug/ge_attr_define.h"
#include "graph/compute_graph.h"
#include "common/framework_types_internal.h"
namespace ge {
namespace {
std::string VecIntToStr(const std::vector<uint32_t> &vec) {
std::stringstream ss;
for (size_t i = 0; i < vec.size(); ++i) {
ss << vec[i];
if (i != vec.size() - 1) {
ss << ",";
}
}
return ss.str();
}
}
StreamGraph::StreamGraph(const ComputeGraphPtr &graph) {
name_ = graph->GetName();
for (const auto &node : graph->GetAllNodes()) {
nodes_.emplace_back(std::make_shared<StreamNode>(node));
}
Build();
}
std::string StreamGraph::DotId(const StreamNodePtr node) const {
std::stringstream dot_name;
dot_name << "S" << node->GetStreamId() << "N" << node->GetIdInStream();
return dot_name.str();
}
StreamNode::StreamNode(const NodePtr node) : origin_node_(node) {}
size_t StreamNode::GetTopoId() const {
return static_cast<size_t>(origin_node_->GetOpDesc()->GetId());
}
std::string StreamNode::GetStreamLabel() const {
std::string stream_label;
(void)ge::AttrUtils::GetStr(origin_node_->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label);
return stream_label;
}
std::vector<uint32_t> StreamNode::GetActivatedStreams() const {
std::vector<uint32_t> activated_streams;
(void)ge::AttrUtils::GetListInt(origin_node_->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, activated_streams);
return activated_streams;
}
int64_t StreamNode::GetStreamId() const {
return origin_node_->GetOpDesc()->GetStreamId();
}
std::string StreamNode::GetName() const {
return origin_node_->GetOpDesc()->GetName();
}
std::string StreamNode::GetType() const {
return origin_node_->GetOpDesc()->GetType();
}
const std::vector<size_t> &StreamNode::GetSendEventIds() const {
return send_event_ids_;
}
const std::vector<size_t> &StreamNode::GetRecvEventIds() const {
return recv_event_ids_;
}
const std::vector<uint32_t> &StreamNode::GetLabelIndexes() const {
return label_indexes_;
}
size_t StreamNode::GetIdInStream() const {
return idx_;
}
bool StreamNode::IsInBranch() const {
return in_branch_;
}
void StreamNode::SetIdInStream(size_t id) {
idx_ = id;
}
void StreamNode::AppendSendEventId(size_t event_id) {
send_event_ids_.push_back(event_id);
}
void StreamNode::SetRecvEventIds(const std::vector<size_t> &recv_event_ids) {
recv_event_ids_ = recv_event_ids;
}
std::string StreamGraph::DotStreams() const {
std::stringstream content;
for (size_t stream_id = 0; stream_id < stream_num_; stream_id++) {
std::stringstream ss;
size_t all_nodes_num = stream_nodes_[stream_id].size();
size_t node_num_without_events = stream_nodes_without_events_[stream_id].size();
std::vector<string> node_dot_ids;
ss << "<table border=\"0\" cellborder=\"1\" cellspacing=\"0\">";
ss << "<tr><td><b>Stream" << stream_id << "</b></td></tr>";
ss << "<tr><td>" << all_nodes_num << "(" << node_num_without_events << "+"
<< (all_nodes_num - node_num_without_events) << ")nodes</td></tr>";
if (!stream_labels_[stream_id].empty()) {
ss << "<tr><td><i>Label:" << stream_labels_[stream_id] << "</i></td></tr>";
}
ss << "</table>";
std::string color;
if (stream_in_branch_.find(stream_id) != stream_in_branch_.end()) {
color = ",color=\"blue\"";
}
content << "{\n";
content << "rank=\"same\";\n";
content << "edge[style=\"solid\",penwidth=2" << color << "];\n";
content << "Stream" + std::to_string(stream_id) + "[label=<" + ss.str() + ">" + color + "];\n";
content << "End" + std::to_string(stream_id) + "[shape=\"point\",width=0.1,height=0.1];\n";
for (auto node : stream_nodes_without_events_[stream_id]) {
std::string dot_id = DotId(node);
node_dot_ids.emplace_back(dot_id);
content << dot_id << "[label=<" << node->GetName() << "<br/>(id:" << node->GetTopoId() << ")";
const auto &activated_streams = node->GetActivatedStreams();
if (!activated_streams.empty()) {
content << "<br/>(active streams:" << VecIntToStr(activated_streams) << ")";
}
const auto &label_indexes = node->GetLabelIndexes();
if (!label_indexes.empty()) {
content << "<br/>(labels:" << VecIntToStr(label_indexes) << ")";
}
content << ">];\n";
}
content << "Stream" << stream_id;
for (auto node_dot_id : node_dot_ids) {
content << "->" << node_dot_id;
}
content << "->End" << stream_id << ";\n";
content << "}\n";
}
return content.str();
}
std::string StreamGraph::DotEdges() const {
std::stringstream content;
for (auto node : nodes_) {
for (size_t send_event_id : node->GetSendEventIds()) {
if (recv_event_node_map_.find(send_event_id) == recv_event_node_map_.end()) {
continue;
}
auto recv_node = recv_event_node_map_.at(send_event_id);
content << DotId(node) << "->" << DotId(recv_node);
content << "[label=\"Event" << send_event_id << "\",arrowhead=\"normal\"];\n";
}
}
return content.str();
}
std::string StreamGraph::DotActives() const {
if (stream_in_branch_.empty()) {
return "";
}
std::stringstream content;
content << "{\n";
for (auto node : nodes_) {
std::string color;
std::string font_color;
if (node->GetType() == STREAMSWITCH) {
color = ",color=\"blue\"";
font_color = ",fontcolor=\"blue\"";
}
const auto &activated_streams = node->GetActivatedStreams();
if (activated_streams.size() > 0) {
content << "AN" << node->GetIdInStream() << "[label=\"" << node->GetName() << "(id: " << node->GetTopoId()
<< ",stream: " << node->GetStreamId() << ")\"];\n";
for (uint32_t activated_stream : activated_streams) {
content << "AS" << activated_stream << "[label=\"Stream" << activated_stream << "\"" << font_color << "];\n";
content << "AN" << node->GetIdInStream() << "-> AS" << activated_stream
<< " [label=\"active\", arrowhead=\"normal\"" << color << font_color << "];\n";
}
}
}
content << "}\n";
return content.str();
}
std::string StreamGraph::DotHeaderStr() const {
std::stringstream content;
content << "digraph " << name_
<< "{\nrankdir=\"LR\";\nnode[shape=\"plaintext\",width=0,height=0];\nedge[arrowhead=\"none\",style="
"\"dashed\"];\n\n";
return content.str();
}
std::string StreamGraph::DotTailStr() const {
return "}\n";
}
std::string StreamGraph::ToDotString() const {
std::stringstream content;
content << DotHeaderStr();
content << DotStreams();
content << DotEdges();
content << DotActives();
content << DotTailStr();
return content.str();
}
void StreamGraph::Build() {
if (nodes_.empty()) {
return;
}
for (size_t idx = 0; idx < nodes_.size(); ++idx) {
nodes_[idx]->SetIdInStream(idx);
}
auto max_stream_node = std::max_element(
nodes_.begin(), nodes_.end(),
[](const StreamNodePtr &lhs, const StreamNodePtr &rhs) { return lhs->GetStreamId() < rhs->GetStreamId(); });
if (max_stream_node != nodes_.end() && (*max_stream_node != nullptr)) {
stream_num_ = (*max_stream_node)->GetStreamId() + 1;
} else {
return;
}
if (stream_nodes_.size() < stream_num_) {
stream_nodes_.resize(stream_num_);
}
for (auto &node : nodes_) {
if (node->GetStreamId() >= 0) {
stream_nodes_[node->GetStreamId()].emplace_back(node);
}
}
for (auto &nodes_per_stream : stream_nodes_) {
StreamNodePtr pre_node = nullptr;
std::vector<size_t> recv_events;
for (auto &node : nodes_per_stream) {
if (node->GetType() == SEND) {
if (pre_node) {
if (auto pos = node->GetName().find("_Send_"); pos != std::string::npos) {
int event_id = std::stoi(node->GetName().substr(pos + 6));
pre_node->AppendSendEventId(event_id);
}
}
} else if (node->GetType() == RECV) {
if (auto pos = node->GetName().find("_Recv_"); pos != std::string::npos) {
int event_id = std::stoi(node->GetName().substr(pos + 6));
recv_events.push_back(event_id);
}
} else {
node->SetRecvEventIds(recv_events);
recv_events.clear();
pre_node = node;
}
}
}
nodes_.erase(
std::remove_if(nodes_.begin(), nodes_.end(),
[](const StreamNodePtr &node) { return (node->GetType() == SEND) || (node->GetType() == RECV); }),
nodes_.end());
for (auto &node : nodes_) {
for (auto id : node->GetSendEventIds()) {
send_event_node_map_[id] = node;
}
for (auto id : node->GetRecvEventIds()) {
recv_event_node_map_[id] = node;
}
}
stream_nodes_without_events_.resize(stream_num_);
stream_labels_.resize(stream_num_);
for (auto &node : nodes_) {
if (node->GetStreamId() < 0) {
continue;
}
stream_nodes_without_events_[node->GetStreamId()].push_back(node);
if (!node->GetStreamLabel().empty() && stream_labels_[node->GetStreamId()].empty()) {
stream_labels_[node->GetStreamId()] = node->GetStreamLabel();
}
if (node->GetType() == STREAMSWITCH || (node->GetType() == STREAMACTIVE && node->IsInBranch())) {
for (uint32_t activated_stream : node->GetActivatedStreams()) {
stream_in_branch_.insert(activated_stream);
}
}
}
}
}