#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
namespace mlir {
#define GEN_PASS_DEF_LOWERHOSTCODETOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
}
using namespace mlir;
static constexpr const char kSPIRVModule[] = "__spv__";
static std::string descriptorSetName() {
return llvm::convertToSnakeFromCamelCase(
stringifyDecoration(spirv::Decoration::DescriptorSet));
}
static std::string bindingName() {
return llvm::convertToSnakeFromCamelCase(
stringifyDecoration(spirv::Decoration::Binding));
}
static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
return binding.getInt();
}
static void copy(Location loc, Value dst, Value src, Value size,
OpBuilder &builder) {
builder.create<LLVM::MemcpyOp>(loc, dst, src, size, false);
}
static std::string
createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
StringRef kernelModuleName) {
IntegerAttr descriptorSet =
op->getAttrOfType<IntegerAttr>(descriptorSetName());
IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
kernelModuleName.str(), op.getSymName().str(),
std::to_string(descriptorSet.getInt()),
std::to_string(binding.getInt()));
}
static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) {
IntegerAttr descriptorSet =
op->getAttrOfType<IntegerAttr>(descriptorSetName());
IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
return descriptorSet && binding;
}
static LogicalResult getKernelGlobalVariables(
spirv::ModuleOp module,
DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) {
auto entryPoints = module.getOps<spirv::EntryPointOp>();
if (!llvm::hasSingleElement(entryPoints)) {
return module.emitError(
"The module must contain exactly one entry point function");
}
auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
for (auto globalOp : globalVariables) {
if (hasDescriptorSetAndBinding(globalOp))
globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp;
}
return success();
}
static LogicalResult encodeKernelName(spirv::ModuleOp module) {
StringRef spvModuleName = module.getSymName().value_or(kSPIRVModule);
auto entryPoints = module.getOps<spirv::EntryPointOp>();
if (!llvm::hasSingleElement(entryPoints)) {
return module.emitError(
"The module must contain exactly one entry point function");
}
spirv::EntryPointOp entryPoint = *entryPoints.begin();
StringRef funcName = entryPoint.getFn();
auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr());
StringAttr newFuncName =
StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
return failure();
SymbolTable::setSymbolName(funcOp, newFuncName);
return success();
}
namespace {
struct CopyInfo {
Value dst;
Value src;
Value size;
};
class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *op = launchOp.getOperation();
MLIRContext *context = rewriter.getContext();
auto module = launchOp->getParentOfType<ModuleOp>();
StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
StringAttr::get(context, spvModuleName));
if (!spvModule) {
return launchOp.emitOpError("SPIR-V kernel module '")
<< spvModuleName << "' is not found";
}
StringRef kernelFuncName = launchOp.getKernelName().getValue();
std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
StringAttr::get(context, newKernelFuncName));
if (!kernelFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), newKernelFuncName,
LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
ArrayRef<Type>()));
rewriter.setInsertionPoint(launchOp);
}
DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
if (failed(getKernelGlobalVariables(spvModule, globalVariableMap)))
return failure();
Location loc = launchOp.getLoc();
SmallVector<CopyInfo, 4> copyInfo;
auto numKernelOperands = launchOp.getNumKernelOperands();
auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
for (const auto &operand : llvm::enumerate(kernelOperands)) {
auto memRefType = dyn_cast<MemRefType>(
launchOp.getKernelOperand(operand.index()).getType());
if (!memRefType)
return failure();
SmallVector<Value, 4> sizes;
SmallVector<Value, 4> strides;
Value sizeBytes;
getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
sizeBytes);
MemRefDescriptor descriptor(operand.value());
Value src = descriptor.allocatedPtr(rewriter, loc);
spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
auto pointeeType =
cast<spirv::PointerType>(spirvGlobal.getType()).getPointeeType();
auto dstGlobalType = typeConverter->convertType(pointeeType);
if (!dstGlobalType)
return failure();
std::string name =
createGlobalVariableWithBindName(spirvGlobal, spvModuleName);
auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
if (!dstGlobal) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
dstGlobal = rewriter.create<LLVM::GlobalOp>(
loc, dstGlobalType,
false, LLVM::Linkage::Linkonce, name, Attribute(),
0);
rewriter.setInsertionPoint(launchOp);
}
Value dst = rewriter.create<LLVM::AddressOfOp>(
loc, typeConverter->convertType(spirvGlobal.getType()),
dstGlobal.getSymName());
copy(loc, dst, src, sizeBytes, rewriter);
CopyInfo info;
info.dst = dst;
info.src = src;
info.size = sizeBytes;
copyInfo.push_back(info);
}
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
ArrayRef<Value>());
for (CopyInfo info : copyInfo)
copy(loc, info.src, info.dst, info.size, rewriter);
return success();
}
};
class LowerHostCodeToLLVM
: public impl::LowerHostCodeToLLVMPassBase<LowerHostCodeToLLVM> {
public:
using Base::Base;
void runOnOperation() override {
ModuleOp module = getOperation();
for (auto gpuModule :
llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
gpuModule.erase();
for (auto func : module.getOps<func::FuncOp>()) {
func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
UnitAttr::get(&getContext()));
}
LowerToLLVMOptions options(module.getContext());
auto *context = module.getContext();
RewritePatternSet patterns(context);
LLVMTypeConverter typeConverter(context, options);
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
populateFuncToLLVMConversionPatterns(typeConverter, patterns);
patterns.add<GPULaunchLowering>(typeConverter);
populateSPIRVToLLVMTypeConversion(typeConverter);
ConversionTarget target(*context);
target.addLegalDialect<LLVM::LLVMDialect>();
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
for (auto spvModule : module.getOps<spirv::ModuleOp>()) {
if (failed(encodeKernelName(spvModule))) {
signalPassFailure();
return;
}
}
}
};
}