#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::x86vector;
template <typename OpTy>
static Type getSrcVectorElementType(OpTy op) {
return cast<VectorType>(op.getSrc().getType()).getElementType();
}
template <>
Type getSrcVectorElementType(Vp2IntersectOp op) {
return cast<VectorType>(op.getA().getType()).getElementType();
}
namespace {
template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
explicit LowerToIntrinsic(LLVMTypeConverter &converter)
: OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
const LLVMTypeConverter &getTypeConverter() const {
return *static_cast<const LLVMTypeConverter *>(
OpConversionPattern<OpTy>::getTypeConverter());
}
LogicalResult
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type elementType = getSrcVectorElementType<OpTy>(op);
unsigned bitwidth = elementType.getIntOrFloatBitWidth();
if (bitwidth == 32)
return LLVM::detail::oneToOneRewrite(
op, Intr32OpTy::getOperationName(), adaptor.getOperands(),
op->getAttrs(), getTypeConverter(), rewriter);
if (bitwidth == 64)
return LLVM::detail::oneToOneRewrite(
op, Intr64OpTy::getOperationName(), adaptor.getOperands(),
op->getAttrs(), getTypeConverter(), rewriter);
return rewriter.notifyMatchFailure(
op, "expected 'src' to be either f32 or f64");
}
};
struct MaskCompressOpConversion
: public ConvertOpToLLVMPattern<MaskCompressOp> {
using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto opType = adaptor.getA().getType();
Value src;
if (op.getSrc()) {
src = adaptor.getSrc();
} else if (op.getConstantSrc()) {
src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType,
op.getConstantSrcAttr());
} else {
auto zeroAttr = rewriter.getZeroAttr(opType);
src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
}
rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.getA(),
src, adaptor.getK());
return success();
}
};
struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(RsqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto opType = adaptor.getA().getType();
rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.getA());
return success();
}
};
struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto opType = adaptor.getA().getType();
Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff));
Value scale =
rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.getA(),
adaptor.getB(), scale);
return success();
}
};
template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
struct RegEntry {
using MainOp = OpTy;
using Intr32Op = Intr32OpTy;
using Intr64Op = Intr64OpTy;
};
template <typename... Args>
struct RegistryImpl {
static void registerPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns
.add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
typename Args::Intr64Op>...>(converter);
}
static void configureTarget(LLVMConversionTarget &target) {
target.addIllegalOp<typename Args::MainOp...>();
target.addLegalOp<typename Args::Intr32Op...>();
target.addLegalOp<typename Args::Intr64Op...>();
}
};
using Registry = RegistryImpl<
RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
}
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
Registry::registerPatterns(converter, patterns);
patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
converter);
}
void mlir::configureX86VectorLegalizeForExportTarget(
LLVMConversionTarget &target) {
Registry::configureTarget(target);
target.addLegalOp<MaskCompressIntrOp>();
target.addIllegalOp<MaskCompressOp>();
target.addLegalOp<RsqrtIntrOp>();
target.addIllegalOp<RsqrtOp>();
target.addLegalOp<DotIntrOp>();
target.addIllegalOp<DotOp>();
}