* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* Copyright (c) Microsoft Corporation.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "ascend/include/TritonToLinalg/TritonToLinalgPass.h"
#include "ascend/include/TritonToLinalg/ArgMinMaxConverter.h"
#include "ascend/include/TritonToLinalg/FunctionConverter.h"
#include "ascend/include/TritonToLinalg/LoadStoreConverter.h"
#include "ascend/include/TritonToLinalg/TritonOpConverter.h"
#include "ascend/include/Dialect/TritonAscend/IR/TritonAscendDialect.h"
#include "ascend/include/TritonToLinalg/DescriptorConverter.h"
#include "ascend/include/TritonToLinalg/HoistBroadcast.h"
#include "ascend/include/TritonToLinalg/UseAnalysis.h"
#include "ascend/include/TritonToLinalg/ImplicitPermute.h"
#include "ascend/include/TritonToLinalg/MarkTensorKindPass.h"
#include "ascend/include/TritonToStructured/CannonicalizerConverter.h"
#include "ascend/include/Utils/InterleaveOptimization.h"
#include "ascend/include/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "bishengir/Dialect/HFusion/IR/HFusion.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "bishengir/Dialect/HIVM/IR/HIVM.h"
#include "bishengir/Dialect/Annotation/IR/Annotation.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include <cassert>
#include <cstdint>
#include <optional>
#define DEBUG_TYPE "triton-to-linalg"
using namespace mlir;
using namespace triton;
int nd2nzFlag = 0;
bool compileOn91095Flag = false;
bool existDotFlag = false;
class CustomOpConverter : public OpConversionPattern<hivm::CustomOp> {
public:
using OpConversionPattern<hivm::CustomOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(hivm::CustomOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override
{
llvm::SmallDenseMap<Value, BlockData> known;
BlockDataParser::rewriteCustomOp(op, adaptor, rewriter, known);
return success();
}
};
static bool isSIMTOp(Operation *op)
{
if (auto custom_op = dyn_cast<hivm::CustomOp>(op)) {
return custom_op.getCoreType() == hivm::TCoreType::VECTOR &&
custom_op.getVFMode() == hivm::VFMode::SIMT;
}
return isa<
triton::ascend::IndexPutOp,
triton::ascend::GatherOutToUbOp,
triton::ascend::ScatterUbToOutOp,
triton::ascend::IndirectLoadOp,
triton::ascend::IndirectStoreOp
>(op);
}
TritonTypeConverter::TritonTypeConverter() {
addConversion([](Type type) { return type; });
addConversion([](triton::PointerType ptrType) {
Type elem = ptrType.getPointeeType();
if (auto it = dyn_cast<IntegerType>(elem); it && it.getWidth() == 1) {
elem = IntegerType::get(ptrType.getContext(), 8);
LLVM_DEBUG({
llvm::dbgs() << "[TritonTypeConverter] Normalize i1 pointer to i8 "
"memref. elemType="
<< elem << "\n";
});
}
return MemRefType::get({ShapedType::kDynamic}, elem);
});
addConversion([](TensorType tensorType) -> Type {
auto elemType = tensorType.getElementType();
if (auto ptrType = dyn_cast<triton::PointerType>(elemType)) {
elemType = ptrType.getPointeeType();
}
if (auto it = dyn_cast<IntegerType>(elemType); it && it.getWidth() == 1) {
elemType = IntegerType::get(tensorType.getContext(), 8);
LLVM_DEBUG({
llvm::dbgs() << "[TritonTypeConverter] Normalize i1 tensor to i8 "
"memref. elemType="
<< elemType << "\n";
});
}
return MemRefType::get(tensorType.getShape(), elemType);
});
}
void TritonToLinalgPass::addProgramInfo(triton::FuncOp func,
bool globalKernel) {
OpBuilder b(func);
auto origFuncType = func.getFunctionType();
auto origInputTypes = origFuncType.getInputs();
SmallVector<Type> newInputTypes(origInputTypes);
newInputTypes.append(TRITON_PROGRAM_INFO_ARG_COUNT, b.getI32Type());
auto newFuncType =
b.getFunctionType(newInputTypes, origFuncType.getResults());
func.setFunctionType(newFuncType);
if (func.getAllArgAttrs()) {
SmallVector<DictionaryAttr> newArgAttrs;
func.getAllArgAttrs(newArgAttrs);
newArgAttrs.append(TRITON_PROGRAM_INFO_ARG_COUNT, DictionaryAttr());
func.setAllArgAttrs(newArgAttrs);
}
for (unsigned i = 0; i < TRITON_PROGRAM_INFO_ARG_COUNT; i++) {
func.getBody().front().addArgument(b.getI32Type(), func.getLoc());
}
if (globalKernel) {
func->setAttr(globalKernelAttr, b.getStringAttr(""));
} else {
func->setAttr(globalKernelAttr, b.getStringAttr("local"));
}
}
LogicalResult
TritonToLinalgPass::convertMultipleBlockControlFlow(Operation *funcOp,
OpBuilder &builder) {
if (!isa<func::FuncOp>(funcOp)) {
funcOp->emitError("convertMultipleBlockControlFlow can only process func::FuncOp!");
return failure();
}
SmallVector<Operation *> candidate;
SmallVector<Block *> eraseBlocks;
for (Block &block : dyn_cast<func::FuncOp>(funcOp).getBody()) {
auto curTerminator = block.getTerminator();
if (isa<cf::CondBranchOp>(curTerminator)) {
candidate.push_back(curTerminator);
} else if (isa<triton::ReturnOp>(curTerminator)) {
if (candidate.empty()) {
curTerminator->emitError("funcOp has more than one Block but got an early 'tt.return' Op.");
return failure();
}
} else if (!isa<cf::BranchOp>(curTerminator)) {
funcOp->emitError("funcOp has more than one Block but found unsupported Terminator: ")
<< *curTerminator;
return failure();
}
if (!block.isEntryBlock())
eraseBlocks.push_back(&block);
}
LLVM_DEBUG({
llvm::dbgs() << "Found " << candidate.size()
<< " candidate cond_branch operations to convert.\n";
});
if (candidate.empty()) {
funcOp->emitError("funcOp has more than one Block but no candidate Terminator was found!");
return failure();
}
llvm::BitVector visitFlag(candidate.size(), false);
std::function<void(Operation *, Operation *)> convertToSCF =
[&](Operation *op, Operation *insertPosOp) -> void {
auto condBranchOp = dyn_cast_if_present<cf::CondBranchOp>(op);
auto iter = llvm::find(candidate, condBranchOp);
if (!(condBranchOp && iter != candidate.end())) {
op->emitError("convertToSCF must process with condBranchOp in candidates!");
return;
}
visitFlag.set(iter - candidate.begin());
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointAfter(insertPosOp);
builder.create<scf::IfOp>(
condBranchOp->getLoc(), condBranchOp.getCondition(),
[&](OpBuilder &builder, Location loc) {
SmallVector<Operation *> movedOps = llvm::map_to_vector(
condBranchOp.getTrueDest()->without_terminator(),
[](Operation &op) { return &op; });
for (auto *innerOp : movedOps) {
innerOp->moveBefore(builder.getInsertionBlock(),
builder.getInsertionPoint());
}
auto blockTerm = condBranchOp.getTrueDest()->getTerminator();
if (auto nextCond = dyn_cast<cf::CondBranchOp>(blockTerm)) {
if (movedOps.empty()) {
blockTerm->emitError("movedOps can not be empty before entering convertToSCF (then)!");
return;
}
convertToSCF(nextCond, movedOps.back());
} else if (!isa<cf::BranchOp, triton::ReturnOp>(blockTerm)) {
blockTerm->emitError("Unsupported terminator in then branch after structuring");
}
builder.create<scf::YieldOp>(loc);
},
[&](OpBuilder &builder, Location loc) {
SmallVector<Operation *> movedOps = llvm::map_to_vector(
condBranchOp.getFalseDest()->without_terminator(),
[](Operation &op) { return &op; });
for (auto *innerOp : movedOps) {
innerOp->moveBefore(builder.getInsertionBlock(),
builder.getInsertionPoint());
}
auto blockTerm = condBranchOp.getFalseDest()->getTerminator();
if (auto nextCond = dyn_cast<cf::CondBranchOp>(blockTerm)) {
if (movedOps.empty()) {
blockTerm->emitError("movedOps can not be empty before entering convertToSCF (else)!");
return;
}
convertToSCF(nextCond, movedOps.back());
} else if (!isa<cf::BranchOp, triton::ReturnOp>(blockTerm)) {
blockTerm->emitError("Unsupported terminator in else branch after structuring");
}
builder.create<scf::YieldOp>(loc);
});
};
Block::iterator insertOp(candidate.front());
if (insertOp == candidate.front()->getBlock()->begin()) {
convertToSCF(candidate.front(), candidate.front());
} else {
--insertOp;
convertToSCF(candidate.front(), &(*insertOp));
}
if (!visitFlag.all()) {
funcOp->emitError("Not all cf.cond_br converted!");
return failure();
}
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(candidate.front());
builder.create<triton::ReturnOp>(candidate.front()->getLoc());
for (Operation *eachTerm : candidate)
eachTerm->erase();
for (Block *block : llvm::reverse(eraseBlocks))
block->erase();
return success();
}
void TritonToLinalgPass::convertTTFunc(triton::FuncOp func,
const bool existDot,
const bool existSIMTOp) {
OpBuilder builder(func);
auto name = func.getName();
auto type = func.getFunctionType();
SmallVector<DictionaryAttr> argAttrs, resAttrs;
func.getAllArgAttrs(argAttrs);
func.getAllResultAttrs(resAttrs);
SmallVector<Type> inputTypes{type.getInputs()};
SmallVector<Type> retTypes{type.getResults()};
if (func.getSymVisibility() == "public" && !func.isDeclaration()) {
for (size_t i = 0; i < func.getNumArguments(); ++i) {
auto arg = func.getArgument(i);
if (!isa<BaseMemRefType>(arg.getType()) ||
dyn_cast<BaseMemRefType>(arg.getType()).getElementTypeBitWidth() !=
1) {
continue;
}
SmallVector<Operation *> argVaildUser{arg.getUsers()};
llvm::erase_if(argVaildUser, [](Operation *op) -> bool {
return isOpTriviallyDead(op);
});
if (!argVaildUser.empty()) {
LLVM_DEBUG({
auto &os = llvm::dbgs();
os << arg << " has users:\n";
int cnt = 0;
for (auto it : argVaildUser) {
os << "users[" << cnt++ << "] = " << *it;
}
});
if (llvm::all_of(argVaildUser, [](Operation *userOp) {
return isa<UnrealizedConversionCastOp>(userOp);
})) {
auto castOp = cast<UnrealizedConversionCastOp>(*argVaildUser.begin());
if (castOp.getInputs().size() == 1 &&
castOp.getOutputs().size() == 1) {
arg.setType(castOp.getOutputs()[0].getType());
inputTypes[i] = arg.getType();
}
} else {
func->emitError(Twine("Unsupported use of func arg at index ") +
Twine(i));
}
} else {
auto memType = dyn_cast<BaseMemRefType>(arg.getType())
.cloneWith(std::nullopt, builder.getI8Type());
arg.setType(memType);
inputTypes[i] = arg.getType();
}
}
}
auto castType = FunctionType::get(func.getContext(), inputTypes, retTypes);
auto funcFunc = builder.create<func::FuncOp>(func.getLoc(), name, castType);
funcFunc.setAllArgAttrs(argAttrs);
funcFunc.setAllResultAttrs(resAttrs);
auto kernelAttr = func->getAttr(globalKernelAttr);
if (kernelAttr) {
funcFunc->setAttr(globalKernelAttr, kernelAttr);
}
std::string kernelMixMode = "aiv";
if (existDot) {
kernelMixMode = "mix";
}
funcFunc->setAttr(kernelMixModeName, builder.getStringAttr(kernelMixMode));
std::string parallelMode = "simd";
if (existSIMTOp) {
parallelMode = "mix_simd_simt";
}
funcFunc->setAttr(kernelParallelModeName, builder.getStringAttr(parallelMode));
auto autoBlockifyAttr = func->getAttr("auto_blockify_size");
if (autoBlockifyAttr)
funcFunc->setAttr("auto_blockify_size", autoBlockifyAttr);
auto &funcFuncBody = funcFunc.getBody();
auto &funcBody = func.getBody();
IRMapping map;
funcBody.cloneInto(&funcFuncBody, map);
if (!funcFuncBody.hasOneBlock()) {
if (failed(convertMultipleBlockControlFlow(funcFunc, builder))) {
llvm_unreachable("Encounter unsupported control flow");
}
}
for (Block &block : funcFuncBody.getBlocks()) {
auto term = block.getTerminator();
builder.setInsertionPoint(term);
builder.create<func::ReturnOp>(func.getLoc(), term->getOperands());
term->erase();
}
func.erase();
}
void TritonToLinalgPass::addDynamicLegal(
ConversionTarget &target, TritonTypeConverter &tritonTypeConverter) {
target.addLegalDialect<
func::FuncDialect, arith::ArithDialect, math::MathDialect,
linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect,
cf::ControlFlowDialect, tensor::TensorDialect, LLVM::LLVMDialect,
bufferization::BufferizationDialect, memref::MemRefDialect,
annotation::AnnotationDialect, hivm::HIVMDialect,
hfusion::HFusionDialect>();
target.addLegalOp<ModuleOp>();
target.addDynamicallyLegalOp<mlir::UnrealizedConversionCastOp>(
[](mlir::Operation *op) {
if (op->use_empty()) {
return false;
} else {
return true;
}
});
target.addDynamicallyLegalOp<triton::FuncOp>([&](triton::FuncOp op) {
return tritonTypeConverter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<hivm::CustomOp>([&](hivm::CustomOp op) {
return all_of(op->getOperandTypes(), [](Type t) {
if (isa<triton::PointerType>(t)) {
return false;
}
if (auto shapedType = dyn_cast<ShapedType>(t)) {
return !isa<triton::PointerType>(shapedType.getElementType());
}
return true;
});
});
target.addDynamicallyLegalOp<arith::ConstantOp>([](arith::ConstantOp op) {
auto res = op.getResult();
if (!isa<RankedTensorType>(res.getType())) {
return true;
}
if (auto denseAttr = dyn_cast<DenseElementsAttr>(op.getValue())) {
if (!denseAttr.isSplat() ||
!isa<FloatType, IntegerType>(denseAttr.getElementType())) {
return true;
}
if (res.hasOneUse() && isa<tensor::ReshapeOp>(*res.user_begin())) {
return true;
}
return false;
}
return true;
});
target.addDynamicallyLegalOp<scf::ForOp, scf::YieldOp>([](Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type t) {
if (isa<triton::PointerType>(t)) {
return false;
}
if (auto shapedType = dyn_cast<ShapedType>(t)) {
return shapedType.getElementType().isIntOrFloat();
}
assert(t.isIntOrIndexOrFloat());
return true;
});
});
target.addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect>(
[this](Operation *op) {
if (op->hasAttr("MetaUse")) {
return false;
}
if (isa<arith::ConstantOp>(op)) {
return true;
}
bool operateOnTensors =
llvm::all_of(op->getOperandTypes(),
[](Type type) { return isa<RankedTensorType>(type); });
return this->namedOps || !operateOnTensors;
});
}
void TritonToLinalgPass::populateTritonToLinalgCanonicalizationPatterns(RewritePatternSet &patterns)
{
patterns.add<LoadStoreConverter::LoadStoreCanonicalizer<triton::LoadOp>,
LoadStoreConverter::LoadStoreCanonicalizer<triton::StoreOp>,
LoadStoreConverter::LoadStoreCanonicalizer<triton::AtomicRMWOp>,
LoadStoreConverter::LoadStoreCanonicalizer<triton::AtomicCASOp>>(patterns.getContext());
patterns.add<TTOpConverters::BitcastCanonicalizer>(patterns.getContext());
patterns.add<TTOpConverters::FpToFpCanonicalizer>(patterns.getContext());
patterns.add<LoadStoreConverter::ScalarStoreCanonicalizer>(patterns.getContext());
patterns.add<LoadStoreConverter::ScalarAtomicRMWCanonicalizer>(patterns.getContext());
patterns.add<LoadStoreConverter::ScalarAtomicCASCanonicalizer>(patterns.getContext());
patterns.add<LoadStoreConverter::AtomicMaxMinCanonicalizer>(patterns.getContext());
patterns.add<
TTOpConverters::ScalarMathCanonicalizer<math::AbsFOp>,
TTOpConverters::ScalarMathCanonicalizer<math::CeilOp>, TTOpConverters::ScalarMathCanonicalizer<math::CosOp>,
TTOpConverters::ScalarMathCanonicalizer<math::ErfOp>, TTOpConverters::ScalarMathCanonicalizer<math::ExpOp>,
TTOpConverters::ScalarMathCanonicalizer<math::Exp2Op>,
TTOpConverters::ScalarMathCanonicalizer<math::FloorOp>,
TTOpConverters::ScalarMathCanonicalizer<math::LogOp>,
TTOpConverters::ScalarMathCanonicalizer<math::Log2Op>,
TTOpConverters::ScalarMathCanonicalizer<math::RsqrtOp>, TTOpConverters::ScalarMathCanonicalizer<math::SinOp>,
TTOpConverters::ScalarMathCanonicalizer<math::SqrtOp>,
TTOpConverters::ScalarMathCanonicalizer<arith::AddFOp>, TTOpConverters::ScalarMathCanonicalizer<arith::SubFOp>,
TTOpConverters::ScalarMathCanonicalizer<arith::MulFOp>,
TTOpConverters::ScalarMathCanonicalizer<arith::NegFOp>, TTOpConverters::ScalarMathCanonicalizer<arith::RemFOp>,
TTOpConverters::ScalarMathCanonicalizer<arith::MaxNumFOp>,
TTOpConverters::ScalarMathCanonicalizer<arith::MaximumFOp>,
TTOpConverters::ScalarMathCanonicalizer<arith::MinNumFOp>,
TTOpConverters::ScalarMathCanonicalizer<arith::MinimumFOp>
>(patterns.getContext());
patterns.add<TTOpConverters::ReduceSingleCanonicalizer>(patterns.getContext());
if (this->enableSelectAnalysis) {
patterns.add<TTOpConverters::SelectCanonicalizer>(patterns.getContext());
}
}
void TritonToLinalgPass::populateTritonToLinalgConversionPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
unsigned int launchGridRank) {
nd2nzFlag = this->enableNd2nzOnVector;
populateFunctionOpInterfaceTypeConversionPattern<triton::FuncOp>(
patterns, typeConverter);
patterns.add<triton::MetaUseEraser>(patterns.getContext());
patterns.add<LoadStoreConverter::StoreConverter>(patterns.getContext());
patterns.add<LoadStoreConverter::AddPtrConverter>(patterns.getContext());
patterns.add<FunctionConverter::GetProgramIDConverter>(patterns.getContext());
patterns.add<FunctionConverter::GetNumProgramsConverter>(
patterns.getContext());
patterns.add<LoadStoreConverter::LoadConverter>(patterns.getContext());
patterns.add<LoadStoreConverter::AtomicRMWConverter>(patterns.getContext());
patterns.add<LoadStoreConverter::AtomicCASConverter>(patterns.getContext());
patterns.add<TTOpConverters::MakeRangeConverter>(patterns.getContext());
patterns.add<TTOpConverters::SplatConverter>(patterns.getContext());
patterns.add<TTOpConverters::UnsplatConverter>(patterns.getContext());
patterns.add<TTOpConverters::ClampFConverter>(patterns.getContext());
patterns.add<TTOpConverters::PreciseDivConverter>(patterns.getContext());
patterns.add<TTOpConverters::ArgMinConverter>(patterns.getContext());
patterns.add<TTOpConverters::ArgMaxConverter>(patterns.getContext());
patterns.add<TTOpConverters::ReduceConverter>(patterns.getContext());
patterns.add<TTOpConverters::ScanConverter>(patterns.getContext());
patterns.add<TTOpConverters::ReshapeConverter>(patterns.getContext());
patterns.add<TTOpConverters::ExpandDimsConverter>(patterns.getContext());
patterns.add<TTOpConverters::BroadcastConverter>(patterns.getContext());
patterns.add<TTOpConverters::DenseConstantConverter>(patterns.getContext());
patterns.add<TTOpConverters::ExternElementwiseClOpConverter>(
patterns.getContext());
patterns.add<TTOpConverters::TritonMulhiuiConverter>(patterns.getContext());
patterns.add<TTOpConverters::TritonPreciseSqrtConverter>(
patterns.getContext());
patterns.add<TTOpConverters::MakeTensorPtrConverter>(patterns.getContext());
patterns.add<TTOpConverters::AdvanceConverter>(patterns.getContext());
patterns.add<TTOpConverters::TransposeConverter>(patterns.getContext());
patterns.add<TTOpConverters::SplitConverter>(patterns.getContext());
patterns.add<TTOpConverters::JoinConverter>(patterns.getContext());
patterns.add<TTOpConverters::CatConverter>(patterns.getContext());
patterns.add<TTOpConverters::BitcastConverter>(patterns.getContext());
patterns.add<TTOpConverters::LoopConverter<scf::ForOp>>(patterns.getContext());
patterns.add<TTOpConverters::LoopConverter<scf::WhileOp>>(patterns.getContext());
patterns.add<TTOpConverters::YieldConverter>(patterns.getContext());
patterns.add<TTOpConverters::DeviceAssertConverter>(patterns.getContext());
patterns.add<TTOpConverters::DevicePrintConverter>(patterns.getContext());
patterns.add<TTOpConverters::MatmulConverter>(patterns.getContext());
patterns.add<TTOpConverters::DotScaledConverter>(patterns.getContext());
patterns.add<TTOpConverters::PtrToIntConverter>(patterns.getContext());
patterns.add<TTOpConverters::IndirectLoadConverter>(patterns.getContext());
patterns.add<TTOpConverters::IndirectStoreConverter>(patterns.getContext());
patterns.add<TTOpConverters::GatherOutToUbConverter>(patterns.getContext());
patterns.add<TTOpConverters::ScatterUbToOutConverter>(patterns.getContext());
patterns.add<TTOpConverters::IndexSelectSimdConverter>(patterns.getContext());
patterns.add<TTOpConverters::IndexPutConverter>(patterns.getContext());
patterns.add<TTOpConverters::SortOpConverter>(patterns.getContext());
patterns.add<TTOpConverters::FlipOpConverter>(patterns.getContext());
patterns.add<TTOpConverters::GatherConverter>(patterns.getContext());
patterns.add<CustomOpConverter>(patterns.getContext());
if (!this->namedOps) {
linalg::populateElementwiseToLinalgConversionPatterns(patterns);
}
}
void TritonToLinalgPass::getDependentDialects(DialectRegistry ®istry) const {
registry.insert<func::FuncDialect, arith::ArithDialect, math::MathDialect,
linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect,
tensor::TensorDialect, bufferization::BufferizationDialect,
memref::MemRefDialect, hfusion::HFusionDialect,
hivm::HIVMDialect, annotation::AnnotationDialect>();
}
LogicalResult TritonToLinalgPass::processDescriptorOperations(ModuleOp moduleOp)
{
mlir::ConversionTarget target(getContext());
target.addDynamicallyLegalDialect<mlir::arith::ArithDialect, mlir::scf::SCFDialect, triton::TritonDialect>(
[](mlir::Operation *op) {
return !DescriptorConverter::hasATensorDescriptorType(op->getOperandTypes()) &&
!DescriptorConverter::hasATensorDescriptorType(op->getResultTypes());
});
target.addDynamicallyLegalOp<triton::FuncOp>([](triton::FuncOp funcOp) {
return !DescriptorConverter::hasATensorDescriptorType(funcOp.getFunctionType().getInputs()) &&
!DescriptorConverter::hasATensorDescriptorType(funcOp.getFunctionType().getResults());
});
target.addLegalOp<triton::MakeTensorDescOp>();
target.addIllegalOp<triton::DescriptorLoadOp, triton::DescriptorStoreOp>();
mlir::RewritePatternSet patterns(&getContext());
patterns.add<DescriptorConverter::DescriptorLoadConverter>(patterns.getContext());
patterns.add<DescriptorConverter::DescriptorStoreConverter>(patterns.getContext());
mlir::ConversionConfig config;
config.buildMaterializations = true;
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns), config))) {
moduleOp->emitError("failed to convert tensor descriptor operations");
return failure();
}
return success();
}
LogicalResult TritonToLinalgPass::processPtrBroadcastOperations(ModuleOp moduleOp)
{
mlir::ConversionTarget target(getContext());
target.addLegalOp<triton::SplatOp>();
target.addLegalOp<triton::AddPtrOp>();
target.addDynamicallyLegalOp<triton::BroadcastOp>([](triton::BroadcastOp op) {
if (op->hasAttr("MetaUse")) {
return true;
}
auto resultType = dyn_cast<RankedTensorType>(op.getType());
HoistBroadcast::BroadcastHoister hoister(op);
return !(isa<triton::PointerType>(resultType.getElementType()) && hoister.canBroadcast());
});
mlir::RewritePatternSet patterns(&getContext());
patterns.add<HoistBroadcast::BroadcastConverter>(patterns.getContext());
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
moduleOp->emitError("failed to convert ptr broadcast operations");
return failure();
}
return success();
}
LogicalResult TritonToLinalgPass::processImplicitPermuteOperations(ModuleOp moduleOp)
{
mlir::RewritePatternSet patterns(&getContext());
patterns.add<ImplicitPermute::LoadConverter>(patterns.getContext());
patterns.add<ImplicitPermute::StoreConverter>(patterns.getContext());
patterns.add<ImplicitPermute::AtomicRMWConverter>(patterns.getContext());
patterns.add<ImplicitPermute::AtomicCASConverter>(patterns.getContext());
patterns.add<CannonicalizerConverter::SplatCmpConverter>(patterns.getContext());
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
LLVM_DEBUG({
llvm::dbgs() << "ImplicitPermute: rewrite MemOp failed\n";
});
}
mlir::PassManager pm(&getContext(), moduleOp.getOperationName());
pm.addPass(createCSEPass());
pm.addPass(createCanonicalizerPass());
return runPipeline(pm, getOperation());
}
LogicalResult TritonToLinalgPass::processLegalStrideOperations(ModuleOp moduleOp)
{
mlir::ConversionTarget target(getContext());
target.addLegalOp<arith::ConstantOp>();
target.addDynamicallyLegalOp<memref::ReinterpretCastOp>(
[](memref::ReinterpretCastOp op) {
return !LoadStoreConverter::ReinterpretCastStrideCanonicalizer::hasFixableZeroStride(op);
});
mlir::RewritePatternSet patterns(&getContext());
patterns.add<LoadStoreConverter::ReinterpretCastStrideCanonicalizer>(
patterns.getContext());
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
moduleOp->emitError(
"failed to legalize reinterpret_cast dynamic stride(0) with size(1)");
return failure();
}
return success();
}
void TritonToLinalgPass::runOnOperation() {
compileOn91095Flag = this->compileOn91095;
auto moduleOp = getOperation();
bool existDot = false;
moduleOp.walk([&](triton::DotOp dotOp) {
existDot = true;
return WalkResult::interrupt();
});
moduleOp.walk([&](triton::DotScaledOp dotScaledOp) {
existDot = true;
return WalkResult::interrupt();
});
existDotFlag = existDot;
bool existSIMTOp = false;
moduleOp.walk([&](Operation *op) {
if (isSIMTOp(op)) {
existSIMTOp = true;
LLVM_DEBUG({
auto &os = llvm::dbgs();
os << "Found SIMT op in function: ";
os << op->getName();
os << "\n";
});
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (failed(processDescriptorOperations(moduleOp))) {
signalPassFailure();
}
if (failed(processImplicitPermuteOperations(moduleOp))) {
LLVM_DEBUG({
llvm::dbgs() << "Failed to process implicit permute operations\n";
});
signalPassFailure();
}
{
PassManager pm(&getContext(), moduleOp.getOperationName());
pm.addPass(triton::createMarkTensorKindPass());
if (failed(runPipeline(pm, moduleOp))) {
moduleOp->emitError("failed to run LoopCanonicalizerPass");
signalPassFailure();
return;
}
}
RewritePatternSet canonicalizerPatterns(&getContext());
this->populateTritonToLinalgCanonicalizationPatterns(canonicalizerPatterns);
if (failed(applyPatternsGreedily(moduleOp,
std::move(canonicalizerPatterns)))) {
moduleOp->emitError("failed to apply Canonicalizer Patterns");
signalPassFailure();
}
{
PassManager pm(&getContext(), moduleOp.getOperationName());
pm.addPass(createCSEPass());
pm.addPass(createCanonicalizerPass());
if (failed(runPipeline(pm, moduleOp))) {
moduleOp->emitError("failed to pre-clean dead control-flow before use analysis");
signalPassFailure();
return;
}
}
moduleOp.walk([this](triton::FuncOp op) {
if (failed(runUseAnalysis(op))) {
signalPassFailure();
}
});
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
TritonTypeConverter tritonTypeConverter{};
this->addDynamicLegal(target, tritonTypeConverter);
auto loopOpLegalFn = [](LoopLikeOpInterface op) {
return !op.getOperation()->hasAttr("UnhandledLoopOp");
};
target.addIllegalOp<triton::ScanOp>();
target.addDynamicallyLegalOp<scf::ForOp>(loopOpLegalFn);
target.addDynamicallyLegalOp<scf::WhileOp>(loopOpLegalFn);
if (failed(processPtrBroadcastOperations(moduleOp))) {
signalPassFailure();
}
this->populateTritonToLinalgConversionPatterns(tritonTypeConverter, patterns,
LAUNCH_GRID_RANK);
for (auto func : getOperation().getOps<triton::FuncOp>()) {
addProgramInfo(func, globalKernel);
}
moduleOp.walk([this](LoopLikeOpInterface loopOp) {
auto *op = loopOp.getOperation();
if (!op->hasAttr("ExtractedLoadOrStore"))
op->setAttr("UnhandledLoopOp", UnitAttr::get(op->getContext()));
for (auto res: loopOp->getResults()) {
if (auto tensorType = dyn_cast<RankedTensorType>(res.getType());
tensorType && !isa<triton::PointerType>(tensorType.getElementType())) {
IRRewriter rewriter(op->getContext());
rewriter.setInsertionPointAfter(op);
auto newVal = rewriter.create<tensor::CastOp>(op->getLoc(), res.getType(), res);
rewriter.replaceAllUsesExcept(res, newVal, newVal);
}
}
});
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
moduleOp->emitError("failed to apply Conversion Patterns");
signalPassFailure();
}
if (failed(processLegalStrideOperations(moduleOp))) {
signalPassFailure();
}
moduleOp.walk(
[&](triton::FuncOp func) { this->convertTTFunc(func, existDot, existSIMTOp); });
PassManager pm(&getContext(), moduleOp.getOperationName());
pm.addPass(createCSEPass());
pm.addPass(createCanonicalizerPass());
if (failed(runPipeline(pm, getOperation()))) {
signalPassFailure();
}
SmallVector<hivm::PointerCastOp> castOps;
moduleOp.walk([&](hivm::PointerCastOp op) { castOps.push_back(op); });
for (auto op : castOps) {
SmallVector<Operation *> userOps(op->getUsers().begin(),
op->getUsers().end());
IRRewriter rewriter(&getContext());
rewriter.setInsertionPointAfter(op);
Value addr = op.getAddrs()[0];
auto elementType =
cast<MemRefType>(op.getResult().getType()).getElementType();
Value elementTypeSize;
if (auto intType = dyn_cast<IntegerType>(elementType)) {
elementTypeSize = rewriter.create<arith::ConstantOp>(op.getLoc(), rewriter.getIntegerAttr(addr.getType(), intType.getWidth() / 8));
} else if (auto floatType = dyn_cast<FloatType>(elementType)) {
elementTypeSize = rewriter.create<arith::ConstantOp>(op.getLoc(), rewriter.getIntegerAttr(addr.getType(), floatType.getWidth() / 8));
} else {
llvm_unreachable("Cannot get memory size");
}
for (auto userOp : userOps) {
auto reinterpretCastOp = cast<memref::ReinterpretCastOp>(userOp);
auto sizes = reinterpretCastOp.getStaticSizes();
auto staticStrides = reinterpretCastOp.getStaticStrides();
auto strides = reinterpretCastOp.getStrides();
if(reinterpretCastOp.getStaticOffsets().size() != 1)
userOp->emitError("IntToPtrOp must converted to PointerCastOp of memref<?xdtype> type");
int64_t castOpSize = 0;
SmallVector<int64_t> dynamicSizes;
for (const auto &[size, stride] : llvm::zip_equal(sizes, staticStrides)) {
assert(!ShapedType::isDynamic(size));
if (ShapedType::isDynamic(stride))
dynamicSizes.push_back(size);
else
castOpSize = size * stride;
}
rewriter.setInsertionPoint(reinterpretCastOp);
Value dynamicSize = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(castOpSize));
for (const auto &[size, stride] :
llvm::zip_equal(dynamicSizes, strides)) {
Value axisSize = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(size));
axisSize =
rewriter.create<arith::MulIOp>(op.getLoc(), stride, axisSize);
dynamicSize = rewriter.create<arith::AddIOp>(op.getLoc(), dynamicSize,
axisSize);
}
Value offsetValue;
auto staticOffset = reinterpretCastOp.getStaticOffsets()[0];
if (ShapedType::isDynamic(staticOffset)) {
offsetValue = reinterpretCastOp.getOffsets()[0];
if (offsetValue.getType() != addr.getType())
offsetValue = rewriter.create<arith::IndexCastOp>(
op.getLoc(), addr.getType(), offsetValue);
} else {
offsetValue = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(addr.getType(), staticOffset));
}
offsetValue = rewriter.create<arith::MulIOp>(op.getLoc(), offsetValue, elementTypeSize);
Value realAddr = rewriter.create<arith::AddIOp>(op.getLoc(), addr, offsetValue);
auto memrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
auto newCastOp = rewriter.create<hivm::PointerCastOp>(
op.getLoc(), memrefType, realAddr, dynamicSize);
auto markOp = rewriter.create<annotation::MarkOp>(op.getLoc(),
newCastOp.getResult());
markOp->setAttr(hivm::AddressSpaceAttr::getMnemonic(),
{hivm::AddressSpaceAttr::get(rewriter.getContext(),
hivm::AddressSpace::GM)});
auto oldMemrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
auto newMemrefType = MemRefType::get(oldMemrefType.getShape(), oldMemrefType.getElementType(),
StridedLayoutAttr::get(&getContext(), ShapedType::kDynamic, staticStrides));
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
reinterpretCastOp,
newMemrefType, newCastOp,
ValueRange({}), reinterpretCastOp.getSizes(),
reinterpretCastOp.getStrides(), SmallVector<int64_t>({0}),
reinterpretCastOp.getStaticSizes(),
reinterpretCastOp.getStaticStrides());
}
rewriter.eraseOp(op);
}
llvm::DenseMap<BlockArgument, SmallVector<Operation *>> interleaveCandidate;
llvm::DenseMap<BlockArgument, SmallVector<Operation *>>
interleaveCandidateWithMask;
moduleOp.walk([&](bufferization::MaterializeInDestinationOp materializeOp) {
if (auto reinterpretCastOp =
materializeOp.getDest()
.getDefiningOp<memref::ReinterpretCastOp>()) {
if (llvm::isa<BlockArgument>(reinterpretCastOp.getSource()) &&
reinterpretCastOp.getStaticStrides().back() == 2) {
interleaveCandidate[llvm::cast<BlockArgument>(
reinterpretCastOp.getSource())]
.push_back(materializeOp);
}
}
if (auto subviewOp =
materializeOp.getDest().getDefiningOp<memref::SubViewOp>()) {
if (!llvm::isa<tensor::ExtractSliceOp>(
materializeOp.getSource().getDefiningOp()))
return WalkResult::advance();
if (auto reinterpretCastOp =
subviewOp.getSource()
.getDefiningOp<memref::ReinterpretCastOp>()) {
if (llvm::isa<BlockArgument>(reinterpretCastOp.getSource()) &&
reinterpretCastOp.getStaticStrides().back() == 2) {
interleaveCandidateWithMask[llvm::cast<BlockArgument>(
reinterpretCastOp.getSource())]
.push_back(materializeOp);
}
}
}
return WalkResult::advance();
});
for (auto [blockArg, materializeVec] : interleaveCandidate) {
if (materializeVec.size() != 2)
continue;
auto result = InterleaveStatusOptimization(materializeVec);
}
for (auto [blockArg, materializeVec] : interleaveCandidateWithMask) {
if (materializeVec.size() != 2)
continue;
auto result = InterleaveStatusWithMaskOptimization(materializeVec);
}
for (auto func : getOperation().getOps<func::FuncOp>()) {
if (!func->hasAttr("global_kernel"))
continue;
auto context = func.getContext();
constexpr int64_t syncBlockLockArgIdx = 0;
NamedAttribute syncBlockLockArgAttr(StringAttr::get(context, "syncBlockLock"),
UnitAttr::get(context));
MemRefType syncBlockLockArgType =
MemRefType::get(SmallVector<int64_t>(1, ShapedType::kDynamic),
IntegerType::get(context, 8));
llvm::LogicalResult syncBlockLockArg = func.insertArgument(syncBlockLockArgIdx,
syncBlockLockArgType,
nullptr, func->getLoc());
func->setAttr("SyncBlockLockArgIdx",
IntegerAttr::get(IntegerType::get(&getContext(), 64), 0));
constexpr int64_t workspaceArgIdx = 1;
MemRefType workspaceArgType =
MemRefType::get(SmallVector<int64_t>(1, ShapedType::kDynamic),
IntegerType::get(context, 8));
NamedAttribute workspaceArgAttr(StringAttr::get(context, "workspace"),
UnitAttr::get(context));
llvm::LogicalResult workspaceArg = func.insertArgument( workspaceArgIdx,
workspaceArgType,
nullptr, func->getLoc());
func->setAttr("WorkspaceArgIdx",
IntegerAttr::get(IntegerType::get(&getContext(), 64), 1));
}
moduleOp.walk([&](Operation *op) {
auto loc = op->getLoc();
if (isa<UnknownLoc>(loc)) {
llvm::SmallPtrSet<Operation *, 16> stopOps;
traverseForwardUpdateUserChainIf(
op,
[](Operation *curOp) { return false; },
[](Operation *curOp) { return !isa<UnknownLoc>(curOp->getLoc()); },
nullptr, stopOps);
if (stopOps.empty()) {
op->emitWarning() << *op << " and its users all have no location!";
} else {
Operation *goodOp = *stopOps.begin();
op->setLoc(goodOp->getLoc());
}
}
return WalkResult::advance();
});
}
std::unique_ptr<OperationPass<ModuleOp>>
triton::createTritonToLinalgPass(bool globalKernel,
bool namedOps,
bool enableNd2nzOnVector,
bool enableSelectAnalysis,
bool compileOn91095) {
return std::make_unique<TritonToLinalgPass>(globalKernel, namedOps,
enableNd2nzOnVector,
enableSelectAnalysis,
compileOn91095);
}
std::unique_ptr<OperationPass<ModuleOp>> triton::createTritonToLinalgPass() {
return std::make_unique<TritonToLinalgPass>();
}