//===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements logic for testing Linalg transformations.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "llvm/ADT/SmallVector.h"

using namespace mlir;
using namespace mlir::linalg;

namespace {
struct TestLinalgTransforms
    : public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> {
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms)

  TestLinalgTransforms() = default;
  TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}

  void getDependentDialects(DialectRegistry &registry) const override {
    // clang-format off
    registry.insert<affine::AffineDialect,
                    bufferization::BufferizationDialect,
                    memref::MemRefDialect,
                    scf::SCFDialect,
                    linalg::LinalgDialect,
                    vector::VectorDialect,
                    gpu::GPUDialect>();
    // clang-format on
  }
  StringRef getArgument() const final {
    return "test-linalg-transform-patterns";
  }
  StringRef getDescription() const final {
    return "Test Linalg transformation patterns by applying them greedily.";
  }

  void runOnOperation() override;

  Option<bool> testPatterns{*this, "test-patterns",
                            llvm::cl::desc("Test a mixed set of patterns"),
                            llvm::cl::init(false)};
  Option<bool> testVectorTransferForwardingPatterns{
      *this, "test-vector-transfer-forwarding-patterns",
      llvm::cl::desc(
          "Test a fused pass that forwards memref.copy to vector.transfer"),
      llvm::cl::init(false)};
  Option<bool> testGenericToVectorPattern{
      *this, "test-linalg-to-vector-patterns",
      llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
                     "in vector.contract form"),
      llvm::cl::init(false)};
  Option<bool> testGeneralizePadTensor{
      *this, "test-generalize-pad-tensor",
      llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
      llvm::cl::init(false)};
  Option<bool> testGeneralizeTensorPackOp{
      *this, "test-generalize-tensor-pack",
      llvm::cl::desc("Test transform that generalizes pack ops into a sequence "
                     "of tensor and Linalg ops"),
      llvm::cl::init(false)};
  Option<bool> testGeneralizeTensorUnPackOp{
      *this, "test-generalize-tensor-unpack",
      llvm::cl::desc(
          "Test transform that generalizes unpack ops into a sequence "
          "of tensor and Linalg ops"),
      llvm::cl::init(false)};
  Option<bool> testSwapSubTensorPadTensor{
      *this, "test-swap-subtensor-padtensor",
      llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into "
                     "tensor.pad(subtensor)"),
      llvm::cl::init(false)};
  ListOption<int64_t> peeledLoops{
      *this, "peeled-loops",
      llvm::cl::desc("Loops to be peeled when test-tile-pattern")};
  ListOption<int64_t> tileSizes{
      *this, "tile-sizes",
      llvm::cl::desc("Linalg tile sizes for test-tile-pattern")};
  Option<bool> skipPartial{
      *this, "skip-partial",
      llvm::cl::desc("Skip loops inside partial iterations during peeling"),
      llvm::cl::init(false)};
  Option<std::string> loopType{
      *this, "loop-type",
      llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
                     "tiled_loop"),
      llvm::cl::init("for")};
  Option<bool> testBubbleUpExtractSliceOpPattern{
      *this, "test-bubble-up-extract-slice-op-pattern",
      llvm::cl::desc("Test rewrite of linalgOp + extract_slice into "
                     "extract_slice + linalgOp"),
      llvm::cl::init(false)};
  Option<bool> testSwapExtractSliceWithFill{
      *this, "test-swap-extract-slice-with-fill-pattern",
      llvm::cl::desc(
          "Test patterns to swap tensor.extract_slice(linalg.fill())"),
      llvm::cl::init(false)};
  Option<bool> testEraseUnusedOperandsAndResults{
      *this, "test-erase-unused-operands-and-results",
      llvm::cl::desc("Test patterns to erase unused operands and results"),
      llvm::cl::init(false)};
  Option<bool> testEraseUnnecessaryInputs{
      *this, "test-erase-unnecessary-inputs",
      llvm::cl::desc("Test patterns to erase unnecessary inputs"),
      llvm::cl::init(false)};
  Option<bool> testWinogradConv2D{
      *this, "test-winograd-conv2d",
      llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
      llvm::cl::init(false)};
  Option<bool> testDecomposeWinogradOps{
      *this, "test-decompose-winograd-ops",
      llvm::cl::desc("Test decompose Winograd ops"), llvm::cl::init(false)};
};
} // namespace

