//===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
// (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
// etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
// including transposing/reversing/reshaping etc..
//     of the weights and input/output tenors and reversing/reshaping etc .. of
//     the weights
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;
using namespace mlir::tosa;

namespace {

template <typename TosaOp, typename... Args>
TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
                        Args &&...args) {
  auto op = rewriter.create<TosaOp>(loc, resultTy, args...);

  InferShapedTypeOpInterface shapeInterface =
      dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
  if (!shapeInterface)
    return op;

  SmallVector<ShapedTypeComponents> returnedShapes;
  if (shapeInterface
          .inferReturnTypeComponents(
              op.getContext(), op.getLoc(), op->getOperands(),
              op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
              op->getRegions(), returnedShapes)
          .failed())
    return op;

  // We need to use the element type of the existing result type to generate
  // the new result shaped type. This is because rescale can include a cast to
  // different bit-width types and does not have a TypeAttr to define the
  // target type.
  auto result = op->getResult(0);
  auto predictedShape = returnedShapes[0];
  auto currentKnowledge =
      mlir::tosa::ValueKnowledge::getKnowledgeFromType(resultTy);

  // Compute the knowledge based on the inferred type.
  auto inferredKnowledge =
      mlir::tosa::ValueKnowledge::getPessimisticValueState();
  inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
  inferredKnowledge.hasRank = predictedShape.hasRank();
  if (predictedShape.hasRank()) {
    for (auto dim : predictedShape.getDims()) {
      inferredKnowledge.sizes.push_back(dim);
    }
  }

  // Compute the new type based on the joined version.
  auto newKnowledge =
      mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
  auto newTy = newKnowledge.getType();
  result.setType(newTy);
  return op;
}

class TransposeConvNonStridedConverter
    : public OpRewritePattern<tosa::TransposeConv2DOp> {
public:
  using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
  LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
                                PatternRewriter &rewriter) const final {
    Location loc = op->getLoc();
    Value input = op->getOperand(0);
    Value weight = op->getOperand(1);
    Value bias = op->getOperand(2);

    ShapedType inputTy = cast<ShapedType>(input.getType());
    ShapedType weightTy = cast<ShapedType>(weight.getType());
    ShapedType biasTy = cast<ShapedType>(bias.getType());
    ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());

    llvm::ArrayRef<int64_t> stride = op.getStride();
    llvm::ArrayRef<int64_t> pad = op.getOutPad();

    // If striding is all 1 we can modify padding and reverse the kernel along
    // the x/y direction to make it a regular convolution. This is much simpler
    // then handling striding....
    if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
      return failure();

    if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
        !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
      return failure();

    int64_t kernelHeight = weightTy.getDimSize(1);
    int64_t kernelWidth = weightTy.getDimSize(2);

    llvm::SmallVector<int64_t> convPad(4, 0);
    convPad[0] = kernelHeight - 1 + pad[0];
    convPad[1] = kernelHeight - 1 + pad[1];
    convPad[2] = kernelWidth - 1 + pad[2];
    convPad[3] = kernelWidth - 1 + pad[3];

    auto reverse1 = rewriter.create<tosa::ReverseOp>(
        loc, weightTy, weight, /* axis = */ rewriter.getI32IntegerAttr(1));
    auto reverse2 = rewriter.create<tosa::ReverseOp>(
        loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2));

    Value conv2d;
    if (op.getQuantizationInfo()) {
      conv2d = rewriter.create<tosa::Conv2DOp>(
          loc, resultTy, input, reverse2, bias,
          rewriter.getDenseI64ArrayAttr(convPad),
          rewriter.getDenseI64ArrayAttr(stride),
          rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo());
    } else {
      conv2d = rewriter.create<tosa::Conv2DOp>(
          loc, resultTy, input, reverse2, bias,
          rewriter.getDenseI64ArrayAttr(convPad),
          rewriter.getDenseI64ArrayAttr(stride),
          rewriter.getDenseI64ArrayAttr({1, 1}));
    }

    rewriter.replaceOp(op, conv2d);
    return success();
  }
};

