//===- MMAUtils.cpp - MLIR NVGPU dialect utils for MMA operations----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"

using namespace mlir;
using namespace mlir::nvgpu;

/// There are always 4 threads per [128|256|512] bit row.
static constexpr int64_t kThreadsPerRow = 4;
static constexpr int64_t kNumRowsPerTile = 8;

static bool isAccumulatorOrResult(MatMulOperandRole operandType) {
  return operandType == MatMulOperandRole::C;
}

/// Returns the number of registers which compose a matrix fragment held by a
/// single thread.
static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
  int64_t lineSize = inferTileWidthInBits(type);
  auto shape = type.vectorType.getShape();
  return (shape[0] / kNumRowsPerTile) *
         (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
         lineSize;
}

/// Returns the number of 8 x [128|256|512] bit tiles that compose the given
/// operand shape.
static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
                                           Type elementType,
                                           int64_t lineSizeBits) {
  // For each 8x128bit square, a thread is responsible for one 32bit register.
  return {operandShape[0] / kNumRowsPerTile,
          (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
              lineSizeBits};
}

/// Returns the first user of the `op` that is vector.contract. If no
/// vector.contract user exists, return failure.
FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) {
  for (Operation *user : op->getUsers()) {
    if (auto contractOp = dyn_cast<vector::ContractionOp>(user))
      return contractOp;
  }
  return failure();
}

FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
  WarpMatrixInfo info;

  // Determine the vector type at warp-level.
  if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
    info.vectorType = writeOp.getVectorType();
  } else if (isa<vector::TransferReadOp, vector::ContractionOp,
                 vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
    info.vectorType = cast<VectorType>(op->getResult(0).getType());
  } else {
    return op->emitError()
           << "unhandled operation type in nvgpu.mma.sync conversion path";
  }

  // Determine the operand role. We assume it is an accumulator/result unless it
  // is directly consumed by a `vector.contract` op.
  info.operandRole = MatMulOperandRole::C;
  FailureOr<vector::ContractionOp> contractOp = getUserContract(op);
  if (failed(contractOp))
    return info;

  if ((*contractOp).getLhs() == op->getResult(0))
    info.operandRole = MatMulOperandRole::A;
  else if ((*contractOp).getRhs() == op->getResult(0))
    info.operandRole = MatMulOperandRole::B;

  return info;
}

int64_t nvgpu::inferTileWidthInBits(const WarpMatrixInfo &type) {
  bool isAcc = isAccumulatorOrResult(type.operandRole);
  Type elType = type.vectorType.getElementType();
  if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
    return 256;
  }
  if (elType.getIntOrFloatBitWidth() == 64) {
    return isAcc ? 512 : 256;
  }
  return 128;
}

FailureOr<FragmentElementInfo>
nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
  MLIRContext *ctx = type.vectorType.getContext();
  const bool isAccum = isAccumulatorOrResult(type.operandRole);

  Type elType = type.vectorType.getElementType();
  if (elType.isF16()) {
    return FragmentElementInfo{
        LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
        inferNumRegistersPerMatrixFragment(type)};
  }

  // f64 operand
  Type f64Ty = Float64Type::get(ctx);
  if (elType.isF64()) {
    return isAccum
               ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
                                     inferNumRegistersPerMatrixFragment(type)}
               : FragmentElementInfo{f64Ty, 1, 64,
                                     inferNumRegistersPerMatrixFragment(type)};
  }

  // int8 operand
  if (elType.isInteger(8)) {
    return FragmentElementInfo{
        LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
        inferNumRegistersPerMatrixFragment(type)};
  }

  // int4 operand
  if (elType.isInteger(4)) {
    return FragmentElementInfo{
        LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
        inferNumRegistersPerMatrixFragment(type)};
  }

  // Integer 32bit acc operands
  if (elType.isInteger(32)) {
    return FragmentElementInfo{
        LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
        inferNumRegistersPerMatrixFragment(type)};
  }

  // Floating point 32bit operands
  if (elType.isF32()) {
    Type f32Ty = Float32Type::get(ctx);
    return isAccum
               ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
                                     inferNumRegistersPerMatrixFragment(type)}
               : FragmentElementInfo{f32Ty, 1, 32,
                                     inferNumRegistersPerMatrixFragment(type)};
  }
  return failure();
}

static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
                                                 Type elementType,
                                                 ArrayRef<int64_t> operandShape,
                                                 bool isAccumulator,
                                                 int64_t elementsPerRegister,
                                                 AffineExpr logicalValueId) {
  const int64_t elementsPerLine =
      lineSize / elementType.getIntOrFloatBitWidth();
  const std::array<int64_t, 2> num8x128bTiles =
      getTileShape(operandShape, elementType, lineSize);
  AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
  return AffineMap::get(
      2, 0,
      {(registerIdx % num8x128bTiles[0]) * 8,
       (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
      elementType.getContext());
}

FailureOr<AffineMap>
nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
                                         const WarpMatrixInfo &fragmentType) {
  Type elementType = fragmentType.vectorType.getElementType();
  ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
  FailureOr<nvgpu::FragmentElementInfo> regInfo =
      getMmaSyncRegisterType(fragmentType);
  if (failed(regInfo))
    return failure();

  const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
  const int64_t elementsPerRegister =
      regInfo->registerWidthBits / elementBitWidth;
  const int64_t lineSize = inferTileWidthInBits(fragmentType);

  AffineExpr laneId, logicalValueIdDim;
  bindDims(builder.getContext(), laneId, logicalValueIdDim);

  // Determine what register logicalValueId corresponds to. Use that as a
  // linear index into the coordinate mapping `index -> (tile row, tile col)`.
  AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
      lineSize, elementType, operandShape,
      isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister,
      logicalValueIdDim);

  auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
    return AffineMap::get(2, 0, dimExprs, builder.getContext());
  };

  auto tileRow = registerIndexToTileCoord.getResult(0);
  auto tileCol = registerIndexToTileCoord.getResult(1);
  return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow),
                  tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
                      (logicalValueIdDim % elementsPerRegister)});
}

