/*
 * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining
 * a copy of this software and associated documentation files
 * (the "Software"), to deal in the Software without restriction,
 * including without limitation the rights to use, copy, modify, merge,
 * publish, distribute, sublicense, and/or sell copies of the Software,
 * and to permit persons to whom the Software is furnished to do so,
 * subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be
 * included in all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
 * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
 * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 */

#include <queue>

#include "mlir/Support/LLVM.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h"

namespace ttg = mlir::triton::gpu;

namespace mlir {
namespace triton {
namespace nvidia_gpu {

#define GEN_PASS_DEF_TRITONGPUPLANCTAPASS
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"

namespace {

// TODO: use ConvertLayoutOp
using CastOp = ::mlir::UnrealizedConversionCastOp;

unsigned getNumUsers(Value value) {
  return std::distance(value.user_begin(), value.user_end());
}

Type replaceLayout(const Type &type, const Attribute &newLayout) {
  Type curType = type;
  auto ptrTy = dyn_cast<triton::PointerType>(curType);
  if (ptrTy)
    curType = ptrTy.getPointeeType();
  if (auto tensorTy = dyn_cast<RankedTensorType>(curType))
    curType = tensorTy.cloneWithEncoding(newLayout);
  if (ptrTy)
    curType = triton::PointerType::get(curType, ptrTy.getAddressSpace());
  return curType;
}

ttg::DistributedEncodingTrait
replaceCTALayout(ttg::DistributedEncodingTrait layout,
                 llvm::ArrayRef<int64_t> shape, int numWarps,
                 const ttg::CTALayoutAttr &newCTALayout) {
  if (auto blockedLayout = mlir::dyn_cast<ttg::BlockedEncodingAttr>(layout)) {
    return ttg::BlockedEncodingAttr::get(
        layout.getContext(), shape, blockedLayout.getSizePerThread(),
        blockedLayout.getOrder(), numWarps, 32, newCTALayout);
  } else if (auto sliceLayout =
                 mlir::dyn_cast<ttg::SliceEncodingAttr>(layout)) {
    return ttg::SliceEncodingAttr::get(
        layout.getContext(), sliceLayout.getDim(),
        replaceCTALayout(sliceLayout.getParent(), shape, numWarps,
                         newCTALayout));
  } else {
    // Other layouts are generated by passes after PlanCTAPass
    llvm::report_fatal_error("replaceCTALayout not implemented");
    return layout;
  }
}

class CTAPlanner {
public:
  CTAPlanner(ClusterInfo *clusterInfo_);
  ~CTAPlanner();

  void run(triton::FuncOp &funcOp);

private:
  CastOp markBackward(CastOp cast) const;
  CastOp markForward(CastOp cast) const;
  bool isBackward(CastOp cast) const;
  bool isForward(CastOp cast) const;

  void setTiling(llvm::ArrayRef<unsigned> CTAsPerCGA);
  bool processDot(triton::FuncOp &funcOp);
  bool processReduce(triton::FuncOp &funcOp);
  void processStoreLikeOps(triton::FuncOp &funcOp);

  bool propagate(CastOp cast);
  bool propagateBackward(CastOp cast);
  bool propagateForward(CastOp cast);

  void eraseCastOp(CastOp cast);
  void eraseCastOpFromQueue(CastOp cast);
  void eraseCastOpsFromQueue(llvm::ArrayRef<CastOp> casts);

  void insertCasts(Operation *op, llvm::ArrayRef<Attribute> newOperandLayouts,
                   llvm::ArrayRef<Attribute> newResultLayouts);
  void eliminateAdjacentCasts(CastOp cast0, CastOp cast1);

  bool isLoadStoreOp(Operation *op) const;
  bool processLoadStore(Operation *op, Attribute layout);

  bool isElementwiseOp(Operation *op) const;
  bool processElementwise(Operation *op, Attribute layout);

  bool processConstant(arith::ConstantOp constant, Attribute layout);
  bool processSplat(triton::SplatOp splat, Attribute layout);
  bool processMakeRange(triton::MakeRangeOp makeRange, Attribute layout);
  bool processMakeTensorPtr(triton::MakeTensorPtrOp makeTensorPtr,
                            Attribute layout);

  bool processBroadcast(triton::BroadcastOp broadcast, Attribute layout);
  bool processExpandDimsBackward(triton::ExpandDimsOp expandDims,
                                 ttg::DistributedEncodingTrait newResultLayout);
  bool processExpandDimsForward(triton::ExpandDimsOp expandDims,
                                ttg::DistributedEncodingTrait newSrcLayout);

  bool processConvertLayoutBackward(ttg::ConvertLayoutOp convertLayout,
                                    CastOp cast);
  bool processConvertLayoutForward(ttg::ConvertLayoutOp convertLayout,
                                   CastOp cast);

  bool processIfOp(scf::IfOp ifOp, int index, const Type &newType);
  bool processForOp(scf::ForOp forOp, int index, const Type &newType);

  bool processIfOpBackward(scf::IfOp ifOp, CastOp cast);
  bool processForOpBackward(scf::ForOp forOp, CastOp cast);
  bool processBlockArgBackward(BlockArgument arg, CastOp cast);
  bool processForOpForward(scf::ForOp forOp, CastOp cast);
  bool processYieldOpForward(scf::YieldOp yieldOp, CastOp cast);

  bool processOpFallback(Operation *op);

  bool processMultiUsersBackward(Value input, CastOp cast);
  bool processMultiUsersForward(Value output, CastOp cast);

