#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include <optional>
using namespace mlir;
using namespace mlir::linalg;
namespace {
struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override;
private:
GenericOp createPeeledGenericOp(GenericOp genericOp,
PatternRewriter &rewriter) const;
GenericOp createResidualGenericOp(GenericOp genericOp,
GenericOp peeledGenericOp,
PatternRewriter &rewriter) const;
};
}
static SmallVector<OpFoldResult> getGenericOpLoopRange(OpBuilder &b,
GenericOp op) {
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
Location loc = op.getLoc();
auto allShapesSizes =
cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc);
AffineMap map = op.getShapesToLoopsMap();
IRRewriter rewriter(b);
return affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
allShapesSizes);
}
SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values,
AffineMap map) {
assert(map.isPermutation());
SmallVector<OpFoldResult> permutedValues(values.size());
for (const auto &position :
llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) {
return cast<AffineDimExpr>(expr).getPosition();
})))
permutedValues[position.value()] = values[position.index()];
return permutedValues;
}
static Value getZero(OpBuilder &b, Location loc, Type elementType) {
assert(elementType.isIntOrIndexOrFloat() &&
"expected scalar type while computing zero value");
if (isa<IntegerType>(elementType))
return b.create<arith::ConstantIntOp>(loc, 0, elementType);
if (elementType.isIndex())
return b.create<arith::ConstantIndexOp>(loc, 0);
auto floatType = cast<FloatType>(elementType);
return b.create<arith::ConstantFloatOp>(
loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
}
GenericOp
DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
PatternRewriter &rewriter) const {
Block *body = genericOp.getBody();
Operation *peeledScalarOperation = &(*body->begin());
SmallVector<AffineMap> peeledGenericOpIndexingMaps =
genericOp.getIndexingMapsArray();
Location loc = genericOp.getLoc();
SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp);
SmallVector<Value> newInitValues;
SmallVector<Type> newResultTypes;
for (auto scalarOpResult : peeledScalarOperation->getResults()) {
std::optional<unsigned> resultNumber;
for (auto *user : scalarOpResult.getUsers()) {
if (auto yieldOp = dyn_cast<YieldOp>(user)) {
for (OpOperand &yieldOperand : yieldOp->getOpOperands()) {
if (yieldOperand.get() == scalarOpResult) {
resultNumber = yieldOperand.getOperandNumber();
break;
}
}
assert(resultNumber && "unable to find use of a value in its user");
break;
}
}
if (resultNumber) {
newInitValues.push_back(
genericOp.getDpsInitOperand(*resultNumber)->get());
OpResult result = cast<OpResult>(genericOp.getResult(*resultNumber));
newResultTypes.push_back(result.getType());
peeledGenericOpIndexingMaps.push_back(
genericOp.getIndexingMapMatchingResult(result));
continue;
}
AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size());
Value emptyTensor =
rewriter.create<tensor::EmptyOp>(loc, domain, scalarOpResult.getType());
newInitValues.push_back(emptyTensor);
newResultTypes.push_back(emptyTensor.getType());
peeledGenericOpIndexingMaps.push_back(indexingMap);
}
SmallVector<Value> outsOperands = genericOp.getOutputs();
outsOperands.append(newInitValues.begin(), newInitValues.end());
SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
resultTypes.append(newResultTypes.begin(), newResultTypes.end());
auto indexingMapAttr =
rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps);
return rewriter.create<GenericOp>(
loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr,
genericOp.getIteratorTypes(), nullptr, nullptr,
[](OpBuilder, Location, ValueRange) {});
}
GenericOp
DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
GenericOp peeledGenericOp,
PatternRewriter &rewriter) const {
SmallVector<Value> residualGenericOpOperands = genericOp.getInputs();
unsigned origNumResults = genericOp.getNumResults();
unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
SmallVector<Value> extraIns;
for (auto resultNum :
llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
extraIns.push_back(peeledGenericOp->getResult(resultNum));
residualGenericOpOperands.append(extraIns);
auto indexingMaps = llvm::to_vector(
llvm::map_range(genericOp.getDpsInputOperands(), [&](OpOperand *operand) {
return genericOp.getMatchingIndexingMap(operand);
}));
for (auto resultNum :
llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
OpResult result = cast<OpResult>(peeledGenericOp.getResult(resultNum));
indexingMaps.push_back(
peeledGenericOp.getIndexingMapMatchingResult(result));
}
for (OpOperand &outOperand : genericOp.getDpsInitsMutable())
indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand));
auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
return rewriter.create<GenericOp>(
genericOp->getLoc(), genericOp->getResultTypes(),
residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,
genericOp.getIteratorTypes(), nullptr, nullptr,
[](OpBuilder, Location, ValueRange) {});
}
LogicalResult
DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const {
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
return rewriter.notifyMatchFailure(genericOp,
"unhandled decomposition of operation "
"with non-parallel iterator types");
}
if (!genericOp.hasPureTensorSemantics()) {
return rewriter.notifyMatchFailure(
genericOp, "only operations with tensor semantics are handled");
}
if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) {
return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation();
})) {
return rewriter.notifyMatchFailure(
genericOp, "unhandled decomposition of generic op with out operand not "
"accessed using a permutation");
}
Block *body = genericOp.getBody();
if (body->getOperations().size() <= 2) {
return rewriter.notifyMatchFailure(genericOp,
"operation has less than 3 statements");
}
if (llvm::any_of(body->getOperations().begin()->getResultTypes(),
[](Type t) { return !t.isIntOrIndexOrFloat(); })) {
return rewriter.notifyMatchFailure(
&(*body->getOperations().begin()),
"expected return type to be only int, index or float");
}
GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
GenericOp residualGenericOp =
createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
Block *peeledGenericOpBody = peeledGenericOp.getBody();
Block *residualGenericOpBody = residualGenericOp.getBody();
assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&
"expected split generic ops to have empty region");
peeledGenericOpBody->getOperations().splice(
peeledGenericOpBody->begin(), body->getOperations(), body->begin());
residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(),
body->getOperations());
Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());
auto *yieldOp = residualGenericOpBody->getTerminator();
{
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToEnd(peeledGenericOpBody);
SmallVector<Value> yieldedVals;
for (auto origYield : yieldOp->getOperands()) {
if (origYield.getDefiningOp() == peeledScalarOperation) {
yieldedVals.push_back(origYield);
} else {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(peeledGenericOp);
yieldedVals.push_back(
getZero(rewriter, genericOp.getLoc(), origYield.getType()));
}
}
yieldedVals.append(llvm::to_vector(
llvm::map_range(peeledScalarOperation->getResults(),
[](OpResult opr) -> Value { return opr; })));
rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals);
}
unsigned origNumInputs = genericOp.getNumDpsInputs();
for (const auto &inputBlockArg :
llvm::enumerate(genericOp.getBody()->getArguments())) {
Value residualOpReplacementArg =
residualGenericOpBody->getArgument(inputBlockArg.index());
rewriter.replaceUsesWithIf(
inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) {
return use.getOwner()->getBlock() == residualGenericOpBody;
});
Value peeledOpReplacementArg =
peeledGenericOpBody->getArgument(inputBlockArg.index());
rewriter.replaceUsesWithIf(
inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) {
return use.getOwner()->getBlock() == peeledGenericOpBody;
});
}
SmallVector<Value> replacements;
for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
OpResult opr = dyn_cast<OpResult>(yieldValue.value());
if (!opr || opr.getOwner() != peeledScalarOperation)
replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
else
replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
}
{
SmallVector<Value> scalarReplacements;
unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();
scalarReplacements.reserve(peeledScalarOpNumResults);
for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
scalarReplacements.push_back(
residualGenericOpBody->getArgument(num + origNumInputs));
bool allUsesReplaced = false;
rewriter.replaceOpUsesWithinBlock(peeledScalarOperation, scalarReplacements,
residualGenericOpBody, &allUsesReplaced);
assert(!allUsesReplaced &&
"peeled scalar operation is erased when it wasnt expected to be");
}
rewriter.replaceOp(genericOp, replacements);
return success();
}
void mlir::linalg::populateDecomposeLinalgOpsPattern(
RewritePatternSet &patterns, bool removeDeadArgsAndResults) {
patterns.insert<DecomposeLinalgOp>(patterns.getContext());
if (removeDeadArgsAndResults)
populateEraseUnusedOperandsAndResultsPatterns(patterns);
}