//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// \file
// TOSA canonicalization patterns and folders.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"

#include <functional>

using namespace mlir;
using namespace mlir::tosa;

//===----------------------------------------------------------------------===//
// Operator Canonicalizers.
//===----------------------------------------------------------------------===//

struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
  using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(tosa::ConcatOp op,
                                PatternRewriter &rewriter) const override {
    if (op.getInput1().size() != 1)
      return failure();
    if (op.getInput1().front().getType() != op.getType()) {
      rewriter
          .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
                                              op.getInput1().front())
          .getResult();
      return success();
    }

    rewriter.replaceOp(op, op.getInput1().front());
    return success();
  }
};

void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {
  results.add<ConcatOptimization>(context);
}

LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
  auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
  if (!notOp)
    return failure();
  rewriter.modifyOpInPlace(op, [&]() {
    op.getOperation()->setOperands(
        {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
  });
  return success();
}

struct ConsolidateTransposeOptimization
    : public OpRewritePattern<tosa::TransposeOp> {
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
                                PatternRewriter &rewriter) const override {
    // Input is also TransposeOp - transpose(transpose(A)).
    auto innerTranspose =
        transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
    if (!innerTranspose)
      return rewriter.notifyMatchFailure(transposeOp,
                                         "input must be transpose operation");

    SmallVector<int64_t> transposePerms, innerTransposePerms;
    if (transposeOp.getConstantPerms(transposePerms).failed())
      return rewriter.notifyMatchFailure(transposeOp,
                                         "transpose perms must be constant");
    if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
      return rewriter.notifyMatchFailure(
          transposeOp, "inner transpose perms must be constant");
    if (transposePerms.size() != innerTransposePerms.size())
      return rewriter.notifyMatchFailure(
          transposeOp,
          "transpose and inner transpose perms sizes must be equal");
    if (transposePerms.empty())
      return rewriter.notifyMatchFailure(
          transposeOp, "transpose perms sizes must be positive");

    // Consolidate transposes into one transpose.
    SmallVector<int32_t> perms(transposePerms.size());
    for (int i = 0, s = transposePerms.size(); i < s; ++i)
      perms[i] = innerTransposePerms[transposePerms[i]];

    auto permsTy =
        RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
    auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
    Value permsValue =
        rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);

    rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
        transposeOp, transposeOp.getResult().getType(),
        innerTranspose.getInput1(), permsValue);

    return success();
  }
};

// Determines the case when tosa.transpose is a tosa.reshape operation.
struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(tosa::TransposeOp op,
                                PatternRewriter &rewriter) const override {
    DenseIntElementsAttr permAttr;
    if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
      return rewriter.notifyMatchFailure(op, "Non-constant permutation");

    if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
      return rewriter.notifyMatchFailure(
          op, "Src is from transpose, can compose transposes");

    Value result = op.getResult();
    for (Operation *subop : result.getUsers()) {
      if (dyn_cast_or_null<tosa::TransposeOp>(subop))
        return rewriter.notifyMatchFailure(
            op, "Dest is used by transpose, can compose transposes");
    }

    auto input = op.getInput1();
    auto inputTy = llvm::cast<ShapedType>(input.getType());
    if (!inputTy.hasRank())
      return rewriter.notifyMatchFailure(op, "Unranked input.");

    int64_t numDynDims = 0;
    for (int i = 0; i < inputTy.getRank(); ++i)
      if (inputTy.isDynamicDim(i))
        numDynDims++;

    if (numDynDims > 1)
      return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");

    SmallVector<int64_t> permValues = llvm::to_vector<6>(
        llvm::map_range(permAttr.getValues<APInt>(),
                        [](const APInt &val) { return val.getSExtValue(); }));

    SmallVector<int64_t> nonZeroPerms;
    nonZeroPerms.reserve(permValues.size());
    for (auto idx : permValues) {
      auto sz = inputTy.getDimSize(idx);
      if (sz != 1)
        nonZeroPerms.push_back(idx);
    }

    for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
      if (nonZeroPerms[i - 1] > nonZeroPerms[i])
        return rewriter.notifyMatchFailure(op,
                                           "Transpose changes memory layout.");

    SmallVector<int64_t> newShape;
    newShape.reserve(inputTy.getRank());
    for (int i = 0, s = inputTy.getRank(); i < s; ++i)
      newShape.push_back(inputTy.getDimSize(permValues[i]));

    rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
        op, op.getType(), op.getInput1(),
        rewriter.getDenseI64ArrayAttr(newShape));
    return success();
  }
};

