//===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Interfaces/DestinationStyleOpInterface.h"

using namespace mlir;

namespace mlir {
#include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
} // namespace mlir

namespace {
size_t getNumTensorResults(Operation *op) {
  size_t numTensorResults = 0;
  for (auto t : op->getResultTypes()) {
    if (isa<TensorType>(t)) {
      ++numTensorResults;
    }
  }
  return numTensorResults;
}
} // namespace

LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
  DestinationStyleOpInterface dstStyleOp =
      cast<DestinationStyleOpInterface>(op);

  SmallVector<OpOperand *> outputTensorOperands;
  for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
    Type type = operand.get().getType();
    if (isa<TensorType>(type)) {
      outputTensorOperands.push_back(&operand);
    } else if (!isa<BaseMemRefType>(type)) {
      return op->emitOpError("expected that operand #")
             << operand.getOperandNumber() << " is a tensor or a memref";
    }
  }

  // Verify the number of tensor results matches the number of output tensors.
  if (getNumTensorResults(op) != outputTensorOperands.size())
    return op->emitOpError("expected the number of tensor results (")
           << getNumTensorResults(op)
           << ") to be equal to the number of output tensors ("
           << outputTensorOperands.size() << ")";

  for (OpOperand *opOperand : outputTensorOperands) {
    OpResult result = dstStyleOp.getTiedOpResult(opOperand);
    if (result.getType() != opOperand->get().getType())
      return op->emitOpError("expected type of operand #")
             << opOperand->getOperandNumber() << " ("
             << opOperand->get().getType() << ")"
             << " to match type of corresponding result (" << result.getType()
             << ")";
  }

  return success();
}