  // This flag indicates whether clusterInfo needs to be deleted in the
  // destructor of CTAPlanner. The flag `ownInfo` is set to false when a
  // non-null pointer to clusterInfo is passed to the constructor of CTAPlanner.
  // Otherwise, a self-managed ClusterInfo will be created and the ownInfo will
  // be set to true.
  bool ownInfo;
  ClusterInfo *clusterInfo;
  bool tiled;
  unsigned step;
  unsigned stepUnchanged;
  std::queue<CastOp> queue;
};

CTAPlanner::CTAPlanner(ClusterInfo *clusterInfo_)
    : ownInfo(false), clusterInfo(clusterInfo_), tiled(false), step(0),
      stepUnchanged(0) {
  if (clusterInfo == nullptr) {
    clusterInfo = new ClusterInfo();
    ownInfo = true;
  }
}

CTAPlanner::~CTAPlanner() {
  if (ownInfo) {
    delete clusterInfo;
    // Actually not necessary but safer
    ownInfo = false;
    clusterInfo = nullptr;
  }
}

void CTAPlanner::run(triton::FuncOp &funcOp) {
  assert(!tiled && "Please create a new CTAPlanner");
  static const unsigned maxSteps = 10000;

  auto nextStep = [&]() {
    ++step;
    assert(step < maxSteps && "Maximum number of steps exceeded");
  };

  processDot(funcOp);
  nextStep();

  processReduce(funcOp);
  nextStep();

  if (!tiled) {
    processStoreLikeOps(funcOp);
    nextStep();
  }

  while (!queue.empty()) {
    CastOp cast = queue.front();
    queue.pop();
    bool changed = propagate(cast);
    if (changed) {
      stepUnchanged = 0;
    } else {
      queue.push(cast);
      ++stepUnchanged;
    }
    nextStep();
  }
}

CastOp CTAPlanner::markBackward(CastOp cast) const {
  cast->setAttr("direction", StringAttr::get(cast.getContext(), "backward"));
  return cast;
}

CastOp CTAPlanner::markForward(CastOp cast) const {
  cast->setAttr("direction", StringAttr::get(cast.getContext(), "forward"));
  return cast;
}

bool CTAPlanner::isBackward(CastOp cast) const {
  return cast->getAttrOfType<StringAttr>("direction") == "backward";
}

bool CTAPlanner::isForward(CastOp cast) const {
  return cast->getAttrOfType<StringAttr>("direction") == "forward";
}

void CTAPlanner::setTiling(llvm::ArrayRef<unsigned> CTAsPerCGA) {
  assert(!tiled && "CTA tiling is already determinted");
  assert(clusterInfo && "ClusterInfo pointer is null");
  tiled = true;
  unsigned numCTAs = 1;
  for (unsigned cta : CTAsPerCGA)
    numCTAs *= cta;
  if (numCTAs == 2) {
    // For 2 CTAs always use 2x1x1.
    // TODO: can we always serialize the CTAs on X dimension?
    clusterInfo->clusterDimX = 2;
    return;
  }

  if (CTAsPerCGA.size() > 0)
    clusterInfo->clusterDimX = CTAsPerCGA[0];
  if (CTAsPerCGA.size() > 1)
    clusterInfo->clusterDimY = CTAsPerCGA[1];
  if (CTAsPerCGA.size() > 2)
    clusterInfo->clusterDimZ = CTAsPerCGA[2];
  for (auto i = 3; i < CTAsPerCGA.size(); ++i)
    if (CTAsPerCGA[i] != 1)
      llvm::report_fatal_error("tiling > 3 dims is not implemented");
}

bool CTAPlanner::processDot(triton::FuncOp &funcOp) {
  // TODO: This is a naive implementation and should be refactored
  auto getCTATiling = [](int64_t M, int64_t N, int64_t K,
                         unsigned numCTAs) -> std::pair<unsigned, unsigned> {
    // prefer a larger chunk size, at most 128; first assign splitM.
    unsigned chunk_m = 128;
    auto isLegal = [](unsigned chunk) { return chunk >= 64; };
    unsigned splitM, splitN;
    for (; isLegal(chunk_m); chunk_m /= 2) {
      splitM = std::clamp<unsigned>(M / chunk_m, 1, numCTAs);
      splitN = numCTAs / splitM;
      if (isLegal(N / splitN)) // chunk_n;
        break;
    }
    return {splitM, splitN};
  };

  funcOp.walk([&](triton::DotOp dot) {
    MLIRContext *ctx = dot.getContext();

    auto aTy = cast<RankedTensorType>(dot.getA().getType());
    auto bTy = cast<RankedTensorType>(dot.getB().getType());
    auto dTy = cast<RankedTensorType>(dot.getD().getType());

    assert(isa<ttg::DotOperandEncodingAttr>(aTy.getEncoding()) &&
           isa<ttg::DotOperandEncodingAttr>(bTy.getEncoding()) &&
           isa<ttg::BlockedEncodingAttr>(dTy.getEncoding()) &&
           "PlanCTAPass should follow immediately after CoalescePass");

    auto aLayout = cast<ttg::DotOperandEncodingAttr>(aTy.getEncoding());
    auto bLayout = cast<ttg::DotOperandEncodingAttr>(bTy.getEncoding());
    auto dLayout = cast<ttg::BlockedEncodingAttr>(dTy.getEncoding());

    unsigned M = dTy.getShape()[0];
    unsigned N = dTy.getShape()[1];
    unsigned K = aTy.getShape()[1];

    unsigned splitM, splitN;
    std::tie(splitM, splitN) = getCTATiling(M, N, K, ttg::getNumCTAs(dLayout));
    // FIXME: Should consider IR with more than one DotOps
    setTiling({splitM, splitN, 1});

    OpBuilder builder(dot);
    auto numThreads = ttg::lookupThreadsPerWarp(builder);
    auto numWarps = ttg::lookupNumWarps(dot);

    auto newCTALayout = ttg::CTALayoutAttr::get(ctx, {splitM, splitN},
                                                {splitM, splitN}, {1, 0});
    auto newDLayout = ttg::BlockedEncodingAttr::get(
        ctx, dTy.getShape(), dLayout.getSizePerThread(), dLayout.getOrder(),
        numWarps, numThreads, newCTALayout);
    auto newALayout = ttg::DotOperandEncodingAttr::get(ctx, aLayout.getOpIdx(),
                                                       newDLayout, 0);
    auto newBLayout = ttg::DotOperandEncodingAttr::get(ctx, bLayout.getOpIdx(),
                                                       newDLayout, 0);

    insertCasts(dot.getOperation(), {newALayout, newBLayout, newDLayout},
                {newDLayout});
  });

  return true;
}

bool CTAPlanner::processReduce(triton::FuncOp &funcOp) {
  ModuleOp mod = funcOp->getParentOfType<ModuleOp>();
  unsigned numCTAs = ttg::TritonGPUDialect::getNumCTAs(mod);

  funcOp.walk([&](triton::ReduceOp reduce) {
    MLIRContext *context = reduce.getContext();
    Value src = reduce.getOperands()[0];
    unsigned axis = reduce.getAxis();

    auto srcTy = cast<RankedTensorType>(src.getType());
    auto srcShape = srcTy.getShape();
    auto srcLayout = srcTy.getEncoding();

    auto rank = srcShape.size();
    auto order = ttg::getOrder(srcTy);
    auto sizePerThread = ttg::getContigPerThread(srcTy);
    auto CTAOrder = ttg::getCTAOrder(srcLayout);

    llvm::SmallVector<unsigned> CTAsPerCGA(rank, 0);
    unsigned remainingCTAs = numCTAs;
    for (int i = rank - 1; i >= 0; --i) {
      unsigned dim = order[i];
      if (dim == axis) {
        CTAsPerCGA[dim] = 1;
      } else {
        CTAsPerCGA[dim] = std::min<unsigned>(srcShape[dim] / sizePerThread[dim],
                                             remainingCTAs);
        remainingCTAs /= CTAsPerCGA[dim];
      }
    }

    for (int i = rank - 1; i >= 0; --i) {
      unsigned dim = order[i];
      if (dim != axis) {
        CTAsPerCGA[dim] *= remainingCTAs;
        break;
      }
    }

    llvm::SmallVector<unsigned> CTASplitNum = CTAsPerCGA;

    // If numCTAs > 1 and the only dimension is the reduced dimension, after the
    // above two for-loops, CTAsPerCGA = [0] and remainingCTAs = numCTAs. We set
    // CTAsPerCGA[0] = numCTAs and keep CTASplitNum[0] = 1 to ensure that no
    // cross-CTA reduction is required, although this will introduce duplicated
    // calculation
    if (remainingCTAs > 0)
      CTAsPerCGA[order[rank - 1]] *= remainingCTAs;

    auto numWarps = ttg::lookupNumWarps(reduce);
    auto CTALayout =
        ttg::CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
    if (!tiled)
      setTiling(CTALayout.getCTAsPerCGA());
    auto newSrcLayout =
        replaceCTALayout(cast<ttg::DistributedEncodingTrait>(srcLayout),
                         srcShape, numWarps, CTALayout);
    auto newResultLayout =
        ttg::SliceEncodingAttr::get(context, axis, newSrcLayout);
    unsigned numOperands = reduce.getNumOperands();
    SmallVector<Attribute> newSrcLayoutVec(numOperands, newSrcLayout);
    SmallVector<Attribute> newResultLayoutVec(numOperands, newResultLayout);

    insertCasts(reduce.getOperation(), newSrcLayoutVec, newResultLayoutVec);
  });
  return true;
}

void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) {
  assert(!tiled && "CTA tiling is already determinted");

  llvm::SmallVector<Operation *> stores;
  funcOp.walk([&](Operation *op) {
    if (llvm::isa<triton::StoreOp, triton::AtomicRMWOp, triton::AtomicCASOp,
                  triton::DescriptorStoreLikeOpInterface>(op))
      stores.push_back(op);
  });
  assert(stores.size() > 0 && "Cannot find store-like ops");
  auto numWarps = ttg::lookupNumWarps(funcOp);

  ttg::CTALayoutAttr CTALayout;
  for (Operation *store : stores) {
    auto val = [store]() -> Value {
      if (auto descStore =
              dyn_cast<triton::DescriptorStoreLikeOpInterface>(store))
        return descStore.getSrc();
      return store->getOperand(0);
    }();
    if (auto tensorTy = dyn_cast<RankedTensorType>(val.getType())) {
      if (!tiled) {
        // Use CTA tiling of the first store-like op as global CTA tiling
        CTALayout = ttg::getCTALayout(tensorTy.getEncoding());
        setTiling(CTALayout.getCTAsPerCGA());
      }
      auto newLayout = replaceCTALayout(
          cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding()),
          tensorTy.getShape(), numWarps, CTALayout);
      processElementwise(store, newLayout);
    }
  }

