//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::scf;

namespace mlir {
namespace scf {
namespace {

/// Helper function for loop bufferization. Cast the given buffer to the given
/// memref type.
static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
  assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
  assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
  // If the buffer already has the correct type, no cast is needed.
  if (buffer.getType() == type)
    return buffer;
  // TODO: In case `type` has a layout map that is not the fully dynamic
  // one, we may not be able to cast the buffer. In that case, the loop
  // iter_arg's layout map must be changed (see uses of `castBuffer`).
  assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
         "scf.while op bufferization: cast incompatible");
  return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
}

/// Helper function for loop bufferization. Return "true" if the given value
/// is guaranteed to not alias with an external tensor apart from values in
/// `exceptions`. A value is external if it is defined outside of the given
/// region or if it is an entry block argument of the region.
static bool doesNotAliasExternalValue(Value value, Region *region,
                                      ValueRange exceptions,
                                      const OneShotAnalysisState &state) {
  assert(region->getBlocks().size() == 1 &&
         "expected region with single block");
  bool result = true;
  state.applyOnAliases(value, [&](Value alias) {
    if (llvm::is_contained(exceptions, alias))
      return;
    Region *aliasRegion = alias.getParentRegion();
    if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion))
      result = false;
    if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion))
      result = false;
  });
  return result;
}

/// Bufferization of scf.condition.
struct ConditionOpInterface
    : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
                                                    scf::ConditionOp> {
  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                              const AnalysisState &state) const {
    return true;
  }

  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
    return false;
  }

  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                      const AnalysisState &state) const {
    return {};
  }

  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
                            const AnalysisState &state) const {
    // Condition operands always bufferize inplace. Otherwise, an alloc + copy
    // may be generated inside the block. We should not return/yield allocations
    // when possible.
    return true;
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    auto conditionOp = cast<scf::ConditionOp>(op);
    auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());

    SmallVector<Value> newArgs;
    for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
      Value value = it.value();
      if (isa<TensorType>(value.getType())) {
        FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
        if (failed(maybeBuffer))
          return failure();
        FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
            whileOp.getAfterArguments()[it.index()], options);
        if (failed(resultType))
          return failure();
        Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
        newArgs.push_back(buffer);
      } else {
        newArgs.push_back(value);
      }
    }

    replaceOpWithNewBufferizedOp<scf::ConditionOp>(
        rewriter, op, conditionOp.getCondition(), newArgs);
    return success();
  }
};

/// Return the unique scf.yield op. If there are multiple or no scf.yield ops,
/// return an empty op.
static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
  scf::YieldOp result;
  for (Block &block : executeRegionOp.getRegion()) {
    if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
      if (result)
        return {};
      result = yieldOp;
    }
  }
  return result;
}

/// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
/// fully implemented at the moment.
struct ExecuteRegionOpInterface
    : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
          ExecuteRegionOpInterface, scf::ExecuteRegionOp> {

  static bool supportsUnstructuredControlFlow() { return true; }

  bool isWritable(Operation *op, Value value,
                  const AnalysisState &state) const {
    return true;
  }

  LogicalResult verifyAnalysis(Operation *op,
                               const AnalysisState &state) const {
    auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
    // TODO: scf.execute_region with multiple yields are not supported.
    if (!getUniqueYieldOp(executeRegionOp))
      return op->emitOpError("op without unique scf.yield is not supported");
    return success();
  }

  AliasingOpOperandList
  getAliasingOpOperands(Operation *op, Value value,
                        const AnalysisState &state) const {
    if (auto bbArg = dyn_cast<BlockArgument>(value))
      return getAliasingBranchOpOperands(op, bbArg, state);

    // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
    // any SSA value that is in scope. To allow for use-def chain traversal
    // through ExecuteRegionOps in the analysis, the corresponding yield value
    // is considered to be aliasing with the result.
    auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
    auto it = llvm::find(op->getOpResults(), value);
    assert(it != op->getOpResults().end() && "invalid value");
    size_t resultNum = std::distance(op->getOpResults().begin(), it);
    auto yieldOp = getUniqueYieldOp(executeRegionOp);
    // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail.
    if (!yieldOp)
      return {};
    return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
    auto yieldOp = getUniqueYieldOp(executeRegionOp);
    TypeRange newResultTypes(yieldOp.getResults());

    // Create new op and move over region.
    auto newOp =
        rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
    newOp.getRegion().takeBody(executeRegionOp.getRegion());

    // Bufferize every block.
    for (Block &block : newOp.getRegion())
      if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
                                                        options)))
        return failure();

    // Update all uses of the old op.
    rewriter.setInsertionPointAfter(newOp);
    SmallVector<Value> newResults;
    for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
      if (isa<TensorType>(it.value())) {
        newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
            executeRegionOp.getLoc(), newOp->getResult(it.index())));
      } else {
        newResults.push_back(newOp->getResult(it.index()));
      }
    }

    // Replace old op.
    rewriter.replaceOp(executeRegionOp, newResults);

    return success();
  }
};

/// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
struct IfOpInterface
    : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
  AliasingOpOperandList
  getAliasingOpOperands(Operation *op, Value value,
                        const AnalysisState &state) const {
    // IfOps do not have tensor OpOperands. The yielded value can be any SSA
    // value that is in scope. To allow for use-def chain traversal through
    // IfOps in the analysis, both corresponding yield values from the then/else
    // branches are considered to be aliasing with the result.
    auto ifOp = cast<scf::IfOp>(op);
    size_t resultNum = std::distance(op->getOpResults().begin(),
                                     llvm::find(op->getOpResults(), value));
    OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
    OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
    return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false},
            {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}};
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    OpBuilder::InsertionGuard g(rewriter);
    auto ifOp = cast<scf::IfOp>(op);

    // Compute bufferized result types.
    SmallVector<Type> newTypes;
    for (Value result : ifOp.getResults()) {
      if (!isa<TensorType>(result.getType())) {
        newTypes.push_back(result.getType());
        continue;
      }
      auto bufferType = bufferization::getBufferType(result, options);
      if (failed(bufferType))
        return failure();
      newTypes.push_back(*bufferType);
    }

    // Create new op.
    rewriter.setInsertionPoint(ifOp);
    auto newIfOp =
        rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
                                   /*withElseRegion=*/true);

    // Move over then/else blocks.
    rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
    rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());

    // Replace op results.
    replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());

    return success();
  }

  FailureOr<BaseMemRefType>
  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                SmallVector<Value> &invocationStack) const {
    auto ifOp = cast<scf::IfOp>(op);
    auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
    auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
    assert(value.getDefiningOp() == op && "invalid valid");

    // Determine buffer types of the true/false branches.
    auto opResult = cast<OpResult>(value);
    auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
    auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
    BaseMemRefType thenBufferType, elseBufferType;
    if (isa<BaseMemRefType>(thenValue.getType())) {
      // True branch was already bufferized.
      thenBufferType = cast<BaseMemRefType>(thenValue.getType());
    } else {
      auto maybeBufferType =
          bufferization::getBufferType(thenValue, options, invocationStack);
      if (failed(maybeBufferType))
        return failure();
      thenBufferType = *maybeBufferType;
    }
    if (isa<BaseMemRefType>(elseValue.getType())) {
      // False branch was already bufferized.
      elseBufferType = cast<BaseMemRefType>(elseValue.getType());
    } else {
      auto maybeBufferType =
          bufferization::getBufferType(elseValue, options, invocationStack);
      if (failed(maybeBufferType))
        return failure();
      elseBufferType = *maybeBufferType;
    }

    // Best case: Both branches have the exact same buffer type.
    if (thenBufferType == elseBufferType)
      return thenBufferType;

    // Memory space mismatch.
    if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
      return op->emitError("inconsistent memory space on then/else branches");

    // Layout maps are different: Promote to fully dynamic layout map.
    return getMemRefTypeWithFullyDynamicLayout(
        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
  }
};

/// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
/// yields memrefs.
struct IndexSwitchOpInterface
    : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
                                                    scf::IndexSwitchOp> {
  AliasingOpOperandList
  getAliasingOpOperands(Operation *op, Value value,
                        const AnalysisState &state) const {
    // IndexSwitchOps do not have tensor OpOperands. The yielded value can be
    // any SSA. This is similar to IfOps.
    auto switchOp = cast<scf::IndexSwitchOp>(op);
    int64_t resultNum = cast<OpResult>(value).getResultNumber();
    AliasingOpOperandList result;
    for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
      auto yieldOp =
          cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
      result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
                                        BufferRelation::Equivalent,
                                        /*isDefinite=*/false));
    }
    auto defaultYieldOp =
        cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
    result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
                                      BufferRelation::Equivalent,
                                      /*isDefinite=*/false));
    return result;
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    OpBuilder::InsertionGuard g(rewriter);
    auto switchOp = cast<scf::IndexSwitchOp>(op);

    // Compute bufferized result types.
    SmallVector<Type> newTypes;
    for (Value result : switchOp.getResults()) {
      if (!isa<TensorType>(result.getType())) {
        newTypes.push_back(result.getType());
        continue;
      }
      auto bufferType = bufferization::getBufferType(result, options);
      if (failed(bufferType))
        return failure();
      newTypes.push_back(*bufferType);
    }

    // Create new op.
    rewriter.setInsertionPoint(switchOp);
    auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
        switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
        switchOp.getCases().size());

    // Move over blocks.
    for (auto [src, dest] :
         llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
      rewriter.inlineRegionBefore(src, dest, dest.begin());
    rewriter.inlineRegionBefore(switchOp.getDefaultRegion(),
                                newSwitchOp.getDefaultRegion(),
                                newSwitchOp.getDefaultRegion().begin());

    // Replace op results.
    replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());

    return success();
  }

  FailureOr<BaseMemRefType>
  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                SmallVector<Value> &invocationStack) const {
    auto switchOp = cast<scf::IndexSwitchOp>(op);
    assert(value.getDefiningOp() == op && "invalid value");
    int64_t resultNum = cast<OpResult>(value).getResultNumber();

    // Helper function to get buffer type of a case.
    SmallVector<BaseMemRefType> yieldedTypes;
    auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
      auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
      Value yieldedValue = yieldOp->getOperand(resultNum);
      if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
        return bufferType;
      auto maybeBufferType =
          bufferization::getBufferType(yieldedValue, options, invocationStack);
      if (failed(maybeBufferType))
        return failure();
      return maybeBufferType;
    };

    // Compute buffer type of the default case.
    auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
    if (failed(maybeBufferType))
      return failure();
    BaseMemRefType bufferType = *maybeBufferType;

    // Compute buffer types of all other cases.
    for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
      auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
      if (failed(yieldedBufferType))
        return failure();

      // Best case: Both branches have the exact same buffer type.
      if (bufferType == *yieldedBufferType)
        continue;

      // Memory space mismatch.
      if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
        return op->emitError("inconsistent memory space on switch cases");

      // Layout maps are different: Promote to fully dynamic layout map.
      bufferType = getMemRefTypeWithFullyDynamicLayout(
          cast<TensorType>(value.getType()), bufferType.getMemorySpace());
    }

    return bufferType;
  }
};

