#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include <numeric>
using namespace mlir;
Type mlir::getElementTypeOrSelf(Type type) {
if (auto st = llvm::dyn_cast<ShapedType>(type))
return st.getElementType();
return type;
}
Type mlir::getElementTypeOrSelf(Value val) {
return getElementTypeOrSelf(val.getType());
}
Type mlir::getElementTypeOrSelf(Attribute attr) {
if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr))
return getElementTypeOrSelf(typedAttr.getType());
return {};
}
SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) {
SmallVector<Type, 10> fTypes;
t.getFlattenedTypes(fTypes);
return fTypes;
}
bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
StringRef typeData) {
if (auto opaque = llvm::dyn_cast<mlir::OpaqueType>(type))
return opaque.getDialectNamespace() == dialect &&
opaque.getTypeData() == typeData;
return false;
}
LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1,
ArrayRef<int64_t> shape2) {
if (shape1.size() != shape2.size())
return failure();
for (auto dims : llvm::zip(shape1, shape2)) {
int64_t dim1 = std::get<0>(dims);
int64_t dim2 = std::get<1>(dims);
if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) &&
dim1 != dim2)
return failure();
}
return success();
}
LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) {
auto sType1 = llvm::dyn_cast<ShapedType>(type1);
auto sType2 = llvm::dyn_cast<ShapedType>(type2);
if (!sType1)
return success(!sType2);
if (!sType2)
return failure();
if (!sType1.hasRank() || !sType2.hasRank())
return success();
return verifyCompatibleShape(sType1.getShape(), sType2.getShape());
}
LogicalResult mlir::verifyCompatibleShapes(TypeRange types1, TypeRange types2) {
if (types1.size() != types2.size())
return failure();
for (auto it : llvm::zip_first(types1, types2))
if (failed(verifyCompatibleShape(std::get<0>(it), std::get<1>(it))))
return failure();
return success();
}
LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
if (dims.empty())
return success();
auto staticDim = std::accumulate(
dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) {
return ShapedType::isDynamic(dim) ? fold : dim;
});
return success(llvm::all_of(dims, [&](auto dim) {
return ShapedType::isDynamic(dim) || dim == staticDim;
}));
}
LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
auto shapedTypes = llvm::map_to_vector<8>(
types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); });
if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
return success();
if (!llvm::all_of(shapedTypes, [](auto t) { return t; }))
return failure();
bool hasScalableVecTypes = false;
bool hasNonScalableVecTypes = false;
for (Type t : types) {
auto vType = llvm::dyn_cast<VectorType>(t);
if (vType && vType.isScalable())
hasScalableVecTypes = true;
else
hasNonScalableVecTypes = true;
if (hasScalableVecTypes && hasNonScalableVecTypes)
return failure();
}
auto shapes = llvm::to_vector<8>(llvm::make_filter_range(
shapedTypes, [](auto shapedType) { return shapedType.hasRank(); }));
if (shapes.empty())
return success();
auto firstRank = shapes.front().getRank();
if (llvm::any_of(shapes,
[&](auto shape) { return firstRank != shape.getRank(); }))
return failure();
for (unsigned i = 0; i < firstRank; ++i) {
auto dims = llvm::map_to_vector<8>(
llvm::make_filter_range(
shapes, [&](auto shape) { return shape.getRank() >= i; }),
[&](auto shape) { return shape.getDimSize(i); });
if (verifyCompatibleDims(dims).failed())
return failure();
}
return success();
}
Type OperandElementTypeIterator::mapElement(Value value) const {
return llvm::cast<ShapedType>(value.getType()).getElementType();
}
Type ResultElementTypeIterator::mapElement(Value value) const {
return llvm::cast<ShapedType>(value.getType()).getElementType();
}
TypeRange mlir::insertTypesInto(TypeRange oldTypes, ArrayRef<unsigned> indices,
TypeRange newTypes,
SmallVectorImpl<Type> &storage) {
assert(indices.size() == newTypes.size() &&
"mismatch between indice and type count");
if (indices.empty())
return oldTypes;
auto fromIt = oldTypes.begin();
for (auto it : llvm::zip(indices, newTypes)) {
const auto toIt = oldTypes.begin() + std::get<0>(it);
storage.append(fromIt, toIt);
storage.push_back(std::get<1>(it));
fromIt = toIt;
}
storage.append(fromIt, oldTypes.end());
return storage;
}
TypeRange mlir::filterTypesOut(TypeRange types, const BitVector &indices,
SmallVectorImpl<Type> &storage) {
if (indices.none())
return types;
for (unsigned i = 0, e = types.size(); i < e; ++i)
if (!indices[i])
storage.emplace_back(types[i]);
return storage;
}