#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Pass/PassManager.h"
#include "ascend/include/AutoBlockify/Passes.h"
#include "ascend/include/TritonToStructured/Passes.h"
#include "ascend/include/TritonToAnnotation/Passes.h"
#include "ascend/include/TritonToLinalg/Passes.h"
#include "ascend/include/Dialect/TritonAscend/IR/TritonAscendDialect.h"
#include "ascend/include/DiscreteMaskAccessConversion/Passes.h"
#include "ascend/include/TritonToUnstructure/Passes.h"
#include "ascend/include/TritonToHIVM/Passes.h"
#include "ascend/include/TritonToHFusion/Passes.h"
#include "ascend/include/TritonToLLVM/Passes.h"
#include "ascend/include/DynamicCVPipeline/Passes.h"
#include "ascend/include/DynamicCVPipeline/Common/BufferCountManager.h"
#include "ascend/include/TritonAffinityOpt/Passes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "ir.h"
#include <pybind11/pybind11.h>
namespace py = pybind11;
using namespace ir;
using namespace mlir;
void init_triton_ascend_ir(py::module &&m) {
auto *builder_cls = ir::getBuilderClass();
builder_cls
->def("create_extract_scalar",
[](TritonOpBuilder &self, Value &src, std::vector<Value> &indices) -> Value {
llvm::SmallVector<Value> arg_indices;
for (const auto &i : indices) {
auto iTy = i.getType();
if (!iTy.isIndex()) {
auto v = self.create<arith::IndexCastOp>(
self.getBuilder().getIndexType(), i);
arg_indices.push_back(v);
} else {
arg_indices.push_back(i);
}
}
auto ret = self.create<tensor::ExtractOp>(src, arg_indices);
return ret;
})
.def("create_extract_slice",
[](TritonOpBuilder &self, Value &ful, std::vector<Value> &offs_vec,
std::vector<int> &sizs_vec, std::vector<int> &strd_vec) -> Value {
llvm::SmallVector<Value> offsets;
llvm::SmallVector<int64_t> staticOffsets;
for (const auto &o : offs_vec) {
auto oTy = o.getType();
if (!oTy.isIndex()) {
auto v = self.create<arith::IndexCastOp>(
self.getBuilder().getIndexType(), o);
offsets.push_back(v);
} else {
offsets.push_back(o);
}
staticOffsets.push_back(ShapedType::kDynamic);
}
llvm::SmallVector<Value> sizes;
llvm::SmallVector<int64_t> staticSizes;
llvm::SmallVector<int64_t> retSizes;
for (const auto &s : sizs_vec) {
staticSizes.push_back(s);
retSizes.push_back(s);
}
llvm::SmallVector<Value> strides;
llvm::SmallVector<int64_t> staticStrides;
for (const auto &s : strd_vec) {
auto v = self.create<arith::ConstantIndexOp>(s);
strides.push_back(v);
staticStrides.push_back(ShapedType::kDynamic);
}
auto retTy = RankedTensorType::get(retSizes,
cast<RankedTensorType>(ful.getType()).getElementType());
return self.create<tensor::ExtractSliceOp>(retTy, ful,
offsets, sizes, strides,
staticOffsets, staticSizes, staticStrides);
})
.def("create_insert_slice",
[](TritonOpBuilder &self, Value &ful, Value &sub,
std::vector<Value> &offs_vec, std::vector<int> &sizs_vec,
std::vector<int> &strd_vec) -> Value {
llvm::SmallVector<Value> offsets;
llvm::SmallVector<int64_t> staticOffsets;
for (const auto &o : offs_vec) {
auto oTy = o.getType();
if (!oTy.isIndex()) {
auto v = self.create<arith::IndexCastOp>(
self.getBuilder().getIndexType(), o);
offsets.push_back(v);
} else {
offsets.push_back(o);
}
staticOffsets.push_back(ShapedType::kDynamic);
}
llvm::SmallVector<Value> sizes;
llvm::SmallVector<int64_t> staticSizes;
llvm::SmallVector<int64_t> retSizes;
for (const auto &s : sizs_vec) {
staticSizes.push_back(s);
retSizes.push_back(s);
}
llvm::SmallVector<Value> strides;
llvm::SmallVector<int64_t> staticStrides;
for (const auto &s : strd_vec) {
auto v = self.create<arith::ConstantIndexOp>(s);
strides.push_back(v);
staticStrides.push_back(ShapedType::kDynamic);
}
auto retTy = RankedTensorType::get(
retSizes,
cast<RankedTensorType>(ful.getType()).getElementType());
auto ret = self.create<tensor::InsertSliceOp>(sub, ful,
offsets, sizes, strides,
staticOffsets, staticSizes, staticStrides);
return ret;
})
.def("create_custom_op_for_inter_core_sync",
[](TritonOpBuilder &self, std::string &op_name,
std::string &mode_or_sender, int id) -> void {
auto args = self.getBuilder().getArrayAttr(
{self.getBuilder().getStringAttr(mode_or_sender),
self.getBuilder().getI32IntegerAttr(id)}
);
self.create<triton::ascend::CustomOp>(op_name, args, ValueRange());
})
.def("create_index_select_simd",
[](TritonOpBuilder &self, Value &src, Value &index, int32_t dim,
std::vector<Value> &srcShape, std::vector<Value> &srcOffset,
std::vector<int32_t> &readShape, std::vector<int32_t> &returnShape) -> Value {
auto &builder = self.getBuilder();
auto loc = self.getLastLoc();
Type elemType;
if (auto ptrTy = dyn_cast<triton::PointerType>(src.getType())) {
elemType = ptrTy.getPointeeType();
} else {
llvm::report_fatal_error("index_select_simd: src must be pointer type");
}
llvm::SmallVector<int64_t> retShape;
for (const auto &s : returnShape) {
retShape.push_back(s);
}
auto retTensorType = RankedTensorType::get(retShape, elemType);
llvm::SmallVector<Value> srcShapeIndex;
for (auto val : srcShape) {
if (!val.getType().isIndex()) {
val = self.create<arith::IndexCastOp>(builder.getIndexType(), val);
}
srcShapeIndex.push_back(val);
}
llvm::SmallVector<Value> srcOffsetIndex;
for (auto val : srcOffset) {
if (!val.getType().isIndex()) {
val = self.create<arith::IndexCastOp>(builder.getIndexType(), val);
}
srcOffsetIndex.push_back(val);
}
auto dimAttr = builder.getI32IntegerAttr(dim);
auto readShapeAttr = builder.getDenseI32ArrayAttr(readShape);
auto indexSelectSimdOp = builder.create<triton::ascend::IndexSelectSimdOp>(
loc,
retTensorType,
src,
index,
dimAttr,
srcShapeIndex,
srcOffsetIndex,
readShapeAttr
);
return indexSelectSimdOp.getResult();
})
.def("create_index_put",
[](TritonOpBuilder &self, Value &ptr, Value &index,
Value &value, const int32_t dim, const int64_t indexBoundary,
std::vector<Value> &endOffset, std::vector<Value> &startOffset,
std::vector<Value> &dstStride) -> void {
auto dimI32Ty = self.getBuilder().getI32Type();
auto dim_val = self.create<arith::ConstantIntOp>(dimI32Ty, static_cast<int64_t>(dim));
auto BoundI64Ty = self.getBuilder().getI64Type();
auto bound_val = self.create<arith::ConstantIntOp>(BoundI64Ty, static_cast<int64_t>(indexBoundary));
self.create<triton::ascend::IndexPutOp>(
ptr,
index,
value,
dim_val,
bound_val,
endOffset,
startOffset,
dstStride
);
})
.def("create_gather_out_to_ub",
[](TritonOpBuilder &self, Value &src, Value &index, const int64_t indexBoundary,
const int32_t dim, std::vector<Value> &srcStride, std::vector<Value> &endOffset,
std::vector<Value> &startOffset, std::optional<Value> &other) -> Value {
auto elemTy = cast<PointerType>(src.getType()).getPointeeType();
auto idxTy = cast<RankedTensorType>(index.getType());
auto idxShape = idxTy.getShape();
std::vector<int64_t> retShape(idxShape.begin(), idxShape.end());
auto resType = RankedTensorType::get(retShape, elemTy);
auto BoundI64Ty = self.getBuilder().getI64Type();
auto bound_val = self.create<arith::ConstantIntOp>(BoundI64Ty, static_cast<int64_t>(indexBoundary));
auto dimI32Ty = self.getBuilder().getI32Type();
auto dim_val = self.create<arith::ConstantIntOp>(dimI32Ty, static_cast<int64_t>(dim));
return self.create<triton::ascend::GatherOutToUbOp>(
resType,
src,
index,
bound_val,
dim_val,
srcStride,
endOffset,
startOffset,
other.value_or(Value())
);
})
.def("create_scatter_ub_to_out",
[](TritonOpBuilder &self, Value &ptr, Value &value, Value &index,
const int64_t indexBoundary, const int32_t dim, std::vector<Value> &dstStride,
std::vector<Value> &endOffset, std::vector<Value> &startOffset) -> void {
auto idxTy = cast<RankedTensorType>(index.getType());
auto BoundI64Ty = self.getBuilder().getI64Type();
auto bound_val = self.create<arith::ConstantIntOp>(BoundI64Ty, static_cast<int64_t>(indexBoundary));
auto dimI32Ty = self.getBuilder().getI32Type();
auto dim_val = self.create<arith::ConstantIntOp>(dimI32Ty, static_cast<int64_t>(dim));
self.create<triton::ascend::ScatterUbToOutOp>(
ptr,
value,
index,
bound_val,
dim_val,
dstStride,
endOffset,
startOffset
);
})
.def("create_conv1d",
[](TritonOpBuilder &self, Value input, Value weight, py::object bias,
int64_t stride, int64_t padding_size, int64_t dilation, int64_t groups,
Type output_type) -> Value {
Value biasValue;
if (!bias.is_none()) {
biasValue = bias.cast<Value>();
} else {
biasValue = Value();
}
auto &builder = self.getBuilder();
auto strideAttr = builder.getI64IntegerAttr(stride);
auto paddingSizeAttr = builder.getI64IntegerAttr(padding_size);
auto dilationAttr = builder.getI64IntegerAttr(dilation);
auto groupsAttr = builder.getI64IntegerAttr(groups);
auto op = self.create<triton::ascend::Conv1dOp>(
output_type,
input,
weight,
biasValue,
strideAttr,
paddingSizeAttr,
dilationAttr,
groupsAttr
);
return op.getResult();
},
py::arg("input"), py::arg("weight"), py::arg("bias"),
py::arg("stride"), py::arg("padding_size"), py::arg("dilation"),
py::arg("groups"), py::arg("output_type")
)
.def("create_sort",
[](TritonOpBuilder &self, Value src, int64_t dim, bool descending) -> Value {
auto &builder = self.getBuilder();
auto loc = self.getLastLoc();
auto dimAttr = builder.getI64IntegerAttr(dim);
auto descendingAttr = builder.getBoolAttr(descending);
auto op = builder.create<triton::ascend::SortOp>(loc, src, dimAttr, descendingAttr);
return op->getResult(0);
})
.def("create_flip",
[](TritonOpBuilder &self, Value src, int64_t dim) -> Value {
auto &builder = self.getBuilder();
auto loc = self.getLastLoc();
auto dimAttr = builder.getI64IntegerAttr(dim);
auto op = builder.create<triton::ascend::FlipOp>(loc, src, dimAttr);
return op->getResult(0);
})
.def("create_annotation",
[](TritonOpBuilder &self, Value &ptr, const std::string &attrKey,
Attribute &attrVal) {
auto annotationOp = self.create<triton::ascend::AnnotationOp>(ptr);
annotationOp->setAttr(self.getBuilder().getStringAttr(attrKey),
attrVal);
});
}
void init_triton_ascend_passes_ttir(py::module &&m) {
m.def("add_auto_blockify", [](mlir::PassManager &pm,
int autoBlockifySize) {
AutoBlockifyOptions opts;
opts.autoBlockifySize = autoBlockifySize;
pm.addPass(mlir::triton::createAutoBlockifyPass(opts));});
m.def("add_triton_to_structure", [](mlir::PassManager &pm,
bool enableMaskFallbackConversion, bool optimizeDynamicOffset) {
pm.addPass(mlir::triton::createTritonToStructuredPass(
enableMaskFallbackConversion, optimizeDynamicOffset)); });
m.def("add_triton_to_annotation", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::createTritonToAnnotationPass());});
m.def("add_triton_to_linalg", [](mlir::PassManager &pm, bool globalKernel,
bool namedOps, bool enableNd2nzOnVector, bool enableSelectAnalysis,
bool compileOn91095) {
pm.addPass(mlir::triton::createTritonToLinalgPass(
globalKernel, namedOps, enableNd2nzOnVector,
enableSelectAnalysis, compileOn91095)); });
m.def("add_triton_to_unstructure", [](mlir::PassManager &pm,
bool compileOn91095, bool forceSimtTemplate) {
TritonToUnstructureOptions opts;
opts.compileOn91095 = compileOn91095;
opts.forceSimtTemplate = forceSimtTemplate;
pm.addPass(mlir::triton::createTritonToUnstructurePass(opts));});
m.def("add_triton_to_hfusion", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::createTritonToHFusionPass());});
m.def("add_discrete_mask_access_conversion", [](mlir::PassManager &pm,
bool compileOn91095, bool forceSimtTemplate, bool enableSyncBlockLock) {
DiscreteMaskAccessConversionOptions opts;
opts.compileOn91095 = compileOn91095;
opts.forceSimtTemplate = forceSimtTemplate;
opts.enableSyncBlockLock = enableSyncBlockLock;
pm.addPass(mlir::triton::createDiscreteMaskAccessConversionPass(opts));});
m.def("add_triton_to_hivm", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::createTritonToHIVMPass());});
m.def("add_triton_to_llvm", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::createTritonToLLVMPass());});
m.def("add_bubble_up_operation", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::createBubbleUpOperationPass());});
m.def("add_dynamic_cv_pipeline", [](mlir::PassManager &pm,
bool compileOn91095) {
AddDynamicCVPipelineOptions opts;
opts.compileOn91095 = compileOn91095;
pm.addPass(mlir::triton::createAddDynamicCVPipelinePass(opts));
});
m.def("add_dag_sync", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::createDAGSyncPass());});
m.def("add_dag_scope", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::createDAGScopePass());});
m.def("add_dag_ssbuffer", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::createDAGSSBufferPass());});
m.def("set_buffer_count", [](int type, int count) {
auto depType = static_cast<mlir::triton::BufferCountManager::DepType>(type);
mlir::triton::BufferCountManager::getInstance().setBufferCount(depType, count);
});
}
void init_ascend_ir(py::module &&m);
void init_triton_ascend(py::module &&m) {
auto passes = m.def_submodule("passes");
m.def("load_dialects", [](mlir::MLIRContext &context) {
mlir::DialectRegistry registry;
registry.insert<mlir::triton::ascend::TritonAscendDialect>();
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
});
init_triton_ascend_passes_ttir(passes.def_submodule("ttir"));
init_triton_ascend_ir(m.def_submodule("ascend_ir"));
init_ascend_ir(m.def_submodule("ir"));
}