#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/FormatVariadic.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTVULKANLAUNCHFUNCTOVULKANCALLSPASS
#include "mlir/Conversion/Passes.h.inc"
}
using namespace mlir;
static constexpr const char *kCInterfaceVulkanLaunch =
"_mlir_ciface_vulkanLaunch";
static constexpr const char *kDeinitVulkan = "deinitVulkan";
static constexpr const char *kRunOnVulkan = "runOnVulkan";
static constexpr const char *kInitVulkan = "initVulkan";
static constexpr const char *kSetBinaryShader = "setBinaryShader";
static constexpr const char *kSetEntryPoint = "setEntryPoint";
static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
static constexpr const char *kSPIRVElementTypesAttrName = "spirv_element_types";
static constexpr const char *kVulkanLaunch = "vulkanLaunch";
namespace {
class VulkanLaunchFuncToVulkanCallsPass
: public impl::ConvertVulkanLaunchFuncToVulkanCallsPassBase<
VulkanLaunchFuncToVulkanCallsPass> {
private:
void initializeCachedTypes() {
llvmFloatType = Float32Type::get(&getContext());
llvmVoidType = LLVM::LLVMVoidType::get(&getContext());
llvmPointerType = LLVM::LLVMPointerType::get(&getContext());
llvmInt32Type = IntegerType::get(&getContext(), 32);
llvmInt64Type = IntegerType::get(&getContext(), 64);
}
Type getMemRefType(uint32_t rank, Type elemenType) {
auto llvmArrayRankElementSizeType =
LLVM::LLVMArrayType::get(getInt64Type(), rank);
return LLVM::LLVMStructType::getLiteral(
&getContext(),
{llvmPointerType, llvmPointerType, getInt64Type(),
llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
}
Type getVoidType() { return llvmVoidType; }
Type getPointerType() { return llvmPointerType; }
Type getInt32Type() { return llvmInt32Type; }
Type getInt64Type() { return llvmInt64Type; }
Value createEntryPointNameConstant(StringRef name, Location loc,
OpBuilder &builder);
void declareVulkanFunctions(Location loc);
bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
return (callOp.getCallee() && *callOp.getCallee() == kVulkanLaunch &&
callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
}
bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
return (callOp.getCallee() &&
*callOp.getCallee() == kCInterfaceVulkanLaunch &&
callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
}
void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
Value vulkanRuntime);
void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
LogicalResult deduceMemRefRank(Value launchCallArg, uint32_t &rank);
StringRef stringifyType(Type type) {
if (isa<Float32Type>(type))
return "Float";
if (isa<Float16Type>(type))
return "Half";
if (auto intType = dyn_cast<IntegerType>(type)) {
if (intType.getWidth() == 32)
return "Int32";
if (intType.getWidth() == 16)
return "Int16";
if (intType.getWidth() == 8)
return "Int8";
}
llvm_unreachable("unsupported type");
}
public:
using Base::Base;
void runOnOperation() override;
private:
Type llvmFloatType;
Type llvmVoidType;
Type llvmPointerType;
Type llvmInt32Type;
Type llvmInt64Type;
struct SPIRVAttributes {
StringAttr blob;
StringAttr entryPoint;
SmallVector<Type> elementTypes;
};
SPIRVAttributes spirvAttributes;
static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
};
}
void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
initializeCachedTypes();
getOperation().walk([this](LLVM::CallOp op) {
if (isVulkanLaunchCallOp(op))
collectSPIRVAttributes(op);
});
getOperation().walk([this](LLVM::CallOp op) {
if (isCInterfaceVulkanLaunchCallOp(op))
translateVulkanLaunchCall(op);
});
}
void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
LLVM::CallOp vulkanLaunchCallOp) {
auto spirvBlobAttr =
vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
if (!spirvBlobAttr) {
vulkanLaunchCallOp.emitError()
<< "missing " << kSPIRVBlobAttrName << " attribute";
return signalPassFailure();
}
auto spirvEntryPointNameAttr =
vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
if (!spirvEntryPointNameAttr) {
vulkanLaunchCallOp.emitError()
<< "missing " << kSPIRVEntryPointAttrName << " attribute";
return signalPassFailure();
}
auto spirvElementTypesAttr =
vulkanLaunchCallOp->getAttrOfType<ArrayAttr>(kSPIRVElementTypesAttrName);
if (!spirvElementTypesAttr) {
vulkanLaunchCallOp.emitError()
<< "missing " << kSPIRVElementTypesAttrName << " attribute";
return signalPassFailure();
}
if (llvm::any_of(spirvElementTypesAttr,
[](Attribute attr) { return !isa<TypeAttr>(attr); })) {
vulkanLaunchCallOp.emitError()
<< "expected " << spirvElementTypesAttr << " to be an array of types";
return signalPassFailure();
}
spirvAttributes.blob = spirvBlobAttr;
spirvAttributes.entryPoint = spirvEntryPointNameAttr;
spirvAttributes.elementTypes =
llvm::to_vector(spirvElementTypesAttr.getAsValueRange<mlir::TypeAttr>());
}
void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
kVulkanLaunchNumConfigOperands)
return;
OpBuilder builder(cInterfaceVulkanLaunchCallOp);
Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
Value descriptorSet =
builder.create<LLVM::ConstantOp>(loc, getInt32Type(), 0);
for (auto [index, ptrToMemRefDescriptor] :
llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
kVulkanLaunchNumConfigOperands))) {
Value descriptorBinding =
builder.create<LLVM::ConstantOp>(loc, getInt32Type(), index);
if (index >= spirvAttributes.elementTypes.size()) {
cInterfaceVulkanLaunchCallOp.emitError()
<< kSPIRVElementTypesAttrName << " missing element type for "
<< ptrToMemRefDescriptor;
return signalPassFailure();
}
uint32_t rank = 0;
Type type = spirvAttributes.elementTypes[index];
if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
cInterfaceVulkanLaunchCallOp.emitError()
<< "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
return signalPassFailure();
}
auto symbolName =
llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
builder.create<LLVM::CallOp>(
loc, TypeRange(), StringRef(symbolName.data(), symbolName.size()),
ValueRange{vulkanRuntime, descriptorSet, descriptorBinding,
ptrToMemRefDescriptor});
}
}
LogicalResult
VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value launchCallArg,
uint32_t &rank) {
auto alloca = launchCallArg.getDefiningOp<LLVM::AllocaOp>();
if (!alloca)
return failure();
std::optional<Type> elementType = alloca.getElemType();
assert(elementType && "expected to work with opaque pointers");
auto llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(*elementType);
if (!llvmDescriptorTy)
return failure();
if (llvmDescriptorTy.getBody().size() == 3) {
rank = 0;
return success();
}
rank =
cast<LLVM::LLVMArrayType>(llvmDescriptorTy.getBody()[3]).getNumElements();
return success();
}
void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
ModuleOp module = getOperation();
auto builder = OpBuilder::atBlockEnd(module.getBody());
if (!module.lookupSymbol(kSetEntryPoint)) {
builder.create<LLVM::LLVMFuncOp>(
loc, kSetEntryPoint,
LLVM::LLVMFunctionType::get(getVoidType(),
{getPointerType(), getPointerType()}));
}
if (!module.lookupSymbol(kSetNumWorkGroups)) {
builder.create<LLVM::LLVMFuncOp>(
loc, kSetNumWorkGroups,
LLVM::LLVMFunctionType::get(getVoidType(),
{getPointerType(), getInt64Type(),
getInt64Type(), getInt64Type()}));
}
if (!module.lookupSymbol(kSetBinaryShader)) {
builder.create<LLVM::LLVMFuncOp>(
loc, kSetBinaryShader,
LLVM::LLVMFunctionType::get(
getVoidType(),
{getPointerType(), getPointerType(), getInt32Type()}));
}
if (!module.lookupSymbol(kRunOnVulkan)) {
builder.create<LLVM::LLVMFuncOp>(
loc, kRunOnVulkan,
LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
}
for (unsigned i = 1; i <= 3; i++) {
SmallVector<Type, 5> types{
Float32Type::get(&getContext()), IntegerType::get(&getContext(), 32),
IntegerType::get(&getContext(), 16), IntegerType::get(&getContext(), 8),
Float16Type::get(&getContext())};
for (auto type : types) {
std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
std::string(stringifyType(type));
if (isa<Float16Type>(type))
type = IntegerType::get(&getContext(), 16);
if (!module.lookupSymbol(fnName)) {
auto fnType = LLVM::LLVMFunctionType::get(
getVoidType(),
{llvmPointerType, getInt32Type(), getInt32Type(), llvmPointerType},
false);
builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
}
}
}
if (!module.lookupSymbol(kInitVulkan)) {
builder.create<LLVM::LLVMFuncOp>(
loc, kInitVulkan, LLVM::LLVMFunctionType::get(getPointerType(), {}));
}
if (!module.lookupSymbol(kDeinitVulkan)) {
builder.create<LLVM::LLVMFuncOp>(
loc, kDeinitVulkan,
LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
}
}
Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
StringRef name, Location loc, OpBuilder &builder) {
SmallString<16> shaderName(name.begin(), name.end());
shaderName.push_back('\0');
std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
shaderName, LLVM::Linkage::Internal);
}
void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
OpBuilder builder(cInterfaceVulkanLaunchCallOp);
Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
auto initVulkanCall = builder.create<LLVM::CallOp>(
loc, TypeRange{getPointerType()}, kInitVulkan);
auto vulkanRuntime = initVulkanCall.getResult();
Value ptrToSPIRVBinary = LLVM::createGlobalString(
loc, builder, kSPIRVBinary, spirvAttributes.blob.getValue(),
LLVM::Linkage::Internal);
Value binarySize = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(), spirvAttributes.blob.getValue().size());
createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
builder.create<LLVM::CallOp>(
loc, TypeRange(), kSetBinaryShader,
ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
Value entryPointName = createEntryPointNameConstant(
spirvAttributes.entryPoint.getValue(), loc, builder);
builder.create<LLVM::CallOp>(loc, TypeRange(), kSetEntryPoint,
ValueRange{vulkanRuntime, entryPointName});
builder.create<LLVM::CallOp>(
loc, TypeRange(), kSetNumWorkGroups,
ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
cInterfaceVulkanLaunchCallOp.getOperand(1),
cInterfaceVulkanLaunchCallOp.getOperand(2)});
builder.create<LLVM::CallOp>(loc, TypeRange(), kRunOnVulkan,
ValueRange{vulkanRuntime});
builder.create<LLVM::CallOp>(loc, TypeRange(), kDeinitVulkan,
ValueRange{vulkanRuntime});
declareVulkanFunctions(loc);
cInterfaceVulkanLaunchCallOp.erase();
}