FailureOr<nvgpu::LdMatrixParams>
nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) {
  LdMatrixParams params;
  Type elType = type.vectorType.getElementType();
  params.fragmentType = type.vectorType;
  if (type.operandRole == MatMulOperandRole::A ||
      type.operandRole == MatMulOperandRole::C) {
    params.targetLayout = NVVM::MMALayout::row;
  } else {
    params.targetLayout = NVVM::MMALayout::col;
  }
  ArrayRef<int64_t> shape = type.vectorType.getShape();
  params.contiguousDimType = transpose ? vector::IteratorType::parallel
                                       : vector::IteratorType::reduction;

  if (params.contiguousDimType == vector::IteratorType::reduction) {
    params.numTiles = (shape[0] / kNumRowsPerTile) *
                      ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
  } else {
    params.numTiles = (shape[1] / kNumRowsPerTile) *
                      ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
  }

  if (params.numTiles == 0)
    return failure();

  return params;
}

FailureOr<AffineMap>
nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
                                      const LdMatrixParams &params) {
  // One thread per 128b row.
  const int bitsPerElement = static_cast<int>(
      params.fragmentType.getElementType().getIntOrFloatBitWidth());
  const int kElementsPer128b = (128 / bitsPerElement);
  ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
  AffineExpr d0 = getAffineDimExpr(0, builder.getContext());

  auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
    return AffineMap::get(1, 0, dimExprs, builder.getContext());
  };

  // Index `idx` in vectorType `operandShape` maps to the strided dimension of
  // the `srcMemref` memory of the LdMatrixOp.
  int idx =
      (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1;

  // Affine expr in strided and contiguous dimension encodes the coordinate
  // mapping for the element a thread points to for warp-wide LdMatrixOp.
  AffineExpr strided = d0 % (operandShape[idx]);
  AffineExpr contiguous = d0.floorDiv(operandShape[idx]) * (kElementsPer128b);

  // This case corresponds to row-major matrixA or col-major matrixB or
  // row-major matrixC. This is when the memory layout in `srcMemref`
  // match mma.sync hardware vector register operand layout.
  if (params.contiguousDimType == vector::IteratorType::reduction)
    return makeMap({strided, contiguous});

  // This case corresponds to col-major matrixA or row-major matrixB or
  // col-major matrixC. This is when the memory layout in `srcMemref` does not
  // match mma.sync hardware vector register operand layout.
  if (params.contiguousDimType == vector::IteratorType::parallel)
    return makeMap({contiguous, strided});

  return failure();
}

bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {
  if (op.getMask() || op.hasOutOfBoundsDim())
    return false;
  VectorType type = op.getType();
  // The result type should be 2D. Note that it is possible to expand support so
  // that we are robust to extra unit dimensions that failed to fold, but that
  // would significantly increase downstream code complexity in the conversion
  // step. For now, we rely on other patterns to ensure canonical 2D form is
  // used when targeting the `nvgpu.mma.sync` lowering path.
  if (!type.hasStaticShape() || type.getRank() != 2)
    return false;

  // Currently we can't support reads on tensor types because we need stride
  // information to ensure correctness of downstream assumptions. It is possible
  // to enable this if caller can assert that tensor will be lowered in a
  // particular manner.
  auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
  if (!sourceType)
    return false;

  // Check that the last dimension of the read is contiguous. Note that it is
  // possible to expand support for this by scalarizing all the loads during
  // conversion.
  auto [strides, offset] = mlir::getStridesAndOffset(sourceType);
  return strides.back() == 1;
}

bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {
  if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
    return false;
  VectorType type = op.getVectorType();
  if (!type.hasStaticShape() || type.getRank() != 2)
    return false;
  // TODO: Currently we rely on lowering to a `vector.store` operation. We could
  // support the transposed write case by lowering to scalarized `memref.store`
  // operations.
  if (!op.getPermutationMap().isMinorIdentity())
    return false;
  // Currently we can't support reads on tensor types because we need stride
  // information to ensure correctness of downstream assumptions.
  auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
  if (!sourceType)
    return false;

  // Check that the last dimension of the target memref is contiguous. Note that
  // it is possible to expand support for this by scalarizing all the stores
  // during conversion.
  auto [strides, offset] = mlir::getStridesAndOffset(sourceType);
  return strides.back() == 1;
}