* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "scalar_broadcast_optimization.h"
#include "attr_utils.h"
#include "ascir_ops.h"
#include "ascir_ops_utils.h"
#include "ascgraph_info_complete.h"
#include "graph_utils.h"
#include "node_utils.h"
#include "schedule_utils.h"
#include "common_utils.h"
using namespace ascir;
using namespace af::ascir_op;
using namespace af::ops;
namespace {
constexpr int32_t kFirstInputIndex = 0;
constexpr int32_t kSecondInputIndex = 1;
constexpr int32_t kSupportScalarInputNum = 2;
}
namespace optimize {
static bool IsVfNodeScalarUnsupported(const af::AscNodePtr &next_node, const af::AscNodePtr &first_brc_node) {
if (!ascgen_utils::IsNodeSupportsVectorFunction(next_node)) {
GELOGD("Node %s(%s) does not support VF, continue with other checks.", next_node->GetTypePtr(),
next_node->GetNamePtr());
return false;
}
if (first_brc_node->GetInDataNodesSize() == 0UL) {
GELOGD("First broadcast node %s(%s) has no input, skip optimization.", first_brc_node->GetTypePtr(),
first_brc_node->GetNamePtr());
return true;
}
const auto &source_node = std::dynamic_pointer_cast<af::AscNode>(first_brc_node->GetInDataNodes().at(0UL));
if (source_node == nullptr) {
GELOGD("Source node of first broadcast %s(%s) is nullptr, skip optimization.", first_brc_node->GetTypePtr(),
first_brc_node->GetNamePtr());
return true;
}
if (!IsOps<af::ascir_op::Scalar>(source_node)) {
GELOGD("Source node %s(%s) is not Scalar, skip optimization for VF node %s(%s).", source_node->GetTypePtr(),
source_node->GetNamePtr(), next_node->GetTypePtr(), next_node->GetNamePtr());
return true;
}
if (!ScheduleUtils::IsMicroApiSupportsScalarInput(next_node)) {
GELOGD("VF node %s(%s) micro API does not support scalar input, skip optimization.", next_node->GetTypePtr(),
next_node->GetNamePtr());
return true;
}
return false;
}
Status ScalarBroadcastOptimizationPass::GetNodeScalarInputList(const af::NodePtr &node,
std::vector<bool> &is_scalar_list) {
GE_ASSERT_NOTNULL(node);
GE_ASSERT_NOTNULL(std::dynamic_pointer_cast<af::AscNode>(node));
const auto &asc_node = std::dynamic_pointer_cast<af::AscNode>(node);
is_scalar_list.resize(asc_node->GetInDataNodesSize(), false);
for (size_t i = 0UL; i < is_scalar_list.size(); ++i) {
is_scalar_list[i] = ascgen_utils::IsScalarInput(asc_node->inputs[i].attr.repeats);
}
return ge::SUCCESS;
}
Status ScalarBroadcastOptimizationPass::IsNextNodeSupportScalarInput(const NodeView &brc_node, bool &is_supported,
const af::AscNodePtr &first_brc_node) {
is_supported = false;
if (!IsOps<Broadcast>(brc_node)) {
return ge::SUCCESS;
}
for (const auto &out_anchor : brc_node->GetAllOutDataAnchors()) {
GE_CHECK_NOTNULL(out_anchor);
for (const auto &next_in_anchor : out_anchor->GetPeerInDataAnchorsPtr()) {
GE_CHECK_NOTNULL(next_in_anchor);
const auto &next_node = std::dynamic_pointer_cast<af::AscNode>(next_in_anchor->GetOwnerNode());
GE_CHECK_NOTNULL(next_node);
if (IsVfNodeScalarUnsupported(next_node, first_brc_node)) {
return ge::SUCCESS;
}
if (ascgen_utils::IsNodeSupportsAllScalar(next_node)) {
continue;
}
std::vector<bool> is_scalar_list;
GE_ASSERT_SUCCESS(GetNodeScalarInputList(next_node, is_scalar_list));
const auto idx = static_cast<size_t>(next_in_anchor->GetIdx());
GE_ASSERT_TRUE(idx < is_scalar_list.size(), "Input index(%zu) of %s(%s) out of range(%zu)", idx,
next_node->GetTypePtr(), next_node->GetNamePtr(), is_scalar_list.size());
is_scalar_list[idx] = true;
if (ascgen_utils::IsNodeSupportsScalarInput(next_node, is_scalar_list)) {
GELOGD("Node %s(%s) supports scalar input, index=%zu.", next_node->GetTypePtr(), next_node->GetNamePtr(), idx);
continue;
}
if (is_scalar_list.size() != static_cast<size_t>(kSupportScalarInputNum)) {
return ge::SUCCESS;
}
if (!ascgen_utils::IsNodeSupportsScalarIfExchangeInputs(next_node, is_scalar_list)) {
return ge::SUCCESS;
}
const int32_t swap_input_index = kSecondInputIndex - static_cast<int32_t>(idx);
if (ascgen_utils::IsScalarInput(next_node->inputs[swap_input_index].attr.repeats)) {
GELOGD("The input index 1 of %s[%s] is already scalar, not support swap with %d.", next_node->GetTypePtr(),
next_node->GetNamePtr(), idx);
return ge::SUCCESS;
}
if (ScheduleUtils::HasSameInput(next_node)) {
GELOGD("Node %s(%s) has same input, not support.", next_node->GetTypePtr(), next_node->GetNamePtr());
return ge::SUCCESS;
}
GE_ASSERT_SUCCESS(ScheduleUtils::SwapInputIndex(next_node, kFirstInputIndex, kSecondInputIndex));
}
}
is_supported = true;
return ge::SUCCESS;
}
Status ScalarBroadcastOptimizationPass::RunPass(af::AscGraph &graph) {
for (const auto &node : graph.GetAllNodes()) {
GE_CHECK_NOTNULL(node);
GELOGD("Check node : %s %s, out size: %d, in size: %d", node->GetTypePtr(), node->GetNamePtr(),
node->GetOutDataNodesSize(), node->GetInDataNodesSize());
if (!IsOps<Broadcast>(node) || !ScheduleUtils::IsScalarBroadcastNode(node)) {
continue;
}
GELOGD("Find scalar Broadcast node [%s], start to check", node->GetNamePtr());
std::vector<af::AscNodePtr> continues_brc_nodes;
if (!ScheduleUtils::FindContinuesBroadcastNode(node, continues_brc_nodes) || continues_brc_nodes.empty()) {
continues_brc_nodes.clear();
continue;
}
const auto &last_brc_node = continues_brc_nodes[continues_brc_nodes.size() - 1UL];
bool is_next_node_supported_scalar = false;
const auto &first_brc_node = std::dynamic_pointer_cast<af::AscNode>(node);
GE_ASSERT_SUCCESS(IsNextNodeSupportScalarInput(last_brc_node, is_next_node_supported_scalar, first_brc_node));
if (!is_next_node_supported_scalar) {
continues_brc_nodes.clear();
continue;
}
GE_CHECK_NOTNULL(node->GetInDataAnchor(0));
auto pre_node_out_anchor = node->GetInDataAnchor(0)->GetPeerOutAnchor();
GE_ASSERT_NOTNULL(pre_node_out_anchor);
for (const auto &out_anchor : last_brc_node->GetAllOutDataAnchors()) {
GE_CHECK_NOTNULL(out_anchor);
for (const auto &next_in_anchor : out_anchor->GetPeerInDataAnchors()) {
GE_CHECK_NOTNULL(next_in_anchor);
GE_CHECK_NOTNULL(next_in_anchor->GetOwnerNode());
const auto &next_node = std::dynamic_pointer_cast<af::AscNode>(next_in_anchor->GetOwnerNode());
next_node->inputs[next_in_anchor->GetIdx()].attr = node->inputs[0].attr;
GE_ASSERT_SUCCESS(af::GraphUtils::ReplaceEdgeSrc(out_anchor, next_in_anchor, pre_node_out_anchor));
}
}
for (const auto &continues_brc_node : continues_brc_nodes) {
GELOGD("Scalar Broadcast node [%s] can be optimized, remove it.", continues_brc_node->GetNamePtr());
af::NodeUtils::UnlinkAll(*continues_brc_node);
GE_CHECK_NOTNULL(continues_brc_node->GetOwnerComputeGraph());
af::GraphUtils::RemoveNodeWithoutRelink(continues_brc_node->GetOwnerComputeGraph(), continues_brc_node);
}
}
return ge::SUCCESS;
}
}