  // If all store-like ops are processing scalar values and no ReduceOp is
  // found, we can conclude that this is an all-scalar computation, since
  // ReduceOp is the only op that converts tensor values to scalar values.
  if (!tiled)
    setTiling({1, 1, 1});
}

bool CTAPlanner::propagate(CastOp cast) {
  return isBackward(cast) ? propagateBackward(cast) : propagateForward(cast);
}

bool CTAPlanner::propagateBackward(CastOp cast) {
  Value input = cast.getOperand(0);
  Value output = cast.getResult(0);
  unsigned numUsers = getNumUsers(input);
  if (numUsers == 0) {
    llvm::report_fatal_error("Unreachable branch");
    return false;
  } else if (numUsers == 1) {
    Type outTy = output.getType();
    if (auto ptrTy = dyn_cast<triton::PointerType>(outTy))
      outTy = ptrTy.getPointeeType();
    auto layout = mlir::cast<ttg::DistributedEncodingTrait>(
        mlir::cast<RankedTensorType>(outTy).getEncoding());
    Operation *op = input.getDefiningOp();
    if (op == nullptr) {
      assert(isa<BlockArgument>(input) &&
             "Unexpected Value without defining op");
      processBlockArgBackward(llvm::cast<BlockArgument>(input), cast);
    } else if (auto prevCast = llvm::dyn_cast<CastOp>(op)) {
      eliminateAdjacentCasts(prevCast, cast);
    } else if (isLoadStoreOp(op)) {
      processLoadStore(op, layout);
    } else if (isElementwiseOp(op)) {
      processElementwise(op, layout);
    } else if (auto constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
      processConstant(constant, layout);
    } else if (auto splat = llvm::dyn_cast<triton::SplatOp>(op)) {
      processSplat(splat, layout);
    } else if (auto makeRange = llvm::dyn_cast<triton::MakeRangeOp>(op)) {
      processMakeRange(makeRange, layout);
    } else if (auto makeTensorPtr =
                   llvm::dyn_cast<triton::MakeTensorPtrOp>(op)) {
      processMakeTensorPtr(makeTensorPtr, layout);
    } else if (llvm::isa<triton::AdvanceOp>(op)) {
      // ptr operand and result have the same layout, while other operands are
      // scalar values
      processElementwise(op, layout);
    } else if (auto broadcast = llvm::dyn_cast<triton::BroadcastOp>(op)) {
      processBroadcast(broadcast, layout);
    } else if (auto expandDims = llvm::dyn_cast<triton::ExpandDimsOp>(op)) {
      processExpandDimsBackward(expandDims, layout);
    } else if (auto ifOp = llvm::dyn_cast<scf::IfOp>(op)) {
      processIfOpBackward(ifOp, cast);
    } else if (auto forOp = llvm::dyn_cast<scf::ForOp>(op)) {
      processForOpBackward(forOp, cast);
    } else if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(op)) {
      return processConvertLayoutBackward(convertLayout, cast);
    } else {
      // Keep original layouts. This may result in a loss of performance.
      return processOpFallback(op);
    }
    return true;
  } else {
    return processMultiUsersBackward(input, cast);
  }
}

bool CTAPlanner::propagateForward(CastOp cast) {
  Value input = cast.getOperand(0);
  Value output = cast.getResult(0);
  unsigned numUsers = getNumUsers(output);
  if (numUsers == 0) {
    cast.erase();
  } else if (numUsers == 1) {
    Type inTy = input.getType();
    if (auto ptrTy = dyn_cast<triton::PointerType>(inTy))
      inTy = ptrTy.getPointeeType();
    Attribute layout = mlir::cast<RankedTensorType>(inTy).getEncoding();
    Operation *op = *output.user_begin();
    if (auto nextCast = llvm::dyn_cast<CastOp>(op)) {
      eliminateAdjacentCasts(cast, nextCast);
    } else if (isLoadStoreOp(op)) {
      processLoadStore(op, layout);
    } else if (isElementwiseOp(op)) {
      processElementwise(op, layout);
    } else if (llvm::isa<triton::AdvanceOp>(op)) {
      // ptr operand and result have the same layout, while other operands are
      // scalar values
      processElementwise(op, layout);
    } else if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(op)) {
      return processConvertLayoutForward(convertLayout, cast);
    } else if (auto forOp = llvm::dyn_cast<scf::ForOp>(op)) {
      processForOpForward(forOp, cast);
    } else if (auto yieldOp = llvm::dyn_cast<scf::YieldOp>(op)) {
      processYieldOpForward(yieldOp, cast);
    } else {
      // Keep original layouts. This may result in a loss of performance.
      return processOpFallback(op);
    }
  } else {
    processMultiUsersForward(output, cast);
  }
  return true;
}

