//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert the Arith dialect to the EmitC
// dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// Conversion Patterns
//===----------------------------------------------------------------------===//

namespace {
class ArithConstantOpConversionPattern
    : public OpConversionPattern<arith::ConstantOp> {
public:
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(arith::ConstantOp arithConst,
                  arith::ConstantOp::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
    if (!newTy)
      return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
    rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
                                                   adaptor.getValue());
    return success();
  }
};

/// Get the signed or unsigned type corresponding to \p ty.
Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
  if (isa<IntegerType>(ty)) {
    if (ty.isUnsignedInteger() != needsUnsigned) {
      auto signedness = needsUnsigned
                            ? IntegerType::SignednessSemantics::Unsigned
                            : IntegerType::SignednessSemantics::Signed;
      return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
                              signedness);
    }
  } else if (emitc::isPointerWideType(ty)) {
    if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
      if (needsUnsigned)
        return emitc::SizeTType::get(ty.getContext());
      return emitc::PtrDiffTType::get(ty.getContext());
    }
  }
  return ty;
}

/// Insert a cast operation to type \p ty if \p val does not have this type.
Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
  return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
}

class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
public:
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {

    if (!isa<FloatType>(adaptor.getRhs().getType())) {
      return rewriter.notifyMatchFailure(op.getLoc(),
                                         "cmpf currently only supported on "
                                         "floats, not tensors/vectors thereof");
    }

    bool unordered = false;
    emitc::CmpPredicate predicate;
    switch (op.getPredicate()) {
    case arith::CmpFPredicate::AlwaysFalse: {
      auto constant = rewriter.create<emitc::ConstantOp>(
          op.getLoc(), rewriter.getI1Type(),
          rewriter.getBoolAttr(/*value=*/false));
      rewriter.replaceOp(op, constant);
      return success();
    }
    case arith::CmpFPredicate::OEQ:
      unordered = false;
      predicate = emitc::CmpPredicate::eq;
      break;
    case arith::CmpFPredicate::OGT:
      unordered = false;
      predicate = emitc::CmpPredicate::gt;
      break;
    case arith::CmpFPredicate::OGE:
      unordered = false;
      predicate = emitc::CmpPredicate::ge;
      break;
    case arith::CmpFPredicate::OLT:
      unordered = false;
      predicate = emitc::CmpPredicate::lt;
      break;
    case arith::CmpFPredicate::OLE:
      unordered = false;
      predicate = emitc::CmpPredicate::le;
      break;
    case arith::CmpFPredicate::ONE:
      unordered = false;
      predicate = emitc::CmpPredicate::ne;
      break;
    case arith::CmpFPredicate::ORD: {
      // ordered, i.e. none of the operands is NaN
      auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
                                      adaptor.getRhs());
      rewriter.replaceOp(op, cmp);
      return success();
    }
    case arith::CmpFPredicate::UEQ:
      unordered = true;
      predicate = emitc::CmpPredicate::eq;
      break;
    case arith::CmpFPredicate::UGT:
      unordered = true;
      predicate = emitc::CmpPredicate::gt;
      break;
    case arith::CmpFPredicate::UGE:
      unordered = true;
      predicate = emitc::CmpPredicate::ge;
      break;
    case arith::CmpFPredicate::ULT:
      unordered = true;
      predicate = emitc::CmpPredicate::lt;
      break;
    case arith::CmpFPredicate::ULE:
      unordered = true;
      predicate = emitc::CmpPredicate::le;
      break;
    case arith::CmpFPredicate::UNE:
      unordered = true;
      predicate = emitc::CmpPredicate::ne;
      break;
    case arith::CmpFPredicate::UNO: {
      // unordered, i.e. either operand is nan
      auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
                                        adaptor.getRhs());
      rewriter.replaceOp(op, cmp);
      return success();
    }
    case arith::CmpFPredicate::AlwaysTrue: {
      auto constant = rewriter.create<emitc::ConstantOp>(
          op.getLoc(), rewriter.getI1Type(),
          rewriter.getBoolAttr(/*value=*/true));
      rewriter.replaceOp(op, constant);
      return success();
    }
    }

    // Compare the values naively
    auto cmpResult =
        rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
                                      adaptor.getLhs(), adaptor.getRhs());

    // Adjust the results for unordered/ordered semantics
    if (unordered) {
      auto isUnordered = createCheckIsUnordered(
          rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
      rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(),
                                                      isUnordered, cmpResult);
      return success();
    }

    auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
                                          adaptor.getLhs(), adaptor.getRhs());
    rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
                                                     isOrdered, cmpResult);
    return success();
  }

