#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
namespace {
class SelectObjectAttrImpl
: public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
SelectObjectAttrImpl> {
public:
LogicalResult embedBinary(Attribute attribute, Operation *operation,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) const;
LogicalResult launchKernel(Attribute attribute,
Operation *launchFuncOperation,
Operation *binaryOperation,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) const;
gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) const;
};
std::string getBinaryIdentifier(StringRef binaryName) {
return binaryName.str() + "_bin_cst";
}
}
void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
});
}
gpu::ObjectAttr
SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const {
ArrayRef<Attribute> objects = op.getObjectsAttr().getValue();
int64_t index = -1;
if (Attribute target =
cast<gpu::SelectObjectAttr>(op.getOffloadingHandlerAttr())
.getTarget()) {
if (auto indexAttr = mlir::dyn_cast<IntegerAttr>(target)) {
index = indexAttr.getInt();
} else {
for (auto [i, attr] : llvm::enumerate(objects)) {
auto obj = mlir::dyn_cast<gpu::ObjectAttr>(attr);
if (obj.getTarget() == target) {
index = i;
}
}
}
} else {
index = 0;
}
if (index < 0 || index >= static_cast<int64_t>(objects.size())) {
op->emitError("the requested target object couldn't be found");
return nullptr;
}
return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
}
LogicalResult SelectObjectAttrImpl::embedBinary(
Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) const {
assert(operation && "The binary operation must be non null.");
if (!operation)
return failure();
auto op = mlir::dyn_cast<gpu::BinaryOp>(operation);
if (!op) {
operation->emitError("operation must be a GPU binary");
return failure();
}
gpu::ObjectAttr object = getSelectedObject(op);
if (!object)
return failure();
llvm::Module *module = moduleTranslation.getLLVMModule();
llvm::Constant *binary = llvm::ConstantDataArray::getString(
builder.getContext(), object.getObject().getValue(), false);
llvm::GlobalVariable *serializedObj =
new llvm::GlobalVariable(*module, binary->getType(), true,
llvm::GlobalValue::LinkageTypes::InternalLinkage,
binary, getBinaryIdentifier(op.getName()));
serializedObj->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
serializedObj->setAlignment(llvm::MaybeAlign(8));
serializedObj->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::None);
return success();
}
namespace llvm {
namespace {
class LaunchKernel {
public:
LaunchKernel(Module &module, IRBuilderBase &builder,
mlir::LLVM::ModuleTranslation &moduleTranslation);
FunctionCallee getKernelLaunchFn();
FunctionCallee getClusterKernelLaunchFn();
FunctionCallee getModuleFunctionFn();
FunctionCallee getModuleLoadFn();
FunctionCallee getModuleLoadJITFn();
FunctionCallee getModuleUnloadFn();
FunctionCallee getStreamCreateFn();
FunctionCallee getStreamDestroyFn();
FunctionCallee getStreamSyncFn();
Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName);
Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op);
llvm::LogicalResult createKernelLaunch(mlir::gpu::LaunchFuncOp op,
mlir::gpu::ObjectAttr object);
private:
Module &module;
IRBuilderBase &builder;
mlir::LLVM::ModuleTranslation &moduleTranslation;
Type *i32Ty{};
Type *i64Ty{};
Type *voidTy{};
Type *intPtrTy{};
PointerType *ptrTy{};
};
}
}
LogicalResult SelectObjectAttrImpl::launchKernel(
Attribute attribute, Operation *launchFuncOperation,
Operation *binaryOperation, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) const {
assert(launchFuncOperation && "The launch func operation must be non null.");
if (!launchFuncOperation)
return failure();
auto launchFuncOp = mlir::dyn_cast<gpu::LaunchFuncOp>(launchFuncOperation);
if (!launchFuncOp) {
launchFuncOperation->emitError("operation must be a GPU launch func Op.");
return failure();
}
auto binOp = mlir::dyn_cast<gpu::BinaryOp>(binaryOperation);
if (!binOp) {
binaryOperation->emitError("operation must be a GPU binary.");
return failure();
}
gpu::ObjectAttr object = getSelectedObject(binOp);
if (!object)
return failure();
return llvm::LaunchKernel(*moduleTranslation.getLLVMModule(), builder,
moduleTranslation)
.createKernelLaunch(launchFuncOp, object);
}
llvm::LaunchKernel::LaunchKernel(
Module &module, IRBuilderBase &builder,
mlir::LLVM::ModuleTranslation &moduleTranslation)
: module(module), builder(builder), moduleTranslation(moduleTranslation) {
i32Ty = builder.getInt32Ty();
i64Ty = builder.getInt64Ty();
ptrTy = builder.getPtrTy(0);
voidTy = builder.getVoidTy();
intPtrTy = builder.getIntPtrTy(module.getDataLayout());
}
llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
return module.getOrInsertFunction(
"mgpuLaunchKernel",
FunctionType::get(voidTy,
ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy,
intPtrTy, intPtrTy, intPtrTy, i32Ty,
ptrTy, ptrTy, ptrTy, i64Ty}),
false));
}
llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
return module.getOrInsertFunction(
"mgpuLaunchClusterKernel",
FunctionType::get(
voidTy,
ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
i32Ty, ptrTy, ptrTy, ptrTy}),
false));
}
llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
return module.getOrInsertFunction(
"mgpuModuleGetFunction",
FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, ptrTy}), false));
}
llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() {
return module.getOrInsertFunction(
"mgpuModuleLoad",
FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i64Ty}), false));
}
llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn() {
return module.getOrInsertFunction(
"mgpuModuleLoadJIT",
FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i32Ty}), false));
}
llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn() {
return module.getOrInsertFunction(
"mgpuModuleUnload",
FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
}
llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
return module.getOrInsertFunction("mgpuStreamCreate",
FunctionType::get(ptrTy, false));
}
llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() {
return module.getOrInsertFunction(
"mgpuStreamDestroy",
FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
}
llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
return module.getOrInsertFunction(
"mgpuStreamSynchronize",
FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
}
llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
StringRef kernelName) {
std::string globalName =
std::string(formatv("{0}_{1}_kernel_name", moduleName, kernelName));
if (GlobalVariable *gv = module.getGlobalVariable(globalName))
return gv;
return builder.CreateGlobalString(kernelName, globalName);
}
llvm::Value *
llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
SmallVector<Value *> args =
moduleTranslation.lookupValues(op.getKernelOperands());
SmallVector<Type *> structTypes(args.size(), nullptr);
for (auto [i, arg] : llvm::enumerate(args))
structTypes[i] = arg->getType();
Type *structTy = StructType::create(module.getContext(), structTypes);
Value *argStruct = builder.CreateAlloca(structTy, 0u);
Value *argArray = builder.CreateAlloca(
ptrTy, ConstantInt::get(intPtrTy, structTypes.size()));
for (auto [i, arg] : enumerate(args)) {
Value *structMember = builder.CreateStructGEP(structTy, argStruct, i);
builder.CreateStore(arg, structMember);
Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i);
builder.CreateStore(structMember, arrayMember);
}
return argArray;
}
llvm::LogicalResult
llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
mlir::gpu::ObjectAttr object) {
auto llvmValue = [&](mlir::Value value) -> Value * {
Value *v = moduleTranslation.lookupValue(value);
assert(v && "Value has not been translated.");
return v;
};
mlir::gpu::KernelDim3 grid = op.getGridSizeOperandValues();
Value *gx = llvmValue(grid.x), *gy = llvmValue(grid.y),
*gz = llvmValue(grid.z);
mlir::gpu::KernelDim3 block = op.getBlockSizeOperandValues();
Value *bx = llvmValue(block.x), *by = llvmValue(block.y),
*bz = llvmValue(block.z);
Value *dynamicMemorySize = nullptr;
if (mlir::Value dynSz = op.getDynamicSharedMemorySize())
dynamicMemorySize = llvmValue(dynSz);
else
dynamicMemorySize = ConstantInt::get(i32Ty, 0);
Value *argArray = createKernelArgArray(op);
llvm::Constant *optV = llvm::ConstantInt::get(i32Ty, 0);
DictionaryAttr objectProps = object.getProperties();
mlir::Attribute optAttr;
if (objectProps && (optAttr = objectProps.get("O"))) {
auto optLevel = dyn_cast<IntegerAttr>(optAttr);
if (!optLevel)
return op.emitError("the optimization level must be an integer");
optV = llvm::ConstantInt::get(i32Ty, optLevel.getValue());
}
StringRef moduleName = op.getKernelModuleName().getValue();
std::string binaryIdentifier = getBinaryIdentifier(moduleName);
Value *binary = module.getGlobalVariable(binaryIdentifier, true);
if (!binary)
return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary);
if (!binaryVar)
return op.emitError() << "Binary is not a global variable: "
<< binaryIdentifier;
llvm::Constant *binaryInit = binaryVar->getInitializer();
auto binaryDataSeq =
dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit);
if (!binaryDataSeq)
return op.emitError() << "Couldn't find binary data array: "
<< binaryIdentifier;
llvm::Constant *binarySize =
llvm::ConstantInt::get(i64Ty, binaryDataSeq->getNumElements() *
binaryDataSeq->getElementByteSize());
Value *moduleObject =
object.getFormat() == gpu::CompilationTarget::Assembly
? builder.CreateCall(getModuleLoadJITFn(), {binary, optV})
: builder.CreateCall(getModuleLoadFn(), {binary, binarySize});
Value *moduleFunction = builder.CreateCall(
getModuleFunctionFn(),
{moduleObject,
getOrCreateFunctionName(moduleName, op.getKernelName().getValue())});
Value *stream = nullptr;
bool handleStream = false;
if (mlir::Value asyncObject = op.getAsyncObject()) {
stream = llvmValue(asyncObject);
} else {
handleStream = true;
stream = builder.CreateCall(getStreamCreateFn(), {});
}
llvm::Constant *paramsCount =
llvm::ConstantInt::get(i64Ty, op.getNumKernelOperands());
Value *nullPtr = ConstantPointerNull::get(ptrTy);
if (op.hasClusterSize()) {
mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y),
*cz = llvmValue(cluster.z);
builder.CreateCall(
getClusterKernelLaunchFn(),
ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
dynamicMemorySize, stream, argArray, nullPtr}));
} else {
builder.CreateCall(getKernelLaunchFn(),
ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by,
bz, dynamicMemorySize, stream,
argArray, nullPtr, paramsCount}));
}
if (handleStream) {
builder.CreateCall(getStreamSyncFn(), {stream});
builder.CreateCall(getStreamDestroyFn(), {stream});
}
builder.CreateCall(getModuleUnloadFn(), {moduleObject});
return success();
}