void CTAPlanner::eraseCastOp(CastOp cast) {
  Value output = cast.getResult(0);
  assert(getNumUsers(output) == 0 &&
         "Cannot erase CastOp because it is still in use");
  cast.erase();
}

void CTAPlanner::eraseCastOpFromQueue(CastOp cast) {
  eraseCastOpsFromQueue({cast});
}

void CTAPlanner::eraseCastOpsFromQueue(llvm::ArrayRef<CastOp> casts) {
  llvm::DenseSet<CastOp> erased;
  for (CastOp cast : casts) {
    eraseCastOp(cast);
    erased.insert(cast);
  }

  decltype(queue) tempQueue;
  std::swap(queue, tempQueue);

  // This is only a naive implementation. Should refactor with linked-list.
  while (!tempQueue.empty()) {
    auto cast = tempQueue.front();
    tempQueue.pop();
    if (!erased.contains(cast))
      queue.push(cast);
  }
}

void CTAPlanner::insertCasts(Operation *op,
                             llvm::ArrayRef<Attribute> newOperandLayouts,
                             llvm::ArrayRef<Attribute> newResultLayouts) {
  assert(op->getNumOperands() == newOperandLayouts.size() &&
         "NumOperands mismatched");
  assert(op->getNumResults() == newResultLayouts.size() &&
         "NumResults mismatched");

  Location loc = op->getLoc();
  OpBuilder builder(op->getContext());

  builder.setInsertionPoint(op);
  for (unsigned i = 0; i < op->getNumOperands(); ++i) {
    Value operand = op->getOperand(i);
    auto operandTy = operand.getType();
    if (triton::isTensorOrTensorPointerType(operandTy)) {
      operandTy = replaceLayout(operandTy, newOperandLayouts[i]);
      auto cast = markBackward(builder.create<CastOp>(loc, operandTy, operand));
      op->setOperand(i, cast.getResult(0));
      queue.push(cast);
    }
  }

  builder.setInsertionPointAfter(op);
  for (unsigned i = 0; i < op->getNumResults(); ++i) {
    Value result = op->getResult(i);
    auto resultTy = result.getType();
    if (triton::isTensorOrTensorPointerType(resultTy)) {
      resultTy = replaceLayout(resultTy, newResultLayouts[i]);
      auto cast =
          markForward(builder.create<CastOp>(loc, result.getType(), result));
      result.setType(resultTy);
      result.replaceAllUsesExcept(cast.getResult(0), cast.getOperation());
      queue.push(cast);
    }
  }
}