/// Helper function for loop bufferization. Return the indices of all values
/// that have a tensor type.
static DenseSet<int64_t> getTensorIndices(ValueRange values) {
  DenseSet<int64_t> result;
  for (const auto &it : llvm::enumerate(values))
    if (isa<TensorType>(it.value().getType()))
      result.insert(it.index());
  return result;
}

/// Helper function for loop bufferization. Return the indices of all
/// bbArg/yielded value pairs who's buffer relation is "Equivalent".
DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
                                       ValueRange yieldedValues,
                                       const AnalysisState &state) {
  unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
  DenseSet<int64_t> result;
  for (unsigned int i = 0; i < minSize; ++i) {
    if (!isa<TensorType>(bbArgs[i].getType()) ||
        !isa<TensorType>(yieldedValues[i].getType()))
      continue;
    if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
      result.insert(i);
  }
  return result;
}

/// Helper function for loop bufferization. Return the bufferized values of the
/// given OpOperands. If an operand is not a tensor, return the original value.
static FailureOr<SmallVector<Value>>
getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
           const BufferizationOptions &options) {
  SmallVector<Value> result;
  for (OpOperand &opOperand : operands) {
    if (isa<TensorType>(opOperand.get().getType())) {
      FailureOr<Value> resultBuffer =
          getBuffer(rewriter, opOperand.get(), options);
      if (failed(resultBuffer))
        return failure();
      result.push_back(*resultBuffer);
    } else {
      result.push_back(opOperand.get());
    }
  }
  return result;
}

/// Helper function for loop bufferization. Given a list of bbArgs of the new
/// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
/// ToTensorOps, so that the block body can be moved over to the new op.
static SmallVector<Value>
getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
                     const DenseSet<int64_t> &tensorIndices) {
  SmallVector<Value> result;
  for (const auto &it : llvm::enumerate(bbArgs)) {
    size_t idx = it.index();
    Value val = it.value();
    if (tensorIndices.contains(idx)) {
      result.push_back(
          rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
              .getResult());
    } else {
      result.push_back(val);
    }
  }
  return result;
}

/// Compute the bufferized type of a loop iter_arg. This type must be equal to
/// the bufferized type of the corresponding init_arg and the bufferized type
/// of the corresponding yielded value.
///
/// This function uses bufferization::getBufferType to compute the bufferized
/// type of the init_arg and of the yielded value. (The computation of the
/// bufferized yielded value type usually requires computing the bufferized type
/// of the iter_arg again; the implementation of getBufferType traces back the
/// use-def chain of the given value and computes a buffer type along the way.)
/// If both buffer types are equal, no casts are needed the computed buffer type
/// can be used directly. Otherwise, the buffer types can only differ in their
/// layout map and a cast must be inserted.
static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
    Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
    const BufferizationOptions &options, SmallVector<Value> &invocationStack) {
  // Determine the buffer type of the init_arg.
  auto initArgBufferType =
      bufferization::getBufferType(initArg, options, invocationStack);
  if (failed(initArgBufferType))
    return failure();

  if (llvm::count(invocationStack, iterArg) >= 2) {
    // If the iter_arg is already twice on the invocation stack, just take the
    // type of the init_arg. This is to avoid infinite loops when calculating
    // the buffer type. This will most likely result in computing a memref type
    // with a fully dynamic layout map.

    // Note: For more precise layout map computation, a fixpoint iteration could
    // be done (i.e., re-computing the yielded buffer type until the bufferized
    // iter_arg type no longer changes). This current implementation immediately
    // switches to a fully dynamic layout map when a mismatch between bufferized
    // init_arg type and bufferized yield value type is detected.
    return *initArgBufferType;
  }

  // Compute the buffer type of the yielded value.
  BaseMemRefType yieldedValueBufferType;
  if (isa<BaseMemRefType>(yieldedValue.getType())) {
    // scf.yield was already bufferized.
    yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
  } else {
    // Note: This typically triggers a recursive call for the buffer type of
    // the iter_arg.
    auto maybeBufferType =
        bufferization::getBufferType(yieldedValue, options, invocationStack);
    if (failed(maybeBufferType))
      return failure();
    yieldedValueBufferType = *maybeBufferType;
  }

  // If yielded type and init_arg type are the same, use that type directly.
  if (*initArgBufferType == yieldedValueBufferType)
    return yieldedValueBufferType;

  // If there is a mismatch between the yielded buffer type and the init_arg
  // buffer type, the buffer type must be promoted to a fully dynamic layout
  // map.
  auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
  auto iterTensorType = cast<TensorType>(iterArg.getType());
  auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
  if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
    return loopOp->emitOpError(
        "init_arg and yielded value bufferize to inconsistent memory spaces");
#ifndef NDEBUG
  if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
    assert(
        llvm::all_equal({yieldedRankedBufferType.getShape(),
                         cast<MemRefType>(initBufferType).getShape(),
                         cast<RankedTensorType>(iterTensorType).getShape()}) &&
        "expected same shape");
  }
