//===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass for testing fusion of elementwise operations in
// Linalg, mainly linalg options.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;

static void addOperands(Operation *op, SetVector<Value> &operandSet) {
  if (!op)
    return;
  TypeSwitch<Operation *, void>(op)
      .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
        SmallVector<Value> inputOperands = linalgOp.getDpsInputs();
        operandSet.insert(inputOperands.begin(), inputOperands.end());
      })
      .Default([&](Operation *operation) {
        operandSet.insert(operation->operand_begin(), operation->operand_end());
      });
}

template <int limit = 3>
static bool setFusedOpOperandLimit(OpOperand *fusedOperand) {
  Operation *producer = fusedOperand->get().getDefiningOp();
  if (!producer)
    return false;

  Operation *consumer = fusedOperand->getOwner();
  SetVector<Value> fusedOpOperands;
  if (producer->getNumResults() != 1)
    return false;
  addOperands(consumer, fusedOpOperands);
  fusedOpOperands.remove(producer->getResult(0));
  addOperands(producer, fusedOpOperands);
  return fusedOpOperands.size() <= limit;
}

namespace {

/// Pattern to test fusion of producer with consumer, even if producer has
/// multiple uses.
struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
  using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
                                PatternRewriter &rewriter) const override {
    OpOperand *fusableOperand = nullptr;
    for (OpOperand &operand : genericOp->getOpOperands()) {
      if (linalg::areElementwiseOpsFusable(&operand)) {
        fusableOperand = &operand;
        break;
      }
    }
    if (!fusableOperand) {
      return rewriter.notifyMatchFailure(genericOp, "no fusable operand found");
    }
    std::optional<linalg::ElementwiseOpFusionResult> fusionResult =
        linalg::fuseElementwiseOps(rewriter, fusableOperand);
    if (!fusionResult)
      return rewriter.notifyMatchFailure(genericOp, "fusion failed");
    for (auto [origValue, replacement] : fusionResult->replacements) {
      rewriter.replaceUsesWithIf(origValue, replacement, [&](OpOperand &use) {
        return use.getOwner() != genericOp.getOperation();
      });
    }
    rewriter.eraseOp(genericOp);
    return success();
  }
};

