//===----------- MultiBuffering.cpp ---------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements multi buffering transformation.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"

using namespace mlir;

#define DEBUG_TYPE "memref-transforms"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")

/// Return true if the op fully overwrite the given `buffer` value.
static bool overrideBuffer(Operation *op, Value buffer) {
  auto copyOp = dyn_cast<memref::CopyOp>(op);
  if (!copyOp)
    return false;
  return copyOp.getTarget() == buffer;
}

/// Replace the uses of `oldOp` with the given `val` and for subview uses
/// propagate the type change. Changing the memref type may require propagating
/// it through subview ops so we cannot just do a replaceAllUse but need to
/// propagate the type change and erase old subview ops.
static void replaceUsesAndPropagateType(RewriterBase &rewriter,
                                        Operation *oldOp, Value val) {
  SmallVector<Operation *> opsToDelete;
  SmallVector<OpOperand *> operandsToReplace;

  // Save the operand to replace / delete later (avoid iterator invalidation).
  // TODO: can we use an early_inc iterator?
  for (OpOperand &use : oldOp->getUses()) {
    // Non-subview ops will be replaced by `val`.
    auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
    if (!subviewUse) {
      operandsToReplace.push_back(&use);
      continue;
    }

    // `subview(old_op)` is replaced by a new `subview(val)`.
    OpBuilder::InsertionGuard g(rewriter);
    rewriter.setInsertionPoint(subviewUse);
    Type newType = memref::SubViewOp::inferRankReducedResultType(
        subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
        subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
        subviewUse.getStaticStrides());
    Value newSubview = rewriter.create<memref::SubViewOp>(
        subviewUse->getLoc(), cast<MemRefType>(newType), val,
        subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
        subviewUse.getMixedStrides());

    // Ouch recursion ... is this really necessary?
    replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);

    opsToDelete.push_back(use.getOwner());
  }

  // Perform late replacement.
  // TODO: can we use an early_inc iterator?
  for (OpOperand *operand : operandsToReplace) {
    Operation *op = operand->getOwner();
    rewriter.startOpModification(op);
    operand->set(val);
    rewriter.finalizeOpModification(op);
  }

  // Perform late op erasure.
  // TODO: can we use an early_inc iterator?
  for (Operation *op : opsToDelete)
    rewriter.eraseOp(op);
}

