* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "AutoBlockify/AutoBlockify.h"
#include "AutoBlockify/Utils.h"
#include "Utils/Utils.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "auto-blockify-rewrite-operation"
using namespace mlir;
using namespace triton;
void PropagateUnrealizedCastDown::handleBlockifyLoop(
scf::ForOp blockifyLoop, Operation *op, PatternRewriter &rewriter) const {
SmallVector<Value> newOperands;
for (auto opr : op->getOperands()) {
auto uccOp = opr.getDefiningOp<UnrealizedConversionCastOp>();
if (!uccOp) {
newOperands.push_back(opr);
continue;
}
auto input = uccOp.getInputs()[0];
auto tensorType = cast<RankedTensorType>(input.getType());
Value newOperand;
if (tensorType.getRank() > 1) {
SmallVector<OpFoldResult> offsets(tensorType.getRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes(1, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> strides(tensorType.getRank(),
rewriter.getIndexAttr(1));
offsets[0] = blockifyLoop.getInductionVar();
for (auto dim : llvm::drop_begin(tensorType.getShape()))
sizes.push_back(rewriter.getIndexAttr(dim));
newOperand = rewriter.create<tensor::ExtractSliceOp>(
input.getLoc(), cast<RankedTensorType>(opr.getType()), input, offsets,
sizes, strides);
} else {
newOperand = rewriter.create<tensor::ExtractOp>(
input.getLoc(), input, ValueRange{blockifyLoop.getInductionVar()});
if (isa<IndexType>(opr.getType())) {
newOperand = rewriter.create<arith::IndexCastOp>(
input.getLoc(), rewriter.getIndexType(), newOperand);
}
}
newOperands.push_back(newOperand);
}
rewriter.modifyOpInPlace(op, [&]() { op->setOperands(newOperands); });
}
void PropagateUnrealizedCastDown::rewriteGeneraleOp(
UnrealizedConversionCastOp op, Operation *generalOp,
PatternRewriter &rewriter) const {
auto input = op.getInputs()[0];
auto mask = op.getInputs()[1];
auto res = op->getResult(0);
auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<Value> newOperands;
SmallVector<Value> newResults;
SmallVector<Type> newResultTypes;
for (auto operand : generalOp->getOperands())
newOperands.push_back(rewriteValue(operand, op, rewriter));
for (auto resType : generalOp->getResultTypes()) {
newResultTypes.push_back(getExpandedType(resType, op));
}
auto *newOp =
rewriter.create(generalOp->getLoc(), generalOp->getName().getIdentifier(),
newOperands, newResultTypes, generalOp->getAttrs());
replaceValue(newOp, generalOp, mask, rewriter);
}
void PropagateUnrealizedCastDown::rewriteSplat(
UnrealizedConversionCastOp op, triton::SplatOp splatOp,
PatternRewriter &rewriter) const {
auto input = op.getInputs()[0];
auto mask = op.getInputs()[1];
auto resType = cast<RankedTensorType>(splatOp.getResult().getType());
auto curShape =
llvm::to_vector(cast<RankedTensorType>(input.getType()).getShape());
auto splatedShape = resType.getShape();
for (auto dim : splatedShape) {
input = rewriter.create<triton::ExpandDimsOp>(input.getLoc(), input,
curShape.size());
curShape.push_back(dim);
input = rewriter.create<triton::BroadcastOp>(
input.getLoc(),
RankedTensorType::get(curShape, getElementTypeOrSelf(input)), input);
}
replaceValue(input.getDefiningOp(), splatOp, mask, rewriter);
}
void PropagateUnrealizedCastDown::rewriteExpandDims(
UnrealizedConversionCastOp op, triton::ExpandDimsOp expandDimsOp,
PatternRewriter &rewriter) const {
auto input = op.getInputs()[0];
auto mask = op.getInputs()[1];
auto newOp = rewriter.create<triton::ExpandDimsOp>(
expandDimsOp.getLoc(), input, expandDimsOp.getAxis() + 1);
for (auto attr : expandDimsOp->getAttrs()) {
if (!newOp->hasAttr(attr.getName()))
newOp->setAttr(attr.getName(), attr.getValue());
}
replaceValue(newOp, expandDimsOp, mask, rewriter);
}
void PropagateUnrealizedCastDown::rewriteReduce(
UnrealizedConversionCastOp op, triton::ReduceOp reduceOp,
PatternRewriter &rewriter) const {
auto mask = op.getInputs()[1];
auto srcs = llvm::map_to_vector(reduceOp.getSrcs(), [&](Value src) {
return rewriteValue(src, op, rewriter);
});
auto newOp = rewriter.create<triton::ReduceOp>(reduceOp.getLoc(), srcs,
reduceOp.getAxis() + 1);
auto &newCombineOp = newOp.getCombineOp();
rewriter.cloneRegionBefore(reduceOp.getCombineOp(), newCombineOp,
newCombineOp.end());
for (auto attr : reduceOp->getAttrs()) {
if (!newOp->hasAttr(attr.getName()))
newOp->setAttr(attr.getName(), attr.getValue());
}
replaceValue(newOp, reduceOp, mask, rewriter);
}
void PropagateUnrealizedCastDown::rewriteScan(UnrealizedConversionCastOp op,
triton::ScanOp scanOp,
PatternRewriter &rewriter) const {
auto mask = op.getInputs()[1];
auto srcs = llvm::map_to_vector(scanOp.getSrcs(), [&](Value src) {
return rewriteValue(src, op, rewriter);
});
auto newOp = rewriter.create<triton::ScanOp>(
scanOp.getLoc(), srcs, scanOp.getAxis() + 1, scanOp.getReverse());
auto &newCombineOp = newOp.getCombineOp();
rewriter.cloneRegionBefore(scanOp.getCombineOp(), newCombineOp,
newCombineOp.end());
for (auto attr : scanOp->getAttrs()) {
if (!newOp->hasAttr(attr.getName()))
newOp->setAttr(attr.getName(), attr.getValue());
}
replaceValue(newOp, scanOp, mask, rewriter);
}
void PropagateUnrealizedCastDown::rewriteLoad(UnrealizedConversionCastOp op,
triton::LoadOp loadOp,
PatternRewriter &rewriter) const {
auto uccMask = op.getInputs()[1];
auto ptr = rewriteValue(loadOp.getPtr(), op, rewriter);
auto other = rewriteValue(loadOp.getOther(), op, rewriter);
auto mask = rewriteValue(loadOp.getMask(), op, rewriter);
auto res = loadOp.getResult();
auto resType = getExpandedType(res.getType(), op);
if (!other) {
other = rewriter.create<arith::ConstantOp>(
rewriter.getUnknownLoc(),
DenseElementsAttr::get(
resType, rewriter.getZeroAttr(getElementTypeOrSelf(res))));
}
mask = createMask(mask, uccMask, resType.getShape(), rewriter);
auto boundaryCheck = llvm::map_to_vector(loadOp.getBoundaryCheck(),
[](int32_t idx) { return idx + 1; });
auto newOp = rewriter.create<triton::LoadOp>(
loadOp.getLoc(), ptr, mask, other, boundaryCheck, loadOp.getPadding(),
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
for (auto attr : loadOp->getAttrs()) {
if (!newOp->hasAttr(attr.getName()))
newOp->setAttr(attr.getName(), attr.getValue());
}
replaceValue(newOp, loadOp, uccMask, rewriter);
}
void PropagateUnrealizedCastDown::rewriteStore(
UnrealizedConversionCastOp op, triton::StoreOp storeOp,
PatternRewriter &rewriter) const {
auto uccMask = op.getInputs()[1];
auto ptr = rewriteValue(storeOp.getPtr(), op, rewriter);
auto value = rewriteValue(storeOp.getValue(), op, rewriter);
auto mask = rewriteValue(storeOp.getMask(), op, rewriter);
auto ptrShape = cast<RankedTensorType>(ptr.getType()).getShape();
mask = createMask(mask, uccMask, ptrShape, rewriter);
auto boundaryCheck = llvm::map_to_vector(storeOp.getBoundaryCheck(),
[](int32_t idx) { return idx + 1; });
auto newOp = rewriter.create<triton::StoreOp>(
storeOp.getLoc(), ptr, value, mask, boundaryCheck, storeOp.getCache(),
storeOp.getEvict());
for (auto attr : storeOp->getAttrs()) {
if (!newOp->hasAttr(attr.getName()))
newOp->setAttr(attr.getName(), attr.getValue());
}
rewriter.replaceOp(storeOp, newOp);
}
void PropagateUnrealizedCastDown::rewriteAtomicRMW(
UnrealizedConversionCastOp op, triton::AtomicRMWOp atomicRMWOp,
PatternRewriter &rewriter) const {
auto uccMask = op.getInputs()[1];
auto ptr = rewriteValue(atomicRMWOp.getPtr(), op, rewriter);
auto val = rewriteValue(atomicRMWOp.getVal(), op, rewriter);
auto mask = rewriteValue(atomicRMWOp.getMask(), op, rewriter);
auto resType = getExpandedType(atomicRMWOp.getResult().getType(), op);
mask = createMask(mask, uccMask, resType.getShape(), rewriter);
auto newOp = rewriter.create<triton::AtomicRMWOp>(
atomicRMWOp.getLoc(), resType, atomicRMWOp.getAtomicRmwOp(), ptr, val,
mask, atomicRMWOp.getSem(), atomicRMWOp.getScope());
for (auto attr : atomicRMWOp->getAttrs()) {
if (!newOp->hasAttr(attr.getName()))
newOp->setAttr(attr.getName(), attr.getValue());
}
replaceValue(newOp, atomicRMWOp, uccMask, rewriter);
}
void PropagateUnrealizedCastDown::rewriteAssert(
UnrealizedConversionCastOp op, triton::AssertOp assertOp,
PatternRewriter &rewriter) const {
auto input = op.getInputs()[0];
auto mask = op.getInputs()[1];
auto inputShape = cast<RankedTensorType>(input.getType()).getShape();
auto conditionType = cast<RankedTensorType>(mask.getType());
auto oneAttr = rewriter.getIntegerAttr(getElementTypeOrSelf(mask), 1);
auto one = rewriter.create<arith::ConstantOp>(
mask.getLoc(), DenseElementsAttr::get(conditionType, oneAttr));
Value condition = rewriter.create<arith::XOrIOp>(input.getLoc(), mask, one);
condition = createMask(nullptr, condition, inputShape, rewriter);
condition =
rewriter.create<arith::OrIOp>(condition.getLoc(), condition, input);
auto newOp = rewriter.create<triton::AssertOp>(assertOp.getLoc(), condition,
assertOp.getMessage());
for (auto attr : assertOp->getAttrs()) {
if (!newOp->hasAttr(attr.getName()))
newOp->setAttr(attr.getName(), attr.getValue());
}
rewriter.replaceOp(assertOp, newOp);
}
void PropagateUnrealizedCastDown::rewriteExtractSlice(
UnrealizedConversionCastOp op, tensor::ExtractSliceOp extractSliceOp,
PatternRewriter &rewriter) const {
auto mask = op.getInputs()[1];
auto src = rewriteValue(extractSliceOp.getSource(), op, rewriter);
auto offsets = llvm::to_vector(extractSliceOp.getMixedOffsets());
auto sizes = llvm::to_vector(extractSliceOp.getMixedSizes());
auto strides = llvm::to_vector(extractSliceOp.getMixedStrides());
auto srcType = cast<RankedTensorType>(src.getType());
offsets.insert(offsets.begin(), rewriter.getIndexAttr(0));
sizes.insert(sizes.begin(), rewriter.getIndexAttr(srcType.getShape()[0]));
strides.insert(strides.begin(), rewriter.getIndexAttr(1));
auto newOp = rewriter.create<tensor::ExtractSliceOp>(
extractSliceOp.getLoc(), src, offsets, sizes, strides);
auto newMask = rewriter.create<tensor::ExtractSliceOp>(
mask.getLoc(), mask, offsets, sizes, strides);
for (auto attr : extractSliceOp->getAttrs()) {
if (!newOp->hasAttr(attr.getName()))
newOp->setAttr(attr.getName(), attr.getValue());
}
replaceValue(newOp, extractSliceOp, newMask, rewriter);
}
void PropagateUnrealizedCastDown::rewriteInsertSlice(
UnrealizedConversionCastOp op, tensor::InsertSliceOp insertSliceOp,
PatternRewriter &rewriter) const {
auto mask = op.getInputs()[1];
auto src = rewriteValue(insertSliceOp.getSource(), op, rewriter);
auto dst = rewriteValue(insertSliceOp.getDest(), op, rewriter);
auto offsets = llvm::to_vector(insertSliceOp.getMixedOffsets());
auto sizes = llvm::to_vector(insertSliceOp.getMixedSizes());
auto strides = llvm::to_vector(insertSliceOp.getMixedStrides());
auto srcType = cast<RankedTensorType>(src.getType());
offsets.insert(offsets.begin(), rewriter.getIndexAttr(0));
sizes.insert(sizes.begin(), rewriter.getIndexAttr(srcType.getShape()[0]));
strides.insert(strides.begin(), rewriter.getIndexAttr(1));
auto newOp = rewriter.create<tensor::InsertSliceOp>(
insertSliceOp.getLoc(), src, dst, offsets, sizes, strides);
for (auto attr : insertSliceOp->getAttrs()) {
if (!newOp->hasAttr(attr.getName()))
newOp->setAttr(attr.getName(), attr.getValue());
}
replaceValue(newOp, insertSliceOp, mask, rewriter);
}
void PropagateUnrealizedCastDown::rewriteWhile(
UnrealizedConversionCastOp op, scf::WhileOp whileOp,
PatternRewriter &rewriter) const {
auto input = op.getInputs()[0];
auto mask = op.getInputs()[1];
auto res = op->getResult(0);
SmallVector<int64_t> indices;
SmallVector<Value> newInits;
IRMapping mapping;
for (auto [idx, init] : llvm::enumerate(whileOp.getInits())) {
if (init == res) {
indices.push_back(idx);
newInits.push_back(input);
} else {
newInits.push_back(init);
}
}
auto newOp = rewriter.create<scf::WhileOp>(
whileOp.getLoc(), whileOp->getResultTypes(), newInits,
[&](OpBuilder &b, Location loc, ValueRange args) {
mapRegionIterArg(mapping, whileOp.getBeforeArguments(), args, indices,
mask, b);
for (auto &bodyOp : *whileOp.getBeforeBody())
b.clone(bodyOp, mapping);
},
[&](OpBuilder &b, Location loc, ValueRange args) {
mapRegionIterArg(mapping, whileOp.getAfterArguments(), args, {}, mask,
b);
for (auto &bodyOp : whileOp.getAfterBody()->without_terminator())
b.clone(bodyOp, mapping);
auto yieldOp =
cast<scf::YieldOp>(whileOp.getAfterBody()->getTerminator());
mapYieldedValue(mapping, yieldOp, indices, op, b);
});
for (auto attr : whileOp->getAttrs()) {
if (!newOp->hasAttr(attr.getName()))
newOp->setAttr(attr.getName(), attr.getValue());
}
rewriter.replaceOp(whileOp, newOp);
}
void PropagateUnrealizedCastDown::rewriteLoop(UnrealizedConversionCastOp op,
LoopLikeOpInterface loopOp,
PatternRewriter &rewriter) const {
auto input = op.getInputs()[0];
auto mask = op.getInputs()[1];
auto res = op->getResult(0);
SmallVector<int64_t> indices;
SmallVector<Value> newInits;
IRMapping mapping;
for (auto [idx, init] : llvm::enumerate(loopOp.getInits())) {
if (init == res) {
indices.push_back(idx);
newInits.push_back(input);
} else {
newInits.push_back(init);
}
}
LoopLikeOpInterface newOp;
if (auto forOp = dyn_cast<scf::ForOp>(loopOp.getOperation())) {
newOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newInits,
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
mapping.map(forOp.getInductionVar(), iv);
mapRegionIterArg(mapping, forOp.getRegionIterArgs(), args, indices,
mask, b);
for (auto &bodyOp : forOp.getBody()->without_terminator())
b.clone(bodyOp, mapping);
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
mapYieldedValue(mapping, yieldOp, indices, op, b);
});
for (auto attr : forOp->getAttrs()) {
if (!newOp->hasAttr(attr.getName()))
newOp->setAttr(attr.getName(), attr.getValue());
}
} else {
llvm_unreachable("Unhandled loopOp");
}
replaceValue(newOp, loopOp, mask, rewriter, indices);
}
void PropagateUnrealizedCastDown::rewriteIf(UnrealizedConversionCastOp &op,
scf::IfOp ifOp,
ArrayRef<int64_t> indices,
PatternRewriter &rewriter) const {
IRMapping mapping;
auto mask = op.getInputs()[1];
auto thenBlockBuilder = [&](OpBuilder &b, Location loc) {
for (auto &bodyOp : *ifOp.thenBlock())
b.clone(bodyOp, mapping);
};
auto elseBlockBuilder = [&](OpBuilder &b, Location loc) {
for (auto &bodyOp : *ifOp.elseBlock())
b.clone(bodyOp, mapping);
};
scf::IfOp newOp;
if (ifOp.elseBlock()) {
newOp = rewriter.create<scf::IfOp>(ifOp.getLoc(), ifOp.getCondition(),
thenBlockBuilder, elseBlockBuilder);
} else {
newOp = rewriter.create<scf::IfOp>(ifOp.getLoc(), ifOp.getCondition(),
thenBlockBuilder, nullptr);
}
for (auto attr : ifOp->getAttrs()) {
if (!newOp->hasAttr(attr.getName()))
newOp->setAttr(attr.getName(), attr.getValue());
}
if (mapping.contains(op))
op = cast<UnrealizedConversionCastOp>(mapping.lookup(op));
replaceValue(newOp, ifOp, mask, rewriter, indices);
}
void PropagateUnrealizedCastDown::rewriteYield(
UnrealizedConversionCastOp &op, scf::YieldOp yieldOp,
PatternRewriter &rewriter) const {
auto input = op.getInputs()[0];
auto mask = op.getInputs()[1];
auto res = op->getResult(0);
SmallVector<int64_t> indices;
auto newOperands = llvm::to_vector(yieldOp.getOperands());
for (auto [idx, opr] : llvm::enumerate(newOperands)) {
if (opr == res)
indices.push_back(idx);
}
if (auto loopOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp())) {
auto uccOp = rewriter.create<UnrealizedConversionCastOp>(
op.getLoc(), res.getType(), ValueRange({input}));
for (auto curIdx : indices)
newOperands[curIdx] = uccOp->getResult(0);
auto newOp = rewriter.create<scf::YieldOp>(yieldOp.getLoc(), newOperands);
for (auto attr : yieldOp->getAttrs()) {
if (!newOp->hasAttr(attr.getName()))
newOp->setAttr(attr.getName(), attr.getValue());
}
rewriter.replaceOp(yieldOp, newOp);
rewriter.setInsertionPoint(loopOp);
for (auto curIdx : indices) {
auto &initArg = loopOp.getInitsMutable()[curIdx];
auto initVal = initArg.get();
uccOp = rewriter.create<UnrealizedConversionCastOp>(
initVal.getLoc(), input.getType(), ValueRange({initVal}));
uccOp = rewriter.create<UnrealizedConversionCastOp>(
initVal.getLoc(), initVal.getType(),
ValueRange({uccOp->getResult(0), mask}));
rewriter.modifyOpInPlace(loopOp,
[&]() { initArg.set(uccOp->getResult(0)); });
}
} else if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
for (auto curIdx : indices)
newOperands[curIdx] = input;
auto newOp = rewriter.create<scf::YieldOp>(yieldOp.getLoc(), newOperands);
for (auto attr : yieldOp->getAttrs()) {
if (!newOp->hasAttr(attr.getName()))
newOp->setAttr(attr.getName(), attr.getValue());
}
rewriter.replaceOp(yieldOp, newOp);
yieldOp = ifOp.thenYield() == yieldOp ? ifOp.elseYield() : ifOp.thenYield();
if (yieldOp) {
rewriter.setInsertionPoint(yieldOp);
newOperands = llvm::to_vector(yieldOp.getOperands());
for (auto curIdx : indices) {
auto uccOp = rewriter.create<UnrealizedConversionCastOp>(
op.getLoc(), input.getType(), ValueRange({newOperands[curIdx]}));
newOperands[curIdx] = uccOp->getResult(0);
}
rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newOperands);
}
rewriter.setInsertionPoint(ifOp);
rewriteIf(op, ifOp, indices, rewriter);
}
}
void PropagateUnrealizedCastDown::rewriteCondition(
UnrealizedConversionCastOp op, scf::ConditionOp conditionOp,
PatternRewriter &rewriter) const {
auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
auto input = op.getInputs()[0];
auto mask = op.getInputs()[1];
auto res = op->getResult(0);
int64_t curIdx = -1;
auto args = llvm::to_vector(conditionOp.getArgs());
for (auto [idx, opr] : llvm::enumerate(args)) {
if (opr == res)
curIdx = idx;
}
args[curIdx] = input;
auto newOp = rewriter.create<scf::ConditionOp>(
conditionOp.getLoc(), conditionOp.getCondition(), args);
for (auto attr : conditionOp->getAttrs()) {
if (!newOp->hasAttr(attr.getName()))
newOp->setAttr(attr.getName(), attr.getValue());
}
rewriter.replaceOp(conditionOp, newOp);
res = whileOp->getResult(curIdx);
auto oldResType = res.getType();
auto newResType = getExpandedType(oldResType, op);
rewriter.modifyOpInPlace(whileOp, [&]() { res.setType(newResType); });
rewriter.setInsertionPointAfter(whileOp);
auto newUccOp = rewriter.create<UnrealizedConversionCastOp>(
res.getLoc(), oldResType, ValueRange({res, mask}));
rewriter.replaceAllUsesExcept(res, newUccOp->getResult(0), newUccOp);
auto arg = whileOp.getAfterArguments()[curIdx];
auto oldArgType = arg.getType();
auto newArgType = getExpandedType(oldArgType, op);
rewriter.modifyOpInPlace(whileOp, [&]() { arg.setType(newArgType); });
rewriter.setInsertionPointToStart(whileOp.getAfterBody());
newUccOp = rewriter.create<UnrealizedConversionCastOp>(
arg.getLoc(), oldArgType, ValueRange({arg, mask}));
rewriter.replaceAllUsesExcept(arg, newUccOp->getResult(0), newUccOp);
}