* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "AutoBlockify/AutoBlockify.h"
#include "AutoBlockify/Utils.h"
#include "Dialect/TritonAscend/IR/TritonAscendDialect.h"
#include "Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "auto-blockify"
using namespace mlir;
using namespace triton;
PropagateUnrealizedCastDown::PropagateUnrealizedCastDown(MLIRContext *context,
Value logicalBlockId,
Value logicalBlockNum,
int autoBlockifySize)
: OpRewritePattern<UnrealizedConversionCastOp>(context),
logicalBlockId(logicalBlockId), logicalBlockNum(logicalBlockNum),
autoBlockifySize(autoBlockifySize) {}
LogicalResult
PropagateUnrealizedCastDown::matchAndRewrite(UnrealizedConversionCastOp op,
PatternRewriter &rewriter) const {
if (op.getInputs().size() != 2)
return failure();
auto funcOp = op->getParentOfType<triton::FuncOp>();
auto input = op.getInputs()[0];
auto res = op->getResult(0);
SmallPtrSet<Operation *, 8> users(op->user_begin(), op->user_end());
LLVM_DEBUG({
auto &os = llvm::dbgs();
os << "Handling UnrealizedConversionCastOp:\n" << op << "\n";
os << "Users:\n";
for (auto *user : users)
os << *user << "\n";
});
for (auto *user : users) {
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(user);
if (auto uccOp = dyn_cast<UnrealizedConversionCastOp>(user)) {
if (uccOp->getResultTypes()[0] != input.getType()) {
LLVM_DEBUG({
auto &os = llvm::dbgs();
os << *user << "\n";
});
return op.emitError("UnrealizedConversionCastOp cannot be resolved\n");
}
rewriter.replaceOp(user, input);
} else if (auto blockifyLoop = getBlockifyLoop(user)) {
handleBlockifyLoop(blockifyLoop.value(), user, rewriter);
} else if (auto splatOp = dyn_cast<triton::SplatOp>(user)) {
rewriteSplat(op, splatOp, rewriter);
} else if (auto expandDimsOp = dyn_cast<triton::ExpandDimsOp>(user)) {
rewriteExpandDims(op, expandDimsOp, rewriter);
} else if (auto reduceOp = dyn_cast<triton::ReduceOp>(user)) {
rewriteReduce(op, reduceOp, rewriter);
} else if (auto scanOp = dyn_cast<triton::ScanOp>(user)) {
rewriteScan(op, scanOp, rewriter);
} else if (auto loadOp = dyn_cast<triton::LoadOp>(user)) {
rewriteLoad(op, loadOp, rewriter);
} else if (auto storeOp = dyn_cast<triton::StoreOp>(user)) {
rewriteStore(op, storeOp, rewriter);
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(user)) {
rewriteAtomicRMW(op, atomicRMWOp, rewriter);
} else if (auto assertOp = dyn_cast<triton::AssertOp>(user)) {
rewriteAssert(op, assertOp, rewriter);
} else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user)) {
rewriteExtractSlice(op, extractSliceOp, rewriter);
} else if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(user)) {
rewriteInsertSlice(op, insertSliceOp, rewriter);
} else if (auto whileOp = dyn_cast<scf::WhileOp>(user)) {
rewriteWhile(op, whileOp, rewriter);
} else if (auto loopOp = dyn_cast<LoopLikeOpInterface>(user)) {
rewriteLoop(op, loopOp, rewriter);
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
rewriteYield(op, yieldOp, rewriter);
} else if (auto conditionOp = dyn_cast<scf::ConditionOp>(user)) {
rewriteCondition(op, conditionOp, rewriter);
} else if (user->hasTrait<OpTrait::Elementwise>() ||
isa<triton::BroadcastOp, triton::JoinOp,
triton::ReshapeOp, triton::PrintOp,
triton::ascend::AnnotationOp>(user)) {
rewriteGeneraleOp(op, user, rewriter);
} else if (isa<triton::AtomicCASOp>(user)) {
auto *newOp =
createBlockifyLoop(user, op, logicalBlockId, logicalBlockNum,
autoBlockifySize, rewriter);
rewriter.setInsertionPoint(newOp);
handleBlockifyLoop(*getBlockifyLoop(newOp), newOp, rewriter);
} else {
LLVM_DEBUG({
auto &os = llvm::dbgs();
os << "Unhandled Op\n" << *user << "\n";
});
llvm_unreachable("Unhandled operation");
}
}
LLVM_DEBUG({
auto &os = llvm::dbgs();
os << "After successful conversion\n";
os << funcOp << "\n";
});
rewriter.eraseOp(op);
return success();
}
AutoBlockifyPass::AutoBlockifyPass(const AutoBlockifyOptions &options)
: AutoBlockifyBase(options) {}
bool AutoBlockifyPass::checkBlockifiable(Value v) {
if (!checkedValues.insert(v).second)
return true;
LLVM_DEBUG({
auto &os = llvm::dbgs();
os << "Checking blockifiable:\n" << v << "\n";
});
for (auto &use : v.getUses()) {
auto *user = use.getOwner();
auto opNum = use.getOperandNumber();
LLVM_DEBUG({
auto &os = llvm::dbgs();
os << "User:\n" << *user << "\n";
});
if (isa<cf::CondBranchOp, triton::IntToPtrOp, scf::WhileOp, triton::DotOp>(user) ||
llvm::any_of(user->getOperandTypes(), isTensorPtrType))
return false;
if (auto ifOp = dyn_cast<scf::IfOp>(user)) {
user->setAttr(autoBlockifyRegionOpAttr, UnitAttr::get(v.getContext()));
return true;
} else if (auto whileOp = dyn_cast<scf::WhileOp>(user)) {
if (!checkBlockifiable(whileOp.getBeforeArguments()[opNum]))
return false;
} else if (auto loopOp = dyn_cast<LoopLikeOpInterface>(user)) {
auto regionIterArg = loopOp.getTiedLoopRegionIterArg(&use);
auto loopResult = loopOp.getTiedLoopResult(&use);
if (!regionIterArg || !loopResult) {
user->setAttr(autoBlockifyRegionOpAttr, UnitAttr::get(v.getContext()));
return true;
}
if (!checkBlockifiable(regionIterArg) || !checkBlockifiable(loopResult))
return false;
} else if (auto conditionOp = dyn_cast<scf::ConditionOp>(user)) {
auto whileOp = cast<scf::WhileOp>(user->getParentOp());
if (opNum == 0) {
whileOp->setAttr(autoBlockifyRegionOpAttr,
UnitAttr::get(v.getContext()));
return true;
}
if (!checkBlockifiable(whileOp.getAfterArguments()[opNum - 1]) ||
!checkBlockifiable(whileOp->getResult(opNum - 1)))
return false;
} else if (auto conditionOp = dyn_cast<scf::YieldOp>(user)) {
if (auto loopOp = dyn_cast<LoopLikeOpInterface>(user->getParentOp());
loopOp && !checkBlockifiable(loopOp.getInits()[opNum]))
return false;
} else {
for (auto res : user->getResults()) {
if (!checkBlockifiable(res))
return false;
}
}
}
return true;
}
void AutoBlockifyPass::preProcess(triton::FuncOp func) {
IRRewriter rewriter(func.getContext());
rewriter.setInsertionPointToStart(&func.getBody().front());
auto loc = rewriter.getUnknownLoc();
auto xNum =
rewriter.create<triton::GetNumProgramsOp>(loc, triton::ProgramIDDim::X);
auto yNum =
rewriter.create<triton::GetNumProgramsOp>(loc, triton::ProgramIDDim::Y);
auto zNum =
rewriter.create<triton::GetNumProgramsOp>(loc, triton::ProgramIDDim::Z);
auto yzNum = rewriter.create<arith::MulIOp>(loc, yNum, zNum);
logicalBlockNum = rewriter.create<arith::MulIOp>(loc, yzNum, xNum);
auto xDim =
rewriter.create<triton::GetProgramIdOp>(loc, triton::ProgramIDDim::X);
auto yDim =
rewriter.create<triton::GetProgramIdOp>(loc, triton::ProgramIDDim::Y);
auto zDim =
rewriter.create<triton::GetProgramIdOp>(loc, triton::ProgramIDDim::Z);
xDim->setAttr(logicalBlockIdAttr, rewriter.getUnitAttr());
yDim->setAttr(logicalBlockIdAttr, rewriter.getUnitAttr());
zDim->setAttr(logicalBlockIdAttr, rewriter.getUnitAttr());
auto xFlatten = rewriter.create<arith::MulIOp>(loc, xDim, yzNum);
auto yFlatten = rewriter.create<arith::MulIOp>(loc, yDim, zNum);
logicalBlockId = rewriter.create<arith::AddIOp>(loc, xFlatten, yFlatten);
logicalBlockId = rewriter.create<arith::AddIOp>(loc, logicalBlockId, zDim);
auto blockifyTensorType =
RankedTensorType::get({autoBlockifySize}, rewriter.getI32Type());
auto blockfyRange = rewriter.create<triton::MakeRangeOp>(
loc, blockifyTensorType, 0, autoBlockifySize);
auto splatedLogicalBlockId = rewriter.create<triton::SplatOp>(
loc, blockfyRange.getType(), logicalBlockId);
Value blockifiedId =
rewriter.create<arith::AddIOp>(loc, splatedLogicalBlockId, blockfyRange);
auto splatedBlockNum = rewriter.create<triton::SplatOp>(
loc, blockfyRange.getType(), logicalBlockNum);
auto upperboundMask = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, blockifiedId, splatedBlockNum);
auto splatedZero = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(blockifyTensorType,
rewriter.getI32IntegerAttr(0)));
auto lowerboundMask = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, blockifiedId, splatedZero);
Value blockifiedIdMask =
rewriter.create<arith::OrIOp>(loc, upperboundMask, lowerboundMask);
blockifiedId = rewriter
.create<UnrealizedConversionCastOp>(
loc, logicalBlockId.getType(),
ValueRange({blockifiedId, blockifiedIdMask}))
->getResult(0);
SmallVector<triton::GetProgramIdOp> toReplace;
func.walk([&](triton::GetProgramIdOp id) {
if (id->hasAttr(logicalBlockIdAttr))
return;
toReplace.push_back(id);
});
for (auto id : toReplace) {
rewriter.setInsertionPoint(id);
Value newId;
if (id.getAxis() == triton::ProgramIDDim::X) {
newId = rewriter.create<arith::DivSIOp>(id.getLoc(), blockifiedId, yzNum);
newId = rewriter.create<arith::RemSIOp>(id.getLoc(), newId, xNum);
} else if (id.getAxis() == triton::ProgramIDDim::Y) {
newId = rewriter.create<arith::DivSIOp>(id.getLoc(), blockifiedId, zNum);
newId = rewriter.create<arith::RemSIOp>(id.getLoc(), newId, yNum);
} else {
newId = rewriter.create<arith::RemSIOp>(id.getLoc(), blockifiedId, zNum);
}
rewriter.replaceOp(id, newId);
}
func.walk<WalkOrder::PreOrder>([&](Operation *op) {
if (op->hasAttr(autoBlockifyRegionOpAttr)) {
auto *newOp = createBlockifyLoop(
op, blockifiedId.getDefiningOp<UnrealizedConversionCastOp>(),
logicalBlockId, logicalBlockNum, autoBlockifySize, rewriter);
newOp->removeAttr(autoBlockifyRegionOpAttr);
return WalkResult::skip();
}
return WalkResult::advance();
});
}
void AutoBlockifyPass::runOnOperation() {
if (autoBlockifySize == 1)
return;
ModuleOp moduleOp = getOperation();
if (autoBlockifySize <= 0) {
moduleOp->emitWarning("[AutoBlockify V2] AutoBlockifySize cannot be "
"negative integer, skipping.");
return signalPassFailure();
}
MLIRContext *ctx = &getContext();
moduleOp.walk([&](triton::FuncOp func) {
LogicalResult result = success();
func.walk([&](triton::GetProgramIdOp id) {
if (!checkBlockifiable(id.getResult())) {
result = failure();
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (failed(result)) {
func->emitWarning("Cannot apply auto blockify");
return WalkResult::skip();
}
preProcess(func);
LLVM_DEBUG({
auto &os = llvm::dbgs();
os << "After preprocess:\n" << func << "\n";
});
RewritePatternSet patterns(ctx);
patterns.add<PropagateUnrealizedCastDown>(
ctx, logicalBlockId, logicalBlockNum, autoBlockifySize);
if (failed(applyPatternsGreedily(func, std::move(patterns)))) {
moduleOp->emitError("failed to apply Patterns");
signalPassFailure();
return WalkResult::interrupt();
}
IRRewriter rewriter(ctx);
func->walk([&](UnrealizedConversionCastOp op) {
rewriter.setInsertionPoint(op);
auto input = op.getInputs()[0];
auto resType = cast<RankedTensorType>(op->getResultTypes()[0]);
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
Attribute val = constantOp.getValue();
if (auto denseAttr = dyn_cast<DenseElementsAttr>(val))
val = denseAttr.getSplatValue<Attribute>();
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, DenseElementsAttr::get(resType, val));
} else if (auto tensorType =
dyn_cast<RankedTensorType>(input.getType())) {
input = rewriter.create<triton::ExpandDimsOp>(input.getLoc(), input, 0);
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(op, resType, input);
} else {
rewriter.replaceOpWithNewOp<triton::SplatOp>(op, resType, input);
}
});
func->setAttr(autoBlockifySizeAttr,
rewriter.getI32IntegerAttr(autoBlockifySize));
return WalkResult::skip();
});
PassManager pm(&getContext(), moduleOp.getOperationName());
pm.addPass(createCSEPass());
pm.addPass(createCanonicalizerPass());
if (failed(runPipeline(pm, moduleOp))) {
signalPassFailure();
}
}
std::unique_ptr<OperationPass<ModuleOp>>
triton::createAutoBlockifyPass(const AutoBlockifyOptions &options) {
return std::make_unique<AutoBlockifyPass>(options);
}