struct TestLinalgElementwiseFusion
    : public PassWrapper<TestLinalgElementwiseFusion,
                         OperationPass<func::FuncOp>> {
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion)

  TestLinalgElementwiseFusion() = default;
  TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass)
      : PassWrapper(pass) {}
  void getDependentDialects(DialectRegistry &registry) const override {
    registry.insert<affine::AffineDialect, linalg::LinalgDialect,
                    memref::MemRefDialect, tensor::TensorDialect>();
  }
  StringRef getArgument() const final {
    return "test-linalg-elementwise-fusion-patterns";
  }
  StringRef getDescription() const final {
    return "Test Linalg element wise operation fusion patterns";
  }

  Option<bool> fuseGenericOps{
      *this, "fuse-generic-ops",
      llvm::cl::desc("Test fusion of generic operations."),
      llvm::cl::init(false)};

  Option<bool> fuseGenericOpsControl{
      *this, "fuse-generic-ops-control",
      llvm::cl::desc(
          "Test fusion of generic operations with a control function."),
      llvm::cl::init(false)};

  Option<bool> fuseWithReshapeByExpansion{
      *this, "fuse-with-reshape-by-expansion",
      llvm::cl::desc(
          "Test fusion of generic operations with reshape by expansion"),
      llvm::cl::init(false)};

  Option<bool> controlFuseByExpansion{
      *this, "control-fusion-by-expansion",
      llvm::cl::desc(
          "Test controlling fusion of reshape with generic op by expansion"),
      llvm::cl::init(false)};

  Option<bool> fuseWithReshapeByCollapsing{
      *this, "fuse-with-reshape-by-collapsing",
      llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
                     "collapse the iteration space of the consumer"),
      llvm::cl::init(false)};

  Option<bool> fuseWithReshapeByCollapsingWithControlFn{
      *this, "fuse-with-reshape-by-collapsing-control",
      llvm::cl::desc("Test controlling the linalg expand_shape -> generic "
                     "fusion patterns that "
                     "collapse the iteration space of the consumer"),
      llvm::cl::init(false)};

  Option<bool> fuseMultiUseProducer{
      *this, "fuse-multiuse-producer",
      llvm::cl::desc("Test fusion of producer ops with multiple uses"),
      llvm::cl::init(false)};

  ListOption<int64_t> collapseDimensions{
      *this, "collapse-dimensions-control",
      llvm::cl::desc("Test controlling dimension collapse pattern")};

  void runOnOperation() override {
    MLIRContext *context = &this->getContext();
    func::FuncOp funcOp = this->getOperation();

    if (fuseGenericOps) {
      RewritePatternSet fusionPatterns(context);
      auto controlFn = [](OpOperand *operand) { return true; };
      linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
                                              std::move(fusionPatterns))))
        return signalPassFailure();
      return;
    }

    if (fuseGenericOpsControl) {
      RewritePatternSet fusionPatterns(context);
      linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
                                                   setFusedOpOperandLimit<4>);

      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
                                              std::move(fusionPatterns))))
        return signalPassFailure();
      return;
    }

    if (fuseWithReshapeByExpansion) {
      RewritePatternSet fusionPatterns(context);
      linalg::populateFoldReshapeOpsByExpansionPatterns(
          fusionPatterns, [](OpOperand * /*fusedOperand*/) { return true; });
      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
                                              std::move(fusionPatterns))))
        return signalPassFailure();
      return;
    }

    if (controlFuseByExpansion) {
      RewritePatternSet fusionPatterns(context);

      linalg::ControlFusionFn controlReshapeFusionFn =
          [](OpOperand *fusedOperand) {
            Operation *producer = fusedOperand->get().getDefiningOp();
            if (!producer)
              return false;

            if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(producer)) {
              if (!collapseOp.getSrc().getDefiningOp<linalg::LinalgOp>()) {
                return false;
              }
            }

            Operation *consumer = fusedOperand->getOwner();
            if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(consumer)) {
              if (expandOp->hasOneUse()) {
                OpOperand &use = *expandOp->getUses().begin();
                auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
                if (linalgOp && linalgOp.isDpsInit(&use))
                  return true;
              }
              return false;
            }
            return true;
          };

      linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
                                                        controlReshapeFusionFn);
      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
                                              std::move(fusionPatterns))))
        return signalPassFailure();
      return;
    }

    if (fuseWithReshapeByCollapsing) {
      RewritePatternSet patterns(context);
      linalg::populateFoldReshapeOpsByCollapsingPatterns(
          patterns, [](OpOperand * /*fusedOperand */) { return true; });
      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
                                              std::move(patterns))))
        return signalPassFailure();
      return;
    }

    if (fuseWithReshapeByCollapsingWithControlFn) {
      RewritePatternSet patterns(context);
      linalg::ControlFusionFn controlFn = [](OpOperand *fusedOperand) -> bool {
        Operation *producer = fusedOperand->get().getDefiningOp();
        if (isa<tensor::ExpandShapeOp>(producer)) {
          // Skip fusing the first operand.
          return fusedOperand->getOperandNumber();
        }
        return true;
      };
      linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
                                              std::move(patterns))))
        return signalPassFailure();
      return;
    }

    if (fuseMultiUseProducer) {
      RewritePatternSet patterns(context);
      patterns.insert<TestMultiUseProducerFusion>(context);
      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
                                              std::move(patterns))))
        return signalPassFailure();
      return;
    }

    if (!collapseDimensions.empty()) {
      SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
                                   collapseDimensions.end());
      linalg::GetCollapsableDimensionsFn collapseFn =
          [&dims](linalg::LinalgOp op) {
            SmallVector<ReassociationIndices> reassociations;
            reassociations.emplace_back(dims);
            return reassociations;
          };
      RewritePatternSet patterns(context);
      linalg::populateCollapseDimensions(patterns, collapseFn);
      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
                                              std::move(patterns))))
        return signalPassFailure();
      return;
    }
  }
};

} // namespace

namespace mlir {
namespace test {
void registerTestLinalgElementwiseFusion() {
  PassRegistration<TestLinalgElementwiseFusion>();
}
} // namespace test
} // namespace mlir