void CTAPlanner::eliminateAdjacentCasts(CastOp cast0, CastOp cast1) {
  assert(cast0.getResult(0) == cast1.getOperand(0) &&
         "The two casts are not adjacent");
  assert(isForward(cast0) && isBackward(cast1) &&
         "Expected pattern of adjacent casts: forward + backward");

  Value input = cast0.getOperand(0);
  Value output = cast1.getResult(0);

  if (input.getType() == output.getType()) {
    output.replaceAllUsesWith(input);
    eraseCastOpsFromQueue({cast1, cast0});
  } else {
    OpBuilder builder(cast1.getOperation());
    auto cvt = builder.create<ttg::ConvertLayoutOp>(cast1.getLoc(),
                                                    output.getType(), input);
    output.replaceAllUsesWith(cvt.getResult());
    eraseCastOpsFromQueue({cast1, cast0});
  }
}

bool CTAPlanner::isLoadStoreOp(Operation *op) const {
  return llvm::isa<triton::LoadOp, triton::StoreOp, triton::AtomicRMWOp,
                   triton::AtomicCASOp, triton::DescriptorLoadOp,
                   triton::DescriptorStoreLikeOpInterface,
                   triton::DescriptorGatherOp>(op);
}

bool CTAPlanner::processLoadStore(Operation *op, Attribute layout) {
  // Special logic for:
  //     LoadOp -> SliceLayout
  // Transform to:
  //     LoadOp -> originalLayout -> ConvertLayout(DSmem) -> SliceLayout
  if (auto sliceLayout = mlir::dyn_cast<ttg::SliceEncodingAttr>(layout)) {
    auto dim = sliceLayout.getDim();
    auto CTAsPerCGA = ttg::getCTAsPerCGA(sliceLayout.getParent());
    if (CTAsPerCGA[dim] > 1) {
      // Find an input or output value of LoadOp or StoreOp to get its layout
      Value val =
          op->getNumResults() > 0 ? op->getResult(0) : op->getOperand(0);
      Attribute originalLayout =
          cast<RankedTensorType>(val.getType()).getEncoding();
      // Insert casts using originalLayout. Adjacent casts will be eliminated
      // and generate a ConvertLayoutOp with DSmem access
      return processLoadStore(op, originalLayout);
    }
  }

  auto CTALayout = ttg::getCTALayout(layout);
  auto numWarps = ttg::lookupNumWarps(op);

  llvm::SmallVector<Attribute> newOperandLayouts;
  for (unsigned i = 0; i < op->getNumOperands(); ++i) {
    auto type = op->getOperand(i).getType();
    if (auto ptrTy = dyn_cast<triton::PointerType>(type))
      type = ptrTy.getPointeeType();
    auto tensorTy = dyn_cast<RankedTensorType>(type);
    if (!tensorTy) {
      newOperandLayouts.push_back(Attribute());
      continue;
    }
    auto oldLayout =
        cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding());
    auto newLayout =
        replaceCTALayout(oldLayout, tensorTy.getShape(), numWarps, CTALayout);
    newOperandLayouts.push_back(newLayout);
  }

  llvm::SmallVector<Attribute> newResultLayouts;
  for (unsigned i = 0; i < op->getNumResults(); ++i) {
    auto type = op->getResult(i).getType();
    if (auto ptrTy = dyn_cast<triton::PointerType>(type))
      type = ptrTy.getPointeeType();
    auto tensorTy = cast<RankedTensorType>(type);
    auto oldLayout =
        cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding());
    auto newLayout =
        replaceCTALayout(oldLayout, tensorTy.getShape(), numWarps, CTALayout);
    newResultLayouts.push_back(newLayout);
  }

  insertCasts(op, newOperandLayouts, newResultLayouts);
  return true;
}

