//===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <utility>

namespace mlir {
namespace linalg {
static bool hasAllOneValues(DenseIntElementsAttr attr) {
  return llvm::all_of(
      attr, [](const APInt &element) { return element.getSExtValue() == 1; });
}

static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
  if (isa<IntegerType>(x.getType()))
    return builder.create<arith::AddIOp>(loc, x, y);
  if (isa<ComplexType>(x.getType()))
    return builder.create<complex::AddOp>(loc, x, y);
  return builder.create<arith::AddFOp>(loc, x, y);
}

static Value createMul(Location loc, Value x, Value y, Type accType,
                       OpBuilder &builder) {
  // Linalg named ops specify signed extend for named ops.
  Value xConvert =
      convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
  Value yConvert =
      convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
  if (isa<ComplexType>(accType))
    return builder.create<complex::MulOp>(loc, xConvert, yConvert);
  if (isa<IntegerType>(accType))
    return builder.create<arith::MulIOp>(loc, xConvert, yConvert);
  return builder.create<arith::MulFOp>(loc, xConvert, yConvert);
}

// Delinearizes the given composite `index` by the basis specified in `factors`.
static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
                                      ArrayRef<int64_t> factors) {
  assert(!factors.empty() && "empty factor list");
  SmallVector<Value> basis;
  for (int64_t f : factors)
    basis.push_back(b.create<arith::ConstantOp>(loc, b.getIndexAttr(f)));
  FailureOr<SmallVector<Value>> multiIndex =
      affine::delinearizeIndex(b, loc, index, basis);
  assert(!failed(multiIndex) && "Failed to linearize img2col index");
  return *multiIndex;
}

// Given indices corresponding to iterators in the output (oIndex) and filter
// (fIndex) for a convolution, compute the convolved index for the
// input as `oIndex * stride + fIndex`.
static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
                               Value fIndex, int64_t stride) {
  AffineExpr oExpr, fExpr;
  bindSymbols(b.getContext(), oExpr, fExpr);
  AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
  return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
}

FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());

  if (!filterType.hasStaticShape())
    return rewriter.notifyMatchFailure(
        convOp, "expected a static shape for the filter");

  if (!inputType.hasStaticShape())
    return rewriter.notifyMatchFailure(convOp,
                                       "expected a static shape for the input");

  // TODO: Support dilation.
  if (!hasAllOneValues(convOp.getDilations()))
    return rewriter.notifyMatchFailure(convOp,
                                       "expected all ones for dilations");

  MLIRContext *context = rewriter.getContext();
  Value input = convOp.getInputs()[0];
  Value filter = convOp.getInputs()[1];
  Value output = convOp.getOutputs()[0];

  ArrayRef<int64_t> filterShape = filterType.getShape();
  ArrayRef<int64_t> outputShape = outputType.getShape();

  int64_t n = outputShape[0];
  int64_t oh = outputShape[1];
  int64_t ow = outputShape[2];
  int64_t oc = outputShape[3];
  int64_t fh = filterShape[0];
  int64_t fw = filterShape[1];
  int64_t ic = filterShape[2];

  Location loc = convOp.getLoc();

  // Reshape output and filter to the LHS and result of a (B)MNK matmul.
  SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
  auto reshapedFilterType =
      RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
      loc, reshapedFilterType, filter, filterReassocIndices);

  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
  RankedTensorType reshapedOutputType =
      RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
      loc, reshapedOutputType, output, outputReassocIndices);

  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
  Value colTensor = rewriter.create<tensor::EmptyOp>(
      loc, colTensorShape, inputType.getElementType());

  // Convert the input to a (BMK) column tensor.
  auto nloops = colTensorShape.size();

  auto parallel = utils::IteratorType::parallel;
  auto reduction = utils::IteratorType::reduction;
  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);

  SmallVector<AffineMap> img2colIndexingMaps = {
      AffineMap::getMultiDimIdentityMap(nloops, context)};

  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
      loc, colTensor.getType(),
      /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
      img2colIterators,
      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
        // Get the iterators named based on the matmul (batch, m, k).
        Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
        Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
        Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);

        // Recover the original iteration indices from the problem/input sizes.
        SmallVector<Value> mIndices = unrollIndex(
            nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
        auto ohIndex = mIndices[0];
        auto owIndex = mIndices[1];

        SmallVector<Value> kIndices = unrollIndex(
            nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
        auto fhIndex = kIndices[0];
        auto fwIndex = kIndices[1];
        auto icIndex = kIndices[2];

        // Extract the input element corresponding to the expanded indices.
        Value hIndex =
            getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
                              convOp.getStrides().getValues<int64_t>()[0]);
        Value wIndex =
            getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
                              convOp.getStrides().getValues<int64_t>()[1]);

        // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
        SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
        Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
            loc, input, extractionIndices);
        nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
      });

  // Because the filter does not share the same batch dimension,
  // the batch dimension is only used in indexing the input and output. Thus
  // we cannot use existing linalg named ops like linalg.batch_matmul.
  // i.e. (B x) M x K * K x N = (B x) M x N
  AffineExpr bDim, mDim, nDim, kDim;
  bindDims(context, bDim, mDim, nDim, kDim);
  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
  auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
                                                       parallel, reduction};

  auto genericOp = rewriter.create<linalg::GenericOp>(
      loc, reshapedOutputType,
      /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
      /*outputs=*/ValueRange{reshapedOutput},
      ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
        Value mul =
            createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
        Value add = createAdd(loc, mul, args[2], nestedBuilder);
        nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
      });
  Value result = genericOp.getResults().front();

  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
      loc, outputType, result, outputReassocIndices);

  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});

  return std::make_pair(img2ColTensor.getOperation(),
                        reshapedResult.getOperation());
}

FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter,
                linalg::DepthwiseConv2DNhwcHwcOp convOp) {
  auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
  auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
  auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());

  if (!filterType.hasStaticShape())
    return rewriter.notifyMatchFailure(
        convOp, "expected a static shape for the filter");

  if (!inputType.hasStaticShape())
    return rewriter.notifyMatchFailure(convOp,
                                       "expected a static shape for the input");

  // TODO: Support dilation.
  if (!hasAllOneValues(convOp.getDilations()))
    return rewriter.notifyMatchFailure(convOp,
                                       "expected all ones for dilations");

  Location loc = convOp.getLoc();

  auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
    auto operandTensorType = cast<RankedTensorType>(operand.getType());
    auto nloops = indices.size();
    ArrayRef<int64_t> inputShape = operandTensorType.getShape();

    SmallVector<AffineExpr> exprs = llvm::to_vector<4>(
        llvm::map_range(indices, [&](int64_t index) -> AffineExpr {
          return rewriter.getAffineDimExpr(index);
        }));

    SmallVector<int64_t> targetShape = llvm::to_vector<4>(llvm::map_range(
        indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));

    Value outputTensor = rewriter.create<tensor::EmptyOp>(
        loc, targetShape, operandTensorType.getElementType());

    SmallVector<utils::IteratorType> loopAttributeTypes(
        nloops, utils::IteratorType::parallel);

    SmallVector<AffineMap> indexingMaps = {
        inversePermutation(
            AffineMap::get(nloops, 0, exprs, rewriter.getContext())),
        AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};

    auto transposedOp = rewriter.create<linalg::GenericOp>(
        loc, outputTensor.getType(),
        /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
        loopAttributeTypes,
        [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
          nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
        });

    return transposedOp.getResult(0);
  };

  Value input = convOp.getInputs()[0];
  Value filter = convOp.getInputs()[1];
  Value output = convOp.getOutputs()[0];

  // Transpose input, filter so channels are outermost
  Value inputT = transposeOperand(input, {0, 3, 1, 2});
  Value filterT = transposeOperand(filter, {2, 0, 1});
  ArrayRef<int64_t> filterTShape =
      cast<RankedTensorType>(filterT.getType()).getShape();
  ArrayRef<int64_t> outputShape = outputType.getShape();

  int n = outputShape[0];
  int oh = outputShape[1];
  int ow = outputShape[2];
  int c = outputShape[3];
  int fh = filterTShape[1];
  int fw = filterTShape[2];

  SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
  Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});

  AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
  bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);

  AffineExpr shSym = rewriter.getAffineConstantExpr(
      convOp.getStrides().getValues<int64_t>()[0]);
  AffineExpr swSym = rewriter.getAffineConstantExpr(
      convOp.getStrides().getValues<int64_t>()[1]);

  SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
                                        owDim * swSym + kwDim};

  auto nloops = colTensorShape.size();

  SmallVector<utils::IteratorType> loopAttributeTypes(
      nloops, utils::IteratorType::parallel);

  SmallVector<AffineMap> indexingMaps = {
      AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
      AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};

  Value colTensor = rewriter.create<tensor::EmptyOp>(
      loc, colTensorShape, inputType.getElementType());

  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
      loc, colTensor.getType(),
      /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
      loopAttributeTypes,
      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
        nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
      });

  SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
      {0, 1}, {2, 3}, {4, 5}};
  SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
  SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
                                                                 {2, 3}};

  auto reshapedImg2ColTensorType = RankedTensorType::get(
      {n * c, oh * ow, fh * fw}, inputType.getElementType());
  auto reshapedFilterTensorType =
      RankedTensorType::get({c, fh * fw}, filterType.getElementType());
  auto reshapedOutputTensorType =
      RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());

  Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
      loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
      img2ColTensorReassocIndices);
  Value reshapedFilterTensor = rewriter.create<tensor::CollapseShapeOp>(
      loc, reshapedFilterTensorType, filterT, filterReassociationIndice);
  Value reshapedoutputTensor = rewriter.create<tensor::CollapseShapeOp>(
      loc, reshapedOutputTensorType, transposedOutputTensor,
      outputReassociationIndice);

  auto batchMatVecResult = rewriter.create<linalg::BatchMatvecOp>(
      loc, TypeRange{reshapedoutputTensor.getType()},
      ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
      ValueRange{reshapedoutputTensor});

  SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
                                                                      {2, 3}};

  auto batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
      loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
      batchMatVecReassociationIndice);

  Value transposedResult =
      transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});

  rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult});
  return std::make_pair(img2ColTensor.getOperation(),
                        transposedResult.getDefiningOp());
}

FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());

  if (!filterType.hasStaticShape())
    return rewriter.notifyMatchFailure(
        convOp, "expected a static shape for the filter");

  if (!inputType.hasStaticShape())
    return rewriter.notifyMatchFailure(convOp,
                                       "expected a static shape for the input");

  // TODO: Support dilation.
  if (!hasAllOneValues(convOp.getDilations()))
    return rewriter.notifyMatchFailure(convOp,
                                       "expected all ones for dilations");

  Value input = convOp.getInputs()[0];
  Value filter = convOp.getInputs()[1];
  Value output = convOp.getOutputs()[0];

  auto filterShape = filterType.getShape();
  auto outputShape = outputType.getShape();

  int64_t n = outputShape[0];
  int64_t oc = outputShape[1];
  int64_t oh = outputShape[2];
  int64_t ow = outputShape[3];
  int64_t ic = filterShape[1];
  int64_t fh = filterShape[2];
  int64_t fw = filterShape[3];

  auto loc = convOp.getLoc();
  MLIRContext *context = rewriter.getContext();

  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
  auto reshapedFilterType =
      RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
      loc, reshapedFilterType, filter, filterReassocIndices);

  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
  auto reshapedOutputType =
      RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
      loc, reshapedOutputType, output, outputReassocIndices);

  // Convert the input to a (BKN) tensor.
  SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
  Value colTensor = rewriter.create<tensor::EmptyOp>(
      loc, colTensorShape, inputType.getElementType());

  auto nloops = colTensorShape.size();

  auto parallel = utils::IteratorType::parallel;
  auto reduction = utils::IteratorType::reduction;
  SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);

  SmallVector<AffineMap, 4> img2colIndexingMaps = {
      AffineMap::getMultiDimIdentityMap(nloops, context)};

  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
      loc, colTensor.getType(),
      /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
      img2colIterators,
      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
        // Get the iterators named based on the matmul (batch, m, k).
        Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
        Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
        Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);

        // Recover the original iteration indices from the problem/input sizes.
        SmallVector<Value> kIndices = unrollIndex(
            nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
        auto icIndex = kIndices[0];
        auto fhIndex = kIndices[1];
        auto fwIndex = kIndices[2];

        SmallVector<Value> nIndices = unrollIndex(
            nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
        auto ohIndex = nIndices[0];
        auto owIndex = nIndices[1];

        // Extract the input element corresponding to the expanded indices.
        Value hIndex =
            getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
                              convOp.getStrides().getValues<int64_t>()[0]);
        Value wIndex =
            getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
                              convOp.getStrides().getValues<int64_t>()[1]);

        // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
        SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
        Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
            loc, input, extractionIndices);
        nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
      });

  // Because the filter does not share the same batch dimension,
  // the batch dimension is only used in indexing the input and output. Thus
  // we cannot use existing linalg named ops like linalg.batch_matmul.
  // i.e. M x K * (B x) K x N = (B x) M x N
  AffineExpr bDim, mDim, nDim, kDim;
  bindDims(context, bDim, mDim, nDim, kDim);
  auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
  auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
                                                       parallel, reduction};
  auto genericOp = rewriter.create<linalg::GenericOp>(
      loc, reshapedOutputType,
      /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
      /*outputs=*/ValueRange{reshapedOutput},
      ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
        Value mul =
            createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
        Value add = createAdd(loc, mul, args[2], nestedBuilder);
        nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
      });
  Value result = genericOp.getResults().front();

  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
      loc, outputType, result, outputReassocIndices);

  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});

  return std::make_pair(img2ColTensor.getOperation(),
                        reshapedResult.getOperation());
}

FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());

  if (!filterType.hasStaticShape())
    return rewriter.notifyMatchFailure(
        convOp, "expected a static shape for the filter");

  if (!inputType.hasStaticShape())
    return rewriter.notifyMatchFailure(convOp,
                                       "expected a static shape for the input");

  // TODO: Support dilation.
  if (!hasAllOneValues(convOp.getDilations()))
    return rewriter.notifyMatchFailure(convOp,
                                       "expected all ones for dilations");

  MLIRContext *context = rewriter.getContext();
  Value input = convOp.getInputs()[0];
  Value filter = convOp.getInputs()[1];
  Value output = convOp.getOutputs()[0];

  ArrayRef<int64_t> filterShape = filterType.getShape();
  ArrayRef<int64_t> outputShape = outputType.getShape();

  int64_t n = outputShape[0];
  int64_t oh = outputShape[1];
  int64_t ow = outputShape[2];
  int64_t oc = outputShape[3];
  int64_t fh = filterShape[1];
  int64_t fw = filterShape[2];
  int64_t ic = filterShape[3];

  Location loc = convOp.getLoc();

  // Reshape output and filter to the LHS and result of a "row-wise" matrix
  // multiplication.
  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
  auto reshapedFilterType =
      RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
      loc, reshapedFilterType, filter, filterReassocIndices);

  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
  RankedTensorType reshapedOutputType =
      RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
      loc, reshapedOutputType, output, outputReassocIndices);

  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
  Value colTensor = rewriter.create<tensor::EmptyOp>(
      loc, colTensorShape, inputType.getElementType());

  // Convert the input to a (BMK) column tensor.
  auto nloops = colTensorShape.size();

  auto parallel = utils::IteratorType::parallel;
  auto reduction = utils::IteratorType::reduction;
  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);

  SmallVector<AffineMap> img2colIndexingMaps = {
      AffineMap::getMultiDimIdentityMap(nloops, context)};

  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
      loc, colTensor.getType(),
      /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
      img2colIterators,
      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
        // Get the iterators named based on the matmul (batch, m, k).
        Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
        Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
        Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);

        // Recover the original iteration indices from the problem/input sizes.
        SmallVector<Value> mIndices = unrollIndex(
            nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
        auto ohIndex = mIndices[0];
        auto owIndex = mIndices[1];

        SmallVector<Value> kIndices = unrollIndex(
            nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
        auto fhIndex = kIndices[0];
        auto fwIndex = kIndices[1];
        auto icIndex = kIndices[2];

        // Extract the input element corresponding to the expanded indices.
        Value hIndex =
            getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
                              convOp.getStrides().getValues<int64_t>()[0]);
        Value wIndex =
            getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
                              convOp.getStrides().getValues<int64_t>()[1]);

        // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
        SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
        Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
            loc, input, extractionIndices);
        nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
      });

  // Because we didn't transpose the filters we don't actually have a batched
  // matrix multiply. Instead, we have an operation consisting of "row-wise" dot
  // products.
  AffineExpr bDim, mDim, nDim, kDim;
  bindDims(context, bDim, mDim, nDim, kDim);
  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
  auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
                                                       parallel, reduction};

  auto genericOp = rewriter.create<linalg::GenericOp>(
      loc, reshapedOutputType,
      /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
      /*outputs=*/ValueRange{reshapedOutput},
      ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
        Value mul =
            createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
        Value add = createAdd(loc, mul, args[2], nestedBuilder);
        nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
      });
  Value result = genericOp.getResults().front();

  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
      loc, outputType, result, outputReassocIndices);

  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});

  return std::make_pair(img2ColTensor.getOperation(),
                        reshapedResult.getOperation());
}

namespace {

class ConvertConv2DNhwcHwcf final
    : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
public:
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
                                PatternRewriter &rewriter) const override {
    if (failed(rewriteInIm2Col(rewriter, convOp)))
      return failure();
    return success();
  }
};

class ConvertDepthwiseConv2DNhwcHwc final
    : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
public:
  using OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
                                PatternRewriter &rewriter) const override {
    if (failed(rewriteInIm2Col(rewriter, convOp)))
      return failure();
    return success();
  }
};

class ConvertConv2DNchwFchw final
    : public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
public:
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
                                PatternRewriter &rewriter) const override {
    if (failed(rewriteInIm2Col(rewriter, convOp)))
      return failure();
    return success();
  }
};

class ConvertConv2DNhwcFhwc final
    : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
public:
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
                                PatternRewriter &rewriter) const override {
    if (failed(rewriteInIm2Col(rewriter, convOp)))
      return failure();
    return success();
  }
};
} // end anonymous namespace

void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) {
  MLIRContext *context = patterns.getContext();
  patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
                  ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
}
} // end namespace linalg
} // end namespace mlir