//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"

#include "llvm/Support/Debug.h"
#include <optional>

#define DEBUG_TYPE "int-range-analysis"

using namespace mlir;
using namespace mlir::index;
using namespace mlir::intrange;

//===----------------------------------------------------------------------===//
// Constants
//===----------------------------------------------------------------------===//

void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                   SetIntRangeFn setResultRange) {
  const APInt &value = getValue();
  setResultRange(getResult(), ConstantIntRanges::constant(value));
}

void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
  bool value = getValue();
  APInt asInt(/*numBits=*/1, value);
  setResultRange(getResult(), ConstantIntRanges::constant(asInt));
}

//===----------------------------------------------------------------------===//
// Arithmec operations. All of these operations will have their results inferred
// using both the 64-bit values and truncated 32-bit values of their inputs,
// with the results being the union of those inferences, except where the
// truncation of the 64-bit result is equal to the 32-bit result (at which time
// we take the 64-bit result).
//===----------------------------------------------------------------------===//

// Some arithmetic inference functions allow specifying special overflow / wrap
// behavior. We do not require this for the IndexOps and use this helper to call
// the inference function without any `OverflowFlags`.
static std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
  return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
    return inferWithOvfFn(argRanges, OverflowFlags::None);
  };
}

void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                              SetIntRangeFn setResultRange) {
  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
                                           argRanges, CmpMode::Both));
}

void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                              SetIntRangeFn setResultRange) {
  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
                                           argRanges, CmpMode::Both));
}

void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                              SetIntRangeFn setResultRange) {
  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
                                           argRanges, CmpMode::Both));
}

void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
}

void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferIndexOp(inferDivS, argRanges, CmpMode::Signed));
}

void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                   SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
}

void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                   SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed));
}

void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                    SetIntRangeFn setResultRange) {
  return setResultRange(
      getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
}

void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
}

void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
}

void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
}

void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
}

void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
}

void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
}

void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                              SetIntRangeFn setResultRange) {
  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
                                           argRanges, CmpMode::Both));
}

void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
}

void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
}

void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                              SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
}

void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                             SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
}

void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                              SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
}

//===----------------------------------------------------------------------===//
// Casts
//===----------------------------------------------------------------------===//

static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range,
                                      unsigned srcWidth, unsigned destWidth,
                                      bool isSigned) {
  if (srcWidth < destWidth)
    return isSigned ? extSIRange(range, destWidth)
                    : extUIRange(range, destWidth);
  if (srcWidth > destWidth)
    return truncRange(range, destWidth);
  return range;
}

// When casting to `index`, we will take the union of the possible fixed-width
// casts.
static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
                                        Type sourceType, Type destType,
                                        bool isSigned) {
  unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
  if (sourceType.isIndex())
    return makeLikeDest(range, srcWidth, destWidth, isSigned);
  // We are casting to indexs, so use the union of the 32-bit and 64-bit casts
  ConstantIntRanges storageRange =
      makeLikeDest(range, srcWidth, destWidth, isSigned);
  ConstantIntRanges minWidthRange =
      makeLikeDest(range, srcWidth, indexMinWidth, isSigned);
  ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth);
  ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt);
  return ret;
}

void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
  Type sourceType = getOperand().getType();
  Type destType = getResult().getType();
  setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
                                             /*isSigned=*/true));
}

void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
  Type sourceType = getOperand().getType();
  Type destType = getResult().getType();
  setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
                                             /*isSigned=*/false));
}

//===----------------------------------------------------------------------===//
// CmpOp
//===----------------------------------------------------------------------===//

void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                              SetIntRangeFn setResultRange) {
  index::IndexCmpPredicate indexPred = getPred();
  intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];

  APInt min = APInt::getZero(1);
  APInt max = APInt::getAllOnes(1);

  std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);

  ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
                    rhsTrunc = truncRange(rhs, indexMinWidth);
  std::optional<bool> truthValue32 =
      intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);

  if (truthValue64 == truthValue32) {
    if (truthValue64.has_value() && *truthValue64)
      min = max;
    else if (truthValue64.has_value() && !(*truthValue64))
      max = min;
  }
  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
}

//===----------------------------------------------------------------------===//
// SizeOf, which is bounded between the two supported bitwidth (32 and 64).
//===----------------------------------------------------------------------===//

void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                 SetIntRangeFn setResultRange) {
  unsigned storageWidth =
      ConstantIntRanges::getStorageBitwidth(getResult().getType());
  APInt min(/*numBits=*/storageWidth, indexMinWidth);
  APInt max(/*numBits=*/storageWidth, indexMaxWidth);
  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
}