#include <cstdint>
#include <optional>
#include <pybind11/cast.h>
#include <pybind11/detail/common.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <string>
#include <utility>
#include <vector>
#include "IRModule.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir-c/Interfaces.h"
#include "mlir-c/Support.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
namespace py = pybind11;
namespace mlir {
namespace python {
constexpr static const char *constructorDoc =
R"(Creates an interface from a given operation/opview object or from a
subclass of OpView. Raises ValueError if the operation does not implement the
interface.)";
constexpr static const char *operationDoc =
R"(Returns an Operation for which the interface was constructed.)";
constexpr static const char *opviewDoc =
R"(Returns an OpView subclass _instance_ for which the interface was
constructed)";
constexpr static const char *inferReturnTypesDoc =
R"(Given the arguments required to build an operation, attempts to infer
its return types. Raises ValueError on failure.)";
constexpr static const char *inferReturnTypeComponentsDoc =
R"(Given the arguments required to build an operation, attempts to infer
its return shaped type components. Raises ValueError on failure.)";
namespace {
llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
llvm::SmallVector<MlirValue> mlirOperands;
if (!operandList || operandList->empty()) {
return mlirOperands;
}
mlirOperands.reserve(operandList->size());
for (const auto &&it : llvm::enumerate(*operandList)) {
if (it.value().is_none())
continue;
PyValue *val;
try {
val = py::cast<PyValue *>(it.value());
if (!val)
throw py::cast_error();
mlirOperands.push_back(val->get());
continue;
} catch (py::cast_error &err) {
(void)err;
}
try {
auto vals = py::cast<py::sequence>(it.value());
for (py::object v : vals) {
try {
val = py::cast<PyValue *>(v);
if (!val)
throw py::cast_error();
mlirOperands.push_back(val->get());
} catch (py::cast_error &err) {
throw py::value_error(
(llvm::Twine("Operand ") + llvm::Twine(it.index()) +
" must be a Value or Sequence of Values (" + err.what() + ")")
.str());
}
}
continue;
} catch (py::cast_error &err) {
throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
" must be a Value or Sequence of Values (" +
err.what() + ")")
.str());
}
throw py::cast_error();
}
return mlirOperands;
}
llvm::SmallVector<MlirRegion>
wrapRegions(std::optional<std::vector<PyRegion>> regions) {
llvm::SmallVector<MlirRegion> mlirRegions;
if (regions) {
mlirRegions.reserve(regions->size());
for (PyRegion ®ion : *regions) {
mlirRegions.push_back(region);
}
}
return mlirRegions;
}
}
template <typename ConcreteIface>
class PyConcreteOpInterface {
protected:
using ClassTy = py::class_<ConcreteIface>;
using GetTypeIDFunctionTy = MlirTypeID (*)();
public:
PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
: obj(std::move(object)) {
try {
operation = &py::cast<PyOperation &>(obj);
} catch (py::cast_error &) {
}
try {
operation = &py::cast<PyOpView &>(obj).getOperation();
} catch (py::cast_error &) {
}
if (operation != nullptr) {
if (!mlirOperationImplementsInterface(*operation,
ConcreteIface::getInterfaceID())) {
std::string msg = "the operation does not implement ";
throw py::value_error(msg + ConcreteIface::pyClassName);
}
MlirIdentifier identifier = mlirOperationGetName(*operation);
MlirStringRef stringRef = mlirIdentifierStr(identifier);
opName = std::string(stringRef.data, stringRef.length);
} else {
try {
opName = obj.attr("OPERATION_NAME").template cast<std::string>();
} catch (py::cast_error &) {
throw py::type_error(
"Op interface does not refer to an operation or OpView class");
}
if (!mlirOperationImplementsInterfaceStatic(
mlirStringRefCreate(opName.data(), opName.length()),
context.resolve().get(), ConcreteIface::getInterfaceID())) {
std::string msg = "the operation does not implement ";
throw py::value_error(msg + ConcreteIface::pyClassName);
}
}
}
static void bind(py::module &m) {
py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName,
py::module_local());
cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
py::arg("context") = py::none(), constructorDoc)
.def_property_readonly("operation",
&PyConcreteOpInterface::getOperationObject,
operationDoc)
.def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
opviewDoc);
ConcreteIface::bindDerived(cls);
}
static void bindDerived(ClassTy &cls) {}
bool isStatic() { return operation == nullptr; }
py::object getOperationObject() {
if (operation == nullptr) {
throw py::type_error("Cannot get an operation from a static interface");
}
return operation->getRef().releaseObject();
}
py::object getOpView() {
if (operation == nullptr) {
throw py::type_error("Cannot get an opview from a static interface");
}
return operation->createOpView();
}
const std::string &getOpName() { return opName; }
private:
PyOperation *operation = nullptr;
std::string opName;
py::object obj;
};
class PyInferTypeOpInterface
: public PyConcreteOpInterface<PyInferTypeOpInterface> {
public:
using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
constexpr static const char *pyClassName = "InferTypeOpInterface";
constexpr static GetTypeIDFunctionTy getInterfaceID =
&mlirInferTypeOpInterfaceTypeID;
struct AppendResultsCallbackData {
std::vector<PyType> &inferredTypes;
PyMlirContext &pyMlirContext;
};
static void appendResultsCallback(intptr_t nTypes, MlirType *types,
void *userData) {
auto *data = static_cast<AppendResultsCallbackData *>(userData);
data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
for (intptr_t i = 0; i < nTypes; ++i) {
data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
}
}
std::vector<PyType>
inferReturnTypes(std::optional<py::list> operandList,
std::optional<PyAttribute> attributes, void *properties,
std::optional<std::vector<PyRegion>> regions,
DefaultingPyMlirContext context,
DefaultingPyLocation location) {
llvm::SmallVector<MlirValue> mlirOperands =
wrapOperands(std::move(operandList));
llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
std::vector<PyType> inferredTypes;
PyMlirContext &pyContext = context.resolve();
AppendResultsCallbackData data{inferredTypes, pyContext};
MlirStringRef opNameRef =
mlirStringRefCreate(getOpName().data(), getOpName().length());
MlirAttribute attributeDict =
attributes ? attributes->get() : mlirAttributeGetNull();
MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
mlirRegions.data(), &appendResultsCallback, &data);
if (mlirLogicalResultIsFailure(result)) {
throw py::value_error("Failed to infer result types");
}
return inferredTypes;
}
static void bindDerived(ClassTy &cls) {
cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
py::arg("operands") = py::none(),
py::arg("attributes") = py::none(),
py::arg("properties") = py::none(), py::arg("regions") = py::none(),
py::arg("context") = py::none(), py::arg("loc") = py::none(),
inferReturnTypesDoc);
}
};
class PyShapedTypeComponents {
public:
PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
PyShapedTypeComponents(py::list shape, MlirType elementType)
: shape(std::move(shape)), elementType(elementType), ranked(true) {}
PyShapedTypeComponents(py::list shape, MlirType elementType,
MlirAttribute attribute)
: shape(std::move(shape)), elementType(elementType), attribute(attribute),
ranked(true) {}
PyShapedTypeComponents(PyShapedTypeComponents &) = delete;
PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept
: shape(other.shape), elementType(other.elementType),
attribute(other.attribute), ranked(other.ranked) {}
static void bind(py::module &m) {
py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents",
py::module_local())
.def_property_readonly(
"element_type",
[](PyShapedTypeComponents &self) { return self.elementType; },
"Returns the element type of the shaped type components.")
.def_static(
"get",
[](PyType &elementType) {
return PyShapedTypeComponents(elementType);
},
py::arg("element_type"),
"Create an shaped type components object with only the element "
"type.")
.def_static(
"get",
[](py::list shape, PyType &elementType) {
return PyShapedTypeComponents(std::move(shape), elementType);
},
py::arg("shape"), py::arg("element_type"),
"Create a ranked shaped type components object.")
.def_static(
"get",
[](py::list shape, PyType &elementType, PyAttribute &attribute) {
return PyShapedTypeComponents(std::move(shape), elementType,
attribute);
},
py::arg("shape"), py::arg("element_type"), py::arg("attribute"),
"Create a ranked shaped type components object with attribute.")
.def_property_readonly(
"has_rank",
[](PyShapedTypeComponents &self) -> bool { return self.ranked; },
"Returns whether the given shaped type component is ranked.")
.def_property_readonly(
"rank",
[](PyShapedTypeComponents &self) -> py::object {
if (!self.ranked) {
return py::none();
}
return py::int_(self.shape.size());
},
"Returns the rank of the given ranked shaped type components. If "
"the shaped type components does not have a rank, None is "
"returned.")
.def_property_readonly(
"shape",
[](PyShapedTypeComponents &self) -> py::object {
if (!self.ranked) {
return py::none();
}
return py::list(self.shape);
},
"Returns the shape of the ranked shaped type components as a list "
"of integers. Returns none if the shaped type component does not "
"have a rank.");
}
pybind11::object getCapsule();
static PyShapedTypeComponents createFromCapsule(pybind11::object capsule);
private:
py::list shape;
MlirType elementType;
MlirAttribute attribute;
bool ranked{false};
};
class PyInferShapedTypeOpInterface
: public PyConcreteOpInterface<PyInferShapedTypeOpInterface> {
public:
using PyConcreteOpInterface<
PyInferShapedTypeOpInterface>::PyConcreteOpInterface;
constexpr static const char *pyClassName = "InferShapedTypeOpInterface";
constexpr static GetTypeIDFunctionTy getInterfaceID =
&mlirInferShapedTypeOpInterfaceTypeID;
struct AppendResultsCallbackData {
std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents;
};
static void appendResultsCallback(bool hasRank, intptr_t rank,
const int64_t *shape, MlirType elementType,
MlirAttribute attribute, void *userData) {
auto *data = static_cast<AppendResultsCallbackData *>(userData);
if (!hasRank) {
data->inferredShapedTypeComponents.emplace_back(elementType);
} else {
py::list shapeList;
for (intptr_t i = 0; i < rank; ++i) {
shapeList.append(shape[i]);
}
data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
attribute);
}
}
std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
std::optional<py::list> operandList,
std::optional<PyAttribute> attributes, void *properties,
std::optional<std::vector<PyRegion>> regions,
DefaultingPyMlirContext context, DefaultingPyLocation location) {
llvm::SmallVector<MlirValue> mlirOperands =
wrapOperands(std::move(operandList));
llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
PyMlirContext &pyContext = context.resolve();
AppendResultsCallbackData data{inferredShapedTypeComponents};
MlirStringRef opNameRef =
mlirStringRefCreate(getOpName().data(), getOpName().length());
MlirAttribute attributeDict =
attributes ? attributes->get() : mlirAttributeGetNull();
MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes(
opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
mlirRegions.data(), &appendResultsCallback, &data);
if (mlirLogicalResultIsFailure(result)) {
throw py::value_error("Failed to infer result shape type components");
}
return inferredShapedTypeComponents;
}
static void bindDerived(ClassTy &cls) {
cls.def("inferReturnTypeComponents",
&PyInferShapedTypeOpInterface::inferReturnTypeComponents,
py::arg("operands") = py::none(),
py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
py::arg("properties") = py::none(), py::arg("context") = py::none(),
py::arg("loc") = py::none(), inferReturnTypeComponentsDoc);
}
};
void populateIRInterfaces(py::module &m) {
PyInferTypeOpInterface::bind(m);
PyShapedTypeComponents::bind(m);
PyInferShapedTypeOpInterface::bind(m);
}
}
}