//===- TestDenseForwardDataFlowAnalysis.cpp -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Implementation of tests passes exercising dense forward data flow analysis.
//
//===----------------------------------------------------------------------===//

#include "TestDenseDataFlowAnalysis.h"
#include "TestDialect.h"
#include "TestOps.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/TypeSwitch.h"
#include <optional>

using namespace mlir;
using namespace mlir::dataflow;
using namespace mlir::dataflow::test;

namespace {

/// This lattice represents, for a given memory resource, the potential last
/// operations that modified the resource.
class LastModification : public AbstractDenseLattice, public AccessLatticeBase {
public:
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification)

  using AbstractDenseLattice::AbstractDenseLattice;

  /// Join the last modifications.
  ChangeResult join(const AbstractDenseLattice &lattice) override {
    return AccessLatticeBase::merge(static_cast<AccessLatticeBase>(
        static_cast<const LastModification &>(lattice)));
  }

  void print(raw_ostream &os) const override {
    return AccessLatticeBase::print(os);
  }
};

class LastModifiedAnalysis
    : public DenseForwardDataFlowAnalysis<LastModification> {
public:
  explicit LastModifiedAnalysis(DataFlowSolver &solver, bool assumeFuncWrites)
      : DenseForwardDataFlowAnalysis(solver),
        assumeFuncWrites(assumeFuncWrites) {}

  /// Visit an operation. If the operation has no memory effects, then the state
  /// is propagated with no change. If the operation allocates a resource, then
  /// its reaching definitions is set to empty. If the operation writes to a
  /// resource, then its reaching definition is set to the written value.
  void visitOperation(Operation *op, const LastModification &before,
                      LastModification *after) override;

  void visitCallControlFlowTransfer(CallOpInterface call,
                                    CallControlFlowAction action,
                                    const LastModification &before,
                                    LastModification *after) override;

  void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
                                            std::optional<unsigned> regionFrom,
                                            std::optional<unsigned> regionTo,
                                            const LastModification &before,
                                            LastModification *after) override;

  /// At an entry point, the last modifications of all memory resources are
  /// unknown.
  void setToEntryState(LastModification *lattice) override {
    propagateIfChanged(lattice, lattice->reset());
  }

private:
  const bool assumeFuncWrites;
};
} // end anonymous namespace

void LastModifiedAnalysis::visitOperation(Operation *op,
                                          const LastModification &before,
                                          LastModification *after) {
  auto memory = dyn_cast<MemoryEffectOpInterface>(op);
  // If we can't reason about the memory effects, then conservatively assume we
  // can't deduce anything about the last modifications.
  if (!memory)
    return setToEntryState(after);

  SmallVector<MemoryEffects::EffectInstance> effects;
  memory.getEffects(effects);

  // First, check if all underlying values are already known. Otherwise, avoid
  // propagating and stay in the "undefined" state to avoid incorrectly
  // propagating values that may be overwritten later on as that could be
  // problematic for convergence based on monotonicity of lattice updates.
  SmallVector<Value> underlyingValues;
  underlyingValues.reserve(effects.size());
  for (const auto &effect : effects) {
    Value value = effect.getValue();

    // If we see an effect on anything other than a value, assume we can't
    // deduce anything about the last modifications.
    if (!value)
      return setToEntryState(after);

    // If we cannot find the underlying value, we shouldn't just propagate the
    // effects through, return the pessimistic state.
    std::optional<Value> underlyingValue =
        UnderlyingValueAnalysis::getMostUnderlyingValue(
            value, [&](Value value) {
              return getOrCreateFor<UnderlyingValueLattice>(op, value);
            });

    // If the underlying value is not yet known, don't propagate yet.
    if (!underlyingValue)
      return;

    underlyingValues.push_back(*underlyingValue);
  }

  // Update the state when all underlying values are known.
  ChangeResult result = after->join(before);
  for (const auto &[effect, value] : llvm::zip(effects, underlyingValues)) {
    // If the underlying value is known to be unknown, set to fixpoint state.
    if (!value)
      return setToEntryState(after);

    // Nothing to do for reads.
    if (isa<MemoryEffects::Read>(effect.getEffect()))
      continue;

    result |= after->set(value, op);
  }
  propagateIfChanged(after, result);
}

