/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "InitFuncDef.h"

#include "ascir/Dialect/Asc/Transforms/Passes.h"

#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/SourceMgr.h"

#include <pybind11/cast.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // automatic casts between containers and python types

#define DEFINE_ADD_PASS(NAME, CONSTRUCTOR) m.def(NAME, [](PassManager& pm) { pm.addPass(CONSTRUCTOR()); })

#define DEFINE_ADD_PASS_ON(NEST, NAME, CONSTRUCTOR) \
    m.def(NAME, [](PassManager& pm) { pm.addNestedPass<NEST>(CONSTRUCTOR()); })

namespace py = pybind11;
using namespace mlir;

namespace {

void definePassManager(py::module& m)
{
    using namespace pybind11::literals;

    py::class_<PassManager>(m, "PassManager", py::module_local())
        .def(py::init<MLIRContext*>())
        .def(
            "get_pipeline_str",
            [](PassManager& self) -> std::string {
                std::string result;
                llvm::raw_string_ostream os(result);
                self.printAsTextualPipeline(os);
                os.flush();
                return result;
            })
        .def(
            "run",
            [](PassManager& self, ModuleOp& mod) {
                llvm::SourceMgr sourceMgr;
                SourceMgrDiagnosticHandler handler(sourceMgr, self.getContext());
                if (self.run(mod.getOperation()).failed())
                    throw std::runtime_error("Failed to run passes");
            })
        .def(
            "enable_verifier", [](PassManager& self, bool enable) { self.enableVerifier(enable); }, "enable"_a = true)
        .def("enable_printing", [](PassManager& self) {
            OpPrintingFlags flags;
            flags.enableDebugInfo(true);
            self.enableIRPrinting(
                [](Pass*, Operation*) { return true; }, /*shouldPrintBeforePass*/
                [](Pass*, Operation*) { return true; }, /*shouldPrintAfterPass*/
                false,                                  /*printModuleScope*/
                false,                                  /*printAfterOnlyOnChange*/
                true,                                   /*printAfterOnlyOnFailure*/
                llvm::errs(),                           /*out*/
                flags                                   /*opPrintingFlags*/
            );
        });
}

void defineCommonPasses(py::module& mod)
{
    auto m = mod.def_submodule("common");
    DEFINE_ADD_PASS("add_canonicalizer", createCanonicalizerPass);
    DEFINE_ADD_PASS("add_cse", createCSEPass);
    DEFINE_ADD_PASS("add_inliner", createInlinerPass);
    DEFINE_ADD_PASS("add_licm", createLoopInvariantCodeMotionPass);
    DEFINE_ADD_PASS("add_print_ir", createPrintIRPass);
    DEFINE_ADD_PASS("add_reconcile_unrealized_casts", createReconcileUnrealizedCastsPass);
    DEFINE_ADD_PASS("add_sccp", createSCCPPass);
    DEFINE_ADD_PASS("add_strip_debug_info", createStripDebugInfoPass);
    DEFINE_ADD_PASS("add_symbol_dce", createSymbolDCEPass);
}

void defineAscendCPasses(py::module& mod)
{
    using namespace ascendc;
    auto m = mod.def_submodule("ascendc");
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_noop_pass", createNoopPass);
    DEFINE_ADD_PASS("add_detect_kernel_type", createDetectKernelTypePass);
    DEFINE_ADD_PASS("add_declare_py_struct", createDeclarePyStructPass);
    DEFINE_ADD_PASS("add_define_cube_only", createDefineCubeOnlyPass);
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_erase_sync", createEraseSyncPass);
    DEFINE_ADD_PASS("add_generate_boilerplate", createGenerateBoilerplatePass);
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_hoist_que_bind", createHoistQueBindPass);
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_hoist_ub_allocation", createHoistUBAllocationPass);
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_input_output_tensor", createInputOutputTensorPass);
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_insert_sync", createInsertSyncPass);
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_materialize_tensor", createMaterializeTensorPass);
    DEFINE_ADD_PASS("add_legalize_kernel_args", createLegalizeKernelArgsPass);
    DEFINE_ADD_PASS("add_privatize_func", createPrivatizeFuncPass);
    DEFINE_ADD_PASS("add_detect_enable_debug", createDetectEnableDebugPass);
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_unify_pipe", createUnifyPipePass);
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_verify_sync", createVerifySyncPass);
}

} // namespace

namespace pybind11 {
namespace asc {
void initPassesModule(py::module&& m)
{
    definePassManager(m);
    defineCommonPasses(m);
    defineAscendCPasses(m);
}
} // namespace asc
} // namespace pybind11

#undef DEFINE_ADD_PASS
#undef DEFINE_ADD_PASS_ON