//===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/SymbolTable.h"

using namespace mlir;

namespace {
LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
                                      Operation *module, Type indexType) {
  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
  if (useGenericFn)
    return LLVM::lookupOrCreateGenericAllocFn(module, indexType);

  return LLVM::lookupOrCreateMallocFn(module, indexType);
}

LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
                                   Operation *module, Type indexType) {
  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;

  if (useGenericFn)
    return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);

  return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
}

} // end namespace

Value AllocationOpLLVMLowering::createAligned(
    ConversionPatternRewriter &rewriter, Location loc, Value input,
    Value alignment) {
  Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
  Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
  Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
  Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
  return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
}

static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
                                 Location loc, Value allocatedPtr,
                                 MemRefType memRefType, Type elementPtrType,
                                 const LLVMTypeConverter &typeConverter) {
  auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
  FailureOr<unsigned> maybeMemrefAddrSpace =
      typeConverter.getMemRefAddressSpace(memRefType);
  if (failed(maybeMemrefAddrSpace))
    return Value();
  unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
  if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
    allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
        loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
        allocatedPtr);
  return allocatedPtr;
}

std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
    ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
    Operation *op, Value alignment) const {
  if (alignment) {
    // Adjust the allocation size to consider alignment.
    sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
  }

  MemRefType memRefType = getMemRefResultType(op);
  // Allocate the underlying buffer.
  Type elementPtrType = this->getElementPtrType(memRefType);
  LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
      getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
      getIndexType());
  auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);

  Value allocatedPtr =
      castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
                          elementPtrType, *getTypeConverter());
  if (!allocatedPtr)
    return std::make_tuple(Value(), Value());
  Value alignedPtr = allocatedPtr;
  if (alignment) {
    // Compute the aligned pointer.
    Value allocatedInt =
        rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
    Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
    alignedPtr =
        rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
  }

  return std::make_tuple(allocatedPtr, alignedPtr);
}

unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
    MemRefType memRefType, Operation *op,
    const DataLayout *defaultLayout) const {
  const DataLayout *layout = defaultLayout;
  if (const DataLayoutAnalysis *analysis =
          getTypeConverter()->getDataLayoutAnalysis()) {
    layout = &analysis->getAbove(op);
  }
  Type elementType = memRefType.getElementType();
  if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
    return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
                                                       *layout);
  if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
    return getTypeConverter()->getUnrankedMemRefDescriptorSize(
        memRefElementType, *layout);
  return layout->getTypeSize(elementType);
}

bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
    MemRefType type, uint64_t factor, Operation *op,
    const DataLayout *defaultLayout) const {
  uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
  for (unsigned i = 0, e = type.getRank(); i < e; i++) {
    if (type.isDynamicDim(i))
      continue;
    sizeDivisor = sizeDivisor * type.getDimSize(i);
  }
  return sizeDivisor % factor == 0;
}

Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
    ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
    Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
  Value allocAlignment =
      createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);

  MemRefType memRefType = getMemRefResultType(op);
  // Function aligned_alloc requires size to be a multiple of alignment; we pad
  // the size to the next multiple if necessary.
  if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout))
    sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);

  Type elementPtrType = this->getElementPtrType(memRefType);
  LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
      getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
      getIndexType());
  auto results = rewriter.create<LLVM::CallOp>(
      loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));

  return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
                             elementPtrType, *getTypeConverter());
}

void AllocLikeOpLLVMLowering::setRequiresNumElements() {
  requiresNumElements = true;
}

LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
    Operation *op, ArrayRef<Value> operands,
    ConversionPatternRewriter &rewriter) const {
  MemRefType memRefType = getMemRefResultType(op);
  if (!isConvertibleAndHasIdentityMaps(memRefType))
    return rewriter.notifyMatchFailure(op, "incompatible memref type");
  auto loc = op->getLoc();

  // Get actual sizes of the memref as values: static sizes are constant
  // values and dynamic sizes are passed to 'alloc' as operands.  In case of
  // zero-dimensional memref, assume a scalar (size 1).
  SmallVector<Value, 4> sizes;
  SmallVector<Value, 4> strides;
  Value size;

  this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
                                 strides, size, !requiresNumElements);

  // Allocate the underlying buffer.
  auto [allocatedPtr, alignedPtr] =
      this->allocateBuffer(rewriter, loc, size, op);

  if (!allocatedPtr || !alignedPtr)
    return rewriter.notifyMatchFailure(loc,
                                       "underlying buffer allocation failed");

  // Create the MemRef descriptor.
  auto memRefDescriptor = this->createMemRefDescriptor(
      loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);

  // Return the final value of the descriptor.
  rewriter.replaceOp(op, {memRefDescriptor});
  return success();
}