#include "mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include <array>
#include <cstdint>
namespace mlir {
namespace spirv {
#define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
}
}
namespace mlir {
namespace spirv {
namespace {
static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
if (auto intTy = dyn_cast<IntegerType>(type))
return IntegerAttr::get(intTy, sizedValue);
return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
}
static Value lowerExtendedMultiplication(Operation *mulOp,
PatternRewriter &rewriter, Value lhs,
Value rhs, bool signExtendArguments) {
Location loc = mulOp->getLoc();
Type argTy = lhs.getType();
Value cstLowMask = rewriter.create<ConstantOp>(
loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
};
Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
getScalarOrSplatAttr(argTy, 16));
auto getHighDigit = [&rewriter, loc, cst16](Value val) {
return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
};
auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
return getHighDigit(
rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16));
};
Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
getScalarOrSplatAttr(argTy, 0));
Value lhsLow = getLowDigit(lhs);
Value lhsHigh = getHighDigit(lhs);
Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0;
Value rhsLow = getLowDigit(rhs);
Value rhsHigh = getHighDigit(rhs);
Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0;
std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
if (i + j >= resultDigits.size())
continue;
if (lhsDigit == cst0 || rhsDigit == cst0)
continue;
Value &thisResDigit = resultDigits[i + j];
Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
thisResDigit = getLowDigit(current);
if (i + j + 1 != resultDigits.size()) {
Value &nextResDigit = resultDigits[i + j + 1];
Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
getHighDigit(current));
nextResDigit = carry;
}
}
}
auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
return rewriter.create<BitwiseOrOp>(loc, low, highBits);
};
Value low = combineDigits(resultDigits[0], resultDigits[1]);
Value high = combineDigits(resultDigits[2], resultDigits[3]);
return rewriter.create<CompositeConstructOp>(
loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high}));
}
template <typename MulExtendedOp, bool SignExtendArguments>
struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
using OpRewritePattern<MulExtendedOp>::OpRewritePattern;
LogicalResult matchAndRewrite(MulExtendedOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value lhs = op.getOperand1();
Value rhs = op.getOperand2();
auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
if (elemTy.getIntOrFloatBitWidth() != 32)
return rewriter.notifyMatchFailure(
loc,
llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs,
SignExtendArguments);
rewriter.replaceOp(op, mul);
return success();
}
};
using ExpandSMulExtendedPattern =
ExpandMulExtendedPattern<SMulExtendedOp, true>;
using ExpandUMulExtendedPattern =
ExpandMulExtendedPattern<UMulExtendedOp, false>;
struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
using OpRewritePattern<IAddCarryOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IAddCarryOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value lhs = op.getOperand1();
Value rhs = op.getOperand2();
Type argTy = lhs.getType();
auto elemTy = cast<IntegerType>(getElementTypeOrSelf(argTy));
if (elemTy.getIntOrFloatBitWidth() != 32)
return rewriter.notifyMatchFailure(
loc,
llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
Value one =
rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1));
Value zero =
rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0));
Value out = rewriter.create<IAddOp>(loc, lhs, rhs);
Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs);
Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero);
Value add = rewriter.create<CompositeConstructOp>(
loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry}));
rewriter.replaceOp(op, add);
return success();
}
};
struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(IsInfOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
return success();
}
};
struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(IsNanOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
return success();
}
};
struct WebGPUPreparePass final
: impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateSPIRVExpandExtendedMultiplicationPatterns(patterns);
populateSPIRVExpandNonFiniteArithmeticPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};
}
void populateSPIRVExpandExtendedMultiplicationPatterns(
RewritePatternSet &patterns) {
patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
ExpandAddCarryPattern>(patterns.getContext());
}
void populateSPIRVExpandNonFiniteArithmeticPatterns(
RewritePatternSet &patterns) {
patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
}
}
}