bool CTAPlanner::isElementwiseOp(Operation *op) const {
  if (llvm::isa<arith::AddFOp, arith::AddIOp, arith::AndIOp, arith::CeilDivSIOp,
                arith::CeilDivUIOp, arith::DivFOp, arith::DivSIOp,
                arith::DivUIOp, arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp,
                arith::FloorDivSIOp, arith::FPToSIOp, arith::FPToUIOp,
                arith::MaximumFOp, arith::MaxNumFOp, arith::MaxSIOp,
                arith::MaxUIOp, arith::MinimumFOp, arith::MinNumFOp,
                arith::MinSIOp, arith::MinUIOp, arith::MulFOp, arith::MulIOp,
                arith::MulUIExtendedOp, arith::MulSIExtendedOp, arith::NegFOp,
                arith::OrIOp, arith::RemFOp, arith::RemSIOp, arith::RemUIOp,
                arith::ShLIOp, arith::ShRSIOp, arith::ShRUIOp, arith::SIToFPOp,
                arith::SubFOp, arith::SubIOp, arith::TruncFOp, arith::TruncIOp,
                arith::UIToFPOp, arith::XOrIOp>(op))
    return true;
  if (llvm::isa<math::AbsFOp, math::AbsIOp, math::AtanOp, math::Atan2Op,
                math::CeilOp, math::CopySignOp, math::CosOp, math::SinOp,
                math::CountLeadingZerosOp, math::CountTrailingZerosOp,
                math::CtPopOp, math::ErfOp, math::ExpOp, math::Exp2Op,
                math::FloorOp, math::ExpM1Op, math::FmaOp, math::LogOp,
                math::Log10Op, math::Log1pOp, math::Log2Op, math::PowFOp,
                math::SqrtOp, math::RsqrtOp, math::TanhOp>(op))
    return true;
  if (llvm::isa<triton::IntToPtrOp, triton::PtrToIntOp, triton::BitcastOp,
                triton::FpToFpOp, triton::AddPtrOp, triton::PreciseSqrtOp,
                triton::PreciseDivFOp>(op))
    return true;
  if (auto externElementwiseOp = dyn_cast<triton::ExternElementwiseOp>(op))
    return externElementwiseOp.getPure();
  if (llvm::isa<arith::CmpIOp, arith::CmpFOp, arith::SelectOp>(op))
    return true;
  return false;
}

bool CTAPlanner::processElementwise(Operation *op, Attribute layout) {
  llvm::SmallVector<Attribute> newOperandLayouts(op->getNumOperands(), layout);
  llvm::SmallVector<Attribute> newResultLayouts(op->getNumResults(), layout);
  insertCasts(op, newOperandLayouts, newResultLayouts);
  return true;
}

bool CTAPlanner::processConstant(arith::ConstantOp constant, Attribute layout) {
  if (auto tensorTy = dyn_cast<RankedTensorType>(constant.getType())) {
    if (auto attr = dyn_cast<SplatElementsAttr>(constant.getValue())) {

      auto newTensorTy = tensorTy.cloneWithEncoding(layout);
      constant.setValueAttr(
          SplatElementsAttr::get(newTensorTy, attr.getSplatValue<Attribute>()));
    }
  }
  insertCasts(constant.getOperation(), {}, {layout});
  return true;
}

bool CTAPlanner::processSplat(triton::SplatOp splat, Attribute layout) {
  insertCasts(splat.getOperation(), {{}}, {layout});
  return true;
}

bool CTAPlanner::processMakeRange(triton::MakeRangeOp makeRange,
                                  Attribute layout) {
  insertCasts(makeRange.getOperation(), {}, {layout});
  return true;
}

bool CTAPlanner::processMakeTensorPtr(triton::MakeTensorPtrOp makeTensorPtr,
                                      Attribute layout) {
  // All inputs of `makeTensorPtr` are scalar types
  llvm::SmallVector<Attribute> dummyInAttrs(makeTensorPtr.getNumOperands(), {});
  insertCasts(makeTensorPtr.getOperation(), dummyInAttrs, {layout});
  return true;
}

bool CTAPlanner::processBroadcast(triton::BroadcastOp broadcast,
                                  Attribute layout) {
  insertCasts(broadcast.getOperation(), {layout}, {layout});
  return true;
}

bool CTAPlanner::processExpandDimsBackward(
    triton::ExpandDimsOp expandDims,
    ttg::DistributedEncodingTrait newResultLayout) {
  auto newSrcLayout = ttg::SliceEncodingAttr::get(
      newResultLayout.getContext(), expandDims.getAxis(), newResultLayout);
  insertCasts(expandDims.getOperation(), {newSrcLayout}, {newResultLayout});
  return true;
}

bool CTAPlanner::processExpandDimsForward(
    triton::ExpandDimsOp expandDims,
    ttg::DistributedEncodingTrait newSrcLayout) {
  llvm::report_fatal_error("processExpandDimsForward not implemented yet");
  return true;
}

bool CTAPlanner::processConvertLayoutBackward(
    ttg::ConvertLayoutOp convertLayout, CastOp cast) {
  Value src = convertLayout.getSrc();
  Value result = convertLayout.getResult();
  assert(getNumUsers(result) == 1 &&
         "Expect to call processMultiUsersBackward first");
  result.replaceAllUsesWith(src);
  convertLayout.erase();
  queue.push(cast);
  return true;
}

bool CTAPlanner::processConvertLayoutForward(ttg::ConvertLayoutOp convertLayout,
                                             CastOp cast) {
  Value src = convertLayout.getSrc();
  Value result = convertLayout.getResult();
  assert(getNumUsers(src) == 1 &&
         "Expect to call processMultiUsersForward first");
  src.setType(result.getType());
  result.replaceAllUsesWith(src);
  convertLayout.erase();
  queue.push(cast);
  return true;
}

