#include "mlir/Dialect/Traits.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/Support/FormatVariadic.h"
#include <optional>
using namespace mlir;
bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
ArrayRef<int64_t> shape2) {
SmallVector<SmallVector<int64_t, 6>, 2> extents;
extents.emplace_back(shape1.begin(), shape1.end());
extents.emplace_back(shape2.begin(), shape2.end());
return staticallyKnownBroadcastable(extents);
}
bool OpTrait::util::staticallyKnownBroadcastable(
ArrayRef<SmallVector<int64_t, 6>> shapes) {
assert(!shapes.empty() && "Expected at least one shape");
size_t maxRank = shapes[0].size();
for (size_t i = 1; i != shapes.size(); ++i)
maxRank = std::max(maxRank, shapes[i].size());
for (size_t i = 0; i != maxRank; ++i) {
bool seenDynamic = false;
std::optional<int64_t> nonOneDim;
for (ArrayRef<int64_t> extent : shapes) {
int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
if (dim == 1)
continue;
if (ShapedType::isDynamic(dim)) {
if (seenDynamic || nonOneDim)
return false;
seenDynamic = true;
}
if (nonOneDim && dim != *nonOneDim)
return false;
nonOneDim = dim;
}
}
return true;
}
bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
ArrayRef<int64_t> shape2,
SmallVectorImpl<int64_t> &resultShape) {
resultShape.clear();
if (shape1.size() > shape2.size()) {
std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
} else {
std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
}
auto i1 = shape1.rbegin(), e1 = shape1.rend();
auto i2 = shape2.rbegin(), e2 = shape2.rend();
auto iR = resultShape.rbegin();
for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
if (*i1 > 1) {
*iR = *i1;
} else if (*i2 > 1) {
*iR = *i2;
} else if (*i1 == 1) {
*iR = *i2;
} else if (*i2 == 1) {
*iR = *i1;
} else {
*iR = ShapedType::kDynamic;
}
} else {
if (*i1 == *i2 || *i2 == 1) {
*iR = *i1;
} else if (*i1 == 1) {
*iR = *i2;
} else {
resultShape.clear();
return false;
}
}
}
return true;
}
static ArrayRef<int64_t> getShape(Type type) {
if (auto sType = dyn_cast<ShapedType>(type))
return sType.getShape();
return {};
}
Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
Type elementType) {
if (!elementType) {
elementType = getElementTypeOrSelf(type1);
if (elementType != getElementTypeOrSelf(type2))
return {};
}
if (isa<UnrankedTensorType>(type1) || isa<UnrankedTensorType>(type2)) {
if (isa<VectorType>(type1) || isa<VectorType>(type2))
return {};
return UnrankedTensorType::get(elementType);
}
auto getCompositeTypeKind = [](Type type) -> std::optional<TypeID> {
if (isa<VectorType, RankedTensorType>(type))
return type.getTypeID();
return std::nullopt;
};
std::optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
std::optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
std::optional<TypeID> resultCompositeKind;
if (compositeKind1 && compositeKind2) {
if (compositeKind1 != compositeKind2)
return {};
resultCompositeKind = compositeKind1;
} else if (compositeKind1) {
resultCompositeKind = compositeKind1;
} else if (compositeKind2) {
resultCompositeKind = compositeKind2;
}
SmallVector<int64_t, 4> resultShape;
if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
return {};
if (resultCompositeKind == VectorType::getTypeID())
return VectorType::get(resultShape, elementType);
if (resultCompositeKind == RankedTensorType::getTypeID())
return RankedTensorType::get(resultShape, elementType);
return elementType;
}
template <typename iterator_range>
static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
return {llvm::any_of(types, llvm::IsaPred<TensorType>),
llvm::any_of(types, llvm::IsaPred<VectorType>)};
}
static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
ArrayRef<int64_t> existing) {
auto isCompatible = [](int64_t inferredDim, int64_t existingDim) {
return ShapedType::isDynamic(existingDim) ||
ShapedType::isDynamic(inferredDim) || inferredDim == existingDim;
};
if (inferred.size() != existing.size())
return false;
for (auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing))
if (!isCompatible(inferredDim, existingDim))
return false;
return true;
}
static std::string getShapeString(ArrayRef<int64_t> shape) {
std::string ret;
llvm::raw_string_ostream ss(ret);
ss << '\'';
llvm::interleave(
shape, ss,
[&](int64_t dim) {
if (ShapedType::isDynamic(dim))
ss << '?';
else
ss << dim;
},
"x");
ss << '\'';
return ss.str();
}
LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
auto operandsHasTensorVectorType =
hasTensorOrVectorType(op->getOperandTypes());
auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes());
if ((std::get<0>(operandsHasTensorVectorType) ||
std::get<0>(resultsHasTensorVectorType)) &&
(std::get<1>(operandsHasTensorVectorType) ||
std::get<1>(resultsHasTensorVectorType)))
return op->emitError("cannot broadcast vector with tensor");
auto rankedOperands =
make_filter_range(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
if (rankedOperands.empty())
return success();
SmallVector<int64_t, 4> resultShape;
(void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {},
resultShape);
for (auto other : make_early_inc_range(rankedOperands)) {
SmallVector<int64_t, 4> temp = resultShape;
if (!util::getBroadcastedShape(temp, getShape(other), resultShape))
return op->emitOpError("operands don't have broadcast-compatible shapes");
}
auto rankedResults =
make_filter_range(op->getResultTypes(), llvm::IsaPred<RankedTensorType>);
if (rankedResults.empty())
return success();
for (auto type : rankedResults) {
ArrayRef<int64_t> actualSuffix =
getShape(type).take_back(resultShape.size());
if (!isCompatibleInferredReturnShape(resultShape, actualSuffix))
return op->emitOpError()
<< "result type " << getShapeString(getShape(type))
<< " not broadcast compatible with broadcasted operands's shapes "
<< getShapeString(resultShape);
}
return success();
}