//===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// EnumPythonBindingGen uses ODS specification of MLIR enum attributes to
// generate the corresponding Python binding classes.
//
//===----------------------------------------------------------------------===//
#include "OpGenHelpers.h"

#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Dialect.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Record.h"

using namespace mlir;
using namespace mlir::tblgen;

/// File header and includes.
constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.

from enum import IntEnum, auto, IntFlag
from ._ods_common import _cext as _ods_cext
from ..ir import register_attribute_builder
_ods_ir = _ods_cext.ir

)Py";

/// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
static std::string makePythonEnumCaseName(StringRef name) {
  if (isPythonReserved(name.str()))
    return (name + "_").str();
  return name.str();
}

/// Emits the Python class for the given enum.
static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
  os << llvm::formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
                      enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
  if (!enumAttr.getSummary().empty())
    os << llvm::formatv("    \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
  os << "\n";

  for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
    os << llvm::formatv(
        "    {0} = {1}\n", makePythonEnumCaseName(enumCase.getSymbol()),
        enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
                                 : "auto()");
  }

  os << "\n";

  if (enumAttr.isBitEnum()) {
    os << llvm::formatv("    def __iter__(self):\n"
                        "        return iter([case for case in type(self) if "
                        "(self & case) is case])\n");
    os << llvm::formatv("    def __len__(self):\n"
                        "        return bin(self).count(\"1\")\n");
    os << "\n";
  }

  os << llvm::formatv("    def __str__(self):\n");
  if (enumAttr.isBitEnum())
    os << llvm::formatv("        if len(self) > 1:\n"
                        "            return \"{0}\".join(map(str, self))\n",
                        enumAttr.getDef().getValueAsString("separator"));
  for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
    os << llvm::formatv("        if self is {0}.{1}:\n",
                        enumAttr.getEnumClassName(),
                        makePythonEnumCaseName(enumCase.getSymbol()));
    os << llvm::formatv("            return \"{0}\"\n", enumCase.getStr());
  }
  os << llvm::formatv(
      "        raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
      enumAttr.getEnumClassName());
  os << "\n";
}

/// Attempts to extract the bitwidth B from string "uintB_t" describing the
/// type. This bitwidth information is not readily available in ODS. Returns
/// `false` on success, `true` on failure.
static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
  if (!uintType.consume_front("uint"))
    return true;
  if (!uintType.consume_back("_t"))
    return true;
  return uintType.getAsInteger(/*Radix=*/10, bitwidth);
}

/// Emits an attribute builder for the given enum attribute to support automatic
/// conversion between enum values and attributes in Python. Returns
/// `false` on success, `true` on failure.
static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
  int64_t bitwidth;
  if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) {
    llvm::errs() << "failed to identify bitwidth of "
                 << enumAttr.getUnderlyingType();
    return true;
  }

  os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
                      enumAttr.getAttrDefName());
  os << llvm::formatv("def _{0}(x, context):\n",
                      enumAttr.getAttrDefName().lower());
  os << llvm::formatv(
      "    return "
      "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
      "context=context), int(x))\n\n",
      bitwidth);
  return false;
}

/// Emits an attribute builder for the given dialect enum attribute to support
/// automatic conversion between enum values and attributes in Python. Returns
/// `false` on success, `true` on failure.
static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
                                            StringRef formatString,
                                            raw_ostream &os) {
  os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
  os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
  os << llvm::formatv("    return "
                      "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
                      formatString);
  return false;
}

/// Emits Python bindings for all enums in the record keeper. Returns
/// `false` on success, `true` on failure.
static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
                            raw_ostream &os) {
  os << fileHeader;
  for (auto &it :
       recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
    EnumAttr enumAttr(*it);
    emitEnumClass(enumAttr, os);
    emitAttributeBuilder(enumAttr, os);
  }
  for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
    AttrOrTypeDef attr(&*it);
    if (!attr.getMnemonic()) {
      llvm::errs() << "enum case " << attr
                   << " needs mnemonic for python enum bindings generation";
      return true;
    }
    StringRef mnemonic = attr.getMnemonic().value();
    std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
    StringRef dialect = attr.getDialect().getName();
    if (assemblyFormat == "`<` $value `>`") {
      emitDialectEnumAttributeBuilder(
          attr.getName(),
          llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
    } else if (assemblyFormat == "$value") {
      emitDialectEnumAttributeBuilder(
          attr.getName(),
          llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
    } else {
      llvm::errs()
          << "unsupported assembly format for python enum bindings generation";
      return true;
    }
  }

  return false;
}

// Registers the enum utility generator to mlir-tblgen.
static mlir::GenRegistration
    genPythonEnumBindings("gen-python-enum-bindings",
                          "Generate Python bindings for enum attributes",
                          &emitPythonEnums);