#include "IRNumbering.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/Bytecode/Encoding.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
using namespace mlir;
using namespace mlir::bytecode::detail;
struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
NumberingDialectWriter(
IRNumberingState &state,
llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap)
: state(state), dialectVersionMap(dialectVersionMap) {}
void writeAttribute(Attribute attr) override { state.number(attr); }
void writeOptionalAttribute(Attribute attr) override {
if (attr)
state.number(attr);
}
void writeType(Type type) override { state.number(type); }
void writeResourceHandle(const AsmDialectResourceHandle &resource) override {
state.number(resource.getDialect(), resource);
}
void writeVarInt(uint64_t) override {}
void writeSignedVarInt(int64_t value) override {}
void writeAPIntWithKnownWidth(const APInt &value) override {}
void writeAPFloatWithKnownSemantics(const APFloat &value) override {}
void writeOwnedString(StringRef) override {
}
void writeOwnedBlob(ArrayRef<char> blob) override {}
void writeOwnedBool(bool value) override {}
int64_t getBytecodeVersion() const override {
return state.getDesiredBytecodeVersion();
}
FailureOr<const DialectVersion *>
getDialectVersion(StringRef dialectName) const override {
auto dialectEntry = dialectVersionMap.find(dialectName);
if (dialectEntry == dialectVersionMap.end())
return failure();
return dialectEntry->getValue().get();
}
IRNumberingState &state;
llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap;
};
template <typename T>
static void groupByDialectPerByte(T range) {
if (range.empty())
return;
auto sortByDialect = [](unsigned dialectToOrderFirst, const auto &lhs,
const auto &rhs) {
if (lhs->dialect->number == dialectToOrderFirst)
return rhs->dialect->number != dialectToOrderFirst;
if (rhs->dialect->number == dialectToOrderFirst)
return false;
return lhs->dialect->number < rhs->dialect->number;
};
unsigned dialectToOrderFirst = 0;
size_t elementsInByteGroup = 0;
auto iterRange = range;
for (unsigned i = 1; i < 9; ++i) {
elementsInByteGroup = (1ULL << (7ULL * i)) - elementsInByteGroup;
auto byteSubRange = iterRange.take_front(elementsInByteGroup);
iterRange = iterRange.drop_front(byteSubRange.size());
llvm::stable_sort(byteSubRange, [&](const auto &lhs, const auto &rhs) {
return sortByDialect(dialectToOrderFirst, lhs, rhs);
});
dialectToOrderFirst = byteSubRange.back()->dialect->number;
if (iterRange.empty())
break;
}
for (auto [idx, value] : llvm::enumerate(range))
value->number = idx;
}
IRNumberingState::IRNumberingState(Operation *op,
const BytecodeWriterConfig &config)
: config(config) {
computeGlobalNumberingState(op);
number(*op);
SmallVector<std::pair<Region *, unsigned>, 8> numberContext;
auto addOpRegionsToNumber = [&](Operation *op) {
MutableArrayRef<Region> regions = op->getRegions();
if (regions.empty())
return;
unsigned opFirstValueID = isIsolatedFromAbove(op) ? 0 : nextValueID;
for (Region ®ion : regions)
numberContext.emplace_back(®ion, opFirstValueID);
};
addOpRegionsToNumber(op);
while (!numberContext.empty()) {
Region *region;
std::tie(region, nextValueID) = numberContext.pop_back_val();
number(*region);
for (Operation &op : region->getOps())
addOpRegionsToNumber(&op);
}
for (auto [idx, dialect] : llvm::enumerate(dialects))
dialect.second->number = idx;
auto sortByRefCountFn = [](const auto &lhs, const auto &rhs) {
return lhs->refCount > rhs->refCount;
};
llvm::stable_sort(orderedAttrs, sortByRefCountFn);
llvm::stable_sort(orderedOpNames, sortByRefCountFn);
llvm::stable_sort(orderedTypes, sortByRefCountFn);
groupByDialectPerByte(llvm::MutableArrayRef(orderedAttrs));
groupByDialectPerByte(llvm::MutableArrayRef(orderedOpNames));
groupByDialectPerByte(llvm::MutableArrayRef(orderedTypes));
finalizeDialectResourceNumberings(op);
}
void IRNumberingState::computeGlobalNumberingState(Operation *rootOp) {
struct StackState {
Operation *op;
OperationNumbering *numbering;
bool hasUnresolvedIsolation;
};
unsigned operationID = 0;
SmallVector<StackState> opStack;
rootOp->walk([&](Operation *op, const WalkStage &stage) {
if (op->getNumRegions() && stage.isAfterAllRegions()) {
OperationNumbering *numbering = opStack.pop_back_val().numbering;
if (!numbering->isIsolatedFromAbove.has_value())
numbering->isIsolatedFromAbove = true;
return;
}
if (!stage.isBeforeAllRegions())
return;
if (!opStack.empty() && opStack.back().hasUnresolvedIsolation) {
Region *parentRegion = op->getParentRegion();
for (Value operand : op->getOperands()) {
Region *operandRegion = operand.getParentRegion();
if (operandRegion == parentRegion)
continue;
Operation *operandContainerOp = operandRegion->getParentOp();
auto it = std::find_if(
opStack.rbegin(), opStack.rend(), [=](const StackState &it) {
return !it.hasUnresolvedIsolation || it.op == operandContainerOp;
});
assert(it != opStack.rend() && "expected to find the container");
for (auto &state : llvm::make_range(opStack.rbegin(), it)) {
state.hasUnresolvedIsolation = it->hasUnresolvedIsolation;
state.numbering->isIsolatedFromAbove = false;
}
}
}
auto *numbering =
new (opAllocator.Allocate()) OperationNumbering(operationID++);
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>())
numbering->isIsolatedFromAbove = true;
operations.try_emplace(op, numbering);
if (op->getNumRegions()) {
opStack.emplace_back(StackState{
op, numbering, !numbering->isIsolatedFromAbove.has_value()});
}
});
}
void IRNumberingState::number(Attribute attr) {
auto it = attrs.insert({attr, nullptr});
if (!it.second) {
++it.first->second->refCount;
return;
}
auto *numbering = new (attrAllocator.Allocate()) AttributeNumbering(attr);
it.first->second = numbering;
orderedAttrs.push_back(numbering);
if (OpaqueAttr opaqueAttr = dyn_cast<OpaqueAttr>(attr)) {
numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace());
return;
}
numbering->dialect = &numberDialect(&attr.getDialect());
if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
for (const auto &callback : config.getAttributeWriterCallbacks()) {
NumberingDialectWriter writer(*this, config.getDialectVersionMap());
std::optional<StringRef> groupNameOverride;
if (succeeded(callback->write(attr, groupNameOverride, writer))) {
if (groupNameOverride.has_value())
numbering->dialect = &numberDialect(*groupNameOverride);
return;
}
}
if (const auto *interface = numbering->dialect->interface) {
NumberingDialectWriter writer(*this, config.getDialectVersionMap());
if (succeeded(interface->writeAttribute(attr, writer)))
return;
}
}
AsmState tempState(attr.getContext());
llvm::raw_null_ostream dummyOS;
attr.print(dummyOS, tempState);
for (const auto &it : tempState.getDialectResources())
number(it.getFirst(), it.getSecond().getArrayRef());
}
void IRNumberingState::number(Block &block) {
for (BlockArgument arg : block.getArguments()) {
valueIDs.try_emplace(arg, nextValueID++);
number(arg.getLoc());
number(arg.getType());
}
unsigned &numOps = blockOperationCounts[&block];
for (Operation &op : block) {
number(op);
++numOps;
}
}
auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & {
DialectNumbering *&numbering = registeredDialects[dialect];
if (!numbering) {
numbering = &numberDialect(dialect->getNamespace());
numbering->interface = dyn_cast<BytecodeDialectInterface>(dialect);
numbering->asmInterface = dyn_cast<OpAsmDialectInterface>(dialect);
}
return *numbering;
}
auto IRNumberingState::numberDialect(StringRef dialect) -> DialectNumbering & {
DialectNumbering *&numbering = dialects[dialect];
if (!numbering) {
numbering = new (dialectAllocator.Allocate())
DialectNumbering(dialect, dialects.size() - 1);
}
return *numbering;
}
void IRNumberingState::number(Region ®ion) {
if (region.empty())
return;
size_t firstValueID = nextValueID;
size_t blockCount = 0;
for (auto it : llvm::enumerate(region)) {
blockIDs.try_emplace(&it.value(), it.index());
number(it.value());
++blockCount;
}
regionBlockValueCounts.try_emplace(®ion, blockCount,
nextValueID - firstValueID);
}
void IRNumberingState::number(Operation &op) {
number(op.getName());
for (OpResult result : op.getResults()) {
valueIDs.try_emplace(result, nextValueID++);
number(result.getType());
}
DictionaryAttr dictAttr;
if (config.getDesiredBytecodeVersion() >= bytecode::kNativePropertiesEncoding)
dictAttr = op.getRawDictionaryAttrs();
else
dictAttr = op.getAttrDictionary();
if (!dictAttr.empty())
number(dictAttr);
if (config.getDesiredBytecodeVersion() >=
bytecode::kNativePropertiesEncoding &&
op.getPropertiesStorageSize()) {
if (op.isRegistered()) {
auto iface = cast<BytecodeOpInterface>(op);
NumberingDialectWriter writer(*this, config.getDialectVersionMap());
iface.writeProperties(writer);
} else {
if (Attribute prop = *op.getPropertiesStorage().as<Attribute *>())
number(prop);
}
}
number(op.getLoc());
}
void IRNumberingState::number(OperationName opName) {
OpNameNumbering *&numbering = opNames[opName];
if (numbering) {
++numbering->refCount;
return;
}
DialectNumbering *dialectNumber = nullptr;
if (Dialect *dialect = opName.getDialect())
dialectNumber = &numberDialect(dialect);
else
dialectNumber = &numberDialect(opName.getDialectNamespace());
numbering =
new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName);
orderedOpNames.push_back(numbering);
}
void IRNumberingState::number(Type type) {
auto it = types.insert({type, nullptr});
if (!it.second) {
++it.first->second->refCount;
return;
}
auto *numbering = new (typeAllocator.Allocate()) TypeNumbering(type);
it.first->second = numbering;
orderedTypes.push_back(numbering);
if (OpaqueType opaqueType = dyn_cast<OpaqueType>(type)) {
numbering->dialect = &numberDialect(opaqueType.getDialectNamespace());
return;
}
numbering->dialect = &numberDialect(&type.getDialect());
if (!type.hasTrait<TypeTrait::IsMutable>()) {
for (const auto &callback : config.getTypeWriterCallbacks()) {
NumberingDialectWriter writer(*this, config.getDialectVersionMap());
std::optional<StringRef> groupNameOverride;
if (succeeded(callback->write(type, groupNameOverride, writer))) {
if (groupNameOverride.has_value())
numbering->dialect = &numberDialect(*groupNameOverride);
return;
}
}
if (const auto *interface = numbering->dialect->interface) {
NumberingDialectWriter writer(*this, config.getDialectVersionMap());
if (succeeded(interface->writeType(type, writer)))
return;
}
}
AsmState tempState(type.getContext());
llvm::raw_null_ostream dummyOS;
type.print(dummyOS, tempState);
for (const auto &it : tempState.getDialectResources())
number(it.getFirst(), it.getSecond().getArrayRef());
}
void IRNumberingState::number(Dialect *dialect,
ArrayRef<AsmDialectResourceHandle> resources) {
DialectNumbering &dialectNumber = numberDialect(dialect);
assert(
dialectNumber.asmInterface &&
"expected dialect owning a resource to implement OpAsmDialectInterface");
for (const auto &resource : resources) {
if (!dialectNumber.resources.insert(resource))
return;
auto *numbering =
new (resourceAllocator.Allocate()) DialectResourceNumbering(
dialectNumber.asmInterface->getResourceKey(resource));
dialectNumber.resourceMap.insert({numbering->key, numbering});
dialectResources.try_emplace(resource, numbering);
}
}
int64_t IRNumberingState::getDesiredBytecodeVersion() const {
return config.getDesiredBytecodeVersion();
}
namespace {
struct NumberingResourceBuilder : public AsmResourceBuilder {
NumberingResourceBuilder(DialectNumbering *dialect, unsigned &nextResourceID)
: dialect(dialect), nextResourceID(nextResourceID) {}
~NumberingResourceBuilder() override = default;
void buildBlob(StringRef key, ArrayRef<char>, uint32_t) final {
numberEntry(key);
}
void buildBool(StringRef key, bool) final { numberEntry(key); }
void buildString(StringRef key, StringRef) final {
numberEntry(key);
}
void numberEntry(StringRef key) {
auto *it = dialect->resourceMap.find(key);
if (it != dialect->resourceMap.end()) {
it->second->number = nextResourceID++;
it->second->isDeclaration = false;
}
}
DialectNumbering *dialect;
unsigned &nextResourceID;
};
}
void IRNumberingState::finalizeDialectResourceNumberings(Operation *rootOp) {
unsigned nextResourceID = 0;
for (DialectNumbering &dialect : getDialects()) {
if (!dialect.asmInterface)
continue;
NumberingResourceBuilder entryBuilder(&dialect, nextResourceID);
dialect.asmInterface->buildResources(rootOp, dialect.resources,
entryBuilder);
for (const auto &it : dialect.resourceMap)
if (it.second->isDeclaration)
it.second->number = nextResourceID++;
}
}