//===- TestBytecodeCallbacks.cpp - Pass to test bytecode callback hooks  --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TestDialect.h"
#include "TestOps.h"
#include "mlir/Bytecode/BytecodeReader.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/raw_ostream.h"
#include <list>

using namespace mlir;
using namespace llvm;

namespace {
class TestDialectVersionParser : public cl::parser<test::TestDialectVersion> {
public:
  TestDialectVersionParser(cl::Option &o)
      : cl::parser<test::TestDialectVersion>(o) {}

  bool parse(cl::Option &o, StringRef /*argName*/, StringRef arg,
             test::TestDialectVersion &v) {
    long long major, minor;
    if (getAsSignedInteger(arg.split(".").first, 10, major))
      return o.error("Invalid argument '" + arg);
    if (getAsSignedInteger(arg.split(".").second, 10, minor))
      return o.error("Invalid argument '" + arg);
    v = test::TestDialectVersion(major, minor);
    // Returns true on error.
    return false;
  }
  static void print(raw_ostream &os, const test::TestDialectVersion &v) {
    os << v.major_ << "." << v.minor_;
  };
};

/// This is a test pass which uses callbacks to encode attributes and types in a
/// custom fashion.
struct TestBytecodeRoundtripPass
    : public PassWrapper<TestBytecodeRoundtripPass, OperationPass<ModuleOp>> {
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeRoundtripPass)

  StringRef getArgument() const final { return "test-bytecode-roundtrip"; }
  StringRef getDescription() const final {
    return "Test pass to implement bytecode roundtrip tests.";
  }
  void getDependentDialects(DialectRegistry &registry) const override {
    registry.insert<test::TestDialect>();
  }
  TestBytecodeRoundtripPass() = default;
  TestBytecodeRoundtripPass(const TestBytecodeRoundtripPass &) {}

  LogicalResult initialize(MLIRContext *context) override {
    testDialect = context->getOrLoadDialect<test::TestDialect>();
    return success();
  }

  void runOnOperation() override {
    switch (testKind) {
      // Tests 0-5 implement a custom roundtrip with callbacks.
    case (0):
      return runTest0(getOperation());
    case (1):
      return runTest1(getOperation());
    case (2):
      return runTest2(getOperation());
    case (3):
      return runTest3(getOperation());
    case (4):
      return runTest4(getOperation());
    case (5):
      return runTest5(getOperation());
    case (6):
      // test-kind 6 is a plain roundtrip with downgrade/upgrade to/from
      // `targetVersion`.
      return runTest6(getOperation());
    default:
      llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass");
    }
  }

  mlir::Pass::Option<test::TestDialectVersion, TestDialectVersionParser>
      targetVersion{*this, "test-dialect-version",
                    llvm::cl::desc(
                        "Specifies the test dialect version to emit and parse"),
                    cl::init(test::TestDialectVersion())};

  mlir::Pass::Option<int> testKind{
      *this, "test-kind", llvm::cl::desc("Specifies the test kind to execute"),
      cl::init(0)};

private:
  void doRoundtripWithConfigs(Operation *op,
                              const BytecodeWriterConfig &writeConfig,
                              const ParserConfig &parseConfig) {
    std::string bytecode;
    llvm::raw_string_ostream os(bytecode);
    if (failed(writeBytecodeToFile(op, os, writeConfig))) {
      op->emitError() << "failed to write bytecode\n";
      signalPassFailure();
      return;
    }
    auto newModuleOp = parseSourceString(StringRef(bytecode), parseConfig);
    if (!newModuleOp.get()) {
      op->emitError() << "failed to read bytecode\n";
      signalPassFailure();
      return;
    }
    // Print the module to the output stream, so that we can filecheck the
    // result.
    newModuleOp->print(llvm::outs());
  }