private:
  /// Return a value that is true if \p operand is NaN.
  Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
              Value operand) const {
    // A value is NaN exactly when it compares unequal to itself.
    return rewriter.create<emitc::CmpOp>(
        loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand);
  }

  /// Return a value that is true if \p operand is not NaN.
  Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
                 Value operand) const {
    // A value is not NaN exactly when it compares equal to itself.
    return rewriter.create<emitc::CmpOp>(
        loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand);
  }

  /// Return a value that is true if the operands \p first and \p second are
  /// unordered (i.e., at least one of them is NaN).
  Value createCheckIsUnordered(ConversionPatternRewriter &rewriter,
                               Location loc, Value first, Value second) const {
    auto firstIsNaN = isNaN(rewriter, loc, first);
    auto secondIsNaN = isNaN(rewriter, loc, second);
    return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(),
                                               firstIsNaN, secondIsNaN);
  }

  /// Return a value that is true if the operands \p first and \p second are
  /// both ordered (i.e., none one of them is NaN).
  Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc,
                             Value first, Value second) const {
    auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
    auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
    return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(),
                                                firstIsNotNaN, secondIsNotNaN);
  }
};

class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
public:
  using OpConversionPattern::OpConversionPattern;

  bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
    switch (pred) {
    case arith::CmpIPredicate::eq:
    case arith::CmpIPredicate::ne:
    case arith::CmpIPredicate::slt:
    case arith::CmpIPredicate::sle:
    case arith::CmpIPredicate::sgt:
    case arith::CmpIPredicate::sge:
      return false;
    case arith::CmpIPredicate::ult:
    case arith::CmpIPredicate::ule:
    case arith::CmpIPredicate::ugt:
    case arith::CmpIPredicate::uge:
      return true;
    }
    llvm_unreachable("unknown cmpi predicate kind");
  }

  emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
    switch (pred) {
    case arith::CmpIPredicate::eq:
      return emitc::CmpPredicate::eq;
    case arith::CmpIPredicate::ne:
      return emitc::CmpPredicate::ne;
    case arith::CmpIPredicate::slt:
    case arith::CmpIPredicate::ult:
      return emitc::CmpPredicate::lt;
    case arith::CmpIPredicate::sle:
    case arith::CmpIPredicate::ule:
      return emitc::CmpPredicate::le;
    case arith::CmpIPredicate::sgt:
    case arith::CmpIPredicate::ugt:
      return emitc::CmpPredicate::gt;
    case arith::CmpIPredicate::sge:
    case arith::CmpIPredicate::uge:
      return emitc::CmpPredicate::ge;
    }
    llvm_unreachable("unknown cmpi predicate kind");
  }

  LogicalResult
  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {

    Type type = adaptor.getLhs().getType();
    if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
      return rewriter.notifyMatchFailure(
          op, "expected integer or size_t/ssize_t/ptrdiff_t type");
    }

    bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
    emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());

    Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
    Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
    Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);

    rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
    return success();
  }
};

class NegFOpConversion : public OpConversionPattern<arith::NegFOp> {
public:
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {

    auto adaptedOp = adaptor.getOperand();
    auto adaptedOpType = adaptedOp.getType();

    if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) {
      return rewriter.notifyMatchFailure(
          op.getLoc(),
          "negf currently only supports scalar types, not vectors or tensors");
    }

    if (!emitc::isSupportedFloatType(adaptedOpType)) {
      return rewriter.notifyMatchFailure(
          op.getLoc(), "floating-point type is not supported by EmitC");
    }

