//===- TosaFolders.cpp ----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Fold TOSA operations
//
//===----------------------------------------------------------------------===//

#include <functional>
#include <numeric>

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/ADT/SmallVector.h"

using namespace mlir;
using namespace mlir::tosa;

namespace {

/// Apply the given transformation \p toApply to every element of the tensor to
/// be transformed \p toTransform.
///
/// Elements of \p toTransform are extracted as \p SrcValueType.
///
/// \returns A tensor with the same size as \p toTransform, containing
/// \p TargetValueType values of type \p TargetType.
template <class SrcValType, class TargetValType, class TargetType>
DenseElementsAttr applyElementWise(
    const DenseElementsAttr &toTransform,
    const std::function<TargetValType(const SrcValType &)> &toApply,
    TargetType targetType) {
  SmallVector<TargetValType> transformedValues;
  // We already know the amount of values we will insert, reserve space for
  // all of them to avoid dynamic resizing
  transformedValues.reserve(toTransform.getNumElements());
  for (auto val : toTransform.getValues<SrcValType>()) {
    auto transformedVal = toApply(val);
    transformedValues.push_back(transformedVal);
  }

  // Make sure that the output tensor has the expected output type
  auto inShape = toTransform.getType();
  auto outTy = inShape.cloneWith({}, targetType);

  return DenseElementsAttr::get(outTy, transformedValues);
}

template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
    const DenseElementsAttr &toTransform,
    const std::function<APFloat(const APFloat &)> &toApply,
    FloatType targetType);

/// Function that checks if the type contained in \p toCheck is float.
LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
                               PatternRewriter &rewriter) {
  if (isa<FloatType>(toCheck.getType().getElementType())) {
    return success();
  }
  return rewriter.notifyMatchFailure(location,
                                     "Unexpected input tensor type: the "
                                     "TOSA spec only allows floats");
}

/// Function that checks if \p toCheck is a dense TOSA constant tensor.
LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
                                                TosaOp location,
                                                PatternRewriter &rewriter) {
  // Check whether the tensor is constant and dense
  // TODO We currently ensure the tensor is dense by using the correct type for
  // the bind_value, however we do not actually need this value. It would be
  // nicer to only have a check here.
  DenseElementsAttr tmp;
  if (!matchPattern(toCheck, m_Constant(&tmp))) {
    return rewriter.notifyMatchFailure(location,
                                       "Non-const or non-dense input tensor");
  }

  // Make sure it actually is a TOSA constant (the match allows for other
  // constants as well)
  if (isa<ConstOp>(toCheck.getDefiningOp())) {
    return success();
  }

  return rewriter.notifyMatchFailure(location,
                                     "The reciprocal can only be folded if "
                                     "it operates on a TOSA constant");
}

/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
                                                 TosaOp location,
                                                 PatternRewriter &rewriter) {
  auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
  if (failed(floatCheck)) {
    return floatCheck;
  }
  return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
}

/// Heuristic to decide when to replace a unary operation on a constant with the
/// folded value.
/// Folding operations on constants can lead to an increased memory usage
/// whenever the input cannot be replaced but a new constant is inserted. Hence,
/// this will currently only suggest folding when the memory impact is
/// negligible.
/// Takes the \p unaryOp and the constant input \p values.
/// \returns Whether folding should be applied.
bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
  assert(unaryOp->getNumOperands() == 1);
  auto inputOp = unaryOp->getOperand(0);

  // If the input is a splat, we don't care for the number of users
  if (isa<SplatElementsAttr>(values)) {
    return true;
  }

  // If this is the only use of the tensor it should be replaced as no
  // additional memory is required
  return inputOp.hasOneUse();
}

