//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"

using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::vector;

namespace mlir {
namespace vector {
namespace {

/// Bufferization of vector.transfer_read. Replaced with a new
/// vector.transfer_read that operates on a memref.
struct TransferReadOpInterface
    : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
                                                    vector::TransferReadOp> {
  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                              const AnalysisState &state) const {
    assert(isa<RankedTensorType>(opOperand.get().getType()) &&
           "only tensor types expected");
    return true;
  }

  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
    assert(isa<RankedTensorType>(opOperand.get().getType()) &&
           "only tensor types expected");
    return false;
  }

  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                      const AnalysisState &state) const {
    return {};
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    auto readOp = cast<vector::TransferReadOp>(op);
    assert(isa<TensorType>(readOp.getShapedType()) &&
           "only tensor types expected");
    FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
    if (failed(buffer))
      return failure();
    replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
        rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
        readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
        readOp.getInBoundsAttr());
    return success();
  }
};

/// Bufferization of vector.transfer_write. Replace with a new
/// vector.transfer_write that operates on a memref.
///
/// Note: DstBufferizableOpInterfaceExternalModel provides many default method
/// implementations for DestinationStyle ops.
struct TransferWriteOpInterface
    : public DstBufferizableOpInterfaceExternalModel<TransferWriteOpInterface,
                                                     vector::TransferWriteOp> {
  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                              const AnalysisState &state) const {
    auto writeOp = cast<vector::TransferWriteOp>(op);

    // Does not bufferize to a memory read if the vector completely overwrites
    // the buffer.

    // Destination must have static shape.
    if (!writeOp.getShapedType().hasStaticShape())
      return true;

    // All offsets must be 0.
    for (Value offset : writeOp.getIndices()) {
      if (getConstantIntValue(offset) != 0)
        return true;
    }

    // There is no mask.
    if (writeOp.isMasked())
      return true;

    // Must write at least the full dimension size.
    for (auto [d0, d1] : llvm::zip(writeOp.getShapedType().getShape(),
                                   writeOp.getVectorType().getShape())) {
      if (d0 > d1)
        return true;
    }

    return false;
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    auto writeOp = cast<vector::TransferWriteOp>(op);
    assert(isa<TensorType>(writeOp.getShapedType()) &&
           "only tensor types expected");

    // Create a new transfer_write on buffer that doesn't have a return value.
    FailureOr<Value> resultBuffer =
        getBuffer(rewriter, writeOp.getSource(), options);
    if (failed(resultBuffer))
      return failure();
    rewriter.create<vector::TransferWriteOp>(
        writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
        writeOp.getIndices(), writeOp.getPermutationMapAttr(),
        writeOp.getMask(), writeOp.getInBoundsAttr());
    replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);

    return success();
  }
};

/// Bufferization of vector.gather. Replaced with a new vector.gather that
/// operates on a memref.
struct GatherOpInterface
    : public BufferizableOpInterface::ExternalModel<GatherOpInterface,
                                                    vector::GatherOp> {
  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                              const AnalysisState &state) const {
    assert(isa<RankedTensorType>(opOperand.get().getType()) &&
           "only tensor types expected");
    return true;
  }

  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
    assert(isa<RankedTensorType>(opOperand.get().getType()) &&
           "only tensor types expected");
    return false;
  }

  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                      const AnalysisState &state) const {
    return {};
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    auto gatherOp = cast<vector::GatherOp>(op);
    assert(isa<TensorType>(gatherOp.getBaseType()) &&
           "only tensor types expected");
    FailureOr<Value> buffer = getBuffer(rewriter, gatherOp.getBase(), options);
    if (failed(buffer))
      return failure();
    replaceOpWithNewBufferizedOp<vector::GatherOp>(
        rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
        gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
        gatherOp.getPassThru());
    return success();
  }
};

