#include "mlir/IR/BuiltinTypes.h"
#include "TypeDetail.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TensorEncoding.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::detail;
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.cpp.inc"
void BuiltinDialect::registerTypes() {
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/IR/BuiltinTypes.cpp.inc"
>();
}
LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType) {
if (!elementType.isIntOrFloat())
return emitError() << "invalid element type for complex";
return success();
}
constexpr unsigned IntegerType::kMaxWidth;
LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
unsigned width,
SignednessSemantics signedness) {
if (width > IntegerType::kMaxWidth) {
return emitError() << "integer bitwidth is limited to "
<< IntegerType::kMaxWidth << " bits";
}
return success();
}
unsigned IntegerType::getWidth() const { return getImpl()->width; }
IntegerType::SignednessSemantics IntegerType::getSignedness() const {
return getImpl()->signedness;
}
IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
if (!scale)
return IntegerType();
return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
}
unsigned FloatType::getWidth() {
if (isa<Float16Type, BFloat16Type>())
return 16;
if (isa<Float32Type>())
return 32;
if (isa<Float64Type>())
return 64;
if (isa<Float80Type>())
return 80;
if (isa<Float128Type>())
return 128;
llvm_unreachable("unexpected float type");
}
const llvm::fltSemantics &FloatType::getFloatSemantics() {
if (isa<BFloat16Type>())
return APFloat::BFloat();
if (isa<Float16Type>())
return APFloat::IEEEhalf();
if (isa<Float32Type>())
return APFloat::IEEEsingle();
if (isa<Float64Type>())
return APFloat::IEEEdouble();
if (isa<Float80Type>())
return APFloat::x87DoubleExtended();
if (isa<Float128Type>())
return APFloat::IEEEquad();
llvm_unreachable("non-floating point type used");
}
FloatType FloatType::scaleElementBitwidth(unsigned scale) {
if (!scale)
return FloatType();
MLIRContext *ctx = getContext();
if (isF16() || isBF16()) {
if (scale == 2)
return FloatType::getF32(ctx);
if (scale == 4)
return FloatType::getF64(ctx);
}
if (isF32())
if (scale == 2)
return FloatType::getF64(ctx);
return FloatType();
}
unsigned FloatType::getFPMantissaWidth() {
return APFloat::semanticsPrecision(getFloatSemantics());
}
unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
ArrayRef<Type> FunctionType::getInputs() const {
return getImpl()->getInputs();
}
unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
ArrayRef<Type> FunctionType::getResults() const {
return getImpl()->getResults();
}
FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
return get(getContext(), inputs, results);
}
FunctionType FunctionType::getWithArgsAndResults(
ArrayRef<unsigned> argIndices, TypeRange argTypes,
ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
SmallVector<Type> argStorage, resultStorage;
TypeRange newArgTypes = function_interface_impl::insertTypesInto(
getInputs(), argIndices, argTypes, argStorage);
TypeRange newResultTypes = function_interface_impl::insertTypesInto(
getResults(), resultIndices, resultTypes, resultStorage);
return clone(newArgTypes, newResultTypes);
}
FunctionType
FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
const BitVector &resultIndices) {
SmallVector<Type> argStorage, resultStorage;
TypeRange newArgTypes = function_interface_impl::filterTypesOut(
getInputs(), argIndices, argStorage);
TypeRange newResultTypes = function_interface_impl::filterTypesOut(
getResults(), resultIndices, resultStorage);
return clone(newArgTypes, newResultTypes);
}
void FunctionType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
for (Type type : llvm::concat<const Type>(getInputs(), getResults()))
walkTypesFn(type);
}
Type FunctionType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
ArrayRef<Type> replTypes) const {
unsigned numInputs = getNumInputs();
return get(getContext(), replTypes.take_front(numInputs),
replTypes.drop_front(numInputs));
}
LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
StringAttr dialect, StringRef typeData) {
if (!Dialect::isValidNamespace(dialect.strref()))
return emitError() << "invalid dialect namespace '" << dialect << "'";
MLIRContext *context = dialect.getContext();
if (!context->allowsUnregisteredDialects() &&
!context->getLoadedDialect(dialect.strref())) {
return emitError()
<< "`!" << dialect << "<\"" << typeData << "\">"
<< "` type created with unregistered dialect. If this is "
"intended, please call allowUnregisteredDialects() on the "
"MLIRContext, or use -allow-unregistered-dialect with "
"the MLIR opt tool used";
}
return success();
}
LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
unsigned numScalableDims) {
if (!isValidElementType(elementType))
return emitError()
<< "vector elements must be int/index/float type but got "
<< elementType;
if (any_of(shape, [](int64_t i) { return i <= 0; }))
return emitError()
<< "vector types must have positive constant sizes but got "
<< shape;
return success();
}
VectorType VectorType::scaleElementBitwidth(unsigned scale) {
if (!scale)
return VectorType();
if (auto et = getElementType().dyn_cast<IntegerType>())
if (auto scaledEt = et.scaleElementBitwidth(scale))
return VectorType::get(getShape(), scaledEt, getNumScalableDims());
if (auto et = getElementType().dyn_cast<FloatType>())
if (auto scaledEt = et.scaleElementBitwidth(scale))
return VectorType::get(getShape(), scaledEt, getNumScalableDims());
return VectorType();
}
void VectorType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
walkTypesFn(getElementType());
}
Type VectorType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
ArrayRef<Type> replTypes) const {
return get(getShape(), replTypes.front(), getNumScalableDims());
}
VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return VectorType::get(shape.value_or(getShape()), elementType,
getNumScalableDims());
}
Type TensorType::getElementType() const {
return llvm::TypeSwitch<TensorType, Type>(*this)
.Case<RankedTensorType, UnrankedTensorType>(
[](auto type) { return type.getElementType(); });
}
bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
ArrayRef<int64_t> TensorType::getShape() const {
return cast<RankedTensorType>().getShape();
}
TensorType TensorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
Type elementType) const {
if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
if (shape)
return RankedTensorType::get(*shape, elementType);
return UnrankedTensorType::get(elementType);
}
auto rankedTy = cast<RankedTensorType>();
if (!shape)
return RankedTensorType::get(rankedTy.getShape(), elementType,
rankedTy.getEncoding());
return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
rankedTy.getEncoding());
}
static LogicalResult
checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
Type elementType) {
if (!TensorType::isValidElementType(elementType))
return emitError() << "invalid tensor element type: " << elementType;
return success();
}
bool TensorType::isValidElementType(Type type) {
return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
IndexType>() ||
!llvm::isa<BuiltinDialect>(type.getDialect());
}
LogicalResult
RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
Attribute encoding) {
for (int64_t s : shape)
if (s < -1)
return emitError() << "invalid tensor dimension size";
if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
if (failed(v.verifyEncoding(shape, elementType, emitError)))
return failure();
return checkTensorElementType(emitError, elementType);
}
void RankedTensorType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
walkTypesFn(getElementType());
if (Attribute encoding = getEncoding())
walkAttrsFn(encoding);
}
Type RankedTensorType::replaceImmediateSubElements(
ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
return get(getShape(), replTypes.front(),
replAttrs.empty() ? Attribute() : replAttrs.back());
}
LogicalResult
UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType) {
return checkTensorElementType(emitError, elementType);
}
void UnrankedTensorType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
walkTypesFn(getElementType());
}
Type UnrankedTensorType::replaceImmediateSubElements(
ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
return get(replTypes.front());
}
Type BaseMemRefType::getElementType() const {
return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
.Case<MemRefType, UnrankedMemRefType>(
[](auto type) { return type.getElementType(); });
}
bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
ArrayRef<int64_t> BaseMemRefType::getShape() const {
return cast<MemRefType>().getShape();
}
BaseMemRefType BaseMemRefType::cloneWith(Optional<ArrayRef<int64_t>> shape,
Type elementType) const {
if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
if (!shape)
return UnrankedMemRefType::get(elementType, getMemorySpace());
MemRefType::Builder builder(*shape, elementType);
builder.setMemorySpace(getMemorySpace());
return builder;
}
MemRefType::Builder builder(cast<MemRefType>());
if (shape)
builder.setShape(*shape);
builder.setElementType(elementType);
return builder;
}
Attribute BaseMemRefType::getMemorySpace() const {
if (auto rankedMemRefTy = dyn_cast<MemRefType>())
return rankedMemRefTy.getMemorySpace();
return cast<UnrankedMemRefType>().getMemorySpace();
}
unsigned BaseMemRefType::getMemorySpaceAsInt() const {
if (auto rankedMemRefTy = dyn_cast<MemRefType>())
return rankedMemRefTy.getMemorySpaceAsInt();
return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
}
llvm::Optional<llvm::SmallDenseSet<unsigned>>
mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> reducedShape) {
size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
llvm::SmallDenseSet<unsigned> unusedDims;
unsigned reducedIdx = 0;
for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
if (reducedIdx < reducedRank &&
originalShape[originalIdx] == reducedShape[reducedIdx]) {
reducedIdx++;
continue;
}
unusedDims.insert(originalIdx);
if (originalShape[originalIdx] != 1)
return llvm::None;
}
if (reducedIdx != reducedRank)
return llvm::None;
return unusedDims;
}
SliceVerificationResult
mlir::isRankReducedType(ShapedType originalType,
ShapedType candidateReducedType) {
if (originalType == candidateReducedType)
return SliceVerificationResult::Success;
ShapedType originalShapedType = originalType.cast<ShapedType>();
ShapedType candidateReducedShapedType =
candidateReducedType.cast<ShapedType>();
ArrayRef<int64_t> originalShape = originalShapedType.getShape();
ArrayRef<int64_t> candidateReducedShape =
candidateReducedShapedType.getShape();
unsigned originalRank = originalShape.size(),
candidateReducedRank = candidateReducedShape.size();
if (candidateReducedRank > originalRank)
return SliceVerificationResult::RankTooLarge;
auto optionalUnusedDimsMask =
computeRankReductionMask(originalShape, candidateReducedShape);
if (!optionalUnusedDimsMask)
return SliceVerificationResult::SizeMismatch;
if (originalShapedType.getElementType() !=
candidateReducedShapedType.getElementType())
return SliceVerificationResult::ElemTypeMismatch;
return SliceVerificationResult::Success;
}
bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
if (!memorySpace)
return true;
if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
return true;
if (!isa<BuiltinDialect>(memorySpace.getDialect()))
return true;
return false;
}
Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
MLIRContext *ctx) {
if (memorySpace == 0)
return nullptr;
return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
}
Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
if (intMemorySpace && intMemorySpace.getValue() == 0)
return nullptr;
return memorySpace;
}
unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
if (!memorySpace)
return 0;
assert(memorySpace.isa<IntegerAttr>() &&
"Using `getMemorySpaceInteger` with non-Integer attribute");
return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
}
MemRefType::Builder &
MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
memorySpace =
wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
return *this;
}
unsigned MemRefType::getMemorySpaceAsInt() const {
return detail::getMemorySpaceAsInt(getMemorySpace());
}
MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
if (!layout)
layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
shape.size(), elementType.getContext()));
memorySpace = skipDefaultMemorySpace(memorySpace);
return Base::get(elementType.getContext(), shape, elementType, layout,
memorySpace);
}
MemRefType MemRefType::getChecked(
function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
if (!layout)
layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
shape.size(), elementType.getContext()));
memorySpace = skipDefaultMemorySpace(memorySpace);
return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
elementType, layout, memorySpace);
}
MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
AffineMap map, Attribute memorySpace) {
if (!map)
map = AffineMap::getMultiDimIdentityMap(shape.size(),
elementType.getContext());
Attribute layout = AffineMapAttr::get(map);
memorySpace = skipDefaultMemorySpace(memorySpace);
return Base::get(elementType.getContext(), shape, elementType, layout,
memorySpace);
}
MemRefType
MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
ArrayRef<int64_t> shape, Type elementType, AffineMap map,
Attribute memorySpace) {
if (!map)
map = AffineMap::getMultiDimIdentityMap(shape.size(),
elementType.getContext());
Attribute layout = AffineMapAttr::get(map);
memorySpace = skipDefaultMemorySpace(memorySpace);
return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
elementType, layout, memorySpace);
}
MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
AffineMap map, unsigned memorySpaceInd) {
if (!map)
map = AffineMap::getMultiDimIdentityMap(shape.size(),
elementType.getContext());
Attribute layout = AffineMapAttr::get(map);
Attribute memorySpace =
wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
return Base::get(elementType.getContext(), shape, elementType, layout,
memorySpace);
}
MemRefType
MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
ArrayRef<int64_t> shape, Type elementType, AffineMap map,
unsigned memorySpaceInd) {
if (!map)
map = AffineMap::getMultiDimIdentityMap(shape.size(),
elementType.getContext());
Attribute layout = AffineMapAttr::get(map);
Attribute memorySpace =
wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
elementType, layout, memorySpace);
}
LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
if (!BaseMemRefType::isValidElementType(elementType))
return emitError() << "invalid memref element type";
for (int64_t s : shape)
if (s < -1)
return emitError() << "invalid memref size";
assert(layout && "missing layout specification");
if (failed(layout.verifyLayout(shape, emitError)))
return failure();
if (!isSupportedMemorySpace(memorySpace))
return emitError() << "unsupported memory space Attribute";
return success();
}
void MemRefType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
walkTypesFn(getElementType());
if (!getLayout().isIdentity())
walkAttrsFn(getLayout());
walkAttrsFn(getMemorySpace());
}
Type MemRefType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
ArrayRef<Type> replTypes) const {
bool hasLayout = replAttrs.size() > 1;
return get(getShape(), replTypes[0],
hasLayout ? replAttrs[0].dyn_cast<MemRefLayoutAttrInterface>()
: MemRefLayoutAttrInterface(),
hasLayout ? replAttrs[1] : replAttrs[0]);
}
unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
return detail::getMemorySpaceAsInt(getMemorySpace());
}
LogicalResult
UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType, Attribute memorySpace) {
if (!BaseMemRefType::isValidElementType(elementType))
return emitError() << "invalid memref element type";
if (!isSupportedMemorySpace(memorySpace))
return emitError() << "unsupported memory space Attribute";
return success();
}
static void extractStridesFromTerm(AffineExpr e,
AffineExpr multiplicativeFactor,
MutableArrayRef<AffineExpr> strides,
AffineExpr &offset) {
if (auto dim = e.dyn_cast<AffineDimExpr>())
strides[dim.getPosition()] =
strides[dim.getPosition()] + multiplicativeFactor;
else
offset = offset + e * multiplicativeFactor;
}
static LogicalResult extractStrides(AffineExpr e,
AffineExpr multiplicativeFactor,
MutableArrayRef<AffineExpr> strides,
AffineExpr &offset) {
auto bin = e.dyn_cast<AffineBinaryOpExpr>();
if (!bin) {
extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
return success();
}
if (bin.getKind() == AffineExprKind::CeilDiv ||
bin.getKind() == AffineExprKind::FloorDiv ||
bin.getKind() == AffineExprKind::Mod)
return failure();
if (bin.getKind() == AffineExprKind::Mul) {
auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
if (dim) {
strides[dim.getPosition()] =
strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
return success();
}
if (bin.getLHS().isSymbolicOrConstant())
return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
strides, offset);
return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
strides, offset);
}
if (bin.getKind() == AffineExprKind::Add) {
auto res1 =
extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
auto res2 =
extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
return success(succeeded(res1) && succeeded(res2));
}
llvm_unreachable("unexpected binary operation");
}
LogicalResult mlir::getStridesAndOffset(MemRefType t,
SmallVectorImpl<AffineExpr> &strides,
AffineExpr &offset) {
AffineMap m = t.getLayout().getAffineMap();
if (m.getNumResults() != 1 && !m.isIdentity())
return failure();
auto zero = getAffineConstantExpr(0, t.getContext());
auto one = getAffineConstantExpr(1, t.getContext());
offset = zero;
strides.assign(t.getRank(), zero);
if (m.isIdentity()) {
if (t.getRank() == 0)
return success();
auto stridedExpr =
makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
return success();
assert(false && "unexpected failure: extract strides in canonical layout");
}
auto stridedExpr =
simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
if (failed(extractStrides(stridedExpr, one, strides, offset))) {
offset = AffineExpr();
strides.clear();
return failure();
}
unsigned numDims = m.getNumDims();
unsigned numSymbols = m.getNumSymbols();
offset = simplifyAffineExpr(offset, numDims, numSymbols);
for (auto &stride : strides)
stride = simplifyAffineExpr(stride, numDims, numSymbols);
if (llvm::any_of(strides, [](AffineExpr e) {
return e == getAffineConstantExpr(0, e.getContext());
})) {
offset = AffineExpr();
strides.clear();
return failure();
}
return success();
}
LogicalResult mlir::getStridesAndOffset(MemRefType t,
SmallVectorImpl<int64_t> &strides,
int64_t &offset) {
AffineExpr offsetExpr;
SmallVector<AffineExpr, 4> strideExprs;
if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
return failure();
if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
offset = cst.getValue();
else
offset = ShapedType::kDynamicStrideOrOffset;
for (auto e : strideExprs) {
if (auto c = e.dyn_cast<AffineConstantExpr>())
strides.push_back(c.getValue());
else
strides.push_back(ShapedType::kDynamicStrideOrOffset);
}
return success();
}
void UnrankedMemRefType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
walkTypesFn(getElementType());
walkAttrsFn(getMemorySpace());
}
Type UnrankedMemRefType::replaceImmediateSubElements(
ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
return get(replTypes.front(), replAttrs.front());
}
ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
for (Type type : getTypes()) {
if (auto nestedTuple = type.dyn_cast<TupleType>())
nestedTuple.getFlattenedTypes(types);
else
types.push_back(type);
}
}
size_t TupleType::size() const { return getImpl()->size(); }
void TupleType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
for (Type type : getTypes())
walkTypesFn(type);
}
Type TupleType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
ArrayRef<Type> replTypes) const {
return get(getContext(), replTypes);
}
AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
int64_t offset,
MLIRContext *context) {
AffineExpr expr;
unsigned nSymbols = 0;
if (offset != MemRefType::getDynamicStrideOrOffset()) {
auto cst = getAffineConstantExpr(offset, context);
expr = cst;
} else {
auto sym = getAffineSymbolExpr(nSymbols++, context);
expr = sym;
}
for (const auto &en : llvm::enumerate(strides)) {
auto dim = en.index();
auto stride = en.value();
assert(stride != 0 && "Invalid stride specification");
auto d = getAffineDimExpr(dim, context);
AffineExpr mult;
if (stride != MemRefType::getDynamicStrideOrOffset())
mult = getAffineConstantExpr(stride, context);
else
mult = getAffineSymbolExpr(nSymbols++, context);
expr = expr + d * mult;
}
return AffineMap::get(strides.size(), nSymbols, expr);
}
MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
AffineMap m = t.getLayout().getAffineMap();
if (m.isIdentity())
return t;
if (m.getNumResults() > 1)
return t;
if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
if (cst.getValue() == 0)
return MemRefType::Builder(t).setLayout({});
return t;
}
if (t.getShape().empty())
return t;
AffineExpr expr =
makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
auto simplifiedLayoutExpr =
simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
if (expr != simplifiedLayoutExpr)
return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
return MemRefType::Builder(t).setLayout({});
}
AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
ArrayRef<AffineExpr> exprs,
MLIRContext *context) {
if (sizes.empty() || llvm::is_contained(sizes, 0))
return getAffineConstantExpr(0, context);
assert(!exprs.empty() && "expected exprs");
auto maps = AffineMap::inferFromExprList(exprs);
assert(!maps.empty() && "Expected one non-empty map");
unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
AffineExpr expr;
bool dynamicPoisonBit = false;
int64_t runningSize = 1;
for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
int64_t size = std::get<1>(en);
if (size == 0)
continue;
AffineExpr dimExpr = std::get<0>(en);
AffineExpr stride = dynamicPoisonBit
? getAffineSymbolExpr(nSymbols++, context)
: getAffineConstantExpr(runningSize, context);
expr = expr ? expr + dimExpr * stride : dimExpr * stride;
if (size > 0) {
runningSize *= size;
assert(runningSize > 0 && "integer overflow in size computation");
} else {
dynamicPoisonBit = true;
}
}
return simplifyAffineExpr(expr, numDims, nSymbols);
}
MemRefType mlir::eraseStridedLayout(MemRefType t) {
auto val = ShapedType::kDynamicStrideOrOffset;
return MemRefType::Builder(t).setLayout(
AffineMapAttr::get(makeStridedLinearLayoutMap(
SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())));
}
AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
MLIRContext *context) {
SmallVector<AffineExpr, 4> exprs;
exprs.reserve(sizes.size());
for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
exprs.push_back(getAffineDimExpr(dim, context));
return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
}
bool mlir::isStrided(MemRefType t) {
int64_t offset;
SmallVector<int64_t, 4> strides;
auto res = getStridesAndOffset(t, strides, offset);
return succeeded(res);
}
AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(t, strides, offset)))
return AffineMap();
return makeStridedLinearLayoutMap(strides, offset, t.getContext());
}
static AffineExpr getOffsetExpr(MemRefType memrefType) {
SmallVector<AffineExpr> strides;
AffineExpr offset;
if (failed(getStridesAndOffset(memrefType, strides, offset)))
assert(false && "expected strided memref");
return offset;
}
static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context,
ArrayRef<int64_t> shape,
Type elementType,
AffineExpr offset) {
AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context);
AffineExpr contiguousRowMajor = canonical + offset;
AffineMap contiguousRowMajorMap =
AffineMap::inferFromExprList({contiguousRowMajor})[0];
return MemRefType::get(shape, elementType, contiguousRowMajorMap);
}
bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) {
if (!memrefType.hasStaticShape())
return false;
AffineExpr offset = getOffsetExpr(memrefType);
MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType(
memrefType.getContext(), memrefType.getShape(),
memrefType.getElementType(), offset);
return canonicalizeStridedLayout(memrefType) ==
canonicalizeStridedLayout(contiguousRowMajorMemRefType);
}