#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
namespace {
struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
using OpRewritePattern<gpu::ShuffleOp>::OpRewritePattern;
void initialize() {
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(gpu::ShuffleOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto value = op.getValue();
auto valueType = value.getType();
auto valueLoc = value.getLoc();
auto i32 = rewriter.getI32Type();
auto i64 = rewriter.getI64Type();
if (valueType.getIntOrFloatBitWidth() == 32)
return failure();
Value lo, hi;
if (isa<FloatType>(valueType))
value = rewriter.create<arith::BitcastOp>(valueLoc, i64, value);
lo = rewriter.create<arith::TruncIOp>(valueLoc, i32, value);
auto c32 = rewriter.create<arith::ConstantOp>(
valueLoc, rewriter.getIntegerAttr(i64, 32));
hi = rewriter.create<arith::ShRUIOp>(valueLoc, value, c32);
hi = rewriter.create<arith::TruncIOp>(valueLoc, i32, hi);
ValueRange loRes =
rewriter
.create<gpu::ShuffleOp>(op.getLoc(), lo, op.getOffset(),
op.getWidth(), op.getMode())
.getResults();
ValueRange hiRes =
rewriter
.create<gpu::ShuffleOp>(op.getLoc(), hi, op.getOffset(),
op.getWidth(), op.getMode())
.getResults();
lo = rewriter.create<arith::ExtUIOp>(valueLoc, i64, loRes[0]);
hi = rewriter.create<arith::ExtUIOp>(valueLoc, i64, hiRes[0]);
hi = rewriter.create<arith::ShLIOp>(valueLoc, hi, c32);
value = rewriter.create<arith::OrIOp>(loc, hi, lo);
if (isa<FloatType>(valueType))
value = rewriter.create<arith::BitcastOp>(valueLoc, valueType, value);
auto validity = rewriter.create<arith::AndIOp>(loc, loRes[1], hiRes[1]);
rewriter.replaceOp(op, {value, validity});
return success();
}
};
}
void mlir::populateGpuShufflePatterns(RewritePatternSet &patterns) {
patterns.add<GpuShuffleRewriter>(patterns.getContext());
}