#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
}
using namespace mlir;
static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
if (!OpTrait::hasElementwiseMappableTraits(op))
return false;
return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
}
static SmallVector<Value, 4>
getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) {
assert(isElementwiseMappableOpOnRankedTensors(op));
Location loc = op->getLoc();
ValueRange operands = op->getOperands();
TypeRange rankedTensorTypes = op->getResultTypes();
SmallVector<Value, 4> res;
res.reserve(rankedTensorTypes.size());
for (Type t : rankedTensorTypes) {
bool found = false;
for (Value v : operands) {
if (v.getType() == t) {
found = true;
res.push_back(v);
break;
}
}
if (found)
continue;
res.push_back(b.create<tensor::EmptyOp>(
loc, tensor::getMixedSizes(b, loc, operands.front()),
cast<RankedTensorType>(t).getElementType()));
}
return res;
}
namespace {
struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
: RewritePattern(MatchAnyOpTypeTag(), 1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
if (!isElementwiseMappableOpOnRankedTensors(op))
return rewriter.notifyMatchFailure(
op, "requires elementwise op on ranked tensors");
auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
SmallVector<AffineMap, 3> indexingMaps(
op->getNumResults() + op->getNumOperands(),
rewriter.getMultiDimIdentityMap(rank));
SmallVector<utils::IteratorType, 6> iteratorTypes(
rank, utils::IteratorType::parallel);
auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, op->getResultTypes(),
op->getOperands(),
outputs,
indexingMaps,
iteratorTypes,
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
auto resultTypes = llvm::to_vector<6>(
llvm::map_range(op->getResultTypes(), [](Type type) {
return cast<TensorType>(type).getElementType();
}));
auto *scalarOp =
builder.create(loc, op->getName().getIdentifier(),
regionArgs.take_front(op->getNumOperands()),
resultTypes, op->getAttrs());
builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
});
return success();
}
};
}
void mlir::linalg::populateElementwiseToLinalgConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
patterns.getContext());
}
namespace {
class ConvertElementwiseToLinalgPass
: public impl::ConvertElementwiseToLinalgPassBase<
ConvertElementwiseToLinalgPass> {
using impl::ConvertElementwiseToLinalgPassBase<
ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase;
void runOnOperation() final {
auto *func = getOperation();
auto *context = &getContext();
ConversionTarget target(*context);
RewritePatternSet patterns(context);
mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns);
target.markUnknownOpDynamicallyLegal([](Operation *op) {
return !isElementwiseMappableOpOnRankedTensors(op);
});
if (failed(applyPartialConversion(func, target, std::move(patterns))))
signalPassFailure();
}
};
}