    rewriter.replaceOpWithNewOp<emitc::UnaryMinusOp>(op, adaptedOpType,
                                                     adaptedOp);
    return success();
  }
};

template <typename ArithOp, bool castToUnsigned>
class CastConversion : public OpConversionPattern<ArithOp> {
public:
  using OpConversionPattern<ArithOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {

    Type opReturnType = this->getTypeConverter()->convertType(op.getType());
    if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
                           emitc::isPointerWideType(opReturnType)))
      return rewriter.notifyMatchFailure(
          op, "expected integer or size_t/ssize_t/ptrdiff_t result type");

    if (adaptor.getOperands().size() != 1) {
      return rewriter.notifyMatchFailure(
          op, "CastConversion only supports unary ops");
    }

    Type operandType = adaptor.getIn().getType();
    if (!operandType || !(isa<IntegerType>(operandType) ||
                          emitc::isPointerWideType(operandType)))
      return rewriter.notifyMatchFailure(
          op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");

    // Signed (sign-extending) casts from i1 are not supported.
    if (operandType.isInteger(1) && !castToUnsigned)
      return rewriter.notifyMatchFailure(op,
                                         "operation not supported on i1 type");

    // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
    // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
    // truncation.
    if (opReturnType.isInteger(1)) {
      Type attrType = (emitc::isPointerWideType(operandType))
                          ? rewriter.getIndexType()
                          : operandType;
      auto constOne = rewriter.create<emitc::ConstantOp>(
          op.getLoc(), operandType, rewriter.getOneAttr(attrType));
      auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
          op.getLoc(), operandType, adaptor.getIn(), constOne);
      rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
                                                 oneAndOperand);
      return success();
    }

    bool isTruncation =
        (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
         operandType.getIntOrFloatBitWidth() >
             opReturnType.getIntOrFloatBitWidth());
    bool doUnsigned = castToUnsigned || isTruncation;

    // Adapt the signedness of the result (bitwidth-preserving cast)
    // This is needed e.g., if the return type is signless.
    Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);

    // Adapt the signedness of the operand (bitwidth-preserving cast)
    Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
    Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);

    // Actual cast (may change bitwidth)
    auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
                                                        castDestType, actualOp);

    // Cast to the expected output type
    auto result = adaptValueType(cast, rewriter, opReturnType);

    rewriter.replaceOp(op, result);
    return success();
  }
};

template <typename ArithOp>
class UnsignedCastConversion : public CastConversion<ArithOp, true> {
  using CastConversion<ArithOp, true>::CastConversion;
};

template <typename ArithOp>
class SignedCastConversion : public CastConversion<ArithOp, false> {
  using CastConversion<ArithOp, false>::CastConversion;
};

template <typename ArithOp, typename EmitCOp>
class ArithOpConversion final : public OpConversionPattern<ArithOp> {
public:
  using OpConversionPattern<ArithOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {

    Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
    if (!newTy)
      return rewriter.notifyMatchFailure(arithOp,
                                         "converting result type failed");
    rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
                                                  adaptor.getOperands());

    return success();
  }
};

template <typename ArithOp, typename EmitCOp>
class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
public:
  using OpConversionPattern<ArithOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {

    Type type = this->getTypeConverter()->convertType(op.getType());
    if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
      return rewriter.notifyMatchFailure(
          op, "expected integer or size_t/ssize_t/ptrdiff_t type");
    }

    if (type.isInteger(1)) {
      // arith expects wrap-around arithmethic, which doesn't happen on `bool`.
      return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
    }

    Type arithmeticType = type;
    if ((type.isSignlessInteger() || type.isSignedInteger()) &&
        !bitEnumContainsAll(op.getOverflowFlags(),
                            arith::IntegerOverflowFlags::nsw)) {
      // If the C type is signed and the op doesn't guarantee "No Signed Wrap",
      // we compute in unsigned integers to avoid UB.
      arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
                                               /*isSigned=*/false);
    }

    Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
    Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);

    Value arithmeticResult = rewriter.template create<EmitCOp>(
        op.getLoc(), arithmeticType, lhs, rhs);

    Value result = adaptValueType(arithmeticResult, rewriter, type);

    rewriter.replaceOp(op, result);
    return success();
  }
};

