//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a commutativity utility pattern and a function to
// populate this pattern. The function is intended to be used inside passes to
// simplify the matching of commutative operations by fixing the order of their
// operands.
//
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/CommutativityUtils.h"

#include <queue>

using namespace mlir;

/// The possible "types" of ancestors. Here, an ancestor is an op or a block
/// argument present in the backward slice of a value.
enum AncestorType {
  /// Pertains to a block argument.
  BLOCK_ARGUMENT,

  /// Pertains to a non-constant-like op.
  NON_CONSTANT_OP,

  /// Pertains to a constant-like op.
  CONSTANT_OP
};

/// Stores the "key" associated with an ancestor.
struct AncestorKey {
  /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on
  /// the ancestor.
  AncestorType type;

  /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or
  /// `CONSTANT_OP`. Else, holds "".
  StringRef opName;

  /// Constructor for `AncestorKey`.
  AncestorKey(Operation *op) {
    if (!op) {
      type = BLOCK_ARGUMENT;
    } else {
      type =
          op->hasTrait<OpTrait::ConstantLike>() ? CONSTANT_OP : NON_CONSTANT_OP;
      opName = op->getName().getStringRef();
    }
  }

  /// Overloaded operator `<` for `AncestorKey`.
  ///
  /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those
  /// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in
  /// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller
  /// ones are the ones with smaller op names (lexicographically).
  ///
  /// TODO: Include other information like attributes, value type, etc., to
  /// enhance this comparison. For example, currently this comparison doesn't
  /// differentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and
  /// `addi (in i64)`. Such an enhancement should only be done if the need
  /// arises.
  bool operator<(const AncestorKey &key) const {
    return std::tie(type, opName) < std::tie(key.type, key.opName);
  }
};

/// Stores a commutative operand along with its BFS traversal information.
struct CommutativeOperand {
  /// Stores the operand.
  Value operand;

  /// Stores the queue of ancestors of the operand's BFS traversal at a
  /// particular point in time.
  std::queue<Operation *> ancestorQueue;

  /// Stores the list of ancestors that have been visited by the BFS traversal
  /// at a particular point in time.
  DenseSet<Operation *> visitedAncestors;

  /// Stores the operand's "key". This "key" is defined as a list of the
  /// "AncestorKeys" associated with the ancestors of this operand, in a
  /// breadth-first order.
  ///
  /// So, if an operand, say `A`, was produced as follows:
  ///
  /// `<block argument>`  `<block argument>`
  ///             \          /
  ///              \        /
  ///             `arith.subi`           `arith.constant`
  ///                       \            /
  ///                        `arith.addi`
  ///                              |
  ///                         returns `A`
  ///
  /// Then, the ancestors of `A`, in the breadth-first order are:
  /// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and
  /// `<block argument>`.
  ///
  /// Thus, the "key" associated with operand `A` is:
  /// {
  ///  {type: `NON_CONSTANT_OP`, opName: "arith.addi"},
  ///  {type: `NON_CONSTANT_OP`, opName: "arith.subi"},
  ///  {type: `CONSTANT_OP`, opName: "arith.constant"},
  ///  {type: `BLOCK_ARGUMENT`, opName: ""},
  ///  {type: `BLOCK_ARGUMENT`, opName: ""}
  /// }
  SmallVector<AncestorKey, 4> key;

  /// Push an ancestor into the operand's BFS information structure. This
  /// entails it being pushed into the queue (always) and inserted into the
  /// "visited ancestors" list (iff it is an op rather than a block argument).
  void pushAncestor(Operation *op) {
    ancestorQueue.push(op);
    if (op)
      visitedAncestors.insert(op);
  }

  /// Refresh the key.
  ///
  /// Refreshing a key entails making it up-to-date with the operand's BFS
  /// traversal that has happened till that point in time, i.e, appending the
  /// existing key with the front ancestor's "AncestorKey". Note that a key
  /// directly reflects the BFS and thus needs to be refreshed during the
  /// progression of the traversal.
  void refreshKey() {
    if (ancestorQueue.empty())
      return;

    Operation *frontAncestor = ancestorQueue.front();
    AncestorKey frontAncestorKey(frontAncestor);
    key.push_back(frontAncestorKey);
  }

  /// Pop the front ancestor, if any, from the queue and then push its adjacent
  /// unvisited ancestors, if any, to the queue (this is the main body of the
  /// BFS algorithm).
  void popFrontAndPushAdjacentUnvisitedAncestors() {
    if (ancestorQueue.empty())
      return;
    Operation *frontAncestor = ancestorQueue.front();
    ancestorQueue.pop();
    if (!frontAncestor)
      return;
    for (Value operand : frontAncestor->getOperands()) {
      Operation *operandDefOp = operand.getDefiningOp();
      if (!operandDefOp || !visitedAncestors.contains(operandDefOp))
        pushAncestor(operandDefOp);
    }
  }
};

