//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Transforms SCF.WhileOp's into SCF.ForOp's.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/Transforms/Passes.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;

namespace {
struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(scf::WhileOp loop,
                                PatternRewriter &rewriter) const override {
    return upliftWhileToForLoop(rewriter, loop);
  }
};
} // namespace

FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
                                                      scf::WhileOp loop) {
  Block *beforeBody = loop.getBeforeBody();
  if (!llvm::hasSingleElement(beforeBody->without_terminator()))
    return rewriter.notifyMatchFailure(loop, "Loop body must have single op");

  auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
  if (!cmp)
    return rewriter.notifyMatchFailure(loop,
                                       "Loop body must have single cmp op");

  scf::ConditionOp beforeTerm = loop.getConditionOp();
  if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult())
    return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
      diag << "Expected single condition use: " << *cmp;
    });

  // All `before` block args must be directly forwarded to ConditionOp.
  // They will be converted to `scf.for` `iter_vars` except induction var.
  if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
    return rewriter.notifyMatchFailure(loop, "Invalid args order");

  using Pred = arith::CmpIPredicate;
  Pred predicate = cmp.getPredicate();
  if (predicate != Pred::slt && predicate != Pred::sgt)
    return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
      diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
    });

  BlockArgument inductionVar;
  Value ub;
  DominanceInfo dom;

  // Check if cmp has a suitable form. One of the arguments must be a `before`
  // block arg, other must be defined outside `scf.while` and will be treated
  // as upper bound.
  for (bool reverse : {false, true}) {
    auto expectedPred = reverse ? Pred::sgt : Pred::slt;
    if (cmp.getPredicate() != expectedPred)
      continue;

    auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
    auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();

    auto blockArg = dyn_cast<BlockArgument>(arg1);
    if (!blockArg || blockArg.getOwner() != beforeBody)
      continue;

    if (!dom.properlyDominates(arg2, loop))
      continue;

    inductionVar = blockArg;
    ub = arg2;
    break;
  }

  if (!inductionVar)
    return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
      diag << "Unrecognized cmp form: " << *cmp;
    });

  // inductionVar must have 2 uses: one is in `cmp` and other is `condition`
  // arg.
  if (!llvm::hasNItems(inductionVar.getUses(), 2))
    return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
      diag << "Unrecognized induction var: " << inductionVar;
    });

  Block *afterBody = loop.getAfterBody();
  scf::YieldOp afterTerm = loop.getYieldOp();
  unsigned argNumber = inductionVar.getArgNumber();
  Value afterTermIndArg = afterTerm.getResults()[argNumber];

  Value inductionVarAfter = afterBody->getArgument(argNumber);

  // Find suitable `addi` op inside `after` block, one of the args must be an
  // Induction var passed from `before` block and second arg must be defined
  // outside of the loop and will be considered step value.
  // TODO: Add `subi` support?
  auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
  if (!addOp)
    return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");

  Value step;
  if (addOp.getLhs() == inductionVarAfter) {
    step = addOp.getRhs();
  } else if (addOp.getRhs() == inductionVarAfter) {
    step = addOp.getLhs();
  }

  if (!step || !dom.properlyDominates(step, loop))
    return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");

  Value lb = loop.getInits()[argNumber];

  assert(lb.getType().isIntOrIndex());
  assert(lb.getType() == ub.getType());
  assert(lb.getType() == step.getType());

  llvm::SmallVector<Value> newArgs;

  // Populate inits for new `scf.for`, skip induction var.
  newArgs.reserve(loop.getInits().size());
  for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
    if (i == argNumber)
      continue;

    newArgs.emplace_back(init);
  }

  Location loc = loop.getLoc();

  // With `builder == nullptr`, ForOp::build will try to insert terminator at
  // the end of newly created block and we don't want it. Provide empty
  // dummy builder instead.
  auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
  auto newLoop =
      rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);

  Block *newBody = newLoop.getBody();

  // Populate block args for `scf.for` body, move induction var to the front.
  newArgs.clear();
  ValueRange newBodyArgs = newBody->getArguments();
  for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
    if (i < argNumber) {
      newArgs.emplace_back(newBodyArgs[i + 1]);
    } else if (i == argNumber) {
      newArgs.emplace_back(newBodyArgs.front());
    } else {
      newArgs.emplace_back(newBodyArgs[i]);
    }
  }

  rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
                             newArgs);

  auto term = cast<scf::YieldOp>(newBody->getTerminator());

  // Populate new yield args, skipping the induction var.
  newArgs.clear();
  for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
    if (i == argNumber)
      continue;

    newArgs.emplace_back(arg);
  }

  OpBuilder::InsertionGuard g(rewriter);
  rewriter.setInsertionPoint(term);
  rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs);

  // Compute induction var value after loop execution.
  rewriter.setInsertionPointAfter(newLoop);
  Value one;
  if (isa<IndexType>(step.getType())) {
    one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
  } else {
    one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType());
  }

  Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
  Value len = rewriter.create<arith::SubIOp>(loc, ub, lb);
  len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
  len = rewriter.create<arith::DivSIOp>(loc, len, step);
  len = rewriter.create<arith::SubIOp>(loc, len, one);
  Value res = rewriter.create<arith::MulIOp>(loc, len, step);
  res = rewriter.create<arith::AddIOp>(loc, lb, res);

  // Reconstruct `scf.while` results, inserting final induction var value
  // into proper place.
  newArgs.clear();
  llvm::append_range(newArgs, newLoop.getResults());
  newArgs.insert(newArgs.begin() + argNumber, res);
  rewriter.replaceOp(loop, newArgs);
  return newLoop;
}

void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
  patterns.add<UpliftWhileOp>(patterns.getContext());
}