#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#define DEBUG_TYPE "spirv-to-llvm-pattern"
using namespace mlir;
constexpr unsigned defaultAddressSpace = 0;
static bool isSignedIntegerOrVector(Type type) {
if (type.isSignedInteger())
return true;
if (auto vecType = dyn_cast<VectorType>(type))
return vecType.getElementType().isSignedInteger();
return false;
}
static bool isUnsignedIntegerOrVector(Type type) {
if (type.isUnsignedInteger())
return true;
if (auto vecType = dyn_cast<VectorType>(type))
return vecType.getElementType().isUnsignedInteger();
return false;
}
static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
if (auto intType = dyn_cast<IntegerType>(type))
return intType.getWidth();
if (auto vecType = dyn_cast<VectorType>(type))
if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
return intType.getWidth();
return std::nullopt;
}
static unsigned getBitWidth(Type type) {
assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
"bitwidth is not supported for this type");
if (type.isIntOrFloat())
return type.getIntOrFloatBitWidth();
auto vecType = dyn_cast<VectorType>(type);
auto elementType = vecType.getElementType();
assert(elementType.isIntOrFloat() &&
"only integers and floats have a bitwidth");
return elementType.getIntOrFloatBitWidth();
}
static unsigned getLLVMTypeBitWidth(Type type) {
return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
? LLVM::getVectorElementType(type)
: type))
.getWidth();
}
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
if (auto vecType = dyn_cast<VectorType>(type)) {
auto integerType = cast<IntegerType>(vecType.getElementType());
return builder.getIntegerAttr(integerType, -1);
}
auto integerType = cast<IntegerType>(type);
return builder.getIntegerAttr(integerType, -1);
}
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
PatternRewriter &rewriter) {
if (isa<VectorType>(srcType)) {
return rewriter.create<LLVM::ConstantOp>(
loc, dstType,
SplatElementsAttr::get(cast<ShapedType>(srcType),
minusOneIntegerAttribute(srcType, rewriter)));
}
return rewriter.create<LLVM::ConstantOp>(
loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
}
static Value createFPConstant(Location loc, Type srcType, Type dstType,
PatternRewriter &rewriter, double value) {
if (auto vecType = dyn_cast<VectorType>(srcType)) {
auto floatType = cast<FloatType>(vecType.getElementType());
return rewriter.create<LLVM::ConstantOp>(
loc, dstType,
SplatElementsAttr::get(vecType,
rewriter.getFloatAttr(floatType, value)));
}
auto floatType = cast<FloatType>(srcType);
return rewriter.create<LLVM::ConstantOp>(
loc, dstType, rewriter.getFloatAttr(floatType, value));
}
static Value optionallyTruncateOrExtend(Location loc, Value value,
Type llvmType,
PatternRewriter &rewriter) {
auto srcType = value.getType();
unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
? getLLVMTypeBitWidth(srcType)
: getBitWidth(srcType);
if (valueBitWidth < targetBitWidth)
return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
if (valueBitWidth > targetBitWidth)
return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
return value;
}
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
auto vectorType = VectorType::get(numElements, toBroadcast.getType());
auto llvmVectorType = typeConverter.convertType(vectorType);
auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
for (unsigned i = 0; i < numElements; ++i) {
auto index = rewriter.create<LLVM::ConstantOp>(
loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
broadcasted = rewriter.create<LLVM::InsertElementOp>(
loc, llvmVectorType, broadcasted, toBroadcast, index);
}
return broadcasted;
}
static Value optionallyBroadcast(Location loc, Value value, Type srcType,
LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
if (auto vectorType = dyn_cast<VectorType>(srcType)) {
unsigned numElements = vectorType.getNumElements();
return broadcast(loc, value, numElements, typeConverter, rewriter);
}
return value;
}
static Value processCountOrOffset(Location loc, Value value, Type srcType,
Type dstType, LLVMTypeConverter &converter,
ConversionPatternRewriter &rewriter) {
Value broadcasted =
optionallyBroadcast(loc, value, srcType, converter, rewriter);
return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
}
static Type convertStructTypeWithOffset(spirv::StructType type,
LLVMTypeConverter &converter) {
if (type != VulkanLayoutUtils::decorateType(type))
return nullptr;
SmallVector<Type> elementsVector;
if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
return nullptr;
return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
false);
}
static Type convertStructTypePacked(spirv::StructType type,
LLVMTypeConverter &converter) {
SmallVector<Type> elementsVector;
if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
return nullptr;
return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
true);
}
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
unsigned value) {
return rewriter.create<LLVM::ConstantOp>(
loc, IntegerType::get(rewriter.getContext(), 32),
rewriter.getIntegerAttr(rewriter.getI32Type(), value));
}
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter,
unsigned alignment, bool isVolatile,
bool isNonTemporal) {
if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
auto dstType = typeConverter.convertType(loadOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
isVolatile, isNonTemporal);
return success();
}
auto storeOp = cast<spirv::StoreOp>(op);
spirv::StoreOpAdaptor adaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
adaptor.getPtr(), alignment,
isVolatile, isNonTemporal);
return success();
}
static std::optional<Type> convertArrayType(spirv::ArrayType type,
TypeConverter &converter) {
unsigned stride = type.getArrayStride();
Type elementType = type.getElementType();
auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
return std::nullopt;
auto llvmElementType = converter.convertType(elementType);
unsigned numElements = type.getNumElements();
return LLVM::LLVMArrayType::get(llvmElementType, numElements);
}
static unsigned mapToOpenCLAddressSpace(spirv::StorageClass storageClass) {
switch (storageClass) {
#define STORAGE_SPACE_MAP(storage, space) \
case spirv::StorageClass::storage: \
return space;
STORAGE_SPACE_MAP(Function, 0)
STORAGE_SPACE_MAP(CrossWorkgroup, 1)
STORAGE_SPACE_MAP(Input, 1)
STORAGE_SPACE_MAP(UniformConstant, 2)
STORAGE_SPACE_MAP(Workgroup, 3)
STORAGE_SPACE_MAP(Generic, 4)
STORAGE_SPACE_MAP(DeviceOnlyINTEL, 5)
STORAGE_SPACE_MAP(HostOnlyINTEL, 6)
#undef STORAGE_SPACE_MAP
default:
return defaultAddressSpace;
}
}
static unsigned mapToAddressSpace(spirv::ClientAPI clientAPI,
spirv::StorageClass storageClass) {
switch (clientAPI) {
#define CLIENT_MAP(client, storage) \
case spirv::ClientAPI::client: \
return mapTo##client##AddressSpace(storage);
CLIENT_MAP(OpenCL, storageClass)
#undef CLIENT_MAP
default:
return defaultAddressSpace;
}
}
static Type convertPointerType(spirv::PointerType type,
LLVMTypeConverter &converter,
spirv::ClientAPI clientAPI) {
unsigned addressSpace = mapToAddressSpace(clientAPI, type.getStorageClass());
return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
}
static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
TypeConverter &converter) {
if (type.getArrayStride() != 0)
return std::nullopt;
auto elementType = converter.convertType(type.getElementType());
return LLVM::LLVMArrayType::get(elementType, 0);
}
static Type convertStructType(spirv::StructType type,
LLVMTypeConverter &converter) {
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
type.getMemberDecorations(memberDecorations);
if (!memberDecorations.empty())
return nullptr;
if (type.hasOffset())
return convertStructTypeWithOffset(type, converter);
return convertStructTypePacked(type, converter);
}
namespace {
class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
public:
using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
auto indices = llvm::to_vector<4>(adaptor.getIndices());
Type indexType = op.getIndices().front().getType();
auto llvmIndexType = typeConverter.convertType(indexType);
if (!llvmIndexType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
Value zero = rewriter.create<LLVM::ConstantOp>(
op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
indices.insert(indices.begin(), zero);
auto elementType = typeConverter.convertType(
cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
if (!elementType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
adaptor.getBasePtr(), indices);
return success();
}
};
class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
public:
using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(op.getPointer().getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
op.getVariable());
return success();
}
};
class BitFieldInsertPattern
: public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
public:
using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = op.getType();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
Location loc = op.getLoc();
Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
typeConverter, rewriter);
Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
typeConverter, rewriter);
Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
Value maskShiftedByCount =
rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
maskShiftedByCount, minusOne);
Value maskShiftedByCountAndOffset =
rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
Value mask = rewriter.create<LLVM::XOrOp>(
loc, dstType, maskShiftedByCountAndOffset, minusOne);
Value baseAndMask =
rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
Value insertShiftedByOffset =
rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
insertShiftedByOffset);
return success();
}
};
class ConstantScalarAndVectorPattern
: public SPIRVToLLVMConversion<spirv::ConstantOp> {
public:
using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = constOp.getType();
if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
return failure();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
return rewriter.notifyMatchFailure(constOp, "type conversion failed");
if (isSignedIntegerOrVector(srcType) ||
isUnsignedIntegerOrVector(srcType)) {
auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
if (isa<VectorType>(srcType)) {
auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
constOp, dstType,
dstElementsAttr.mapValues(
signlessType, [&](const APInt &value) { return value; }));
return success();
}
auto srcAttr = cast<IntegerAttr>(constOp.getValue());
auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
return success();
}
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
return success();
}
};
class BitFieldSExtractPattern
: public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
public:
using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = op.getType();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
Location loc = op.getLoc();
Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
typeConverter, rewriter);
Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
typeConverter, rewriter);
IntegerType integerType;
if (auto vecType = dyn_cast<VectorType>(srcType))
integerType = cast<IntegerType>(vecType.getElementType());
else
integerType = cast<IntegerType>(srcType);
auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
Value size =
isa<VectorType>(srcType)
? rewriter.create<LLVM::ConstantOp>(
loc, dstType,
SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
: rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
Value countPlusOffset =
rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
Value amountToShiftLeft =
rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
loc, dstType, op.getBase(), amountToShiftLeft);
Value amountToShiftRight =
rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
amountToShiftRight);
return success();
}
};
class BitFieldUExtractPattern
: public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
public:
using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = op.getType();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
Location loc = op.getLoc();
Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
typeConverter, rewriter);
Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
typeConverter, rewriter);
Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
Value maskShiftedByCount =
rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
minusOne);
Value shiftedBase =
rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
return success();
}
};
class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
public:
using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
branchOp.getTarget());
return success();
}
};
class BranchConditionalConversionPattern
: public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
public:
using SPIRVToLLVMConversion<
spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
DenseI32ArrayAttr branchWeights = nullptr;
if (auto weights = op.getBranchWeights()) {
SmallVector<int32_t> weightValues;
for (auto weight : weights->getAsRange<IntegerAttr>())
weightValues.push_back(weight.getInt());
branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
}
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, op.getCondition(), op.getTrueBlockArguments(),
op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
op.getFalseBlock());
return success();
}
};
class CompositeExtractPattern
: public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
public:
using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
Type containerType = op.getComposite().getType();
if (isa<VectorType>(containerType)) {
Location loc = op.getLoc();
IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
Value index = createI32ConstantOf(loc, rewriter, value.getInt());
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
op, dstType, adaptor.getComposite(), index);
return success();
}
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
op, adaptor.getComposite(),
LLVM::convertArrayToIndices(op.getIndices()));
return success();
}
};
class CompositeInsertPattern
: public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
public:
using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
Type containerType = op.getComposite().getType();
if (isa<VectorType>(containerType)) {
Location loc = op.getLoc();
IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
Value index = createI32ConstantOf(loc, rewriter, value.getInt());
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
return success();
}
rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
op, adaptor.getComposite(), adaptor.getObject(),
LLVM::convertArrayToIndices(op.getIndices()));
return success();
}
};
template <typename SPIRVOp, typename LLVMOp>
class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
public:
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.template replaceOpWithNewOp<LLVMOp>(
op, dstType, adaptor.getOperands(), op->getAttrs());
return success();
}
};
class ExecutionModePattern
: public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
public:
using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ModuleOp module = op->getParentOfType<ModuleOp>();
spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
std::string moduleName;
if (module.getName().has_value())
moduleName = "_" + module.getName()->str();
else
moduleName = "";
std::string executionModeInfoName = llvm::formatv(
"__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
static_cast<uint32_t>(executionModeAttr.getValue()));
MLIRContext *context = rewriter.getContext();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
auto llvmI32Type = IntegerType::get(context, 32);
SmallVector<Type, 2> fields;
fields.push_back(llvmI32Type);
ArrayAttr values = op.getValues();
if (!values.empty()) {
auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
fields.push_back(arrayType);
}
auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
auto global = rewriter.create<LLVM::GlobalOp>(
UnknownLoc::get(context), structType, true,
LLVM::Linkage::External, executionModeInfoName, Attribute(),
0);
Location loc = global.getLoc();
Region ®ion = global.getInitializerRegion();
Block *block = rewriter.createBlock(®ion);
rewriter.setInsertionPoint(block, block->begin());
Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
Value executionMode = rewriter.create<LLVM::ConstantOp>(
loc, llvmI32Type,
rewriter.getI32IntegerAttr(
static_cast<uint32_t>(executionModeAttr.getValue())));
structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
executionMode, 0);
for (unsigned i = 0, e = values.size(); i < e; ++i) {
auto attr = values.getValue()[i];
Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
structValue = rewriter.create<LLVM::InsertValueOp>(
loc, structValue, entry, ArrayRef<int64_t>({1, i}));
}
rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
rewriter.eraseOp(op);
return success();
}
};
class GlobalVariablePattern
: public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
public:
template <typename... Args>
GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
: SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
std::forward<Args>(args)...),
clientAPI(clientAPI) {}
LogicalResult
matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getInitializer())
return failure();
auto srcType = cast<spirv::PointerType>(op.getType());
auto dstType = typeConverter.convertType(srcType.getPointeeType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
auto storageClass = srcType.getStorageClass();
switch (storageClass) {
case spirv::StorageClass::Input:
case spirv::StorageClass::Private:
case spirv::StorageClass::Output:
case spirv::StorageClass::StorageBuffer:
case spirv::StorageClass::UniformConstant:
break;
default:
return failure();
}
bool isConstant = (storageClass == spirv::StorageClass::Input) ||
(storageClass == spirv::StorageClass::UniformConstant);
auto linkage = storageClass == spirv::StorageClass::Private
? LLVM::Linkage::Private
: LLVM::Linkage::External;
auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
0, mapToAddressSpace(clientAPI, storageClass));
if (op.getLocationAttr())
newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
return success();
}
private:
spirv::ClientAPI clientAPI;
};
template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
public:
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type fromType = op.getOperand().getType();
Type toType = op.getType();
auto dstType = this->typeConverter.convertType(toType);
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
if (getBitWidth(fromType) < getBitWidth(toType)) {
rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
adaptor.getOperands());
return success();
}
if (getBitWidth(fromType) > getBitWidth(toType)) {
rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
adaptor.getOperands());
return success();
}
return failure();
}
};
class FunctionCallPattern
: public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
public:
using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (callOp.getNumResults() == 0) {
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
return success();
}
auto dstType = typeConverter.convertType(callOp.getType(0));
if (!dstType)
return rewriter.notifyMatchFailure(callOp, "type conversion failed");
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
return success();
}
};
template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
public:
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
op, dstType, predicate, op.getOperand1(), op.getOperand2());
return success();
}
};
template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
public:
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
op, dstType, predicate, op.getOperand1(), op.getOperand2());
return success();
}
};
class InverseSqrtPattern
: public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
public:
using SPIRVToLLVMConversion<spirv::GLInverseSqrtOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = op.getType();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
Location loc = op.getLoc();
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
return success();
}
};
template <typename SPIRVOp>
class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
public:
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.getMemoryAccess()) {
return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
this->typeConverter, 0,
false,
false);
}
auto memoryAccess = *op.getMemoryAccess();
switch (memoryAccess) {
case spirv::MemoryAccess::Aligned:
case spirv::MemoryAccess::None:
case spirv::MemoryAccess::Nontemporal:
case spirv::MemoryAccess::Volatile: {
unsigned alignment =
memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
this->typeConverter, alignment, isVolatile,
isNonTemporal);
}
default:
return failure();
}
}
};
template <typename SPIRVOp>
class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
public:
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = notOp.getType();
auto dstType = this->typeConverter.convertType(srcType);
if (!dstType)
return rewriter.notifyMatchFailure(notOp, "type conversion failed");
Location loc = notOp.getLoc();
IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
auto mask =
isa<VectorType>(srcType)
? rewriter.create<LLVM::ConstantOp>(
loc, dstType,
SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
: rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
notOp.getOperand(), mask);
return success();
}
};
template <typename SPIRVOp>
class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
public:
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
public:
using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
ArrayRef<Value>());
return success();
}
};
class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
public:
using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
adaptor.getOperands());
return success();
}
};
class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
public:
using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (loopOp.getLoopControl() != spirv::LoopControl::None)
return failure();
Location loc = loopOp.getLoc();
Block *currentBlock = rewriter.getBlock();
auto position = Block::iterator(loopOp);
Block *endBlock = rewriter.splitBlock(currentBlock, position);
Block *entryBlock = loopOp.getEntryBlock();
assert(entryBlock->getOperations().size() == 1);
auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
if (!brOp)
return failure();
Block *headerBlock = loopOp.getHeaderBlock();
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
rewriter.eraseBlock(entryBlock);
Block *mergeBlock = loopOp.getMergeBlock();
Operation *terminator = mergeBlock->getTerminator();
ValueRange terminatorOperands = terminator->getOperands();
rewriter.setInsertionPointToEnd(mergeBlock);
rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
rewriter.replaceOp(loopOp, endBlock->getArguments());
return success();
}
};
class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
public:
using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getSelectionControl() != spirv::SelectionControl::None)
return failure();
if (op.getBody().getBlocks().size() <= 2) {
rewriter.eraseOp(op);
return success();
}
Location loc = op.getLoc();
auto *currentBlock = rewriter.getInsertionBlock();
rewriter.setInsertionPointAfter(op);
auto position = rewriter.getInsertionPoint();
auto *continueBlock = rewriter.splitBlock(currentBlock, position);
auto *headerBlock = op.getHeaderBlock();
assert(headerBlock->getOperations().size() == 1);
auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
headerBlock->getOperations().front());
if (!condBrOp)
return failure();
rewriter.eraseBlock(headerBlock);
auto *mergeBlock = op.getMergeBlock();
Operation *terminator = mergeBlock->getTerminator();
ValueRange terminatorOperands = terminator->getOperands();
rewriter.setInsertionPointToEnd(mergeBlock);
rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
Block *trueBlock = condBrOp.getTrueBlock();
Block *falseBlock = condBrOp.getFalseBlock();
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
condBrOp.getTrueTargetOperands(),
falseBlock,
condBrOp.getFalseTargetOperands());
rewriter.inlineRegionBefore(op.getBody(), continueBlock);
rewriter.replaceOp(op, continueBlock->getArguments());
return success();
}
};
template <typename SPIRVOp, typename LLVMOp>
class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
public:
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
Type op1Type = op.getOperand1().getType();
Type op2Type = op.getOperand2().getType();
if (op1Type == op2Type) {
rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
adaptor.getOperands());
return success();
}
std::optional<uint64_t> dstTypeWidth =
getIntegerOrVectorElementWidth(dstType);
std::optional<uint64_t> op2TypeWidth =
getIntegerOrVectorElementWidth(op2Type);
if (!dstTypeWidth || !op2TypeWidth)
return failure();
Location loc = op.getLoc();
Value extended;
if (op2TypeWidth < dstTypeWidth) {
if (isUnsignedIntegerOrVector(op2Type)) {
extended = rewriter.template create<LLVM::ZExtOp>(
loc, dstType, adaptor.getOperand2());
} else {
extended = rewriter.template create<LLVM::SExtOp>(
loc, dstType, adaptor.getOperand2());
}
} else if (op2TypeWidth == dstTypeWidth) {
extended = adaptor.getOperand2();
} else {
return failure();
}
Value result = rewriter.template create<LLVMOp>(
loc, dstType, adaptor.getOperand1(), extended);
rewriter.replaceOp(op, result);
return success();
}
};
class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
public:
using SPIRVToLLVMConversion<spirv::GLTanOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(tanOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
Location loc = tanOp.getLoc();
Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
return success();
}
};
class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
public:
using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = tanhOp.getType();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
Location loc = tanhOp.getLoc();
Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
Value multiplied =
rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
Value numerator =
rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
Value denominator =
rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
denominator);
return success();
}
};
class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
public:
using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = varOp.getType();
auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
auto init = varOp.getInitializer();
if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
return failure();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
return rewriter.notifyMatchFailure(varOp, "type conversion failed");
Location loc = varOp.getLoc();
Value size = createI32ConstantOf(loc, rewriter, 1);
if (!init) {
auto elementType = typeConverter.convertType(pointerTo);
if (!elementType)
return rewriter.notifyMatchFailure(varOp, "type conversion failed");
rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
size);
return success();
}
auto elementType = typeConverter.convertType(pointerTo);
if (!elementType)
return rewriter.notifyMatchFailure(varOp, "type conversion failed");
Value allocated =
rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
rewriter.replaceOp(varOp, allocated);
return success();
}
};
class BitcastConversionPattern
: public SPIRVToLLVMConversion<spirv::BitcastOp> {
public:
using SPIRVToLLVMConversion<spirv::BitcastOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(bitcastOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
if (isa<LLVM::LLVMPointerType>(dstType)) {
rewriter.replaceOp(bitcastOp, adaptor.getOperand());
return success();
}
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
return success();
}
};
class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
public:
using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto funcType = funcOp.getFunctionType();
TypeConverter::SignatureConversion signatureConverter(
funcType.getNumInputs());
auto llvmType = typeConverter.convertFunctionSignature(
funcType, false, false,
signatureConverter);
if (!llvmType)
return failure();
Location loc = funcOp.getLoc();
StringRef name = funcOp.getName();
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
MLIRContext *context = funcOp.getContext();
switch (funcOp.getFunctionControl()) {
case spirv::FunctionControl::Inline:
newFuncOp.setAlwaysInline(true);
break;
case spirv::FunctionControl::DontInline:
newFuncOp.setNoInline(true);
break;
#define DISPATCH(functionControl, llvmAttr) \
case functionControl: \
newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
break;
DISPATCH(spirv::FunctionControl::Pure,
StringAttr::get(context, "readonly"));
DISPATCH(spirv::FunctionControl::Const,
StringAttr::get(context, "readnone"));
#undef DISPATCH
default:
break;
}
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
&signatureConverter))) {
return failure();
}
rewriter.eraseOp(funcOp);
return success();
}
};
class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
public:
using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newModuleOp =
rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
rewriter.eraseOp(spvModuleOp);
return success();
}
};
class VectorShufflePattern
: public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
public:
using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto components = adaptor.getComponents();
auto vector1 = adaptor.getVector1();
auto vector2 = adaptor.getVector2();
int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
if (vector1Size == vector2Size) {
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
op, vector1, vector2,
LLVM::convertArrayToIndices<int32_t>(components));
return success();
}
auto dstType = typeConverter.convertType(op.getType());
if (!dstType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
auto scalarType = cast<VectorType>(dstType).getElementType();
auto componentsArray = components.getValue();
auto *context = rewriter.getContext();
auto llvmI32Type = IntegerType::get(context, 32);
Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
for (unsigned i = 0; i < componentsArray.size(); i++) {
if (!isa<IntegerAttr>(componentsArray[i]))
return op.emitError("unable to support non-constant component");
int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
if (indexVal == -1)
continue;
int offsetVal = 0;
Value baseVector = vector1;
if (indexVal >= vector1Size) {
offsetVal = vector1Size;
baseVector = vector2;
}
Value dstIndex = rewriter.create<LLVM::ConstantOp>(
loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
Value index = rewriter.create<LLVM::ConstantOp>(
loc, llvmI32Type,
rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
loc, scalarType, baseVector, index);
targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
extractOp, dstIndex);
}
rewriter.replaceOp(op, targetOp);
return success();
}
};
}
void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter,
spirv::ClientAPI clientAPI) {
typeConverter.addConversion([&](spirv::ArrayType type) {
return convertArrayType(type, typeConverter);
});
typeConverter.addConversion([&, clientAPI](spirv::PointerType type) {
return convertPointerType(type, typeConverter, clientAPI);
});
typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
return convertRuntimeArrayType(type, typeConverter);
});
typeConverter.addConversion([&](spirv::StructType type) {
return convertStructType(type, typeConverter);
});
}
void mlir::populateSPIRVToLLVMConversionPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
spirv::ClientAPI clientAPI) {
patterns.add<
DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
NotPattern<spirv::NotOp>,
BitcastConversionPattern,
DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
FComparePattern<spirv::FUnordGreaterThanEqualOp,
LLVM::FCmpPredicate::uge>,
FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
ConstantScalarAndVectorPattern,
BranchConversionPattern, BranchConditionalConversionPattern,
FunctionCallPattern, LoopPattern, SelectionPattern,
ErasePattern<spirv::MergeOp>,
ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
InverseSqrtPattern, TanPattern, TanhPattern,
DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
NotPattern<spirv::LogicalNotOp>,
AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
LoadStorePattern<spirv::StoreOp>, VariablePattern,
CompositeExtractPattern, CompositeInsertPattern,
DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
VectorShufflePattern,
ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
typeConverter);
}
void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
}
void mlir::populateSPIRVToLLVMModuleConversionPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
}
static constexpr StringRef kBinding = "binding";
static constexpr StringRef kDescriptorSet = "descriptor_set";
void mlir::encodeBindAttribute(ModuleOp module) {
auto spvModules = module.getOps<spirv::ModuleOp>();
for (auto spvModule : spvModules) {
spvModule.walk([&](spirv::GlobalVariableOp op) {
IntegerAttr descriptorSet =
op->getAttrOfType<IntegerAttr>(kDescriptorSet);
IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
if (descriptorSet && binding) {
auto moduleAndName =
spvModule.getName().has_value()
? spvModule.getName()->str() + "_" + op.getSymName().str()
: op.getSymName().str();
std::string name =
llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
std::to_string(descriptorSet.getInt()),
std::to_string(binding.getInt()));
auto nameAttr = StringAttr::get(op->getContext(), name);
if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
op.emitError("unable to replace all symbol uses for ") << name;
SymbolTable::setSymbolName(op, nameAttr);
op->removeAttr(kDescriptorSet);
op->removeAttr(kBinding);
}
});
}
}