#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmNeon/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "lower-contract-to-arm-neon"
using namespace mlir;
using namespace mlir::arm_neon;
namespace {
static Type matchContainerType(Type element, Type container) {
if (auto shapedTy = dyn_cast<ShapedType>(container)) {
return shapedTy.clone(element);
}
return element;
}
class LowerContractionToSMMLAPattern
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
return failure();
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
auto dimK = rhsType.getDimSize(1);
bool isVecmat = dimM == 1 ? true : false;
if (lhsType.getDimSize(lhsType.getRank() - 1) !=
rhsType.getDimSize(rhsType.getRank() - 1)) {
return failure();
}
if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
return failure();
}
auto iteratorTypes = op.getIteratorTypesArray();
if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
vector::IteratorType::reduction) {
return failure();
}
if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
[](vector::IteratorType iteratorType) {
return iteratorType != vector::IteratorType::parallel;
})) {
return failure();
}
arith::ExtSIOp origLhsExtOp =
dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
arith::ExtSIOp origRhsExtOp =
dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
if (!origLhsExtOp || !origRhsExtOp) {
return failure();
}
Value extsiLhs;
Value extsiRhs;
if (auto lhsExtInType =
dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
if (lhsExtInType.getElementTypeBitWidth() <= 8) {
Type targetLhsExtTy =
matchContainerType(rewriter.getI8Type(), lhsExtInType);
extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
origLhsExtOp.getIn());
}
}
if (auto rhsExtInType =
dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
if (rhsExtInType.getElementTypeBitWidth() <= 8) {
Type targetRhsExtTy =
matchContainerType(rewriter.getI8Type(), rhsExtInType);
extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
origRhsExtOp.getIn());
}
}
if (!extsiLhs || !extsiRhs) {
return failure();
}
Value result = rewriter.create<arith::ConstantOp>(
loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
SmallVector<int64_t> smmlaShape{2, 8};
SmallVector<int64_t> loopOrder{0, 1};
if (unrolledSize.size() == 3) {
smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
loopOrder.push_back(2);
}
Value kAcc;
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
auto extractOperand = [&](Value operand, AffineMap permutationMap,
ArrayRef<int64_t> operandOffsets) {
SmallVector<int64_t> operandShape =
applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, operand, operandOffsets, operandShape, operandStrides);
};
AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
SmallVector<int64_t> lhsOffsets =
applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
SmallVector<int64_t> rhsOffsets =
applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
SmallVector<int64_t> accOffsets =
applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
Value tiledAcc =
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
auto inputElementType =
cast<ShapedType>(tiledLhs.getType()).getElementType();
auto accElementType =
cast<ShapedType>(tiledAcc.getType()).getElementType();
auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
auto outputExpandedType = VectorType::get({2, 2}, accElementType);
if (isVecmat) {
auto expandForSMMLA = [&](Value tiledOperand,
VectorType expandedTypeType) {
auto emptyOperand = rewriter.create<arith::ConstantOp>(
loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
SmallVector<int64_t> offsets(
cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
SmallVector<int64_t> strides(
cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
return rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tiledOperand, emptyOperand, offsets, strides);
};
tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
}
auto collapsedInputType =
VectorType::get(inputExpandedType.getNumElements(), inputElementType);
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
auto collapsedOutputType =
VectorType::get(outputExpandedType.getNumElements(), accElementType);
bool initialKAcc = offsets.back() == 0;
Value collapsedRes;
if (!initialKAcc) {
collapsedRes = kAcc;
} else {
collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
}
kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
collapsedRhs);
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
kAcc.getLoc(), tiledAcc.getType(), kAcc);
if (isVecmat) {
tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
}
SmallVector<int64_t> strides(
cast<ShapedType>(tiledRes.getType()).getRank(), 1);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tiledRes, result, accOffsets, strides);
}
rewriter.replaceOp(op, result);
return success();
}
};
}
void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<LowerContractionToSMMLAPattern>(context, 1);
}