template <typename RangeType>
DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
                                ShapedType outputType,
                                llvm::ArrayRef<int64_t> permValues) {
  using ElementType = std::decay_t<decltype(*std::begin(data))>;

  assert(inputType.getElementType() == outputType.getElementType());

  if (inputType.getNumElements() == 0)
    return DenseElementsAttr::get(outputType, llvm::ArrayRef<ElementType>{});

  auto inputShape = inputType.getShape();

  // The inverted permutation map and strides of the output are used to compute
  // the contribution of a given dimension to the destination linear index in
  // an order-independent way.
  auto outputStrides = computeStrides(outputType.getShape());
  auto invertedPermValues = invertPermutationVector(permValues);

  auto initialValue = *std::begin(data);
  SmallVector<ElementType> outputValues(inputType.getNumElements(),
                                        initialValue);

  for (const auto &it : llvm::enumerate(data)) {
    auto srcLinearIndex = it.index();

    uint64_t dstLinearIndex = 0;
    for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
      // Compute the index into the current dimension of the source vector.
      auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
      srcLinearIndex /= inputShape[dim];

      // Add the contribution of the current dimension to the output using the
      // permutation map.
      dstLinearIndex +=
          outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
    }

    outputValues[dstLinearIndex] = it.value();
  }

  return DenseElementsAttr::get(outputType,
                                llvm::ArrayRef<ElementType>(outputValues));
}

// A type specialized transposition of an ElementsAttr.
// This implementation tries to operate on the underlying data in its raw
// representation when possible to avoid allocating a large number of Attribute
// objects.
DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
                            ShapedType outputType,
                            llvm::ArrayRef<int64_t> permValues) {
  if (auto data = attr.tryGetValues<bool>())
    return transposeType(*data, inputType, outputType, permValues);

  if (auto data = attr.tryGetValues<int8_t>())
    return transposeType(*data, inputType, outputType, permValues);

  if (auto data = attr.tryGetValues<int16_t>())
    return transposeType(*data, inputType, outputType, permValues);

  if (auto data = attr.tryGetValues<int32_t>())
    return transposeType(*data, inputType, outputType, permValues);

  if (auto data = attr.tryGetValues<int64_t>())
    return transposeType(*data, inputType, outputType, permValues);

  if (auto data = attr.tryGetValues<float>())
    return transposeType(*data, inputType, outputType, permValues);

  if (auto data = attr.tryGetValues<APFloat>())
    return transposeType(*data, inputType, outputType, permValues);

  return nullptr;
}

struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(tosa::TransposeOp op,
                                PatternRewriter &rewriter) const override {
    auto outputType = cast<ShapedType>(op.getType());
    // TOSA supports quantized types.
    if (!outputType.getElementType().isIntOrIndexOrFloat())
      return failure();

    ElementsAttr inputValues;
    if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
      return failure();
    // Make sure the input is a constant that has a single user.
    if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
      return failure();

    DenseIntElementsAttr permAttr;
    if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
      return failure();
    auto permValues = llvm::map_to_vector(
        // TOSA allows both 32- and 64-bit integer tensors here.
        permAttr.getValues<APInt>(),
        [](const APInt &val) { return val.getSExtValue(); });

    auto inputType = cast<ShapedType>(op.getInput1().getType());

    auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
    if (!resultAttr) {
      return rewriter.notifyMatchFailure(
          op, "unsupported attribute or element type");
    }

    rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
    return success();
  }
};

struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {

  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(ReciprocalOp recip,
                                PatternRewriter &rewriter) const override {
    auto inputTensor = recip.getInput1();

    // Check that we can apply folding
    auto preCondCheck =
        notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
    if (failed(preCondCheck)) {
      return preCondCheck;
    }

    // Extract the tensor values
    DenseElementsAttr inputValues;
    matchPattern(inputTensor, m_Constant(&inputValues));

    // Check whether this should be folded.
    if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
      return rewriter.notifyMatchFailure(
          recip, "Currently, reciprocals will only be folded if the input "
                 "tensor has a single user");
    }

    // Create a new tensor with the updated values
    auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
        inputValues, &ReciprocalOp::calcOneElement,
        cast<FloatType>(inputValues.getElementType()));

    // Replace the use of the reciprocal with the transformed tensor
    rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
    return success();
  }
};

/// Getting the axes position of the element which is located
/// in the tensor at the counter index

llvm::SmallVector<int64_t>
getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) {
  int64_t remaining = index;
  llvm::SmallVector<int64_t> position(tensorShape.size(), 0);
  for (int64_t i = tensorShape.size() - 1; i >= 0; --i) {
    position[i] = remaining % tensorShape[i];
    remaining /= tensorShape[i];
  }
  return position;
}

/// Getting the index of the element which is located at the
/// axes position in the tensor