template <typename ArithOp, typename EmitCOp>
class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
public:
  using OpConversionPattern<ArithOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {

    Type type = this->getTypeConverter()->convertType(op.getType());
    if (!isa_and_nonnull<IntegerType>(type)) {
      return rewriter.notifyMatchFailure(
          op,
          "expected integer type, vector/tensor support not yet implemented");
    }

    // Bitwise ops can be performed directly on booleans
    if (type.isInteger(1)) {
      rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(),
                                           adaptor.getRhs());
      return success();
    }

    // Bitwise ops are defined by the C standard on unsigned operands.
    Type arithmeticType =
        adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true);

    Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
    Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);

    Value arithmeticResult = rewriter.template create<EmitCOp>(
        op.getLoc(), arithmeticType, lhs, rhs);

    Value result = adaptValueType(arithmeticResult, rewriter, type);

    rewriter.replaceOp(op, result);
    return success();
  }
};

template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
class ShiftOpConversion : public OpConversionPattern<ArithOp> {
public:
  using OpConversionPattern<ArithOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {

    Type type = this->getTypeConverter()->convertType(op.getType());
    if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
      return rewriter.notifyMatchFailure(
          op, "expected integer or size_t/ssize_t/ptrdiff_t type");
    }

    if (type.isInteger(1)) {
      return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
    }

    Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);

    Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
    // Shift amount interpreted as unsigned per Arith dialect spec.
    Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
                                               /*needsUnsigned=*/true);
    Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);

    // Add a runtime check for overflow
    Value width;
    if (emitc::isPointerWideType(type)) {
      Value eight = rewriter.create<emitc::ConstantOp>(
          op.getLoc(), rhsType, rewriter.getIndexAttr(8));
      emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
          op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
      width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
                                            sizeOfCall.getResult(0));
    } else {
      width = rewriter.create<emitc::ConstantOp>(
          op.getLoc(), rhsType,
          rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
    }

    Value excessCheck = rewriter.create<emitc::CmpOp>(
        op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);

    // Any concrete value is a valid refinement of poison.
    Value poison = rewriter.create<emitc::ConstantOp>(
        op.getLoc(), arithmeticType,
        (isa<IntegerType>(arithmeticType)
             ? rewriter.getIntegerAttr(arithmeticType, 0)
             : rewriter.getIndexAttr(0)));

    emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
        op.getLoc(), arithmeticType, /*do_not_inline=*/false);
    Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
    auto currentPoint = rewriter.getInsertionPoint();
    rewriter.setInsertionPointToStart(&bodyBlock);
    Value arithmeticResult =
        rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
    Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
        op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
    rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
    rewriter.setInsertionPoint(op->getBlock(), currentPoint);

    Value result = adaptValueType(ternary, rewriter, type);

    rewriter.replaceOp(op, result);
    return success();
  }
};

template <typename ArithOp, typename EmitCOp>
class SignedShiftOpConversion final
    : public ShiftOpConversion<ArithOp, EmitCOp, false> {
  using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
};

template <typename ArithOp, typename EmitCOp>
class UnsignedShiftOpConversion final
    : public ShiftOpConversion<ArithOp, EmitCOp, true> {
  using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
};

class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
public:
  using OpConversionPattern<arith::SelectOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {

    Type dstType = getTypeConverter()->convertType(selectOp.getType());
    if (!dstType)
      return rewriter.notifyMatchFailure(selectOp, "type conversion failed");

    if (!adaptor.getCondition().getType().isInteger(1))
      return rewriter.notifyMatchFailure(
          selectOp,
          "can only be converted if condition is a scalar of type i1");

    rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
                                                      adaptor.getOperands());

    return success();
  }
};

// Floating-point to integer conversions.
template <typename CastOp>
class FtoICastOpConversion : public OpConversionPattern<CastOp> {
public:
  FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
      : OpConversionPattern<CastOp>(typeConverter, context) {}

