//===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;
using namespace mlir::complex;

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//

OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
  return getValue();
}

void ConstantOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "cst");
}

bool ConstantOp::isBuildableWith(Attribute value, Type type) {
  if (auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) {
    auto complexTy = llvm::dyn_cast<ComplexType>(type);
    if (!complexTy || arrAttr.size() != 2)
      return false;
    auto complexEltTy = complexTy.getElementType();
    if (auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) {
      auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]);
      return im && fre.getType() == complexEltTy &&
             im.getType() == complexEltTy;
    }
    if (auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) {
      auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]);
      return im && ire.getType() == complexEltTy &&
             im.getType() == complexEltTy;
    }
  }
  return false;
}

LogicalResult ConstantOp::verify() {
  ArrayAttr arrayAttr = getValue();
  if (arrayAttr.size() != 2) {
    return emitOpError(
        "requires 'value' to be a complex constant, represented as array of "
        "two values");
  }

  auto complexEltTy = getType().getElementType();
  if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
      !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
    return emitOpError(
        "requires attribute's elements to be float or integer attributes");
  auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
  auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
  if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
    return emitOpError()
           << "requires attribute's element types (" << re.getType() << ", "
           << im.getType()
           << ") to match the element type of the op's return type ("
           << complexEltTy << ")";
  }
  return success();
}

//===----------------------------------------------------------------------===//
// BitcastOp
//===----------------------------------------------------------------------===//

OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) {
  if (getOperand().getType() == getType())
    return getOperand();

  return {};
}

LogicalResult BitcastOp::verify() {
  auto operandType = getOperand().getType();
  auto resultType = getType();

  // We allow this to be legal as it can be folded away.
  if (operandType == resultType)
    return success();

  if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) {
    return emitOpError("operand must be int/float/complex");
  }

  if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) {
    return emitOpError("result must be int/float/complex");
  }

  if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
    return emitOpError(
        "requires that either input or output has a complex type");
  }

  if (isa<ComplexType>(resultType))
    std::swap(operandType, resultType);

  int32_t operandBitwidth = dyn_cast<ComplexType>(operandType)
                                .getElementType()
                                .getIntOrFloatBitWidth() *
                            2;
  int32_t resultBitwidth = resultType.getIntOrFloatBitWidth();

  if (operandBitwidth != resultBitwidth) {
    return emitOpError("casting bitwidths do not match");
  }

  return success();
}

struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
  using OpRewritePattern<BitcastOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(BitcastOp op,
                                PatternRewriter &rewriter) const override {
    if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
      if (isa<ComplexType>(op.getType()) ||
          isa<ComplexType>(defining.getOperand().getType())) {
        // complex.bitcast requires that input or output is complex.
        rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
                                               defining.getOperand());
      } else {
        rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
                                                      defining.getOperand());
      }
      return success();
    }

    if (auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) {
      rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
                                             defining.getOperand());
      return success();
    }

    return failure();
  }
};

struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
  using OpRewritePattern<arith::BitcastOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(arith::BitcastOp op,
                                PatternRewriter &rewriter) const override {
    if (auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) {
      rewriter.replaceOpWithNewOp<complex::BitcastOp>(op, op.getType(),
                                                      defining.getOperand());
      return success();
    }

    return failure();
  }
};

void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
  results.add<MergeComplexBitcast, MergeArithBitcast>(context);
}

//===----------------------------------------------------------------------===//
// CreateOp
//===----------------------------------------------------------------------===//

OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
  // Fold complex.create(complex.re(op), complex.im(op)).
  if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
    if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
      if (reOp.getOperand() == imOp.getOperand()) {
        return reOp.getOperand();
      }
    }
  }
  return {};
}

//===----------------------------------------------------------------------===//
// ImOp
//===----------------------------------------------------------------------===//

OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
  ArrayAttr arrayAttr =
      llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
  if (arrayAttr && arrayAttr.size() == 2)
    return arrayAttr[1];
  if (auto createOp = getOperand().getDefiningOp<CreateOp>())
    return createOp.getOperand(1);
  return {};
}

namespace {
template <typename OpKind, int ComponentIndex>
struct FoldComponentNeg final : OpRewritePattern<OpKind> {
  using OpRewritePattern<OpKind>::OpRewritePattern;

