#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/MathExtras.h"
namespace mlir {
namespace linalg {
namespace {
constexpr float G_2x2_3x3[] = {
-1, 0, 0,
1./2, -1./2, 1./2,
1./2, 1./2, 1./2,
0, 0, 1
};
constexpr float GT_2x2_3x3[] = {
-1, 1./2, 1./2, 0,
0, -1./2, 1./2, 0,
0, 1./2, 1./2, 1
};
constexpr float BT_2x2_3x3[] = {
-1, 0, 1, 0,
0, -1, 1, 0,
0, 1, 1, 0,
0, -1, 0, 1
};
constexpr float B_2x2_3x3[] = {
-1, 0, 0, 0,
0, -1, 1, -1,
1, 1, 1, 0,
0, 0, 0, 1
};
constexpr float AT_2x2_3x3[] = {
1, 1, 1, 0,
0, -1, 1, 1
};
constexpr float A_2x2_3x3[] = {
1, 0,
1, -1,
1, 1,
0, 1
};
constexpr float G_4x4_3x3[] = {
1, 0, 0,
-1./3, 1./3, -1./3,
-1./3, -1./3, -1./3,
1./12, -1./6, 1./3,
1./12, 1./6, 1./3,
0, 0, 1
};
constexpr float GT_4x4_3x3[] = {
1, -1./3, -1./3, 1./12, 1./12, 0,
0, 1./3, -1./3, -1./6, 1./6, 0,
0, -1./3, -1./3, 1./3, 1./3, 1
};
constexpr float BT_4x4_3x3[] = {
1./4, 0, -5./16, 0, 1./16, 0,
0, 1./4, -1./4, -1./16, 1./16, 0,
0, -1./4, -1./4, 1./16, 1./16, 0,
0, 1./4, -1./8, -1./4, 1./8, 0,
0, -1./4, -1./8, 1./4, 1./8, 0,
0, 1./4, 0, -5./16, 0, 1./16
};
constexpr float B_4x4_3x3[] = {
1./4, 0, 0, 0, 0, 0,
0, 1./4, -1./4, 1./4, -1./4, 1./4,
-5./16, -1./4, -1./4, -1./8, -1./8, 0,
0, -1./16, 1./16, -1./4, 1./4, -5./16,
1./16, 1./16, 1./16, 1./8, 1./8, 0,
0, 0, 0, 0, 0, 1./16
};
constexpr float AT_4x4_3x3[] = {
1./8, 1./4, 1./4, 1./8, 1./8, 0,
0, -1./4, 1./4, -1./4, 1./4, 0,
0, 1./4, 1./4, 1./2, 1./2, 0,
0, -1./4, 1./4, -1, 1, 1./2
};
constexpr float A_4x4_3x3[] = {
1./8, 0, 0, 0,
1./4, -1./4, 1./4, -1./4,
1./4, 1./4, 1./4, 1./4,
1./8, -1./4, 1./2, -1,
1./8, 1./4, 1./2, 1,
0, 0, 0, 1./2
};
constexpr float G_2x2_5x5[] = {
1, 0, 0, 0, 0,
1./6, -1./6, 1./6, -1./6, 1./6,
-1./6, -1./6, -1./6, -1./6, -1./6,
-4./15, 2./15, -1./15, 1./30, -1./60,
1./60, 1./30, 1./15, 2./15, 4./15,
0, 0, 0, 0, 1
};
constexpr float GT_2x2_5x5[] = {
1, 1./6, -1./6, -4./15, 1./60, 0,
0, -1./6, -1./6, 2./15, 1./30, 0,
0, 1./6, -1./6, -1./15, 1./15, 0,
0, -1./6, -1./6, 1./30, 2./15, 0,
0, 1./6, -1./6, -1./60, 4./15, 1
};
constexpr float BT_2x2_5x5[] = {
1./8, 3./16, -1./4, -3./16, 1./8, 0,
0, 1./8, 1./16, -5./16, 1./8, 0,
0, -1./8, -5./16, -1./16, 1./8, 0,
0, 1./4, -1./8, -1./4, 1./8, 0,
0, -1./8, -1./4, 1./8, 1./4, 0,
0, 1./8, 3./16, -1./4, -3./16, 1./8
};
constexpr float B_2x2_5x5[] = {
1./8, 0, 0, 0, 0, 0,
3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
-1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
-3./16, -5./16, -1./16, -1./4, 1./8, -1./4,
1./8, 1./8, 1./8, 1./8, 1./4, -3./16,
0, 0, 0, 0, 0, 1./8
};
constexpr float AT_2x2_5x5[] = {
1./2, 1, 1, 2, 1, 0,
0, -1, 1, -1, 2, 1./2
};
constexpr float A_2x2_5x5[] = {
1./2, 0,
1, -1,
1, 1,
2, -1,
1, 2,
0, 1./2
};
using TransformMapKeyTy = std::pair<int, int>;
constexpr TransformMapKeyTy F_2_3{2, 3};
constexpr TransformMapKeyTy F_4_3{4, 3};
constexpr TransformMapKeyTy F_2_5{2, 5};
struct TransformMatrix {
TransformMatrix(const float *table, int64_t rows, int64_t cols,
int64_t scalarFactor = 1)
: table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
const float *table;
int64_t rows;
int64_t cols;
int64_t scalarFactor;
};
Value create2DTransformMatrix(OpBuilder &builder, Location loc,
TransformMatrix transform, Type type) {
ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
return builder.create<arith::ConstantOp>(
loc, DenseFPElementsAttr::get(
RankedTensorType::get(
SmallVector<int64_t>{transform.rows, transform.cols}, type),
constVec));
}
Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
Value loopNorFIndex, Value loopCorFIndex,
Value heightOffset, Value widthOffset,
int64_t extractHeight, int64_t extractWidth,
int64_t loopNorFIdx, int64_t loopCorFIdx,
int64_t heightIdx, int64_t widthIdx) {
auto sourceType = cast<ShapedType>(source.getType());
Type elementType = sourceType.getElementType();
int64_t srcSize = sourceType.getRank();
auto oneIndex = builder.getIndexAttr(1);
SmallVector<OpFoldResult> offsets;
offsets.resize(srcSize);
offsets[loopNorFIdx] = loopNorFIndex;
offsets[loopCorFIdx] = loopCorFIndex;
offsets[heightIdx] = heightOffset;
offsets[widthIdx] = widthOffset;
SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
sizes[heightIdx] = builder.getIndexAttr(extractHeight);
sizes[widthIdx] = builder.getIndexAttr(extractWidth);
SmallVector<OpFoldResult> strides(srcSize, oneIndex);
auto extractFilterType =
RankedTensorType::get({extractHeight, extractWidth}, elementType);
auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
loc, extractFilterType, source, offsets, sizes, strides);
return extractFilterOp;
}
Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
Value tileHIndex, Value tileWIndex,
Value loopNorFIndex, Value loopCorFIndex,
int64_t tileHIdx, int64_t tileWIdx,
int64_t loopNorFIdx, int64_t loopCorFIdx,
int64_t heightIdx, int64_t widthIdx) {
auto sourceType = cast<ShapedType>(source.getType());
Type elementType = sourceType.getElementType();
auto sourceShape = sourceType.getShape();
int64_t srcSize = sourceType.getRank();
int64_t height = sourceShape[heightIdx];
int64_t width = sourceShape[widthIdx];
auto zeroIndex = builder.getIndexAttr(0);
auto oneIndex = builder.getIndexAttr(1);
SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
offsets.resize(srcSize);
offsets[tileHIdx] = tileHIndex;
offsets[tileWIdx] = tileWIndex;
offsets[loopNorFIdx] = loopNorFIndex;
offsets[loopCorFIdx] = loopCorFIndex;
SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
sizes[heightIdx] = builder.getIndexAttr(height);
sizes[widthIdx] = builder.getIndexAttr(width);
SmallVector<OpFoldResult> strides(srcSize, oneIndex);
auto extractFilterType = RankedTensorType::get({height, width}, elementType);
auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
loc, extractFilterType, source, offsets, sizes, strides);
return extractFilterOp;
}
Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
Value dest, Value loopNorFIndex, Value loopCorFIndex,
Value heightOffset, Value widthOffset, int64_t height,
int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
int64_t heightIdx, int64_t widthIdx) {
int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
auto oneIndex = builder.getIndexAttr(1);
SmallVector<OpFoldResult> retOffsets;
retOffsets.resize(destSize);
retOffsets[loopNorFIdx] = loopNorFIndex;
retOffsets[loopCorFIdx] = loopCorFIndex;
retOffsets[heightIdx] = heightOffset;
retOffsets[widthIdx] = widthOffset;
SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
retSizes[heightIdx] = builder.getIndexAttr(height);
retSizes[widthIdx] = builder.getIndexAttr(width);
SmallVector<OpFoldResult> strides(destSize, oneIndex);
auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
loc, source, dest, retOffsets, retSizes, strides);
return insertSliceOp;
}
Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
Value dest, Value tileHIndex, Value tileWIndex,
Value loopNorFIndex, Value loopCorFIndex, int64_t height,
int64_t width, int64_t tileHIdx, int64_t tileWIdx,
int64_t loopNorFIdx, int64_t loopCorFIdx,
int64_t heightIdx, int64_t widthIdx) {
int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
auto zeroIndex = builder.getIndexAttr(0);
auto oneIndex = builder.getIndexAttr(1);
SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
retOffsets.resize(destSize);
retOffsets[tileHIdx] = tileHIndex;
retOffsets[tileWIdx] = tileWIndex;
retOffsets[loopNorFIdx] = loopNorFIndex;
retOffsets[loopCorFIdx] = loopCorFIndex;
SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
retSizes[heightIdx] = builder.getIndexAttr(height);
retSizes[widthIdx] = builder.getIndexAttr(width);
SmallVector<OpFoldResult> strides(destSize, oneIndex);
auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
loc, source, dest, retOffsets, retSizes, strides);
return insertSliceOp;
}
Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
Value retValue, int64_t m, int64_t r,
bool leftTransform = true, bool rightTransform = true) {
static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
GMatrices = {
{F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
{F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
{F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
};
static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
GTMatrices = {
{F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
{F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
{F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
};
auto filterType = cast<ShapedType>(filter.getType());
Type elementType = filterType.getElementType();
auto filterShape = filterType.getShape();
int64_t filterF = filterShape[0];
int64_t filterH = filterShape[1];
int64_t filterW = filterShape[2];
int64_t filterC = filterShape[3];
if (filterH != r && filterH != 1)
return Value();
if (filterW != r && filterW != 1)
return Value();
Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
ValueRange args) -> scf::ValueVector {
Value FIter = ivs[0];
Value CIter = ivs[1];
auto extractFilter =
extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx,
zeroIdx, filterH, filterW, 0,
3, 1, 2);
TransformMapKeyTy key = {m, r};
int64_t retRows = 1;
Value matmulRetValue = extractFilter;
if (leftTransform) {
auto it = GMatrices.find(key);
if (it == GMatrices.end())
return {};
const TransformMatrix &GMatrix = it->second;
retRows = GMatrix.rows;
auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
elementType);
Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
auto matmulOp = builder.create<linalg::MatmulOp>(
loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
matmulRetValue = matmulOp.getResult(0);
}
if (rightTransform) {
auto it = GTMatrices.find(key);
if (it == GTMatrices.end())
return {};
const TransformMatrix >Matrix = it->second;
auto matmulType =
RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
elementType);
Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
auto matmulOp = builder.create<linalg::MatmulOp>(
loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
matmulRetValue = matmulOp.getResult(0);
}
int64_t retHeight = leftTransform ? m + r - 1 : 1;
int64_t retWidth = rightTransform ? m + r - 1 : 1;
auto insertSliceOp =
insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter,
zeroIdx, zeroIdx, retHeight, retWidth,
3, 2,
0, 1);
return {insertSliceOp};
};
auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
scf::LoopNest loops = scf::buildLoopNest(
rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
{oneStep, oneStep}, {retValue}, buildBody);
return loops.results[0];
}
Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
Value retValue, int64_t m, int64_t r,
bool leftTransform = true, bool rightTransform = true) {
static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
BTMatrices = {
{F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
{F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
{F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
};
static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
BMatrices = {
{F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
{F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
{F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
};
auto inputType = cast<ShapedType>(input.getType());
Type elementType = inputType.getElementType();
auto inputShape = inputType.getShape();
int64_t inputN = inputShape[0];
int64_t inputH = inputShape[1];
int64_t inputW = inputShape[2];
int64_t inputC = inputShape[3];
auto valueType = cast<ShapedType>(retValue.getType());
auto valueShape = valueType.getShape();
int64_t tileH = valueShape[2];
int64_t tileW = valueShape[3];
int64_t alphaH = leftTransform ? m + r - 1 : 1;
int64_t alphaW = rightTransform ? m + r - 1 : 1;
if ((inputH != (tileH * m) + (r - 1)) && inputH != 1)
return Value();
if ((inputW != (tileW * m) + (r - 1)) && inputW != 1)
return Value();
auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
ValueRange args) -> scf::ValueVector {
Value tileHIter = ivs[0];
Value tileWIter = ivs[1];
Value NIter = ivs[2];
Value CIter = ivs[3];
auto context = builder.getContext();
auto affineMap =
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
Value heightOffset =
builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
Value widthOffset =
builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
auto extractInput =
extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset,
widthOffset, alphaH, alphaW, 0,
3, 1, 2);
TransformMapKeyTy key = {m, r};
int64_t retRows = 1;
int64_t retCols = 1;
Value matmulRetValue = extractInput;
if (leftTransform) {
auto it = BTMatrices.find(key);
if (it == BTMatrices.end())
return {};
const TransformMatrix &BTMatrix = it->second;
retRows = BTMatrix.rows;
auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
elementType);
Value BT =
create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
auto matmulOp = builder.create<linalg::MatmulOp>(
loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
matmulRetValue = matmulOp.getResult(0);
}
if (rightTransform) {
auto it = BMatrices.find(key);
if (it == BMatrices.end())
return {};
const TransformMatrix &BMatrix = it->second;
retCols = BMatrix.cols;
auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
elementType);
Value B =
create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
auto matmulOp = builder.create<linalg::MatmulOp>(
loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
matmulRetValue = matmulOp.getResult(0);
}
auto combinedVal = insert2DDataTo6D(
builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter,
CIter, retRows, retCols, 2, 3, 4, 5,
0, 1);
return {combinedVal};
};
auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tileH);
auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
scf::LoopNest loops = scf::buildLoopNest(
rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
{tileHBound, tileWBound, nUpperBound, cUpperBound},
{oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody);
return loops.results[0];
}
static Value matrixMultiply(RewriterBase &rewriter, Location loc,
Value transformedFilter, Value transformedInput,
Type outputElementType) {
auto filterType = cast<ShapedType>(transformedFilter.getType());
assert(filterType.hasStaticShape() && "only support static shapes.");
ArrayRef<int64_t> filterShape = filterType.getShape();
Type filterElementType = filterType.getElementType();
auto filterReassocType = RankedTensorType::get(
{filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
filterElementType);
SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
loc, filterReassocType, transformedFilter, filterReassoc);
auto inputType = cast<ShapedType>(transformedInput.getType());
assert(inputType.hasStaticShape() && "only support static shapes.");
ArrayRef<int64_t> inputShape = inputType.getShape();
Type inputElementType = inputType.getElementType();
auto inputReassocType = RankedTensorType::get(
{inputShape[0] * inputShape[1],
inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
inputElementType);
SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
Value collapseInput = rewriter.create<tensor::CollapseShapeOp>(
loc, inputReassocType, transformedInput, inputReassoc);
auto matmulType = RankedTensorType::get(
{inputShape[0] * inputShape[1],
inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
outputElementType);
Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
outputElementType);
auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
loc, matmulType, ValueRange({collapseInput, collapseFilter}),
ValueRange{init});
SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
auto outputReassocType =
RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
inputShape[3], inputShape[4], filterShape[3]},
outputElementType);
auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
return expandOutput;
}
Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
Value output, int64_t m, int64_t r,
bool leftTransform = true, bool rightTransform = true) {
static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
ATMatrices = {
{F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
{F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
{F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
};
static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
AMatrices = {
{F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
{F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
{F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
};
auto valueType = cast<ShapedType>(value.getType());
Type elementType = valueType.getElementType();
auto valueShape = valueType.getShape();
int64_t valueH = valueShape[0];
int64_t valueW = valueShape[1];
int64_t valueN = valueShape[4];
int64_t valueF = valueShape[5];
int64_t alphaH = leftTransform ? m + r - 1 : 1;
int64_t alphaW = rightTransform ? m + r - 1 : 1;
if (valueH != alphaH && valueH != 1)
return Value();
if (valueW != alphaW && valueW != 1)
return Value();
auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
ValueRange args) -> scf::ValueVector {
Value tileHIter = ivs[0];
Value tileWIter = ivs[1];
Value NIter = ivs[2];
Value FIter = ivs[3];
auto extractValue =
extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter,
FIter, 2, 3, 4,
5, 0, 1);
TransformMapKeyTy key = {m, r};
int64_t retRows = 1;
int64_t retCols = 1;
int64_t leftScalarFactor = 1;
int64_t rightScalarFactor = 1;
Value matmulRetValue = extractValue;
if (leftTransform) {
auto it = ATMatrices.find(key);
if (it == ATMatrices.end())
return {};
const TransformMatrix &ATMatrix = it->second;
leftScalarFactor = ATMatrix.scalarFactor;
retRows = ATMatrix.rows;
auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
elementType);
Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
auto matmulOp = builder.create<linalg::MatmulOp>(
loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
matmulRetValue = matmulOp.getResult(0);
}
if (rightTransform) {
auto it = AMatrices.find(key);
if (it == AMatrices.end())
return {};
const TransformMatrix &AMatrix = it->second;
rightScalarFactor = AMatrix.scalarFactor;
auto matmulType =
RankedTensorType::get({retRows, AMatrix.cols}, elementType);
retCols = AMatrix.cols;
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
elementType);
Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
auto matmulOp = builder.create<linalg::MatmulOp>(
loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
matmulRetValue = matmulOp.getResult(0);
}
if (leftScalarFactor * rightScalarFactor != 1) {
Value scalarFactor = builder.create<arith::ConstantOp>(
loc,
FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
elementType);
auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
SmallVector<AffineMap> affineMaps = {
AffineMap::get(2, 0, init.getContext()), identityAffineMap};
auto broadcastedScalar =
rewriter
.create<linalg::GenericOp>(
loc, matmulType, ValueRange{scalarFactor}, ValueRange{init},
affineMaps,
llvm::ArrayRef<utils::IteratorType>{
utils::IteratorType::parallel,
utils::IteratorType::parallel},
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
})
.getResult(0);
matmulRetValue = builder
.create<linalg::MulOp>(
loc, matmulType,
ValueRange{broadcastedScalar, matmulRetValue},
ValueRange{init})
.getResult(0);
}
auto context = builder.getContext();
auto affineMap =
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
Value heightOffset =
builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
Value widthOffset =
builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
Value combinedVal =
insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
heightOffset, widthOffset, retRows, retCols,
0,
3, 1,
2);
return {combinedVal};
};
int64_t tilwH = valueShape[2];
int64_t tileW = valueShape[3];
auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tilwH);
auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
scf::LoopNest loops = scf::buildLoopNest(
rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
{tileHBound, tileWBound, nUpperBound, fUpperBound},
{oneStep, oneStep, oneStep, oneStep}, {output}, buildBody);
return loops.results[0];
}
static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
Value value, ArrayRef<int64_t> alignedShape) {
auto valueType = cast<ShapedType>(value.getType());
Type elementType = valueType.getElementType();
auto alignedType = RankedTensorType::get(alignedShape, elementType);
Value padValue = rewriter.create<arith::ConstantOp>(
loc, elementType, rewriter.getZeroAttr(elementType));
return linalg::makeComposedPadHighOp(rewriter, loc, alignedType, value,
padValue, false);
}
static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
Value value,
RankedTensorType extractedType) {
OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
OpFoldResult oneIndex = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
SmallVector<OpFoldResult, 4> strides(4, oneIndex);
ArrayRef<int64_t> extractedShape = extractedType.getShape();
SmallVector<OpFoldResult> sizes =
getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape));
return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
offsets, sizes, strides);
}
static bool hasAllOneValues(DenseIntElementsAttr attr) {
return llvm::all_of(
attr, [](const APInt &element) { return element.getSExtValue() == 1; });
}
static FailureOr<Operation *>
winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
int64_t m, int64_t r) {
Value input = convOp.getInputs()[0];
Value filter = convOp.getInputs()[1];
Value output = convOp.getOutputs()[0];
auto inputType = cast<ShapedType>(input.getType());
auto filterType = cast<ShapedType>(filter.getType());
auto outputType = cast<ShapedType>(output.getType());
if (!inputType.hasStaticShape())
return rewriter.notifyMatchFailure(convOp,
"expected a static shape for the input");
if (!filterType.hasStaticShape())
return rewriter.notifyMatchFailure(
convOp, "expected a static shape for the filter");
if (!hasAllOneValues(convOp.getDilations()))
return rewriter.notifyMatchFailure(convOp,
"expected all ones for dilations");
if (!hasAllOneValues(convOp.getStrides()))
return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
ArrayRef<int64_t> filterShape = filterType.getShape();
int64_t filterF = filterShape[0];
int64_t filterH = filterShape[1];
int64_t filterW = filterShape[2];
int64_t filterC = filterShape[3];
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t inputN = inputShape[0];
int64_t inputH = inputShape[1];
int64_t inputW = inputShape[2];
int64_t inputC = inputShape[3];
ArrayRef<int64_t> outputShape = outputType.getShape();
int64_t outputN = outputShape[0];
int64_t outputH = outputShape[1];
int64_t outputW = outputShape[2];
int64_t outputF = outputShape[3];
bool isSupportedFilter = false;
if (filterH == filterW && filterH == r)
isSupportedFilter = true;
if (filterH == r && filterW == 1)
isSupportedFilter = true;
if (filterH == 1 && filterW == r)
isSupportedFilter = true;
if (!isSupportedFilter)
return rewriter.notifyMatchFailure(
convOp, "only support filter (r x r), (r x 1) or (1 x r)");
static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
F_2_3, F_4_3, F_2_5};
TransformMapKeyTy key = {m, r};
auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
if (it == validConfigs.end())
return failure();
Location loc = convOp.getLoc();
bool leftTransform = filterH != 1;
bool rightTransform = filterW != 1;
int64_t heightM = leftTransform ? m : 1;
int64_t widthM = rightTransform ? m : 1;
int64_t heightR = leftTransform ? r : 1;
int64_t widthR = rightTransform ? r : 1;
Type filterElementType = filterType.getElementType();
int64_t alphaH = heightM + heightR - 1;
int64_t alphaW = widthM + widthR - 1;
int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF},
filterElementType);
Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
filterElementType);
auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
loc, retType, filter, retValue, m, r);
Type inputElementType = inputType.getElementType();
int64_t alignedInputH = tileH * heightM + (heightR - 1);
int64_t alignedInputW = tileW * widthM + (widthR - 1);
if (alignedInputH != inputH || alignedInputW != inputW) {
input = padToAlignedTensor(rewriter, loc, input,
{inputN, alignedInputH, alignedInputW, inputC});
}
retType = RankedTensorType::get(
{alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
inputElementType);
auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
loc, retType, input, retValue, m, r);
Type outputElementType = outputType.getElementType();
Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
transformedInput, outputElementType);
int64_t alignedOutputH = tileH * heightM;
int64_t alignedOutputW = tileW * widthM;
bool isOutputUnaligned =
((alignedOutputH != outputH) || (alignedOutputW != outputW));
if (isOutputUnaligned) {
auto alignedOutputType = RankedTensorType::get(
{outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
output =
padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape());
outputType = alignedOutputType;
}
Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
loc, outputType, matmulRet, output, m, r);
if (isOutputUnaligned) {
transformedOutput = extractFromAlignedTensor(
rewriter, loc, transformedOutput,
RankedTensorType::get({outputN, outputH, outputW, outputF},
outputElementType));
}
rewriter.replaceOp(convOp, transformedOutput);
return transformedOutput.getDefiningOp();
}
FailureOr<Operation *>
decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
linalg::WinogradFilterTransformOp op) {
Location loc = op.getLoc();
Value filter = op.getFilter();
auto filterType = cast<ShapedType>(filter.getType());
auto filterShape = filterType.getShape();
int64_t filterH = filterShape[1];
int64_t filterW = filterShape[2];
bool leftTransform = filterH != 1;
bool rightTransform = filterW != 1;
Value transformedFilter =
filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
op.getR(), leftTransform, rightTransform);
if (!transformedFilter)
return failure();
rewriter.replaceOp(op, transformedFilter);
return transformedFilter.getDefiningOp();
}
FailureOr<Operation *>
decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
linalg::WinogradInputTransformOp op) {
Location loc = op.getLoc();
Value input = op.getInput();
auto inputType = cast<ShapedType>(input.getType());
auto inputShape = inputType.getShape();
int64_t inputH = inputShape[1];
int64_t inputW = inputShape[2];
bool leftTransform = inputH != 1;
bool rightTransform = inputW != 1;
Value transformedInput =
inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
op.getR(), leftTransform, rightTransform);
if (!transformedInput)
return failure();
rewriter.replaceOp(op, transformedInput);
return transformedInput.getDefiningOp();
}
FailureOr<Operation *>
decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
linalg::WinogradOutputTransformOp op) {
Location loc = op.getLoc();
Value value = op.getValue();
auto valueType = cast<ShapedType>(value.getType());
auto valueShape = valueType.getShape();
int64_t valueH = valueShape[0];
int64_t valueW = valueShape[1];
bool leftTransform = valueH != 1;
bool rightTransform = valueW != 1;
Value transformedOutput =
outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
op.getR(), leftTransform, rightTransform);
if (!transformedOutput)
return failure();
rewriter.replaceOp(op, transformedOutput);
return transformedOutput.getDefiningOp();
}
class DecomposeWinogradFilterTransform final
: public OpRewritePattern<linalg::WinogradFilterTransformOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
PatternRewriter &rewriter) const override {
return decomposeWinogradFilterTransformHelper(rewriter, op);
}
};
class DecomposeWinogradInputTransform final
: public OpRewritePattern<linalg::WinogradInputTransformOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
PatternRewriter &rewriter) const override {
return decomposeWinogradInputTransformHelper(rewriter, op);
}
};
class DecomposeWinogradOutputTransform final
: public OpRewritePattern<linalg::WinogradOutputTransformOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
PatternRewriter &rewriter) const override {
return decomposeWinogradOutputTransformHelper(rewriter, op);
}
};
class WinogradConv2DNhwcFhwc final
: public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
public:
using OpRewritePattern::OpRewritePattern;
WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
: OpRewritePattern(context), m(m), r(r) {}
LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
PatternRewriter &rewriter) const override {
if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
return failure();
return success();
}
private:
int64_t m;
int64_t r;
};
}
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
linalg::Conv2DNhwcFhwcOp op, int64_t m,
int64_t r) {
return winogradConv2DHelper(rewriter, op, m, r);
}
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
int64_t r) {
MLIRContext *context = patterns.getContext();
patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
}
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns
.insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
DecomposeWinogradOutputTransform>(context);
}
}
}