//===- TosaValidation.cpp ------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Validate if TOSA dialect input matchs with the specification for given
// requirements.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"

#include <string>

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
namespace tosa {
#define GEN_PASS_DEF_TOSAVALIDATION
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
} // namespace tosa
} // namespace mlir

using namespace mlir;
using namespace mlir::tosa;

namespace {

static LogicalResult checkConstantOperandPad(Operation *op) {
  if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
    DenseElementsAttr paddings;
    if (!matchPattern(padOp.getPadding(), m_Constant(&paddings)))
      return op->emitOpError("padding of pad is not constant");

    DenseElementsAttr padConst;
    // Assume this op is zero-padding if padConst is not presented.
    if (padOp.getPadConst() &&
        !matchPattern(padOp.getPadConst(), m_Constant(&padConst)))
      return op->emitOpError("pad_const of pad is not constant");
  }
  return success();
}

static LogicalResult checkConstantOperandTranspose(Operation *op) {
  if (auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
    DenseElementsAttr perms;
    if (!matchPattern(transposeOp.getPerms(), m_Constant(&perms)))
      return op->emitOpError("perms of transpose is not constant");
  }
  return success();
}

static LogicalResult checkConstantOperandFullyConnected(Operation *op) {
  if (auto fcOp = dyn_cast<tosa::FullyConnectedOp>(op)) {
    DenseElementsAttr weight;
    if (!matchPattern(fcOp.getWeight(), m_Constant(&weight)))
      return op->emitOpError("weight of fully_connected is not constant");

    DenseElementsAttr bias;
    if (!matchPattern(fcOp.getBias(), m_Constant(&bias)))
      return op->emitOpError("bias of fully_connected is not constant");
  }
  return success();
}

struct TosaLevel {
  int32_t MAX_RANK = 0;
  int32_t MAX_KERNEL = 0;
  int32_t MAX_STRIDE = 0;
  int32_t MAX_SCALE = 0;

  // @todo: MAX_LOG2_SIZE value and checks

  bool operator==(const TosaLevel &rhs) {
    return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
           MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
  }
};

static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};

//===----------------------------------------------------------------------===//
// TOSA Validation Pass.
//===----------------------------------------------------------------------===//

struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
public:
  explicit TosaValidation() { populateConstantOperandChecks(); }
  explicit TosaValidation(const TosaValidationOptions &options)
      : TosaValidation() {
    this->profile = options.profile;
    this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
    this->level = options.level;
  }
  void runOnOperation() final;

  LogicalResult applyConstantOperandCheck(Operation *op) {
    for (auto &checker : constCheckers) {
      if (failed(checker(op)))
        return failure();
    }
    return success();
  }

  LogicalResult applyLevelCheck(Operation *op);

  // check variable read/write data types against variable declarations
  LogicalResult applyVariableCheck(Operation *op);

