#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/ReplaceConstant.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include <any>
#include <cstdint>
#include <iterator>
#include <numeric>
#include <optional>
#include <utility>
using namespace mlir;
namespace {
static llvm::omp::ScheduleKind
convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
if (!schedKind.has_value())
return llvm::omp::OMP_SCHEDULE_Default;
switch (schedKind.value()) {
case omp::ClauseScheduleKind::Static:
return llvm::omp::OMP_SCHEDULE_Static;
case omp::ClauseScheduleKind::Dynamic:
return llvm::omp::OMP_SCHEDULE_Dynamic;
case omp::ClauseScheduleKind::Guided:
return llvm::omp::OMP_SCHEDULE_Guided;
case omp::ClauseScheduleKind::Auto:
return llvm::omp::OMP_SCHEDULE_Auto;
case omp::ClauseScheduleKind::Runtime:
return llvm::omp::OMP_SCHEDULE_Runtime;
}
llvm_unreachable("unhandled schedule clause argument");
}
class OpenMPAllocaStackFrame
: public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame)
explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
: allocaInsertPoint(allocaIP) {}
llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
};
class OpenMPVarMappingStackFrame
: public LLVM::ModuleTranslation::StackFrameBase<
OpenMPVarMappingStackFrame> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPVarMappingStackFrame)
explicit OpenMPVarMappingStackFrame(
const DenseMap<Value, llvm::Value *> &mapping)
: mapping(mapping) {}
DenseMap<Value, llvm::Value *> mapping;
};
}
static llvm::OpenMPIRBuilder::InsertPointTy
findAllocaInsertPoint(llvm::IRBuilderBase &builder,
const LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
[&](const OpenMPAllocaStackFrame &frame) {
allocaInsertPoint = frame.allocaInsertPoint;
return WalkResult::interrupt();
});
if (walkResult.wasInterrupted())
return allocaInsertPoint;
if (builder.GetInsertBlock() ==
&builder.GetInsertBlock()->getParent()->getEntryBlock()) {
assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
"Assuming end of basic block");
llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
builder.GetInsertBlock()->getNextNode());
builder.CreateBr(entryBB);
builder.SetInsertPoint(entryBB);
}
llvm::BasicBlock &funcEntryBlock =
builder.GetInsertBlock()->getParent()->getEntryBlock();
return llvm::OpenMPIRBuilder::InsertPointTy(
&funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
}
static llvm::BasicBlock *convertOmpOpRegions(
Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation, LogicalResult &bodyGenStatus,
SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
llvm::BasicBlock *continuationBlock =
splitBB(builder, true, "omp.region.cont");
llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
llvm::LLVMContext &llvmContext = builder.getContext();
for (Block &bb : region) {
llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
llvmContext, blockName, builder.GetInsertBlock()->getParent(),
builder.GetInsertBlock()->getNextNode());
moduleTranslation.mapBlock(&bb, llvmBB);
}
llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
SmallVector<llvm::Type *> continuationBlockPHITypes;
bool operandsProcessed = false;
unsigned numYields = 0;
for (Block &bb : region.getBlocks()) {
if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
if (!operandsProcessed) {
for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
continuationBlockPHITypes.push_back(
moduleTranslation.convertType(yield->getOperand(i).getType()));
}
operandsProcessed = true;
} else {
assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
"mismatching number of values yielded from the region");
for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
llvm::Type *operandType =
moduleTranslation.convertType(yield->getOperand(i).getType());
(void)operandType;
assert(continuationBlockPHITypes[i] == operandType &&
"values of mismatching types yielded from the region");
}
}
numYields++;
}
}
if (!continuationBlockPHITypes.empty())
assert(
continuationBlockPHIs &&
"expected continuation block PHIs if converted regions yield values");
if (continuationBlockPHIs) {
llvm::IRBuilderBase::InsertPointGuard guard(builder);
continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
for (llvm::Type *ty : continuationBlockPHITypes)
continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
}
SetVector<Block *> blocks = getBlocksSortedByDominance(region);
for (Block *bb : blocks) {
llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
if (bb->isEntryBlock()) {
assert(sourceTerminator->getNumSuccessors() == 1 &&
"provided entry block has multiple successors");
assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
"ContinuationBlock is not the successor of the entry block");
sourceTerminator->setSuccessor(0, llvmBB);
}
llvm::IRBuilderBase::InsertPointGuard guard(builder);
if (failed(
moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) {
bodyGenStatus = failure();
return continuationBlock;
}
Operation *terminator = bb->getTerminator();
if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
builder.CreateBr(continuationBlock);
for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
(*continuationBlockPHIs)[i]->addIncoming(
moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
}
}
LLVM::detail::connectPHINodes(region, moduleTranslation);
moduleTranslation.forgetMapping(region);
return continuationBlock;
}
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
switch (kind) {
case omp::ClauseProcBindKind::Close:
return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
case omp::ClauseProcBindKind::Master:
return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
case omp::ClauseProcBindKind::Primary:
return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
case omp::ClauseProcBindKind::Spread:
return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
}
llvm_unreachable("Unknown ClauseProcBindKind kind");
}
static LogicalResult
convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
auto maskedOp = cast<omp::MaskedOp>(opInst);
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
LogicalResult bodyGenStatus = success();
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
auto ®ion = maskedOp.getRegion();
builder.restoreIP(codeGenIP);
convertOmpOpRegions(region, "omp.masked.region", builder, moduleTranslation,
bodyGenStatus);
};
auto finiCB = [&](InsertPointTy codeGenIP) {};
llvm::Value *filterVal = nullptr;
if (auto filterVar = maskedOp.getFilteredThreadId()) {
filterVal = moduleTranslation.lookupValue(filterVar);
} else {
llvm::LLVMContext &llvmContext = builder.getContext();
filterVal =
llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
}
assert(filterVal != nullptr);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createMasked(
ompLoc, bodyGenCB, finiCB, filterVal));
return success();
}
static LogicalResult
convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
LogicalResult bodyGenStatus = success();
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
auto ®ion = cast<omp::MasterOp>(opInst).getRegion();
builder.restoreIP(codeGenIP);
convertOmpOpRegions(region, "omp.master.region", builder, moduleTranslation,
bodyGenStatus);
};
auto finiCB = [&](InsertPointTy codeGenIP) {};
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createMaster(
ompLoc, bodyGenCB, finiCB));
return success();
}
static LogicalResult
convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
auto criticalOp = cast<omp::CriticalOp>(opInst);
LogicalResult bodyGenStatus = success();
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
builder.restoreIP(codeGenIP);
convertOmpOpRegions(region, "omp.critical.region", builder,
moduleTranslation, bodyGenStatus);
};
auto finiCB = [&](InsertPointTy codeGenIP) {};
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
llvm::Constant *hint = nullptr;
if (criticalOp.getNameAttr()) {
auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
auto criticalDeclareOp =
SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
symbolRef);
hint = llvm::ConstantInt::get(
llvm::Type::getInt32Ty(llvmContext),
static_cast<int>(criticalDeclareOp.getHintVal()));
}
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical(
ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(""), hint));
return success();
}
template <typename T>
static void
collectReductionDecls(T loop,
SmallVectorImpl<omp::DeclareReductionOp> &reductions) {
std::optional<ArrayAttr> attr = loop.getReductions();
if (!attr)
return;
reductions.reserve(reductions.size() + loop.getNumReductionVars());
for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
reductions.push_back(
SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
loop, symbolRef));
}
}
static LogicalResult inlineConvertOmpRegions(
Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
if (region.empty())
return success();
if (llvm::hasSingleElement(region)) {
llvm::Instruction *potentialTerminator =
builder.GetInsertBlock()->empty() ? nullptr
: &builder.GetInsertBlock()->back();
if (potentialTerminator && potentialTerminator->isTerminator())
potentialTerminator->removeFromParent();
moduleTranslation.mapBlock(®ion.front(), builder.GetInsertBlock());
if (failed(moduleTranslation.convertBlock(
region.front(), true, builder)))
return failure();
if (continuationBlockArgs)
llvm::append_range(
*continuationBlockArgs,
moduleTranslation.lookupValues(region.front().back().getOperands()));
moduleTranslation.forgetMapping(region);
if (potentialTerminator && potentialTerminator->isTerminator()) {
llvm::BasicBlock *block = builder.GetInsertBlock();
if (block->empty()) {
potentialTerminator->insertInto(block, block->begin());
} else {
potentialTerminator->insertAfter(&block->back());
}
}
return success();
}
LogicalResult bodyGenStatus = success();
SmallVector<llvm::PHINode *> phis;
llvm::BasicBlock *continuationBlock = convertOmpOpRegions(
region, blockName, builder, moduleTranslation, bodyGenStatus, &phis);
if (failed(bodyGenStatus))
return failure();
if (continuationBlockArgs)
llvm::append_range(*continuationBlockArgs, phis);
builder.SetInsertPoint(continuationBlock,
continuationBlock->getFirstInsertionPt());
return success();
}
namespace {
using OwningReductionGen = std::function<llvm::OpenMPIRBuilder::InsertPointTy(
llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
llvm::Value *&)>;
using OwningAtomicReductionGen =
std::function<llvm::OpenMPIRBuilder::InsertPointTy(
llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
llvm::Value *)>;
}
static OwningReductionGen
makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
OwningReductionGen gen =
[&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
llvm::Value *lhs, llvm::Value *rhs,
llvm::Value *&result) mutable {
Region &reductionRegion = decl.getReductionRegion();
moduleTranslation.mapValue(reductionRegion.front().getArgument(0), lhs);
moduleTranslation.mapValue(reductionRegion.front().getArgument(1), rhs);
builder.restoreIP(insertPoint);
SmallVector<llvm::Value *> phis;
if (failed(inlineConvertOmpRegions(reductionRegion,
"omp.reduction.nonatomic.body",
builder, moduleTranslation, &phis)))
return llvm::OpenMPIRBuilder::InsertPointTy();
assert(phis.size() == 1);
result = phis[0];
return builder.saveIP();
};
return gen;
}
static OwningAtomicReductionGen
makeAtomicReductionGen(omp::DeclareReductionOp decl,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
if (decl.getAtomicReductionRegion().empty())
return OwningAtomicReductionGen();
OwningAtomicReductionGen atomicGen =
[&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
llvm::Value *lhs, llvm::Value *rhs) mutable {
Region &atomicRegion = decl.getAtomicReductionRegion();
moduleTranslation.mapValue(atomicRegion.front().getArgument(0), lhs);
moduleTranslation.mapValue(atomicRegion.front().getArgument(1), rhs);
builder.restoreIP(insertPoint);
SmallVector<llvm::Value *> phis;
if (failed(inlineConvertOmpRegions(atomicRegion,
"omp.reduction.atomic.body", builder,
moduleTranslation, &phis)))
return llvm::OpenMPIRBuilder::InsertPointTy();
assert(phis.empty());
return builder.saveIP();
};
return atomicGen;
}
static LogicalResult
convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
auto orderedOp = cast<omp::OrderedOp>(opInst);
omp::ClauseDepend dependType = *orderedOp.getDependTypeVal();
bool isDependSource = dependType == omp::ClauseDepend::dependsource;
unsigned numLoops = *orderedOp.getNumLoopsVal();
SmallVector<llvm::Value *> vecValues =
moduleTranslation.lookupValues(orderedOp.getDependVecVars());
size_t indexVecValues = 0;
while (indexVecValues < vecValues.size()) {
SmallVector<llvm::Value *> storeValues;
storeValues.reserve(numLoops);
for (unsigned i = 0; i < numLoops; i++) {
storeValues.push_back(vecValues[indexVecValues]);
indexVecValues++;
}
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
}
return success();
}
static LogicalResult
convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
if (orderedRegionOp.getSimd())
return failure();
LogicalResult bodyGenStatus = success();
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
builder.restoreIP(codeGenIP);
convertOmpOpRegions(region, "omp.ordered.region", builder,
moduleTranslation, bodyGenStatus);
};
auto finiCB = [&](InsertPointTy codeGenIP) {};
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(
moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getSimd()));
return bodyGenStatus;
}
template <typename T>
static void allocByValReductionVars(
T loop, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
SmallVectorImpl<llvm::Value *> &privateReductionVariables,
DenseMap<Value, llvm::Value *> &reductionVariableMap,
llvm::ArrayRef<bool> isByRefs) {
llvm::IRBuilderBase::InsertPointGuard guard(builder);
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
if (isByRefs[i])
continue;
llvm::Value *var = builder.CreateAlloca(
moduleTranslation.convertType(reductionDecls[i].getType()));
moduleTranslation.mapValue(reductionArgs[i], var);
privateReductionVariables[i] = var;
reductionVariableMap.try_emplace(loop.getReductionVars()[i], var);
}
}
template <typename T>
static void
mapInitializationArg(T loop, LLVM::ModuleTranslation &moduleTranslation,
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
unsigned i) {
mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
Region &initializerRegion = reduction.getInitializerRegion();
Block &entry = initializerRegion.front();
assert(entry.getNumArguments() == 1 &&
"the initialization region has one argument");
mlir::Value mlirSource = loop.getReductionVars()[i];
llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
assert(llvmSource && "lookup reduction var");
moduleTranslation.mapValue(entry.getArgument(0), llvmSource);
}
template <typename T>
static void collectReductionInfo(
T loop, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
SmallVectorImpl<OwningReductionGen> &owningReductionGens,
SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
const ArrayRef<llvm::Value *> privateReductionVariables,
SmallVectorImpl<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos) {
unsigned numReductions = loop.getNumReductionVars();
for (unsigned i = 0; i < numReductions; ++i) {
owningReductionGens.push_back(
makeReductionGen(reductionDecls[i], builder, moduleTranslation));
owningAtomicReductionGens.push_back(
makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
}
reductionInfos.reserve(numReductions);
for (unsigned i = 0; i < numReductions; ++i) {
llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
if (owningAtomicReductionGens[i])
atomicGen = owningAtomicReductionGens[i];
llvm::Value *variable =
moduleTranslation.lookupValue(loop.getReductionVars()[i]);
reductionInfos.push_back(
{moduleTranslation.convertType(reductionDecls[i].getType()), variable,
privateReductionVariables[i],
llvm::OpenMPIRBuilder::EvalKind::Scalar,
owningReductionGens[i],
nullptr, atomicGen});
}
}
static LogicalResult
inlineOmpRegionCleanup(llvm::SmallVectorImpl<Region *> &cleanupRegions,
llvm::ArrayRef<llvm::Value *> privateVariables,
LLVM::ModuleTranslation &moduleTranslation,
llvm::IRBuilderBase &builder, StringRef regionName,
bool shouldLoadCleanupRegionArg = true) {
for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
if (cleanupRegion->empty())
continue;
Block &entry = cleanupRegion->front();
llvm::Instruction *potentialTerminator =
builder.GetInsertBlock()->empty() ? nullptr
: &builder.GetInsertBlock()->back();
if (potentialTerminator && potentialTerminator->isTerminator())
builder.SetInsertPoint(potentialTerminator);
llvm::Value *prviateVarValue =
shouldLoadCleanupRegionArg
? builder.CreateLoad(
moduleTranslation.convertType(entry.getArgument(0).getType()),
privateVariables[i])
: privateVariables[i];
moduleTranslation.mapValue(entry.getArgument(0), prviateVarValue);
if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder,
moduleTranslation)))
return failure();
moduleTranslation.forgetMapping(*cleanupRegion);
}
return success();
}
template <class OP>
static LogicalResult createReductionsAndCleanup(
OP op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef) {
if (op.getNumReductionVars() == 0)
return success();
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
SmallVector<OwningReductionGen> owningReductionGens;
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
owningReductionGens, owningAtomicReductionGens,
privateReductionVariables, reductionInfos);
llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
builder.SetInsertPoint(tempTerminator);
llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
isByRef, op.getNowait());
if (!contInsertPoint.getBlock())
return op->emitOpError() << "failed to convert reductions";
auto nextInsertionPoint =
ompBuilder->createBarrier(contInsertPoint, llvm::omp::OMPD_for);
tempTerminator->eraseFromParent();
builder.restoreIP(nextInsertionPoint);
SmallVector<Region *> reductionRegions;
llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
[](omp::DeclareReductionOp reductionDecl) {
return &reductionDecl.getCleanupRegion();
});
return inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
moduleTranslation, builder,
"omp.reduction.cleanup");
return success();
}
static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
if (!attr)
return {};
return *attr;
}
template <typename OP>
static LogicalResult allocAndInitializeReductionVars(
OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
SmallVectorImpl<llvm::Value *> &privateReductionVariables,
DenseMap<Value, llvm::Value *> &reductionVariableMap,
llvm::ArrayRef<bool> isByRef) {
if (op.getNumReductionVars() == 0)
return success();
allocByValReductionVars(op, reductionArgs, builder, moduleTranslation,
allocaIP, reductionDecls, privateReductionVariables,
reductionVariableMap, isByRef);
for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
SmallVector<llvm::Value *> phis;
mapInitializationArg(op, moduleTranslation, reductionDecls, i);
if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
"omp.reduction.neutral", builder,
moduleTranslation, &phis)))
return failure();
assert(phis.size() == 1 && "expected one value to be yielded from the "
"reduction neutral element declaration region");
if (isByRef[i]) {
llvm::Value *var = builder.CreateAlloca(
moduleTranslation.convertType(reductionDecls[i].getType()));
builder.CreateStore(phis[0], var);
privateReductionVariables[i] = var;
moduleTranslation.mapValue(reductionArgs[i], phis[0]);
reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
} else {
builder.CreateStore(phis[0], privateReductionVariables[i]);
}
moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
}
return success();
}
static LogicalResult
convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
using StorableBodyGenCallbackTy =
llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
auto sectionsOp = cast<omp::SectionsOp>(opInst);
if (!sectionsOp.getAllocateVars().empty() ||
!sectionsOp.getAllocatorsVars().empty())
return emitError(sectionsOp.getLoc())
<< "allocate clause is not supported for sections construct";
llvm::ArrayRef<bool> isByRef = getIsByRef(sectionsOp.getReductionVarsByref());
assert(isByRef.size() == sectionsOp.getNumReductionVars());
SmallVector<omp::DeclareReductionOp> reductionDecls;
collectReductionDecls(sectionsOp, reductionDecls);
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
SmallVector<llvm::Value *> privateReductionVariables(
sectionsOp.getNumReductionVars());
DenseMap<Value, llvm::Value *> reductionVariableMap;
MutableArrayRef<BlockArgument> reductionArgs =
sectionsOp.getRegion().getArguments();
if (failed(allocAndInitializeReductionVars(
sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
reductionDecls, privateReductionVariables, reductionVariableMap,
isByRef)))
return failure();
LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
moduleTranslation, reductionVariableMap);
LogicalResult bodyGenStatus = success();
SmallVector<StorableBodyGenCallbackTy> sectionCBs;
for (Operation &op : *sectionsOp.getRegion().begin()) {
auto sectionOp = dyn_cast<omp::SectionOp>(op);
if (!sectionOp)
continue;
Region ®ion = sectionOp.getRegion();
auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation,
&bodyGenStatus](InsertPointTy allocaIP,
InsertPointTy codeGenIP) {
builder.restoreIP(codeGenIP);
assert(region.getNumArguments() ==
sectionsOp.getRegion().getNumArguments());
for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
sectionsOp.getRegion().getArguments(), region.getArguments())) {
llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
assert(llvmVal);
moduleTranslation.mapValue(sectionArg, llvmVal);
}
convertOmpOpRegions(region, "omp.section.region", builder,
moduleTranslation, bodyGenStatus);
};
sectionCBs.push_back(sectionCB);
}
if (sectionCBs.empty())
return success();
assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
llvm::Value &vPtr,
llvm::Value *&replacementValue) -> InsertPointTy {
replacementValue = &vPtr;
return codeGenIP;
};
auto finiCB = [&](InsertPointTy codeGenIP) {};
allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createSections(
ompLoc, allocaIP, sectionCBs, privCB, finiCB, false,
sectionsOp.getNowait()));
if (failed(bodyGenStatus))
return bodyGenStatus;
return createReductionsAndCleanup(sectionsOp, builder, moduleTranslation,
allocaIP, reductionDecls,
privateReductionVariables, isByRef);
}
static LogicalResult
convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
LogicalResult bodyGenStatus = success();
auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
builder.restoreIP(codegenIP);
convertOmpOpRegions(singleOp.getRegion(), "omp.single.region", builder,
moduleTranslation, bodyGenStatus);
};
auto finiCB = [&](InsertPointTy codeGenIP) {};
Operation::operand_range cpVars = singleOp.getCopyprivateVars();
std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateFuncs();
llvm::SmallVector<llvm::Value *> llvmCPVars;
llvm::SmallVector<llvm::Function *> llvmCPFuncs;
for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
llvmCPFuncs.push_back(
moduleTranslation.lookupFunction(llvmFuncOp.getName()));
}
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createSingle(
ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars, llvmCPFuncs));
return bodyGenStatus;
}
static LogicalResult
convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
LogicalResult bodyGenStatus = success();
if (!op.getAllocatorsVars().empty() || op.getReductions())
return op.emitError("unhandled clauses for translation to LLVM IR");
auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
moduleTranslation, allocaIP);
builder.restoreIP(codegenIP);
convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
moduleTranslation, bodyGenStatus);
};
llvm::Value *numTeamsLower = nullptr;
if (Value numTeamsLowerVar = op.getNumTeamsLower())
numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
llvm::Value *numTeamsUpper = nullptr;
if (Value numTeamsUpperVar = op.getNumTeamsUpper())
numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);
llvm::Value *threadLimit = nullptr;
if (Value threadLimitVar = op.getThreadLimit())
threadLimit = moduleTranslation.lookupValue(threadLimitVar);
llvm::Value *ifExpr = nullptr;
if (Value ifExprVar = op.getIfExpr())
ifExpr = moduleTranslation.lookupValue(ifExprVar);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTeams(
ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr));
return bodyGenStatus;
}
static void
buildDependData(std::optional<ArrayAttr> depends, OperandRange dependVars,
LLVM::ModuleTranslation &moduleTranslation,
SmallVectorImpl<llvm::OpenMPIRBuilder::DependData> &dds) {
if (dependVars.empty())
return;
for (auto dep : llvm::zip(dependVars, depends->getValue())) {
llvm::omp::RTLDependenceKindTy type;
switch (
cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
case mlir::omp::ClauseTaskDepend::taskdependin:
type = llvm::omp::RTLDependenceKindTy::DepIn;
break;
case mlir::omp::ClauseTaskDepend::taskdependout:
case mlir::omp::ClauseTaskDepend::taskdependinout:
type = llvm::omp::RTLDependenceKindTy::DepInOut;
break;
};
llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
dds.emplace_back(dd);
}
}
static LogicalResult
convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
LogicalResult bodyGenStatus = success();
if (taskOp.getUntiedAttr() || taskOp.getMergeableAttr() ||
taskOp.getInReductions() || taskOp.getPriority() ||
!taskOp.getAllocateVars().empty()) {
return taskOp.emitError("unhandled clauses for translation to LLVM IR");
}
auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
moduleTranslation, allocaIP);
builder.restoreIP(codegenIP);
convertOmpOpRegions(taskOp.getRegion(), "omp.task.region", builder,
moduleTranslation, bodyGenStatus);
};
SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
buildDependData(taskOp.getDepends(), taskOp.getDependVars(),
moduleTranslation, dds);
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTask(
ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
moduleTranslation.lookupValue(taskOp.getFinalExpr()),
moduleTranslation.lookupValue(taskOp.getIfExpr()), dds));
return bodyGenStatus;
}
static LogicalResult
convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
LogicalResult bodyGenStatus = success();
if (!tgOp.getTaskReductionVars().empty() || !tgOp.getAllocateVars().empty()) {
return tgOp.emitError("unhandled clauses for translation to LLVM IR");
}
auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
builder.restoreIP(codegenIP);
convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region", builder,
moduleTranslation, bodyGenStatus);
};
InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTaskgroup(
ompLoc, allocaIP, bodyCB));
return bodyGenStatus;
}
static LogicalResult
convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
auto wsloopOp = cast<omp::WsloopOp>(opInst);
auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionVarsByref());
assert(isByRef.size() == wsloopOp.getNumReductionVars());
auto schedule =
wsloopOp.getScheduleVal().value_or(omp::ClauseScheduleKind::Static);
llvm::Value *step = moduleTranslation.lookupValue(loopOp.getStep()[0]);
llvm::Type *ivType = step->getType();
llvm::Value *chunk = nullptr;
if (wsloopOp.getScheduleChunkVar()) {
llvm::Value *chunkVar =
moduleTranslation.lookupValue(wsloopOp.getScheduleChunkVar());
chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
}
SmallVector<omp::DeclareReductionOp> reductionDecls;
collectReductionDecls(wsloopOp, reductionDecls);
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
SmallVector<llvm::Value *> privateReductionVariables(
wsloopOp.getNumReductionVars());
DenseMap<Value, llvm::Value *> reductionVariableMap;
MutableArrayRef<BlockArgument> reductionArgs =
wsloopOp.getRegion().getArguments();
if (failed(allocAndInitializeReductionVars(
wsloopOp, reductionArgs, builder, moduleTranslation, allocaIP,
reductionDecls, privateReductionVariables, reductionVariableMap,
isByRef)))
return failure();
LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
moduleTranslation, reductionVariableMap);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
LogicalResult bodyGenStatus = success();
auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
moduleTranslation.mapValue(
loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
bodyInsertPoints.push_back(ip);
if (loopInfos.size() != loopOp.getNumLoops() - 1)
return;
builder.restoreIP(ip);
convertOmpOpRegions(loopOp.getRegion(), "omp.wsloop.region", builder,
moduleTranslation, bodyGenStatus);
};
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
llvm::Value *lowerBound =
moduleTranslation.lookupValue(loopOp.getLowerBound()[i]);
llvm::Value *upperBound =
moduleTranslation.lookupValue(loopOp.getUpperBound()[i]);
llvm::Value *step = moduleTranslation.lookupValue(loopOp.getStep()[i]);
llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
if (i != 0) {
loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back());
computeIP = loopInfos.front()->getPreheaderIP();
}
loopInfos.push_back(ompBuilder->createCanonicalLoop(
loc, bodyGen, lowerBound, upperBound, step,
true, loopOp.getInclusive(), computeIP));
if (failed(bodyGenStatus))
return failure();
}
llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
llvm::CanonicalLoopInfo *loopInfo =
ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
bool isOrdered = wsloopOp.getOrderedVal().has_value();
std::optional<omp::ScheduleModifier> scheduleModifier =
wsloopOp.getScheduleModifier();
bool isSimd = wsloopOp.getSimdModifier();
ompBuilder->applyWorkshareLoop(
ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(),
convertToScheduleKind(schedule), chunk, isSimd,
scheduleModifier == omp::ScheduleModifier::monotonic,
scheduleModifier == omp::ScheduleModifier::nonmonotonic, isOrdered);
builder.restoreIP(afterIP);
return createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
allocaIP, reductionDecls,
privateReductionVariables, isByRef);
}
class OmpParallelOpConversionManager {
public:
OmpParallelOpConversionManager(omp::ParallelOp opInst)
: region(opInst.getRegion()), privateVars(opInst.getPrivateVars()),
privateArgBeginIdx(opInst.getNumReductionVars()),
privateArgEndIdx(privateArgBeginIdx + privateVars.size()) {
auto privateVarsIt = privateVars.begin();
for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
++argIdx, ++privateVarsIt)
mlir::replaceAllUsesInRegionWith(region.getArgument(argIdx),
*privateVarsIt, region);
}
~OmpParallelOpConversionManager() {
auto privateVarsIt = privateVars.begin();
for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
++argIdx, ++privateVarsIt)
mlir::replaceAllUsesInRegionWith(*privateVarsIt,
region.getArgument(argIdx), region);
}
private:
Region ®ion;
OperandRange privateVars;
unsigned privateArgBeginIdx;
unsigned privateArgEndIdx;
};
static LogicalResult
convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
OmpParallelOpConversionManager raii(opInst);
ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionVarsByref());
assert(isByRef.size() == opInst.getNumReductionVars());
LogicalResult bodyGenStatus = success();
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
SmallVector<omp::DeclareReductionOp> reductionDecls;
collectReductionDecls(opInst, reductionDecls);
SmallVector<llvm::Value *> privateReductionVariables(
opInst.getNumReductionVars());
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
DenseMap<Value, llvm::Value *> reductionVariableMap;
MutableArrayRef<BlockArgument> reductionArgs =
opInst.getRegion().getArguments().slice(
opInst.getNumAllocateVars() + opInst.getNumAllocatorsVars(),
opInst.getNumReductionVars());
allocByValReductionVars(opInst, reductionArgs, builder, moduleTranslation,
allocaIP, reductionDecls, privateReductionVariables,
reductionVariableMap, isByRef);
builder.restoreIP(allocaIP);
llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
allocaIP =
InsertPointTy(allocaIP.getBlock(),
allocaIP.getBlock()->getTerminator()->getIterator());
SmallVector<llvm::Value *> byRefVars(opInst.getNumReductionVars());
for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
if (isByRef[i]) {
byRefVars[i] = builder.CreateAlloca(
moduleTranslation.convertType(reductionDecls[i].getType()));
}
}
builder.SetInsertPoint(initBlock->getFirstNonPHIOrDbgOrAlloca());
for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
SmallVector<llvm::Value *> phis;
mapInitializationArg(opInst, moduleTranslation, reductionDecls, i);
if (failed(inlineConvertOmpRegions(
reductionDecls[i].getInitializerRegion(), "omp.reduction.neutral",
builder, moduleTranslation, &phis)))
bodyGenStatus = failure();
assert(phis.size() == 1 &&
"expected one value to be yielded from the "
"reduction neutral element declaration region");
builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
if (isByRef[i]) {
builder.CreateStore(phis[0], byRefVars[i]);
privateReductionVariables[i] = byRefVars[i];
moduleTranslation.mapValue(reductionArgs[i], phis[0]);
reductionVariableMap.try_emplace(opInst.getReductionVars()[i], phis[0]);
} else {
builder.CreateStore(phis[0], privateReductionVariables[i]);
}
moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
}
LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
moduleTranslation, reductionVariableMap);
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
moduleTranslation, allocaIP);
builder.restoreIP(codeGenIP);
auto regionBlock =
convertOmpOpRegions(opInst.getRegion(), "omp.par.region", builder,
moduleTranslation, bodyGenStatus);
if (opInst.getNumReductionVars() > 0) {
SmallVector<OwningReductionGen> owningReductionGens;
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
owningReductionGens, owningAtomicReductionGens,
privateReductionVariables, reductionInfos);
builder.SetInsertPoint(regionBlock->getTerminator());
llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
builder.SetInsertPoint(tempTerminator);
llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
ompBuilder->createReductions(builder.saveIP(), allocaIP,
reductionInfos, isByRef, false);
if (!contInsertPoint.getBlock()) {
bodyGenStatus = opInst->emitOpError() << "failed to convert reductions";
return;
}
tempTerminator->eraseFromParent();
builder.restoreIP(contInsertPoint);
}
};
SmallVector<omp::PrivateClauseOp> privatizerClones;
SmallVector<llvm::Value *> privateVariables;
auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
llvm::Value &, llvm::Value &vPtr,
llvm::Value *&replacementValue) -> InsertPointTy {
replacementValue = &vPtr;
auto [privVar, privatizerClone] =
[&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
if (!opInst.getPrivateVars().empty()) {
auto privVars = opInst.getPrivateVars();
auto privatizers = opInst.getPrivatizers();
for (auto [privVar, privatizerAttr] :
llvm::zip_equal(privVars, *privatizers)) {
llvm::Value *llvmPrivVar = moduleTranslation.lookupValue(privVar);
if (llvmPrivVar != &vPtr)
continue;
SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
omp::PrivateClauseOp privatizer =
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
opInst, privSym);
MLIRContext &context = moduleTranslation.getContext();
mlir::IRRewriter opCloner(&context);
opCloner.setInsertionPoint(privatizer);
auto clone = llvm::cast<mlir::omp::PrivateClauseOp>(
opCloner.clone(*privatizer));
unsigned counter = 0;
SmallString<256> cloneName = SymbolTable::generateSymbolName<256>(
privatizer.getSymName(),
[&](llvm::StringRef candidate) {
return SymbolTable::lookupNearestSymbolFrom(
opInst, StringAttr::get(&context, candidate)) !=
nullptr;
},
counter);
clone.setSymName(cloneName);
return {privVar, clone};
}
}
return {mlir::Value(), omp::PrivateClauseOp()};
}();
if (privVar) {
Region &allocRegion = privatizerClone.getAllocRegion();
if (privatizerClone.getDataSharingType() ==
omp::DataSharingClauseType::FirstPrivate) {
auto oldAllocBackBlock = std::prev(allocRegion.end());
omp::YieldOp oldAllocYieldOp =
llvm::cast<omp::YieldOp>(oldAllocBackBlock->getTerminator());
Region ©Region = privatizerClone.getCopyRegion();
mlir::IRRewriter copyCloneBuilder(&moduleTranslation.getContext());
copyCloneBuilder.cloneRegionBefore(copyRegion, allocRegion,
allocRegion.end());
auto newCopyRegionFrontBlock = std::next(oldAllocBackBlock);
copyCloneBuilder.mergeBlocks(
&*newCopyRegionFrontBlock, &*oldAllocBackBlock,
{allocRegion.getArgument(0), oldAllocYieldOp.getOperand(0)});
oldAllocYieldOp.erase();
}
auto allocRegionArg = allocRegion.getArgument(0);
replaceAllUsesInRegionWith(allocRegionArg, privVar, allocRegion);
auto oldIP = builder.saveIP();
builder.restoreIP(allocaIP);
SmallVector<llvm::Value *, 1> yieldedValues;
if (failed(inlineConvertOmpRegions(allocRegion, "omp.privatizer", builder,
moduleTranslation, &yieldedValues))) {
opInst.emitError("failed to inline `alloc` region of an `omp.private` "
"op in the parallel region");
bodyGenStatus = failure();
privatizerClone.erase();
} else {
assert(yieldedValues.size() == 1);
replacementValue = yieldedValues.front();
privateVariables.push_back(replacementValue);
privatizerClones.push_back(privatizerClone);
}
builder.restoreIP(oldIP);
}
return codeGenIP;
};
auto finiCB = [&](InsertPointTy codeGenIP) {
InsertPointTy oldIP = builder.saveIP();
builder.restoreIP(codeGenIP);
SmallVector<Region *> reductionCleanupRegions;
llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
[](omp::DeclareReductionOp reductionDecl) {
return &reductionDecl.getCleanupRegion();
});
if (failed(inlineOmpRegionCleanup(
reductionCleanupRegions, privateReductionVariables,
moduleTranslation, builder, "omp.reduction.cleanup")))
bodyGenStatus = failure();
SmallVector<Region *> privateCleanupRegions;
llvm::transform(privatizerClones, std::back_inserter(privateCleanupRegions),
[](omp::PrivateClauseOp privatizer) {
return &privatizer.getDeallocRegion();
});
if (failed(inlineOmpRegionCleanup(
privateCleanupRegions, privateVariables, moduleTranslation, builder,
"omp.private.dealloc", false)))
bodyGenStatus = failure();
builder.restoreIP(oldIP);
};
llvm::Value *ifCond = nullptr;
if (auto ifExprVar = opInst.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifExprVar);
llvm::Value *numThreads = nullptr;
if (auto numThreadsVar = opInst.getNumThreadsVar())
numThreads = moduleTranslation.lookupValue(numThreadsVar);
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
if (auto bind = opInst.getProcBindVal())
pbKind = getProcBindKind(*bind);
bool isCancellable = false;
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(
ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
ifCond, numThreads, pbKind, isCancellable));
for (mlir::omp::PrivateClauseOp privatizerClone : privatizerClones)
privatizerClone.erase();
return bodyGenStatus;
}
static llvm::omp::OrderKind
convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
if (!o)
return llvm::omp::OrderKind::OMP_ORDER_unknown;
switch (*o) {
case omp::ClauseOrderKind::Concurrent:
return llvm::omp::OrderKind::OMP_ORDER_concurrent;
}
llvm_unreachable("Unknown ClauseOrderKind kind");
}
static LogicalResult
convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
auto simdOp = cast<omp::SimdOp>(opInst);
auto loopOp = cast<omp::LoopNestOp>(simdOp.getWrappedLoop());
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
LogicalResult bodyGenStatus = success();
auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
moduleTranslation.mapValue(
loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
bodyInsertPoints.push_back(ip);
if (loopInfos.size() != loopOp.getNumLoops() - 1)
return;
builder.restoreIP(ip);
convertOmpOpRegions(loopOp.getRegion(), "omp.simd.region", builder,
moduleTranslation, bodyGenStatus);
};
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
llvm::Value *lowerBound =
moduleTranslation.lookupValue(loopOp.getLowerBound()[i]);
llvm::Value *upperBound =
moduleTranslation.lookupValue(loopOp.getUpperBound()[i]);
llvm::Value *step = moduleTranslation.lookupValue(loopOp.getStep()[i]);
llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
if (i != 0) {
loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
ompLoc.DL);
computeIP = loopInfos.front()->getPreheaderIP();
}
loopInfos.push_back(ompBuilder->createCanonicalLoop(
loc, bodyGen, lowerBound, upperBound, step,
true, true, computeIP));
if (failed(bodyGenStatus))
return failure();
}
llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
llvm::CanonicalLoopInfo *loopInfo =
ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
llvm::ConstantInt *simdlen = nullptr;
if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
simdlen = builder.getInt64(simdlenVar.value());
llvm::ConstantInt *safelen = nullptr;
if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
safelen = builder.getInt64(safelenVar.value());
llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrderVal());
ompBuilder->applySimd(loopInfo, alignedVars,
simdOp.getIfExpr()
? moduleTranslation.lookupValue(simdOp.getIfExpr())
: nullptr,
order, simdlen, safelen);
builder.restoreIP(afterIP);
return success();
}
static llvm::AtomicOrdering
convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
if (!ao)
return llvm::AtomicOrdering::Monotonic;
switch (*ao) {
case omp::ClauseMemoryOrderKind::Seq_cst:
return llvm::AtomicOrdering::SequentiallyConsistent;
case omp::ClauseMemoryOrderKind::Acq_rel:
return llvm::AtomicOrdering::AcquireRelease;
case omp::ClauseMemoryOrderKind::Acquire:
return llvm::AtomicOrdering::Acquire;
case omp::ClauseMemoryOrderKind::Release:
return llvm::AtomicOrdering::Release;
case omp::ClauseMemoryOrderKind::Relaxed:
return llvm::AtomicOrdering::Monotonic;
}
llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
}
static LogicalResult
convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
auto readOp = cast<omp::AtomicReadOp>(opInst);
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrderVal());
llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
llvm::Type *elementType =
moduleTranslation.convertType(readOp.getElementType());
llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType, false, false};
llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType, false, false};
builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO));
return success();
}
static LogicalResult
convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
auto writeOp = cast<omp::AtomicWriteOp>(opInst);
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrderVal());
llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, false,
false};
builder.restoreIP(ompBuilder->createAtomicWrite(ompLoc, x, expr, ao));
return success();
}
llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
return llvm::TypeSwitch<Operation *, llvm::AtomicRMWInst::BinOp>(&op)
.Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
.Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
.Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
.Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
.Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
.Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
.Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
.Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
.Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
.Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
}
static LogicalResult
convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
auto &innerOpList = opInst.getRegion().front().getOperations();
bool isXBinopExpr{false};
llvm::AtomicRMWInst::BinOp binop;
mlir::Value mlirExpr;
llvm::Value *llvmExpr = nullptr;
llvm::Value *llvmX = nullptr;
llvm::Type *llvmXElementType = nullptr;
if (innerOpList.size() == 2) {
mlir::Operation &innerOp = *opInst.getRegion().front().begin();
if (!llvm::is_contained(innerOp.getOperands(),
opInst.getRegion().getArgument(0))) {
return opInst.emitError("no atomic update operation with region argument"
" as operand found inside atomic.update region");
}
binop = convertBinOpToAtomic(innerOp);
isXBinopExpr = innerOp.getOperand(0) == opInst.getRegion().getArgument(0);
mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
llvmExpr = moduleTranslation.lookupValue(mlirExpr);
} else {
binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
}
llvmX = moduleTranslation.lookupValue(opInst.getX());
llvmXElementType = moduleTranslation.convertType(
opInst.getRegion().getArgument(0).getType());
llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
false,
false};
llvm::AtomicOrdering atomicOrdering =
convertAtomicOrdering(opInst.getMemoryOrderVal());
LogicalResult updateGenStatus = success();
auto updateFn = [&opInst, &moduleTranslation, &updateGenStatus](
llvm::Value *atomicx,
llvm::IRBuilder<> &builder) -> llvm::Value * {
Block &bb = *opInst.getRegion().begin();
moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
if (failed(moduleTranslation.convertBlock(bb, true, builder))) {
updateGenStatus = (opInst.emitError()
<< "unable to convert update operation to llvm IR");
return nullptr;
}
omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
assert(yieldop && yieldop.getResults().size() == 1 &&
"terminator must be omp.yield op and it must have exactly one "
"argument");
return moduleTranslation.lookupValue(yieldop.getResults()[0]);
};
auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(ompBuilder->createAtomicUpdate(
ompLoc, allocaIP, llvmAtomicX, llvmExpr, atomicOrdering, binop, updateFn,
isXBinopExpr));
return updateGenStatus;
}
static LogicalResult
convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
mlir::Value mlirExpr;
bool isXBinopExpr = false, isPostfixUpdate = false;
llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
assert((atomicUpdateOp || atomicWriteOp) &&
"internal op must be an atomic.update or atomic.write op");
if (atomicWriteOp) {
isPostfixUpdate = true;
mlirExpr = atomicWriteOp.getExpr();
} else {
isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
atomicCaptureOp.getAtomicUpdateOp().getOperation();
auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
bool isRegionArgUsed{false};
for (Operation &innerOp : innerOpList) {
if (innerOp.getNumOperands() == 2) {
binop = convertBinOpToAtomic(innerOp);
if (!llvm::is_contained(innerOp.getOperands(),
atomicUpdateOp.getRegion().getArgument(0)))
continue;
isRegionArgUsed = true;
isXBinopExpr =
innerOp.getNumOperands() > 0 &&
innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
mlirExpr =
(isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
break;
}
}
if (!isRegionArgUsed)
return atomicUpdateOp.emitError(
"no atomic update operation with region argument"
" as operand found inside atomic.update region");
}
llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
llvm::Value *llvmX =
moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
llvm::Value *llvmV =
moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
llvm::Type *llvmXElementType = moduleTranslation.convertType(
atomicCaptureOp.getAtomicReadOp().getElementType());
llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
false,
false};
llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
false,
false};
llvm::AtomicOrdering atomicOrdering =
convertAtomicOrdering(atomicCaptureOp.getMemoryOrderVal());
LogicalResult updateGenStatus = success();
auto updateFn = [&](llvm::Value *atomicx,
llvm::IRBuilder<> &builder) -> llvm::Value * {
if (atomicWriteOp)
return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
Block &bb = *atomicUpdateOp.getRegion().begin();
moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
atomicx);
moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
if (failed(moduleTranslation.convertBlock(bb, true, builder))) {
updateGenStatus = (atomicUpdateOp.emitError()
<< "unable to convert update operation to llvm IR");
return nullptr;
}
omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
assert(yieldop && yieldop.getResults().size() == 1 &&
"terminator must be omp.yield op and it must have exactly one "
"argument");
return moduleTranslation.lookupValue(yieldop.getResults()[0]);
};
auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(ompBuilder->createAtomicCapture(
ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr));
return updateGenStatus;
}
static LogicalResult
convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
Value symAddr = threadprivateOp.getSymAddr();
auto *symOp = symAddr.getDefiningOp();
if (!isa<LLVM::AddressOfOp>(symOp))
return opInst.emitError("Addressing symbol not found");
LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
LLVM::GlobalOp global =
addressOfOp.getGlobal(moduleTranslation.symbolTable());
llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
llvm::Type *type = globalValue->getValueType();
llvm::TypeSize typeSize =
builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
type);
llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
llvm::StringRef suffix = llvm::StringRef(".cache", 6);
std::string cacheName = (Twine(global.getSymName()).concat(suffix)).str();
llvm::Value *callInst =
moduleTranslation.getOpenMPBuilder()->createCachedThreadPrivate(
ompLoc, globalValue, size, cacheName);
moduleTranslation.mapValue(opInst.getResult(0), callInst);
return success();
}
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
switch (deviceClause) {
case mlir::omp::DeclareTargetDeviceType::host:
return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
break;
case mlir::omp::DeclareTargetDeviceType::nohost:
return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
break;
case mlir::omp::DeclareTargetDeviceType::any:
return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
break;
}
llvm_unreachable("unhandled device clause");
}
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
convertToCaptureClauseKind(
mlir::omp::DeclareTargetCaptureClause captureClasue) {
switch (captureClasue) {
case mlir::omp::DeclareTargetCaptureClause::to:
return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
case mlir::omp::DeclareTargetCaptureClause::link:
return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
case mlir::omp::DeclareTargetCaptureClause::enter:
return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
}
llvm_unreachable("unhandled capture clause");
}
static llvm::SmallString<64>
getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
llvm::OpenMPIRBuilder &ompBuilder) {
llvm::SmallString<64> suffix;
llvm::raw_svector_ostream os(suffix);
if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
auto fileInfoCallBack = [&loc]() {
return std::pair<std::string, uint64_t>(
llvm::StringRef(loc.getFilename()), loc.getLine());
};
os << llvm::format(
"_%x", ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack).FileID);
}
os << "_decl_tgt_ref_ptr";
return suffix;
}
static bool isDeclareTargetLink(mlir::Value value) {
if (auto addressOfOp =
llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) {
auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
if (auto declareTargetGlobal =
llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp))
if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
mlir::omp::DeclareTargetCaptureClause::link)
return true;
}
return false;
}
static llvm::Value *
getRefPtrIfDeclareTarget(mlir::Value value,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
if (auto addressOfOp =
llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) {
if (auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
addressOfOp.getGlobalName()))) {
if (auto declareTargetGlobal =
llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
gOp.getOperation())) {
if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
mlir::omp::DeclareTargetCaptureClause::link) ||
(declareTargetGlobal.getDeclareTargetCaptureClause() ==
mlir::omp::DeclareTargetCaptureClause::to &&
ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
llvm::SmallString<64> suffix =
getDeclareTargetRefPtrSuffix(gOp, *ompBuilder);
if (gOp.getSymName().contains(suffix))
return moduleTranslation.getLLVMModule()->getNamedValue(
gOp.getSymName());
return moduleTranslation.getLLVMModule()->getNamedValue(
(gOp.getSymName().str() + suffix.str()).str());
}
}
}
}
return nullptr;
}
struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
llvm::SmallVector<bool, 4> IsDeclareTarget;
llvm::SmallVector<bool, 4> IsAMember;
llvm::SmallVector<mlir::Operation *, 4> MapClause;
llvm::SmallVector<llvm::Value *, 4> OriginalValue;
llvm::SmallVector<llvm::Type *, 4> BaseType;
void append(MapInfoData &CurInfo) {
IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
CurInfo.IsDeclareTarget.end());
MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
OriginalValue.append(CurInfo.OriginalValue.begin(),
CurInfo.OriginalValue.end());
BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo);
}
};
uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl) {
if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
arrTy.getElementType()))
return getArrayElementSizeInBits(nestedArrTy, dl);
return dl.getTypeSizeInBits(arrTy.getElementType());
}
llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
Operation *clauseOp, llvm::Value *basePointer,
llvm::Type *baseType, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
if (auto memberClause =
mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
if (!memberClause.getBounds().empty()) {
llvm::Value *elementCount = builder.getInt64(1);
for (auto bounds : memberClause.getBounds()) {
if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
bounds.getDefiningOp())) {
elementCount = builder.CreateMul(
elementCount,
builder.CreateAdd(
builder.CreateSub(
moduleTranslation.lookupValue(boundOp.getUpperBound()),
moduleTranslation.lookupValue(boundOp.getLowerBound())),
builder.getInt64(1)));
}
}
uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
return builder.CreateMul(elementCount,
builder.getInt64(underlyingTypeSzInBits / 8));
}
}
return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
}
void collectMapDataFromMapOperands(MapInfoData &mapData,
llvm::SmallVectorImpl<Value> &mapOperands,
LLVM::ModuleTranslation &moduleTranslation,
DataLayout &dl,
llvm::IRBuilderBase &builder) {
for (mlir::Value mapValue : mapOperands) {
if (auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
mapValue.getDefiningOp())) {
mlir::Value offloadPtr =
mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
mapData.OriginalValue.push_back(
moduleTranslation.lookupValue(offloadPtr));
mapData.Pointers.push_back(mapData.OriginalValue.back());
if (llvm::Value *refPtr =
getRefPtrIfDeclareTarget(offloadPtr,
moduleTranslation)) {
mapData.IsDeclareTarget.push_back(true);
mapData.BasePointers.push_back(refPtr);
} else {
mapData.IsDeclareTarget.push_back(false);
mapData.BasePointers.push_back(mapData.OriginalValue.back());
}
mapData.BaseType.push_back(
moduleTranslation.convertType(mapOp.getVarType()));
mapData.Sizes.push_back(
getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
mapData.BaseType.back(), builder, moduleTranslation));
mapData.MapClause.push_back(mapOp.getOperation());
mapData.Types.push_back(
llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
mapData.Names.push_back(LLVM::createMappingInformation(
mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
mapData.DevicePointers.push_back(
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
mapData.IsAMember.push_back(false);
for (mlir::Value mapValue : mapOperands) {
if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
mapValue.getDefiningOp())) {
for (auto member : map.getMembers()) {
if (member == mapOp) {
mapData.IsAMember.back() = true;
}
}
}
}
}
}
}
static int getMapDataMemberIdx(MapInfoData &mapData,
mlir::omp::MapInfoOp memberOp) {
auto *res = llvm::find(mapData.MapClause, memberOp);
assert(res != mapData.MapClause.end() &&
"MapInfoOp for member not found in MapData, cannot return index");
return std::distance(mapData.MapClause.begin(), res);
}
static mlir::omp::MapInfoOp
getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
mlir::DenseIntElementsAttr indexAttr = mapInfo.getMembersIndexAttr();
if (indexAttr.size() == 1)
if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
mapInfo.getMembers()[0].getDefiningOp()))
return mapOp;
llvm::ArrayRef<int64_t> shape = indexAttr.getShapedType().getShape();
llvm::SmallVector<size_t> indices(shape[0]);
std::iota(indices.begin(), indices.end(), 0);
llvm::sort(indices.begin(), indices.end(),
[&](const size_t a, const size_t b) {
auto indexValues = indexAttr.getValues<int32_t>();
for (int i = 0; i < shape[1]; ++i) {
int aIndex = indexValues[a * shape[1] + i];
int bIndex = indexValues[b * shape[1] + i];
if (aIndex == bIndex)
continue;
if (aIndex != -1 && bIndex == -1)
return false;
if (aIndex == -1 && bIndex != -1)
return true;
if (aIndex < bIndex)
return first;
if (bIndex < aIndex)
return !first;
}
return false;
});
return llvm::cast<mlir::omp::MapInfoOp>(
mapInfo.getMembers()[indices.front()].getDefiningOp());
}
std::vector<llvm::Value *>
calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
llvm::IRBuilderBase &builder, bool isArrayTy,
mlir::OperandRange bounds) {
std::vector<llvm::Value *> idx;
if (bounds.empty())
return idx;
if (isArrayTy) {
idx.push_back(builder.getInt64(0));
for (int i = bounds.size() - 1; i >= 0; --i) {
if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
bounds[i].getDefiningOp())) {
idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
}
}
} else {
std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
for (size_t i = 1; i < bounds.size(); ++i) {
if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
bounds[i].getDefiningOp())) {
dimensionIndexSizeOffset.push_back(builder.CreateMul(
moduleTranslation.lookupValue(boundOp.getExtent()),
dimensionIndexSizeOffset[i - 1]));
}
}
for (int i = bounds.size() - 1; i >= 0; --i) {
if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
bounds[i].getDefiningOp())) {
if (idx.empty())
idx.emplace_back(builder.CreateMul(
moduleTranslation.lookupValue(boundOp.getLowerBound()),
dimensionIndexSizeOffset[i]));
else
idx.back() = builder.CreateAdd(
idx.back(), builder.CreateMul(moduleTranslation.lookupValue(
boundOp.getLowerBound()),
dimensionIndexSizeOffset[i]));
}
}
}
return idx;
}
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
uint64_t mapDataIndex, bool isTargetParams) {
combinedInfo.Types.emplace_back(
isTargetParams
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
combinedInfo.DevicePointers.emplace_back(
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
auto parentClause =
llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
llvm::Value *lowAddr, *highAddr;
if (!parentClause.getPartialMap()) {
lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
builder.getPtrTy());
highAddr = builder.CreatePointerCast(
builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
mapData.Pointers[mapDataIndex], 1),
builder.getPtrTy());
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
} else {
auto mapOp =
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
int firstMemberIdx = getMapDataMemberIdx(
mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
builder.getPtrTy());
int lastMemberIdx = getMapDataMemberIdx(
mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
highAddr = builder.CreatePointerCast(
builder.CreateGEP(mapData.BaseType[lastMemberIdx],
mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
builder.getPtrTy());
combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
}
llvm::Value *size = builder.CreateIntCast(
builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
builder.getInt64Ty(),
false);
combinedInfo.Sizes.push_back(size);
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
if (!parentClause.getPartialMap()) {
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.DevicePointers.emplace_back(
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
}
return memberOfFlag;
}
static bool checkIfPointerMap(mlir::omp::MapInfoOp mapOp) {
if (mapOp.getVarPtrPtr())
return true;
if (isDeclareTargetLink(mapOp.getVarPtr()))
return true;
return false;
}
static void processMapMembersWithParent(
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
auto parentClause =
llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
for (auto mappedMembers : parentClause.getMembers()) {
auto memberClause =
llvm::cast<mlir::omp::MapInfoOp>(mappedMembers.getDefiningOp());
int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
assert(memberDataIdx >= 0 && "could not find mapped member of structure");
auto mapFlag =
llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType().value());
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
if (checkIfPointerMap(memberClause))
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.DevicePointers.emplace_back(
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
combinedInfo.Names.emplace_back(
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
combinedInfo.Sizes.emplace_back(mapData.Sizes[memberDataIdx]);
}
}
static void
processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
bool isTargetParams, int mapDataParentIdx = -1) {
auto mapFlag = mapData.Types[mapDataIdx];
auto mapInfoOp =
llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
bool isPtrTy = checkIfPointerMap(mapInfoOp);
if (isPtrTy)
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
if (mapInfoOp.getMapCaptureType().value() ==
mlir::omp::VariableCaptureKind::ByCopy &&
!isPtrTy)
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
if (mapDataParentIdx >= 0)
combinedInfo.BasePointers.emplace_back(
mapData.BasePointers[mapDataParentIdx]);
else
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
}
static void processMapWithMembersOf(
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
uint64_t mapDataIndex, bool isTargetParams) {
auto parentClause =
llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
auto memberClause = llvm::cast<mlir::omp::MapInfoOp>(
parentClause.getMembers()[0].getDefiningOp());
int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
processIndividualMap(mapData, memberDataIdx, combinedInfo, isTargetParams,
mapDataIndex);
return;
}
llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
combinedInfo, mapData, mapDataIndex, isTargetParams);
processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl,
combinedInfo, mapData, mapDataIndex,
memberOfParentFlag);
}
static void
createAlteredByCaptureMap(MapInfoData &mapData,
LLVM::ModuleTranslation &moduleTranslation,
llvm::IRBuilderBase &builder) {
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
if (!mapData.IsDeclareTarget[i]) {
auto mapOp =
mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(mapData.MapClause[i]);
mlir::omp::VariableCaptureKind captureKind =
mapOp.getMapCaptureType().value_or(
mlir::omp::VariableCaptureKind::ByRef);
bool isPtrTy = checkIfPointerMap(mapOp);
switch (captureKind) {
case mlir::omp::VariableCaptureKind::ByRef: {
llvm::Value *newV = mapData.Pointers[i];
std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
mapOp.getBounds());
if (isPtrTy)
newV = builder.CreateLoad(builder.getPtrTy(), newV);
if (!offsetIdx.empty())
newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
"array_offset");
mapData.Pointers[i] = newV;
} break;
case mlir::omp::VariableCaptureKind::ByCopy: {
llvm::Type *type = mapData.BaseType[i];
llvm::Value *newV;
if (mapData.Pointers[i]->getType()->isPointerTy())
newV = builder.CreateLoad(type, mapData.Pointers[i]);
else
newV = mapData.Pointers[i];
if (!isPtrTy) {
auto curInsert = builder.saveIP();
builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
auto *memTempAlloc =
builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
builder.restoreIP(curInsert);
builder.CreateStore(newV, memTempAlloc);
newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
}
mapData.Pointers[i] = newV;
mapData.BasePointers[i] = newV;
} break;
case mlir::omp::VariableCaptureKind::This:
case mlir::omp::VariableCaptureKind::VLAType:
mapData.MapClause[i]->emitOpError("Unhandled capture kind");
break;
}
}
}
}
static void genMapInfos(llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
DataLayout &dl,
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
MapInfoData &mapData,
const SmallVector<Value> &devPtrOperands = {},
const SmallVector<Value> &devAddrOperands = {},
bool isTargetParams = false) {
if (!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
createAlteredByCaptureMap(mapData, moduleTranslation, builder);
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
auto fail = [&combinedInfo]() -> void {
combinedInfo.BasePointers.clear();
combinedInfo.Pointers.clear();
combinedInfo.DevicePointers.clear();
combinedInfo.Sizes.clear();
combinedInfo.Types.clear();
combinedInfo.Names.clear();
};
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
if (mapData.IsAMember[i])
continue;
auto mapInfoOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[i]);
if (!mapInfoOp.getMembers().empty()) {
processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
combinedInfo, mapData, i, isTargetParams);
continue;
}
processIndividualMap(mapData, i, combinedInfo, isTargetParams);
}
auto findMapInfo = [&combinedInfo](llvm::Value *val, unsigned &index) {
index = 0;
for (llvm::Value *basePtr : combinedInfo.BasePointers) {
if (basePtr == val)
return true;
index++;
}
return false;
};
auto addDevInfos = [&, fail](auto devOperands, auto devOpType) -> void {
for (const auto &devOp : devOperands) {
if (!isa<LLVM::LLVMPointerType>(devOp.getType()))
return fail();
llvm::Value *mapOpValue = moduleTranslation.lookupValue(devOp);
unsigned infoIndex;
if (findMapInfo(mapOpValue, infoIndex)) {
combinedInfo.Types[infoIndex] |=
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
combinedInfo.DevicePointers[infoIndex] = devOpType;
} else {
combinedInfo.BasePointers.emplace_back(mapOpValue);
combinedInfo.Pointers.emplace_back(mapOpValue);
combinedInfo.DevicePointers.emplace_back(devOpType);
combinedInfo.Names.emplace_back(
LLVM::createMappingInformation(devOp.getLoc(), *ompBuilder));
combinedInfo.Types.emplace_back(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
combinedInfo.Sizes.emplace_back(builder.getInt64(0));
}
}
};
addDevInfos(devPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
addDevInfos(devAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
}
static LogicalResult
convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Value *ifCond = nullptr;
int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
SmallVector<Value> mapOperands;
SmallVector<Value> useDevPtrOperands;
SmallVector<Value> useDevAddrOperands;
llvm::omp::RuntimeFunction RTLFn;
DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
LogicalResult result =
llvm::TypeSwitch<Operation *, LogicalResult>(op)
.Case([&](omp::TargetDataOp dataOp) {
if (auto ifExprVar = dataOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifExprVar);
if (auto devId = dataOp.getDevice())
if (auto constOp =
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
mapOperands = dataOp.getMapOperands();
useDevPtrOperands = dataOp.getUseDevicePtr();
useDevAddrOperands = dataOp.getUseDeviceAddr();
return success();
})
.Case([&](omp::TargetEnterDataOp enterDataOp) {
if (enterDataOp.getNowait())
return (LogicalResult)(enterDataOp.emitError(
"`nowait` is not supported yet"));
if (auto ifExprVar = enterDataOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifExprVar);
if (auto devId = enterDataOp.getDevice())
if (auto constOp =
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
RTLFn = llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
mapOperands = enterDataOp.getMapOperands();
return success();
})
.Case([&](omp::TargetExitDataOp exitDataOp) {
if (exitDataOp.getNowait())
return (LogicalResult)(exitDataOp.emitError(
"`nowait` is not supported yet"));
if (auto ifExprVar = exitDataOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifExprVar);
if (auto devId = exitDataOp.getDevice())
if (auto constOp =
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
RTLFn = llvm::omp::OMPRTL___tgt_target_data_end_mapper;
mapOperands = exitDataOp.getMapOperands();
return success();
})
.Case([&](omp::TargetUpdateOp updateDataOp) {
if (updateDataOp.getNowait())
return (LogicalResult)(updateDataOp.emitError(
"`nowait` is not supported yet"));
if (auto ifExprVar = updateDataOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifExprVar);
if (auto devId = updateDataOp.getDevice())
if (auto constOp =
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
RTLFn = llvm::omp::OMPRTL___tgt_target_data_update_mapper;
mapOperands = updateDataOp.getMapOperands();
return success();
})
.Default([&](Operation *op) {
return op->emitError("unsupported OpenMP operation: ")
<< op->getName();
});
if (failed(result))
return failure();
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
MapInfoData mapData;
collectMapDataFromMapOperands(mapData, mapOperands, moduleTranslation, DL,
builder);
llvm::OpenMPIRBuilder::MapInfosTy combinedInfo;
auto genMapInfoCB =
[&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & {
builder.restoreIP(codeGenIP);
if (auto dataOp = dyn_cast<omp::TargetDataOp>(op)) {
genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
useDevPtrOperands, useDevAddrOperands);
} else {
genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
}
return combinedInfo;
};
llvm::OpenMPIRBuilder::TargetDataInfo info(true,
true);
using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
LogicalResult bodyGenStatus = success();
auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType) {
assert(isa<omp::TargetDataOp>(op) &&
"BodyGen requested for non TargetDataOp");
Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
switch (bodyGenType) {
case BodyGenTy::Priv:
if (!info.DevicePtrInfoMap.empty()) {
builder.restoreIP(codeGenIP);
unsigned argIndex = 0;
for (auto &devPtrOp : useDevPtrOperands) {
llvm::Value *mapOpValue = moduleTranslation.lookupValue(devPtrOp);
const auto &arg = region.front().getArgument(argIndex);
moduleTranslation.mapValue(arg,
info.DevicePtrInfoMap[mapOpValue].second);
argIndex++;
}
for (auto &devAddrOp : useDevAddrOperands) {
llvm::Value *mapOpValue = moduleTranslation.lookupValue(devAddrOp);
const auto &arg = region.front().getArgument(argIndex);
auto *LI = builder.CreateLoad(
builder.getPtrTy(), info.DevicePtrInfoMap[mapOpValue].second);
moduleTranslation.mapValue(arg, LI);
argIndex++;
}
bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region",
builder, moduleTranslation);
}
break;
case BodyGenTy::DupNoPriv:
break;
case BodyGenTy::NoPriv:
if (info.DevicePtrInfoMap.empty()) {
builder.restoreIP(codeGenIP);
bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region",
builder, moduleTranslation);
}
break;
}
return builder.saveIP();
};
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
if (isa<omp::TargetDataOp>(op)) {
builder.restoreIP(ompBuilder->createTargetData(
ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
info, genMapInfoCB, nullptr, bodyGenCB));
} else {
builder.restoreIP(ompBuilder->createTargetData(
ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
info, genMapInfoCB, &RTLFn));
}
return bodyGenStatus;
}
LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
LLVM::ModuleTranslation &moduleTranslation) {
if (!cast<mlir::ModuleOp>(op))
return failure();
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device",
attribute.getOpenmpDeviceVersion());
if (attribute.getNoGpuLib())
return success();
ompBuilder->createGlobalFlag(
attribute.getDebugKind() ,
"__omp_rtl_debug_kind");
ompBuilder->createGlobalFlag(
attribute
.getAssumeTeamsOversubscription()
,
"__omp_rtl_assume_teams_oversubscription");
ompBuilder->createGlobalFlag(
attribute
.getAssumeThreadsOversubscription()
,
"__omp_rtl_assume_threads_oversubscription");
ompBuilder->createGlobalFlag(
attribute.getAssumeNoThreadState() ,
"__omp_rtl_assume_no_thread_state");
ompBuilder->createGlobalFlag(
attribute
.getAssumeNoNestedParallelism()
,
"__omp_rtl_assume_no_nested_parallelism");
return success();
}
static bool getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
omp::TargetOp targetOp,
llvm::StringRef parentName = "") {
auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
assert(fileLoc && "No file found from location");
StringRef fileName = fileLoc.getFilename().getValue();
llvm::sys::fs::UniqueID id;
if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) {
targetOp.emitError("Unable to get unique ID for file");
return false;
}
uint64_t line = fileLoc.getLine();
targetInfo = llvm::TargetRegionEntryInfo(parentName, id.getDevice(),
id.getFile(), line);
return true;
}
static bool targetOpSupported(Operation &opInst) {
auto targetOp = cast<omp::TargetOp>(opInst);
if (targetOp.getIfExpr()) {
opInst.emitError("If clause not yet supported");
return false;
}
if (targetOp.getDevice()) {
opInst.emitError("Device clause not yet supported");
return false;
}
if (targetOp.getThreadLimit()) {
opInst.emitError("Thread limit clause not yet supported");
return false;
}
if (targetOp.getNowait()) {
opInst.emitError("Nowait clause not yet supported");
return false;
}
return true;
}
static void
handleDeclareTargetMapVar(MapInfoData &mapData,
LLVM::ModuleTranslation &moduleTranslation,
llvm::IRBuilderBase &builder, llvm::Function *func) {
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
if (mapData.IsDeclareTarget[i]) {
if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
convertUsersOfConstantsToInstructions(constant, func, false);
llvm::SmallVector<llvm::User *> userVec;
for (llvm::User *user : mapData.OriginalValue[i]->users())
userVec.push_back(user);
for (llvm::User *user : userVec) {
if (auto *insn = dyn_cast<llvm::Instruction>(user)) {
if (insn->getFunction() == func) {
auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
mapData.BasePointers[i]);
load->moveBefore(insn);
user->replaceUsesOfWith(mapData.OriginalValue[i], load);
}
}
}
}
}
}
static llvm::IRBuilderBase::InsertPoint
createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
llvm::Value *input, llvm::Value *&retVal,
llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder &ompBuilder,
LLVM::ModuleTranslation &moduleTranslation,
llvm::IRBuilderBase::InsertPoint allocaIP,
llvm::IRBuilderBase::InsertPoint codeGenIP) {
builder.restoreIP(allocaIP);
mlir::omp::VariableCaptureKind capture =
mlir::omp::VariableCaptureKind::ByRef;
for (size_t i = 0; i < mapData.MapClause.size(); ++i)
if (mapData.OriginalValue[i] == input) {
if (auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
mapData.MapClause[i])) {
capture = mapOp.getMapCaptureType().value_or(
mlir::omp::VariableCaptureKind::ByRef);
}
break;
}
unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
unsigned int defaultAS =
ompBuilder.M.getDataLayout().getProgramAddressSpace();
llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
if (allocaAS != defaultAS && arg.getType()->isPointerTy())
v = builder.CreatePointerBitCastOrAddrSpaceCast(
v, arg.getType()->getPointerTo(defaultAS));
builder.CreateStore(&arg, v);
builder.restoreIP(codeGenIP);
switch (capture) {
case mlir::omp::VariableCaptureKind::ByCopy: {
retVal = v;
break;
}
case mlir::omp::VariableCaptureKind::ByRef: {
retVal = builder.CreateAlignedLoad(
v->getType(), v,
ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
break;
}
case mlir::omp::VariableCaptureKind::This:
case mlir::omp::VariableCaptureKind::VLAType:
assert(false && "Currently unsupported capture kind");
break;
}
return builder.saveIP();
}
static LogicalResult
convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
if (!targetOpSupported(opInst))
return failure();
auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
auto targetOp = cast<omp::TargetOp>(opInst);
auto &targetRegion = targetOp.getRegion();
DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
SmallVector<Value> mapOperands = targetOp.getMapOperands();
llvm::Function *llvmOutlinedFn = nullptr;
LogicalResult bodyGenStatus = success();
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
auto bodyCB = [&](InsertPointTy allocaIP,
InsertPointTy codeGenIP) -> InsertPointTy {
llvm::Function *llvmParentFn =
moduleTranslation.lookupFunction(parentFn.getName());
llvmOutlinedFn = codeGenIP.getBlock()->getParent();
assert(llvmParentFn && llvmOutlinedFn &&
"Both parent and outlined functions must exist at this point");
if (auto attr = llvmParentFn->getFnAttribute("target-cpu");
attr.isStringAttribute())
llvmOutlinedFn->addFnAttr(attr);
if (auto attr = llvmParentFn->getFnAttribute("target-features");
attr.isStringAttribute())
llvmOutlinedFn->addFnAttr(attr);
builder.restoreIP(codeGenIP);
unsigned argIndex = 0;
for (auto &mapOp : mapOperands) {
auto mapInfoOp =
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp());
llvm::Value *mapOpValue =
moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
const auto &arg = targetRegion.front().getArgument(argIndex);
moduleTranslation.mapValue(arg, mapOpValue);
argIndex++;
}
llvm::BasicBlock *exitBlock = convertOmpOpRegions(
targetRegion, "omp.target", builder, moduleTranslation, bodyGenStatus);
builder.SetInsertPoint(exitBlock);
return builder.saveIP();
};
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
StringRef parentName = parentFn.getName();
llvm::TargetRegionEntryInfo entryInfo;
if (!getTargetEntryUniqueInfo(entryInfo, targetOp, parentName))
return failure();
int32_t defaultValTeams = -1;
int32_t defaultValThreads = 0;
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
MapInfoData mapData;
collectMapDataFromMapOperands(mapData, mapOperands, moduleTranslation, dl,
builder);
llvm::OpenMPIRBuilder::MapInfosTy combinedInfos;
auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
-> llvm::OpenMPIRBuilder::MapInfosTy & {
builder.restoreIP(codeGenIP);
genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, {}, {},
true);
return combinedInfos;
};
auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
llvm::Value *&retVal, InsertPointTy allocaIP,
InsertPointTy codeGenIP) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
if (!ompBuilder->Config.isTargetDevice()) {
retVal = cast<llvm::Value>(&arg);
return codeGenIP;
}
return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder,
*ompBuilder, moduleTranslation,
allocaIP, codeGenIP);
};
llvm::SmallVector<llvm::Value *, 4> kernelInput;
for (size_t i = 0; i < mapOperands.size(); ++i) {
if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
kernelInput.push_back(mapData.OriginalValue[i]);
}
SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
buildDependData(targetOp.getDepends(), targetOp.getDependVars(),
moduleTranslation, dds);
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTarget(
ompLoc, allocaIP, builder.saveIP(), entryInfo, defaultValTeams,
defaultValThreads, kernelInput, genMapInfoCB, bodyCB, argAccessorCB,
dds));
if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
llvmOutlinedFn);
return bodyGenStatus;
}
static LogicalResult
convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
LLVM::ModuleTranslation &moduleTranslation) {
if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
op->getParentOfType<ModuleOp>().getOperation())) {
if (!offloadMod.getIsTargetDevice())
return success();
omp::DeclareTargetDeviceType declareType =
attribute.getDeviceType().getValue();
if (declareType == omp::DeclareTargetDeviceType::host) {
llvm::Function *llvmFunc =
moduleTranslation.lookupFunction(funcOp.getName());
llvmFunc->dropAllReferences();
llvmFunc->eraseFromParent();
}
}
return success();
}
if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
bool isDeclaration = gOp.isDeclaration();
bool isExternallyVisible =
gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
llvm::StringRef mangledName = gOp.getSymName();
auto captureClause =
convertToCaptureClauseKind(attribute.getCaptureClause().getValue());
auto deviceClause =
convertToDeviceClauseKind(attribute.getDeviceType().getValue());
std::vector<llvm::GlobalVariable *> generatedRefs;
std::vector<llvm::Triple> targetTriple;
auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
op->getParentOfType<mlir::ModuleOp>()->getAttr(
LLVM::LLVMDialect::getTargetTripleAttrName()));
if (targetTripleAttr)
targetTriple.emplace_back(targetTripleAttr.data());
auto fileInfoCallBack = [&loc]() {
std::string filename = "";
std::uint64_t lineNo = 0;
if (loc) {
filename = loc.getFilename().str();
lineNo = loc.getLine();
}
return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
lineNo);
};
ompBuilder->registerTargetGlobalVariable(
captureClause, deviceClause, isDeclaration, isExternallyVisible,
ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
generatedRefs, false, targetTriple,
nullptr, nullptr,
gVal->getType(), gVal);
if (ompBuilder->Config.isTargetDevice() &&
(attribute.getCaptureClause().getValue() !=
mlir::omp::DeclareTargetCaptureClause::to ||
ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
ompBuilder->getAddrOfDeclareTargetVar(
captureClause, deviceClause, isDeclaration, isExternallyVisible,
ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
generatedRefs, false, targetTriple, gVal->getType(),
nullptr,
nullptr);
}
}
}
return success();
}
static bool isTargetDeviceOp(Operation *op) {
if (op->getParentOfType<omp::TargetOp>())
return true;
if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
if (auto declareTargetIface =
llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
parentFn.getOperation()))
if (declareTargetIface.isDeclareTarget() &&
declareTargetIface.getDeclareTargetDeviceType() !=
mlir::omp::DeclareTargetDeviceType::host)
return true;
return false;
}
static LogicalResult
convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
return llvm::TypeSwitch<Operation *, LogicalResult>(op)
.Case([&](omp::BarrierOp) {
ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
return success();
})
.Case([&](omp::TaskwaitOp) {
ompBuilder->createTaskwait(builder.saveIP());
return success();
})
.Case([&](omp::TaskyieldOp) {
ompBuilder->createTaskyield(builder.saveIP());
return success();
})
.Case([&](omp::FlushOp) {
ompBuilder->createFlush(builder.saveIP());
return success();
})
.Case([&](omp::ParallelOp op) {
return convertOmpParallel(op, builder, moduleTranslation);
})
.Case([&](omp::MaskedOp) {
return convertOmpMasked(*op, builder, moduleTranslation);
})
.Case([&](omp::MasterOp) {
return convertOmpMaster(*op, builder, moduleTranslation);
})
.Case([&](omp::CriticalOp) {
return convertOmpCritical(*op, builder, moduleTranslation);
})
.Case([&](omp::OrderedRegionOp) {
return convertOmpOrderedRegion(*op, builder, moduleTranslation);
})
.Case([&](omp::OrderedOp) {
return convertOmpOrdered(*op, builder, moduleTranslation);
})
.Case([&](omp::WsloopOp) {
return convertOmpWsloop(*op, builder, moduleTranslation);
})
.Case([&](omp::SimdOp) {
return convertOmpSimd(*op, builder, moduleTranslation);
})
.Case([&](omp::AtomicReadOp) {
return convertOmpAtomicRead(*op, builder, moduleTranslation);
})
.Case([&](omp::AtomicWriteOp) {
return convertOmpAtomicWrite(*op, builder, moduleTranslation);
})
.Case([&](omp::AtomicUpdateOp op) {
return convertOmpAtomicUpdate(op, builder, moduleTranslation);
})
.Case([&](omp::AtomicCaptureOp op) {
return convertOmpAtomicCapture(op, builder, moduleTranslation);
})
.Case([&](omp::SectionsOp) {
return convertOmpSections(*op, builder, moduleTranslation);
})
.Case([&](omp::SingleOp op) {
return convertOmpSingle(op, builder, moduleTranslation);
})
.Case([&](omp::TeamsOp op) {
return convertOmpTeams(op, builder, moduleTranslation);
})
.Case([&](omp::TaskOp op) {
return convertOmpTaskOp(op, builder, moduleTranslation);
})
.Case([&](omp::TaskgroupOp op) {
return convertOmpTaskgroupOp(op, builder, moduleTranslation);
})
.Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareReductionOp,
omp::CriticalDeclareOp>([](auto op) {
return success();
})
.Case([&](omp::ThreadprivateOp) {
return convertOmpThreadprivate(*op, builder, moduleTranslation);
})
.Case<omp::TargetDataOp, omp::TargetEnterDataOp, omp::TargetExitDataOp,
omp::TargetUpdateOp>([&](auto op) {
return convertOmpTargetData(op, builder, moduleTranslation);
})
.Case([&](omp::TargetOp) {
return convertOmpTarget(*op, builder, moduleTranslation);
})
.Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
[&](auto op) {
return success();
})
.Default([&](Operation *inst) {
return inst->emitError("unsupported OpenMP operation: ")
<< inst->getName();
});
}
static LogicalResult
convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
return convertHostOrTargetOperation(op, builder, moduleTranslation);
}
static LogicalResult
convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
if (isa<omp::TargetOp>(op))
return convertOmpTarget(*op, builder, moduleTranslation);
if (isa<omp::TargetDataOp>(op))
return convertOmpTargetData(op, builder, moduleTranslation);
bool interrupted =
op->walk<WalkOrder::PreOrder>([&](Operation *oper) {
if (isa<omp::TargetOp>(oper)) {
if (failed(convertOmpTarget(*oper, builder, moduleTranslation)))
return WalkResult::interrupt();
return WalkResult::skip();
}
if (isa<omp::TargetDataOp>(oper)) {
if (failed(convertOmpTargetData(oper, builder, moduleTranslation)))
return WalkResult::interrupt();
return WalkResult::skip();
}
return WalkResult::advance();
}).wasInterrupted();
return failure(interrupted);
}
namespace {
class OpenMPDialectLLVMIRTranslationInterface
: public LLVMTranslationDialectInterface {
public:
using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
LogicalResult
convertOperation(Operation *op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) const final;
LogicalResult
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final;
};
}
LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const {
return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
attribute.getName())
.Case("omp.is_target_device",
[&](Attribute attr) {
if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
llvm::OpenMPIRBuilderConfig &config =
moduleTranslation.getOpenMPBuilder()->Config;
config.setIsTargetDevice(deviceAttr.getValue());
return success();
}
return failure();
})
.Case("omp.is_gpu",
[&](Attribute attr) {
if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
llvm::OpenMPIRBuilderConfig &config =
moduleTranslation.getOpenMPBuilder()->Config;
config.setIsGPU(gpuAttr.getValue());
return success();
}
return failure();
})
.Case("omp.host_ir_filepath",
[&](Attribute attr) {
if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
llvm::OpenMPIRBuilder *ompBuilder =
moduleTranslation.getOpenMPBuilder();
ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
return success();
}
return failure();
})
.Case("omp.flags",
[&](Attribute attr) {
if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
return convertFlagsAttr(op, rtlAttr, moduleTranslation);
return failure();
})
.Case("omp.version",
[&](Attribute attr) {
if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
llvm::OpenMPIRBuilder *ompBuilder =
moduleTranslation.getOpenMPBuilder();
ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
versionAttr.getVersion());
return success();
}
return failure();
})
.Case("omp.declare_target",
[&](Attribute attr) {
if (auto declareTargetAttr =
dyn_cast<omp::DeclareTargetAttr>(attr))
return convertDeclareTargetAttr(op, declareTargetAttr,
moduleTranslation);
return failure();
})
.Case("omp.requires",
[&](Attribute attr) {
if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
using Requires = omp::ClauseRequires;
Requires flags = requiresAttr.getValue();
llvm::OpenMPIRBuilderConfig &config =
moduleTranslation.getOpenMPBuilder()->Config;
config.setHasRequiresReverseOffload(
bitEnumContainsAll(flags, Requires::reverse_offload));
config.setHasRequiresUnifiedAddress(
bitEnumContainsAll(flags, Requires::unified_address));
config.setHasRequiresUnifiedSharedMemory(
bitEnumContainsAll(flags, Requires::unified_shared_memory));
config.setHasRequiresDynamicAllocators(
bitEnumContainsAll(flags, Requires::dynamic_allocators));
return success();
}
return failure();
})
.Default([](Attribute) {
return success();
})(attribute.getValue());
return failure();
}
LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
Operation *op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) const {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
if (ompBuilder->Config.isTargetDevice()) {
if (isTargetDeviceOp(op)) {
return convertTargetDeviceOp(op, builder, moduleTranslation);
} else {
return convertTargetOpsInNest(op, builder, moduleTranslation);
}
}
return convertHostOrTargetOperation(op, builder, moduleTranslation);
}
void mlir::registerOpenMPDialectTranslation(DialectRegistry ®istry) {
registry.insert<omp::OpenMPDialect>();
registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
});
}
void mlir::registerOpenMPDialectTranslation(MLIRContext &context) {
DialectRegistry registry;
registerOpenMPDialectTranslation(registry);
context.appendDialectRegistry(registry);
}