class TransposeConvStridedConverter
    : public OpRewritePattern<tosa::TransposeConv2DOp> {
public:
  using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
  LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
                                PatternRewriter &rewriter) const final {
    Location loc = op->getLoc();
    Value input = op->getOperand(0);
    Value weight = op->getOperand(1);
    Value bias = op->getOperand(2);

    ShapedType inputTy = cast<ShapedType>(input.getType());
    ShapedType weightTy = cast<ShapedType>(weight.getType());
    ShapedType biasTy = cast<ShapedType>(bias.getType());
    ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());

    Type inputETy = inputTy.getElementType();
    Type weightETy = weightTy.getElementType();
    Type biasETy = biasTy.getElementType();
    Type resultETy = resultTy.getElementType();

    llvm::ArrayRef<int64_t> pad = op.getOutPad();
    llvm::ArrayRef<int64_t> stride = op.getStride();

    // If striding is all 1 we can modify padding and reverse the kernel along
    // the x/y direction to make it a regular convolution. This is much simpler
    // then handling striding....

    // If strides are all 1 we dont need to use this one.
    if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
      return rewriter.notifyMatchFailure(op, "non-one stride found.");

    if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
        !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
      return failure();

    int64_t batch = inputTy.getDimSize(0);

    int64_t outputChannels = weightTy.getDimSize(0);
    int64_t weightHeight = weightTy.getDimSize(1);
    int64_t weightWidth = weightTy.getDimSize(2);
    int64_t inputChannels = weightTy.getDimSize(3);

    // Pad the weight so that it is modulo of the striding.
    llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
    weightPadding[3] =
        (weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
    weightPadding[5] =
        (weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
    DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
        RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
    Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>(
        rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);

    if (op.getQuantizationInfo().has_value()) {
      auto quantInfo = op.getQuantizationInfo().value();
      weight = createOpAndInfer<tosa::PadOp>(
          rewriter, loc, UnrankedTensorType::get(weightETy), weight,
          weightPaddingVal, nullptr,
          rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));

    } else {
      weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
                                             UnrankedTensorType::get(weightETy),
                                             weight, weightPaddingVal);
    }

    weightTy = cast<ShapedType>(weight.getType());
    weightHeight = weightTy.getDimSize(1);
    weightWidth = weightTy.getDimSize(2);

    // Split out the width / height by the stride dimensions.
    llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
        outputChannels, weightHeight / stride[0],
        stride[0],      weightWidth / stride[1],
        stride[1],      inputChannels};
    weight = createOpAndInfer<tosa::ReshapeOp>(
        rewriter, loc, UnrankedTensorType::get(weightETy), weight,
        rewriter.getDenseI64ArrayAttr(weightReshapeDims0));

    // Transpose the factored-out stride to the output channels.
    Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
        loc, RankedTensorType::get({6}, rewriter.getI32Type()),
        rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));

    weight = createOpAndInfer<tosa::TransposeOp>(
        rewriter, loc, UnrankedTensorType::get(weightETy), weight,
        transposeWeightVal);

    // Collapse the strides and output channels into a single dimension.
    llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
        outputChannels * stride[0] * stride[1], weightHeight / stride[0],
        weightWidth / stride[1], inputChannels};
    weight = createOpAndInfer<tosa::ReshapeOp>(
        rewriter, loc, UnrankedTensorType::get(weightETy), weight,
        rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
    ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());

    weight = createOpAndInfer<tosa::ReverseOp>(
        rewriter, loc, UnrankedTensorType::get(weightETy), weight,
        /* axis = */ rewriter.getI32IntegerAttr(1));
    weight = createOpAndInfer<tosa::ReverseOp>(
        rewriter, loc, UnrankedTensorType::get(weightETy), weight,
        /* axis = */ rewriter.getI32IntegerAttr(2));

    // We need to pad the input far enough that we can pull all values.
    llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
    inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
    inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
    inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
    inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;

    DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
        RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);

    Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>(
        rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);

    if (op.getQuantizationInfo().has_value()) {
      auto quantInfo = op.getQuantizationInfo().value();
      input = createOpAndInfer<tosa::PadOp>(
          rewriter, loc, UnrankedTensorType::get(inputETy), input,
          inputPaddingVal, nullptr,
          rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
    } else {
      input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
                                            UnrankedTensorType::get(inputETy),
                                            input, inputPaddingVal);
    }

    // We use a zero bias as we need to broadcast the bias.
    auto zeroBias = rewriter.create<tosa::ConstOp>(
        loc,
        RankedTensorType::get({outputChannels * stride[0] * stride[1]},
                              biasETy),
        DenseElementsAttr::get(
            RankedTensorType::get({outputChannels * stride[0] * stride[1]},
                                  biasETy),
            rewriter.getZeroAttr(biasETy)));

    // Perform the convolution using the zero bias.
    Value conv2d;
    if (op.getQuantizationInfo()) {
      conv2d = createOpAndInfer<tosa::Conv2DOp>(
                   rewriter, loc, UnrankedTensorType::get(resultETy), input,
                   weight, zeroBias,
                   /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
                   /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
                   /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
                   *op.getQuantizationInfo())
                   .getResult();
    } else {
      conv2d = createOpAndInfer<tosa::Conv2DOp>(
                   rewriter, loc, UnrankedTensorType::get(resultETy), input,
                   weight, zeroBias,
                   /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
                   /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
                   /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}))
                   .getResult();
    }

    // Factor the resulting width / height.
    ShapedType convTy = cast<ShapedType>(conv2d.getType());
    Type convETy = convTy.getElementType();

    int64_t convHeight = convTy.getDimSize(1);
    int64_t convWidth = convTy.getDimSize(2);

    // Factor striding out of the convolution result.
    llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
        batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
    conv2d = createOpAndInfer<tosa::ReshapeOp>(
        rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
        rewriter.getDenseI64ArrayAttr(convReshapeDims0));

    // Transpose the factored-out stride to the output channels.
    Value transposeConvVal = rewriter.create<tosa::ConstOp>(
        loc, RankedTensorType::get({6}, rewriter.getI32Type()),
        rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));

    conv2d = createOpAndInfer<tosa::TransposeOp>(
        rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
        transposeConvVal);

    // Fuse striding behavior back into width / height.
    llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
        batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
    conv2d = createOpAndInfer<tosa::ReshapeOp>(
        rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
        rewriter.getDenseI64ArrayAttr(convReshapeDims1));

    // Determine the amount to slice / pad from the result start.
    int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
    int64_t resultSliceLeft = std::max<int64_t>(0, -pad[2]);
    int64_t resultPadTop = std::max<int64_t>(0, pad[0]);
    int64_t resultPadLeft = std::max<int64_t>(0, pad[2]);

    // Try to slice the targetted result size, cap to the convolutions width.
    int64_t resultSliceHeight =
        std::min<int64_t>(convReshapeDims1[1] - resultSliceTop,
                          resultTy.getDimSize(1) - resultPadTop);
    int64_t resultSliceWidth =
        std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft,
                          resultTy.getDimSize(2) - resultPadLeft);

    llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop,
                                                resultSliceLeft, 0};
    llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(),
                                            convReshapeDims1.end());
    sliceSize[1] = resultSliceHeight;
    sliceSize[2] = resultSliceWidth;

    auto slice = createOpAndInfer<tosa::SliceOp>(
                     rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
                     rewriter.getDenseI64ArrayAttr(sliceBegin),
                     rewriter.getDenseI64ArrayAttr(sliceSize))
                     .getResult();

    llvm::SmallVector<int32_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
    resultPadding[2] = resultPadTop;
    resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
    resultPadding[4] = resultPadLeft;
    resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];

    DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
        RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding);

    Value resultPaddingVal = createOpAndInfer<tosa::ConstOp>(
        rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);

    Value resultPad = createOpAndInfer<tosa::PadOp>(
        rewriter, loc, UnrankedTensorType::get(resultETy), slice,
        resultPaddingVal);

    if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) {
      return failure();
    }

    rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
    return success();
  }
};

} // namespace

void mlir::tosa::populateTosaDecomposeTransposeConv(
    MLIRContext *ctx, RewritePatternSet &patterns) {
  patterns.add<TransposeConvNonStridedConverter>(ctx);
  patterns.add<TransposeConvStridedConverter>(ctx);
}