  LogicalResult
  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {

    Type operandType = adaptor.getIn().getType();
    if (!emitc::isSupportedFloatType(operandType))
      return rewriter.notifyMatchFailure(castOp,
                                         "unsupported cast source type");

    Type dstType = this->getTypeConverter()->convertType(castOp.getType());
    if (!dstType)
      return rewriter.notifyMatchFailure(castOp, "type conversion failed");

    // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
    // truncated to 0, whereas a boolean conversion would return true.
    if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
      return rewriter.notifyMatchFailure(castOp,
                                         "unsupported cast destination type");

    // Convert to unsigned if it's the "ui" variant
    // Signless is interpreted as signed, so no need to cast for "si"
    Type actualResultType = dstType;
    if (isa<arith::FPToUIOp>(castOp)) {
      actualResultType =
          rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
                                  /*isSigned=*/false);
    }

    Value result = rewriter.create<emitc::CastOp>(
        castOp.getLoc(), actualResultType, adaptor.getOperands());

    if (isa<arith::FPToUIOp>(castOp)) {
      result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
    }
    rewriter.replaceOp(castOp, result);

    return success();
  }
};

// Integer to floating-point conversions.
template <typename CastOp>
class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
public:
  ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
      : OpConversionPattern<CastOp>(typeConverter, context) {}

  LogicalResult
  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // Vectors in particular are not supported
    Type operandType = adaptor.getIn().getType();
    if (!emitc::isSupportedIntegerType(operandType))
      return rewriter.notifyMatchFailure(castOp,
                                         "unsupported cast source type");

    Type dstType = this->getTypeConverter()->convertType(castOp.getType());
    if (!dstType)
      return rewriter.notifyMatchFailure(castOp, "type conversion failed");

    if (!emitc::isSupportedFloatType(dstType))
      return rewriter.notifyMatchFailure(castOp,
                                         "unsupported cast destination type");

    // Convert to unsigned if it's the "ui" variant
    // Signless is interpreted as signed, so no need to cast for "si"
    Type actualOperandType = operandType;
    if (isa<arith::UIToFPOp>(castOp)) {
      actualOperandType =
          rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
                                  /*isSigned=*/false);
    }
    Value fpCastOperand = adaptor.getIn();
    if (actualOperandType != operandType) {
      fpCastOperand = rewriter.template create<emitc::CastOp>(
          castOp.getLoc(), actualOperandType, fpCastOperand);
    }
    rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);

    return success();
  }
};

} // namespace

//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//

void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
                                        RewritePatternSet &patterns) {
  MLIRContext *ctx = patterns.getContext();

  mlir::populateEmitCSizeTTypeConversions(typeConverter);

  // clang-format off
  patterns.add<
    ArithConstantOpConversionPattern,
    ArithOpConversion<arith::AddFOp, emitc::AddOp>,
    ArithOpConversion<arith::DivFOp, emitc::DivOp>,
    ArithOpConversion<arith::DivSIOp, emitc::DivOp>,
    ArithOpConversion<arith::MulFOp, emitc::MulOp>,
    ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
    ArithOpConversion<arith::SubFOp, emitc::SubOp>,
    IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
    IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
    IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
    BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
    BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
    BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
    UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
    SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
    UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
    CmpFOpConversion,
    CmpIOpConversion,
    NegFOpConversion,
    SelectOpConversion,
    // Truncation is guaranteed for unsigned types.
    UnsignedCastConversion<arith::TruncIOp>,
    SignedCastConversion<arith::ExtSIOp>,
    UnsignedCastConversion<arith::ExtUIOp>,
    SignedCastConversion<arith::IndexCastOp>,
    UnsignedCastConversion<arith::IndexCastUIOp>,
    ItoFCastOpConversion<arith::SIToFPOp>,
    ItoFCastOpConversion<arith::UIToFPOp>,
    FtoICastOpConversion<arith::FPToSIOp>,
    FtoICastOpConversion<arith::FPToUIOp>
  >(typeConverter, ctx);
  // clang-format on
}