/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "InitFuncDef.h"

#include "ascir/Dialect/Asc/IR/Asc.h"
#include "ascir/Dialect/Asc/Utils/Attributes.h"
#include "ascir/Dialect/Asc/Utils/Utils.h"
#include "ascir/Dialect/EmitAsc/IR/EmitAsc.h"
#include "ascir/Dialect/EmitAsc/Utils/Attributes.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h"

#include <pybind11/cast.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h> // automatic casts between containers and python types

#include <cstdint>
#include <optional>
#include <stdexcept>
#include <string>

namespace py = pybind11;
using namespace mlir;

namespace {

constexpr unsigned index64 = 64;

OpPrintingFlags getOpPrintingFlags()
{
    auto printingFlags = OpPrintingFlags();
    printingFlags.enableDebugInfo();
    return printingFlags;
}

std::optional<SmallVector<emitasc::KernelArgument>> getKernelArgAttrs(ModuleOp op)
{
    func::FuncOp kernelFunc;
    op.walk([&](func::FuncOp fn) -> WalkResult {
        if (fn->hasAttrOfType<UnitAttr>(ascendc::attr::global)) {
            kernelFunc = fn;
            return WalkResult::interrupt();
        }
        return WalkResult::advance();
    });
    if (!kernelFunc) {
        return std::nullopt;
    }
    SmallVector<emitasc::KernelArgument> kernelArgs;
    unsigned numArgs = kernelFunc.getNumArguments();
    kernelArgs.reserve(numArgs);
    for (unsigned i = 0; i < numArgs; i++) {
        auto attr = kernelFunc.getArgAttrOfType<emitasc::KernelArgumentAttr>(i, emitasc::attr::kernelArg);
        kernelArgs.push_back(attr ? attr.getValue() : emitasc::KernelArgument::Explicit);
    }
    return kernelArgs;
}

void bindEnums(py::module& m)
{
    using ret = py::return_value_policy;

    m.attr("dynshape") = py::int_(ShapedType::kDynamic);

    py::enum_<ascendc::AddressSpace>(m, "AddressSpace", py::module_local())
        .value("ca", ascendc::AddressSpace::ca)
        .value("cb", ascendc::AddressSpace::cb)
        .value("cbuf", ascendc::AddressSpace::cbuf)
        .value("cc", ascendc::AddressSpace::cc)
        .value("Default", ascendc::AddressSpace::Default)
        .value("fbuf", ascendc::AddressSpace::fbuf)
        .value("gm", ascendc::AddressSpace::gm)
        .value("ubuf", ascendc::AddressSpace::ubuf);

    py::enum_<ascendc::AippInputFormat>(m, "AippInputFormat", py::module_local())
        .value("YUV420SP_U8", ascendc::AippInputFormat::YUV420SP_U8)
        .value("XRGB8888_U8", ascendc::AippInputFormat::XRGB8888_U8)
        .value("RGB888_U8", ascendc::AippInputFormat::RGB888_U8)
        .value("YUV400_U8", ascendc::AippInputFormat::YUV400_U8)
        .def_static("symbolize", [](uint8_t inputFormat) -> ascendc::AippInputFormat {
            return static_cast<ascendc::AippInputFormat>(inputFormat);
        });

    py::enum_<ascendc::CacheLine>(m, "CacheLine", py::module_local())
        .value("SINGLE_CACHE_LINE", ascendc::CacheLine::SINGLE_CACHE_LINE)
        .value("ENTIRE_DATA_CACHE", ascendc::CacheLine::ENTIRE_DATA_CACHE)
        .def_static("symbolize", [](uint8_t v) -> ascendc::CacheLine { return static_cast<ascendc::CacheLine>(v); });

    py::enum_<arith::CmpFPredicate>(m, "CmpFPredicate", py::module_local())
        .value("OEQ", arith::CmpFPredicate::OEQ)
        .value("ONE", arith::CmpFPredicate::ONE)
        .value("OGE", arith::CmpFPredicate::OGE)
        .value("OGT", arith::CmpFPredicate::OGT)
        .value("OLE", arith::CmpFPredicate::OLE)
        .value("OLT", arith::CmpFPredicate::OLT);

    py::enum_<arith::CmpIPredicate>(m, "CmpIPredicate", py::module_local())
        .value("eq", arith::CmpIPredicate::eq)
        .value("ne", arith::CmpIPredicate::ne)
        .value("sge", arith::CmpIPredicate::sge)
        .value("sgt", arith::CmpIPredicate::sgt)
        .value("sle", arith::CmpIPredicate::sle)
        .value("slt", arith::CmpIPredicate::slt);

    py::enum_<ascendc::DcciDst>(m, "DcciDst", py::module_local())
        .value("CACHELINE_ALL", ascendc::DcciDst::CACHELINE_ALL)
        .value("CACHELINE_UB", ascendc::DcciDst::CACHELINE_UB)
        .value("CACHELINE_OUT", ascendc::DcciDst::CACHELINE_OUT)
        .value("CACHELINE_ATOMIC", ascendc::DcciDst::CACHELINE_ATOMIC)
        .def_static("symbolize", [](uint8_t v) -> ascendc::DcciDst { return static_cast<ascendc::DcciDst>(v); });

    py::enum_<ascendc::MaskMode>(m, "MaskMode", py::module_local())
        .value("NORMAL", ascendc::MaskMode::NORMAL)
        .value("COUNTER", ascendc::MaskMode::COUNTER)
        .def_static("symbolize", [](uint8_t v) -> ascendc::MaskMode { return static_cast<ascendc::MaskMode>(v); });

    py::enum_<ascendc::ReduceOrder>(m, "ReduceOrder", py::module_local())
        .value("ORDER_VALUE_INDEX", ascendc::ReduceOrder::ORDER_VALUE_INDEX)
        .value("ORDER_INDEX_VALUE", ascendc::ReduceOrder::ORDER_INDEX_VALUE)
        .value("ORDER_ONLY_VALUE", ascendc::ReduceOrder::ORDER_ONLY_VALUE)
        .value("ORDER_ONLY_INDEX", ascendc::ReduceOrder::ORDER_ONLY_INDEX)
        .def_static(
            "symbolize", [](uint8_t v) -> ascendc::ReduceOrder { return static_cast<ascendc::ReduceOrder>(v); });

    py::enum_<ascendc::RoundMode>(m, "RoundMode", py::module_local())
        .value("CAST_NONE", ascendc::RoundMode::CAST_NONE)
        .value("CAST_RINT", ascendc::RoundMode::CAST_RINT)
        .value("CAST_FLOOR", ascendc::RoundMode::CAST_FLOOR)
        .value("CAST_CEIL", ascendc::RoundMode::CAST_CEIL)
        .value("CAST_ROUND", ascendc::RoundMode::CAST_ROUND)
        .value("CAST_TRUNC", ascendc::RoundMode::CAST_TRUNC)
        .value("CAST_ODD", ascendc::RoundMode::CAST_ODD)
        .def_static("symbolize", [](uint8_t v) -> ascendc::RoundMode { return static_cast<ascendc::RoundMode>(v); });

    py::enum_<ascendc::TPosition>(m, "TPosition", py::module_local())
        .def_static(
            "symbolize", [](uint8_t pos) -> ascendc::TPosition { return static_cast<ascendc::TPosition>(pos); });

    py::enum_<ascendc::CMPMODE>(m, "CMPMODE", py::module_local())
        .value("LT", ascendc::CMPMODE::LT)
        .value("GT", ascendc::CMPMODE::GT)
        .value("EQ", ascendc::CMPMODE::EQ)
        .value("LE", ascendc::CMPMODE::LE)
        .value("GE", ascendc::CMPMODE::GE)
        .value("NE", ascendc::CMPMODE::NE)
        .def_static(
            "symbolize", [](uint8_t cmpMode) -> ascendc::CMPMODE { return static_cast<ascendc::CMPMODE>(cmpMode); });

    py::enum_<ascendc::SELMODE>(m, "SELMODE", py::module_local())
        .value("VSEL_CMPMASK_SPR", ascendc::SELMODE::VSEL_CMPMASK_SPR)
        .value("VSEL_TENSOR_SCALAR_MODE", ascendc::SELMODE::VSEL_TENSOR_SCALAR_MODE)
        .value("VSEL_TENSOR_TENSOR_MODE", ascendc::SELMODE::VSEL_TENSOR_TENSOR_MODE)
        .def_static(
            "symbolize", [](uint8_t selMode) -> ascendc::SELMODE { return static_cast<ascendc::SELMODE>(selMode); });
}

void bindContextAndDialect(py::module& m)
{
    py::class_<MLIRContext>(m, "Context", py::module_local())
        .def(py::init<>())
        .def("disable_multithreading", [](MLIRContext& self) { self.disableMultithreading(); });

    m.def("load_dialects", [](MLIRContext& context) {
        DialectRegistry registry;
        registry.insert<
            //
            arith::ArithDialect, ascendc::AscendCDialect, emitasc::EmitAscDialect, emitc::EmitCDialect,
            func::FuncDialect, memref::MemRefDialect, scf::SCFDialect, vector::VectorDialect
            //
            >();
        ascendc::registerExternalModels(registry);
        ascendc::registerInlinerInterfaces(registry);
        emitasc::registerExternalModels(registry);
        func::registerAllExtensions(registry);
        context.appendDialectRegistry(registry);
        context.loadAllAvailableDialects();
    });
}

void bindType(py::module& m)
{
    py::class_<Type>(m, "Type", py::module_local())
        .def("is_integer", [](Type& self) -> bool { return self.isInteger(); })
        .def("is_index", &Type::isIndex)
        .def(
            "__eq__",
            [](Type& self, py::object& other) {
                Type* otherTy = py::cast<Type*>(other);
                return (otherTy != nullptr) && (*otherTy == self);
            })
        .def(
            "__ne__",
            [](Type& self, py::object& other) {
                Type* otherTy = py::cast<Type*>(other);
                return (otherTy == nullptr) || (*otherTy != self);
            })
        .def(
            "get_py_name",
            [](Type& self) -> std::optional<std::string> {
                if (isa<IntegerType>(self)) {
                    std::string name = self.isUnsignedInteger() ? "uint" : "int";
                    name += std::to_string(self.getIntOrFloatBitWidth());
                    return name;
                }
                if (isa<FloatType>(self)) {
                    std::string name = "float";
                    name += std::to_string(self.getIntOrFloatBitWidth());
                    return name;
                }
                if (isa<NoneType>(self))
                    return "void";
                return std::nullopt;
            })
        .def("__str__", [](Type& self) {
            std::string str;
            llvm::raw_string_ostream os(str);
            self.print(os);
            os.flush();
            return os.str();
        });
}

void bindMemref(py::module& m)
{
    using namespace pybind11::literals;
    m.def("get_element_type", [](const Type& shapedType) -> Type {
        auto type = llvm::dyn_cast_if_present<ShapedType>(shapedType);
        if (!type)
            throw std::runtime_error("get_element_type(): must be shaped type");
        return type.getElementType();
    });

    m.def("get_shape", [](const Type& shapedType) -> std::vector<int64_t> {
        auto type = llvm::dyn_cast_if_present<ShapedType>(shapedType);
        if (!type)
            throw std::runtime_error("get_shape(): must be shaped type");
        return type.getShape().vec();
    });

    m.def("get_vector_type", [](Type& elementType, std::vector<int64_t>& shape) -> Type {
        return VectorType::get(shape, elementType);
    });

    m.def(
        "get_memref_type",
        [](Type& elementType, const std::variant<std::vector<int64_t>, int64_t>& shape,
           std::optional<int64_t> addressSpace) -> Type {
            Attribute memorySpace;
            if (auto as = addressSpace.value_or(0)) {
                memorySpace = IntegerAttr::get(IntegerType::get(elementType.getContext(), index64), as);
            }
            SmallVector<int64_t> sh;
            if (std::holds_alternative<int64_t>(shape)) {
                sh.push_back(std::get<int64_t>(shape));
            } else {
                const auto& shapeVec = std::get<std::vector<int64_t>>(shape);
                sh.append(shapeVec.begin(), shapeVec.end());
            }
            return MemRefType::get(sh, elementType, AffineMap{}, memorySpace);
        },
        "element_type"_a, "shape"_a, "address_space"_a = py::none());
    m.def(
        "get_unranked_memref_type",
        [](Type& elementType, std::optional<int64_t> addressSpace) -> Type {
            Attribute memorySpace;
            if (auto as = addressSpace.value_or(0))
                memorySpace = IntegerAttr::get(IntegerType::get(elementType.getContext(), index64), as);
            return UnrankedMemRefType::get(elementType, memorySpace);
        },
        "element_type"_a, "address_space"_a = py::none());
}

void bindTensorType(py::module& m)
{
    using namespace pybind11::literals;
    m.def("get_global_tensor_type", [](Type& elementType, std::vector<int64_t>& shape) -> Type {
        return ascendc::GlobalTensorType::get(shape, elementType);
    });

    m.def("get_global_tensor_type", [](Type& elementType) -> Type {
        return ascendc::GlobalTensorType::get(elementType);
    });

    m.def("get_local_tensor_type", [](Type& elementType, std::vector<int64_t>& shape) -> Type {
        return ascendc::LocalTensorType::get(shape, elementType);
    });

    m.def(
        "get_local_tensor_type", [](Type& elementType) -> Type { return ascendc::LocalTensorType::get(elementType); });

    m.def("get_opaque_type_name", [](Type& type) -> std::string {
        return cast<emitc::OpaqueType>(type).getValue().str();
    });
}

void bindLocation(py::module& m)
{
    py::class_<Location>(m, "Location", py::module_local()).def("__str__", [](Location& self) {
        std::string str;
        llvm::raw_string_ostream os(str);
        self.print(os);
        return os.str();
    });
}

void bindValue(py::module& m)
{
    using ret = py::return_value_policy;
    py::class_<Value>(m, "Value", py::module_local())
        .def("get_context", &Value::getContext, ret::reference)
        .def(
            "get_defining_op",
            [](Value& self) -> std::optional<Operation*> {
                auto* def = self.getDefiningOp();
                if (def)
                    return def;
                return std::nullopt;
            },
            ret::reference)
        .def("replace_all_uses_with", [](Value& self, Value& newValue) { self.replaceAllUsesWith(newValue); })
        .def(
            "replace_uses_in_block",
            [](Value& self, Block* block, Value& newValue) {
                self.replaceUsesWithIf(newValue, [block](OpOperand& opnd) -> bool {
                    auto* op = opnd.getOwner();
                    Block* parentBlock = op->getBlock();
                    while (parentBlock) {
                        if (parentBlock == block)
                            return true;
                        if (auto* parentOp = parentBlock->getParentOp())
                            parentBlock = parentOp->getBlock();
                        else
                            parentBlock = nullptr;
                    }
                    return false;
                });
            })
        .def("get_type", &Value::getType)
        .def("dump", &Value::dump)
        .def("id", [](Value& self) { return reinterpret_cast<uint64_t>(self.getImpl()); });
}

void bindRegion(py::module& m)
{
    using ret = py::return_value_policy;
    py::class_<OpResult, Value>(m, "OpResult", py::module_local());
    py::class_<BlockArgument, Value>(m, "BlockArgument", py::module_local());
    py::class_<Region>(m, "Region", py::module_local())
        .def("get_parent_region", &Region::getParentRegion, ret::reference)
        .def(
            "get_block",
            [](Region& self, unsigned index) -> Block& {
                if (index >= self.getBlocks().size())
                    throw std::runtime_error("block index is out of range");
                return *std::next(self.begin(), index);
            },
            ret::reference)
        .def("size", [](Region& self) { return self.getBlocks().size(); })
        .def("empty", &Region::empty)
        .def("id", [](Region& self) { return (uint64_t)&self; });
}

void bindBlocks(py::module& m)
{
    using ret = py::return_value_policy;
    py::class_<Block>(m, "Block", py::module_local())
        .def(py::init())
        .def("dump", &Block::dump)
        .def("id", [](Block& self) { return (uint64_t)&self; })
        .def("has_terminator", &Block::mightHaveTerminator)
        .def("get_terminator", &Block::getTerminator, ret::reference)
        .def(
            "add_argument",
            [](Block& self, Type& type) -> BlockArgument {
                return self.addArgument(type, UnknownLoc::get(type.getContext()));
            })
        .def("get_argument", &Block::getArgument)
        .def("get_arguments", [](Block& self) -> std::vector<BlockArgument> { return self.getArguments().vec(); })
        .def(
            "merge_block_before",
            [](Block& self, Block& dst) {
                // See RewriterBase::mergeBlocks()
                if (self.getNumArguments() != 0)
                    throw std::runtime_error("Unable to merge block with arguments");
                dst.getOperations().splice(dst.begin(), self.getOperations());
                self.dropAllUses();
                if (self.getParent())
                    self.erase();
            })
        .def("clear", &Block::clear)
        .def("erase", &Block::erase);
}

void bindInlineBlock(py::module& m)
{
    using namespace pybind11::literals;

    m.def(
        "inline_block_at_end",
        [](Block* src, Block* dst, const std::optional<std::vector<Value>>& args) {
            // See RewriterBase::inlineBlockBefore()
            ValueRange argValues({});
            if (args)
                argValues = *args;
            auto before = dst->end();
            if (argValues.size() != src->getNumArguments())
                throw std::runtime_error("incorrect # of argument replacement values");
            // Replace all of the successor arguments with the provided values
            for (auto [arg, newVal] : llvm::zip(src->getArguments(), argValues))
                arg.replaceAllUsesWith(newVal);
            dst->getOperations().splice(before, src->getOperations());
            if (!src->empty()) {
                throw std::runtime_error("expected 'src' to be empty");
            }
            if (src->getParent())
                src->erase();
        },
        "src"_a, "dst"_a, "args"_a = py::none());
}

void bindAttritube(py::module& m)
{
    using ret = py::return_value_policy;
    py::class_<Attribute>(m, "Attribute", py::module_local())
        .def("dump", &Attribute::dump)
        .def("id", [](Attribute& self) { return reinterpret_cast<uint64_t>(self.getAsOpaquePointer()); });

    py::class_<ArrayAttr, Attribute>(m, "ArrayAttr", py::module_local());

    m.def("get_type_attr", [](const Type& type) -> Attribute { return TypeAttr::get(type); });
}

void bindOperation(py::module& m)
{
    using ret = py::return_value_policy;
    py::class_<Operation, std::unique_ptr<Operation, py::nodelete>>(m, "Operation", py::module_local())
        .def(
            "get_name",
            [](Operation& self) {
                llvm::StringRef opName = self.getName().getStringRef();
                return opName.str();
            })
        .def("get_num_operands", &Operation::getNumOperands)
        .def("get_operand", &Operation::getOperand)
        .def("get_num_results", &Operation::getNumResults)
        .def("get_result", &Operation::getResult)
        .def("get_num_regions", &Operation::getNumRegions)
        .def("get_region", &Operation::getRegion, ret::reference)
        .def("get_block", &Operation::getBlock, ret::reference)
        .def(
            "has_unit_attr",
            [](Operation& self, const std::string& name) -> bool { return self.hasAttrOfType<UnitAttr>(name); })
        .def(
            "get_str_attr",
            [](Operation& self, const std::string& name) -> std::optional<std::string> {
                auto ret = self.getAttrOfType<StringAttr>(name);
                if (!ret)
                    return std::nullopt;
                return ret.getValue().str();
            })
        .def(
            "get_bool_attr",
            [](Operation& self, const std::string& name) -> std::optional<bool> {
                auto ret = self.getAttrOfType<BoolAttr>(name);
                if (!ret)
                    return std::nullopt;
                return ret.getValue();
            })
        .def(
            "get_integer_attr",
            [](Operation& self, const std::string& name) -> py::object {
                auto ret = self.getAttrOfType<IntegerAttr>(name);
                if (!ret)
                    return py::none();
                return py::int_(ret.getValue().getSExtValue());
            })
        .def("get_flat_symbol_ref_attr", [](Operation& self, const std::string& name) -> py::object {
            auto ret = self.getAttrOfType<FlatSymbolRefAttr>(name);
            if (!ret)
                return py::none();
            return py::str(ret.getValue().str());
        });
}

void bindOpstate(py::module& m)
{
    using ret = py::return_value_policy;
    py::class_<OpState>(m, "OpState", py::module_local())
        .def("get_context", &OpState::getContext, ret::reference)
        .def("set_attr", [](OpState& self, std::string& name, Attribute& attr) { self->setAttr(name, attr); })
        .def("get_num_results", [](OpState& self) -> unsigned { return self->getNumResults(); })
        .def(
            "get_result",
            [](OpState& self, unsigned idx) -> Value {
                if (idx >= self->getNumResults())
                    throw pybind11::index_error("Op result index out of range");
                return self->getResult(idx);
            })
        .def(
            "get_region",
            [](OpState& self, unsigned idx) -> Region& {
                if (idx >= self->getNumRegions())
                    throw pybind11::index_error("Op region index out of range");
                return self->getRegion(idx);
            },
            ret::reference)
        .def("dump", [](OpState& self) { self->dump(); })
        .def(
            "__str__",
            [](OpState& self) -> std::string {
                std::string str;
                llvm::raw_string_ostream os(str);
                auto printingFlags = getOpPrintingFlags();
                self->print(os, printingFlags);
                return str;
            })
        .def("append_operand", [](OpState& self, Value& val) { self->insertOperands(self->getNumOperands(), val); })
        .def("verify", [](OpState& self) -> bool { return succeeded(verify(self.getOperation())); })
        .def_property_readonly("op", &OpState::getOperation, ret::reference);
}

void bindModuleop(py::module& m)
{
    using ret = py::return_value_policy;
    using namespace pybind11::literals;

    py::class_<ModuleOp, OpState>(m, "ModuleOp", py::module_local())
        .def("dump", &ModuleOp::dump)
        .def(
            "get_body", [](ModuleOp& self) -> Block* { return self.getBody(); }, ret::reference)
        .def(
            "has_function",
            [](ModuleOp& self, const std::string& name, const std::optional<Type>& type) -> bool {
                auto* op = SymbolTable::lookupSymbolIn(self, name);
                if (auto funcOp = dyn_cast_if_present<func::FuncOp>(op))
                    return !type || funcOp.getFunctionType() == *type;
                return false;
            },
            "name"_a, "type"_a = py::none())
        .def(
            "need_insert_sync",
            [](ModuleOp& self) {
                auto result = self.walk([](ascendc::LocalTensorAutoOp) { return WalkResult::interrupt(); });
                return result.wasInterrupted();
            })
        .def("erase", [](ModuleOp& self) { self->erase(); });
}

void bindFuncop(py::module& m)
{
    using ret = py::return_value_policy;
    py::class_<func::FuncOp, OpState>(m, "FuncOp", py::module_local())
        .def(
            "get_arg",
            [](func::FuncOp& self, unsigned idx) -> BlockArgument {
                if (idx >= self.getNumArguments())
                    throw pybind11::index_error("Function argument index out of range");
                return self.getArgument(idx);
            })
        .def("get_num_args", &func::FuncOp::getNumArguments)
        .def(
            "add_entry_block", [](func::FuncOp& self) -> Block* { return self.addEntryBlock(); }, ret::reference)
        .def(
            "set_type",
            [](func::FuncOp& self, const Type& funcType) {
                auto type = dyn_cast<FunctionType>(funcType);
                if (!type)
                    throw std::runtime_error("set_type(): must be FunctionType");
                self.setFunctionType(type);
            })
        .def(
            "set_arg_names",
            [](func::FuncOp& self, const std::vector<std::string>& names) {
                if (names.size() != self.getNumArguments())
                    throw std::runtime_error("Number of names must be equal to number of arguments");
                for (unsigned i = 0; i < names.size(); i++) {
                    auto arg = self.getArgument(i);
                    auto name = StringAttr::get(self.getContext(), names[i]);
                    arg.setLoc(NameLoc::get(name, arg.getLoc()));
                }
            })
        .def(
            "get_body", [](func::FuncOp& self) -> Block& { return self.getFunctionBody().front(); }, ret::reference)
        .def(
            "make_aicore",
            [](func::FuncOp& self) { self->setAttr(ascendc::attr::aicore, UnitAttr::get(self.getContext())); })
        .def("make_global", [](func::FuncOp& self) {
            self.setPublic();
            self->setAttr(ascendc::attr::global, UnitAttr::get(self.getContext()));
        });
}

void bindScfop(py::module& m)
{
    using ret = py::return_value_policy;
    py::class_<scf::ForOp, OpState>(m, "ForOp", py::module_local())
        .def("get_induction_var", &scf::ForOp::getInductionVar)
        .def("get_body", [](scf::ForOp& self) -> Block* { return self.getBody(); }, ret::reference);
    py::class_<scf::IfOp, OpState>(m, "IfOp", py::module_local())
        .def("get_then_block", &scf::IfOp::thenBlock, ret::reference)
        .def("get_else_block", &scf::IfOp::elseBlock, ret::reference)
        .def("get_then_yield", &scf::IfOp::thenYield)
        .def("get_else_yield", &scf::IfOp::elseYield);
    py::class_<scf::YieldOp, OpState>(m, "YieldOp", py::module_local());
    py::class_<scf::WhileOp, OpState>(m, "WhileOp", py::module_local())
        .def("get_before", &scf::WhileOp::getBefore, ret::reference)
        .def("get_after", &scf::WhileOp::getAfter, ret::reference);
    py::class_<scf::ConditionOp, OpState>(m, "ConditionOp", py::module_local());
}

void bindKernelArgument(py::module& m)
{
    py::enum_<emitasc::KernelArgument>(m, "KernelArgument", py::module_local())
        .value("Explicit", emitasc::KernelArgument::Explicit)
        .value("FftsAddr", emitasc::KernelArgument::FftsAddr);

    m.def("get_kernel_arg_attrs", [](ModuleOp& mod) -> py::object {
        auto kernelArgs = getKernelArgAttrs(mod);
        if (!kernelArgs) {
            return py::none();
        }
        py::list result;
        for (auto arg : kernelArgs.value()) {
            result.append(arg);
        }
        return result;
    });
}

} // namespace

namespace pybind11 {
namespace asc {
void initIRModule(py::module&& m)
{
    bindEnums(m);
    bindContextAndDialect(m);
    bindType(m);
    bindMemref(m);
    bindTensorType(m);
    bindLocation(m);
    bindValue(m);
    bindRegion(m);
    bindBlocks(m);
    bindInlineBlock(m);
    bindAttritube(m);
    bindOperation(m);
    bindOpstate(m);
    bindModuleop(m);
    bindFuncop(m);
    bindScfop(m);
    bindKernelArgument(m);
    py::class_<OpBuilder::InsertPoint>(m, "InsertPoint", py::module_local());

    initBuilderInIRModule(m);
}
} // namespace asc
} // namespace pybind11