void LastModifiedAnalysis::visitCallControlFlowTransfer(
    CallOpInterface call, CallControlFlowAction action,
    const LastModification &before, LastModification *after) {
  if (action == CallControlFlowAction::ExternalCallee && assumeFuncWrites) {
    SmallVector<Value> underlyingValues;
    underlyingValues.reserve(call->getNumOperands());
    for (Value operand : call.getArgOperands()) {
      std::optional<Value> underlyingValue =
          UnderlyingValueAnalysis::getMostUnderlyingValue(
              operand, [&](Value value) {
                return getOrCreateFor<UnderlyingValueLattice>(
                    call.getOperation(), value);
              });
      if (!underlyingValue)
        return;
      underlyingValues.push_back(*underlyingValue);
    }

    ChangeResult result = after->join(before);
    for (Value operand : underlyingValues)
      result |= after->set(operand, call);
    return propagateIfChanged(after, result);
  }
  auto testCallAndStore =
      dyn_cast<::test::TestCallAndStoreOp>(call.getOperation());
  if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee &&
                            testCallAndStore.getStoreBeforeCall()) ||
                           (action == CallControlFlowAction::ExitCallee &&
                            !testCallAndStore.getStoreBeforeCall()))) {
    return visitOperation(call, before, after);
  }
  AbstractDenseForwardDataFlowAnalysis::visitCallControlFlowTransfer(
      call, action, before, after);
}

void LastModifiedAnalysis::visitRegionBranchControlFlowTransfer(
    RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
    std::optional<unsigned> regionTo, const LastModification &before,
    LastModification *after) {
  auto defaultHandling = [&]() {
    AbstractDenseForwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
        branch, regionFrom, regionTo, before, after);
  };
  TypeSwitch<Operation *>(branch.getOperation())
      .Case<::test::TestStoreWithARegion, ::test::TestStoreWithALoopRegion>(
          [=](auto storeWithRegion) {
            if ((!regionTo && !storeWithRegion.getStoreBeforeRegion()) ||
                (!regionFrom && storeWithRegion.getStoreBeforeRegion()))
              visitOperation(branch, before, after);
            defaultHandling();
          })
      .Default([=](auto) { defaultHandling(); });
}

namespace {
struct TestLastModifiedPass
    : public PassWrapper<TestLastModifiedPass, OperationPass<>> {
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass)

  TestLastModifiedPass() = default;
  TestLastModifiedPass(const TestLastModifiedPass &other) : PassWrapper(other) {
    interprocedural = other.interprocedural;
    assumeFuncWrites = other.assumeFuncWrites;
  }

  StringRef getArgument() const override { return "test-last-modified"; }

  Option<bool> interprocedural{
      *this, "interprocedural", llvm::cl::init(true),
      llvm::cl::desc("perform interprocedural analysis")};
  Option<bool> assumeFuncWrites{
      *this, "assume-func-writes", llvm::cl::init(false),
      llvm::cl::desc(
          "assume external functions have write effect on all arguments")};

  void runOnOperation() override {
    Operation *op = getOperation();

    DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
    solver.load<DeadCodeAnalysis>();
    solver.load<SparseConstantPropagation>();
    solver.load<LastModifiedAnalysis>(assumeFuncWrites);
    solver.load<UnderlyingValueAnalysis>();
    if (failed(solver.initializeAndRun(op)))
      return signalPassFailure();

    raw_ostream &os = llvm::errs();

    // Note that if the underlying value could not be computed or is unknown, we
    // conservatively treat the result also unknown.
    op->walk([&](Operation *op) {
      auto tag = op->getAttrOfType<StringAttr>("tag");
      if (!tag)
        return;
      os << "test_tag: " << tag.getValue() << ":\n";
      const LastModification *lastMods =
          solver.lookupState<LastModification>(op);
      assert(lastMods && "expected a dense lattice");
      for (auto [index, operand] : llvm::enumerate(op->getOperands())) {
        os << " operand #" << index << "\n";
        std::optional<Value> underlyingValue =
            UnderlyingValueAnalysis::getMostUnderlyingValue(
                operand, [&](Value value) {
                  return solver.lookupState<UnderlyingValueLattice>(value);
                });
        if (!underlyingValue) {
          os << " - <unknown>\n";
          continue;
        }
        Value value = *underlyingValue;
        assert(value && "expected an underlying value");
        if (const AdjacentAccess *lastMod =
                lastMods->getAdjacentAccess(value)) {
          if (!lastMod->isKnown()) {
            os << " - <unknown>\n";
          } else {
            for (Operation *lastModifier : lastMod->get()) {
              if (auto tagName =
                      lastModifier->getAttrOfType<StringAttr>("tag_name")) {
                os << "  - " << tagName.getValue() << "\n";
              } else {
                os << "  - " << lastModifier->getName() << "\n";
              }
            }
          }
        } else {
          os << "  - <unknown>\n";
        }
      }
    });
  }
};
} // end anonymous namespace

namespace mlir {
namespace test {
void registerTestLastModifiedPass() {
  PassRegistration<TestLastModifiedPass>();
}
} // end namespace test
} // end namespace mlir