#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
using namespace mlir;
#include "mlir/Dialect/DLTI/DLTIDialect.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/DLTI/DLTIAttrs.cpp.inc"
#define DEBUG_TYPE "dlti"
namespace mlir {
namespace detail {
class DataLayoutEntryAttrStorage : public AttributeStorage {
public:
using KeyTy = std::pair<DataLayoutEntryKey, Attribute>;
DataLayoutEntryAttrStorage(DataLayoutEntryKey entryKey, Attribute value)
: entryKey(entryKey), value(value) {}
static DataLayoutEntryAttrStorage *
construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
return new (allocator.allocate<DataLayoutEntryAttrStorage>())
DataLayoutEntryAttrStorage(key.first, key.second);
}
bool operator==(const KeyTy &other) const {
return other.first == entryKey && other.second == value;
}
DataLayoutEntryKey entryKey;
Attribute value;
};
}
}
DataLayoutEntryAttr DataLayoutEntryAttr::get(StringAttr key, Attribute value) {
return Base::get(key.getContext(), key, value);
}
DataLayoutEntryAttr DataLayoutEntryAttr::get(Type key, Attribute value) {
return Base::get(key.getContext(), key, value);
}
DataLayoutEntryKey DataLayoutEntryAttr::getKey() const {
return getImpl()->entryKey;
}
Attribute DataLayoutEntryAttr::getValue() const { return getImpl()->value; }
Attribute DataLayoutEntryAttr::parse(AsmParser &parser, Type ty) {
if (failed(parser.parseLess()))
return {};
Type type = nullptr;
std::string identifier;
SMLoc idLoc = parser.getCurrentLocation();
OptionalParseResult parsedType = parser.parseOptionalType(type);
if (parsedType.has_value() && failed(parsedType.value()))
return {};
if (!parsedType.has_value()) {
OptionalParseResult parsedString = parser.parseOptionalString(&identifier);
if (!parsedString.has_value() || failed(parsedString.value())) {
parser.emitError(idLoc) << "expected a type or a quoted string";
return {};
}
}
Attribute value;
if (failed(parser.parseComma()) || failed(parser.parseAttribute(value)) ||
failed(parser.parseGreater()))
return {};
return type ? get(type, value)
: get(parser.getBuilder().getStringAttr(identifier), value);
}
void DataLayoutEntryAttr::print(AsmPrinter &os) const {
os << "<";
if (auto type = llvm::dyn_cast_if_present<Type>(getKey()))
os << type;
else
os << "\"" << getKey().get<StringAttr>().strref() << "\"";
os << ", " << getValue() << ">";
}
LogicalResult
DataLayoutSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<DataLayoutEntryInterface> entries) {
DenseSet<Type> types;
DenseSet<StringAttr> ids;
for (DataLayoutEntryInterface entry : entries) {
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
if (!types.insert(type).second)
return emitError() << "repeated layout entry key: " << type;
} else {
auto id = entry.getKey().get<StringAttr>();
if (!ids.insert(id).second)
return emitError() << "repeated layout entry key: " << id.getValue();
}
}
return success();
}
static void
overwriteDuplicateEntries(SmallVectorImpl<DataLayoutEntryInterface> &oldEntries,
ArrayRef<DataLayoutEntryInterface> newEntries) {
unsigned oldEntriesSize = oldEntries.size();
for (DataLayoutEntryInterface entry : newEntries) {
bool replaced = false;
for (unsigned i = 0; i < oldEntriesSize; ++i) {
if (oldEntries[i].getKey() == entry.getKey()) {
oldEntries[i] = entry;
replaced = true;
break;
}
}
if (!replaced)
oldEntries.push_back(entry);
}
}
static LogicalResult
combineOneSpec(DataLayoutSpecInterface spec,
DenseMap<TypeID, DataLayoutEntryList> &entriesForType,
DenseMap<StringAttr, DataLayoutEntryInterface> &entriesForID) {
if (!spec)
return success();
DenseMap<TypeID, DataLayoutEntryList> newEntriesForType;
DenseMap<StringAttr, DataLayoutEntryInterface> newEntriesForID;
spec.bucketEntriesByType(newEntriesForType, newEntriesForID);
for (auto &kvp : newEntriesForType) {
if (!entriesForType.count(kvp.first)) {
entriesForType[kvp.first] = std::move(kvp.second);
continue;
}
Type typeSample = kvp.second.front().getKey().get<Type>();
assert(&typeSample.getDialect() !=
typeSample.getContext()->getLoadedDialect<BuiltinDialect>() &&
"unexpected data layout entry for built-in type");
auto interface = llvm::cast<DataLayoutTypeInterface>(typeSample);
if (!interface.areCompatible(entriesForType.lookup(kvp.first), kvp.second))
return failure();
overwriteDuplicateEntries(entriesForType[kvp.first], kvp.second);
}
for (const auto &kvp : newEntriesForID) {
StringAttr id = kvp.second.getKey().get<StringAttr>();
Dialect *dialect = id.getReferencedDialect();
if (!entriesForID.count(id)) {
entriesForID[id] = kvp.second;
continue;
}
entriesForID[id] =
dialect ? cast<DataLayoutDialectInterface>(dialect)->combine(
entriesForID[id], kvp.second)
: DataLayoutDialectInterface::defaultCombine(entriesForID[id],
kvp.second);
if (!entriesForID[id])
return failure();
}
return success();
}
DataLayoutSpecAttr
DataLayoutSpecAttr::combineWith(ArrayRef<DataLayoutSpecInterface> specs) const {
if (llvm::any_of(specs, [](DataLayoutSpecInterface spec) {
return !llvm::isa<DataLayoutSpecAttr>(spec);
}))
return {};
DenseMap<TypeID, DataLayoutEntryList> entriesForType;
DenseMap<StringAttr, DataLayoutEntryInterface> entriesForID;
for (DataLayoutSpecInterface spec : specs)
if (failed(combineOneSpec(spec, entriesForType, entriesForID)))
return nullptr;
if (failed(combineOneSpec(*this, entriesForType, entriesForID)))
return nullptr;
SmallVector<DataLayoutEntryInterface> entries;
llvm::append_range(entries, llvm::make_second_range(entriesForID));
for (const auto &kvp : entriesForType)
llvm::append_range(entries, kvp.getSecond());
return DataLayoutSpecAttr::get(getContext(), entries);
}
StringAttr
DataLayoutSpecAttr::getEndiannessIdentifier(MLIRContext *context) const {
return Builder(context).getStringAttr(DLTIDialect::kDataLayoutEndiannessKey);
}
StringAttr
DataLayoutSpecAttr::getAllocaMemorySpaceIdentifier(MLIRContext *context) const {
return Builder(context).getStringAttr(
DLTIDialect::kDataLayoutAllocaMemorySpaceKey);
}
StringAttr DataLayoutSpecAttr::getProgramMemorySpaceIdentifier(
MLIRContext *context) const {
return Builder(context).getStringAttr(
DLTIDialect::kDataLayoutProgramMemorySpaceKey);
}
StringAttr
DataLayoutSpecAttr::getGlobalMemorySpaceIdentifier(MLIRContext *context) const {
return Builder(context).getStringAttr(
DLTIDialect::kDataLayoutGlobalMemorySpaceKey);
}
StringAttr
DataLayoutSpecAttr::getStackAlignmentIdentifier(MLIRContext *context) const {
return Builder(context).getStringAttr(
DLTIDialect::kDataLayoutStackAlignmentKey);
}
Attribute DataLayoutSpecAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
if (succeeded(parser.parseOptionalGreater()))
return get(parser.getContext(), {});
SmallVector<DataLayoutEntryInterface> entries;
if (parser.parseCommaSeparatedList(
[&]() { return parser.parseAttribute(entries.emplace_back()); }) ||
parser.parseGreater())
return {};
return getChecked([&] { return parser.emitError(parser.getNameLoc()); },
parser.getContext(), entries);
}
void DataLayoutSpecAttr::print(AsmPrinter &os) const {
os << "<";
llvm::interleaveComma(getEntries(), os);
os << ">";
}
namespace mlir {
template <>
struct FieldParser<DeviceIDTargetDeviceSpecPair> {
static FailureOr<DeviceIDTargetDeviceSpecPair> parse(AsmParser &parser) {
std::string deviceID;
if (failed(parser.parseString(&deviceID))) {
parser.emitError(parser.getCurrentLocation())
<< "DeviceID is missing, or is not of string type";
return failure();
}
if (failed(parser.parseColon())) {
parser.emitError(parser.getCurrentLocation()) << "Missing colon";
return failure();
}
auto target_device_spec =
FieldParser<TargetDeviceSpecInterface>::parse(parser);
if (failed(target_device_spec)) {
parser.emitError(parser.getCurrentLocation())
<< "Error in parsing target device spec";
return failure();
}
return std::make_pair(parser.getBuilder().getStringAttr(deviceID),
*target_device_spec);
}
};
inline AsmPrinter &operator<<(AsmPrinter &printer,
DeviceIDTargetDeviceSpecPair param) {
return printer << param.first << " : " << param.second;
}
}
LogicalResult
TargetDeviceSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<DataLayoutEntryInterface> entries) {
DenseSet<StringAttr> ids;
for (DataLayoutEntryInterface entry : entries) {
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
return emitError()
<< "dlti.target_device_spec does not allow type as a key: "
<< type;
} else {
auto id = entry.getKey().get<StringAttr>();
if (!ids.insert(id).second)
return emitError() << "repeated layout entry key: " << id.getValue();
}
}
return success();
}
LogicalResult
TargetSystemSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<DeviceIDTargetDeviceSpecPair> entries) {
DenseSet<TargetSystemSpecInterface::DeviceID> device_ids;
for (const auto &entry : entries) {
TargetDeviceSpecInterface target_device_spec = entry.second;
if (failed(TargetDeviceSpecAttr::verify(emitError,
target_device_spec.getEntries())))
return failure();
TargetSystemSpecInterface::DeviceID device_id = entry.first;
if (!device_ids.insert(device_id).second) {
return emitError() << "repeated Device ID in dlti.target_system_spec: "
<< device_id;
}
}
return success();
}
constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutAttrName;
constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessKey;
constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessBig;
constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessLittle;
namespace {
class TargetDataLayoutInterface : public DataLayoutDialectInterface {
public:
using DataLayoutDialectInterface::DataLayoutDialectInterface;
LogicalResult verifyEntry(DataLayoutEntryInterface entry,
Location loc) const final {
StringRef entryName = entry.getKey().get<StringAttr>().strref();
if (entryName == DLTIDialect::kDataLayoutEndiannessKey) {
auto value = llvm::dyn_cast<StringAttr>(entry.getValue());
if (value &&
(value.getValue() == DLTIDialect::kDataLayoutEndiannessBig ||
value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle))
return success();
return emitError(loc) << "'" << entryName
<< "' data layout entry is expected to be either '"
<< DLTIDialect::kDataLayoutEndiannessBig << "' or '"
<< DLTIDialect::kDataLayoutEndiannessLittle << "'";
}
if (entryName == DLTIDialect::kDataLayoutAllocaMemorySpaceKey ||
entryName == DLTIDialect::kDataLayoutProgramMemorySpaceKey ||
entryName == DLTIDialect::kDataLayoutGlobalMemorySpaceKey ||
entryName == DLTIDialect::kDataLayoutStackAlignmentKey)
return success();
return emitError(loc) << "unknown data layout entry name: " << entryName;
}
};
}
void DLTIDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/DLTI/DLTIAttrs.cpp.inc"
>();
addInterfaces<TargetDataLayoutInterface>();
}
LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
if (attr.getName() == DLTIDialect::kDataLayoutAttrName) {
if (!llvm::isa<DataLayoutSpecAttr>(attr.getValue())) {
return op->emitError() << "'" << DLTIDialect::kDataLayoutAttrName
<< "' is expected to be a #dlti.dl_spec attribute";
}
if (isa<ModuleOp>(op))
return detail::verifyDataLayoutOp(op);
return success();
} else if (attr.getName() == DLTIDialect::kTargetSystemDescAttrName) {
if (!llvm::isa<TargetSystemSpecAttr>(attr.getValue())) {
return op->emitError()
<< "'" << DLTIDialect::kTargetSystemDescAttrName
<< "' is expected to be a #dlti.target_system_spec attribute";
}
return success();
}
return op->emitError() << "attribute '" << attr.getName().getValue()
<< "' not supported by dialect";
}