static void applyPatterns(func::FuncOp funcOp) {
  MLIRContext *ctx = funcOp.getContext();
  RewritePatternSet patterns(ctx);

  //===--------------------------------------------------------------------===//
  // Linalg distribution patterns.
  //===--------------------------------------------------------------------===//
  LinalgLoopDistributionOptions distributionOptions;

  //===--------------------------------------------------------------------===//
  // Linalg to vector contraction patterns.
  //===--------------------------------------------------------------------===//
  patterns.add<CopyVectorizationPattern>(ctx);

  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
  RewritePatternSet forwardPattern(funcOp.getContext());
  forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
  forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
  (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
}

static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
  RewritePatternSet patterns(funcOp.getContext());
  auto *ctx = funcOp.getContext();
  patterns.add<CopyVectorizationPattern>(ctx);
  populatePadOpVectorizationPatterns(patterns);
  populateConvolutionVectorizationPatterns(patterns);
  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) {
  RewritePatternSet patterns(funcOp.getContext());
  patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applyGeneralizeTensorPackPatterns(func::FuncOp funcOp) {
  RewritePatternSet patterns(funcOp.getContext());
  patterns.add<GeneralizeOuterUnitDimsPackOpPattern>(funcOp.getContext());
  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applyGeneralizeTensorUnPackPatterns(func::FuncOp funcOp) {
  RewritePatternSet patterns(funcOp.getContext());
  patterns.add<GeneralizeOuterUnitDimsUnPackOpPattern>(funcOp.getContext());
  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
  RewritePatternSet patterns(funcOp.getContext());
  patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
  RewritePatternSet patterns(funcOp.getContext());
  populateBubbleUpExtractSliceOpPatterns(patterns);
  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) {
  RewritePatternSet patterns(funcOp.getContext());
  populateSwapExtractSliceWithFillPatterns(patterns);
  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) {
  RewritePatternSet patterns(funcOp.getContext());
  populateEraseUnusedOperandsAndResultsPatterns(patterns);
  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
  RewritePatternSet patterns(funcOp.getContext());
  populateEraseUnnecessaryInputsPatterns(patterns);
  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applyWinogradConv2D(func::FuncOp funcOp) {
  RewritePatternSet patterns(funcOp.getContext());
  populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
  populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
  RewritePatternSet patterns(funcOp.getContext());
  populateDecomposeWinogradOpsPatterns(patterns);
  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnOperation() {
  if (testPatterns)
    return applyPatterns(getOperation());
  if (testVectorTransferForwardingPatterns)
    return applyVectorTransferForwardingPatterns(getOperation());
  if (testGenericToVectorPattern)
    return applyLinalgToVectorPatterns(getOperation());
  if (testGeneralizePadTensor)
    return applyGeneralizePadTensorPatterns(getOperation());
  if (testGeneralizeTensorPackOp)
    return applyGeneralizeTensorPackPatterns(getOperation());
  if (testGeneralizeTensorUnPackOp)
    return applyGeneralizeTensorUnPackPatterns(getOperation());
  if (testSwapSubTensorPadTensor)
    return applyExtractSliceOfPadTensorSwapPattern(getOperation());
  if (testBubbleUpExtractSliceOpPattern)
    return applyBubbleUpExtractSliceOpPattern(getOperation());
  if (testSwapExtractSliceWithFill)
    return applySwapExtractSliceWithFillPattern(getOperation());
  if (testEraseUnusedOperandsAndResults)
    return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
  if (testEraseUnnecessaryInputs)
    return applyEraseUnnecessaryInputs(getOperation());
  if (testWinogradConv2D)
    return applyWinogradConv2D(getOperation());
  if (testDecomposeWinogradOps)
    return applyDecomposeWinogradOps(getOperation());
}

namespace mlir {
namespace test {
void registerTestLinalgTransforms() {
  PassRegistration<TestLinalgTransforms>();
}
} // namespace test
} // namespace mlir