private:
  void populateConstantOperandChecks() {
    constCheckers.emplace_back(checkConstantOperandPad);
    constCheckers.emplace_back(checkConstantOperandTranspose);
    constCheckers.emplace_back(checkConstantOperandFullyConnected);
  }

  bool levelCheckKernel(Operation *op, int32_t v,
                        const std::string &checkDesc) {
    if (v > tosaLevel.MAX_KERNEL) {
      op->emitOpError() << "failed level check: " << checkDesc;
      return false;
    }
    return true;
  }

  bool levelCheckStride(Operation *op, int32_t v,
                        const std::string &checkDesc) {
    if (v > tosaLevel.MAX_STRIDE) {
      op->emitOpError() << "failed level check: " << checkDesc;
      return false;
    }
    return true;
  }

  bool levelCheckScale(Operation *op, int32_t v, const std::string &checkDesc) {
    if (v > tosaLevel.MAX_SCALE) {
      op->emitOpError() << "failed level check: " << checkDesc;
      return false;
    }
    return true;
  }

  bool levelCheckRank(Operation *op, const Value &v,
                      const std::string &checkDesc) {
    if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
      if (!type.hasRank()) {
        op->emitOpError() << "failed level check: unranked tensor";
        return false;
      }
      if (type.getRank() > tosaLevel.MAX_RANK) {
        op->emitOpError() << "failed level check: " << checkDesc;
        return false;
      }
    }
    return true;
  }

  template <typename T>
  bool levelCheckRanksFor(Operation *op) {
    if (dyn_cast<T>(op)) {
      // level check ranks of all operands and results
      for (auto v : op->getOperands()) {
        if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
          return false;
      }
      for (auto v : op->getResults()) {
        if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
          return false;
      }
    }
    return true;
  }

  bool levelCheckRanks(Operation *op) {
#define CHECK_RANKS_FOR(tosaOp)                                                \
  if (!levelCheckRanksFor<tosaOp##Op>(op))                                     \
    return false;

    // tensor operators:
    CHECK_RANKS_FOR(ArgMax);
    // all activation functions:
    CHECK_RANKS_FOR(Clamp);
    CHECK_RANKS_FOR(Sigmoid);
    CHECK_RANKS_FOR(Tanh);
    // all elementwise binary operators:
    CHECK_RANKS_FOR(Add);
    CHECK_RANKS_FOR(ArithmeticRightShift);
    CHECK_RANKS_FOR(BitwiseAnd);
    CHECK_RANKS_FOR(BitwiseOr);
    CHECK_RANKS_FOR(BitwiseXor);
    CHECK_RANKS_FOR(IntDiv);
    CHECK_RANKS_FOR(LogicalAnd);
    CHECK_RANKS_FOR(LogicalLeftShift);
    CHECK_RANKS_FOR(LogicalRightShift);
    CHECK_RANKS_FOR(LogicalOr);
    CHECK_RANKS_FOR(LogicalXor);
    CHECK_RANKS_FOR(Maximum);
    CHECK_RANKS_FOR(Minimum);
    CHECK_RANKS_FOR(Mul);
    CHECK_RANKS_FOR(Pow);
    CHECK_RANKS_FOR(Sub);
    CHECK_RANKS_FOR(Table);
    // all elementwise unary operators:
    CHECK_RANKS_FOR(Abs);
    CHECK_RANKS_FOR(BitwiseNot);
    CHECK_RANKS_FOR(Ceil);
    CHECK_RANKS_FOR(Clz);
    CHECK_RANKS_FOR(Exp);
    CHECK_RANKS_FOR(Floor);
    CHECK_RANKS_FOR(Log);
    CHECK_RANKS_FOR(LogicalNot);
    CHECK_RANKS_FOR(Negate);
    CHECK_RANKS_FOR(Reciprocal);
    CHECK_RANKS_FOR(Rsqrt);
    // all elementwise ternary operators:
    CHECK_RANKS_FOR(Select);
    // all comparison operators:
    CHECK_RANKS_FOR(Equal);
    CHECK_RANKS_FOR(Greater);
    CHECK_RANKS_FOR(GreaterEqual);
    // all reduction operators:
    CHECK_RANKS_FOR(ReduceAll);
    CHECK_RANKS_FOR(ReduceAny);
    CHECK_RANKS_FOR(ReduceMax);
    CHECK_RANKS_FOR(ReduceMin);
    CHECK_RANKS_FOR(ReduceProd);
    CHECK_RANKS_FOR(ReduceSum);
    // all data layout operators:
    CHECK_RANKS_FOR(Concat);
    CHECK_RANKS_FOR(Pad);
    CHECK_RANKS_FOR(Reshape);
    CHECK_RANKS_FOR(Reverse);
    CHECK_RANKS_FOR(Slice);
    CHECK_RANKS_FOR(Tile);
    CHECK_RANKS_FOR(Transpose);
    // all type conversion operators:
    CHECK_RANKS_FOR(Cast);
    CHECK_RANKS_FOR(Rescale);
    // all data nodes operators:
    CHECK_RANKS_FOR(Const);
    CHECK_RANKS_FOR(Identity);

#undef CHECK_RANKS_FOR
    return true;
  }

  // Pool Op: level check kernel/stride/pad values
  template <typename T>
  bool levelCheckPool(Operation *op) {
    if (auto poolOp = dyn_cast<T>(op)) {
      for (auto k : poolOp.getKernel()) {
        if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
          return false;
        }
      }
      for (auto s : poolOp.getStride()) {
        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
          return false;
        }
      }
      for (auto p : poolOp.getPad()) {
        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
          return false;
        }
      }
    }
    return true;
  }

  // Conv Op: level check dilation/stride/pad values
  template <typename T>
  bool levelCheckConv(Operation *op) {
    if (auto convOp = dyn_cast<T>(op)) {

      for (auto k : convOp.getDilation()) {
        if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
          return false;
        }
      }
      for (auto p : convOp.getPad()) {
        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
          return false;
        }
      }
      for (auto s : convOp.getStride()) {
        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
          return false;
        }
      }
      auto dilation = convOp.getDilation();
      if (ShapedType weightType =
              dyn_cast<ShapedType>(op->getOperand(1).getType())) {
        auto shape = weightType.getShape();
        if (isa<tosa::Conv2DOp>(op)) {
          assert(shape.size() == 4);
          assert(dilation.size() == 2);
          if (!levelCheckKernel(op, dilation[0] * shape[1],
                                "dilation_y * KH <= MAX_KERNEL)") ||
              !levelCheckKernel(op, dilation[1] * shape[2],
                                "dilation_x * KW <= MAX_KERNEL)"))
            return false;
        } else if (isa<tosa::Conv3DOp>(op)) {
          assert(shape.size() == 5);
          assert(dilation.size() == 3);
          if (!levelCheckKernel(op, dilation[0] * shape[1],
                                "dilation_d * KD <= MAX_KERNEL)") ||
              !levelCheckKernel(op, dilation[1] * shape[2],
                                "dilation_y * KH <= MAX_KERNEL)") ||
              !levelCheckKernel(op, dilation[2] * shape[3],
                                "dilation_x * KW <= MAX_KERNEL)"))
            return false;
        } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
          assert(shape.size() == 4);
          assert(dilation.size() == 2);
          if (!levelCheckKernel(op, dilation[0] * shape[0],
                                "dilation_y * KH <= MAX_KERNEL)") ||
              !levelCheckKernel(op, dilation[1] * shape[1],
                                "dilation_x * KW <= MAX_KERNEL)"))
            return false;
        }
      }
    }
    return true;
  }

  // FFT op: level check H, W in input shape [N,H,W]
  template <typename T>
  bool levelCheckFFT(Operation *op) {
    if (isa<T>(op)) {
      for (auto v : op->getOperands()) {
        if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
          auto shape = type.getShape();
          assert(shape.size() == 3);
          if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
              !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
            return false;
          }
        }
      }
    }
    return true;
  }

  // TransposeConv2d op: level check kH/kW, outpad, and stride
  bool levelCheckTransposeConv2d(Operation *op) {
    if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
      if (ShapedType filterType =
              dyn_cast<ShapedType>(transpose.getFilter().getType())) {
        auto shape = filterType.getShape();
        assert(shape.size() == 4);
        // level check kernel sizes for kH and KW
        if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
            !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
          return false;
        }
      }
      for (auto p : transpose.getOutPad()) {
        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
          return false;
        }
      }
      for (auto s : transpose.getStride()) {
        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
          return false;
        }
      }
    }
    return true;
  }

  // Resize op: level check max scales
  bool levelCheckResize(Operation *op) {
    if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
      auto scale = resize.getScale();
      int16_t scaleYN = scale[0];
      int16_t scaleYD = scale[1];
      int16_t scaleXN = scale[2];
      int16_t scaleXD = scale[3];
      if (!levelCheckScale(op, scaleYN / scaleYD,
                           "scale_y_n/scale_y_d <= MAX_SCALE") ||
          !levelCheckScale(op, scaleXN / scaleXD,
                           "scale_x_n/scale_x_d <= MAX_SCALE")) {
        return false;
      }
    }
    return true;
  }

  // configure profile and level values from pass options profileName and
  // levelName
  void configLevelAndProfile() {
    tosaLevel = TOSA_LEVEL_NONE;
    if (level == TosaLevelEnum::EightK) {
      tosaLevel = TOSA_LEVEL_EIGHTK;
    }
  }

  bool CheckVariable(Operation *op);
  bool CheckVariableReadOrWrite(Operation *op);

  bool isValidElementType(Type type);

  SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
  TosaLevel tosaLevel;
  DenseMap<StringAttr, mlir::Type> variablesMap;
};

LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
  if (tosaLevel == TOSA_LEVEL_NONE) {
    // no need to do level checks
    return success();
  }

  if (!levelCheckRanks(op)) {
    return failure();
  }

  // additional level checks from spec 0.70
  if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
      !levelCheckConv<tosa::Conv2DOp>(op) ||
      !levelCheckConv<tosa::Conv3DOp>(op) ||
      !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
      !levelCheckFFT<tosa::FFT2dOp>(op) ||
      !levelCheckPool<tosa::MaxPool2dOp>(op) ||
      !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
      !levelCheckResize(op)) {
    return failure();
  }

  return success();
}

inline bool CompatibleTypes(const mlir::Type &type,
                            const mlir::Type &declaredType) {
  // for now, simply use type equality comparison
  return type == declaredType;
}

bool TosaValidation::CheckVariable(Operation *op) {
  if (isa<mlir::tosa::VariableOp>(op)) {
    auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));

    if (variablesMap.count(nameAttr)) {
      op->emitOpError() << "name has already been declared";
      return false;
    }

    auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
    mlir::Type type = typeAttr.getValue();

    variablesMap[nameAttr] = type;
  }

  return true;
}

bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
  if (isa<mlir::tosa::VariableReadOp>(op) ||
      isa<mlir::tosa::VariableWriteOp>(op)) {
    auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));

    if (!variablesMap.count(nameAttr)) {
      op->emitOpError() << "name has not been declared";
      return false;
    }

    auto varType = variablesMap[nameAttr];

    for (auto v : op->getOperands()) {
      auto type = v.getType();
      if (!CompatibleTypes(type, varType)) {
        op->emitOpError() << "operand type does not equal variable type";
        return false;
      }
    }

    for (auto v : op->getResults()) {
      auto type = v.getType();
      if (!CompatibleTypes(type, varType)) {
        op->emitOpError() << "result type does not equal variable type";
        return false;
      }
    }
  }

  return true;
}

LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
  if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
    return failure();
  }
  return success();
}

bool TosaValidation::isValidElementType(Type type) {
  if (isa<FloatType>(type)) {
    if (profile == TosaProfileEnum::BaseInference)
      return false;
    return type.isF32() || type.isF16() || type.isBF16();
  }
  if (auto intTy = dyn_cast<IntegerType>(type)) {
    if (intTy.isUnsigned()) {
      switch (intTy.getWidth()) {
      case 8:
      case 16:
        return true;
      default:
        return false;
      }
    } else {
      // Signless - treated as signed.
      switch (intTy.getWidth()) {
      case 1:
      case 4:
      case 8:
      case 16:
      case 32:
      case 48:
      case 64:
        return true;
      default:
        return false;
      }
    }
    return false;
  }
  return true;
}

void TosaValidation::runOnOperation() {
  configLevelAndProfile();
  getOperation().walk([&](Operation *op) {
    for (Value operand : op->getOperands()) {
      auto elementTy = getElementTypeOrSelf(operand);
      if (!isValidElementType(elementTy)) {
        op->emitOpError() << "is not profile-aligned: element type "
                          << elementTy << " is not legal";
        return signalPassFailure();
      }
    }
    for (Type resultTy : op->getResultTypes()) {
      auto elementTy = getElementTypeOrSelf(resultTy);
      if (!isValidElementType(elementTy)) {
        op->emitOpError() << "is not profile-aligned: element type "
                          << elementTy << " is not legal";
        return signalPassFailure();
      }
    }

    // Some uses of TOSA rely on the constant operands of particular
    // operations.
    if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
      signalPassFailure();

    // do level checks
    if (failed(applyLevelCheck(op)))
      signalPassFailure();

    // do variable type checks
    if (failed(applyVariableCheck(op)))
      signalPassFailure();
  });
}
} // namespace