/// Sorts the operands of `op` in ascending order of the "key" associated with
/// each operand iff `op` is commutative. This is a stable sort.
///
/// After the application of this pattern, since the commutative operands now
/// have a deterministic order in which they occur in an op, the matching of
/// large DAGs becomes much simpler, i.e., requires much less number of checks
/// to be written by a user in her/his pattern matching function.
///
/// Some examples of such a sorting:
///
/// Assume that the sorting is being applied to `foo.commutative`, which is a
/// commutative op.
///
/// Example 1:
///
/// %1 = foo.const 0
/// %2 = foo.mul <block argument>, <block argument>
/// %3 = foo.commutative %1, %2
///
/// Here,
/// 1. The key associated with %1 is:
///     `{
///       {CONSTANT_OP, "foo.const"}
///      }`
/// 2. The key associated with %2 is:
///     `{
///       {NON_CONSTANT_OP, "foo.mul"},
///       {BLOCK_ARGUMENT, ""},
///       {BLOCK_ARGUMENT, ""}
///      }`
///
/// The key of %2 < the key of %1
/// Thus, the sorted `foo.commutative` is:
/// %3 = foo.commutative %2, %1
///
/// Example 2:
///
/// %1 = foo.const 0
/// %2 = foo.mul <block argument>, <block argument>
/// %3 = foo.mul %2, %1
/// %4 = foo.add %2, %1
/// %5 = foo.commutative %1, %2, %3, %4
///
/// Here,
/// 1. The key associated with %1 is:
///     `{
///       {CONSTANT_OP, "foo.const"}
///      }`
/// 2. The key associated with %2 is:
///     `{
///       {NON_CONSTANT_OP, "foo.mul"},
///       {BLOCK_ARGUMENT, ""}
///      }`
/// 3. The key associated with %3 is:
///     `{
///       {NON_CONSTANT_OP, "foo.mul"},
///       {NON_CONSTANT_OP, "foo.mul"},
///       {CONSTANT_OP, "foo.const"},
///       {BLOCK_ARGUMENT, ""},
///       {BLOCK_ARGUMENT, ""}
///      }`
/// 4. The key associated with %4 is:
///     `{
///       {NON_CONSTANT_OP, "foo.add"},
///       {NON_CONSTANT_OP, "foo.mul"},
///       {CONSTANT_OP, "foo.const"},
///       {BLOCK_ARGUMENT, ""},
///       {BLOCK_ARGUMENT, ""}
///      }`
///
/// Thus, the sorted `foo.commutative` is:
/// %5 = foo.commutative %4, %3, %2, %1
class SortCommutativeOperands : public RewritePattern {
public:
  SortCommutativeOperands(MLIRContext *context)
      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {}
  LogicalResult matchAndRewrite(Operation *op,
                                PatternRewriter &rewriter) const override {
    // Custom comparator for two commutative operands, which returns true iff
    // the "key" of `constCommOperandA` < the "key" of `constCommOperandB`,
    // i.e.,
    // 1. In the first unequal pair of corresponding AncestorKeys, the
    // AncestorKey in `constCommOperandA` is smaller, or,
    // 2. Both the AncestorKeys in every pair are the same and the size of
    // `constCommOperandA`'s "key" is smaller.
    auto commutativeOperandComparator =
        [](const std::unique_ptr<CommutativeOperand> &constCommOperandA,
           const std::unique_ptr<CommutativeOperand> &constCommOperandB) {
          if (constCommOperandA->operand == constCommOperandB->operand)
            return false;

          auto &commOperandA =
              const_cast<std::unique_ptr<CommutativeOperand> &>(
                  constCommOperandA);
          auto &commOperandB =
              const_cast<std::unique_ptr<CommutativeOperand> &>(
                  constCommOperandB);

          // Iteratively perform the BFS's of both operands until an order among
          // them can be determined.
          unsigned keyIndex = 0;
          while (true) {
            if (commOperandA->key.size() <= keyIndex) {
              if (commOperandA->ancestorQueue.empty())
                return true;
              commOperandA->popFrontAndPushAdjacentUnvisitedAncestors();
              commOperandA->refreshKey();
            }
            if (commOperandB->key.size() <= keyIndex) {
              if (commOperandB->ancestorQueue.empty())
                return false;
              commOperandB->popFrontAndPushAdjacentUnvisitedAncestors();
              commOperandB->refreshKey();
            }
            if (commOperandA->ancestorQueue.empty() ||
                commOperandB->ancestorQueue.empty())
              return commOperandA->key.size() < commOperandB->key.size();
            if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex])
              return true;
            if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex])
              return false;
            keyIndex++;
          }
        };

    // If `op` is not commutative, do nothing.
    if (!op->hasTrait<OpTrait::IsCommutative>())
      return failure();

    // Populate the list of commutative operands.
    SmallVector<Value, 2> operands = op->getOperands();
    SmallVector<std::unique_ptr<CommutativeOperand>, 2> commOperands;
    for (Value operand : operands) {
      std::unique_ptr<CommutativeOperand> commOperand =
          std::make_unique<CommutativeOperand>();
      commOperand->operand = operand;
      commOperand->pushAncestor(operand.getDefiningOp());
      commOperand->refreshKey();
      commOperands.push_back(std::move(commOperand));
    }

    // Sort the operands.
    std::stable_sort(commOperands.begin(), commOperands.end(),
                     commutativeOperandComparator);
    SmallVector<Value, 2> sortedOperands;
    for (const std::unique_ptr<CommutativeOperand> &commOperand : commOperands)
      sortedOperands.push_back(commOperand->operand);
    if (sortedOperands == operands)
      return failure();
    rewriter.modifyOpInPlace(op, [&] { op->setOperands(sortedOperands); });
    return success();
  }
};

void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) {
  patterns.add<SortCommutativeOperands>(patterns.getContext());
}