//===- ExtractAddressCmoputations.cpp - Extract address computations  -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// This transformation pass rewrites loading/storing from/to a memref with
/// offsets into loading/storing from/to a subview and without any offset on
/// the instruction itself.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;

namespace {

//===----------------------------------------------------------------------===//
// Helper functions for the `load base[off0...]`
//  => `load (subview base[off0...])[0...]` pattern.
//===----------------------------------------------------------------------===//

// Matches getFailureOrSrcMemRef specs for LoadOp.
// \see LoadStoreLikeOpRewriter.
static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) {
  return loadOp.getMemRef();
}

// Matches rebuildOpFromAddressAndIndices specs for LoadOp.
// \see LoadStoreLikeOpRewriter.
static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter,
                                    memref::LoadOp loadOp, Value srcMemRef,
                                    ArrayRef<Value> indices) {
  Location loc = loadOp.getLoc();
  return rewriter.create<memref::LoadOp>(loc, srcMemRef, indices,
                                         loadOp.getNontemporal());
}

// Matches getViewSizeForEachDim specs for LoadOp.
// \see LoadStoreLikeOpRewriter.
static SmallVector<OpFoldResult>
getLoadOpViewSizeForEachDim(RewriterBase &rewriter, memref::LoadOp loadOp) {
  MemRefType ldTy = loadOp.getMemRefType();
  unsigned loadRank = ldTy.getRank();
  return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
}

//===----------------------------------------------------------------------===//
// Helper functions for the `store val, base[off0...]`
//  => `store val, (subview base[off0...])[0...]` pattern.
//===----------------------------------------------------------------------===//

// Matches getFailureOrSrcMemRef specs for StoreOp.
// \see LoadStoreLikeOpRewriter.
static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) {
  return storeOp.getMemRef();
}

// Matches rebuildOpFromAddressAndIndices specs for StoreOp.
// \see LoadStoreLikeOpRewriter.
static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter,
                                      memref::StoreOp storeOp, Value srcMemRef,
                                      ArrayRef<Value> indices) {
  Location loc = storeOp.getLoc();
  return rewriter.create<memref::StoreOp>(loc, storeOp.getValueToStore(),
                                          srcMemRef, indices,
                                          storeOp.getNontemporal());
}

// Matches getViewSizeForEachDim specs for StoreOp.
// \see LoadStoreLikeOpRewriter.
static SmallVector<OpFoldResult>
getStoreOpViewSizeForEachDim(RewriterBase &rewriter, memref::StoreOp storeOp) {
  MemRefType ldTy = storeOp.getMemRefType();
  unsigned loadRank = ldTy.getRank();
  return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
}

//===----------------------------------------------------------------------===//
// Helper functions for the `ldmatrix base[off0...]`
//  => `ldmatrix (subview base[off0...])[0...]` pattern.
//===----------------------------------------------------------------------===//

// Matches getFailureOrSrcMemRef specs for LdMatrixOp.
// \see LoadStoreLikeOpRewriter.
static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {
  return ldMatrixOp.getSrcMemref();
}

// Matches rebuildOpFromAddressAndIndices specs for LdMatrixOp.
// \see LoadStoreLikeOpRewriter.
static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter,
                                           nvgpu::LdMatrixOp ldMatrixOp,
                                           Value srcMemRef,
                                           ArrayRef<Value> indices) {
  Location loc = ldMatrixOp.getLoc();
  return rewriter.create<nvgpu::LdMatrixOp>(
      loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
      ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
}

//===----------------------------------------------------------------------===//
// Helper functions for the `transfer_read base[off0...]`
//  => `transfer_read (subview base[off0...])[0...]` pattern.
//===----------------------------------------------------------------------===//

// Matches getFailureOrSrcMemRef specs for TransferReadOp.
// \see LoadStoreLikeOpRewriter.
template <typename TransferLikeOp>
static FailureOr<Value>
getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
  Value src = transferLikeOp.getSource();
  if (isa<MemRefType>(src.getType()))
    return src;
  return failure();
}

