//===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Transforms/Passes.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"

using namespace mlir;

namespace mlir {
namespace transform {
#define GEN_PASS_DEF_INTERPRETERPASS
#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
} // namespace transform
} // namespace mlir

/// Returns the payload operation to be used as payload root:
///   - the operation nested under `passRoot` that has the given tag attribute,
///     must be unique;
///   - the `passRoot` itself if the tag is empty.
static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) {
  // Fast return.
  if (tag.empty())
    return passRoot;

  // Walk to do a lookup.
  Operation *target = nullptr;
  auto tagAttrName = StringAttr::get(
      passRoot->getContext(), transform::TransformDialect::kTargetTagAttrName);
  WalkResult walkResult = passRoot->walk([&](Operation *op) {
    auto attr = op->getAttrOfType<StringAttr>(tagAttrName);
    if (!attr || attr.getValue() != tag)
      return WalkResult::advance();

    if (!target) {
      target = op;
      return WalkResult::advance();
    }

    InFlightDiagnostic diag = op->emitError()
                              << "repeated operation with the target tag '"
                              << tag << "'";
    diag.attachNote(target->getLoc()) << "previously seen operation";
    return WalkResult::interrupt();
  });

  if (!target) {
    passRoot->emitError()
        << "could not find the operation with transform.target_tag=\"" << tag
        << "\" attribute";
    return nullptr;
  }

  return walkResult.wasInterrupted() ? nullptr : target;
}

namespace {
class InterpreterPass
    : public transform::impl::InterpreterPassBase<InterpreterPass> {
  // Parses the pass arguments to bind trailing arguments of the entry point.
  std::optional<RaggedArray<transform::MappedValue>>
  parseArguments(Operation *payloadRoot) {
    MLIRContext *context = payloadRoot->getContext();

    SmallVector<SmallVector<transform::MappedValue>, 2> trailingBindings;
    trailingBindings.resize(debugBindTrailingArgs.size());

    // Construct lists of op names to match.
    SmallVector<std::optional<OperationName>> debugBindNames;
    debugBindNames.reserve(debugBindTrailingArgs.size());
    for (auto &&[position, nameString] :
         llvm::enumerate(debugBindTrailingArgs)) {
      StringRef name = nameString;

      // Parse the integer literals.
      if (name.starts_with("#")) {
        debugBindNames.push_back(std::nullopt);
        StringRef lhs = "";
        StringRef rhs = name.drop_front();
        do {
          std::tie(lhs, rhs) = rhs.split(';');
          int64_t value;
          if (lhs.getAsInteger(10, value)) {
            emitError(UnknownLoc::get(context))
                << "couldn't parse integer pass argument " << name;
            return std::nullopt;
          }
          trailingBindings[position].push_back(
              Builder(context).getI64IntegerAttr(value));
        } while (!rhs.empty());
      } else if (name.starts_with("^")) {
        debugBindNames.emplace_back(OperationName(name.drop_front(), context));
      } else {
        debugBindNames.emplace_back(OperationName(name, context));
      }
    }

    // Collect operations or results for extra bindings.
    payloadRoot->walk([&](Operation *payload) {
      for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
        if (!name || payload->getName() != *name)
          continue;

        if (StringRef(*std::next(debugBindTrailingArgs.begin(), position))
                .starts_with("^")) {
          llvm::append_range(trailingBindings[position], payload->getResults());
        } else {
          trailingBindings[position].push_back(payload);
        }
      }
    });

    RaggedArray<transform::MappedValue> bindings;
    bindings.push_back(ArrayRef<Operation *>{payloadRoot});
    for (SmallVector<transform::MappedValue> &trailing : trailingBindings)
      bindings.push_back(std::move(trailing));
    return bindings;
  }

public:
  using Base::Base;

  void runOnOperation() override {
    MLIRContext *context = &getContext();
    ModuleOp transformModule =
        transform::detail::getPreloadedTransformModule(context);
    Operation *payloadRoot =
        findPayloadRoot(getOperation(), debugPayloadRootTag);
    if (!payloadRoot)
      return signalPassFailure();

    Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
        getOperation(), transformModule, entryPoint);
    if (!transformEntryPoint)
      return signalPassFailure();

    std::optional<RaggedArray<transform::MappedValue>> bindings =
        parseArguments(payloadRoot);
    if (!bindings)
      return signalPassFailure();
    if (failed(transform::applyTransformNamedSequence(
            *bindings,
            cast<transform::TransformOpInterface>(transformEntryPoint),
            transformModule,
            options.enableExpensiveChecks(!disableExpensiveChecks)))) {
      return signalPassFailure();
    }
  }

private:
  /// Transform interpreter options.
  transform::TransformOptions options;
};
} // namespace