int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
                             llvm::ArrayRef<int64_t> tensorShape) {
  int64_t index = 0;
  int64_t multiplierTmp = 1;
  for (int64_t i = position.size() - 1; i >= 0; --i) {
    index += position[i] * multiplierTmp;
    multiplierTmp *= tensorShape[i];
  }
  return index;
}

template <typename OperationType>
llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
                                  llvm::ArrayRef<int64_t> oldShape,
                                  int64_t reductionAxis,
                                  int64_t reductionIndex) {

  llvm::SmallVector<int64_t> newShape(oldShape);
  newShape[reductionAxis] = 1;
  /// Let's calculate the position of the index
  llvm::SmallVector<int64_t> position =
      getPositionFromIndex(reductionIndex, newShape);
  auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
  /// Starting from the first positon along the reduction axis
  position[reductionAxis] = 0;
  int64_t indexAtOldTensor = getIndexFromPosition(position, oldShape);
  llvm::APInt reducedValue = oldTensor[indexAtOldTensor];

  for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
       ++reductionAxisVal) {

    int64_t stride = std::accumulate(oldShape.begin() + reductionAxis + 1,
                                     oldShape.end(), 1, std::multiplies<int>());
    int64_t index = indexAtOldTensor + stride * reductionAxisVal;
    reducedValue =
        OperationType::calcOneElement(reducedValue, oldTensor[index]);
  }
  return reducedValue;
}

template <typename OperationType>
struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {

  ReduceConstantOptimization(MLIRContext *context,
                             bool aggressiveReduceConstant)
      : OpRewritePattern<OperationType>(context),
        aggressiveReduceConstant(aggressiveReduceConstant) {}

  using OpRewritePattern<OperationType>::OpRewritePattern;

  LogicalResult matchAndRewrite(OperationType op,
                                PatternRewriter &rewriter) const override {
    Value inputOp = op.getInput();
    auto constOp = inputOp.getDefiningOp<tosa::ConstOp>();

    if (!constOp)
      return rewriter.notifyMatchFailure(
          op, "reduce input must be const operation");

    if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
      return rewriter.notifyMatchFailure(
          op, "input operation has more than one user");

    auto resultType = cast<ShapedType>(op.getOutput().getType());

    if (!resultType.hasStaticShape())
      return rewriter.notifyMatchFailure(op, "result type shape is not static");

    auto reductionAxis = op.getAxis();
    const auto denseElementsAttr = constOp.getValue();
    const auto shapedOldElementsValues =
        cast<ShapedType>(denseElementsAttr.getType());

    if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
      return rewriter.notifyMatchFailure(
          op, "reduce input currently supported with integer type");

    auto oldShape = shapedOldElementsValues.getShape();
    auto newShape = resultType.getShape();

    auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1,
                                            std::multiplies<int>());
    llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);

    for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
         ++reductionIndex) {

      /// Let's reduce all the elements along this reduction axis
      newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
          denseElementsAttr, oldShape, reductionAxis, reductionIndex);
    }

    auto rankedTensorType = cast<RankedTensorType>(resultType);
    auto denseAttr =
        mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor);
    rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
    return success();
  }
  const bool aggressiveReduceConstant;
};

} // namespace

void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx,
                                               RewritePatternSet &patterns,
                                               bool aggressiveReduceConstant) {
  patterns.add<ReduceConstantOptimization<ReduceAllOp>>(
      ctx, aggressiveReduceConstant);
  patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(
      ctx, aggressiveReduceConstant);
  patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(
      ctx, aggressiveReduceConstant);
  patterns.add<ReduceConstantOptimization<ReduceMinOp>>(
      ctx, aggressiveReduceConstant);
  patterns.add<ReduceConstantOptimization<ReduceProdOp>>(
      ctx, aggressiveReduceConstant);
  patterns.add<ReduceConstantOptimization<ReduceSumOp>>(
      ctx, aggressiveReduceConstant);
}

void mlir::tosa::populateTosaFoldConstantTransposePatterns(
    MLIRContext *ctx, RewritePatternSet &patterns) {
  patterns.add<TosaFoldConstantTranspose>(ctx);
}

void mlir::tosa::populateTosaFoldConstantReciprocalPatterns(
    MLIRContext *ctx, RewritePatternSet &patterns) {
  patterns.add<TosaFoldConstantReciprocal>(ctx);
}