#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include <deque>
namespace mlir::triton::gpu {
#define GEN_PASS_DEF_TRITONGPUREMOVELAYOUTCONVERSIONS
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
#define DEBUG_TYPE "tritongpu-remove-layout-conversions"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
namespace {
class LayoutPropagation {
public:
struct LayoutInfo {
LayoutInfo(Attribute encoding) { encodings.insert(encoding); }
LayoutInfo() {}
llvm::SmallSetVector<Attribute, 8> encodings;
};
LayoutPropagation(FuncOp F) : funcOp(F) {}
void initAnchorLayout();
void propagateLayout();
SmallVector<Value> propagateToUsers(Value value, LayoutInfo &info);
void setEncoding(ValueRange values, LayoutInfo &info,
SmallVector<Value> &changed, Operation *op);
void resolveConflicts();
void rewrite();
void rewriteRegion(Region &R);
Operation *rewriteOp(Operation *op);
Operation *rewriteForOp(scf::ForOp forOp);
Operation *rewriteWhileOp(scf::WhileOp whileOp);
Operation *rewriteIfOp(scf::IfOp ifOp);
void rewriteYieldOp(scf::YieldOp yieldOp);
void rewriteConditionOp(scf::ConditionOp conditionOp);
void rewriteReduceToScalar(Operation *reduceOp);
void rewriteAssertOp(AssertOp assertOp);
Operation *cloneElementwise(OpBuilder &rewriter, Operation *op,
Attribute encoding);
void map(Value old, Value newV);
Value getValueAs(Value value, Attribute encoding);
Value getRewrittenValue(Value value);
void dump();
private:
llvm::MapVector<Value, LayoutInfo> layouts;
DenseMap<std::pair<Value, Attribute>, Value> rewriteMapping;
SetVector<Operation *> opToDelete;
FuncOp funcOp;
};
class LayoutRematerialization {
public:
LayoutRematerialization(FuncOp F) : funcOp(F) {}
void addRematValue(Value old, Attribute encoding, Value newV);
Value getRematValue(Value value, Attribute encoding) const {
return rematMapping.lookup({value, encoding});
}
void cleanup();
void backwardRematerialization();
void backwardRematerialization(ConvertLayoutOp convertOp);
void hoistConvertDotOperand();
void hoistConvertDotOperand(ConvertLayoutOp convertOp);
void hoistConvertOnTopOfExtOrBroadcast();
void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp);
void hoistConvertIntoConditionals();
void hoistConvertIntoConditionals(ConvertLayoutOp convertOp);
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp, IRMapping &mapping);
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp);
LogicalResult
getConvertBackwardSlice(OpOperand &root, Attribute rootEncoding,
SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation);
LogicalResult getRematerializableSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation = nullptr);
private:
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
DenseMap<Value, Attribute> mappedValues;
DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
SetVector<Operation *> opToDelete;
FuncOp funcOp;
DominanceInfo domInfo;
PostDominanceInfo postDomInfo;
};
void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
Value newV) {
LDBG("addRematValue " << old << " encoding " << encoding << " " << newV);
rematMapping[{old, encoding}] = newV;
mappedValues[old] = encoding;
}
void LayoutRematerialization::cleanup() {
for (Operation *op : llvm::reverse(opToDelete))
op->erase();
}
bool isLayoutAnchor(Operation *op) {
if (isa<LoadOp, StoreOp>(op))
return isExpensiveLoadOrStore(op);
if (isa<DotOp, DotScaledOp, nvidia_gpu::WarpGroupDotOp, AtomicRMWOp,
AtomicCASOp, triton::nvidia_gpu::TMEMLoadOp>(op))
return true;
if (auto gatherOp = dyn_cast<GatherOp>(op))
return gatherOp.getEfficientLayout();
if (auto reshape = dyn_cast<ReshapeOp>(op))
return reshape.getAllowReorder();
return false;
}
void LayoutPropagation::initAnchorLayout() {
auto addAnchor = [&](Value v) {
if (auto tensorType = dyn_cast<RankedTensorType>(v.getType())) {
layouts.insert({v, LayoutInfo(tensorType.getEncoding())});
}
};
for (auto arg : funcOp.getArguments()) {
addAnchor(arg);
}
funcOp.walk([&](Operation *op) {
if (isLayoutAnchor(op)) {
for (auto result : op->getResults()) {
addAnchor(result);
}
}
});
}
void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info,
SmallVector<Value> &changed,
Operation *op) {
for (Value value : values) {
if (!isa<RankedTensorType>(value.getType()))
continue;
bool hasChanged = false;
for (auto encoding : info.encodings) {
Attribute dstEncoding;
if (isa<ConvertLayoutOp>(op)) {
dstEncoding = encoding;
} else {
dstEncoding = inferDstEncoding(op, encoding);
}
if (dstEncoding)
hasChanged |= layouts[value].encodings.insert(dstEncoding);
}
if (hasChanged)
changed.push_back(value);
}
}
SmallVector<Value> LayoutPropagation::propagateToUsers(Value value,
LayoutInfo &info) {
SmallVector<Value> changed;
for (OpOperand &use : value.getUses()) {
Operation *user = use.getOwner();
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
Value arg = forOp.getTiedLoopRegionIterArg(&use);
Value result = forOp.getTiedLoopResult(&use);
setEncoding({arg, result}, info, changed, user);
continue;
}
if (auto whileOp = dyn_cast<scf::WhileOp>(user)) {
Value arg = whileOp.getBeforeArguments()[use.getOperandNumber()];
setEncoding({arg}, info, changed, user);
continue;
}
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
auto parent = yieldOp->getParentOp();
SmallVector<Value> valuesToPropagate;
if (isa<scf::ForOp, scf::IfOp, scf::WhileOp>(parent))
valuesToPropagate.push_back(parent->getResult(use.getOperandNumber()));
if (auto forOp = dyn_cast<scf::ForOp>(parent))
valuesToPropagate.push_back(
forOp.getRegionIterArg(use.getOperandNumber()));
if (auto whileOp = dyn_cast<scf::WhileOp>(parent))
valuesToPropagate.push_back(
whileOp.getBeforeArguments()[use.getOperandNumber()]);
if (isa<scf::ForOp, scf::IfOp, scf::WhileOp>(parent))
setEncoding(valuesToPropagate, info, changed, user);
continue;
}
if (auto conditionOp = dyn_cast<scf::ConditionOp>(user)) {
auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
unsigned argIndex = use.getOperandNumber() - 1;
Value afterArg = whileOp.getAfterArguments()[argIndex];
Value result = whileOp->getResult(argIndex);
setEncoding({afterArg, result}, info, changed, user);
continue;
}
if (auto dotWaitOp = dyn_cast<nvidia_gpu::WarpGroupDotWaitOp>(user)) {
unsigned opIndex = use.getOperandNumber();
Value result = dotWaitOp->getResult(opIndex);
setEncoding(result, info, changed, user);
continue;
}
if (auto gatherOp = dyn_cast<GatherOp>(user)) {
if (!gatherOp.getEfficientLayout() &&
&use == &gatherOp.getIndicesMutable()) {
setEncoding(gatherOp.getResult(), info, changed, user);
continue;
}
}
if (user->hasTrait<OpTrait::SameOperandsAndResultEncoding>() ||
user->hasTrait<OpTrait::Elementwise>() ||
isa<ReduceOp, ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp,
ConvertLayoutOp>(user)) {
setEncoding(user->getResults(), info, changed, user);
continue;
}
}
return changed;
}
void LayoutPropagation::propagateLayout() {
SmallVector<Value> queue;
for (auto it : layouts) {
queue.push_back(it.first);
}
while (!queue.empty()) {
Value currentValue = queue.back();
LayoutInfo info = layouts[currentValue];
queue.pop_back();
SmallVector<Value> changed = propagateToUsers(currentValue, info);
LLVM_DEBUG({
DBGS() << "propagateLayout considering " << currentValue << ", which has "
<< info.encodings.size() << " candidate encoding(s):\n";
for (Attribute encoding : info.encodings)
DBGS() << " " << encoding << "\n";
DBGS() << "changed: " << changed.size() << "\n";
});
queue.insert(queue.end(), changed.begin(), changed.end());
}
}
void LayoutPropagation::resolveConflicts() {
for (auto &it : layouts) {
Operation *op = it.first.getDefiningOp();
LayoutInfo &info = it.second;
if (info.encodings.size() <= 1)
continue;
Attribute encoding = *info.encodings.begin();
bool isLoadOrStore =
op && isa<LoadOp, StoreOp, AtomicRMWOp, AtomicCASOp>(op);
for (Attribute e : info.encodings) {
if ((isLoadOrStore && isa<BlockedEncodingAttr>(e)) ||
(!isLoadOrStore && isa<MmaEncodingTrait>(e))) {
encoding = e;
break;
}
}
info.encodings.clear();
info.encodings.insert(encoding);
}
}
void LayoutPropagation::dump() {
for (auto it : layouts) {
llvm::errs() << "Value: ";
OpPrintingFlags flags;
flags.skipRegions();
it.first.print(llvm::errs(), flags);
llvm::errs() << " \n encoding:\n";
for (auto encoding : it.second.encodings) {
encoding.print(llvm::errs());
llvm::errs() << "\n";
}
llvm::errs() << "--\n";
}
}
void LayoutPropagation::rewrite() { rewriteRegion(funcOp->getRegion(0)); }
bool reduceToScalar(Operation *op) {
return isa<ReduceOp>(op) && !isa<RankedTensorType>(op->getResultTypes()[0]);
}
void LayoutPropagation::rewriteRegion(Region ®ion) {
std::deque<Region *> queue = {®ion};
while (!queue.empty()) {
Region *currentRegion = queue.front();
queue.pop_front();
for (Operation &op : currentRegion->getOps()) {
bool needRewrite = false;
SmallVector<Value> results = op.getResults();
for (Value result : results) {
auto it = layouts.find(result);
if (it == layouts.end())
continue;
LayoutInfo &info = it->second;
assert(info.encodings.size() == 1 &&
"we should have resolved to a single encoding");
auto encoding = cast<RankedTensorType>(result.getType()).getEncoding();
if (encoding == *info.encodings.begin())
continue;
needRewrite = true;
}
if (needRewrite) {
Operation *newOp = rewriteOp(&op);
for (Region &R : newOp->getRegions())
queue.push_back(&R);
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(&op)) {
rewriteYieldOp(yieldOp);
} else if (auto conditionOp = dyn_cast<scf::ConditionOp>(&op)) {
rewriteConditionOp(conditionOp);
} else if (reduceToScalar(&op)) {
rewriteReduceToScalar(&op);
} else if (auto assertOp = dyn_cast<AssertOp>(&op)) {
rewriteAssertOp(assertOp);
} else {
for (OpOperand &operand : op.getOpOperands()) {
auto it = layouts.find(operand.get());
if (it == layouts.end())
continue;
Attribute encoding =
cast<RankedTensorType>(operand.get().getType()).getEncoding();
Value newOperand = getValueAs(operand.get(), encoding);
op.setOperand(operand.getOperandNumber(), newOperand);
}
for (Region &R : op.getRegions())
queue.push_back(&R);
}
}
}
for (Operation *op : llvm::reverse(opToDelete))
op->erase();
}
void LayoutPropagation::map(Value old, Value newV) {
rewriteMapping[{old, cast<RankedTensorType>(newV.getType()).getEncoding()}] =
newV;
}
Value LayoutPropagation::getRewrittenValue(Value value) {
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
if (!tensorType)
return value;
auto layoutIt = layouts.find(value);
if (layoutIt == layouts.end()) {
return value;
}
assert(layoutIt->second.encodings.size() == 1 &&
"we should have resolved to a single encoding");
Attribute encodingPicked = *(layoutIt->second.encodings.begin());
if (encodingPicked == tensorType.getEncoding())
return value;
return rewriteMapping.at({value, encodingPicked});
}
Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
if (auto tensorType = dyn_cast<RankedTensorType>(value.getType())) {
Value rewrittenValue = getRewrittenValue(value);
if (cast<RankedTensorType>(rewrittenValue.getType()).getEncoding() ==
encoding)
return rewrittenValue;
OpBuilder rewriter(value.getContext());
rewriter.setInsertionPointAfterValue(rewrittenValue);
auto tmpType = tensorType.cloneWithEncoding(encoding);
Value converted = rewriter.create<ConvertLayoutOp>(value.getLoc(), tmpType,
rewrittenValue);
return converted;
}
return value;
}
Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter,
Operation *op,
Attribute encoding) {
Operation *newOp = rewriter.clone(*op);
Attribute operandEnc;
if (op->getNumOperands() > 0) {
for (auto operand : op->getOperands()) {
auto ty =
dyn_cast<RankedTensorType>(getRewrittenValue(operand).getType());
if (!ty)
continue;
auto enc = ty.getEncoding();
if (inferDstEncoding(op, enc) == encoding) {
operandEnc = enc;
break;
}
}
if (!operandEnc)
operandEnc = inferSrcEncoding(op, encoding);
assert(operandEnc);
}
for (OpOperand &operand : op->getOpOperands()) {
newOp->setOperand(operand.getOperandNumber(),
getValueAs(operand.get(), operandEnc));
}
for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) {
auto origType = dyn_cast<RankedTensorType>(op->getResult(i).getType());
if (!origType)
continue;
auto newType = origType.cloneWithEncoding(encoding);
newOp->getResult(i).setType(newType);
}
return newOp;
}
Operation *LayoutPropagation::rewriteForOp(scf::ForOp forOp) {
SmallVector<Value> operands;
OpBuilder rewriter(forOp);
for (auto [operand, result] :
llvm::zip(forOp.getInitArgs(), forOp.getResults())) {
Value convertedOperand = operand;
if (layouts.count(result))
convertedOperand =
getValueAs(operand, *layouts[result].encodings.begin());
operands.push_back(convertedOperand);
}
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), operands);
newForOp->setAttrs(forOp->getAttrs());
newForOp.getBody()->getOperations().splice(
newForOp.getBody()->getOperations().begin(),
forOp.getBody()->getOperations());
for (auto [oldResult, newResult] :
llvm::zip(forOp.getResults(), newForOp.getResults())) {
if (oldResult.getType() == newResult.getType()) {
oldResult.replaceAllUsesWith(newResult);
continue;
}
map(oldResult, newResult);
}
for (auto [oldArg, newArg] : llvm::zip(forOp.getBody()->getArguments(),
newForOp.getBody()->getArguments())) {
if (oldArg.getType() == newArg.getType()) {
oldArg.replaceAllUsesWith(newArg);
continue;
}
map(oldArg, newArg);
}
return newForOp.getOperation();
}
Operation *LayoutPropagation::rewriteWhileOp(scf::WhileOp whileOp) {
SmallVector<Value> operands;
SmallVector<Type> returnTypes;
OpBuilder rewriter(whileOp);
for (auto [operand, arg] :
llvm::zip(whileOp->getOperands(), whileOp.getBeforeArguments())) {
Value convertedOperand = operand;
if (layouts.count(arg))
convertedOperand = getValueAs(operand, *layouts[arg].encodings.begin());
operands.push_back(convertedOperand);
}
for (Value ret : whileOp.getResults()) {
auto it = layouts.find(ret);
if (it == layouts.end()) {
returnTypes.push_back(ret.getType());
continue;
}
auto origType = dyn_cast<RankedTensorType>(ret.getType());
auto newType = origType.cloneWithEncoding(it->second.encodings[0]);
returnTypes.push_back(newType);
}
auto newWhileOp =
rewriter.create<scf::WhileOp>(whileOp.getLoc(), returnTypes, operands);
SmallVector<Type> argsTypesBefore;
for (Value operand : operands)
argsTypesBefore.push_back(operand.getType());
SmallVector<Location> bbArgLocsBefore(argsTypesBefore.size(),
whileOp.getLoc());
SmallVector<Location> bbArgLocsAfter(returnTypes.size(), whileOp.getLoc());
rewriter.createBlock(&newWhileOp.getBefore(), {}, argsTypesBefore,
bbArgLocsBefore);
rewriter.createBlock(&newWhileOp.getAfter(), {}, returnTypes, bbArgLocsAfter);
for (int i = 0; i < whileOp.getNumRegions(); ++i) {
newWhileOp->getRegion(i).front().getOperations().splice(
newWhileOp->getRegion(i).front().getOperations().begin(),
whileOp->getRegion(i).front().getOperations());
}
auto remapArg = [&](Value oldVal, Value newVal) {
if (oldVal.getType() == newVal.getType())
oldVal.replaceAllUsesWith(newVal);
else
map(oldVal, newVal);
};
for (auto [oldResult, newResult] :
llvm::zip(whileOp.getResults(), newWhileOp.getResults()))
remapArg(oldResult, newResult);
for (auto [oldArg, newArg] :
llvm::zip(whileOp.getBeforeArguments(), newWhileOp.getBeforeArguments()))
remapArg(oldArg, newArg);
for (auto [oldArg, newArg] :
llvm::zip(whileOp.getAfterArguments(), newWhileOp.getAfterArguments()))
remapArg(oldArg, newArg);
return newWhileOp.getOperation();
}
Operation *LayoutPropagation::rewriteIfOp(scf::IfOp ifOp) {
SmallVector<Value> operands;
OpBuilder rewriter(ifOp);
SmallVector<Type> newResultTypes(ifOp->getResultTypes());
for (unsigned i = 0, e = ifOp->getNumResults(); i < e; ++i) {
auto it = layouts.find(ifOp->getResult(i));
if (it == layouts.end())
continue;
auto origType = cast<RankedTensorType>(ifOp->getResult(i).getType());
Attribute encoding = *(it->second.encodings.begin());
newResultTypes[i] = origType.cloneWithEncoding(encoding);
}
auto newIfOp = rewriter.create<scf::IfOp>(ifOp.getLoc(), newResultTypes,
ifOp.getCondition(), true, true);
newIfOp.getThenRegion().takeBody(ifOp.getThenRegion());
newIfOp.getElseRegion().takeBody(ifOp.getElseRegion());
for (auto [oldResult, newResult] :
llvm::zip(ifOp.getResults(), newIfOp.getResults())) {
if (oldResult.getType() == newResult.getType()) {
oldResult.replaceAllUsesWith(newResult);
continue;
}
map(oldResult, newResult);
}
return newIfOp.getOperation();
}
void LayoutPropagation::rewriteYieldOp(scf::YieldOp yieldOp) {
Operation *parentOp = yieldOp->getParentOp();
for (OpOperand &operand : yieldOp->getOpOperands()) {
Type yieldType = operand.get().getType();
if (isa<scf::ForOp, scf::IfOp>(parentOp))
yieldType = parentOp->getResult(operand.getOperandNumber()).getType();
if (auto whileOp = dyn_cast<scf::WhileOp>(parentOp))
yieldType =
whileOp.getBeforeArguments()[operand.getOperandNumber()].getType();
auto tensorType = dyn_cast<RankedTensorType>(yieldType);
if (!tensorType)
continue;
Value newOperand = getValueAs(operand.get(), tensorType.getEncoding());
yieldOp->setOperand(operand.getOperandNumber(), newOperand);
}
}
void LayoutPropagation::rewriteConditionOp(scf::ConditionOp conditionOp) {
scf::WhileOp whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
for (unsigned i = 1; i < conditionOp->getNumOperands(); ++i) {
OpOperand &operand = conditionOp->getOpOperand(i);
Type argType = whileOp->getResult(operand.getOperandNumber() - 1).getType();
auto tensorType = dyn_cast<RankedTensorType>(argType);
if (!tensorType)
continue;
Value newOperand = getValueAs(operand.get(), tensorType.getEncoding());
conditionOp->setOperand(operand.getOperandNumber(), newOperand);
}
}
void LayoutPropagation::rewriteReduceToScalar(Operation *reduceOp) {
OpBuilder rewriter(reduceOp);
Attribute srcEncoding;
for (Value operand : reduceOp->getOperands()) {
auto it = layouts.find(operand);
if (it != layouts.end()) {
srcEncoding = it->second.encodings[0];
break;
}
}
if (!srcEncoding)
return;
for (OpOperand &operand : reduceOp->getOpOperands()) {
Value newOperand = getValueAs(operand.get(), srcEncoding);
reduceOp->setOperand(operand.getOperandNumber(), newOperand);
}
}
void LayoutPropagation::rewriteAssertOp(AssertOp assertOp) {
Attribute srcEncoding;
Value operand = assertOp->getOperand(0);
auto it = layouts.find(operand);
if (it == layouts.end())
return;
srcEncoding = it->second.encodings[0];
Value newOperand = getValueAs(operand, srcEncoding);
assertOp->setOperand(0, newOperand);
}
Operation *LayoutPropagation::rewriteOp(Operation *op) {
opToDelete.insert(op);
if (auto forOp = dyn_cast<scf::ForOp>(op))
return rewriteForOp(forOp);
if (auto whileOp = dyn_cast<scf::WhileOp>(op))
return rewriteWhileOp(whileOp);
if (auto ifOp = dyn_cast<scf::IfOp>(op))
return rewriteIfOp(ifOp);
OpBuilder rewriter(op);
Attribute encoding = *layouts[op->getResult(0)].encodings.begin();
if (auto convertOp = dyn_cast<ConvertLayoutOp>(op)) {
Attribute srcEncoding = convertOp.getSrc().getType().getEncoding();
auto it = layouts.find(convertOp.getSrc());
if (it != layouts.end())
srcEncoding = *(it->second.encodings.begin());
Value src = getValueAs(convertOp.getSrc(), srcEncoding);
auto tensorType = cast<RankedTensorType>(op->getResult(0).getType());
auto newType = tensorType.cloneWithEncoding(encoding);
auto cvt = rewriter.create<ConvertLayoutOp>(op->getLoc(), newType, src);
map(op->getResult(0), cvt.getResult());
return cvt.getOperation();
}
if (canFoldIntoConversion(op, encoding)) {
Operation *newOp = rewriter.clone(*op);
auto tensorType = cast<RankedTensorType>(op->getResult(0).getType());
auto newType = tensorType.cloneWithEncoding(encoding);
auto cvt = rewriter.create<ConvertLayoutOp>(op->getLoc(), newType,
newOp->getResult(0));
map(op->getResult(0), cvt.getResult());
return cvt.getOperation();
}
if (op->hasTrait<OpTrait::SameOperandsAndResultEncoding>() ||
op->hasTrait<OpTrait::Elementwise>() ||
isa<ReduceOp, ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp, GatherOp,
ConvertLayoutOp, nvidia_gpu::WarpGroupDotWaitOp>(op)) {
Operation *newOp = cloneElementwise(rewriter, op, encoding);
for (auto [oldResult, newResult] :
llvm::zip(op->getResults(), newOp->getResults())) {
if (oldResult.getType() == newResult.getType()) {
oldResult.replaceAllUsesWith(newResult);
continue;
}
map(oldResult, newResult);
}
return newOp;
}
llvm::report_fatal_error("unexpected op in rewrite");
return nullptr;
}
bool canBeRemat(Operation *op) {
if (isa<LoadOp, StoreOp>(op))
return !isExpensiveLoadOrStore(op);
if (isa<AtomicRMWOp, AtomicCASOp, DotOp>(op))
return false;
if (auto gather = dyn_cast<GatherOp>(op))
return !gather.getEfficientLayout();
if (isa<scf::WhileOp, scf::ConditionOp>(op))
return false;
return true;
}
void LayoutRematerialization::updateRematMapping(
SmallVector<std::tuple<Value, Value>> &values) {
for (auto [old, newV] : values) {
auto it = mappedValues.find(old);
if (it != mappedValues.end()) {
Attribute encoding = it->second;
auto rematIt = rematMapping.find({old, it->second});
assert(rematIt != rematMapping.end());
Value replacedValue = rematIt->second;
rematMapping.erase(rematIt);
mappedValues.erase(it);
for (auto [before, after] : values) {
if (before == replacedValue) {
replacedValue = after;
break;
}
}
rematMapping[{newV, encoding}] = replacedValue;
mappedValues[newV] = encoding;
}
}
}
void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp,
IRMapping &mapping) {
SetVector<Operation *> opsToRewrite;
DenseMap<Operation *, SmallVector<int>> yieldOperandsMap;
SetVector<Value> valuesWithExistingRemat;
for (Value v : slice) {
auto layoutIt = layout.find(v);
assert(layoutIt != layout.end());
if (Value remat = getRematValue(v, layoutIt->second)) {
mapping.map(v, remat);
valuesWithExistingRemat.insert(v);
continue;
}
if (v.getDefiningOp()) {
opsToRewrite.insert(v.getDefiningOp());
if (auto ifOp = v.getDefiningOp<scf::IfOp>()) {
unsigned operandIdx = cast<OpResult>(v).getResultNumber();
opsToRewrite.insert(ifOp.thenYield().getOperation());
yieldOperandsMap[ifOp.thenYield()].push_back(operandIdx);
opsToRewrite.insert(ifOp.elseYield().getOperation());
yieldOperandsMap[ifOp.elseYield()].push_back(operandIdx);
}
} else {
BlockArgument blockArg = cast<BlockArgument>(v);
Operation *parentOp = blockArg.getOwner()->getParentOp();
if (auto loopOp = cast<LoopLikeOpInterface>(parentOp)) {
opsToRewrite.insert(loopOp.getOperation());
OpOperand *operand = loopOp.getTiedLoopYieldedValue(blockArg);
auto yieldOp = blockArg.getOwner()->getTerminator();
yieldOperandsMap[yieldOp].push_back(operand->getOperandNumber());
opsToRewrite.insert(yieldOp);
}
}
}
slice.set_subtract(valuesWithExistingRemat);
opsToRewrite = multiRootTopologicalSort(opsToRewrite);
SmallVector<std::tuple<Value, Value>> replacements;
SmallVector<Operation *> deadOps;
IRRewriter builder(slice.begin()->getContext());
for (Operation *op : opsToRewrite) {
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
SmallVector<std::pair<size_t, size_t>> argMapping;
SmallVector<Value> newOperands;
for (auto arg : forOp.getRegionIterArgs()) {
if (slice.count(arg)) {
OpOperand &initVal = *forOp.getTiedLoopInit(arg);
argMapping.push_back(std::make_pair(
forOp.getTiedLoopResult(&initVal).getResultNumber(),
forOp.getInitArgs().size() + newOperands.size()));
newOperands.push_back(mapping.lookup(initVal.get()));
}
}
scf::ForOp newForOp = replaceForOpWithNewSignature(
builder, forOp, newOperands, replacements);
deadOps.push_back(forOp.getOperation());
Block &loopBody = *newForOp.getBody();
for (auto m : argMapping) {
mapping.map(forOp.getResult(m.first), newForOp.getResult(m.second));
int numIndVars = newForOp.getNumInductionVars();
mapping.map(loopBody.getArgument(m.first + numIndVars),
loopBody.getArgument(m.second + numIndVars));
LLVM_DEBUG({
DBGS() << "mapping forOp "
<< loopBody.getArgument(m.first + numIndVars) << " to "
<< loopBody.getArgument(m.second + numIndVars) << '\n';
});
Value oldArg = loopBody.getArgument(m.first + numIndVars);
addRematValue(newForOp.getResult(m.first), layout[oldArg],
newForOp.getResult(m.second));
addRematValue(oldArg, layout[oldArg],
loopBody.getArgument(m.second + numIndVars));
}
continue;
}
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
SmallVector<Type> newTypes;
for (auto res : ifOp.getResults()) {
if (slice.count(res)) {
auto it = layout.find(res);
assert(it != layout.end());
auto oldType = cast<RankedTensorType>(res.getType());
auto newType = oldType.cloneWithEncoding(it->second);
newTypes.push_back(newType);
}
}
scf::IfOp newIfOp =
replaceIfOpWithNewSignature(builder, ifOp, newTypes, replacements);
unsigned oldIdx = 0;
unsigned newIdx = ifOp.getNumResults();
for (auto res : ifOp.getResults()) {
if (slice.count(res)) {
mapping.map(ifOp.getResult(oldIdx), newIfOp.getResult(newIdx));
addRematValue(ifOp.getResult(oldIdx), layout[res],
newIfOp.getResult(newIdx));
++newIdx;
}
++oldIdx;
}
deadOps.push_back(ifOp.getOperation());
continue;
}
builder.setInsertionPoint(op);
if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
auto yieldOperands = llvm::to_vector(yieldOp.getOperands());
SmallVector<int> operandsToRewrite = yieldOperandsMap[op];
std::sort(operandsToRewrite.begin(), operandsToRewrite.end());
for (int operandIdx : operandsToRewrite) {
yieldOperands.push_back(mapping.lookup(yieldOp.getOperand(operandIdx)));
}
builder.create<scf::YieldOp>(op->getLoc(), yieldOperands);
op->erase();
continue;
}
if (isa<arith::ConstantOp>(op)) {
Operation *newOp = builder.clone(*op);
auto tensorType = cast<RankedTensorType>(op->getResult(0).getType());
auto newType = tensorType.cloneWithEncoding(layout[op->getResult(0)]);
auto cvt = builder.create<ConvertLayoutOp>(op->getLoc(), newType,
newOp->getResult(0));
mapping.map(op->getResult(0), cvt.getResult());
addRematValue(op->getResult(0), layout[op->getResult(0)],
cvt.getResult());
continue;
}
Operation *newOp = builder.clone(*op, mapping);
for (auto [old, newV] : llvm::zip(op->getResults(), newOp->getResults())) {
auto it = layout.find(old);
if (it == layout.end())
continue;
auto newType =
cast<RankedTensorType>(old.getType()).cloneWithEncoding(it->second);
newV.setType(newType);
addRematValue(old, it->second, newV);
}
}
convertOp.replaceAllUsesWith(mapping.lookup(convertOp.getSrc()));
opToDelete.insert(convertOp);
updateRematMapping(replacements);
for (auto &kv : replacements) {
builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv));
}
for (Operation *op : deadOps)
opToDelete.insert(op);
}
void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp) {
IRMapping mapping;
rewriteSlice(slice, layout, convertOp, mapping);
}
LogicalResult LayoutRematerialization::getConvertBackwardSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation) {
auto getExistingConversion = [&](OpOperand &value, Attribute encoding) {
Value remat = getRematValue(value.get(), encoding);
if (!remat)
return Value();
Operation *user = value.getOwner();
if (domInfo.properlyDominates(remat, user)) {
return remat;
}
return Value();
};
return mlir::getConvertBackwardSlice(root, slice, rootEncoding, layout,
stopPropagation, getExistingConversion);
}
LogicalResult LayoutRematerialization::getRematerializableSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation) {
LogicalResult result = getConvertBackwardSlice(root, rootEncoding, slice,
layout, stopPropagation);
if (result.failed() || slice.empty())
return failure();
for (Value v : slice) {
if (Operation *op = v.getDefiningOp()) {
if (!canBeRemat(op))
return failure();
}
}
return success();
}
void LayoutRematerialization::backwardRematerialization() {
SmallVector<ConvertLayoutOp> convertOps;
funcOp.walk(
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
for (ConvertLayoutOp convertOp : convertOps) {
backwardRematerialization(convertOp);
if (!opToDelete.contains(convertOp)) {
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
convertOp.getResult());
}
}
}
void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
SmallVector<ConvertLayoutOp> convertOps;
funcOp.walk(
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
for (ConvertLayoutOp convertOp : convertOps) {
hoistConvertOnTopOfExtOrBroadcast(convertOp);
if (!opToDelete.contains(convertOp)) {
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
convertOp.getResult());
}
}
}
void LayoutRematerialization::hoistConvertIntoConditionals() {
SmallVector<ConvertLayoutOp> convertOps;
funcOp.walk(
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
for (ConvertLayoutOp convertOp : convertOps) {
hoistConvertIntoConditionals(convertOp);
if (!opToDelete.contains(convertOp)) {
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
convertOp.getResult());
}
}
}
static bool isExpensiveMathOp(Operation *op) {
return isa<arith::DivFOp, math::ErfcOp, math::SinhOp, math::CoshOp,
math::TanhOp, math::AsinhOp, math::AcoshOp, math::AtanhOp,
math::CtPopOp, math::CountLeadingZerosOp,
math::CountTrailingZerosOp, math::ExpOp, math::Exp2Op,
math::ExpM1Op, math::LogOp, math::Log2Op, math::Log10Op,
math::Log1pOp, math::SinOp, math::CosOp, math::TanOp, math::AsinOp,
math::AcosOp, math::AtanOp, math::Atan2Op, math::PowFOp,
math::SqrtOp, math::RsqrtOp, math::ErfOp, math::CbrtOp>(op);
}
static int64_t getByteCount(Value result, int64_t minElementCount = 0,
int64_t minBitWidth = 0) {
int64_t elementCount = 0;
int64_t dtypeBitWidth = 0;
if (auto tensorTy = dyn_cast<RankedTensorType>(result.getType())) {
elementCount = tensorTy.getNumElements();
auto elemType = tensorTy.getElementType();
if (elemType.isIntOrFloat()) {
dtypeBitWidth = elemType.getIntOrFloatBitWidth();
}
}
if (elementCount < minElementCount) {
elementCount = minElementCount;
}
if (dtypeBitWidth < minBitWidth) {
dtypeBitWidth = minBitWidth;
}
return (elementCount * dtypeBitWidth) >> 3;
}
void LayoutRematerialization::backwardRematerialization(
ConvertLayoutOp convertOp) {
RankedTensorType targetType = convertOp.getType();
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
return;
Value oldV = convertOp.getSrc();
LDBG("check backward remat with source " << oldV << " encoding "
<< targetType.getEncoding());
Value newV = getRematValue(oldV, targetType.getEncoding());
if (newV && domInfo.properlyDominates(newV, convertOp)) {
convertOp.replaceAllUsesWith(newV);
opToDelete.insert(convertOp);
LDBG("found remat'ed value" << newV);
return;
}
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
LogicalResult result = getRematerializableSlice(
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout);
if (result.failed()) {
LDBG(" getRematerializableSlice failed");
return;
}
SetVector<Operation *> sliceOps;
for (Value v : slice) {
if (Operation *op = v.getDefiningOp()) {
sliceOps.insert(op);
}
}
DenseMap<Operation *, bool> isSingleUse;
std::function<bool(Operation *)> isOpSingleUse;
isOpSingleUse = [&](Operation *op) -> bool {
auto it = isSingleUse.find(op);
if (it != isSingleUse.end()) {
return it->second;
}
bool singleUse = true;
for (Value result : op->getResults()) {
for (Operation *user : result.getUsers()) {
if (user == convertOp) {
continue;
}
if (sliceOps.contains(user)) {
if (!isOpSingleUse(user)) {
singleUse = false;
break;
}
} else {
singleUse = false;
break;
}
}
if (!singleUse) {
break;
}
}
isSingleUse[op] = singleUse;
return singleUse;
};
int64_t convertLayoutBytes = getByteCount(convertOp.getSrc(), 32, 32);
int64_t convertLayoutCost = 32 * convertLayoutBytes;
int64_t rematerialisationCost = 0;
for (Operation *op : sliceOps) {
auto dialect = op->getDialect();
if (isOpSingleUse(op)) {
continue;
} else if (isa<arith::ConstantOp>(op)) {
continue;
} else if (isa<LoadOp>(op) || isa<LocalLoadOp>(op)) {
for (Value result : op->getResults()) {
rematerialisationCost += 8 * getByteCount(result);
}
} else if (isa<arith::ArithDialect, math::MathDialect>(dialect)) {
int64_t multiplier = isExpensiveMathOp(op) ? 8 : 1;
for (Value result : op->getResults()) {
rematerialisationCost += multiplier * getByteCount(result);
}
} else if (isa<ReduceOp>(op)) {
auto reduceOp = dyn_cast<ReduceOp>(op);
ReduceOpHelper helper(reduceOp);
if (!helper.isAssociative()) {
LDBG(" skipped rematerialization due to non-associative reduce in the "
"slice");
return;
}
rematerialisationCost += helper.getIntraWarpSizeWithUniqueData();
rematerialisationCost += 8 * helper.getInterWarpSizeWithUniqueData();
}
}
LLVM_DEBUG({
DBGS() << " convert layout cost: " << convertLayoutCost << "\n";
DBGS() << " rematerialisation cost: " << rematerialisationCost << "\n";
});
if (rematerialisationCost > convertLayoutCost) {
LDBG(" skipped rematerialization due to higher cost");
return;
}
LLVM_DEBUG({
DBGS() << " remat convert op " << convertOp << '\n';
for (Value v : slice)
DBGS() << " " << v << '\n';
});
rewriteSlice(slice, layout, convertOp);
}
void LayoutRematerialization::hoistConvertDotOperand() {
SmallVector<ConvertLayoutOp> convertOps;
funcOp.walk(
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
for (ConvertLayoutOp convertOp : convertOps) {
hoistConvertDotOperand(convertOp);
if (!opToDelete.contains(convertOp)) {
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
convertOp.getResult());
}
}
}
void LayoutRematerialization::hoistConvertDotOperand(
ConvertLayoutOp convertOp) {
auto targetType = convertOp.getType();
auto canBePipelined = [&](ConvertLayoutOp convertOp) {
auto parent = convertOp->getParentOp();
if (!parent)
return false;
SmallVector<Operation *> dotLikeOps;
parent->walk([&](Operation *op) {
if (!isa<mlir::triton::DotOpInterface>(op))
return;
auto opType = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
if (!opType)
return;
auto dotEnc = dyn_cast<DotOperandEncodingAttr>(opType.getEncoding());
if (!dotEnc)
return;
if (isa<MmaEncodingTrait>(dotEnc.getParent()))
dotLikeOps.push_back(op);
});
if (dotLikeOps.empty())
return false;
return llvm::any_of(dotLikeOps, [&](Operation *dot) {
return postDomInfo.postDominates(dot, convertOp);
});
};
if (!canBePipelined(convertOp))
return;
auto noDataMovement = [](Operation *op) {
return (op->hasTrait<OpTrait::Elementwise>() && isMemoryEffectFree(op)) ||
isa<BroadcastOp, Fp4ToFpOp, ConvertLayoutOp>(op) || isView(op);
};
auto stop = std::not_fn(noDataMovement);
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
LogicalResult result = getConvertBackwardSlice(
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout, stop);
if (result.failed())
return;
IRMapping mapping;
OpBuilder builder(convertOp.getContext());
SetVector<Value> innerSlice;
for (Value v : slice) {
if (!v.getDefiningOp()) {
LLVM_DEBUG(
{ DBGS() << " Block arguments not supported. Got " << v << "\n"; });
return;
}
if (!isa<LoadOp, DescriptorLoadOp>(v.getDefiningOp())) {
auto op = v.getDefiningOp();
if (isa<arith::ConstantOp>(op) || noDataMovement(op)) {
innerSlice.insert(v);
continue;
} else {
LLVM_DEBUG({
DBGS() << " Leaves must be Load, DescriptorLoad or Constant. Got "
<< v << "\n";
});
return;
}
}
Operation *loadOp = v.getDefiningOp();
builder.setInsertionPointAfter(loadOp);
auto type = dyn_cast<RankedTensorType>(loadOp->getResult(0).getType());
if (!type)
continue;
auto newType = type.cloneWithEncoding(layout[loadOp->getResult(0)]);
auto newConvertOp = builder.create<ConvertLayoutOp>(
convertOp.getLoc(), newType, loadOp->getResult(0));
mapping.map(loadOp->getResult(0), newConvertOp.getResult());
}
if (innerSlice.empty()) {
return;
}
LLVM_DEBUG({
DBGS() << " Hoisting " << convertOp << '\n';
for (Value v : innerSlice)
DBGS() << " " << v << '\n';
});
rewriteSlice(innerSlice, layout, convertOp, mapping);
}
void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
ConvertLayoutOp convertOp) {
RankedTensorType targetType = convertOp.getType();
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
return;
auto isExtOrBroadcastOp = [](Operation *op) {
if (isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp, BroadcastOp,
ExpandDimsOp>(op)) {
return true;
}
if (auto fpToFpOp = dyn_cast<FpToFpOp>(op)) {
auto srcType = cast<RankedTensorType>(fpToFpOp.getOperand().getType());
return getElementBitWidth(srcType) <
getElementBitWidth(cast<RankedTensorType>(fpToFpOp.getType()));
}
return false;
};
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
LogicalResult result = getRematerializableSlice(
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout,
isExtOrBroadcastOp);
if (result.failed())
return;
Operation *extOrBroadcastOp = nullptr;
unsigned sliceSize = slice.size();
for (unsigned i = 0; i < sliceSize; i++) {
Value v = slice[i];
Operation *op = v.getDefiningOp();
if (!op)
continue;
if (isExtOrBroadcastOp(op)) {
SetVector<Value> tempSlice;
DenseMap<Value, Attribute> tempLayout;
Attribute srcEncoding = inferSrcEncoding(op, layout[v]);
if (!srcEncoding)
return;
LogicalResult result = getRematerializableSlice(
op->getOpOperand(0), srcEncoding, tempSlice, tempLayout);
for (auto [val, enc] : tempLayout) {
auto preexistingLayout = layout.find(val);
if (preexistingLayout != layout.end() &&
preexistingLayout->second != enc) {
result = failure();
break;
}
}
if (result.succeeded()) {
slice.insert(tempSlice.begin(), tempSlice.end());
layout.insert(tempLayout.begin(), tempLayout.end());
continue;
}
if (extOrBroadcastOp != nullptr)
return;
extOrBroadcastOp = op;
}
}
if (extOrBroadcastOp == nullptr)
return;
Attribute dstEncoding = layout[extOrBroadcastOp->getResult(0)];
Attribute srcEncoding = inferSrcEncoding(extOrBroadcastOp, dstEncoding);
if (!srcEncoding)
return;
OpBuilder builder(extOrBroadcastOp);
auto tensorType =
cast<RankedTensorType>(extOrBroadcastOp->getOperand(0).getType());
auto newType = tensorType.cloneWithEncoding(srcEncoding);
auto newConvertOp = builder.create<ConvertLayoutOp>(
convertOp.getLoc(), newType, extOrBroadcastOp->getOperand(0));
Operation *newExtOrBroadcast = builder.clone(*extOrBroadcastOp);
newExtOrBroadcast->setOperand(0, newConvertOp.getResult());
auto oldExtOrBroadcastType =
cast<RankedTensorType>(extOrBroadcastOp->getResult(0).getType());
Type newExtOrBroadcastType =
oldExtOrBroadcastType.cloneWithEncoding(dstEncoding);
newExtOrBroadcast->getResult(0).setType(newExtOrBroadcastType);
IRMapping mapping;
mapping.map(extOrBroadcastOp->getResult(0), newExtOrBroadcast->getResult(0));
slice.remove(extOrBroadcastOp->getResult(0));
rewriteSlice(slice, layout, convertOp, mapping);
}
void LayoutRematerialization::hoistConvertIntoConditionals(
ConvertLayoutOp convertOp) {
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
auto isIfOp = [](Operation *op) { return isa<scf::IfOp>(op); };
if (failed(getRematerializableSlice(convertOp.getSrcMutable(),
convertOp.getType().getEncoding(), slice,
layout, isIfOp)))
return;
SmallVector<std::pair<Value, OpOperand *>> hoistAbove;
SmallVector<OpResult> terminals;
for (unsigned i = 0; i != slice.size(); ++i) {
Value v = slice[i];
auto ifOp = v.getDefiningOp<scf::IfOp>();
if (!ifOp)
continue;
Attribute rootLayout = layout.at(v);
unsigned resIdx = cast<OpResult>(v).getResultNumber();
auto thenYield =
cast<scf::YieldOp>(ifOp.getThenRegion().front().getTerminator());
auto elseYield =
cast<scf::YieldOp>(ifOp.getElseRegion().front().getTerminator());
OpOperand &thenRes = thenYield.getResultsMutable()[resIdx];
OpOperand &elseRes = elseYield.getResultsMutable()[resIdx];
SetVector<Value> thenSlice, elseSlice;
DenseMap<Value, Attribute> thenLayout, elseLayout;
LogicalResult thenResult = getRematerializableSlice(
thenRes, rootLayout, thenSlice, thenLayout, isIfOp);
LogicalResult elseResult = getRematerializableSlice(
elseRes, rootLayout, elseSlice, elseLayout, isIfOp);
if (succeeded(thenResult) && succeeded(elseResult)) {
slice.insert(thenSlice.begin(), thenSlice.end());
slice.insert(elseSlice.begin(), elseSlice.end());
layout.insert(thenLayout.begin(), thenLayout.end());
layout.insert(elseLayout.begin(), elseLayout.end());
continue;
}
if (failed(thenResult) && failed(elseResult)) {
terminals.push_back(cast<OpResult>(v));
continue;
}
if (!isa<scf::ForOp>(ifOp->getParentOp())) {
terminals.push_back(cast<OpResult>(v));
continue;
}
if (succeeded(thenResult)) {
hoistAbove.emplace_back(v, &elseRes);
slice.insert(thenSlice.begin(), thenSlice.end());
layout.insert(thenLayout.begin(), thenLayout.end());
} else {
hoistAbove.emplace_back(v, &thenRes);
slice.insert(elseSlice.begin(), elseSlice.end());
layout.insert(elseLayout.begin(), elseLayout.end());
}
}
if (hoistAbove.empty())
return;
IRMapping mapping;
auto hoistRemat = [&](OpBuilder &b, Value v, Attribute encoding) {
auto tensorType = cast<RankedTensorType>(v.getType());
auto newType = tensorType.cloneWithEncoding(encoding);
Value newCvt = b.create<ConvertLayoutOp>(convertOp.getLoc(), newType, v);
mapping.map(v, newCvt);
slice.remove(v);
};
for (Value v : terminals) {
OpBuilder b(v.getContext());
b.setInsertionPointAfter(v.getDefiningOp());
hoistRemat(b, v, layout.at(v));
}
for (auto [result, edge] : hoistAbove) {
OpBuilder b(edge->getOwner());
hoistRemat(b, edge->get(), layout.at(result));
}
rewriteSlice(slice, layout, convertOp, mapping);
}
void backwardRematerialization(ModuleOp module) {
module.walk([](FuncOp funcOp) {
LayoutRematerialization layoutRemat(funcOp);
layoutRemat.backwardRematerialization();
layoutRemat.cleanup();
});
}
void hoistConvert(ModuleOp module) {
SmallVector<ConvertLayoutOp> convertOps;
module.walk([](FuncOp funcOp) {
LayoutRematerialization layoutRemat(funcOp);
layoutRemat.hoistConvertOnTopOfExtOrBroadcast();
layoutRemat.cleanup();
layoutRemat = LayoutRematerialization(funcOp);
layoutRemat.hoistConvertIntoConditionals();
layoutRemat.cleanup();
layoutRemat = LayoutRematerialization(funcOp);
layoutRemat.hoistConvertDotOperand();
layoutRemat.cleanup();
});
}
}
class TritonGPURemoveLayoutConversionsPass
: public impl::TritonGPURemoveLayoutConversionsBase<
TritonGPURemoveLayoutConversionsPass> {
public:
void cleanupConvertOps() {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
RewritePatternSet cleanUpPatterns(context);
ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, context);
if (applyPatternsGreedily(m, std::move(cleanUpPatterns)).failed()) {
signalPassFailure();
}
LLVM_DEBUG({
DBGS() << "Module after canonicalizing:\n";
m.dump();
});
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
m.walk([](FuncOp funcOp) {
LayoutPropagation layoutPropagation(funcOp);
layoutPropagation.initAnchorLayout();
layoutPropagation.propagateLayout();
layoutPropagation.resolveConflicts();
layoutPropagation.rewrite();
});
LLVM_DEBUG({
DBGS() << "Module after propagating layouts forward:\n";
m.dump();
});
cleanupConvertOps();
backwardRematerialization(m);
LLVM_DEBUG({
DBGS() << "Module after backward remat:\n";
m.dump();
});
cleanupConvertOps();
hoistConvert(m);
LLVM_DEBUG({
DBGS() << "Module after hoisting converts:\n";
m.dump();
});
RewritePatternSet cleanUpPatterns2(context);
populateForOpDeadArgumentElimination(cleanUpPatterns2);
scf::ForOp::getCanonicalizationPatterns(cleanUpPatterns2, context);
scf::IfOp::getCanonicalizationPatterns(cleanUpPatterns2, context);
ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns2, context);
if (applyPatternsGreedily(m, std::move(cleanUpPatterns2)).failed()) {
signalPassFailure();
}
LLVM_DEBUG({
DBGS() << "Module after final cleanups:\n";
m.dump();
});
}
};
}