#ifndef ATTRIBUTEDETAIL_H_
#define ATTRIBUTEDETAIL_H_
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/AttributeSupport.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/StorageUniquer.h"
#include "mlir/Support/ThreadLocalCache.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
namespace detail {
inline size_t getDenseElementBitWidth(Type eltType) {
if (ComplexType comp = llvm::dyn_cast<ComplexType>(eltType))
return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
if (eltType.isIndex())
return IndexType::kInternalStorageBitWidth;
return eltType.getIntOrFloatBitWidth();
}
struct DenseElementsAttributeStorage : public AttributeStorage {
public:
DenseElementsAttributeStorage(ShapedType type, bool isSplat)
: type(type), isSplat(isSplat) {}
ShapedType type;
bool isSplat;
};
struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
DenseIntOrFPElementsAttrStorage(ShapedType ty, ArrayRef<char> data,
bool isSplat = false)
: DenseElementsAttributeStorage(ty, isSplat), data(data) {}
struct KeyTy {
KeyTy(ShapedType type, ArrayRef<char> data, llvm::hash_code hashCode,
bool isSplat = false)
: type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
ShapedType type;
ArrayRef<char> data;
llvm::hash_code hashCode;
bool isSplat;
};
bool operator==(const KeyTy &key) const {
return key.type == type && key.data == data;
}
static KeyTy getKey(ShapedType ty, ArrayRef<char> data, bool isKnownSplat) {
if (data.empty())
return KeyTy(ty, data, 0);
bool isBoolData = ty.getElementType().isInteger(1);
if (isKnownSplat) {
if (isBoolData)
return getKeyForSplatBoolData(ty, data[0] != 0);
return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
}
size_t numElements = ty.getNumElements();
assert(numElements != 1 && "splat of 1 element should already be detected");
if (isBoolData)
return getKeyForBoolData(ty, data, numElements);
size_t elementWidth = getDenseElementBitWidth(ty.getElementType());
size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT);
assert(((data.size() / storageSize) == numElements) &&
"data does not hold expected number of elements");
auto firstElt = data.take_front(storageSize);
auto hashVal = llvm::hash_value(firstElt);
for (size_t i = storageSize, e = data.size(); i != e; i += storageSize)
if (memcmp(data.data(), &data[i], storageSize))
return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
return KeyTy(ty, firstElt, hashVal, true);
}
static KeyTy getKeyForBoolData(ShapedType ty, ArrayRef<char> data,
size_t numElements) {
ArrayRef<char> splatData = data;
bool splatValue = splatData.front() & 1;
if (splatData == ArrayRef<char>(splatValue ? kSplatTrue : kSplatFalse))
return getKeyForSplatBoolData(ty, splatValue);
size_t numOddElements = numElements % CHAR_BIT;
if (splatValue && numOddElements != 0) {
char lastElt = splatData.back();
if (lastElt != llvm::maskTrailingOnes<unsigned char>(numOddElements))
return KeyTy(ty, data, llvm::hash_value(data));
if (splatData.size() == 1)
return getKeyForSplatBoolData(ty, splatValue);
splatData = splatData.drop_back();
}
char mask = splatValue ? ~0 : 0;
return llvm::all_of(splatData, [mask](char c) { return c == mask; })
? getKeyForSplatBoolData(ty, splatValue)
: KeyTy(ty, data, llvm::hash_value(data));
}
static KeyTy getKeyForSplatBoolData(ShapedType type, bool splatValue) {
const char &splatData = splatValue ? kSplatTrue : kSplatFalse;
return KeyTy(type, splatData, llvm::hash_value(splatData),
true);
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_combine(key.type, key.hashCode);
}
static DenseIntOrFPElementsAttrStorage *
construct(AttributeStorageAllocator &allocator, KeyTy key) {
ArrayRef<char> copy, data = key.data;
if (!data.empty()) {
char *rawData = reinterpret_cast<char *>(
allocator.allocate(data.size(), alignof(uint64_t)));
std::memcpy(rawData, data.data(), data.size());
copy = ArrayRef<char>(rawData, data.size());
}
return new (allocator.allocate<DenseIntOrFPElementsAttrStorage>())
DenseIntOrFPElementsAttrStorage(key.type, copy, key.isSplat);
}
ArrayRef<char> data;
static const char kSplatTrue;
static const char kSplatFalse;
};
struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
DenseStringElementsAttrStorage(ShapedType ty, ArrayRef<StringRef> data,
bool isSplat = false)
: DenseElementsAttributeStorage(ty, isSplat), data(data) {}
struct KeyTy {
KeyTy(ShapedType type, ArrayRef<StringRef> data, llvm::hash_code hashCode,
bool isSplat = false)
: type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
ShapedType type;
ArrayRef<StringRef> data;
llvm::hash_code hashCode;
bool isSplat;
};
bool operator==(const KeyTy &key) const {
if (key.type != type)
return false;
return key.data == data;
}
static KeyTy getKey(ShapedType ty, ArrayRef<StringRef> data,
bool isKnownSplat) {
if (data.empty())
return KeyTy(ty, data, 0);
if (isKnownSplat)
return KeyTy(ty, data, llvm::hash_value(data.front()), isKnownSplat);
assert(ty.getNumElements() != 1 &&
"splat of 1 element should already be detected");
const auto &firstElt = data.front();
auto hashVal = llvm::hash_value(firstElt);
for (size_t i = 1, e = data.size(); i != e; i++)
if (firstElt != data[i])
return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
return KeyTy(ty, data.take_front(), hashVal, true);
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_combine(key.type, key.hashCode);
}
static DenseStringElementsAttrStorage *
construct(AttributeStorageAllocator &allocator, KeyTy key) {
ArrayRef<StringRef> copy, data = key.data;
if (data.empty()) {
return new (allocator.allocate<DenseStringElementsAttrStorage>())
DenseStringElementsAttrStorage(key.type, copy, key.isSplat);
}
int numEntries = key.isSplat ? 1 : data.size();
size_t dataSize = sizeof(StringRef) * numEntries;
for (int i = 0; i < numEntries; i++)
dataSize += data[i].size();
char *rawData = reinterpret_cast<char *>(
allocator.allocate(dataSize, alignof(uint64_t)));
auto mutableCopy = MutableArrayRef<StringRef>(
reinterpret_cast<StringRef *>(rawData), numEntries);
auto *stringData = rawData + numEntries * sizeof(StringRef);
for (int i = 0; i < numEntries; i++) {
memcpy(stringData, data[i].data(), data[i].size());
mutableCopy[i] = StringRef(stringData, data[i].size());
stringData += data[i].size();
}
copy =
ArrayRef<StringRef>(reinterpret_cast<StringRef *>(rawData), numEntries);
return new (allocator.allocate<DenseStringElementsAttrStorage>())
DenseStringElementsAttrStorage(key.type, copy, key.isSplat);
}
ArrayRef<StringRef> data;
};
struct StringAttrStorage : public AttributeStorage {
StringAttrStorage(StringRef value, Type type)
: type(type), value(value), referencedDialect(nullptr) {}
using KeyTy = std::pair<StringRef, Type>;
bool operator==(const KeyTy &key) const {
return value == key.first && type == key.second;
}
static ::llvm::hash_code hashKey(const KeyTy &key) {
return DenseMapInfo<KeyTy>::getHashValue(key);
}
static StringAttrStorage *construct(AttributeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<StringAttrStorage>())
StringAttrStorage(allocator.copyInto(key.first), key.second);
}
void initialize(MLIRContext *context);
Type type;
StringRef value;
Dialect *referencedDialect;
};
struct DistinctAttrStorage : public AttributeStorage {
using KeyTy = Attribute;
DistinctAttrStorage(Attribute referencedAttr)
: referencedAttr(referencedAttr) {}
KeyTy getAsKey() const { return KeyTy(referencedAttr); }
Attribute referencedAttr;
};
class DistinctAttributeUniquer {
public:
template <typename T, typename... Args>
static T get(MLIRContext *context, Args &&...args) {
static_assert(std::is_same_v<typename T::ImplType, DistinctAttrStorage>,
"expects a distinct attribute storage");
DistinctAttrStorage *storage = DistinctAttributeUniquer::allocateStorage(
context, std::forward<Args>(args)...);
storage->initializeAbstractAttribute(
AbstractAttribute::lookup(DistinctAttr::getTypeID(), context));
return storage;
}
private:
static DistinctAttrStorage *allocateStorage(MLIRContext *context,
Attribute referencedAttr);
};
class DistinctAttributeAllocator {
public:
DistinctAttributeAllocator() = default;
DistinctAttributeAllocator(DistinctAttributeAllocator &&) = delete;
DistinctAttributeAllocator(const DistinctAttributeAllocator &) = delete;
DistinctAttributeAllocator &
operator=(const DistinctAttributeAllocator &) = delete;
DistinctAttrStorage *allocate(Attribute referencedAttr) {
return new (allocatorCache.get().Allocate<DistinctAttrStorage>())
DistinctAttrStorage(referencedAttr);
}
private:
ThreadLocalCache<llvm::BumpPtrAllocator> allocatorCache;
};
}
}
#endif