* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "InitFuncDef.h"
#include "ascir/Dialect/Asc/Transforms/Passes.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/SourceMgr.h"
#include <pybind11/cast.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#define DEFINE_ADD_PASS(NAME, CONSTRUCTOR) m.def(NAME, [](PassManager& pm) { pm.addPass(CONSTRUCTOR()); })
#define DEFINE_ADD_PASS_ON(NEST, NAME, CONSTRUCTOR) \
m.def(NAME, [](PassManager& pm) { pm.addNestedPass<NEST>(CONSTRUCTOR()); })
namespace py = pybind11;
using namespace mlir;
namespace {
void definePassManager(py::module& m)
{
using namespace pybind11::literals;
py::class_<PassManager>(m, "PassManager", py::module_local())
.def(py::init<MLIRContext*>())
.def(
"get_pipeline_str",
[](PassManager& self) -> std::string {
std::string result;
llvm::raw_string_ostream os(result);
self.printAsTextualPipeline(os);
os.flush();
return result;
})
.def(
"run",
[](PassManager& self, ModuleOp& mod) {
llvm::SourceMgr sourceMgr;
SourceMgrDiagnosticHandler handler(sourceMgr, self.getContext());
if (self.run(mod.getOperation()).failed())
throw std::runtime_error("Failed to run passes");
})
.def(
"enable_verifier", [](PassManager& self, bool enable) { self.enableVerifier(enable); }, "enable"_a = true)
.def("enable_printing", [](PassManager& self) {
OpPrintingFlags flags;
flags.enableDebugInfo(true);
self.enableIRPrinting(
[](Pass*, Operation*) { return true; },
[](Pass*, Operation*) { return true; },
false,
false,
true,
llvm::errs(),
flags
);
});
}
void defineCommonPasses(py::module& mod)
{
auto m = mod.def_submodule("common");
DEFINE_ADD_PASS("add_canonicalizer", createCanonicalizerPass);
DEFINE_ADD_PASS("add_cse", createCSEPass);
DEFINE_ADD_PASS("add_inliner", createInlinerPass);
DEFINE_ADD_PASS("add_licm", createLoopInvariantCodeMotionPass);
DEFINE_ADD_PASS("add_print_ir", createPrintIRPass);
DEFINE_ADD_PASS("add_reconcile_unrealized_casts", createReconcileUnrealizedCastsPass);
DEFINE_ADD_PASS("add_sccp", createSCCPPass);
DEFINE_ADD_PASS("add_strip_debug_info", createStripDebugInfoPass);
DEFINE_ADD_PASS("add_symbol_dce", createSymbolDCEPass);
}
void defineAscendCPasses(py::module& mod)
{
using namespace ascendc;
auto m = mod.def_submodule("ascendc");
DEFINE_ADD_PASS_ON(func::FuncOp, "add_noop_pass", createNoopPass);
DEFINE_ADD_PASS("add_detect_kernel_type", createDetectKernelTypePass);
DEFINE_ADD_PASS("add_declare_py_struct", createDeclarePyStructPass);
DEFINE_ADD_PASS("add_define_cube_only", createDefineCubeOnlyPass);
DEFINE_ADD_PASS_ON(func::FuncOp, "add_erase_sync", createEraseSyncPass);
DEFINE_ADD_PASS("add_generate_boilerplate", createGenerateBoilerplatePass);
DEFINE_ADD_PASS_ON(func::FuncOp, "add_hoist_que_bind", createHoistQueBindPass);
DEFINE_ADD_PASS_ON(func::FuncOp, "add_hoist_ub_allocation", createHoistUBAllocationPass);
DEFINE_ADD_PASS_ON(func::FuncOp, "add_input_output_tensor", createInputOutputTensorPass);
DEFINE_ADD_PASS_ON(func::FuncOp, "add_insert_sync", createInsertSyncPass);
DEFINE_ADD_PASS_ON(func::FuncOp, "add_materialize_tensor", createMaterializeTensorPass);
DEFINE_ADD_PASS("add_legalize_kernel_args", createLegalizeKernelArgsPass);
DEFINE_ADD_PASS("add_privatize_func", createPrivatizeFuncPass);
DEFINE_ADD_PASS("add_detect_enable_debug", createDetectEnableDebugPass);
DEFINE_ADD_PASS_ON(func::FuncOp, "add_unify_pipe", createUnifyPipePass);
DEFINE_ADD_PASS_ON(func::FuncOp, "add_verify_sync", createVerifySyncPass);
}
}
namespace pybind11 {
namespace asc {
void initPassesModule(py::module&& m)
{
definePassManager(m);
defineCommonPasses(m);
defineAscendCPasses(m);
}
}
}
#undef DEFINE_ADD_PASS
#undef DEFINE_ADD_PASS_ON