#endif // NDEBUG
  return getMemRefTypeWithFullyDynamicLayout(
      iterTensorType, yieldedBufferType.getMemorySpace());
}

/// Return `true` if the given loop may have 0 iterations.
bool mayHaveZeroIterations(scf::ForOp forOp) {
  std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound());
  std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound());
  if (!lb.has_value() || !ub.has_value())
    return true;
  return *ub <= *lb;
}

/// Bufferization of scf.for. Replace with a new scf.for that operates on
/// memrefs.
struct ForOpInterface
    : public BufferizableOpInterface::ExternalModel<ForOpInterface,
                                                    scf::ForOp> {
  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                              const AnalysisState &state) const {
    auto forOp = cast<scf::ForOp>(op);

    // If the loop has zero iterations, the results of the op are their
    // corresponding init_args, meaning that the init_args bufferize to a read.
    if (mayHaveZeroIterations(forOp))
      return true;

    // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
    // its matching bbArg may.
    return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
  }

  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
    // Tensor iter_args of scf::ForOps are always considered as a write.
    return true;
  }

  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                      const AnalysisState &state) const {
    auto forOp = cast<scf::ForOp>(op);
    OpResult opResult = forOp.getTiedLoopResult(&opOperand);
    BufferRelation relation = bufferRelation(op, opResult, state);
    return {{opResult, relation,
             /*isDefinite=*/relation == BufferRelation::Equivalent}};
  }

  BufferRelation bufferRelation(Operation *op, OpResult opResult,
                                const AnalysisState &state) const {
    // ForOp results are equivalent to their corresponding init_args if the
    // corresponding iter_args and yield values are equivalent.
    auto forOp = cast<scf::ForOp>(op);
    BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
    bool equivalentYield = state.areEquivalentBufferizedValues(
        bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
    return equivalentYield ? BufferRelation::Equivalent
                           : BufferRelation::Unknown;
  }

  bool isWritable(Operation *op, Value value,
                  const AnalysisState &state) const {
    // Interestingly, scf::ForOp's bbArg can **always** be viewed
    // inplace from the perspective of ops nested under:
    //   1. Either the matching iter operand is not bufferized inplace and an
    //      alloc + optional copy makes the bbArg itself inplaceable.
    //   2. Or the matching iter operand is bufferized inplace and bbArg just
    //      bufferizes to that too.
    return true;
  }

  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
                                 const AnalysisState &state) const {
    auto bufferizableOp = cast<BufferizableOpInterface>(op);
    if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
      return failure();

    if (!state.getOptions().enforceAliasingInvariants)
      return success();

    // According to the `getAliasing...` implementations, a bufferized OpResult
    // may alias only with the corresponding bufferized init_arg (or with a
    // newly allocated buffer) and not with other buffers defined outside of the
    // loop. I.e., the i-th OpResult may alias with the i-th init_arg;
    // but not with any other OpOperand.
    auto forOp = cast<scf::ForOp>(op);
    auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
    OpBuilder::InsertionGuard g(rewriter);
    rewriter.setInsertionPoint(yieldOp);

    // Indices of all iter_args that have tensor type. These are the ones that
    // are bufferized.
    DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
    // For every yielded value, does it alias with something defined outside of
    // the loop?
    SmallVector<Value> yieldValues;
    for (const auto it : llvm::enumerate(yieldOp.getResults())) {
      // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
      // type cannot be used in the signature of `resolveConflicts` because the
      // op interface is in the "IR" build unit and the `OneShotAnalysisState`
      // is defined in the "Transforms" build unit.
      if (!indices.contains(it.index()) ||
          doesNotAliasExternalValue(
              it.value(), &forOp.getRegion(),
              /*exceptions=*/forOp.getRegionIterArg(it.index()),
              static_cast<const OneShotAnalysisState &>(state))) {
        yieldValues.push_back(it.value());
        continue;
      }
      FailureOr<Value> alloc = allocateTensorForShapedValue(
          rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
      if (failed(alloc))
        return failure();
      yieldValues.push_back(*alloc);
    }

    rewriter.modifyOpInPlace(
        yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
    return success();
  }

  FailureOr<BaseMemRefType>
  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                SmallVector<Value> &invocationStack) const {
    auto forOp = cast<scf::ForOp>(op);
    assert(getOwnerOfValue(value) == op && "invalid value");
    assert(isa<TensorType>(value.getType()) && "expected tensor type");

    if (auto opResult = dyn_cast<OpResult>(value)) {
      // The type of an OpResult must match the corresponding iter_arg type.
      BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
      return bufferization::getBufferType(bbArg, options, invocationStack);
    }

    // Compute result/argument number.
    BlockArgument bbArg = cast<BlockArgument>(value);
    unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();

    // Compute the bufferized type.
    auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
    Value yieldedValue = yieldOp.getOperand(resultNum);
    BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
    Value initArg = forOp.getInitArgs()[resultNum];
    return computeLoopRegionIterArgBufferType(
        op, iterArg, initArg, yieldedValue, options, invocationStack);
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    auto forOp = cast<scf::ForOp>(op);
    Block *oldLoopBody = forOp.getBody();

    // Indices of all iter_args that have tensor type. These are the ones that
    // are bufferized.
    DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());

    // The new memref init_args of the loop.
    FailureOr<SmallVector<Value>> maybeInitArgs =
        getBuffers(rewriter, forOp.getInitArgsMutable(), options);
    if (failed(maybeInitArgs))
      return failure();
    SmallVector<Value> initArgs = *maybeInitArgs;

    // Cast init_args if necessary.
    SmallVector<Value> castedInitArgs;
    for (const auto &it : llvm::enumerate(initArgs)) {
      Value initArg = it.value();
      Value result = forOp->getResult(it.index());
      // If the type is not a tensor, bufferization doesn't need to touch it.
      if (!isa<TensorType>(result.getType())) {
        castedInitArgs.push_back(initArg);
        continue;
      }
      auto targetType = bufferization::getBufferType(result, options);
      if (failed(targetType))
        return failure();
      castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
    }

    // Construct a new scf.for op with memref instead of tensor values.
    auto newForOp = rewriter.create<scf::ForOp>(
        forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
        forOp.getStep(), castedInitArgs);
    newForOp->setAttrs(forOp->getAttrs());
    Block *loopBody = newForOp.getBody();

    // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
    // iter_args of the new loop in ToTensorOps.
    rewriter.setInsertionPointToStart(loopBody);
    SmallVector<Value> iterArgs =
        getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
    iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());

    // Move loop body to new loop.
    rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);

    // Replace loop results.
    replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());

    return success();
  }

  /// Assert that yielded values of an scf.for op are equivalent to their
  /// corresponding bbArgs. In that case, the buffer relations of the
  /// corresponding OpResults are "Equivalent".
  ///
  /// If this is not the case, an allocs+copies are inserted and yielded from
  /// the loop. This could be a performance problem, so it must be explicitly
  /// activated with `alloc-return-allocs`.
  LogicalResult verifyAnalysis(Operation *op,
                               const AnalysisState &state) const {
    const auto &options =
        static_cast<const OneShotBufferizationOptions &>(state.getOptions());
    if (options.allowReturnAllocsFromLoops)
      return success();

    auto forOp = cast<scf::ForOp>(op);
    auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
    for (OpResult opResult : op->getOpResults()) {
      if (!isa<TensorType>(opResult.getType()))
        continue;

      // Note: This is overly strict. We should check for aliasing bufferized
      // values. But we don't have a "must-alias" analysis yet.
      if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
        return yieldOp->emitError()
               << "Yield operand #" << opResult.getResultNumber()
               << " is not equivalent to the corresponding iter bbArg";
    }

    return success();
  }
};