// Transformation to do multi-buffering/array expansion to remove dependencies
// on the temporary allocation between consecutive loop iterations.
// Returns success if the transformation happened and failure otherwise.
// This is not a pattern as it requires propagating the new memref type to its
// uses and requires updating subview ops.
FailureOr<memref::AllocOp>
mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
                          unsigned multiBufferingFactor,
                          bool skipOverrideAnalysis) {
  LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n");
  DominanceInfo dom(allocOp->getParentOp());
  LoopLikeOpInterface candidateLoop;
  for (Operation *user : allocOp->getUsers()) {
    auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
    if (!parentLoop) {
      if (isa<memref::DeallocOp>(user)) {
        // Allow dealloc outside of any loop.
        // TODO: The whole precondition function here is very brittle and will
        // need to rethought an isolated into a cleaner analysis.
        continue;
      }
      LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n");
      LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n");
      return failure();
    }
    if (!skipOverrideAnalysis) {
      /// Make sure there is no loop-carried dependency on the allocation.
      if (!overrideBuffer(user, allocOp.getResult())) {
        LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n");
        continue;
      }
      // If this user doesn't dominate all the other users keep looking.
      if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
            return !dom.dominates(user, otherUser);
          })) {
        LLVM_DEBUG(
            DBGS() << "--Skip user: does not dominate all other users\n");
        continue;
      }
    } else {
      if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
            return !isa<memref::DeallocOp>(otherUser) &&
                   !parentLoop->isProperAncestor(otherUser);
          })) {
        LLVM_DEBUG(
            DBGS()
            << "--Skip user: not all other users are in the parent loop\n");
        continue;
      }
    }
    candidateLoop = parentLoop;
    break;
  }

  if (!candidateLoop) {
    LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n");
    return failure();
  }

  std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
  std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
  std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
  if (!inductionVar || !lowerBound || !singleStep ||
      !llvm::hasSingleElement(candidateLoop.getLoopRegions())) {
    LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n");
    return failure();
  }

  if (!dom.dominates(allocOp.getOperation(), candidateLoop)) {
    LLVM_DEBUG(DBGS() << "Skip alloc: does not dominate candidate loop\n");
    return failure();
  }

  LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n");

  // 1. Construct the multi-buffered memref type.
  ArrayRef<int64_t> originalShape = allocOp.getType().getShape();
  SmallVector<int64_t, 4> multiBufferedShape{multiBufferingFactor};
  llvm::append_range(multiBufferedShape, originalShape);
  LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n");
  MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType())
                                .setShape(multiBufferedShape)
                                .setLayout(MemRefLayoutAttrInterface());
  LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n");

  // 2. Create the multi-buffered alloc.
  Location loc = allocOp->getLoc();
  OpBuilder::InsertionGuard g(rewriter);
  rewriter.setInsertionPoint(allocOp);
  auto mbAlloc = rewriter.create<memref::AllocOp>(
      loc, mbMemRefType, ValueRange{}, allocOp->getAttrs());
  LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n");

  // 3. Within the loop, build the modular leading index (i.e. each loop
  // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor).
  rewriter.setInsertionPointToStart(
      &candidateLoop.getLoopRegions().front()->front());
  Value ivVal = *inductionVar;
  Value lbVal = getValueOrCreateConstantIndexOp(rewriter, loc, *lowerBound);
  Value stepVal = getValueOrCreateConstantIndexOp(rewriter, loc, *singleStep);
  AffineExpr iv, lb, step;
  bindDims(rewriter.getContext(), iv, lb, step);
  Value bufferIndex = affine::makeComposedAffineApply(
      rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor,
      {ivVal, lbVal, stepVal});
  LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n");

  // 4. Build the subview accessing the particular slice, taking modular
  // rotation into account.
  int64_t mbMemRefTypeRank = mbMemRefType.getRank();
  IntegerAttr zero = rewriter.getIndexAttr(0);
  IntegerAttr one = rewriter.getIndexAttr(1);
  SmallVector<OpFoldResult> offsets(mbMemRefTypeRank, zero);
  SmallVector<OpFoldResult> sizes(mbMemRefTypeRank, one);
  SmallVector<OpFoldResult> strides(mbMemRefTypeRank, one);
  // Offset is [bufferIndex, 0 ... 0 ].
  offsets.front() = bufferIndex;
  // Sizes is [1, original_size_0 ... original_size_n ].
  for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
    sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
  // Strides is [1, 1 ... 1 ].
  auto dstMemref =
      cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
          originalShape, mbMemRefType, offsets, sizes, strides));
  Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
                                                     offsets, sizes, strides);
  LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");

  // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
  // handle dealloc uses separately..
  for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
    auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
    if (!deallocOp)
      continue;
    OpBuilder::InsertionGuard g(rewriter);
    rewriter.setInsertionPoint(deallocOp);
    auto newDeallocOp =
        rewriter.create<memref::DeallocOp>(deallocOp->getLoc(), mbAlloc);
    (void)newDeallocOp;
    LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n");
    rewriter.eraseOp(deallocOp);
  }

  // 6. RAUW with the particular slice, taking modular rotation into account.
  replaceUsesAndPropagateType(rewriter, allocOp, subview);

  // 7. Finally, erase the old allocOp.
  rewriter.eraseOp(allocOp);

  return mbAlloc;
}

FailureOr<memref::AllocOp>
mlir::memref::multiBuffer(memref::AllocOp allocOp,
                          unsigned multiBufferingFactor,
                          bool skipOverrideAnalysis) {
  IRRewriter rewriter(allocOp->getContext());
  return multiBuffer(rewriter, allocOp, multiBufferingFactor,
                     skipOverrideAnalysis);
}