#ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
#define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringMap.h"
#include <cstdint>
namespace mlir {
class BytecodeDialectInterface;
class BytecodeWriterConfig;
namespace bytecode {
namespace detail {
struct DialectNumbering;
struct AttrTypeNumbering {
AttrTypeNumbering(PointerUnion<Attribute, Type> value) : value(value) {}
PointerUnion<Attribute, Type> value;
unsigned number = 0;
unsigned refCount = 1;
DialectNumbering *dialect = nullptr;
};
struct AttributeNumbering : public AttrTypeNumbering {
AttributeNumbering(Attribute value) : AttrTypeNumbering(value) {}
Attribute getValue() const { return value.get<Attribute>(); }
};
struct TypeNumbering : public AttrTypeNumbering {
TypeNumbering(Type value) : AttrTypeNumbering(value) {}
Type getValue() const { return value.get<Type>(); }
};
struct OpNameNumbering {
OpNameNumbering(DialectNumbering *dialect, OperationName name)
: dialect(dialect), name(name) {}
DialectNumbering *dialect;
OperationName name;
unsigned number = 0;
unsigned refCount = 1;
};
struct DialectResourceNumbering {
DialectResourceNumbering(std::string key) : key(std::move(key)) {}
std::string key;
unsigned number = 0;
bool isDeclaration = true;
};
struct DialectNumbering {
DialectNumbering(StringRef name, unsigned number)
: name(name), number(number) {}
StringRef name;
unsigned number;
const BytecodeDialectInterface *interface = nullptr;
const OpAsmDialectInterface *asmInterface = nullptr;
SetVector<AsmDialectResourceHandle> resources;
llvm::MapVector<StringRef, DialectResourceNumbering *> resourceMap;
};
struct OperationNumbering {
OperationNumbering(unsigned number) : number(number) {}
unsigned number;
std::optional<bool> isIsolatedFromAbove;
};
class IRNumberingState {
public:
IRNumberingState(Operation *op, const BytecodeWriterConfig &config);
auto getDialects() {
return llvm::make_pointee_range(llvm::make_second_range(dialects));
}
auto getAttributes() { return llvm::make_pointee_range(orderedAttrs); }
auto getOpNames() { return llvm::make_pointee_range(orderedOpNames); }
auto getTypes() { return llvm::make_pointee_range(orderedTypes); }
unsigned getNumber(Attribute attr) {
assert(attrs.count(attr) && "attribute not numbered");
return attrs[attr]->number;
}
unsigned getNumber(Block *block) {
assert(blockIDs.count(block) && "block not numbered");
return blockIDs[block];
}
unsigned getNumber(Operation *op) {
assert(operations.count(op) && "operation not numbered");
return operations[op]->number;
}
unsigned getNumber(OperationName opName) {
assert(opNames.count(opName) && "opName not numbered");
return opNames[opName]->number;
}
unsigned getNumber(Type type) {
assert(types.count(type) && "type not numbered");
return types[type]->number;
}
unsigned getNumber(Value value) {
assert(valueIDs.count(value) && "value not numbered");
return valueIDs[value];
}
unsigned getNumber(const AsmDialectResourceHandle &resource) {
assert(dialectResources.count(resource) && "resource not numbered");
return dialectResources[resource]->number;
}
std::pair<unsigned, unsigned> getBlockValueCount(Region *region) {
assert(regionBlockValueCounts.count(region) && "value not numbered");
return regionBlockValueCounts[region];
}
unsigned getOperationCount(Block *block) {
assert(blockOperationCounts.count(block) && "block not numbered");
return blockOperationCounts[block];
}
bool isIsolatedFromAbove(Operation *op) {
assert(operations.count(op) && "operation not numbered");
return operations[op]->isIsolatedFromAbove.value_or(false);
}
int64_t getDesiredBytecodeVersion() const;
private:
struct NumberingDialectWriter;
void computeGlobalNumberingState(Operation *rootOp);
void number(Attribute attr);
void number(Block &block);
DialectNumbering &numberDialect(Dialect *dialect);
DialectNumbering &numberDialect(StringRef dialect);
void number(Operation &op);
void number(OperationName opName);
void number(Region ®ion);
void number(Type type);
void number(Dialect *dialect, ArrayRef<AsmDialectResourceHandle> resources);
void finalizeDialectResourceNumberings(Operation *rootOp);
DenseMap<Attribute, AttributeNumbering *> attrs;
DenseMap<Operation *, OperationNumbering *> operations;
DenseMap<OperationName, OpNameNumbering *> opNames;
DenseMap<Type, TypeNumbering *> types;
DenseMap<Dialect *, DialectNumbering *> registeredDialects;
llvm::MapVector<StringRef, DialectNumbering *> dialects;
std::vector<AttributeNumbering *> orderedAttrs;
std::vector<OpNameNumbering *> orderedOpNames;
std::vector<TypeNumbering *> orderedTypes;
llvm::DenseMap<AsmDialectResourceHandle, DialectResourceNumbering *>
dialectResources;
llvm::SpecificBumpPtrAllocator<AttributeNumbering> attrAllocator;
llvm::SpecificBumpPtrAllocator<DialectNumbering> dialectAllocator;
llvm::SpecificBumpPtrAllocator<OperationNumbering> opAllocator;
llvm::SpecificBumpPtrAllocator<OpNameNumbering> opNameAllocator;
llvm::SpecificBumpPtrAllocator<DialectResourceNumbering> resourceAllocator;
llvm::SpecificBumpPtrAllocator<TypeNumbering> typeAllocator;
DenseMap<Block *, unsigned> blockIDs;
DenseMap<Value, unsigned> valueIDs;
DenseMap<Block *, unsigned> blockOperationCounts;
DenseMap<Region *, std::pair<unsigned, unsigned>> regionBlockValueCounts;
unsigned nextValueID = 0;
const BytecodeWriterConfig &config;
};
}
}
}
#endif