bool CTAPlanner::processIfOp(scf::IfOp ifOp, int index, const Type &newType) {
  // Check index
  assert(index < ifOp.getNumResults() && "Invalid result index of IfOp");
  assert(index < ifOp.thenYield().getNumOperands() &&
         "Invalid operand index of YieldOp");
  assert(index < ifOp.elseYield().getNumOperands() &&
         "Invalid operand index of YieldOp");

  Location loc = ifOp.getLoc();
  OpBuilder builder(ifOp.getContext());

  // Insert forward cast after ifOp
  Value result = ifOp.getResult(index);
  builder.setInsertionPointAfter(ifOp.getOperation());
  auto newCast =
      markForward(builder.create<CastOp>(loc, result.getType(), result));
  result.setType(newType);
  result.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation());
  queue.push(newCast);

  // Insert backward casts before yield
  for (scf::YieldOp yield : {ifOp.thenYield(), ifOp.elseYield()}) {
    Value yieldSrc = yield.getOperand(index);
    builder.setInsertionPoint(yield.getOperation());
    newCast = markBackward(builder.create<CastOp>(loc, newType, yieldSrc));
    yield->setOperand(index, newCast.getResult(0));
    queue.push(newCast);
  }

  return true;
}

bool CTAPlanner::processForOp(scf::ForOp forOp, int index,
                              const Type &newType) {
  Block *body = forOp.getBody();
  auto yield = llvm::cast<scf::YieldOp>(forOp.getBody()->getTerminator());

  // Check index
  assert(index + forOp.getNumControlOperands() < forOp.getNumOperands() &&
         "Invalid operand index of ForOp");
  assert(index + forOp.getNumInductionVars() < body->getNumArguments() &&
         "Invalid block arg index of ForOp");
  assert(index < yield.getNumOperands() && "Invalid operand index of YieldOp");
  assert(index < forOp.getNumResults() && "Invalid result index of IfOp");

  Location loc = forOp.getLoc();
  OpBuilder builder(forOp.getContext());

  // Insert backward cast before forOp
  OpOperand &operand =
      forOp->getOpOperand(index + forOp.getNumControlOperands());
  builder.setInsertionPoint(forOp.getOperation());
  auto newCast =
      markBackward(builder.create<CastOp>(loc, newType, operand.get()));
  operand.set(newCast.getResult(0));
  queue.push(newCast);

  // Insert forward cast after block arg
  Value arg = body->getArgument(index + forOp.getNumInductionVars());
  builder.setInsertionPointToStart(body);
  newCast = markForward(builder.create<CastOp>(loc, arg.getType(), arg));
  arg.setType(newType);
  arg.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation());
  queue.push(newCast);

  // Insert backward cast before yield
  Value yieldSrc = yield.getOperand(index);
  builder.setInsertionPoint(yield.getOperation());
  newCast = markBackward(builder.create<CastOp>(loc, newType, yieldSrc));
  yield->setOperand(index, newCast.getResult(0));
  queue.push(newCast);

  // Insert forward cast after forOp
  Value result = forOp.getResult(index);
  builder.setInsertionPointAfter(forOp.getOperation());
  newCast = markForward(builder.create<CastOp>(loc, result.getType(), result));
  result.setType(newType);
  result.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation());
  queue.push(newCast);

  return true;
}

int findResultIndex(Operation *op, Value result) {
  for (int i = 0; i < op->getNumResults(); ++i)
    if (op->getResult(i) == result)
      return i;
  llvm::report_fatal_error("Invalid index of op result");
  return -1;
}

bool CTAPlanner::processIfOpBackward(scf::IfOp ifOp, CastOp cast) {
  int index = findResultIndex(ifOp.getOperation(), cast.getOperand(0));
  auto newType = cast.getResult(0).getType();
  return processIfOp(ifOp, index, newType);
}

bool CTAPlanner::processForOpBackward(scf::ForOp forOp, CastOp cast) {
  int index = findResultIndex(forOp.getOperation(), cast.getOperand(0));
  auto newType = cast.getResult(0).getType();
  return processForOp(forOp, index, newType);
}

bool CTAPlanner::processBlockArgBackward(BlockArgument arg, CastOp cast) {
  if (auto forOp = llvm::dyn_cast<scf::ForOp>(arg.getOwner()->getParentOp())) {
    int index = int(arg.getArgNumber()) - forOp.getNumInductionVars();
    auto newType = cast.getResult(0).getType();
    return processForOp(forOp, index, newType);
  } else {
    llvm::report_fatal_error("Unexpected parent op of block argument");
    return true;
  }
}

bool CTAPlanner::processForOpForward(scf::ForOp forOp, CastOp cast) {
  int index = cast.getResult(0).use_begin()->getOperandNumber() -
              forOp.getNumControlOperands();
  auto newType = cast.getOperand(0).getType();
  return processForOp(forOp, index, newType);
}

bool CTAPlanner::processYieldOpForward(scf::YieldOp yieldOp, CastOp cast) {
  int index = cast.getResult(0).use_begin()->getOperandNumber();
  auto newType = cast.getOperand(0).getType();
  if (auto ifOp = llvm::dyn_cast<scf::IfOp>(yieldOp->getParentOp()))
    return processIfOp(ifOp, index, newType);
  else if (auto forOp = llvm::dyn_cast<scf::ForOp>(yieldOp->getParentOp()))
    return processForOp(forOp, index, newType);
  else
    llvm::report_fatal_error("Unexpected parent op of YieldOp");
  return true;
}