void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                              MLIRContext *context) {
  results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
}

struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(tosa::PadOp op,
                                PatternRewriter &rewriter) const override {
    if (op.getPadConst())
      return failure();

    auto input = op.getInput1();
    auto padding = op.getPadding();

    ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
    Type elementTy = inputTy.getElementType();

    Attribute constantAttr;
    if (llvm::isa<FloatType>(elementTy)) {
      constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
    } else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
      constantAttr = rewriter.getIntegerAttr(elementTy, 0);
    } else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
      auto value = op.getQuantizationInfo()->getInputZp();
      constantAttr = rewriter.getIntegerAttr(elementTy, value);
    }

    if (!constantAttr) {
      return rewriter.notifyMatchFailure(
          op,
          "tosa.pad to linalg lowering encountered an unknown element type");
    }

    auto denseAttr = DenseElementsAttr::get(
        RankedTensorType::get({}, elementTy), constantAttr);
    auto constantVal = rewriter.create<tosa::ConstOp>(
        op.getLoc(), denseAttr.getType(), denseAttr);

    rewriter.replaceOpWithNewOp<tosa::PadOp>(
        op, op.getType(), ValueRange{input, padding, constantVal},
        op->getAttrs());
    return success();
  }
};

void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                        MLIRContext *context) {
  results.add<MaterializePadValue>(context);
}

struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
                                PatternRewriter &rewriter) const override {
    Value input = op.getInput();
    Value output = op.getOutput();
    ShapedType inputType = llvm::cast<ShapedType>(input.getType());
    ShapedType outputType = llvm::cast<ShapedType>(output.getType());

    if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
      return failure();
    }

    // If the output and input shapes are 1x1, then this is a no op.
    ArrayRef<int64_t> outputShape = outputType.getShape();
    if (outputShape[1] != 1 || outputShape[2] != 1) {
      return failure();
    }

    ArrayRef<int64_t> inputShape = inputType.getShape();
    if (inputShape[1] != 1 || inputShape[2] != 1) {
      return failure();
    }

    rewriter.replaceOp(op, input);
    return success();
  }
};

void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                              MLIRContext *context) {
  results.add<MaxPool2dIsNoOp>(context);
}

struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(tosa::ClampOp op,
                                PatternRewriter &rewriter) const override {
    Value input = op.getInput();
    auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
    auto inputElementType = inputType.getElementType();

    if (!inputType.hasStaticShape()) {
      return failure();
    }

    if (isa<FloatType>(inputElementType)) {
      // Unlike integer types, floating point types can represent infinity.
      auto minClamp = op.getMinFp();
      auto maxClamp = op.getMaxFp();
      bool isMin = minClamp.isInfinity() && minClamp.isNegative();
      bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();

      if (isMin && isMax) {
        rewriter.replaceOp(op, input);
        return success();
      }
      return failure();
    }

    if (inputElementType.isUnsignedInteger()) {
      int64_t minClamp = op.getMinInt();
      int64_t maxClamp = op.getMaxInt();

      int64_t intMin =
          APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
              .getZExtValue();
      int64_t intMax =
          APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
              .getZExtValue();

      if (minClamp <= intMin && maxClamp >= intMax) {
        rewriter.replaceOp(op, input);
        return success();
      }
      return failure();
    }

    if (llvm::isa<IntegerType>(inputElementType)) {
      int64_t minClamp = op.getMinInt();
      int64_t maxClamp = op.getMaxInt();

      int64_t intMin =
          APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
              .getSExtValue();
      int64_t intMax =
          APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
              .getSExtValue();

      if (minClamp <= intMin && maxClamp >= intMax) {
        rewriter.replaceOp(op, input);
        return success();
      }
      return failure();
    }

    return failure();
  }
};

struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
  using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(tosa::ClampOp op,
                                PatternRewriter &rewriter) const override {
    Value input = op.getInput();

    Operation *definingOp = input.getDefiningOp();
    if (!definingOp)
      return failure();

    if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
      auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
      auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();

      auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
      auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());

      rewriter.replaceOpWithNewOp<tosa::ClampOp>(
          op, op.getType(), clampOp.getInput(),
          rewriter.getI64IntegerAttr(minInt),
          rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
          rewriter.getF32FloatAttr(maxFp));
      return success();
    }

    return failure();
  }
};

void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
  results.add<ClampIsNoOp>(context);
  results.add<ClampClampOptimization>(context);
}

struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
  using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
                                PatternRewriter &rewriter) const override {
    Value sliceInput = sliceOp.getInput();
    auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
    if (!concatOp)
      return rewriter.notifyMatchFailure(
          sliceOp, "slice input must be concat operation");

    OperandRange inputs = concatOp.getInput1();
    auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
    if (!concatType || !concatType.hasStaticShape())
      return rewriter.notifyMatchFailure(
          sliceOp, "slice input must be a static ranked tensor");
    int32_t axis = concatOp.getAxis();

    llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
    llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();

    // Validate slice on the concatenated axis. Slicing along this
    // axis should span only one of the inputs to the concatenate
    // operation.
    std::optional<Value> replaceWithSlice;
    for (auto input : inputs) {
      auto inputType = dyn_cast<RankedTensorType>(input.getType());
      if (!inputType || !inputType.hasStaticShape())
        return rewriter.notifyMatchFailure(
            sliceOp, "concat input must be a static ranked tensor");

      if (sliceStart[axis] >= 0 &&
          (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
        replaceWithSlice = rewriter
                               .create<tosa::SliceOp>(
                                   sliceOp.getLoc(), sliceOp.getType(), input,
                                   rewriter.getDenseI64ArrayAttr(sliceStart),
                                   rewriter.getDenseI64ArrayAttr(sliceSize))
                               .getResult();
        break;
      }
      sliceStart[axis] -= inputType.getDimSize(axis);
    }

    if (!replaceWithSlice)
      return rewriter.notifyMatchFailure(
          sliceOp, "corresponding concat input not found for slice");

    rewriter.replaceOp(sliceOp, replaceWithSlice.value());
    return success();
  }
};

void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
  results.add<ConcatSliceOptimization>(context);
}

//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//

template <typename IntFolder, typename FloatFolder>
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
                               RankedTensorType returnTy) {
  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
    auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
    auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
    if (lETy != rETy)
      return {};

    if (llvm::isa<IntegerType>(lETy)) {
      APInt l = lhs.getSplatValue<APInt>();
      APInt r = rhs.getSplatValue<APInt>();
      auto result = IntFolder()(l, r);
      return DenseElementsAttr::get(returnTy, result);
    }

    if (llvm::isa<FloatType>(lETy)) {
      APFloat l = lhs.getSplatValue<APFloat>();
      APFloat r = rhs.getSplatValue<APFloat>();
      auto result = FloatFolder()(l, r);
      return DenseElementsAttr::get(returnTy, result);
    }
  }

  return {};
}

static bool isSplatZero(Type elemType, DenseElementsAttr val) {
  if (llvm::isa<FloatType>(elemType))
    return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
  if (llvm::isa<IntegerType>(elemType))
    return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
  return false;
}

static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
  if (llvm::isa<FloatType>(elemType))
    return val && val.isSplat() &&
           val.getSplatValue<APFloat>().isExactlyValue(1.0);
  if (llvm::isa<IntegerType>(elemType)) {
    const int64_t shifted = 1LL << shift;
    return val && val.isSplat() &&
           val.getSplatValue<APInt>().getSExtValue() == shifted;
  }
  return false;
}

OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
  if (!lhsTy || !rhsTy || !resultTy)
    return {};

  // Cannot create an ElementsAttr from non-int/float/index types
  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
      !rhsTy.getElementType().isIntOrIndexOrFloat())
    return {};

  auto resultETy = resultTy.getElementType();
  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());

  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
    return getInput1();
  if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
    return getInput2();

  if (!lhsAttr || !rhsAttr)
    return {};

  return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
                                                            resultTy);
}

OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
  if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
      !outputTy.hasStaticShape())
    return {};

  if (inputTy.getDimSize(getAxis()) == 1)
    return DenseElementsAttr::get(outputTy, 0);

  return {};
}

OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
  if (!lhsTy || !rhsTy || !resultTy)
    return {};
  if (lhsTy != rhsTy)
    return {};

  // IntDivOp inputs must be integer type, no need to check for quantized type
  auto resultETy = resultTy.getElementType();
  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
  if (lhsAttr && lhsAttr.isSplat()) {
    if (llvm::isa<IntegerType>(resultETy) &&
        lhsAttr.getSplatValue<APInt>().isZero())
      return lhsAttr;
  }

  if (rhsAttr && rhsAttr.isSplat()) {
    if (llvm::isa<IntegerType>(resultETy) &&
        rhsAttr.getSplatValue<APInt>().isOne())
      return getInput1();
  }

  if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
    if (llvm::isa<IntegerType>(resultETy)) {
      APInt l = lhsAttr.getSplatValue<APInt>();
      APInt r = rhsAttr.getSplatValue<APInt>();
      APInt result = l.sdiv(r);
      return DenseElementsAttr::get(resultTy, result);
    }
  }

  return {};
}

namespace {
DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
                                  RankedTensorType ty, int32_t shift) {
  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
    if (llvm::isa<IntegerType>(ty.getElementType())) {
      APInt l = lhs.getSplatValue<APInt>();
      APInt r = rhs.getSplatValue<APInt>();

      if (shift == 0) {
        return DenseElementsAttr::get(ty, l * r);
      }

      auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
      l = l.sext(bitwidth * 2);
      r = r.sext(bitwidth * 2);
      auto result = l * r;
      result.lshrInPlace(shift);
      result = result.trunc(bitwidth);
      return DenseElementsAttr::get(ty, result);
    }

    if (llvm::isa<FloatType>(ty.getElementType())) {
      APFloat l = lhs.getSplatValue<APFloat>();
      APFloat r = rhs.getSplatValue<APFloat>();
      APFloat result = l * r;
      return DenseElementsAttr::get(ty, result);
    }
  }

  return {};
}
} // namespace

OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
  auto lhs = getInput1();
  auto rhs = getInput2();
  auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
  auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
  if (!lhsTy || !rhsTy || !resultTy)
    return {};

  auto resultETy = resultTy.getElementType();
  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());

  const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
  if (rhsTy == resultTy) {
    if (isSplatZero(resultETy, lhsAttr))
      return lhsAttr.resizeSplat(resultTy);
    if (isSplatOne(resultETy, lhsAttr, shift))
      return rhs;
  }
  if (lhsTy == resultTy) {
    if (isSplatZero(resultETy, rhsAttr))
      return rhsAttr.resizeSplat(resultTy);
    if (isSplatOne(resultETy, rhsAttr, shift))
      return lhs;
  }

  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
}

OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
  if (!lhsTy || !rhsTy || !resultTy)
    return {};

  // Cannot create an ElementsAttr from non-int/float/index types
  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
      !rhsTy.getElementType().isIntOrIndexOrFloat())
    return {};

  auto resultETy = resultTy.getElementType();
  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());

  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
    return getInput1();

  if (!lhsAttr || !rhsAttr)
    return {};

  return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
                                                              resultTy);
}

namespace {
template <typename Cmp>
struct ComparisonFold {
  ComparisonFold() = default;
  APInt operator()(const APInt &l, const APInt &r) {
    return APInt(1, Cmp()(l, r));
  }

  APInt operator()(const APFloat &l, const APFloat &r) {
    return APInt(1, Cmp()(l, r));
  }
};

struct APIntFoldGreater {
  APIntFoldGreater() = default;
  APInt operator()(const APInt &l, const APInt &r) {
    return APInt(1, l.sgt(r));
  }
};

struct APIntFoldGreaterEqual {
  APIntFoldGreaterEqual() = default;
  APInt operator()(const APInt &l, const APInt &r) {
    return APInt(1, l.sge(r));
  }
};
} // namespace

OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());

  if (!lhsAttr || !rhsAttr)
    return {};

  return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
      lhsAttr, rhsAttr, resultTy);
}

OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());

  if (!lhsAttr || !rhsAttr)
    return {};

  return binaryFolder<APIntFoldGreaterEqual,
                      ComparisonFold<std::greater_equal<APFloat>>>(
      lhsAttr, rhsAttr, resultTy);
}

OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
  Value lhs = getInput1();
  Value rhs = getInput2();
  auto lhsTy = llvm::cast<ShapedType>(lhs.getType());

  // If we are comparing an integer value to itself it is always true. We can
  // not do this with float due to float values.
  if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
      resultTy.hasStaticShape() && lhs == rhs) {
    return DenseElementsAttr::get(resultTy, true);
  }

  if (!lhsAttr || !rhsAttr)
    return {};

  return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
                      ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
                                                              resultTy);
}

OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
  if (getInput().getType() == getType())
    return getInput();

  auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
  if (!operand)
    return {};

  auto inTy = llvm::cast<ShapedType>(getInput().getType());
  auto outTy = llvm::cast<ShapedType>(getType());
  auto inETy = inTy.getElementType();
  auto outETy = outTy.getElementType();

  if (operand.isSplat()) {
    if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
      bool overflow;
      auto splatVal = operand.getSplatValue<APFloat>();
      auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
      splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
                       &overflow);
      return SplatElementsAttr::get(outTy, splatVal);
    }

    if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
      auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
      APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
      splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
                                llvm::RoundingMode::NearestTiesToEven);
      return SplatElementsAttr::get(outTy, splatVal);
    }

    if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
      auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
      auto intVal = APSInt(
          llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
      auto floatVal = operand.getSplatValue<APFloat>();
      bool exact;
      floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
                                &exact);
      return SplatElementsAttr::get(outTy, intVal);
    }

    if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
      auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
      bool trunc =
          inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
      auto intVal = operand.getSplatValue<APInt>();
      auto bitwidth = outETy.getIntOrFloatBitWidth();

      if (trunc) {
        intVal = intVal.trunc(bitwidth);
      } else if (unsignIn) {
        intVal = intVal.zext(bitwidth);
      } else {
        intVal = intVal.sext(bitwidth);
      }

      return SplatElementsAttr::get(outTy, intVal);
    }
  }

  return {};
}

OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }

#define REDUCE_FOLDER(OP)                                                      \
  OpFoldResult OP::fold(FoldAdaptor adaptor) {                                 \
    ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType());         \
    if (!inputTy.hasRank())                                                    \
      return {};                                                               \
    if (inputTy != getType())                                                  \
      return {};                                                               \
    if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1)          \
      return getInput();                                                       \
    return {};                                                                 \
  }

REDUCE_FOLDER(ReduceAllOp)
REDUCE_FOLDER(ReduceAnyOp)
REDUCE_FOLDER(ReduceMaxOp)
REDUCE_FOLDER(ReduceMinOp)
REDUCE_FOLDER(ReduceProdOp)
REDUCE_FOLDER(ReduceSumOp)
#undef REDUCE_FOLDER

OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());

  if (!inputTy || !outputTy)
    return {};

  // Fold when the input and output types are the same. This is only safe when
  // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
  // there may still be a productive reshape.
  if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
    return getInput1();

  // reshape(reshape(x)) -> reshape(x)
  if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
          getInput1().getDefiningOp())) {
    getInput1Mutable().assign(reshapeOp.getInput1());
    return getResult();
  }

  // Cannot create an ElementsAttr from non-int/float/index types
  if (!inputTy.getElementType().isIntOrIndexOrFloat())
    return {};

  // reshape(const(x)) -> const(reshape-attr(x))
  if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
    // Constants must have static shape.
    if (!outputTy.hasStaticShape())
      return {};

    // Okay to duplicate splat constants.
    if (operand.isSplat())
      return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());

    // Don't duplicate other constants.
    if (!getInput1().hasOneUse())
      return {};

    return operand.reshape(
        llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
  }

  return {};
}

OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
  // If the pad is all zeros we can fold this operation away.
  if (adaptor.getPadding() && getInput1().getType() == getType()) {
    auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
    if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
      return getInput1();
    }
  }

  return {};
}

// Fold away cases where a tosa.resize operation returns a copy
// of the input image.
OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
  ArrayRef<int64_t> offset = getOffset();
  ArrayRef<int64_t> border = getBorder();
  ArrayRef<int64_t> scale = getScale();

  // Check unit scaling.
  if (scale[0] != scale[1] || scale[2] != scale[3]) {
    return {};
  }

  // There should be no offset.
  if (offset[0] != 0 || offset[1] != 0) {
    return {};
  }

  // There should be no border.
  if (border[0] != 0 || border[1] != 0) {
    return {};
  }

  auto input = getInput();
  auto inputTy = llvm::cast<RankedTensorType>(input.getType());
  auto resultTy = llvm::cast<RankedTensorType>(getType());
  if (inputTy != resultTy)
    return {};

  return input;
}

OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
  auto operand = getInput();
  auto operandTy = llvm::cast<ShapedType>(operand.getType());
  auto axis = getAxis();
  auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
  if (operandAttr)
    return operandAttr;

  // If the dim-length is 1, tosa.reverse is a no-op.
  if (operandTy.hasRank() &&
      (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
    return operand;

  return {};
}

OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());

  if (!inputTy || !outputTy)
    return {};

  if (inputTy == outputTy && inputTy.hasStaticShape())
    return getInput();

  if (!adaptor.getInput())
    return {};

  // Cannot create an ElementsAttr from non-int/float/index types
  if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
      !outputTy.getElementType().isIntOrIndexOrFloat())
    return {};

  auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
  if (operand.isSplat() && outputTy.hasStaticShape()) {
    return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
  }

  if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
      outputTy.getNumElements() == 1) {
    llvm::SmallVector<uint64_t> indices(getStart());
    auto value = operand.getValues<Attribute>()[indices];
    return SplatElementsAttr::get(outputTy, value);
  }

  return {};
}

OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
  if (getOnTrue() == getOnFalse())
    return getOnTrue();

  auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
  if (!predicate)
    return {};

  if (!predicate.isSplat())
    return {};
  return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
                                                         : getOnFalse();
}

OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
  bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
  if (allOnes && getInput1().getType() == getType())
    return getInput1();
  return {};
}

OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
  auto resultTy = llvm::cast<ShapedType>(getType());

  // Transposing splat values just means reshaping.
  if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
    if (input.isSplat() && resultTy.hasStaticShape() &&
        input.getType().getElementType() == resultTy.getElementType())
      return input.reshape(resultTy);
  }

  // Transpose does not change the input type.
  if (getInput1().getType() != getType())
    return {};

  // Transpose is not the identity transpose.
  SmallVector<int64_t> perms;
  if (getConstantPerms(perms).failed())
    return {};

  if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
    return {};

  return getInput1();
}

OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
  auto input = getInput1();
  // Element-wise log(exp(x)) = x
  if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
    return op.getInput1();
  }

  return {};
}

OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
  auto input = getInput1();
  // Element-wise exp(log(x)) = x
  if (auto op = input.getDefiningOp<tosa::LogOp>()) {
    return op.getInput1();
  }

  return {};
}

OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
  auto input = getInput1();
  // Element-wise negate(negate(x)) = x
  if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
    return op.getInput1();
  }

  return {};
}

OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
  auto input = getInput1();
  // Element-wise abs(abs(x)) = abs(x)
  if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
    return input;
  }

  return {};
}

OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
  // Fold consecutive concats on the same axis into a single op.
  // Keep track of the operands so we are able to construct a new concat
  // later. Conservatively assume that we double the number of operands when
  // folding
  SmallVector<Value, 8> concatOperands;
  concatOperands.reserve(2 * getNumOperands());

  // Find all operands that are foldable concats
  bool foundFoldableConcat = false;
  for (Value operand : getOperands()) {
    concatOperands.emplace_back(operand);

    auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
    if (!producer)
      continue;

    // Not foldable if axes are not the same
    if (getAxis() != producer.getAxis())
      continue;

    // Replace the original operand with all incoming operands
    foundFoldableConcat = true;
    concatOperands.pop_back();
    llvm::append_range(concatOperands, producer->getOperands());
  }

  if (!foundFoldableConcat)
    return {};

  getOperation()->setOperands(concatOperands);
  return getResult();
}

OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
  auto input = adaptor.getInput1();

  auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
  // Fold splat inputs only.
  if (!inputAttr || !inputAttr.isSplat())
    return {};

  auto shapeType = llvm::cast<ShapedType>(getType());
  if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
    auto floatVal = inputAttr.getSplatValue<APFloat>();
    return DenseElementsAttr::get(shapeType,
                                  ReciprocalOp::calcOneElement(floatVal));
  }

  return {};
}