//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
// automatically. It is used by NVVM to LLVM pass.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"

#define DEBUG_TYPE "ptx-builder"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")

//===----------------------------------------------------------------------===//
// BasicPtxBuilderInterface
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc"

using namespace mlir;
using namespace NVVM;

static constexpr int64_t kSharedMemorySpace = 3;

static char getRegisterType(Type type) {
  if (type.isInteger(1))
    return 'b';
  if (type.isInteger(16))
    return 'h';
  if (type.isInteger(32))
    return 'r';
  if (type.isInteger(64))
    return 'l';
  if (type.isF32())
    return 'f';
  if (type.isF64())
    return 'd';
  if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
    // Shared address spaces is addressed with 32-bit pointers.
    if (ptr.getAddressSpace() == kSharedMemorySpace) {
      return 'r';
    }
    return 'l';
  }
  // register type for struct is not supported.
  llvm_unreachable("The register type could not deduced from MLIR type");
  return '?';
}

static char getRegisterType(Value v) {
  if (v.getDefiningOp<LLVM::ConstantOp>())
    return 'n';
  return getRegisterType(v.getType());
}

void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
  LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
  auto getModifier = [&]() -> const char * {
    if (itype == PTXRegisterMod::ReadWrite) {
      assert(false && "Read-Write modifier is not supported. Try setting the "
                      "same value as Write and Read separately.");
      return "+";
    }
    if (itype == PTXRegisterMod::Write) {
      return "=";
    }
    return "";
  };
  auto addValue = [&](Value v) {
    if (itype == PTXRegisterMod::Read) {
      ptxOperands.push_back(v);
      return;
    }
    if (itype == PTXRegisterMod::ReadWrite)
      ptxOperands.push_back(v);
    hasResult = true;
  };

  llvm::raw_string_ostream ss(registerConstraints);
  // Handle Structs
  if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
    if (itype == PTXRegisterMod::Write) {
      addValue(v);
    }
    for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
      if (itype != PTXRegisterMod::Write) {
        Value extractValue = rewriter.create<LLVM::ExtractValueOp>(
            interfaceOp->getLoc(), v, idx);
        addValue(extractValue);
      }
      if (itype == PTXRegisterMod::ReadWrite) {
        ss << idx << ",";
      } else {
        ss << getModifier() << getRegisterType(t) << ",";
      }
      ss.flush();
    }
    return;
  }
  // Handle Scalars
  addValue(v);
  ss << getModifier() << getRegisterType(v) << ",";
  ss.flush();
}

LLVM::InlineAsmOp PtxBuilder::build() {
  auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
                                                  LLVM::AsmDialect::AD_ATT);

  auto resultTypes = interfaceOp->getResultTypes();

  // Remove the last comma from the constraints string.
  if (!registerConstraints.empty() &&
      registerConstraints[registerConstraints.size() - 1] == ',')
    registerConstraints.pop_back();

  std::string ptxInstruction = interfaceOp.getPtx();

  // Add the predicate to the asm string.
  if (interfaceOp.getPredicate().has_value() &&
      interfaceOp.getPredicate().value()) {
    std::string predicateStr = "@%";
    predicateStr += std::to_string((ptxOperands.size() - 1));
    ptxInstruction = predicateStr + " " + ptxInstruction;
  }

  // Tablegen doesn't accept $, so we use %, but inline assembly uses $.
  // Replace all % with $
  std::replace(ptxInstruction.begin(), ptxInstruction.end(), '%', '$');

  return rewriter.create<LLVM::InlineAsmOp>(
      interfaceOp->getLoc(),
      /*result types=*/resultTypes,
      /*operands=*/ptxOperands,
      /*asm_string=*/llvm::StringRef(ptxInstruction),
      /*constraints=*/registerConstraints.data(),
      /*has_side_effects=*/interfaceOp.hasSideEffect(),
      /*is_align_stack=*/false,
      /*asm_dialect=*/asmDialectAttr,
      /*operand_attrs=*/ArrayAttr());
}

void PtxBuilder::buildAndReplaceOp() {
  LLVM::InlineAsmOp inlineAsmOp = build();
  LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
  if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
    rewriter.replaceOp(interfaceOp, inlineAsmOp);
  } else {
    rewriter.eraseOp(interfaceOp);
  }
}