#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <iterator>
#include <memory>
#include <utility>
namespace mlir {
#define GEN_PASS_DEF_LINALGDETENSORIZEPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
}
using namespace mlir;
using namespace mlir::linalg;
static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
auto inputType = inputs[0].getType();
if (isa<TensorType>(inputType))
return nullptr;
return builder.create<tensor::FromElementsOp>(
loc, RankedTensorType::get({}, inputType), inputs[0]);
}
namespace {
bool canBeDetensored(TensorType tensorType) {
return tensorType.hasRank() && tensorType.getRank() == 0;
}
bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
return genericOp &&
llvm::all_of(genericOp->getOpOperands(), [&](OpOperand &opOperand) {
return !typeConverter.isLegal(opOperand.get().getType());
});
}
class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(GenericOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Block *originalBlock = op->getBlock();
Block *opEntryBlock = &*op.getRegion().begin();
YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator());
Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));
rewriter.inlineRegionBefore(op.getRegion(), newBlock);
rewriter.replaceOp(op, yieldOp->getOperands());
rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
rewriter.mergeBlocks(newBlock, originalBlock, {});
rewriter.eraseOp(&*Block::iterator(yieldOp));
return success();
}
};
struct FunctionNonEntryBlockConversion
: public OpInterfaceConversionPattern<FunctionOpInterface> {
FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
DenseSet<BlockArgument> blockArgsToDetensor)
: OpInterfaceConversionPattern(converter, ctx),
blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
LogicalResult
matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.startOpModification(op);
Region ®ion = op.getFunctionBody();
for (Block &block :
llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
TypeConverter::SignatureConversion conversion(
block.getNumArguments());
for (BlockArgument blockArgument : block.getArguments()) {
int idx = blockArgument.getArgNumber();
if (blockArgsToDetensor.count(blockArgument))
conversion.addInputs(idx, {getTypeConverter()->convertType(
block.getArgumentTypes()[idx])});
else
conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
}
rewriter.applySignatureConversion(&block, conversion, getTypeConverter());
}
rewriter.finalizeOpModification(op);
return success();
}
private:
const DenseSet<BlockArgument> blockArgsToDetensor;
};
class DetensorizeTypeConverter : public TypeConverter {
public:
DetensorizeTypeConverter() {
addConversion([](Type type) { return type; });
addConversion([](TensorType tensorType) -> Type {
if (canBeDetensored(tensorType))
return tensorType.getElementType();
return tensorType;
});
addTargetMaterialization([](OpBuilder &builder, Type type,
ValueRange inputs, Location loc) -> Value {
return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
});
addSourceMaterialization(sourceMaterializationCallback);
addArgumentMaterialization(sourceMaterializationCallback);
}
};
struct LinalgDetensorize
: public impl::LinalgDetensorizePassBase<LinalgDetensorize> {
using impl::LinalgDetensorizePassBase<
LinalgDetensorize>::LinalgDetensorizePassBase;
LinalgDetensorize() = default;
class CostModel {
public:
virtual ~CostModel() = default;
virtual void compute(FunctionOpInterface func,
DetensorizeTypeConverter typeConverter,
DenseSet<Operation *> &opsToDetensor,
DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
const DenseSet<BlockArgument> &blockArgsToDetensor) {
DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
for (auto blockArgumentElem : blockArgsToDetensor) {
Block *block = blockArgumentElem.getOwner();
for (PredecessorIterator pred = block->pred_begin();
pred != block->pred_end(); ++pred) {
BranchOpInterface terminator =
dyn_cast<BranchOpInterface>((*pred)->getTerminator());
auto blockOperands =
terminator.getSuccessorOperands(pred.getSuccessorIndex());
if (blockOperands.empty() ||
blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
continue;
detensorableBranchOps[terminator].insert(
blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
}
}
return detensorableBranchOps;
}
};
class ControlFlowDetectionModel : public CostModel {
public:
void compute(FunctionOpInterface func,
DetensorizeTypeConverter typeConverter,
DenseSet<Operation *> &opsToDetensor,
DenseSet<BlockArgument> &blockArgsToDetensor) override {
SmallVector<Value> workList;
func->walk([&](cf::CondBranchOp condBr) {
llvm::append_range(workList, condBr.getOperands());
});
func->walk([&](cf::BranchOp br) {
llvm::append_range(workList, br.getOperands());
});
DenseSet<Value> visitedValues;
DenseSet<Operation *> visitedOps;
auto updateWorkListWithSuccessorArguments =
[&](Value value, BranchOpInterface terminator) {
if (!terminator)
return;
for (auto operandIdx :
llvm::seq<unsigned>(0, terminator->getOperands().size())) {
Value operand = terminator->getOperand(operandIdx);
if (operand == value) {
auto succBlockArg =
terminator.getSuccessorBlockArgument(operandIdx);
if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
workList.push_back(*succBlockArg);
}
}
};
while (!workList.empty()) {
Value currentItem = workList.pop_back_val();
if (!visitedValues.insert(currentItem).second)
continue;
updateWorkListWithSuccessorArguments(
currentItem, dyn_cast<BranchOpInterface>(
currentItem.getParentBlock()->getTerminator()));
for (auto *user : currentItem.getUsers())
llvm::append_range(workList, user->getResults());
if (dyn_cast<BlockArgument>(currentItem)) {
BlockArgument currentItemBlockArgument =
cast<BlockArgument>(currentItem);
Block *ownerBlock = currentItemBlockArgument.getOwner();
if (&*ownerBlock->getParent()->begin() == ownerBlock)
continue;
blockArgsToDetensor.insert(currentItemBlockArgument);
for (PredecessorIterator pred = ownerBlock->pred_begin();
pred != ownerBlock->pred_end(); ++pred) {
BranchOpInterface predTerminator =
dyn_cast<BranchOpInterface>((*pred)->getTerminator());
if (!predTerminator) {
opsToDetensor.clear();
blockArgsToDetensor.clear();
return;
}
auto ownerBlockOperands =
predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
if (ownerBlockOperands.empty() ||
ownerBlockOperands.isOperandProduced(
currentItemBlockArgument.getArgNumber()))
continue;
workList.push_back(
ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
}
continue;
}
Operation *currentItemDefiningOp = currentItem.getDefiningOp();
if (!visitedOps.insert(currentItemDefiningOp).second)
continue;
if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
if (opsToDetensor.count(genericOp))
continue;
if (!shouldBeDetensored(genericOp, typeConverter)) {
continue;
}
opsToDetensor.insert(genericOp);
llvm::append_range(workList, genericOp.getInputs());
continue;
}
if (isa<tensor::FromElementsOp>(currentItemDefiningOp))
continue;
if (llvm::all_of(
currentItemDefiningOp->getResultTypes(),
[&](Type resultType) { return resultType.isIntOrFloat(); }))
llvm::append_range(workList, currentItemDefiningOp->getOperands());
}
DenseSet<BlockArgument> blockArgsToRemove;
for (auto &blockArg : blockArgsToDetensor) {
Block *block = blockArg.getParentBlock();
for (PredecessorIterator pred = block->pred_begin();
pred != block->pred_end(); ++pred) {
BranchOpInterface terminator =
dyn_cast<BranchOpInterface>((*pred)->getTerminator());
auto blockOperands =
terminator.getSuccessorOperands(pred.getSuccessorIndex());
if (blockOperands.empty() ||
blockOperands.isOperandProduced(blockArg.getArgNumber()))
continue;
Operation *definingOp =
blockOperands[blockArg.getArgNumber()].getDefiningOp();
if (isa_and_nonnull<GenericOp>(definingOp) &&
opsToDetensor.count(definingOp) == 0) {
blockArgsToRemove.insert(blockArg);
break;
}
}
}
for (auto &blockArg : blockArgsToRemove) {
blockArgsToDetensor.erase(blockArg);
}
}
};
class AggressiveDetensoringModel : public CostModel {
public:
void compute(FunctionOpInterface func,
DetensorizeTypeConverter typeConverter,
DenseSet<Operation *> &opsToDetensor,
DenseSet<BlockArgument> &blockArgsToDetensor) override {
func->walk([&](GenericOp genericOp) {
if (shouldBeDetensored(genericOp, typeConverter))
opsToDetensor.insert(genericOp);
});
for (Block &block : llvm::drop_begin(func.getFunctionBody(), 1))
for (BlockArgument blockArgument : block.getArguments())
blockArgsToDetensor.insert(blockArgument);
}
};
void runOnOperation() override {
MLIRContext *context = &getContext();
DetensorizeTypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
DenseSet<Operation *> opsToDetensor;
DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
DenseSet<BlockArgument> blockArgsToDetensor;
FunctionOpInterface funcOp = getOperation();
if (funcOp.getFunctionBody().empty())
return;
IRRewriter rewriter(funcOp->getContext());
Block *entryBlock = &funcOp.getFunctionBody().front();
Block *postEntryBlock =
rewriter.splitBlock(entryBlock, entryBlock->begin());
rewriter.setInsertionPointToStart(entryBlock);
auto branch =
rewriter.create<cf::BranchOp>(rewriter.getUnknownLoc(), postEntryBlock);
if (aggressiveMode.getValue()) {
AggressiveDetensoringModel costModel;
costModel.compute(funcOp, typeConverter, opsToDetensor,
blockArgsToDetensor);
} else {
ControlFlowDetectionModel costModel;
costModel.compute(funcOp, typeConverter, opsToDetensor,
blockArgsToDetensor);
}
detensorableBranchOps =
CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
target.addDynamicallyLegalOp<GenericOp>(
[&](GenericOp op) { return !opsToDetensor.count(op); });
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
Region &body = funcOp.getFunctionBody();
return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
return !llvm::any_of(
blockArgsToDetensor, [&](BlockArgument blockArgument) {
return blockArgument.getOwner() == &block &&
!typeConverter.isLegal(blockArgument.getType());
});
});
}
if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
isLegalForReturnOpTypeConversionPattern(op, typeConverter,
true))
return true;
if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
if (!detensorableBranchOps.count(branchOp))
return true;
for (auto operandIdx : detensorableBranchOps[branchOp])
if (!typeConverter.isLegal(
branchOp->getOperand(operandIdx).getType()))
return false;
return true;
}
return false;
});
patterns.add<DetensorizeGenericOp>(typeConverter, context);
patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter,
blockArgsToDetensor);
auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
int operandIdx) -> bool {
return detensorableBranchOps.count(branchOp) &&
detensorableBranchOps[branchOp].count(operandIdx);
};
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
shouldConvertBranchOperand);
if (failed(
applyFullConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
RewritePatternSet canonPatterns(context);
tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(canonPatterns))))
signalPassFailure();
rewriter.eraseOp(branch);
rewriter.mergeBlocks(postEntryBlock, entryBlock);
}
};
}