#ifndef FORTRAN_LOWER_CLAUSEPROCESSOR_H
#define FORTRAN_LOWER_CLAUSEPROCESSOR_H
#include "Clauses.h"
#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/Bridge.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Parser/dump-parse-tree.h"
#include "flang/Parser/parse-tree.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
namespace fir {
class FirOpBuilder;
}
namespace Fortran {
namespace lower {
namespace omp {
class ClauseProcessor {
public:
ClauseProcessor(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses)
: converter(converter), semaCtx(semaCtx), clauses(clauses) {}
bool
processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval,
mlir::omp::CollapseClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
bool processDefault() const;
bool processDevice(lower::StatementContext &stmtCtx,
mlir::omp::DeviceClauseOps &result) const;
bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
bool processDistSchedule(lower::StatementContext &stmtCtx,
mlir::omp::DistScheduleClauseOps &result) const;
bool processFilter(lower::StatementContext &stmtCtx,
mlir::omp::FilterClauseOps &result) const;
bool processFinal(lower::StatementContext &stmtCtx,
mlir::omp::FinalClauseOps &result) const;
bool processHasDeviceAddr(
mlir::omp::HasDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const;
bool processHint(mlir::omp::HintClauseOps &result) const;
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
bool processNumTeams(lower::StatementContext &stmtCtx,
mlir::omp::NumTeamsClauseOps &result) const;
bool processNumThreads(lower::StatementContext &stmtCtx,
mlir::omp::NumThreadsClauseOps &result) const;
bool processOrder(mlir::omp::OrderClauseOps &result) const;
bool processOrdered(mlir::omp::OrderedClauseOps &result) const;
bool processPriority(lower::StatementContext &stmtCtx,
mlir::omp::PriorityClauseOps &result) const;
bool processProcBind(mlir::omp::ProcBindClauseOps &result) const;
bool processSafelen(mlir::omp::SafelenClauseOps &result) const;
bool processSchedule(lower::StatementContext &stmtCtx,
mlir::omp::ScheduleClauseOps &result) const;
bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const;
bool processThreadLimit(lower::StatementContext &stmtCtx,
mlir::omp::ThreadLimitClauseOps &result) const;
bool processUntied(mlir::omp::UntiedClauseOps &result) const;
bool processAligned(mlir::omp::AlignedClauseOps &result) const;
bool processAllocate(mlir::omp::AllocateClauseOps &result) const;
bool processCopyin() const;
bool processCopyprivate(mlir::Location currentLocation,
mlir::omp::CopyprivateClauseOps &result) const;
bool processDepend(mlir::omp::DependClauseOps &result) const;
bool
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
mlir::omp::IfClauseOps &result) const;
bool processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processMap(
mlir::Location currentLocation, lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms = nullptr,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
bool processReduction(
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSyms =
nullptr) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processUseDeviceAddr(
mlir::omp::UseDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
bool processUseDevicePtr(
mlir::omp::UseDevicePtrClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
template <typename T>
bool processMotionClauses(lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result);
template <typename... Ts>
void processTODO(mlir::Location currentLocation,
llvm::omp::Directive directive) const;
private:
using ClauseIterator = List<Clause>::const_iterator;
template <typename T>
static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end);
template <typename T>
const T *findUniqueClause(const parser::CharBlock **source = nullptr) const;
template <typename T>
bool findRepeatableClause(
std::function<void(const T &, const parser::CharBlock &source)>
callbackFn) const;
template <typename T>
bool markClauseOccurrence(mlir::UnitAttr &result) const;
lower::AbstractConverter &converter;
semantics::SemanticsContext &semaCtx;
List<Clause> clauses;
};
template <typename T>
bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result) {
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
llvm::SmallVector<const semantics::Symbol *> mapSymbols;
bool clauseFound = findRepeatableClause<T>(
[&](const T &clause, const parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
static_assert(std::is_same_v<T, omp::clause::To> ||
std::is_same_v<T, omp::clause::From>);
constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
std::is_same_v<T, omp::clause::To>
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
auto &objects = std::get<ObjectList>(clause.t);
for (const omp::Object &object : objects) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
lower::AddrAndBoundsInfo info =
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
mlir::omp::MapBoundsType>(
converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
object.ref(), clauseLocation, asFortran, bounds,
treatIndexAsSection);
auto origSymbol = converter.getSymbolAddress(*object.sym());
mlir::Value symAddr = info.addr;
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
symAddr = origSymbol;
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
firOpBuilder, clauseLocation, symAddr,
mlir::Value{}, asFortran.str(), bounds,
{}, mlir::DenseIntElementsAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
if (object.sym()->owner().IsDerivedType()) {
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
semaCtx);
} else {
result.mapVars.push_back(mapOp);
mapSymbols.push_back(object.sym());
}
}
});
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
mapSymbols,
nullptr, nullptr);
return clauseFound;
}
template <typename... Ts>
void ClauseProcessor::processTODO(mlir::Location currentLocation,
llvm::omp::Directive directive) const {
auto checkUnhandledClause = [&](llvm::omp::Clause id, const auto *x) {
if (!x)
return;
TODO(currentLocation,
"Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() +
" in " + llvm::omp::getOpenMPDirectiveName(directive).upper() +
" construct");
};
for (ClauseIterator it = clauses.begin(); it != clauses.end(); ++it)
(checkUnhandledClause(it->id, std::get_if<Ts>(&it->u)), ...);
}
template <typename T>
ClauseProcessor::ClauseIterator
ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) {
for (ClauseIterator it = begin; it != end; ++it) {
if (std::get_if<T>(&it->u))
return it;
}
return end;
}
template <typename T>
const T *
ClauseProcessor::findUniqueClause(const parser::CharBlock **source) const {
ClauseIterator it = findClause<T>(clauses.begin(), clauses.end());
if (it != clauses.end()) {
if (source)
*source = &it->source;
return &std::get<T>(it->u);
}
return nullptr;
}
template <typename T>
bool ClauseProcessor::findRepeatableClause(
std::function<void(const T &, const parser::CharBlock &source)> callbackFn)
const {
bool found = false;
ClauseIterator nextIt, endIt = clauses.end();
for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
nextIt = findClause<T>(it, endIt);
if (nextIt != endIt) {
callbackFn(std::get<T>(nextIt->u), nextIt->source);
found = true;
++nextIt;
}
}
return found;
}
template <typename T>
bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const {
if (findUniqueClause<T>()) {
result = converter.getFirOpBuilder().getUnitAttr();
return true;
}
return false;
}
}
}
}
#endif