//===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Interfaces/ViewLikeInterface.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// ViewLike Interfaces
//===----------------------------------------------------------------------===//

/// Include the definitions of the loop-like interfaces.
#include "mlir/Interfaces/ViewLikeInterface.cpp.inc"

LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
                                                   StringRef name,
                                                   unsigned numElements,
                                                   ArrayRef<int64_t> staticVals,
                                                   ValueRange values) {
  // Check static and dynamic offsets/sizes/strides does not overflow type.
  if (staticVals.size() != numElements)
    return op->emitError("expected ") << numElements << " " << name
                                      << " values, got " << staticVals.size();
  unsigned expectedNumDynamicEntries =
      llvm::count_if(staticVals, [](int64_t staticVal) {
        return ShapedType::isDynamic(staticVal);
      });
  if (values.size() != expectedNumDynamicEntries)
    return op->emitError("expected ")
           << expectedNumDynamicEntries << " dynamic " << name << " values";
  return success();
}

LogicalResult
mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
  std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
  // Offsets can come in 2 flavors:
  //   1. Either single entry (when maxRanks == 1).
  //   2. Or as an array whose rank must match that of the mixed sizes.
  // So that the result type is well-formed.
  if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
      op.getMixedOffsets().size() != op.getMixedSizes().size())
    return op->emitError(
               "expected mixed offsets rank to match mixed sizes rank (")
           << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
           << ") so the rank of the result type is well-formed.";
  // Ranks of mixed sizes and strides must always match so the result type is
  // well-formed.
  if (op.getMixedSizes().size() != op.getMixedStrides().size())
    return op->emitError(
               "expected mixed sizes rank to match mixed strides rank (")
           << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
           << ") so the rank of the result type is well-formed.";

  if (failed(verifyListOfOperandsOrIntegers(
          op, "offset", maxRanks[0], op.getStaticOffsets(), op.getOffsets())))
    return failure();
  if (failed(verifyListOfOperandsOrIntegers(
          op, "size", maxRanks[1], op.getStaticSizes(), op.getSizes())))
    return failure();
  if (failed(verifyListOfOperandsOrIntegers(
          op, "stride", maxRanks[2], op.getStaticStrides(), op.getStrides())))
    return failure();

  for (int64_t offset : op.getStaticOffsets()) {
    if (offset < 0 && !ShapedType::isDynamic(offset))
      return op->emitError("expected offsets to be non-negative, but got ")
             << offset;
  }
  for (int64_t size : op.getStaticSizes()) {
    if (size < 0 && !ShapedType::isDynamic(size))
      return op->emitError("expected sizes to be non-negative, but got ")
             << size;
  }
  return success();
}

static char getLeftDelimiter(AsmParser::Delimiter delimiter) {
  switch (delimiter) {
  case AsmParser::Delimiter::Paren:
    return '(';
  case AsmParser::Delimiter::LessGreater:
    return '<';
  case AsmParser::Delimiter::Square:
    return '[';
  case AsmParser::Delimiter::Braces:
    return '{';
  default:
    llvm_unreachable("unsupported delimiter");
  }
}

static char getRightDelimiter(AsmParser::Delimiter delimiter) {
  switch (delimiter) {
  case AsmParser::Delimiter::Paren:
    return ')';
  case AsmParser::Delimiter::LessGreater:
    return '>';
  case AsmParser::Delimiter::Square:
    return ']';
  case AsmParser::Delimiter::Braces:
    return '}';
  default:
    llvm_unreachable("unsupported delimiter");
  }
}

void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
                                 OperandRange values,
                                 ArrayRef<int64_t> integers,
                                 ArrayRef<bool> scalables, TypeRange valueTypes,
                                 AsmParser::Delimiter delimiter) {
  char leftDelimiter = getLeftDelimiter(delimiter);
  char rightDelimiter = getRightDelimiter(delimiter);
  printer << leftDelimiter;
  if (integers.empty()) {
    printer << rightDelimiter;
    return;
  }

  unsigned dynamicValIdx = 0;
  unsigned scalableIndexIdx = 0;
  llvm::interleaveComma(integers, printer, [&](int64_t integer) {
    if (!scalables.empty() && scalables[scalableIndexIdx])
      printer << "[";
    if (ShapedType::isDynamic(integer)) {
      printer << values[dynamicValIdx];
      if (!valueTypes.empty())
        printer << " : " << valueTypes[dynamicValIdx];
      ++dynamicValIdx;
    } else {
      printer << integer;
    }
    if (!scalables.empty() && scalables[scalableIndexIdx])
      printer << "]";

    scalableIndexIdx++;
  });

  printer << rightDelimiter;
}

ParseResult mlir::parseDynamicIndexList(
    OpAsmParser &parser,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
    DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables,
    SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {

  SmallVector<int64_t, 4> integerVals;
  SmallVector<bool, 4> scalableVals;
  auto parseIntegerOrValue = [&]() {
    OpAsmParser::UnresolvedOperand operand;
    auto res = parser.parseOptionalOperand(operand);

    // When encountering `[`, assume that this is a scalable index.
    scalableVals.push_back(parser.parseOptionalLSquare().succeeded());

    if (res.has_value() && succeeded(res.value())) {
      values.push_back(operand);
      integerVals.push_back(ShapedType::kDynamic);
      if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
        return failure();
    } else {
      int64_t integer;
      if (failed(parser.parseInteger(integer)))
        return failure();
      integerVals.push_back(integer);
    }

    // If this is assumed to be a scalable index, verify that there's a closing
    // `]`.
    if (scalableVals.back() && parser.parseOptionalRSquare().failed())
      return failure();
    return success();
  };
  if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
                                     " in dynamic index list"))
    return parser.emitError(parser.getNameLoc())
           << "expected SSA value or integer";
  integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
  scalables = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
  return success();
}

bool mlir::detail::sameOffsetsSizesAndStrides(
    OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
    llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {
  if (a.getStaticOffsets().size() != b.getStaticOffsets().size())
    return false;
  if (a.getStaticSizes().size() != b.getStaticSizes().size())
    return false;
  if (a.getStaticStrides().size() != b.getStaticStrides().size())
    return false;
  for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
    if (!cmp(std::get<0>(it), std::get<1>(it)))
      return false;
  for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
    if (!cmp(std::get<0>(it), std::get<1>(it)))
      return false;
  for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
    if (!cmp(std::get<0>(it), std::get<1>(it)))
      return false;
  return true;
}

unsigned mlir::detail::getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
                                                   unsigned idx) {
  return std::count_if(staticVals.begin(), staticVals.begin() + idx,
                       [&](int64_t val) { return ShapedType::isDynamic(val); });
}