* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "broadcast_const_to_store.h"
#include "attr_utils.h"
#include "ascir_ops.h"
#include "ascir_ops_utils.h"
#include "ascgraph_info_complete.h"
#include "graph_utils.h"
#include "graph/ascendc_ir/utils/asc_tensor_utils.h"
#include "graph/ascendc_ir/utils/asc_graph_utils.h"
#include "node_utils.h"
using namespace ascir;
using namespace af::ascir_op;
using namespace af::ops;
namespace optimize {
Status BroadcastConstToStorePass::RunPass(af::AscGraph &graph) {
for (const auto &node : graph.GetAllNodes()) {
if (!IsOps<Store>(node) || !af::ascir::AscTensorUtils::IsConstTensor(node->inputs[0])) {
continue;
}
const auto const_node = af::ascir::AscTensorUtils::GetOwner(node->inputs[0]);
if (const_node == nullptr) {
continue;
}
const std::string node_name = "scalar_broadcast_" + node->GetName();
Broadcast scalar_broadcast(node_name.c_str());
af::AscNodePtr broadcast_node = graph.AddNode(scalar_broadcast);
GE_ASSERT_NOTNULL(broadcast_node);
af::GraphUtils::RemoveEdge(const_node->GetOutDataAnchor(0), node->GetInDataAnchor(0));
af::GraphUtils::AddEdge(broadcast_node->GetOutDataAnchor(0), node->GetInDataAnchor(0));
af::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), broadcast_node->GetInDataAnchor(0));
scalar_broadcast.attr.sched = node->attr.sched;
scalar_broadcast.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
scalar_broadcast.attr.api.type = af::ApiType::kAPITypeCompute;
scalar_broadcast.y.dtype = static_cast<af::DataType>(node->inputs[0].attr.dtype);
*scalar_broadcast.y.axis = node->outputs[0].attr.axis;
*scalar_broadcast.y.repeats = node->outputs[0].attr.repeats;
*scalar_broadcast.y.strides = node->outputs[0].attr.strides;
scalar_broadcast.y.dtype = node->outputs[0].attr.dtype;
}
return ge::SUCCESS;
}
}