// Matches rebuildOpFromAddressAndIndices specs for TransferReadOp.
// \see LoadStoreLikeOpRewriter.
static vector::TransferReadOp
rebuildTransferReadOp(RewriterBase &rewriter,
                      vector::TransferReadOp transferReadOp, Value srcMemRef,
                      ArrayRef<Value> indices) {
  Location loc = transferReadOp.getLoc();
  return rewriter.create<vector::TransferReadOp>(
      loc, transferReadOp.getResult().getType(), srcMemRef, indices,
      transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
      transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
}

//===----------------------------------------------------------------------===//
// Helper functions for the `transfer_write base[off0...]`
//  => `transfer_write (subview base[off0...])[0...]` pattern.
//===----------------------------------------------------------------------===//

// Matches rebuildOpFromAddressAndIndices specs for TransferWriteOp.
// \see LoadStoreLikeOpRewriter.
static vector::TransferWriteOp
rebuildTransferWriteOp(RewriterBase &rewriter,
                       vector::TransferWriteOp transferWriteOp, Value srcMemRef,
                       ArrayRef<Value> indices) {
  Location loc = transferWriteOp.getLoc();
  return rewriter.create<vector::TransferWriteOp>(
      loc, transferWriteOp.getValue(), srcMemRef, indices,
      transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
      transferWriteOp.getInBoundsAttr());
}

//===----------------------------------------------------------------------===//
// Generic helper functions used as default implementation in
// LoadStoreLikeOpRewriter.
//===----------------------------------------------------------------------===//

/// Helper function to get the src memref.
/// It uses the already defined getFailureOrSrcMemRef but asserts
/// that the source is a memref.
template <typename LoadStoreLikeOp,
          FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)>
static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {
  FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp);
  assert(!failed(failureOrSrcMemRef) && "Generic getSrcMemRef cannot be used");
  return *failureOrSrcMemRef;
}

/// Helper function to get the sizes of the resulting view.
/// This function gets the sizes of the source memref then substracts the
/// offsets used within \p loadStoreLikeOp. This gives the maximal (for
/// inbound) sizes for the view.
/// The source memref is retrieved using getSrcMemRef on \p loadStoreLikeOp.
template <typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
static SmallVector<OpFoldResult>
getGenericOpViewSizeForEachDim(RewriterBase &rewriter,
                               LoadStoreLikeOp loadStoreLikeOp) {
  Location loc = loadStoreLikeOp.getLoc();
  auto extractStridedMetadataOp =
      rewriter.create<memref::ExtractStridedMetadataOp>(
          loc, getSrcMemRef(loadStoreLikeOp));
  SmallVector<OpFoldResult> srcSizes =
      extractStridedMetadataOp.getConstifiedMixedSizes();
  SmallVector<OpFoldResult> indices =
      getAsOpFoldResult(loadStoreLikeOp.getIndices());
  SmallVector<OpFoldResult> finalSizes;

  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);

  for (auto [srcSize, indice] : llvm::zip(srcSizes, indices)) {
    finalSizes.push_back(affine::makeComposedFoldedAffineApply(
        rewriter, loc, s0 - s1, {srcSize, indice}));
  }
  return finalSizes;
}