/// Bufferization of scf.while. Replace with a new scf.while that operates on
/// memrefs.
struct WhileOpInterface
    : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
                                                    scf::WhileOp> {
  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                              const AnalysisState &state) const {
    // Tensor iter_args of scf::WhileOps are always considered as a read.
    return true;
  }

  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
    // Tensor iter_args of scf::WhileOps are always considered as a write.
    return true;
  }

  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                      const AnalysisState &state) const {
    auto whileOp = cast<scf::WhileOp>(op);
    unsigned int idx = opOperand.getOperandNumber();

    // The OpResults and OpOperands may not match. They may not even have the
    // same type. The number of OpResults and OpOperands can also differ.
    if (idx >= op->getNumResults() ||
        opOperand.get().getType() != op->getResult(idx).getType())
      return {};

    // The only aliasing OpResult may be the one at the same index.
    OpResult opResult = whileOp->getResult(idx);
    BufferRelation relation = bufferRelation(op, opResult, state);
    return {{opResult, relation,
             /*isDefinite=*/relation == BufferRelation::Equivalent}};
  }

  BufferRelation bufferRelation(Operation *op, OpResult opResult,
                                const AnalysisState &state) const {
    // WhileOp results are equivalent to their corresponding init_args if the
    // corresponding iter_args and yield values are equivalent (for both the
    // "before" and the "after" block).
    unsigned int resultNumber = opResult.getResultNumber();
    auto whileOp = cast<scf::WhileOp>(op);

    // The "before" region bbArgs and the OpResults may not match.
    if (resultNumber >= whileOp.getBeforeArguments().size())
      return BufferRelation::Unknown;
    if (opResult.getType() !=
        whileOp.getBeforeArguments()[resultNumber].getType())
      return BufferRelation::Unknown;

    auto conditionOp = whileOp.getConditionOp();
    BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
    Value conditionOperand = conditionOp.getArgs()[resultNumber];
    bool equivCondition =
        state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);

    auto yieldOp = whileOp.getYieldOp();
    BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
    Value yieldOperand = yieldOp.getOperand(resultNumber);
    bool equivYield =
        state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);

    return equivCondition && equivYield ? BufferRelation::Equivalent
                                        : BufferRelation::Unknown;
  }

  bool isWritable(Operation *op, Value value,
                  const AnalysisState &state) const {
    // Interestingly, scf::WhileOp's bbArg can **always** be viewed
    // inplace from the perspective of ops nested under:
    //   1. Either the matching iter operand is not bufferized inplace and an
    //      alloc + optional copy makes the bbArg itself inplaceable.
    //   2. Or the matching iter operand is bufferized inplace and bbArg just
    //      bufferizes to that too.
    return true;
  }

  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
                                 const AnalysisState &state) const {
    auto bufferizableOp = cast<BufferizableOpInterface>(op);
    if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
      return failure();

    if (!state.getOptions().enforceAliasingInvariants)
      return success();

    // According to the `getAliasing...` implementations, a bufferized OpResult
    // may alias only with the corresponding bufferized init_arg and with no
    // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
    // but not with any other OpOperand. If a corresponding OpResult/init_arg
    // pair bufferizes to equivalent buffers, this aliasing requirement is
    // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
    // (New buffer copies do not alias with any buffer.)
    OpBuilder::InsertionGuard g(rewriter);
    auto whileOp = cast<scf::WhileOp>(op);
    auto conditionOp = whileOp.getConditionOp();

    // For every yielded value, is the value equivalent to its corresponding
    // bbArg?
    DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
        whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
    DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
        whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);

    // Update "before" region.
    rewriter.setInsertionPoint(conditionOp);
    SmallVector<Value> beforeYieldValues;
    for (int64_t idx = 0;
         idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
      Value value = conditionOp.getArgs()[idx];
      if (!isa<TensorType>(value.getType()) ||
          (equivalentYieldsAfter.contains(idx) &&
           equivalentYieldsBefore.contains(idx))) {
        beforeYieldValues.push_back(value);
        continue;
      }
      FailureOr<Value> alloc = allocateTensorForShapedValue(
          rewriter, conditionOp.getLoc(), value, state.getOptions());
      if (failed(alloc))
        return failure();
      beforeYieldValues.push_back(*alloc);
    }
    rewriter.modifyOpInPlace(conditionOp, [&]() {
      conditionOp.getArgsMutable().assign(beforeYieldValues);
    });

    return success();
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    auto whileOp = cast<scf::WhileOp>(op);

    // Indices of all bbArgs that have tensor type. These are the ones that
    // are bufferized. The "before" and "after" regions may have different args.
    DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
    DenseSet<int64_t> indicesAfter =
        getTensorIndices(whileOp.getAfterArguments());

    // The new memref init_args of the loop.
    FailureOr<SmallVector<Value>> maybeInitArgs =
        getBuffers(rewriter, whileOp.getInitsMutable(), options);
    if (failed(maybeInitArgs))
      return failure();
    SmallVector<Value> initArgs = *maybeInitArgs;

    // Cast init_args if necessary.
    SmallVector<Value> castedInitArgs;
    for (const auto &it : llvm::enumerate(initArgs)) {
      Value initArg = it.value();
      Value beforeArg = whileOp.getBeforeArguments()[it.index()];
      // If the type is not a tensor, bufferization doesn't need to touch it.
      if (!isa<TensorType>(beforeArg.getType())) {
        castedInitArgs.push_back(initArg);
        continue;
      }
      auto targetType = bufferization::getBufferType(beforeArg, options);
      if (failed(targetType))
        return failure();
      castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
    }

    // The result types of a WhileOp are the same as the "after" bbArg types.
    SmallVector<Type> argsTypesAfter = llvm::to_vector(
        llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
          if (!isa<TensorType>(bbArg.getType()))
            return bbArg.getType();
          // TODO: error handling
          return llvm::cast<Type>(
              *bufferization::getBufferType(bbArg, options));
        }));

    // Construct a new scf.while op with memref instead of tensor values.
    ValueRange argsRangeBefore(castedInitArgs);
    TypeRange argsTypesBefore(argsRangeBefore);
    auto newWhileOp = rewriter.create<scf::WhileOp>(
        whileOp.getLoc(), argsTypesAfter, castedInitArgs);

    // Add before/after regions to the new op.
    SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
                                          whileOp.getLoc());
    SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
                                         whileOp.getLoc());
    Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
    newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
    Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
    newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);

    // Set up new iter_args and move the loop condition block to the new op.
    // The old block uses tensors, so wrap the (memref) bbArgs of the new block
    // in ToTensorOps.
    rewriter.setInsertionPointToStart(newBeforeBody);
    SmallVector<Value> newBeforeArgs = getBbArgReplacements(
        rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
    rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);

    // Set up new iter_args and move the loop body block to the new op.
    // The old block uses tensors, so wrap the (memref) bbArgs of the new block
    // in ToTensorOps.
    rewriter.setInsertionPointToStart(newAfterBody);
    SmallVector<Value> newAfterArgs = getBbArgReplacements(
        rewriter, newWhileOp.getAfterArguments(), indicesAfter);
    rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);

    // Replace loop results.
    replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());

    return success();
  }

  FailureOr<BaseMemRefType>
  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                SmallVector<Value> &invocationStack) const {
    auto whileOp = cast<scf::WhileOp>(op);
    assert(getOwnerOfValue(value) == op && "invalid value");
    assert(isa<TensorType>(value.getType()) && "expected tensor type");

    // Case 1: Block argument of the "before" region.
    if (auto bbArg = dyn_cast<BlockArgument>(value)) {
      if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
        Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
        auto yieldOp = whileOp.getYieldOp();
        Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
        return computeLoopRegionIterArgBufferType(
            op, bbArg, initArg, yieldedValue, options, invocationStack);
      }
    }

    // Case 2: OpResult of the loop or block argument of the "after" region.
    // The bufferized "after" bbArg type can be directly computed from the
    // bufferized "before" bbArg type.
    unsigned resultNum;
    if (auto opResult = dyn_cast<OpResult>(value)) {
      resultNum = opResult.getResultNumber();
    } else if (cast<BlockArgument>(value).getOwner()->getParent() ==
               &whileOp.getAfter()) {
      resultNum = cast<BlockArgument>(value).getArgNumber();
    } else {
      llvm_unreachable("invalid value");
    }
    Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
    if (!isa<TensorType>(conditionYieldedVal.getType())) {
      // scf.condition was already bufferized.
      return cast<BaseMemRefType>(conditionYieldedVal.getType());
    }
    return bufferization::getBufferType(conditionYieldedVal, options,
                                        invocationStack);
  }

  /// Assert that yielded values of an scf.while op are equivalent to their
  /// corresponding bbArgs. In that case, the buffer relations of the
  /// corresponding OpResults are "Equivalent".
  ///
  /// If this is not the case, allocs+copies are inserted and yielded from
  /// the loop. This could be a performance problem, so it must be explicitly
  /// activated with `allow-return-allocs`.
  ///
  /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
  /// equivalence condition must be checked for both.
  LogicalResult verifyAnalysis(Operation *op,
                               const AnalysisState &state) const {
    auto whileOp = cast<scf::WhileOp>(op);
    const auto &options =
        static_cast<const OneShotBufferizationOptions &>(state.getOptions());
    if (options.allowReturnAllocsFromLoops)
      return success();

    auto conditionOp = whileOp.getConditionOp();
    for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
      Block *block = conditionOp->getBlock();
      if (!isa<TensorType>(it.value().getType()))
        continue;
      if (it.index() >= block->getNumArguments() ||
          !state.areEquivalentBufferizedValues(it.value(),
                                               block->getArgument(it.index())))
        return conditionOp->emitError()
               << "Condition arg #" << it.index()
               << " is not equivalent to the corresponding iter bbArg";
    }

    auto yieldOp = whileOp.getYieldOp();
    for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
      Block *block = yieldOp->getBlock();
      if (!isa<TensorType>(it.value().getType()))
        continue;
      if (it.index() >= block->getNumArguments() ||
          !state.areEquivalentBufferizedValues(it.value(),
                                               block->getArgument(it.index())))
        return yieldOp->emitError()
               << "Yield operand #" << it.index()
               << " is not equivalent to the corresponding iter bbArg";
    }

    return success();
  }
};

/// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
/// this is for analysis only.
struct YieldOpInterface
    : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
                                                    scf::YieldOp> {
  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                              const AnalysisState &state) const {
    return true;
  }

  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
    return false;
  }

  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                      const AnalysisState &state) const {
    if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) {
      return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
               BufferRelation::Equivalent, /*isDefinite=*/false}};
    }
    if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
      return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
               BufferRelation::Equivalent}};
    return {};
  }

  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
                            const AnalysisState &state) const {
    // Yield operands always bufferize inplace. Otherwise, an alloc + copy
    // may be generated inside the block. We should not return/yield allocations
    // when possible.
    return true;
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    auto yieldOp = cast<scf::YieldOp>(op);
    if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
             scf::WhileOp>(yieldOp->getParentOp()))
      return yieldOp->emitError("unsupported scf::YieldOp parent");

    SmallVector<Value> newResults;
    for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
      Value value = it.value();
      if (isa<TensorType>(value.getType())) {
        FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
        if (failed(maybeBuffer))
          return failure();
        Value buffer = *maybeBuffer;
        // We may have to cast the value before yielding it.
        if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
                yieldOp->getParentOp())) {
          FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
              yieldOp->getParentOp()->getResult(it.index()), options);
          if (failed(resultType))
            return failure();
          buffer = castBuffer(rewriter, buffer, *resultType);
        } else if (auto whileOp =
                       dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
          FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
              whileOp.getBeforeArguments()[it.index()], options);
          if (failed(resultType))
            return failure();
          buffer = castBuffer(rewriter, buffer, *resultType);
        }
        newResults.push_back(buffer);
      } else {
        newResults.push_back(value);
      }
    }

    replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
    return success();
  }
};