  // Test0: let's assume that versions older than 2.0 were relying on a special
  // integer attribute of a deprecated dialect called "funky". Assume that its
  // encoding was made by two varInts, the first was the ID (999) and the second
  // contained width and signedness info. We can emit it using a callback
  // writing a custom encoding for the "funky" dialect group, and parse it back
  // with a custom parser reading the same encoding in the same dialect group.
  // Note that the ID 999 does not correspond to a valid integer type in the
  // current encodings of builtin types.
  void runTest0(Operation *op) {
    auto newCtx = std::make_shared<MLIRContext>();
    test::TestDialectVersion targetEmissionVersion = targetVersion;
    BytecodeWriterConfig writeConfig;
    // Set the emission version for the test dialect.
    writeConfig.setDialectVersion<test::TestDialect>(
        std::make_unique<test::TestDialectVersion>(targetEmissionVersion));
    writeConfig.attachTypeCallback(
        [&](Type entryValue, std::optional<StringRef> &dialectGroupName,
            DialectBytecodeWriter &writer) -> LogicalResult {
          // Do not override anything if version greater than 2.0.
          auto versionOr = writer.getDialectVersion<test::TestDialect>();
          assert(succeeded(versionOr) && "expected reader to be able to access "
                                         "the version for test dialect");
          const auto *version =
              reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
          if (version->major_ >= 2)
            return failure();

          // For version less than 2.0, override the encoding of IntegerType.
          if (auto type = llvm::dyn_cast<IntegerType>(entryValue)) {
            llvm::outs() << "Overriding IntegerType encoding...\n";
            dialectGroupName = StringLiteral("funky");
            writer.writeVarInt(/* IntegerType */ 999);
            writer.writeVarInt(type.getWidth() << 2 | type.getSignedness());
            return success();
          }
          return failure();
        });
    newCtx->appendDialectRegistry(op->getContext()->getDialectRegistry());
    newCtx->allowUnregisteredDialects();
    ParserConfig parseConfig(newCtx.get(), /*verifyAfterParse=*/true);
    parseConfig.getBytecodeReaderConfig().attachTypeCallback(
        [&](DialectBytecodeReader &reader, StringRef dialectName,
            Type &entry) -> LogicalResult {
          // Get test dialect version from the version map.
          auto versionOr = reader.getDialectVersion<test::TestDialect>();
          assert(succeeded(versionOr) && "expected reader to be able to access "
                                         "the version for test dialect");
          const auto *version =
              reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
          if (version->major_ >= 2)
            return success();

          // `dialectName` is the name of the group we have the opportunity to
          // override. In this case, override only the dialect group "funky",
          // for which does not exist in memory.
          if (dialectName != StringLiteral("funky"))
            return success();

          uint64_t encoding;
          if (failed(reader.readVarInt(encoding)) || encoding != 999)
            return success();
          llvm::outs() << "Overriding parsing of IntegerType encoding...\n";
          uint64_t widthAndSignedness, width;
          IntegerType::SignednessSemantics signedness;
          if (succeeded(reader.readVarInt(widthAndSignedness)) &&
              ((width = widthAndSignedness >> 2), true) &&
              ((signedness = static_cast<IntegerType::SignednessSemantics>(
                    widthAndSignedness & 0x3)),
               true))
            entry = IntegerType::get(reader.getContext(), width, signedness);
          // Return nullopt to fall through the rest of the parsing code path.
          return success();
        });
    doRoundtripWithConfigs(op, writeConfig, parseConfig);
  }

  // Test1: When writing bytecode, we override the encoding of TestI32Type with
  // the encoding of builtin IntegerType. We can natively parse this without
  // the use of a callback, relying on the existing builtin reader mechanism.
  void runTest1(Operation *op) {
    auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
    BytecodeDialectInterface *iface =
        builtin->getRegisteredInterface<BytecodeDialectInterface>();
    BytecodeWriterConfig writeConfig;
    writeConfig.attachTypeCallback(
        [&](Type entryValue, std::optional<StringRef> &dialectGroupName,
            DialectBytecodeWriter &writer) -> LogicalResult {
          // Emit TestIntegerType using the builtin dialect encoding.
          if (llvm::isa<test::TestI32Type>(entryValue)) {
            llvm::outs() << "Overriding TestI32Type encoding...\n";
            auto builtinI32Type =
                IntegerType::get(op->getContext(), 32,
                                 IntegerType::SignednessSemantics::Signless);
            // Specify that this type will need to be written as part of the
            // builtin group. This will override the default dialect group of
            // the attribute (test).
            dialectGroupName = StringLiteral("builtin");
            if (succeeded(iface->writeType(builtinI32Type, writer)))
              return success();
          }
          return failure();
        });
    // We natively parse the attribute as a builtin, so no callback needed.
    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
    doRoundtripWithConfigs(op, writeConfig, parseConfig);
  }

