//===- TransformInterpreterUtils.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Lightweight transform dialect interpreter utilities.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;

#define DEBUG_TYPE "transform-dialect-interpreter-utils"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")

/// Expands the given list of `paths` to a list of `.mlir` files.
///
/// Each entry in `paths` may either be a regular file, in which case it ends up
/// in the result list, or a directory, in which case all (regular) `.mlir`
/// files in that directory are added. Any other file types lead to a failure.
LogicalResult transform::detail::expandPathsToMLIRFiles(
    ArrayRef<std::string> paths, MLIRContext *context,
    SmallVectorImpl<std::string> &fileNames) {
  for (const std::string &path : paths) {
    auto loc = FileLineColLoc::get(context, path, 0, 0);

    if (llvm::sys::fs::is_regular_file(path)) {
      LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
      fileNames.push_back(path);
      continue;
    }

    if (!llvm::sys::fs::is_directory(path)) {
      return emitError(loc)
             << "'" << path << "' is neither a file nor a directory";
    }

    LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");

    std::error_code ec;
    for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
         it != itEnd && !ec; it.increment(ec)) {
      const std::string &fileName = it->path();

      if (it->type() != llvm::sys::fs::file_type::regular_file &&
          it->type() != llvm::sys::fs::file_type::symlink_file) {
        LLVM_DEBUG(DBGS() << "  Skipping non-regular file '" << fileName
                          << "'\n");
        continue;
      }

      if (!StringRef(fileName).ends_with(".mlir")) {
        LLVM_DEBUG(DBGS() << "  Skipping '" << fileName
                          << "' because it does not end with '.mlir'\n");
        continue;
      }

      LLVM_DEBUG(DBGS() << "  Adding '" << fileName << "' to list of files\n");
      fileNames.push_back(fileName);
    }

    if (ec)
      return emitError(loc) << "error while opening files in '" << path
                            << "': " << ec.message();
  }

  return success();
}

LogicalResult transform::detail::parseTransformModuleFromFile(
    MLIRContext *context, llvm::StringRef transformFileName,
    OwningOpRef<ModuleOp> &transformModule) {
  if (transformFileName.empty()) {
    LLVM_DEBUG(
        DBGS() << "no transform file name specified, assuming the transform "
                  "module is embedded in the IR next to the top-level\n");
    return success();
  }
  // Parse transformFileName content into a ModuleOp.
  std::string errorMessage;
  auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
  if (!memoryBuffer) {
    return emitError(FileLineColLoc::get(
               StringAttr::get(context, transformFileName), 0, 0))
           << "failed to open transform file: " << errorMessage;
  }
  // Tell sourceMgr about this buffer, the parser will pick it up.
  llvm::SourceMgr sourceMgr;
  sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
  transformModule =
      OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
  if (!transformModule) {
    // Failed to parse the transform module.
    // Don't need to emit an error here as the parsing should have already done
    // that.
    return failure();
  }
  return mlir::verify(*transformModule);
}

ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
  return context->getOrLoadDialect<transform::TransformDialect>()
      ->getLibraryModule();
}

transform::TransformOpInterface
transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
                                           StringRef entryPoint) {
  SmallVector<Operation *, 2> l{root};
  if (module)
    l.push_back(module);
  for (Operation *op : l) {
    transform::TransformOpInterface transform = nullptr;
    op->walk<WalkOrder::PreOrder>(
        [&](transform::NamedSequenceOp namedSequenceOp) {
          if (namedSequenceOp.getSymName() == entryPoint) {
            transform = cast<transform::TransformOpInterface>(
                namedSequenceOp.getOperation());
            return WalkResult::interrupt();
          }
          return WalkResult::advance();
        });
    if (transform)
      return transform;
  }
  auto diag = root->emitError()
              << "could not find a nested named sequence with name: "
              << entryPoint;
  return nullptr;
}

LogicalResult transform::detail::assembleTransformLibraryFromPaths(
    MLIRContext *context, ArrayRef<std::string> transformLibraryPaths,
    OwningOpRef<ModuleOp> &transformModule) {
  // Assemble list of library files.
  SmallVector<std::string> libraryFileNames;
  if (failed(detail::expandPathsToMLIRFiles(transformLibraryPaths, context,
                                            libraryFileNames)))
    return failure();

  // Parse modules from library files.
  SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
  for (const std::string &libraryFileName : libraryFileNames) {
    OwningOpRef<ModuleOp> parsedLibrary;
    if (failed(transform::detail::parseTransformModuleFromFile(
            context, libraryFileName, parsedLibrary)))
      return failure();
    parsedLibraries.push_back(std::move(parsedLibrary));
  }

  // Merge parsed libraries into one module.
  auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0);
  OwningOpRef<ModuleOp> mergedParsedLibraries =
      ModuleOp::create(loc, "__transform");
  {
    mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
                                         UnitAttr::get(context));
    // TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
    for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
      if (failed(transform::detail::mergeSymbolsInto(
              mergedParsedLibraries.get(), std::move(parsedLibrary))))
        return parsedLibrary->emitError()
               << "failed to merge symbols into shared library module";
    }
  }

  transformModule = std::move(mergedParsedLibraries);
  return success();
}

LogicalResult transform::applyTransformNamedSequence(
    Operation *payload, Operation *transformRoot, ModuleOp transformModule,
    const TransformOptions &options) {
  RaggedArray<MappedValue> bindings;
  bindings.push_back(ArrayRef<Operation *>{payload});
  return applyTransformNamedSequence(bindings,
                                     cast<TransformOpInterface>(transformRoot),
                                     transformModule, options);
}

LogicalResult transform::applyTransformNamedSequence(
    RaggedArray<MappedValue> bindings, TransformOpInterface transformRoot,
    ModuleOp transformModule, const TransformOptions &options) {
  if (bindings.empty()) {
    return transformRoot.emitError()
           << "expected at least one binding for the root";
  }
  if (bindings.at(0).size() != 1) {
    return transformRoot.emitError()
           << "expected one payload to be bound to the first argument, got "
           << bindings.at(0).size();
  }
  auto *payloadRoot = bindings.at(0).front().dyn_cast<Operation *>();
  if (!payloadRoot) {
    return transformRoot->emitError() << "expected the object bound to the "
                                         "first argument to be an operation";
  }

  bindings.removeFront();

  // `transformModule` may not be modified.
  if (transformModule && !transformModule->isAncestor(transformRoot)) {
    OwningOpRef<Operation *> clonedTransformModule(transformModule->clone());
    if (failed(detail::mergeSymbolsInto(
            SymbolTable::getNearestSymbolTable(transformRoot),
            std::move(clonedTransformModule)))) {
      return payloadRoot->emitError() << "failed to merge symbols";
    }
  }

  LLVM_DEBUG(DBGS() << "Apply\n" << *transformRoot << "\n");
  LLVM_DEBUG(DBGS() << "To\n" << *payloadRoot << "\n");

  return applyTransforms(payloadRoot, transformRoot, bindings, options,
                         /*enforceToplevelTransformOp=*/false);
}