* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <algorithm>
#include <fstream>
#include <unordered_set>
#include "graph/utils/graph_utils_ex.h"
#include "graph/utils/op_type_utils.h"
#include "register/graph_optimizer/fusion_common/graph_pass_util.h"
#include "graph/ge_context.h"
#include "common/util.h"
#include "common/framework_types_internal.h"
#include "nlohmann/json.hpp"
#include "fusion_utils.h"
namespace ge {
namespace fusion {
namespace {
constexpr char_t const *kGraphFusionKey = "GraphFusion";
constexpr char_t const *kSwitchKey = "Switch";
constexpr char_t const *kSwitchOn = "on";
constexpr char_t const *kSwitchOff = "off";
std::string SwitchToString(const std::map<string, bool> &pass_name_2_switches) {
std::stringstream ss;
for (const auto &pass_name_2_switch : pass_name_2_switches) {
ss << "[" << pass_name_2_switch.first << ",";
std::string switch_str = pass_name_2_switch.second ? kSwitchOn : kSwitchOff;
ss << switch_str << "]";
}
return ss.str();
}
Status ReadJsonFile(const std::string &file_real_path, nlohmann::json &json_obj) {
std::ifstream file_stream(file_real_path);
GE_MAKE_GUARD(file_guard, [&file_stream]() { file_stream.close(); });
try {
GE_WARN_ASSERT(file_stream.is_open(), "Failed to open json file[%s], File is already opened",
file_real_path.c_str());
file_stream >> json_obj;
} catch (const std::exception &e) {
GELOGW("Failed to read json file[%s], err msg: %s", file_real_path.c_str(), e.what());
return FAILED;
}
GE_CHK_BOOL_RET_STATUS(file_stream.good(), FAILED, "Failed to read json file[%s], error msg = %s",
file_real_path.c_str(), strerror(errno));
GELOGD("Read json file[%s] success", file_real_path.c_str());
return SUCCESS;
}
void CollectFusionSwitchMap(const nlohmann::json &fusion_switch_config_json,
std::map<std::string, bool> &fusion_switch_map) {
if (fusion_switch_config_json == nullptr ||
fusion_switch_config_json.find(kSwitchKey) == fusion_switch_config_json.end() ||
fusion_switch_config_json[kSwitchKey].find(kGraphFusionKey) == fusion_switch_config_json[kSwitchKey].end()) {
return;
}
if (!fusion_switch_config_json[kSwitchKey][kGraphFusionKey].is_object()) {
GELOGW("The third level of json file should be object.");
return;
}
for (auto &item_json : fusion_switch_config_json[kSwitchKey][kGraphFusionKey].items()) {
string key_inner = item_json.key();
string value_inner = item_json.value();
if (!key_inner.empty()) {
fusion_switch_map.emplace(key_inner, value_inner == kSwitchOn);
}
}
}
Status ParseFusionSwitchFile(const string &fusion_switch_json_file, std::map<string, bool> &pass_name_2_switch) {
nlohmann::json fusion_config_json;
GE_WARN_ASSERT(ReadJsonFile(fusion_switch_json_file, fusion_config_json) == SUCCESS,
"Failed to read json from file %s.", fusion_switch_json_file.c_str());
if (fusion_config_json != nullptr && !fusion_config_json.is_object()) {
GELOGE(GRAPH_FAILED, "[GraphOpt][Init][CheckCfgFileFormat] Top level of fusion config file should be object.");
return FAILED;
}
CollectFusionSwitchMap(fusion_config_json, pass_name_2_switch);
return SUCCESS;
}
std::vector<NodePtr> ToNodePtrs(const std::vector<GNode> &nodes) {
std::vector<NodePtr> node_ptrs;
for (const auto &node : nodes) {
node_ptrs.emplace_back(NodeAdapter::GNode2Node(node));
}
return node_ptrs;
}
}
Status FusionUtils::MarkPassNameOnReplacementNodes(const std::unique_ptr<Graph> &replacement,
const std::unique_ptr<SubgraphBoundary> &subgraph,
const std::string &pass_name) {
GE_ASSERT_NOTNULL(replacement);
GE_ASSERT_NOTNULL(subgraph);
const auto replacement_graph = GraphUtilsEx::GetComputeGraph(*replacement);
GE_ASSERT_NOTNULL(replacement_graph);
InnerSubgraphBoundary inner_boundary;
std::string boundary_invalid_reason;
GE_ASSERT_SUCCESS(inner_boundary.Init(*subgraph, boundary_invalid_reason), boundary_invalid_reason.c_str());
for (const auto &node : replacement_graph->GetAllNodes()) {
GE_ASSERT_NOTNULL(node);
const auto op_desc = node->GetOpDesc();
GE_ASSERT_NOTNULL(op_desc, "Node[%s][%s] has null op_desc.", node->GetNamePtr(), node->GetTypePtr());
GE_ASSERT_SUCCESS(
fe::GraphPassUtil::StoreAndUpdataOriginFusionPassName(op_desc, inner_boundary.GetNodes(), pass_name),
"Failed to mark pass name[%s] on node[%s][%s].", pass_name.c_str(), node->GetNamePtr(), node->GetTypePtr());
}
return SUCCESS;
}
std::string FusionUtils::ToString(const std::unique_ptr<Graph> &graph) {
std::stringstream ss;
auto compute_grpah = GraphUtilsEx::GetComputeGraph(*graph);
ss << "[";
for (const auto &node : compute_grpah->GetDirectNode()) {
if (OpTypeUtils::IsDataNode(node->GetType()) || OpTypeUtils::IsGraphOutputNode(node->GetType())) {
continue;
}
ss << "{" << node->GetTypePtr() << "}";
}
ss << "]";
return ss.str();
}
std::string FusionUtils::GetFusionSwitchFileFromOption() {
std::string fusion_switch_file_path;
ge::graphStatus status = GetContext().GetOption(FUSION_SWITCH_FILE, fusion_switch_file_path);
if (status != ge::GRAPH_SUCCESS) {
GELOGD("Cannot get option value [%s].", FUSION_SWITCH_FILE.c_str());
return "";
}
GELOGD("The [%s] in ge context is %s.", FUSION_SWITCH_FILE.c_str(), fusion_switch_file_path.c_str());
return fusion_switch_file_path;
}
std::map<string, bool> FusionUtils::ParseFusionSwitch() {
const auto fusion_switch_file_path = GetFusionSwitchFileFromOption();
if (fusion_switch_file_path.empty()) {
return {};
}
const auto fusion_switch_real_path = RealPath(fusion_switch_file_path.c_str());
if (fusion_switch_real_path.empty()) {
GELOGD("Fusion switch config real path of %s is empty", fusion_switch_file_path.c_str());
return {};
}
GELOGD("Fusion switch config real path is %s", fusion_switch_real_path.c_str());
std::map<string, bool> pass_name_2_switches;
GE_WARN_ASSERT(ParseFusionSwitchFile(fusion_switch_real_path, pass_name_2_switches) == SUCCESS);
GELOGD("[FusionSwitch] is %s", SwitchToString(pass_name_2_switches).c_str());
return pass_name_2_switches;
}
bool FusionUtils::WillCauseCycleIfFuse(const std::unique_ptr<MatchResult> &match_result) {
if (match_result == nullptr) {
return false;
}
auto matched_nodes = ToNodePtrs(match_result->GetMatchedNodes());
matched_nodes.erase(std::remove_if(matched_nodes.begin(), matched_nodes.end(),
[](const NodePtr &node) {
return node == nullptr || node->GetOwnerComputeGraph() == nullptr;
}),
matched_nodes.end());
if (matched_nodes.empty()) {
return false;
}
std::unordered_set<const Node *> fused;
for (const auto &node : matched_nodes) {
fused.insert(node.get());
}
std::unordered_set<const Node *> visited;
std::queue<const Node *> to_visit;
for (const auto &node : matched_nodes) {
for (const auto *const out_node : node->GetOutNodesPtr()) {
if ((out_node == nullptr) || (fused.count(out_node) > 0)) {
continue;
}
if (visited.insert(out_node).second) {
to_visit.push(out_node);
}
}
}
while (!to_visit.empty()) {
const auto *const cur = to_visit.front();
to_visit.pop();
for (const auto *const out_node : cur->GetOutNodesPtr()) {
if (out_node == nullptr) {
continue;
}
if (fused.count(out_node) > 0) {
return true;
}
if (visited.insert(out_node).second) {
to_visit.push(out_node);
}
}
}
return false;
}
void FusionUtils::RecordFusionStatistic(const uint64_t session_id, const std::string graph_id, const std::string pass_name,
const int match_times, const int effect_times) {
fe::FusionStatisticRecorder &fusion_statistic_instance = fe::FusionStatisticRecorder::Instance();
auto fusion_info = fe::FusionInfo(session_id, graph_id, pass_name, match_times, effect_times);
fusion_statistic_instance.UpdateGraphFusionMatchTimes(fusion_info);
fusion_statistic_instance.UpdateGraphFusionEffectTimes(fusion_info);
}
}
}