bool CTAPlanner::processOpFallback(Operation *op) {
  Location loc = op->getLoc();
  OpBuilder builder(op->getContext());

  builder.setInsertionPoint(op);
  for (unsigned i = 0; i < op->getNumOperands(); ++i) {
    Value operand = op->getOperand(i);
    auto operandTy = operand.getType();
    if (triton::isTensorOrTensorPointerType(operandTy)) {
      auto cast = markBackward(builder.create<CastOp>(loc, operandTy, operand));
      op->setOperand(i, cast.getResult(0));
      queue.push(cast);
    }
  }

  builder.setInsertionPointAfter(op);
  for (unsigned i = 0; i < op->getNumResults(); ++i) {
    Value result = op->getResult(i);
    auto resultTy = result.getType();
    if (triton::isTensorOrTensorPointerType(resultTy)) {
      auto cast = markForward(builder.create<CastOp>(loc, resultTy, result));
      result.replaceAllUsesExcept(cast.getResult(0), cast.getOperation());
      queue.push(cast);
    }
  }

  return true;
}

bool CTAPlanner::processMultiUsersBackward(Value input, CastOp cast) {
  Location loc = input.getLoc();
  OpBuilder builder(input.getContext());

  llvm::DenseMap<Type, llvm::SmallVector<CastOp>> typeToIndices;
  for (OpOperand &operand : input.getUses()) {
    auto brotherCast = llvm::dyn_cast<CastOp>(operand.getOwner());
    if (!brotherCast) {
      if (stepUnchanged <= queue.size())
        return false;
      builder.setInsertionPoint(operand.getOwner());
      brotherCast = markBackward(
          builder.create<CastOp>(loc, cast.getResult(0).getType(), input));
      auto newCast = markForward(builder.create<CastOp>(
          loc, input.getType(), brotherCast.getResult(0)));
      operand.set(newCast.getResult(0));
      queue.push(brotherCast);
      queue.push(newCast);
    }
    auto type = brotherCast.getResult(0).getType();
    typeToIndices[type].push_back(brotherCast);
  }

  bool first = true;
  for (auto it : typeToIndices) {
    Type &type = it.first;
    llvm::SmallVector<CastOp> &casts = it.second;
    Value newInput = input;
    if (!first) {
      if (Operation *defOp = input.getDefiningOp()) {
        builder.setInsertionPointAfter(defOp);
        Operation *clonedOp = builder.clone(*defOp);
        newInput = clonedOp->getResult(0);
      } else {
        llvm::report_fatal_error("Layout conflict for block arg"); // TODO
        return false;
      }
    }
    first = false;
    if (Operation *defOp = newInput.getDefiningOp()) {
      builder.setInsertionPointAfter(defOp);
    } else {
      assert(isa<BlockArgument>(newInput) &&
             "Unexpected Value without defining op");
      builder.setInsertionPointToStart(
          llvm::cast<BlockArgument>(newInput).getOwner());
    }
    auto newCast = markBackward(builder.create<CastOp>(loc, type, newInput));
    queue.push(newCast);
    auto newResult = newCast.getResult(0);
    for (CastOp &brotherCast : casts) {
      brotherCast.getResult(0).replaceAllUsesWith(newResult);
      eraseCastOpFromQueue(brotherCast);
    }
  }
  return true;
}

bool CTAPlanner::processMultiUsersForward(Value castResult, CastOp cast) {
  Value castSrc = cast.getOperand(0);

  Location loc = cast.getLoc();
  OpBuilder builder(cast.getContext());
  builder.setInsertionPointAfter(cast.getOperation());

  while (!castResult.use_empty()) {
    auto newCast =
        markForward(builder.create<CastOp>(loc, castResult.getType(), castSrc));
    castResult.use_begin()->set(newCast.getResult(0));
    queue.push(newCast);
  }

  eraseCastOp(cast);
  return true;
}

} // anonymous namespace

struct PlanCTAPass : public impl::TritonGPUPlanCTAPassBase<PlanCTAPass> {
  PlanCTAPass(ClusterInfo *clusterInfo_ = nullptr)
      : clusterInfo(clusterInfo_) {}
  void runOnOperation() override {
    ModuleOp mod = getOperation();

    // Skip PlanCTAPass when numCTAs == 1
    if (ttg::TritonGPUDialect::getNumCTAs(mod) == 1)
      return;

    mod.walk([&](triton::FuncOp funcOp) {
      CTAPlanner planner(clusterInfo);
      planner.run(funcOp);

      // FIXME: Clone funcOp so that the IR change can be identified after
      // PlanCTAPass. Without this, the change after PlanCTAPass will not be
      // displayed when MLIR_ENABLE_DUMP=1. This is not reasonable and should
      // be fixed later.
      OpBuilder builder(funcOp);
      builder.clone(*funcOp.getOperation());
      funcOp.erase();
    });
  }

  ClusterInfo *clusterInfo;
};

std::unique_ptr<Pass>
createTritonNvidiaGPUPlanCTAPass(ClusterInfo *clusterInfo) {
  return std::make_unique<PlanCTAPass>(clusterInfo);
}

} // namespace nvidia_gpu
} // namespace triton
} // namespace mlir

/* TODO
 * - Use ConvertLayoutOp instead of UnrealizedConversionCastOp.
 * - Move PlanCTAPass to the front of CoalescePass.
 * - Design better tiling strategy for DotOp and ReduceOp.
 * - Consider cases where there are more than one DotOps.
 * - Use better data structure for erasing CastOps from queue (linked list?).
 * - Process eliminable CastOps in higher priority.
 * - Fix the clone func bug in PlanCTAPass::runOnOperation.
 * - Add some comments to introduce the overall idea of this pass.
 * - Add some lit tests for this pass.
 */