  LogicalResult matchAndRewrite(OpKind op,
                                PatternRewriter &rewriter) const override {
    auto negOp = op.getOperand().template getDefiningOp<NegOp>();
    if (!negOp)
      return failure();

    auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
    if (!createOp)
      return failure();

    Type elementType = createOp.getType().getElementType();
    assert(isa<FloatType>(elementType));

    rewriter.replaceOpWithNewOp<arith::NegFOp>(
        op, elementType, createOp.getOperand(ComponentIndex));
    return success();
  }
};
} // namespace

void ImOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                       MLIRContext *context) {
  results.add<FoldComponentNeg<ImOp, 1>>(context);
}

//===----------------------------------------------------------------------===//
// ReOp
//===----------------------------------------------------------------------===//

OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
  ArrayAttr arrayAttr =
      llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
  if (arrayAttr && arrayAttr.size() == 2)
    return arrayAttr[0];
  if (auto createOp = getOperand().getDefiningOp<CreateOp>())
    return createOp.getOperand(0);
  return {};
}

void ReOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                       MLIRContext *context) {
  results.add<FoldComponentNeg<ReOp, 0>>(context);
}

//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//

OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
  // complex.add(complex.sub(a, b), b) -> a
  if (auto sub = getLhs().getDefiningOp<SubOp>())
    if (getRhs() == sub.getRhs())
      return sub.getLhs();

  // complex.add(b, complex.sub(a, b)) -> a
  if (auto sub = getRhs().getDefiningOp<SubOp>())
    if (getLhs() == sub.getRhs())
      return sub.getLhs();

  // complex.add(a, complex.constant<0.0, 0.0>) -> a
  if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
    auto arrayAttr = constantOp.getValue();
    if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
        llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
      return getLhs();
    }
  }

  return {};
}

//===----------------------------------------------------------------------===//
// SubOp
//===----------------------------------------------------------------------===//

OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
  // complex.sub(complex.add(a, b), b) -> a
  if (auto add = getLhs().getDefiningOp<AddOp>())
    if (getRhs() == add.getRhs())
      return add.getLhs();

  // complex.sub(a, complex.constant<0.0, 0.0>) -> a
  if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
    auto arrayAttr = constantOp.getValue();
    if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
        llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
      return getLhs();
    }
  }

  return {};
}

//===----------------------------------------------------------------------===//
// NegOp
//===----------------------------------------------------------------------===//

OpFoldResult NegOp::fold(FoldAdaptor adaptor) {
  // complex.neg(complex.neg(a)) -> a
  if (auto negOp = getOperand().getDefiningOp<NegOp>())
    return negOp.getOperand();

  return {};
}

//===----------------------------------------------------------------------===//
// LogOp
//===----------------------------------------------------------------------===//

OpFoldResult LogOp::fold(FoldAdaptor adaptor) {
  // complex.log(complex.exp(a)) -> a
  if (auto expOp = getOperand().getDefiningOp<ExpOp>())
    return expOp.getOperand();

  return {};
}

//===----------------------------------------------------------------------===//
// ExpOp
//===----------------------------------------------------------------------===//

OpFoldResult ExpOp::fold(FoldAdaptor adaptor) {
  // complex.exp(complex.log(a)) -> a
  if (auto logOp = getOperand().getDefiningOp<LogOp>())
    return logOp.getOperand();

  return {};
}

//===----------------------------------------------------------------------===//
// ConjOp
//===----------------------------------------------------------------------===//

OpFoldResult ConjOp::fold(FoldAdaptor adaptor) {
  // complex.conj(complex.conj(a)) -> a
  if (auto conjOp = getOperand().getDefiningOp<ConjOp>())
    return conjOp.getOperand();

  return {};
}

//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//

OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
  auto constant = getRhs().getDefiningOp<ConstantOp>();
  if (!constant)
    return {};

  ArrayAttr arrayAttr = constant.getValue();
  APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
  APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();

  if (!imag.isZero())
    return {};

  // complex.mul(a, complex.constant<1.0, 0.0>) -> a
  if (real == APFloat(real.getSemantics(), 1))
    return getLhs();

  return {};
}

//===----------------------------------------------------------------------===//
// DivOp
//===----------------------------------------------------------------------===//

OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
  auto rhs = adaptor.getRhs();
  if (!rhs)
    return {};

  ArrayAttr arrayAttr = dyn_cast<ArrayAttr>(rhs);
  if (!arrayAttr || arrayAttr.size() != 2)
    return {};

  APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
  APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();

  if (!imag.isZero())
    return {};

  // complex.div(a, complex.constant<1.0, 0.0>) -> a
  if (real == APFloat(real.getSemantics(), 1))
    return getLhs();

  return {};
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"