/// Return `true` if the given loop may have 0 iterations.
bool mayHaveZeroIterations(scf::ForallOp forallOp) {
  for (auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(),
                                 forallOp.getMixedUpperBound())) {
    std::optional<int64_t> lbConst = getConstantIntValue(lb);
    std::optional<int64_t> ubConst = getConstantIntValue(ub);
    if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst)
      return true;
  }
  return false;
}

/// Bufferization of ForallOp. This also bufferizes the terminator of the
/// region. There are op interfaces for the terminators (InParallelOp
/// and ParallelInsertSliceOp), but these are only used during analysis. Not
/// for bufferization.
struct ForallOpInterface
    : public BufferizableOpInterface::ExternalModel<ForallOpInterface,
                                                    ForallOp> {
  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                              const AnalysisState &state) const {
    auto forallOp = cast<ForallOp>(op);

    // If the loop has zero iterations, the results of the op are their
    // corresponding shared_outs, meaning that the shared_outs bufferize to a
    // read.
    if (mayHaveZeroIterations(forallOp))
      return true;

    // scf::ForallOp alone doesn't bufferize to a memory read, one of the
    // uses of its matching bbArg may.
    return state.isValueRead(forallOp.getTiedBlockArgument(&opOperand));
  }

  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
    // Outputs of scf::ForallOps are always considered as a write.
    return true;
  }

  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                      const AnalysisState &state) const {
    auto forallOp = cast<ForallOp>(op);
    return {
        {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
  }

  bool isWritable(Operation *op, Value value,
                  const AnalysisState &state) const {
    return true;
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    OpBuilder::InsertionGuard guard(rewriter);
    auto forallOp = cast<ForallOp>(op);
    int64_t rank = forallOp.getRank();

    // Get buffers for all output operands.
    SmallVector<Value> buffers;
    for (Value out : forallOp.getOutputs()) {
      FailureOr<Value> buffer = getBuffer(rewriter, out, options);
      if (failed(buffer))
        return failure();
      buffers.push_back(*buffer);
    }

    // Use buffers instead of block arguments.
    rewriter.setInsertionPointToStart(forallOp.getBody());
    for (const auto &it : llvm::zip(
             forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
      BlockArgument bbArg = std::get<0>(it);
      Value buffer = std::get<1>(it);
      Value bufferAsTensor =
          rewriter.create<ToTensorOp>(forallOp.getLoc(), buffer);
      bbArg.replaceAllUsesWith(bufferAsTensor);
    }

    // Create new ForallOp without any results and drop the automatically
    // introduced terminator.
    rewriter.setInsertionPoint(forallOp);
    ForallOp newForallOp;
    newForallOp = rewriter.create<ForallOp>(
        forallOp.getLoc(), forallOp.getMixedLowerBound(),
        forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
        /*outputs=*/ValueRange(), forallOp.getMapping());

    // Keep discardable attributes from the original op.
    newForallOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());

    rewriter.eraseOp(newForallOp.getBody()->getTerminator());

    // Move over block contents of the old op.
    SmallVector<Value> replacementBbArgs;
    replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
                             newForallOp.getBody()->getArguments().end());
    replacementBbArgs.append(forallOp.getOutputs().size(), Value());
    rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
                         replacementBbArgs);

    // Remove the old op and replace all of its uses.
    replaceOpWithBufferizedValues(rewriter, op, buffers);

    return success();
  }

  FailureOr<BaseMemRefType>
  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                SmallVector<Value> &invocationStack) const {
    auto forallOp = cast<ForallOp>(op);

    if (auto bbArg = dyn_cast<BlockArgument>(value))
      // A tensor block argument has the same bufferized type as the
      // corresponding output operand.
      return bufferization::getBufferType(
          forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack);

    // The bufferized result type is the same as the bufferized type of the
    // corresponding output operand.
    return bufferization::getBufferType(
        forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
        invocationStack);
  }

  bool isRepetitiveRegion(Operation *op, unsigned index) const {
    auto forallOp = cast<ForallOp>(op);

    // This op is repetitive if it has 1 or more steps.
    // If the control variables are dynamic, it is also considered so.
    for (auto [lb, ub, step] :
         llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
                   forallOp.getMixedStep())) {
      std::optional<int64_t> lbConstant = getConstantIntValue(lb);
      if (!lbConstant)
        return true;

      std::optional<int64_t> ubConstant = getConstantIntValue(ub);
      if (!ubConstant)
        return true;

      std::optional<int64_t> stepConstant = getConstantIntValue(step);
      if (!stepConstant)
        return true;

      if (*lbConstant + *stepConstant < *ubConstant)
        return true;
    }
    return false;
  }

  bool isParallelRegion(Operation *op, unsigned index) const {
    return isRepetitiveRegion(op, index);
  }
};

/// Nothing to do for InParallelOp.
struct InParallelOpInterface
    : public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
                                                    InParallelOp> {
  LogicalResult bufferize(Operation *op, RewriterBase &b,
                          const BufferizationOptions &options) const {
    llvm_unreachable("op does not have any tensor OpOperands / OpResults");
    return failure();
  }
};

} // namespace
} // namespace scf
} // namespace mlir

void mlir::scf::registerBufferizableOpInterfaceExternalModels(
    DialectRegistry &registry) {
  registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
    ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
    ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
    ForOp::attachInterface<ForOpInterface>(*ctx);
    IfOp::attachInterface<IfOpInterface>(*ctx);
    IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
    ForallOp::attachInterface<ForallOpInterface>(*ctx);
    InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
    WhileOp::attachInterface<WhileOpInterface>(*ctx);
    YieldOp::attachInterface<YieldOpInterface>(*ctx);
  });
}