#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
using namespace mlir;
using namespace mlir::linalg;
static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
if (!result.use_empty())
return false;
OpOperand *outputOpOperand =
genericOp.getDpsInitOperand(result.getResultNumber());
if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
return true;
BlockArgument outputArg =
genericOp.getRegionOutputArgs()[result.getResultNumber()];
if (!outputArg.hasOneUse())
return false;
Operation *argUserOp = *outputArg.user_begin();
if (!argUserOp->use_empty())
return false;
auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
if (!yieldOp)
return false;
if (yieldOp.getOperand(result.getResultNumber()) != outputArg)
return false;
return true;
}
namespace {
struct DeduplicateAndRemoveDeadOperandsAndResults
: public OpRewritePattern<GenericOp> {
DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx,
bool removeOutputs)
: OpRewritePattern<GenericOp>(ctx), removeOutputs(removeOutputs) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
SmallVector<OpOperand *> droppedOpOperands;
SmallVector<Value> newInputOperands, newOutputOperands;
SmallVector<AffineMap> newIndexingMaps;
llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
newIndexingMaps);
llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
deduplicateOutputOperands(genericOp, droppedOpOperands,
newOutputOperands, newIndexingMaps);
if (newInputOperands.size() + newOutputOperands.size() ==
genericOp->getNumOperands())
return failure();
Location loc = genericOp.getLoc();
SmallVector<Type> newResultTypes;
for (Value v : newOutputOperands)
if (isa<TensorType>(v.getType()))
newResultTypes.push_back(v.getType());
auto newOp = rewriter.create<GenericOp>(
loc, newResultTypes, newInputOperands, newOutputOperands,
rewriter.getAffineMapArrayAttr(newIndexingMaps),
genericOp.getIteratorTypes(), genericOp.getDocAttr(),
genericOp.getLibraryCallAttr(),
[](OpBuilder & , Location , ValueRange ) {
return;
});
ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
for (NamedAttribute kv : genericOp->getAttrs())
if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
newOp->setAttr(kv.getName(), kv.getValue());
populateOpPayload(genericOp, newOp, origInsToNewInsPos,
origOutsToNewOutsPos, rewriter);
SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
for (const auto &result : llvm::enumerate(genericOp.getResults())) {
auto it = origOutsToNewOutsPos.find(result.index());
if (it == origOutsToNewOutsPos.end())
continue;
replacementsVals[result.index()] = newOp.getResult(it->second);
}
rewriter.replaceOp(genericOp, replacementsVals);
return success();
}
private:
bool removeOutputs;
llvm::SmallDenseMap<unsigned, unsigned>
deduplicateInputOperands(GenericOp genericOp,
SmallVector<OpOperand *> &droppedOpOperands,
SmallVector<Value> &newInputOperands,
SmallVector<AffineMap> &newIndexingMaps) const {
llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
OpOperand *inputOpOperand = en.value();
if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
droppedOpOperands.push_back(inputOpOperand);
if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
continue;
droppedOpOperands.pop_back();
}
AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
auto it = dedupedInputs.find(
std::make_pair(inputOpOperand->get(), indexingMap));
if (it != dedupedInputs.end()) {
origToNewPos[en.index()] = it->second;
droppedOpOperands.push_back(inputOpOperand);
continue;
}
origToNewPos[en.index()] = newInputOperands.size();
dedupedInputs[{inputOpOperand->get(), indexingMap}] =
newInputOperands.size();
newInputOperands.push_back(inputOpOperand->get());
newIndexingMaps.push_back(indexingMap);
}
return origToNewPos;
}
llvm::SmallDenseMap<unsigned, unsigned>
deduplicateOutputOperands(GenericOp genericOp,
SmallVector<OpOperand *> &droppedOpOperands,
SmallVector<Value> &newOutputOperands,
SmallVector<AffineMap> &newIndexingMaps) const {
llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
dedupedOutpts;
if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
origToNewPos[en.index()] = newOutputOperands.size();
newOutputOperands.push_back(en.value().get());
newIndexingMaps.push_back(
genericOp.getMatchingIndexingMap(&en.value()));
}
return origToNewPos;
}
auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
for (const auto &outputOpOperand :
llvm::enumerate(genericOp.getDpsInitsMutable())) {
OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
AffineMap indexingMap =
genericOp.getMatchingIndexingMap(&outputOpOperand.value());
auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
yieldOp->getOperand(outputOpOperand.index()));
if (isResultValueDead(genericOp, result)) {
droppedOpOperands.push_back(&outputOpOperand.value());
if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
continue;
}
droppedOpOperands.pop_back();
}
if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
auto it = dedupedOutpts.find(key);
if (it != dedupedOutpts.end()) {
origToNewPos[outputOpOperand.index()] = it->second;
droppedOpOperands.push_back(&outputOpOperand.value());
continue;
}
}
origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
dedupedOutpts[key] = newOutputOperands.size();
newOutputOperands.push_back(outputOpOperand.value().get());
newIndexingMaps.push_back(
genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
}
return origToNewPos;
}
void populateOpPayload(
GenericOp genericOp, GenericOp newOp,
const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
PatternRewriter &rewriter) const {
Block *newOpBlock = &newOp.getRegion().front();
assert(newOpBlock->empty() && "expected new op to have an empty payload");
Block *origOpBlock = &genericOp.getRegion().front();
SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
auto updateReplacements =
[&](SmallVector<OpOperand *> &origOperands,
SmallVector<OpOperand *> &newOperands,
const llvm::SmallDenseMap<unsigned, unsigned> &map) {
for (const auto &origOperand : llvm::enumerate(origOperands)) {
auto it = map.find(origOperand.index());
if (it == map.end())
continue;
OpOperand *newOperand = newOperands[it->second];
replacements[origOperand.value()->getOperandNumber()] =
newOpBlock->getArgument(newOperand->getOperandNumber());
}
};
SmallVector<OpOperand *> origInputOperands =
genericOp.getDpsInputOperands();
SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands();
updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
SmallVector<OpOperand *> origOutputOperands =
llvm::to_vector(llvm::map_range(genericOp.getDpsInitsMutable(),
[](OpOperand &o) { return &o; }));
SmallVector<OpOperand *> newOutputOperands =
llvm::to_vector(llvm::map_range(newOp.getDpsInitsMutable(),
[](OpOperand &o) { return &o; }));
updateReplacements(origOutputOperands, newOutputOperands,
origOutsToNewOutsPos);
if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
OpBuilder::InsertionGuard g(rewriter);
YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
rewriter.setInsertionPoint(origYieldOp);
SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
for (const auto &yieldOpOperands :
llvm::enumerate(origYieldOp.getValues())) {
auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
if (it == origOutsToNewOutsPos.end())
continue;
newYieldVals[it->second] = yieldOpOperands.value();
}
rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
}
rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
}
};
struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasPureTensorSemantics())
return failure();
bool hasRemovedCycles = false;
for (const auto &outputOpOperand :
llvm::enumerate(genericOp.getDpsInits())) {
Value result = genericOp.getResult(outputOpOperand.index());
if (!result.use_empty())
continue;
BlockArgument outputArg =
genericOp.getRegionOutputArgs()[outputOpOperand.index()];
if (!outputArg.hasOneUse())
continue;
Operation *cycleOp = *outputArg.user_begin();
if (!cycleOp->hasOneUse())
continue;
Operation *cycleUserOp = *cycleOp->user_begin();
if (!isa<linalg::YieldOp>(cycleUserOp))
continue;
if (cycleUserOp->getOperand(outputOpOperand.index()) !=
cycleOp->getResult(0))
continue;
rewriter.replaceOp(cycleOp, outputArg);
rewriter.modifyOpInPlace(genericOp, [] {});
hasRemovedCycles = true;
}
if (hasRemovedCycles) {
return success();
}
return failure();
}
};
struct FoldDuplicateInputBbArgs : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
DenseMap<int, int> replacements;
for (int i = 0; i < genericOp.getNumDpsInputs(); ++i) {
if (genericOp.getBody()->getArgument(i).getUses().empty())
continue;
for (int j = genericOp->getNumOperands() - 1; j > i; --j) {
if (genericOp->getOperand(i) == genericOp->getOperand(j) &&
genericOp.getIndexingMapsArray()[i] ==
genericOp.getIndexingMapsArray()[j]) {
replacements[i] = j;
break;
}
}
}
if (replacements.empty())
return failure();
rewriter.modifyOpInPlace(genericOp, [&]() {
for (auto [before, after] : replacements) {
BlockArgument bbArg = genericOp.getBody()->getArgument(before);
BlockArgument replacement = genericOp.getBody()->getArgument(after);
rewriter.replaceAllUsesWith(bbArg, replacement);
}
});
return success();
}
};
}
void mlir::linalg::populateEraseUnusedOperandsAndResultsPatterns(
RewritePatternSet &patterns) {
patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
patterns.getContext(), true);
patterns.insert<RemoveUnusedCycleInGenericOp>(patterns.getContext());
}
void mlir::linalg::populateEraseUnnecessaryInputsPatterns(
RewritePatternSet &patterns) {
patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
patterns.getContext(), false);
patterns.insert<FoldDuplicateInputBbArgs>(patterns.getContext());
}