#ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
#define MLIR_BINDINGS_PYTHON_IRMODULES_H
#include <optional>
#include <utility>
#include <vector>
#include "Globals.h"
#include "PybindUtils.h"
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "llvm/ADT/DenseMap.h"
namespace mlir {
namespace python {
class PyBlock;
class PyDiagnostic;
class PyDiagnosticHandler;
class PyInsertionPoint;
class PyLocation;
class DefaultingPyLocation;
class PyMlirContext;
class DefaultingPyMlirContext;
class PyModule;
class PyOperation;
class PyOperationBase;
class PyType;
class PySymbolTable;
class PyValue;
template <typename T>
class PyObjectRef {
public:
PyObjectRef(T *referrent, pybind11::object object)
: referrent(referrent), object(std::move(object)) {
assert(this->referrent &&
"cannot construct PyObjectRef with null referrent");
assert(this->object && "cannot construct PyObjectRef with null object");
}
PyObjectRef(PyObjectRef &&other) noexcept
: referrent(other.referrent), object(std::move(other.object)) {
other.referrent = nullptr;
assert(!other.object);
}
PyObjectRef(const PyObjectRef &other)
: referrent(other.referrent), object(other.object ) {}
~PyObjectRef() = default;
int getRefCount() {
if (!object)
return 0;
return object.ref_count();
}
pybind11::object releaseObject() {
assert(referrent && object);
referrent = nullptr;
auto stolen = std::move(object);
return stolen;
}
T *get() { return referrent; }
T *operator->() {
assert(referrent && object);
return referrent;
}
pybind11::object getObject() {
assert(referrent && object);
return object;
}
operator bool() const { return referrent && object; }
private:
T *referrent;
pybind11::object object;
};
class PyThreadContextEntry {
public:
enum class FrameKind {
Context,
InsertionPoint,
Location,
};
PyThreadContextEntry(FrameKind frameKind, pybind11::object context,
pybind11::object insertionPoint,
pybind11::object location)
: context(std::move(context)), insertionPoint(std::move(insertionPoint)),
location(std::move(location)), frameKind(frameKind) {}
static PyMlirContext *getDefaultContext();
static PyInsertionPoint *getDefaultInsertionPoint();
static PyLocation *getDefaultLocation();
PyMlirContext *getContext();
PyInsertionPoint *getInsertionPoint();
PyLocation *getLocation();
FrameKind getFrameKind() { return frameKind; }
static PyThreadContextEntry *getTopOfStack();
static pybind11::object pushContext(PyMlirContext &context);
static void popContext(PyMlirContext &context);
static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint);
static void popInsertionPoint(PyInsertionPoint &insertionPoint);
static pybind11::object pushLocation(PyLocation &location);
static void popLocation(PyLocation &location);
static std::vector<PyThreadContextEntry> &getStack();
private:
static void push(FrameKind frameKind, pybind11::object context,
pybind11::object insertionPoint, pybind11::object location);
pybind11::object context;
pybind11::object insertionPoint;
pybind11::object location;
FrameKind frameKind;
};
using PyMlirContextRef = PyObjectRef<PyMlirContext>;
class PyMlirContext {
public:
PyMlirContext() = delete;
PyMlirContext(const PyMlirContext &) = delete;
PyMlirContext(PyMlirContext &&) = delete;
static PyMlirContext *createNewContextForInit();
static PyMlirContextRef forContext(MlirContext context);
~PyMlirContext();
MlirContext get() { return context; }
PyMlirContextRef getRef() {
return PyMlirContextRef(this, pybind11::cast(this));
}
pybind11::object getCapsule();
static pybind11::object createFromCapsule(pybind11::object capsule);
static size_t getLiveCount();
std::vector<PyOperation *> getLiveOperationObjects();
size_t getLiveOperationCount();
size_t clearLiveOperations();
void clearOperation(MlirOperation op);
void clearOperationsInside(PyOperationBase &op);
void clearOperationsInside(MlirOperation op);
void clearOperationAndInside(PyOperationBase &op);
size_t getLiveModuleCount();
pybind11::object contextEnter();
void contextExit(const pybind11::object &excType,
const pybind11::object &excVal,
const pybind11::object &excTb);
pybind11::object attachDiagnosticHandler(pybind11::object callback);
void setEmitErrorDiagnostics(bool value) { emitErrorDiagnostics = value; }
struct ErrorCapture;
private:
PyMlirContext(MlirContext context);
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
static LiveContextMap &getLiveContexts();
using LiveModuleMap =
llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
LiveModuleMap liveModules;
using LiveOperationMap =
llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
LiveOperationMap liveOperations;
bool emitErrorDiagnostics = false;
MlirContext context;
friend class PyModule;
friend class PyOperation;
};
class DefaultingPyMlirContext
: public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
public:
using Defaulting::Defaulting;
static constexpr const char kTypeDescription[] = "mlir.ir.Context";
static PyMlirContext &resolve();
};
class BaseContextObject {
public:
BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {
assert(this->contextRef &&
"context object constructed with null context ref");
}
PyMlirContextRef &getContext() { return contextRef; }
private:
PyMlirContextRef contextRef;
};
class PyLocation : public BaseContextObject {
public:
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
: BaseContextObject(std::move(contextRef)), loc(loc) {}
operator MlirLocation() const { return loc; }
MlirLocation get() const { return loc; }
pybind11::object contextEnter();
void contextExit(const pybind11::object &excType,
const pybind11::object &excVal,
const pybind11::object &excTb);
pybind11::object getCapsule();
static PyLocation createFromCapsule(pybind11::object capsule);
private:
MlirLocation loc;
};
class PyDiagnostic {
public:
PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {}
void invalidate();
bool isValid() { return valid; }
MlirDiagnosticSeverity getSeverity();
PyLocation getLocation();
pybind11::str getMessage();
pybind11::tuple getNotes();
struct DiagnosticInfo {
MlirDiagnosticSeverity severity;
PyLocation location;
std::string message;
std::vector<DiagnosticInfo> notes;
};
DiagnosticInfo getInfo();
private:
MlirDiagnostic diagnostic;
void checkValid();
std::optional<pybind11::tuple> materializedNotes;
bool valid = true;
};
class PyDiagnosticHandler {
public:
PyDiagnosticHandler(MlirContext context, pybind11::object callback);
~PyDiagnosticHandler();
bool isAttached() { return registeredID.has_value(); }
bool getHadError() { return hadError; }
void detach();
pybind11::object contextEnter() { return pybind11::cast(this); }
void contextExit(const pybind11::object &excType,
const pybind11::object &excVal,
const pybind11::object &excTb) {
detach();
}
private:
MlirContext context;
pybind11::object callback;
std::optional<MlirDiagnosticHandlerID> registeredID;
bool hadError = false;
friend class PyMlirContext;
};
struct PyMlirContext::ErrorCapture {
ErrorCapture(PyMlirContextRef ctx)
: ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler(
ctx->get(), handler, this,
nullptr)) {}
~ErrorCapture() {
mlirContextDetachDiagnosticHandler(ctx->get(), handlerID);
assert(errors.empty() && "unhandled captured errors");
}
std::vector<PyDiagnostic::DiagnosticInfo> take() {
return std::move(errors);
};
private:
PyMlirContextRef ctx;
MlirDiagnosticHandlerID handlerID;
std::vector<PyDiagnostic::DiagnosticInfo> errors;
static MlirLogicalResult handler(MlirDiagnostic diag, void *userData);
};
class PyDialectDescriptor : public BaseContextObject {
public:
PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect)
: BaseContextObject(std::move(contextRef)), dialect(dialect) {}
MlirDialect get() { return dialect; }
private:
MlirDialect dialect;
};
class PyDialects : public BaseContextObject {
public:
PyDialects(PyMlirContextRef contextRef)
: BaseContextObject(std::move(contextRef)) {}
MlirDialect getDialectForKey(const std::string &key, bool attrError);
};
class PyDialect {
public:
PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {}
pybind11::object getDescriptor() { return descriptor; }
private:
pybind11::object descriptor;
};
class PyDialectRegistry {
public:
PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {}
PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {}
~PyDialectRegistry() {
if (!mlirDialectRegistryIsNull(registry))
mlirDialectRegistryDestroy(registry);
}
PyDialectRegistry(PyDialectRegistry &) = delete;
PyDialectRegistry(PyDialectRegistry &&other) noexcept
: registry(other.registry) {
other.registry = {nullptr};
}
operator MlirDialectRegistry() const { return registry; }
MlirDialectRegistry get() const { return registry; }
pybind11::object getCapsule();
static PyDialectRegistry createFromCapsule(pybind11::object capsule);
private:
MlirDialectRegistry registry;
};
class DefaultingPyLocation
: public Defaulting<DefaultingPyLocation, PyLocation> {
public:
using Defaulting::Defaulting;
static constexpr const char kTypeDescription[] = "mlir.ir.Location";
static PyLocation &resolve();
operator MlirLocation() const { return *get(); }
};
class PyModule;
using PyModuleRef = PyObjectRef<PyModule>;
class PyModule : public BaseContextObject {
public:
static PyModuleRef forModule(MlirModule module);
PyModule(PyModule &) = delete;
PyModule(PyMlirContext &&) = delete;
~PyModule();
MlirModule get() { return module; }
PyModuleRef getRef() {
return PyModuleRef(this,
pybind11::reinterpret_borrow<pybind11::object>(handle));
}
pybind11::object getCapsule();
static pybind11::object createFromCapsule(pybind11::object capsule);
private:
PyModule(PyMlirContextRef contextRef, MlirModule module);
MlirModule module;
pybind11::handle handle;
};
class PyAsmState;
class PyOperationBase {
public:
virtual ~PyOperationBase() = default;
void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
bool assumeVerified, py::object fileObject, bool binary,
bool skipRegions);
void print(PyAsmState &state, py::object fileObject, bool binary);
pybind11::object getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
bool assumeVerified, bool skipRegions);
void writeBytecode(const pybind11::object &fileObject,
std::optional<int64_t> bytecodeVersion);
void walk(std::function<MlirWalkResult(MlirOperation)> callback,
MlirWalkOrder walkOrder);
void moveAfter(PyOperationBase &other);
void moveBefore(PyOperationBase &other);
bool verify();
virtual PyOperation &getOperation() = 0;
};
class PyOperation;
using PyOperationRef = PyObjectRef<PyOperation>;
class PyOperation : public PyOperationBase, public BaseContextObject {
public:
~PyOperation() override;
PyOperation &getOperation() override { return *this; }
static PyOperationRef
forOperation(PyMlirContextRef contextRef, MlirOperation operation,
pybind11::object parentKeepAlive = pybind11::object());
static PyOperationRef
createDetached(PyMlirContextRef contextRef, MlirOperation operation,
pybind11::object parentKeepAlive = pybind11::object());
static PyOperationRef parse(PyMlirContextRef contextRef,
const std::string &sourceStr,
const std::string &sourceName);
void detachFromParent() {
mlirOperationRemoveFromParent(getOperation());
setDetached();
parentKeepAlive = pybind11::object();
}
operator MlirOperation() const { return get(); }
MlirOperation get() const {
checkValid();
return operation;
}
PyOperationRef getRef() {
return PyOperationRef(
this, pybind11::reinterpret_borrow<pybind11::object>(handle));
}
bool isAttached() { return attached; }
void setAttached(const pybind11::object &parent = pybind11::object()) {
assert(!attached && "operation already attached");
attached = true;
}
void setDetached() {
assert(attached && "operation already detached");
attached = false;
}
void checkValid() const;
PyBlock getBlock();
std::optional<PyOperationRef> getParentOperation();
pybind11::object getCapsule();
static pybind11::object createFromCapsule(pybind11::object capsule);
static pybind11::object
create(const std::string &name, std::optional<std::vector<PyType *>> results,
std::optional<std::vector<PyValue *>> operands,
std::optional<pybind11::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
DefaultingPyLocation location, const pybind11::object &ip,
bool inferType);
pybind11::object createOpView();
void erase();
void setInvalid() { valid = false; }
pybind11::object clone(const pybind11::object &ip);
private:
PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
static PyOperationRef createInstance(PyMlirContextRef contextRef,
MlirOperation operation,
pybind11::object parentKeepAlive);
MlirOperation operation;
pybind11::handle handle;
pybind11::object parentKeepAlive;
bool attached = true;
bool valid = true;
friend class PyOperationBase;
friend class PySymbolTable;
};
class PyOpView : public PyOperationBase {
public:
PyOpView(const pybind11::object &operationObject);
PyOperation &getOperation() override { return operation; }
pybind11::object getOperationObject() { return operationObject; }
static pybind11::object buildGeneric(
const pybind11::object &cls, std::optional<pybind11::list> resultTypeList,
pybind11::list operandList, std::optional<pybind11::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
const pybind11::object &maybeIp);
static pybind11::object constructDerived(const pybind11::object &cls,
const PyOperation &operation);
private:
PyOperation &operation;
pybind11::object operationObject;
};
class PyRegion {
public:
PyRegion(PyOperationRef parentOperation, MlirRegion region)
: parentOperation(std::move(parentOperation)), region(region) {
assert(!mlirRegionIsNull(region) && "python region cannot be null");
}
operator MlirRegion() const { return region; }
MlirRegion get() { return region; }
PyOperationRef &getParentOperation() { return parentOperation; }
void checkValid() { return parentOperation->checkValid(); }
private:
PyOperationRef parentOperation;
MlirRegion region;
};
class PyAsmState {
public:
PyAsmState(MlirValue value, bool useLocalScope) {
flags = mlirOpPrintingFlagsCreate();
if (useLocalScope)
mlirOpPrintingFlagsUseLocalScope(flags);
state = mlirAsmStateCreateForValue(value, flags);
}
PyAsmState(PyOperationBase &operation, bool useLocalScope) {
flags = mlirOpPrintingFlagsCreate();
if (useLocalScope)
mlirOpPrintingFlagsUseLocalScope(flags);
state =
mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
}
~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); }
PyAsmState(PyAsmState &other) = delete;
PyAsmState(const PyAsmState &other) = delete;
MlirAsmState get() { return state; }
private:
MlirAsmState state;
MlirOpPrintingFlags flags;
};
class PyBlock {
public:
PyBlock(PyOperationRef parentOperation, MlirBlock block)
: parentOperation(std::move(parentOperation)), block(block) {
assert(!mlirBlockIsNull(block) && "python block cannot be null");
}
MlirBlock get() { return block; }
PyOperationRef &getParentOperation() { return parentOperation; }
void checkValid() { return parentOperation->checkValid(); }
pybind11::object getCapsule();
private:
PyOperationRef parentOperation;
MlirBlock block;
};
class PyInsertionPoint {
public:
PyInsertionPoint(PyBlock &block);
PyInsertionPoint(PyOperationBase &beforeOperationBase);
static PyInsertionPoint atBlockBegin(PyBlock &block);
static PyInsertionPoint atBlockTerminator(PyBlock &block);
void insert(PyOperationBase &operationBase);
pybind11::object contextEnter();
void contextExit(const pybind11::object &excType,
const pybind11::object &excVal,
const pybind11::object &excTb);
PyBlock &getBlock() { return block; }
std::optional<PyOperationRef> &getRefOperation() { return refOperation; }
private:
PyInsertionPoint(PyBlock block, std::optional<PyOperationRef> refOperation)
: refOperation(std::move(refOperation)), block(std::move(block)) {}
std::optional<PyOperationRef> refOperation;
PyBlock block;
};
class PyType : public BaseContextObject {
public:
PyType(PyMlirContextRef contextRef, MlirType type)
: BaseContextObject(std::move(contextRef)), type(type) {}
bool operator==(const PyType &other) const;
operator MlirType() const { return type; }
MlirType get() const { return type; }
pybind11::object getCapsule();
static PyType createFromCapsule(pybind11::object capsule);
private:
MlirType type;
};
class PyTypeID {
public:
PyTypeID(MlirTypeID typeID) : typeID(typeID) {}
bool operator==(const PyTypeID &other) const;
operator MlirTypeID() const { return typeID; }
MlirTypeID get() { return typeID; }
pybind11::object getCapsule();
static PyTypeID createFromCapsule(pybind11::object capsule);
private:
MlirTypeID typeID;
};
template <typename DerivedTy, typename BaseTy = PyType>
class PyConcreteType : public BaseTy {
public:
using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
using IsAFunctionTy = bool (*)(MlirType);
using GetTypeIDFunctionTy = MlirTypeID (*)();
static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
PyConcreteType() = default;
PyConcreteType(PyMlirContextRef contextRef, MlirType t)
: BaseTy(std::move(contextRef), t) {}
PyConcreteType(PyType &orig)
: PyConcreteType(orig.getContext(), castFrom(orig)) {}
static MlirType castFrom(PyType &orig) {
if (!DerivedTy::isaFunction(orig)) {
auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
throw py::value_error((llvm::Twine("Cannot cast type to ") +
DerivedTy::pyClassName + " (from " + origRepr +
")")
.str());
}
return orig;
}
static void bind(pybind11::module &m) {
auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local());
cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>(),
pybind11::arg("cast_from_type"));
cls.def_static(
"isinstance",
[](PyType &otherType) -> bool {
return DerivedTy::isaFunction(otherType);
},
pybind11::arg("other"));
cls.def_property_readonly_static(
"static_typeid", [](py::object & ) -> MlirTypeID {
if (DerivedTy::getTypeIdFunction)
return DerivedTy::getTypeIdFunction();
throw py::attribute_error(
(DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str());
});
cls.def_property_readonly("typeid", [](PyType &self) {
return py::cast(self).attr("typeid").cast<MlirTypeID>();
});
cls.def("__repr__", [](DerivedTy &self) {
PyPrintAccumulator printAccum;
printAccum.parts.append(DerivedTy::pyClassName);
printAccum.parts.append("(");
mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
printAccum.parts.append(")");
return printAccum.join();
});
if (DerivedTy::getTypeIdFunction) {
PyGlobals::get().registerTypeCaster(
DerivedTy::getTypeIdFunction(),
pybind11::cpp_function(
[](PyType pyType) -> DerivedTy { return pyType; }));
}
DerivedTy::bindDerived(cls);
}
static void bindDerived(ClassTy &m) {}
};
class PyAttribute : public BaseContextObject {
public:
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
: BaseContextObject(std::move(contextRef)), attr(attr) {}
bool operator==(const PyAttribute &other) const;
operator MlirAttribute() const { return attr; }
MlirAttribute get() const { return attr; }
pybind11::object getCapsule();
static PyAttribute createFromCapsule(pybind11::object capsule);
private:
MlirAttribute attr;
};
class PyNamedAttribute {
public:
PyNamedAttribute(MlirAttribute attr, std::string ownedName);
MlirNamedAttribute namedAttr;
private:
std::unique_ptr<std::string> ownedName;
};
template <typename DerivedTy, typename BaseTy = PyAttribute>
class PyConcreteAttribute : public BaseTy {
public:
using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
using IsAFunctionTy = bool (*)(MlirAttribute);
using GetTypeIDFunctionTy = MlirTypeID (*)();
static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
PyConcreteAttribute() = default;
PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
: BaseTy(std::move(contextRef), attr) {}
PyConcreteAttribute(PyAttribute &orig)
: PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
static MlirAttribute castFrom(PyAttribute &orig) {
if (!DerivedTy::isaFunction(orig)) {
auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
throw py::value_error((llvm::Twine("Cannot cast attribute to ") +
DerivedTy::pyClassName + " (from " + origRepr +
")")
.str());
}
return orig;
}
static void bind(pybind11::module &m) {
auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(),
pybind11::module_local());
cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>(),
pybind11::arg("cast_from_attr"));
cls.def_static(
"isinstance",
[](PyAttribute &otherAttr) -> bool {
return DerivedTy::isaFunction(otherAttr);
},
pybind11::arg("other"));
cls.def_property_readonly(
"type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); });
cls.def_property_readonly_static(
"static_typeid", [](py::object & ) -> MlirTypeID {
if (DerivedTy::getTypeIdFunction)
return DerivedTy::getTypeIdFunction();
throw py::attribute_error(
(DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str());
});
cls.def_property_readonly("typeid", [](PyAttribute &self) {
return py::cast(self).attr("typeid").cast<MlirTypeID>();
});
cls.def("__repr__", [](DerivedTy &self) {
PyPrintAccumulator printAccum;
printAccum.parts.append(DerivedTy::pyClassName);
printAccum.parts.append("(");
mlirAttributePrint(self, printAccum.getCallback(),
printAccum.getUserData());
printAccum.parts.append(")");
return printAccum.join();
});
if (DerivedTy::getTypeIdFunction) {
PyGlobals::get().registerTypeCaster(
DerivedTy::getTypeIdFunction(),
pybind11::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
return pyAttribute;
}));
}
DerivedTy::bindDerived(cls);
}
static void bindDerived(ClassTy &m) {}
};
class PyValue {
public:
virtual ~PyValue() = default;
PyValue(PyOperationRef parentOperation, MlirValue value)
: parentOperation(std::move(parentOperation)), value(value) {}
operator MlirValue() const { return value; }
MlirValue get() { return value; }
PyOperationRef &getParentOperation() { return parentOperation; }
void checkValid() { return parentOperation->checkValid(); }
pybind11::object getCapsule();
pybind11::object maybeDownCast();
static PyValue createFromCapsule(pybind11::object capsule);
private:
PyOperationRef parentOperation;
MlirValue value;
};
class PyAffineExpr : public BaseContextObject {
public:
PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
: BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {}
bool operator==(const PyAffineExpr &other) const;
operator MlirAffineExpr() const { return affineExpr; }
MlirAffineExpr get() const { return affineExpr; }
pybind11::object getCapsule();
static PyAffineExpr createFromCapsule(pybind11::object capsule);
PyAffineExpr add(const PyAffineExpr &other) const;
PyAffineExpr mul(const PyAffineExpr &other) const;
PyAffineExpr floorDiv(const PyAffineExpr &other) const;
PyAffineExpr ceilDiv(const PyAffineExpr &other) const;
PyAffineExpr mod(const PyAffineExpr &other) const;
private:
MlirAffineExpr affineExpr;
};
class PyAffineMap : public BaseContextObject {
public:
PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap)
: BaseContextObject(std::move(contextRef)), affineMap(affineMap) {}
bool operator==(const PyAffineMap &other) const;
operator MlirAffineMap() const { return affineMap; }
MlirAffineMap get() const { return affineMap; }
pybind11::object getCapsule();
static PyAffineMap createFromCapsule(pybind11::object capsule);
private:
MlirAffineMap affineMap;
};
class PyIntegerSet : public BaseContextObject {
public:
PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet)
: BaseContextObject(std::move(contextRef)), integerSet(integerSet) {}
bool operator==(const PyIntegerSet &other) const;
operator MlirIntegerSet() const { return integerSet; }
MlirIntegerSet get() const { return integerSet; }
pybind11::object getCapsule();
static PyIntegerSet createFromCapsule(pybind11::object capsule);
private:
MlirIntegerSet integerSet;
};
class PySymbolTable {
public:
explicit PySymbolTable(PyOperationBase &operation);
~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); }
pybind11::object dunderGetItem(const std::string &name);
void erase(PyOperationBase &symbol);
void dunderDel(const std::string &name);
MlirAttribute insert(PyOperationBase &symbol);
static MlirAttribute getSymbolName(PyOperationBase &symbol);
static void setSymbolName(PyOperationBase &symbol, const std::string &name);
static MlirAttribute getVisibility(PyOperationBase &symbol);
static void setVisibility(PyOperationBase &symbol,
const std::string &visibility);
static void replaceAllSymbolUses(const std::string &oldSymbol,
const std::string &newSymbol,
PyOperationBase &from);
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible,
pybind11::object callback);
operator MlirSymbolTable() { return symbolTable; }
private:
PyOperationRef operation;
MlirSymbolTable symbolTable;
};
struct MLIRError {
MLIRError(llvm::Twine message,
std::vector<PyDiagnostic::DiagnosticInfo> &&errorDiagnostics = {})
: message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {}
std::string message;
std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
};
void populateIRAffine(pybind11::module &m);
void populateIRAttributes(pybind11::module &m);
void populateIRCore(pybind11::module &m);
void populateIRInterfaces(pybind11::module &m);
void populateIRTypes(pybind11::module &m);
}
}
namespace pybind11 {
namespace detail {
template <>
struct type_caster<mlir::python::DefaultingPyMlirContext>
: MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {};
template <>
struct type_caster<mlir::python::DefaultingPyLocation>
: MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {};
}
}
#endif