  // Test2: When writing bytecode, we write standard builtin IntegerTypes. At
  // parsing, we use the encoding of IntegerType to intercept all i32. Then,
  // instead of creating i32s, we assemble TestI32Type and return it.
  void runTest2(Operation *op) {
    auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
    BytecodeDialectInterface *iface =
        builtin->getRegisteredInterface<BytecodeDialectInterface>();
    BytecodeWriterConfig writeConfig;
    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
    parseConfig.getBytecodeReaderConfig().attachTypeCallback(
        [&](DialectBytecodeReader &reader, StringRef dialectName,
            Type &entry) -> LogicalResult {
          if (dialectName != StringLiteral("builtin"))
            return success();
          Type builtinAttr = iface->readType(reader);
          if (auto integerType =
                  llvm::dyn_cast_or_null<IntegerType>(builtinAttr)) {
            if (integerType.getWidth() == 32 && integerType.isSignless()) {
              llvm::outs() << "Overriding parsing of TestI32Type encoding...\n";
              entry = test::TestI32Type::get(reader.getContext());
            }
          }
          return success();
        });
    doRoundtripWithConfigs(op, writeConfig, parseConfig);
  }

  // Test3: When writing bytecode, we override the encoding of
  // TestAttrParamsAttr with the encoding of builtin DenseIntElementsAttr. We
  // can natively parse this without the use of a callback, relying on the
  // existing builtin reader mechanism.
  void runTest3(Operation *op) {
    auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
    BytecodeDialectInterface *iface =
        builtin->getRegisteredInterface<BytecodeDialectInterface>();
    auto i32Type = IntegerType::get(op->getContext(), 32,
                                    IntegerType::SignednessSemantics::Signless);
    BytecodeWriterConfig writeConfig;
    writeConfig.attachAttributeCallback(
        [&](Attribute entryValue, std::optional<StringRef> &dialectGroupName,
            DialectBytecodeWriter &writer) -> LogicalResult {
          // Emit TestIntegerType using the builtin dialect encoding.
          if (auto testParamAttrs =
                  llvm::dyn_cast<test::TestAttrParamsAttr>(entryValue)) {
            llvm::outs() << "Overriding TestAttrParamsAttr encoding...\n";
            // Specify that this attribute will need to be written as part of
            // the builtin group. This will override the default dialect group
            // of the attribute (test).
            dialectGroupName = StringLiteral("builtin");
            auto denseAttr = DenseIntElementsAttr::get(
                RankedTensorType::get({2}, i32Type),
                {testParamAttrs.getV0(), testParamAttrs.getV1()});
            if (succeeded(iface->writeAttribute(denseAttr, writer)))
              return success();
          }
          return failure();
        });
    // We natively parse the attribute as a builtin, so no callback needed.
    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
    doRoundtripWithConfigs(op, writeConfig, parseConfig);
  }

