* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/passes/feature/auto_fuse_pass.h"
#include "op_desc_utils.h"
#include "graph/optimize/symbolic/infer_symbolic_shape/symbolic_shape_inference.h"
#include "graph/optimize/symbolic/codegen/guard_codegen.h"
#include "common/checker.h"
#include "ge_common/ge_api_error_codes.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/ge_context.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_type_utils.h"
#include "graph/attribute_group/attr_group_symbolic_desc.h"
#include "graph/attribute_group/attr_group_shape_env.h"
#include "graph/utils/node_utils.h"
#include "autofuse_frame/autofuse_frames.h"
#include "common/ge_types.h"
#include "mmpa/mmpa_api.h"
#include "graph/optimize/symbolic/shape_env_guarder.h"
#include "common/compile_profiling/ge_trace_wrapper.h"
#include "common/plugin/ge_make_unique_util.h"
#include "graph/passes/base_pass.h"
#include "algebraic_simplification_pass.h"
#include "graph/passes/standard_optimize/prune_pass.h"
#include "graph/passes/standard_optimize/constant_folding/constant_folding_pass.h"
#include "graph/optimize/autofuse/autofuse/pattern_fusion/pattern_fusion.h"
namespace ge {
namespace {
Status MarkEngineAttrForAutofuseNode(const ComputeGraphPtr &graph) {
for (const auto &node : graph->GetDirectNode()) {
const auto op_desc = node->GetOpDesc();
GE_ASSERT_NOTNULL(op_desc);
if (OpTypeUtils::IsAutofuseNode(op_desc)) {
op_desc->SetOpEngineName(kEngineNameAiCore);
op_desc->SetOpKernelLibName(kEngineNameAiCore);
AttrUtils::SetInt(op_desc, ATTR_NAME_IMPLY_TYPE, static_cast<int32_t>(domi::ImplyType::TVM));
(void)ge::AttrUtils::SetStr(op_desc, ge::kAttrLowingFunc, "kAutoFuseLoweringFunc");
}
}
return SUCCESS;
}
class AutofuseCounter : public Counter {
public:
AutofuseCounter() = default;
~AutofuseCounter() override = default;
int64_t NextId() override {
return unique_id_.fetch_add(1);
}
private:
std::atomic<int64_t> unique_id_{0L};
};
}
Status AutoFusePass::Run(ComputeGraphPtr graph) {
GE_DUMP(graph, "AutoFuser_BeforeAutoFuse");
GE_TRACE_START(PreProcess);
GE_ASSERT_SUCCESS(PreProcess(graph));
GE_COMPILE_TRACE_TIMESTAMP_END(PreProcess, ("AutoFusePass::PreProcess::" + graph->GetName()).c_str());
SymbolicShapeInference symbolic_shape_inference;
GE_TRACE_START(Infer);
GE_ASSERT_SUCCESS(symbolic_shape_inference.Infer(graph));
GE_COMPILE_TRACE_TIMESTAMP_END(Infer, ("SymbolicShapeInference::Infer::" + graph->GetName()).c_str());
GE_DUMP(graph, "AutoFuser_AfterPreprocess");
auto root_graph = ge::GraphUtils::FindRootGraph(graph);
GE_ASSERT_NOTNULL(root_graph);
ShapeEnvGuarder guarder(root_graph->GetAttrsGroup<ShapeEnvAttr>());
auto autofuse_counter = MakeUnique<AutofuseCounter>();
GE_ASSERT_NOTNULL(autofuse_counter);
GE_ASSERT_SUCCESS(LoweringAndCanFuseWithCounter(graph, autofuse_counter.get()));
GE_ASSERT_SUCCESS(MarkEngineAttrForAutofuseNode(graph));
GE_DUMP(graph, "AutoFuser_AfterAllFuse");
return SUCCESS;
}
Status AutoFusePass::PreProcess(const ComputeGraphPtr &graph) {
GE_ASSERT_SUCCESS(AlgebraicSimplificationPass::Run(graph));
GraphPasses graph_passes;
graph_passes.prune_graph_func = [](const ComputeGraphPtr &graph) -> Status { return PrunePass().Run(graph); };
graph_passes.constant_folding_func = [](NodePtr &node) -> Status { return ConstantFoldingPass().Run(node); };
GE_ASSERT_SUCCESS(ge::PatternFusion::RunEarlyPasses(graph, graph_passes));
return SUCCESS;
}
REG_PASS_OPTION("AutoFusePass").LEVELS(OoLevel::kO3);
}