/// Rewrite a store/load-like op so that all its indices are zeros.
/// E.g., %ld = memref.load %base[%off0]...[%offN]
/// =>
/// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1]
/// %ld = memref.load %new_base[0,..,0] :
///    memref<1x..x1xTy, strided<[1,..,1], offset: ?>>
///
/// `getSrcMemRef` returns the source memref for the given load-like operation.
///
/// `getViewSizeForEachDim` returns the sizes of view that is going to feed
/// new operation. This must return one size per dimension of the view.
/// The sizes of the view needs to be at least as big as what is actually
/// going to be accessed. Use the provided `loadStoreOp` to get the right
/// sizes.
///
/// Using the given rewriter, `rebuildOpFromAddressAndIndices` creates a new
/// LoadStoreLikeOp that reads from srcMemRef[indices].
/// The returned operation will be used to replace loadStoreOp.
template <typename LoadStoreLikeOp,
          FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp),
          LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
              RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/,
              Value /*srcMemRef*/, ArrayRef<Value> /*indices*/),
          SmallVector<OpFoldResult> (*getViewSizeForEachDim)(
              RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/) =
              getGenericOpViewSizeForEachDim<
                  LoadStoreLikeOp,
                  getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
  using OpRewritePattern<LoadStoreLikeOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
                                PatternRewriter &rewriter) const override {
    FailureOr<Value> failureOrSrcMemRef =
        getFailureOrSrcMemRef(loadStoreLikeOp);
    if (failed(failureOrSrcMemRef))
      return rewriter.notifyMatchFailure(loadStoreLikeOp,
                                         "source is not a memref");
    Value srcMemRef = *failureOrSrcMemRef;
    auto ldStTy = cast<MemRefType>(srcMemRef.getType());
    unsigned loadStoreRank = ldStTy.getRank();
    // Don't waste compile time if there is nothing to rewrite.
    if (loadStoreRank == 0)
      return rewriter.notifyMatchFailure(loadStoreLikeOp,
                                         "0-D accesses don't need rewriting");

    // If our load already has only zeros as indices there is nothing
    // to do.
    SmallVector<OpFoldResult> indices =
        getAsOpFoldResult(loadStoreLikeOp.getIndices());
    if (std::all_of(indices.begin(), indices.end(),
                    [](const OpFoldResult &opFold) {
                      return isConstantIntValue(opFold, 0);
                    })) {
      return rewriter.notifyMatchFailure(
          loadStoreLikeOp, "no computation to extract: offsets are 0s");
    }

    // Create the array of ones of the right size.
    SmallVector<OpFoldResult> ones(loadStoreRank, rewriter.getIndexAttr(1));
    SmallVector<OpFoldResult> sizes =
        getViewSizeForEachDim(rewriter, loadStoreLikeOp);
    assert(sizes.size() == loadStoreRank &&
           "Expected one size per load dimension");
    Location loc = loadStoreLikeOp.getLoc();
    // The subview inherits its strides from the original memref and will
    // apply them properly to the input indices.
    // Therefore the strides multipliers are simply ones.
    auto subview =
        rewriter.create<memref::SubViewOp>(loc, /*source=*/srcMemRef,
                                           /*offsets=*/indices,
                                           /*sizes=*/sizes, /*strides=*/ones);
    // Rewrite the load/store with the subview as the base pointer.
    SmallVector<Value> zeros(loadStoreRank,
                             rewriter.create<arith::ConstantIndexOp>(loc, 0));
    LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
        rewriter, loadStoreLikeOp, subview.getResult(), zeros);
    rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults());
    return success();
  }
};
} // namespace

void memref::populateExtractAddressComputationsPatterns(
    RewritePatternSet &patterns) {
  patterns.add<
      LoadStoreLikeOpRewriter<
          memref::LoadOp,
          /*getSrcMemRef=*/getLoadOpSrcMemRef,
          /*rebuildOpFromAddressAndIndices=*/rebuildLoadOp,
          /*getViewSizeForEachDim=*/getLoadOpViewSizeForEachDim>,
      LoadStoreLikeOpRewriter<
          memref::StoreOp,
          /*getSrcMemRef=*/getStoreOpSrcMemRef,
          /*rebuildOpFromAddressAndIndices=*/rebuildStoreOp,
          /*getViewSizeForEachDim=*/getStoreOpViewSizeForEachDim>,
      LoadStoreLikeOpRewriter<
          nvgpu::LdMatrixOp,
          /*getSrcMemRef=*/getLdMatrixOpSrcMemRef,
          /*rebuildOpFromAddressAndIndices=*/rebuildLdMatrixOp>,
      LoadStoreLikeOpRewriter<
          vector::TransferReadOp,
          /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferReadOp>,
          /*rebuildOpFromAddressAndIndices=*/rebuildTransferReadOp>,
      LoadStoreLikeOpRewriter<
          vector::TransferWriteOp,
          /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferWriteOp>,
          /*rebuildOpFromAddressAndIndices=*/rebuildTransferWriteOp>>(
      patterns.getContext());
}