  // Test4: When writing bytecode, we write standard builtin
  // DenseIntElementsAttr. At parsing, we use the encoding of
  // DenseIntElementsAttr to intercept all ElementsAttr that have shaped type of
  // <2xi32>. Instead of assembling a DenseIntElementsAttr, we assemble
  // TestAttrParamsAttr and return it.
  void runTest4(Operation *op) {
    auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
    BytecodeDialectInterface *iface =
        builtin->getRegisteredInterface<BytecodeDialectInterface>();
    auto i32Type = IntegerType::get(op->getContext(), 32,
                                    IntegerType::SignednessSemantics::Signless);
    BytecodeWriterConfig writeConfig;
    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
    parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
        [&](DialectBytecodeReader &reader, StringRef dialectName,
            Attribute &entry) -> LogicalResult {
          // Override only the case where the return type of the builtin reader
          // is an i32 and fall through on all the other cases, since we want to
          // still use TestDialect normal codepath to parse the other types.
          Attribute builtinAttr = iface->readAttribute(reader);
          if (auto denseAttr =
                  llvm::dyn_cast_or_null<DenseIntElementsAttr>(builtinAttr)) {
            if (denseAttr.getType().getShape() == ArrayRef<int64_t>(2) &&
                denseAttr.getElementType() == i32Type) {
              llvm::outs()
                  << "Overriding parsing of TestAttrParamsAttr encoding...\n";
              int v0 = denseAttr.getValues<IntegerAttr>()[0].getInt();
              int v1 = denseAttr.getValues<IntegerAttr>()[1].getInt();
              entry =
                  test::TestAttrParamsAttr::get(reader.getContext(), v0, v1);
            }
          }
          return success();
        });
    doRoundtripWithConfigs(op, writeConfig, parseConfig);
  }

  // Test5: When writing bytecode, we want TestDialect to use nothing else than
  // the builtin types and attributes and take full control of the encoding,
  // returning failure if any type or attribute is not part of builtin.
  void runTest5(Operation *op) {
    auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
    BytecodeDialectInterface *iface =
        builtin->getRegisteredInterface<BytecodeDialectInterface>();
    BytecodeWriterConfig writeConfig;
    writeConfig.attachAttributeCallback(
        [&](Attribute attr, std::optional<StringRef> &dialectGroupName,
            DialectBytecodeWriter &writer) -> LogicalResult {
          return iface->writeAttribute(attr, writer);
        });
    writeConfig.attachTypeCallback(
        [&](Type type, std::optional<StringRef> &dialectGroupName,
            DialectBytecodeWriter &writer) -> LogicalResult {
          return iface->writeType(type, writer);
        });
    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
    parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
        [&](DialectBytecodeReader &reader, StringRef dialectName,
            Attribute &entry) -> LogicalResult {
          Attribute builtinAttr = iface->readAttribute(reader);
          if (!builtinAttr)
            return failure();
          entry = builtinAttr;
          return success();
        });
    parseConfig.getBytecodeReaderConfig().attachTypeCallback(
        [&](DialectBytecodeReader &reader, StringRef dialectName,
            Type &entry) -> LogicalResult {
          Type builtinType = iface->readType(reader);
          if (!builtinType) {
            return failure();
          }
          entry = builtinType;
          return success();
        });
    doRoundtripWithConfigs(op, writeConfig, parseConfig);
  }

  LogicalResult downgradeToVersion(Operation *op,
                                   const test::TestDialectVersion &version) {
    if ((version.major_ == 2) && (version.minor_ == 0))
      return success();
    if (version.major_ > 2 || (version.major_ == 2 && version.minor_ > 0)) {
      return op->emitError() << "current test dialect version is 2.0, "
                                "can't downgrade to version: "
                             << version.major_ << "." << version.minor_;
    }
    // Prior version 2.0, the old op supported only a single attribute called
    // "dimensions". We need to check that the modifier is false, otherwise we
    // can't do the downgrade.
    auto status = op->walk([&](test::TestVersionedOpA op) {
      auto &prop = op.getProperties();
      if (prop.modifier.getValue()) {
        op->emitOpError() << "cannot downgrade to version " << version.major_
                          << "." << version.minor_
                          << " since the modifier is not compatible";
        return WalkResult::interrupt();
      }
      llvm::outs() << "downgrading op...\n";
      return WalkResult::advance();
    });
    return failure(status.wasInterrupted());
  }

  // Test6: Downgrade IR to `targetVersion`, write to bytecode. Then, read and
  // upgrade IR when back in memory. The module is expected to be unmodified at
  // the end of the function.
  void runTest6(Operation *op) {
    test::TestDialectVersion targetEmissionVersion = targetVersion;

    // Downgrade IR constructs before writing the IR to bytecode.
    auto status = downgradeToVersion(op, targetEmissionVersion);
    assert(succeeded(status) && "expected the downgrade to succeed");
    (void)status;

    BytecodeWriterConfig writeConfig;
    writeConfig.setDialectVersion<test::TestDialect>(
        std::make_unique<test::TestDialectVersion>(targetEmissionVersion));
    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
    doRoundtripWithConfigs(op, writeConfig, parseConfig);
  }

  test::TestDialect *testDialect;
};
} // namespace

namespace mlir {
void registerTestBytecodeRoundtripPasses() {
  PassRegistration<TestBytecodeRoundtripPass>();
}
} // namespace mlir