/// Bufferization of vector.mask. Replaced with a new vector.mask that
/// operates on a memref.
struct MaskOpInterface
    : public BufferizableOpInterface::ExternalModel<MaskOpInterface,
                                                    vector::MaskOp> {
  AliasingOpOperandList
  getAliasingOpOperands(Operation *op, Value value,
                        const AnalysisState &state) const {
    // MaskOps do not have tensor OpOperands. The yielded values are the result
    // of the wrapped op.
    auto maskOp = cast<vector::MaskOp>(op);
    size_t resultNum = std::distance(op->getOpResults().begin(),
                                     llvm::find(op->getOpResults(), value));
    auto yieldOp =
        cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
    return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
  }

  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
                                 const AnalysisState &state) const {
    auto bufferizableOp = cast<BufferizableOpInterface>(op);
    if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
      return failure();

    // TODO: Remove this function when vector.mask bodies can bufferize
    // out-of-place. This is currently not supported because yielding allocs
    // from a block leads to a memory leak and because vector.mask supports only
    // a single op in its body.
    auto maskOp = cast<vector::MaskOp>(op);
    if (!maskOp.getMaskRegion()
             .front()
             .getOps<bufferization::AllocTensorOp>()
             .empty())
      return op->emitOpError("body must bufferize in-place");

    return success();
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    auto maskOp = cast<vector::MaskOp>(op);

    // Do not bufferize if the masked op is not bufferizable.
    Operation *maskedOp = maskOp.getMaskableOp();
    if (!options.dynCastBufferizableOp(maskedOp))
      return success();

    // Update the terminator: Drop all operands that are not results of the
    // masked op.
    auto yieldOp =
        cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
    SmallVector<Value> newReturnValues(maskOp->getNumResults(), Value());
    SmallVector<Value> newYieldedValues;
    for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
      if (llvm::is_contained(maskedOp->getOpResults(), it.value())) {
        newYieldedValues.push_back(it.value());
      } else {
        // This used to be a tensor result of the masked op, but is now a memref
        // that is defined outside of the vector.mask op.
        newReturnValues[it.index()] = it.value();
      }
    }
    rewriter.modifyOpInPlace(yieldOp, [&]() {
      yieldOp.getOperandsMutable().assign(newYieldedValues);
    });

    // Create a new vector.mask op.
    ValueRange newYieldedValuesRange(newYieldedValues);
    TypeRange newResultTypes(newYieldedValuesRange);
    auto newOp = rewriter.create<vector::MaskOp>(
        op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
        /*maskableOp=*/nullptr,
        /*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {});
    newOp.getRegion().takeBody(maskOp.getMaskRegion());

    // Replace all uses of the old vector.mask op.
    int idx = 0;
    for (int i = 0; i < static_cast<int>(maskOp->getNumResults()); ++i) {
      if (!newReturnValues[i])
        newReturnValues[i] = newOp->getResult(idx++);
    }
    replaceOpWithBufferizedValues(rewriter, maskOp, newReturnValues);
    return success();
  }
};

/// Bufferization of vector.yield. Replaced with a new vector.yield that
/// operates on a memref.
struct YieldOpInterface
    : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
                                                    vector::YieldOp> {
  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                              const AnalysisState &state) const {
    return true;
  }

  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
    return false;
  }

  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                      const AnalysisState &state) const {
    return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
             BufferRelation::Equivalent}};
  }

  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
                            const AnalysisState &state) const {
    // Yield operands always bufferize inplace. Otherwise, an alloc + copy
    // may be generated inside the block. We should not return/yield allocations
    // when possible.
    return true;
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    auto yieldOp = cast<vector::YieldOp>(op);

    // Only supported as a vector.mask terminator.
    auto maskOp = dyn_cast<vector::MaskOp>(yieldOp->getParentOp());
    if (!maskOp)
      return yieldOp->emitError("unsupported vector::YieldOp parent");

    // Do not bufferize if the masked op is not bufferizable.
    Operation *maskedOp = &maskOp.getMaskRegion().front().front();
    if (!options.dynCastBufferizableOp(maskedOp))
      return success();

    // Create a new terminator with the same number of operands. Some of these
    // may get dropped during the bufferization of vector.mask.
    SmallVector<Value> newResults;
    for (Value value : yieldOp.getOperands()) {
      if (isa<TensorType>(value.getType())) {
        FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
        if (failed(maybeBuffer))
          return failure();
        newResults.push_back(*maybeBuffer);
      } else {
        newResults.push_back(value);
      }
    }

    replaceOpWithNewBufferizedOp<vector::YieldOp>(rewriter, op, newResults);
    return success();
  }
};

} // namespace
} // namespace vector
} // namespace mlir

void mlir::vector::registerBufferizableOpInterfaceExternalModels(
    DialectRegistry &registry) {
  registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
    TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
    TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
    GatherOp::attachInterface<GatherOpInterface>(*ctx);
    MaskOp::attachInterface<MaskOpInterface>(*ctx);
    YieldOp::attachInterface<YieldOpInterface>(*ctx);
  });
}