#include <optional>
#include <utility>
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVectorExtras.h"
using namespace mlir;
static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
if (!attr)
return std::nullopt;
if (auto boolAttr = llvm::dyn_cast<BoolAttr>(attr))
return boolAttr.getValue();
if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
if (splatAttr.getElementType().isInteger(1))
return splatAttr.getSplatValue<bool>();
return std::nullopt;
}
static Attribute extractCompositeElement(Attribute composite,
ArrayRef<unsigned> indices) {
if (!composite)
return {};
if (indices.empty())
return composite;
if (auto vector = llvm::dyn_cast<ElementsAttr>(composite)) {
assert(indices.size() == 1 && "must have exactly one index for a vector");
return vector.getValues<Attribute>()[indices[0]];
}
if (auto array = llvm::dyn_cast<ArrayAttr>(composite)) {
assert(!indices.empty() && "must have at least one index for an array");
return extractCompositeElement(array.getValue()[indices[0]],
indices.drop_front());
}
return {};
}
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
bool div0 = b.isZero();
bool overflow = a.isMinSignedValue() && b.isAllOnes();
return div0 || overflow;
}
namespace {
#include "SPIRVCanonicalization.inc"
}
namespace {
struct CombineChainedAccessChain final
: OpRewritePattern<spirv::AccessChainOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
PatternRewriter &rewriter) const override {
auto parentAccessChainOp =
accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
if (!parentAccessChainOp) {
return failure();
}
SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
llvm::append_range(indices, accessChainOp.getIndices());
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
accessChainOp, parentAccessChainOp.getBasePtr(), indices);
return success();
}
};
}
void spirv::AccessChainOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<CombineChainedAccessChain>(context);
}
struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value lhs = op.getOperand1();
Value rhs = op.getOperand2();
Type constituentType = lhs.getType();
if (matchPattern(rhs, m_Zero())) {
Value constituents[2] = {rhs, lhs};
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
constituents);
return success();
}
Attribute lhsAttr;
Attribute rhsAttr;
if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
!matchPattern(rhs, m_Constant(&rhsAttr)))
return failure();
auto adds = constFoldBinaryOp<IntegerAttr>(
{lhsAttr, rhsAttr},
[](const APInt &a, const APInt &b) { return a + b; });
if (!adds)
return failure();
auto carrys = constFoldBinaryOp<IntegerAttr>(
ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
APInt zero = APInt::getZero(a.getBitWidth());
return a.ult(b) ? (zero + 1) : zero;
});
if (!carrys)
return failure();
Value addsVal =
rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
Value carrysVal =
rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
Value intermediate =
rewriter.create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
intermediate, 1);
return success();
}
};
void spirv::IAddCarryOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<IAddCarryFold>(context);
}
template <typename MulOp, bool IsSigned>
struct MulExtendedFold final : OpRewritePattern<MulOp> {
using OpRewritePattern<MulOp>::OpRewritePattern;
LogicalResult matchAndRewrite(MulOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value lhs = op.getOperand1();
Value rhs = op.getOperand2();
Type constituentType = lhs.getType();
if (matchPattern(rhs, m_Zero())) {
Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
Value constituents[2] = {zero, zero};
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
constituents);
return success();
}
Attribute lhsAttr;
Attribute rhsAttr;
if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
!matchPattern(rhs, m_Constant(&rhsAttr)))
return failure();
auto lowBits = constFoldBinaryOp<IntegerAttr>(
{lhsAttr, rhsAttr},
[](const APInt &a, const APInt &b) { return a * b; });
if (!lowBits)
return failure();
auto highBits = constFoldBinaryOp<IntegerAttr>(
{lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
if (IsSigned) {
return llvm::APIntOps::mulhs(a, b);
} else {
return llvm::APIntOps::mulhu(a, b);
}
});
if (!highBits)
return failure();
Value lowBitsVal =
rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
Value highBitsVal =
rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
Value intermediate =
rewriter.create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
intermediate, 1);
return success();
}
};
using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
void spirv::SMulExtendedOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<SMulExtendedOpFold>(context);
}
struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value lhs = op.getOperand1();
Value rhs = op.getOperand2();
Type constituentType = lhs.getType();
if (matchPattern(rhs, m_One())) {
Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
Value constituents[2] = {lhs, zero};
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
constituents);
return success();
}
return failure();
}
};
using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
void spirv::UMulExtendedOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
}
struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::UModOp umodOp,
PatternRewriter &rewriter) const override {
auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
if (!prevUMod)
return failure();
IntegerAttr prevValue;
IntegerAttr currValue;
if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
!matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
return failure();
APInt prevConstValue = prevValue.getValue();
APInt currConstValue = currValue.getValue();
if (prevConstValue.urem(currConstValue) != 0 &&
currConstValue.urem(prevConstValue) != 0)
return failure();
rewriter.replaceOpWithNewOp<spirv::UModOp>(
umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
return success();
}
};
void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.insert<UModSimplification>(context);
}
OpFoldResult spirv::BitcastOp::fold(FoldAdaptor ) {
Value curInput = getOperand();
if (getType() == curInput.getType())
return curInput;
if (auto prevCast = curInput.getDefiningOp<spirv::BitcastOp>()) {
Value prevInput = prevCast.getOperand();
if (prevInput.getType() == getType())
return prevInput;
getOperandMutable().assign(prevInput);
return getResult();
}
return {};
}
OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
Value compositeOp = getComposite();
while (auto insertOp =
compositeOp.getDefiningOp<spirv::CompositeInsertOp>()) {
if (getIndices() == insertOp.getIndices())
return insertOp.getObject();
compositeOp = insertOp.getComposite();
}
if (auto constructOp =
compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
if (getIndices().size() == 1 &&
constructOp.getConstituents().size() == type.getNumElements()) {
auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
if (i.getValue().getSExtValue() <
static_cast<int64_t>(constructOp.getConstituents().size()))
return constructOp.getConstituents()[i.getValue().getSExtValue()];
}
}
auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
});
return extractCompositeElement(adaptor.getComposite(), indexVector);
}
OpFoldResult spirv::ConstantOp::fold(FoldAdaptor ) {
return getValue();
}
OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
if (matchPattern(getOperand2(), m_Zero()))
return getOperand1();
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(),
[](APInt a, const APInt &b) { return std::move(a) + b; });
}
OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
if (matchPattern(getOperand2(), m_Zero()))
return getOperand2();
if (matchPattern(getOperand2(), m_One()))
return getOperand1();
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(),
[](const APInt &a, const APInt &b) { return a * b; });
}
OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
if (getOperand1() == getOperand2())
return Builder(getContext()).getIntegerAttr(getType(), 0);
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(),
[](APInt a, const APInt &b) { return std::move(a) - b; });
}
OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
if (matchPattern(getOperand2(), m_One()))
return getOperand1();
bool div0OrOverflow = false;
auto res = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
div0OrOverflow = true;
return a;
}
return a.sdiv(b);
});
return div0OrOverflow ? Attribute() : res;
}
OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
if (matchPattern(getOperand2(), m_One()))
return Builder(getContext()).getZeroAttr(getType());
bool div0OrOverflow = false;
auto res = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
div0OrOverflow = true;
return a;
}
APInt c = a.abs().urem(b.abs());
if (c.isZero())
return c;
if (b.isNegative()) {
APInt zero = APInt::getZero(c.getBitWidth());
return a.isNegative() ? (zero - c) : (b + c);
}
return a.isNegative() ? (b - c) : c;
});
return div0OrOverflow ? Attribute() : res;
}
OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
if (matchPattern(getOperand2(), m_One()))
return Builder(getContext()).getZeroAttr(getType());
bool div0OrOverflow = false;
auto res = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](APInt a, const APInt &b) {
if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
div0OrOverflow = true;
return a;
}
return a.srem(b);
});
return div0OrOverflow ? Attribute() : res;
}
OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
if (matchPattern(getOperand2(), m_One()))
return getOperand1();
bool div0 = false;
auto res = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
if (div0 || b.isZero()) {
div0 = true;
return a;
}
return a.udiv(b);
});
return div0 ? Attribute() : res;
}
OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
if (matchPattern(getOperand2(), m_One()))
return Builder(getContext()).getZeroAttr(getType());
bool div0 = false;
auto res = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
if (div0 || b.isZero()) {
div0 = true;
return a;
}
return a.urem(b);
});
return div0 ? Attribute() : res;
}
OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
auto op = getOperand();
if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
return negateOp->getOperand(0);
return constFoldUnaryOp<IntegerAttr>(
adaptor.getOperands(), [](const APInt &a) {
APInt zero = APInt::getZero(a.getBitWidth());
return zero - a;
});
}
OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
auto op = getOperand();
if (auto notOp = op.getDefiningOp<spirv::NotOp>())
return notOp->getOperand(0);
return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
a.flipAllBits();
return a;
});
}
OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
if (std::optional<bool> rhs =
getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
if (*rhs)
return getOperand1();
if (!*rhs)
return adaptor.getOperand2();
}
return Attribute();
}
OpFoldResult
spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
if (getOperand1() == getOperand2()) {
auto trueAttr = BoolAttr::get(getContext(), true);
if (isa<IntegerType>(getType()))
return trueAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, trueAttr);
}
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
});
}
OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
if (std::optional<bool> rhs =
getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
if (!rhs.value())
return getOperand1();
}
if (getOperand1() == getOperand2()) {
auto falseAttr = BoolAttr::get(getContext(), false);
if (isa<IntegerType>(getType()))
return falseAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, falseAttr);
}
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
});
}
OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
auto op = getOperand();
if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
return notOp->getOperand(0);
return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a) {
APInt zero = APInt::getZero(1);
return a == 1 ? zero : (zero + 1);
});
}
void spirv::LogicalNotOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results
.add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
context);
}
OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
if (*rhs) {
return adaptor.getOperand2();
}
if (!*rhs) {
return getOperand1();
}
}
return Attribute();
}
OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
Value trueVals = getTrueValue();
Value falseVals = getFalseValue();
if (trueVals == falseVals)
return trueVals;
ArrayRef<Attribute> operands = adaptor.getOperands();
if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
return *boolAttr ? trueVals : falseVals;
if (!operands[0] || !operands[1] || !operands[2])
return Attribute();
auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
if (!condAttrs || !trueAttrs || !falseAttrs)
return Attribute();
auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());
auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),
falseAttrs.getValues<Attribute>());
for (auto [result, cond, falseRes] : iters) {
if (!cond.getValue())
result = falseRes;
}
auto resultType = trueAttrs.getType();
return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
}
OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
if (getOperand1() == getOperand2()) {
auto trueAttr = BoolAttr::get(getContext(), true);
if (isa<IntegerType>(getType()))
return trueAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, trueAttr);
}
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
});
}
OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
if (getOperand1() == getOperand2()) {
auto falseAttr = BoolAttr::get(getContext(), false);
if (isa<IntegerType>(getType()))
return falseAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, falseAttr);
}
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
});
}
OpFoldResult
spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
if (getOperand1() == getOperand2()) {
auto falseAttr = BoolAttr::get(getContext(), false);
if (isa<IntegerType>(getType()))
return falseAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, falseAttr);
}
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
return a.sgt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
});
}
OpFoldResult spirv::SGreaterThanEqualOp::fold(
spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
if (getOperand1() == getOperand2()) {
auto trueAttr = BoolAttr::get(getContext(), true);
if (isa<IntegerType>(getType()))
return trueAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, trueAttr);
}
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
return a.sge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
});
}
OpFoldResult
spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
if (getOperand1() == getOperand2()) {
auto falseAttr = BoolAttr::get(getContext(), false);
if (isa<IntegerType>(getType()))
return falseAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, falseAttr);
}
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
return a.ugt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
});
}
OpFoldResult spirv::UGreaterThanEqualOp::fold(
spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
if (getOperand1() == getOperand2()) {
auto trueAttr = BoolAttr::get(getContext(), true);
if (isa<IntegerType>(getType()))
return trueAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, trueAttr);
}
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
return a.uge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
});
}
OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
if (getOperand1() == getOperand2()) {
auto falseAttr = BoolAttr::get(getContext(), false);
if (isa<IntegerType>(getType()))
return falseAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, falseAttr);
}
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
return a.slt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
});
}
OpFoldResult
spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
if (getOperand1() == getOperand2()) {
auto trueAttr = BoolAttr::get(getContext(), true);
if (isa<IntegerType>(getType()))
return trueAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, trueAttr);
}
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
return a.sle(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
});
}
OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
if (getOperand1() == getOperand2()) {
auto falseAttr = BoolAttr::get(getContext(), false);
if (isa<IntegerType>(getType()))
return falseAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, falseAttr);
}
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
return a.ult(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
});
}
OpFoldResult
spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
if (getOperand1() == getOperand2()) {
auto trueAttr = BoolAttr::get(getContext(), true);
if (isa<IntegerType>(getType()))
return trueAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, trueAttr);
}
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
return a.ule(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
});
}
OpFoldResult spirv::ShiftLeftLogicalOp::fold(
spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
if (matchPattern(adaptor.getOperand2(), m_Zero())) {
return getOperand1();
}
bool shiftToLarge = false;
auto res = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
if (shiftToLarge || b.uge(a.getBitWidth())) {
shiftToLarge = true;
return a;
}
return a << b;
});
return shiftToLarge ? Attribute() : res;
}
OpFoldResult spirv::ShiftRightArithmeticOp::fold(
spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
if (matchPattern(adaptor.getOperand2(), m_Zero())) {
return getOperand1();
}
bool shiftToLarge = false;
auto res = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
if (shiftToLarge || b.uge(a.getBitWidth())) {
shiftToLarge = true;
return a;
}
return a.ashr(b);
});
return shiftToLarge ? Attribute() : res;
}
OpFoldResult spirv::ShiftRightLogicalOp::fold(
spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
if (matchPattern(adaptor.getOperand2(), m_Zero())) {
return getOperand1();
}
bool shiftToLarge = false;
auto res = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
if (shiftToLarge || b.uge(a.getBitWidth())) {
shiftToLarge = true;
return a;
}
return a.lshr(b);
});
return shiftToLarge ? Attribute() : res;
}
OpFoldResult
spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
if (getOperand1() == getOperand2()) {
return getOperand1();
}
APInt rhsMask;
if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
if (rhsMask.isZero())
return getOperand2();
if (rhsMask.isAllOnes())
return getOperand1();
if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
int valueBits =
getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
return getOperand1();
}
}
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(),
[](const APInt &a, const APInt &b) { return a & b; });
}
OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
if (getOperand1() == getOperand2()) {
return getOperand1();
}
APInt rhsMask;
if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
if (rhsMask.isZero())
return getOperand1();
if (rhsMask.isAllOnes())
return getOperand2();
}
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(),
[](const APInt &a, const APInt &b) { return a | b; });
}
OpFoldResult
spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
if (matchPattern(adaptor.getOperand2(), m_Zero())) {
return getOperand1();
}
if (getOperand1() == getOperand2())
return Builder(getContext()).getZeroAttr(getType());
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(),
[](const APInt &a, const APInt &b) { return a ^ b; });
}
namespace {
// ...
struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
PatternRewriter &rewriter) const override {
Operation *op = selectionOp.getOperation();
Region &body = op->getRegion(0);
if (body.empty()) {
return failure();
}
if (llvm::range_size(body) != 4) {
return failure();
}
Block *headerBlock = selectionOp.getHeaderBlock();
if (!onlyContainsBranchConditionalOp(headerBlock)) {
return failure();
}
auto brConditionalOp =
cast<spirv::BranchConditionalOp>(headerBlock->front());
Block *trueBlock = brConditionalOp.getSuccessor(0);
Block *falseBlock = brConditionalOp.getSuccessor(1);
Block *mergeBlock = selectionOp.getMergeBlock();
if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
return failure();
Value trueValue = getSrcValue(trueBlock);
Value falseValue = getSrcValue(falseBlock);
Value ptrValue = getDstPtr(trueBlock);
auto storeOpAttributes =
cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
auto selectOp = rewriter.create<spirv::SelectOp>(
selectionOp.getLoc(), trueValue.getType(),
brConditionalOp.getCondition(), trueValue, falseValue);
rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
selectOp.getResult(), storeOpAttributes);
rewriter.eraseOp(op);
return success();
}
private:
LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
Block *mergeBlock) const;
bool onlyContainsBranchConditionalOp(Block *block) const {
return llvm::hasSingleElement(*block) &&
isa<spirv::BranchConditionalOp>(block->front());
}
bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
return lhs->getDiscardableAttrDictionary() ==
rhs->getDiscardableAttrDictionary() &&
lhs.getProperties() == rhs.getProperties();
}
Value getSrcValue(Block *block) const {
auto storeOp = cast<spirv::StoreOp>(block->front());
return storeOp.getValue();
}
Value getDstPtr(Block *block) const {
auto storeOp = cast<spirv::StoreOp>(block->front());
return storeOp.getPtr();
}
};
LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
return failure();
}
auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
auto trueBrBranchOp =
dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
auto falseBrBranchOp =
dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
!falseBrBranchOp) {
return failure();
}
bool isScalarOrVector =
llvm::cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
.isScalarOrVector();
if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
!isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
return failure();
}
if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
(falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
return failure();
}
return success();
}
}
void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ConvertSelectionOpToSelect>(context);
}