#include "flang/Lower/HostAssociations.h"
#include "flang/Evaluate/check-expression.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/Allocatable.h"
#include "flang/Lower/BoxAnalyzer.h"
#include "flang/Lower/CallInterface.h"
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/ConvertVariable.h"
#include "flang/Lower/OpenMP.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/Character.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Support/FatalError.h"
#include "flang/Semantics/tools.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
#define DEBUG_TYPE "flang-host-assoc"
static bool isDerivedWithLenParameters(const Fortran::semantics::Symbol &sym) {
if (const auto *declTy = sym.GetType())
if (const auto *derived = declTy->AsDerived())
return Fortran::semantics::CountLenParameters(*derived) != 0;
return false;
}
static void bindCapturedSymbol(const Fortran::semantics::Symbol &sym,
fir::ExtendedValue val,
Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symMap) {
if (converter.getLoweringOptions().getLowerToHighLevelFIR())
Fortran::lower::genDeclareSymbol(converter, symMap, sym, val,
fir::FortranVariableFlagsEnum::host_assoc);
else
symMap.addSymbol(sym, val);
}
namespace {
struct GetTypeInTuple {
using Result = mlir::Type;
};
struct InstantiateHostTuple {
using Result = void;
fir::ExtendedValue hostValue;
mlir::Value addrInTuple;
mlir::Location loc;
};
struct GetFromTuple {
using Result = void;
Fortran::lower::SymMap &symMap;
mlir::Value valueInTuple;
mlir::Location loc;
};
template <typename SymbolCategory>
class CapturedSymbols {
public:
template <typename T>
static void visit(const T &, Fortran::lower::AbstractConverter &,
const Fortran::semantics::Symbol &,
const Fortran::lower::BoxAnalyzer &) {
static_assert(!std::is_same_v<T, T> &&
"default visit must not be instantiated");
}
static mlir::Type visit(const GetTypeInTuple &,
Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym,
const Fortran::lower::BoxAnalyzer &) {
return SymbolCategory::getType(converter, sym);
}
static void visit(const InstantiateHostTuple &args,
Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym,
const Fortran::lower::BoxAnalyzer &) {
return SymbolCategory::instantiateHostTuple(args, converter, sym);
}
static void visit(const GetFromTuple &args,
Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym,
const Fortran::lower::BoxAnalyzer &ba) {
return SymbolCategory::getFromTuple(args, converter, sym, ba);
}
};
class CapturedSimpleScalars : public CapturedSymbols<CapturedSimpleScalars> {
public:
static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym) {
return fir::ReferenceType::get(converter.genType(sym));
}
static void instantiateHostTuple(const InstantiateHostTuple &args,
Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Type typeInTuple = fir::dyn_cast_ptrEleTy(args.addrInTuple.getType());
assert(typeInTuple && "addrInTuple must be an address");
mlir::Value castBox = builder.createConvert(args.loc, typeInTuple,
fir::getBase(args.hostValue));
builder.create<fir::StoreOp>(args.loc, castBox, args.addrInTuple);
}
static void getFromTuple(const GetFromTuple &args,
Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym,
const Fortran::lower::BoxAnalyzer &) {
bindCapturedSymbol(sym, args.valueInTuple, converter, args.symMap);
}
};
class CapturedProcedure : public CapturedSymbols<CapturedProcedure> {
public:
static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym) {
mlir::Type funTy = Fortran::lower::getDummyProcedureType(sym, converter);
if (Fortran::semantics::IsPointer(sym))
return fir::ReferenceType::get(funTy);
return funTy;
}
static void instantiateHostTuple(const InstantiateHostTuple &args,
Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Type typeInTuple = fir::dyn_cast_ptrEleTy(args.addrInTuple.getType());
assert(typeInTuple && "addrInTuple must be an address");
mlir::Value castBox = builder.createConvert(args.loc, typeInTuple,
fir::getBase(args.hostValue));
builder.create<fir::StoreOp>(args.loc, castBox, args.addrInTuple);
}
static void getFromTuple(const GetFromTuple &args,
Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym,
const Fortran::lower::BoxAnalyzer &) {
bindCapturedSymbol(sym, args.valueInTuple, converter, args.symMap);
}
};
class CapturedCharacterScalars
: public CapturedSymbols<CapturedCharacterScalars> {
public:
static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym) {
fir::KindTy kind =
mlir::cast<fir::CharacterType>(converter.genType(sym)).getFKind();
return fir::BoxCharType::get(&converter.getMLIRContext(), kind);
}
static void instantiateHostTuple(const InstantiateHostTuple &args,
Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &) {
const fir::CharBoxValue *charBox = args.hostValue.getCharBox();
assert(charBox && "host value must be a fir::CharBoxValue");
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Value boxchar = fir::factory::CharacterExprHelper(builder, args.loc)
.createEmbox(*charBox);
builder.create<fir::StoreOp>(args.loc, boxchar, args.addrInTuple);
}
static void getFromTuple(const GetFromTuple &args,
Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym,
const Fortran::lower::BoxAnalyzer &) {
fir::factory::CharacterExprHelper charHelp(converter.getFirOpBuilder(),
args.loc);
std::pair<mlir::Value, mlir::Value> unboxchar =
charHelp.createUnboxChar(args.valueInTuple);
bindCapturedSymbol(sym,
fir::CharBoxValue{unboxchar.first, unboxchar.second},
converter, args.symMap);
}
};
class CapturedPolymorphicScalar
: public CapturedSymbols<CapturedPolymorphicScalar> {
public:
static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym) {
return converter.genType(sym);
}
static void instantiateHostTuple(const InstantiateHostTuple &args,
Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Location loc = args.loc;
mlir::Type typeInTuple = fir::dyn_cast_ptrEleTy(args.addrInTuple.getType());
assert(typeInTuple && "addrInTuple must be an address");
mlir::Value castBox = builder.createConvert(args.loc, typeInTuple,
fir::getBase(args.hostValue));
if (Fortran::semantics::IsOptional(sym)) {
auto isPresent =
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), castBox);
builder.genIfThenElse(loc, isPresent)
.genThen([&]() {
builder.create<fir::StoreOp>(loc, castBox, args.addrInTuple);
})
.genElse([&]() {
mlir::Value null = fir::factory::createUnallocatedBox(
builder, loc, typeInTuple,
mlir::ValueRange{});
builder.create<fir::StoreOp>(loc, null, args.addrInTuple);
})
.end();
} else {
builder.create<fir::StoreOp>(loc, castBox, args.addrInTuple);
}
}
static void getFromTuple(const GetFromTuple &args,
Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym,
const Fortran::lower::BoxAnalyzer &ba) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Location loc = args.loc;
mlir::Value box = args.valueInTuple;
if (Fortran::semantics::IsOptional(sym)) {
auto boxTy = mlir::cast<fir::BaseBoxType>(box.getType());
auto eleTy = boxTy.getEleTy();
if (!fir::isa_ref_type(eleTy))
eleTy = builder.getRefType(eleTy);
auto addr = builder.create<fir::BoxAddrOp>(loc, eleTy, box);
mlir::Value isPresent = builder.genIsNotNullAddr(loc, addr);
auto absentBox = builder.create<fir::AbsentOp>(loc, boxTy);
box =
builder.create<mlir::arith::SelectOp>(loc, isPresent, box, absentBox);
}
bindCapturedSymbol(sym, box, converter, args.symMap);
}
};
class CapturedAllocatableAndPointer
: public CapturedSymbols<CapturedAllocatableAndPointer> {
public:
static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym) {
mlir::Type baseType = converter.genType(sym);
if (sym.GetUltimate().test(Fortran::semantics::Symbol::Flag::CrayPointee))
return fir::ReferenceType::get(
Fortran::lower::getCrayPointeeBoxType(baseType));
return fir::ReferenceType::get(baseType);
}
static void instantiateHostTuple(const InstantiateHostTuple &args,
Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &) {
assert(args.hostValue.getBoxOf<fir::MutableBoxValue>() &&
"host value must be a fir::MutableBoxValue");
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Type typeInTuple = fir::dyn_cast_ptrEleTy(args.addrInTuple.getType());
assert(typeInTuple && "addrInTuple must be an address");
mlir::Value castBox = builder.createConvert(args.loc, typeInTuple,
fir::getBase(args.hostValue));
builder.create<fir::StoreOp>(args.loc, castBox, args.addrInTuple);
}
static void getFromTuple(const GetFromTuple &args,
Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym,
const Fortran::lower::BoxAnalyzer &ba) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Location loc = args.loc;
llvm::SmallVector<mlir::Value> nonDeferredLenParams;
if (ba.isChar()) {
mlir::IndexType idxTy = builder.getIndexType();
if (std::optional<int64_t> len = ba.getCharLenConst()) {
nonDeferredLenParams.push_back(
builder.createIntegerConstant(loc, idxTy, *len));
} else if (Fortran::semantics::IsAssumedLengthCharacter(sym) ||
ba.getCharLenExpr()) {
nonDeferredLenParams.push_back(
Fortran::lower::getAssumedCharAllocatableOrPointerLen(
builder, loc, sym, args.valueInTuple));
}
} else if (isDerivedWithLenParameters(sym)) {
TODO(loc, "host associated derived type allocatable or pointer with "
"length parameters");
}
bindCapturedSymbol(
sym, fir::MutableBoxValue(args.valueInTuple, nonDeferredLenParams, {}),
converter, args.symMap);
}
};
class CapturedArrays : public CapturedSymbols<CapturedArrays> {
public:
static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym) {
mlir::Type type = converter.genType(sym);
bool isPolymorphic = Fortran::semantics::IsPolymorphic(sym);
assert((mlir::isa<fir::SequenceType>(type) ||
(isPolymorphic && mlir::isa<fir::ClassType>(type))) &&
"must be a sequence type");
if (isPolymorphic)
return type;
return fir::BoxType::get(type);
}
static void instantiateHostTuple(const InstantiateHostTuple &args,
Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Location loc = args.loc;
fir::MutableBoxValue boxInTuple(args.addrInTuple, {}, {});
if (args.hostValue.getBoxOf<fir::BoxValue>() &&
Fortran::semantics::IsOptional(sym)) {
auto isPresent = builder.create<fir::IsPresentOp>(
loc, builder.getI1Type(), fir::getBase(args.hostValue));
builder.genIfThenElse(loc, isPresent)
.genThen([&]() {
fir::factory::associateMutableBox(builder, loc, boxInTuple,
args.hostValue,
std::nullopt);
})
.genElse([&]() {
fir::factory::disassociateMutableBox(builder, loc, boxInTuple);
})
.end();
} else {
fir::factory::associateMutableBox(
builder, loc, boxInTuple, args.hostValue, std::nullopt);
}
}
static void getFromTuple(const GetFromTuple &args,
Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym,
const Fortran::lower::BoxAnalyzer &ba) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Location loc = args.loc;
mlir::Value box = args.valueInTuple;
mlir::IndexType idxTy = builder.getIndexType();
llvm::SmallVector<mlir::Value> lbounds;
if (!ba.lboundIsAllOnes() && !Fortran::evaluate::IsAssumedRank(sym)) {
if (ba.isStaticArray()) {
for (std::int64_t lb : ba.staticLBound())
lbounds.emplace_back(builder.createIntegerConstant(loc, idxTy, lb));
} else {
const unsigned rank = sym.Rank();
for (unsigned dim = 0; dim < rank; ++dim) {
mlir::Value dimVal = builder.createIntegerConstant(loc, idxTy, dim);
auto dims = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy,
box, dimVal);
lbounds.emplace_back(dims.getResult(0));
}
}
}
if (canReadCapturedBoxValue(converter, sym)) {
fir::BoxValue boxValue(box, lbounds, std::nullopt);
bindCapturedSymbol(sym,
fir::factory::readBoxValue(builder, loc, boxValue),
converter, args.symMap);
} else {
if (Fortran::semantics::IsOptional(sym)) {
auto boxTy = mlir::cast<fir::BaseBoxType>(box.getType());
auto eleTy = boxTy.getEleTy();
if (!fir::isa_ref_type(eleTy))
eleTy = builder.getRefType(eleTy);
auto addr = builder.create<fir::BoxAddrOp>(loc, eleTy, box);
mlir::Value isPresent = builder.genIsNotNullAddr(loc, addr);
auto absentBox = builder.create<fir::AbsentOp>(loc, boxTy);
box = builder.create<mlir::arith::SelectOp>(loc, isPresent, box,
absentBox);
}
fir::BoxValue boxValue(box, lbounds, std::nullopt);
bindCapturedSymbol(sym, boxValue, converter, args.symMap);
}
}
private:
static bool
canReadCapturedBoxValue(Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym) {
bool isScalarOrContiguous =
sym.Rank() == 0 || Fortran::evaluate::IsSimplyContiguous(
Fortran::evaluate::AsGenericExpr(sym).value(),
converter.getFoldingContext());
const Fortran::semantics::DeclTypeSpec *type = sym.GetType();
bool isPolymorphic = type && type->IsPolymorphic();
return isScalarOrContiguous && !isPolymorphic &&
!isDerivedWithLenParameters(sym) &&
!Fortran::evaluate::IsAssumedRank(sym);
}
};
}
template <typename T>
static typename T::Result
walkCaptureCategories(T visitor, Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym) {
if (isDerivedWithLenParameters(sym))
TODO(converter.genLocation(sym.name()),
"host associated derived type with length parameters");
Fortran::lower::BoxAnalyzer ba;
if (Fortran::semantics::IsProcedure(sym))
return CapturedProcedure::visit(visitor, converter, sym, ba);
ba.analyze(sym);
if (Fortran::semantics::IsAllocatableOrPointer(sym) ||
sym.GetUltimate().test(Fortran::semantics::Symbol::Flag::CrayPointee))
return CapturedAllocatableAndPointer::visit(visitor, converter, sym, ba);
if (ba.isArray())
return CapturedArrays::visit(visitor, converter, sym, ba);
if (Fortran::semantics::IsPolymorphic(sym))
return CapturedPolymorphicScalar::visit(visitor, converter, sym, ba);
if (ba.isChar())
return CapturedCharacterScalars::visit(visitor, converter, sym, ba);
assert(ba.isTrivial() && "must be trivial scalar");
return CapturedSimpleScalars::visit(visitor, converter, sym, ba);
}
static mlir::TupleType unwrapTupleTy(mlir::Type t) {
return mlir::cast<mlir::TupleType>(fir::dyn_cast_ptrEleTy(t));
}
static mlir::Value genTupleCoor(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type varTy, mlir::Value tupleArg,
mlir::Value offset) {
auto ty = mlir::isa<fir::ReferenceType>(varTy)
? mlir::Type(fir::LLVMPointerType::get(varTy))
: mlir::Type(builder.getRefType(varTy));
return builder.create<fir::CoordinateOp>(loc, ty, tupleArg, offset);
}
void Fortran::lower::HostAssociations::addSymbolsToBind(
const llvm::SetVector<const Fortran::semantics::Symbol *> &symbols,
const Fortran::semantics::Scope &hostScope) {
assert(tupleSymbols.empty() && globalSymbols.empty() &&
"must be initially empty");
this->hostScope = &hostScope;
for (const auto *s : symbols)
if (Fortran::lower::symbolIsGlobal(*s) ||
(*s).test(Fortran::semantics::Symbol::Flag::OmpThreadprivate)) {
globalSymbols.insert(&s->GetUltimate());
} else {
tupleSymbols.insert(s);
}
}
void Fortran::lower::HostAssociations::hostProcedureBindings(
Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symMap) {
if (tupleSymbols.empty())
return;
mlir::TupleType tupTy = unwrapTupleTy(getArgumentType(converter));
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Location loc = converter.getCurrentLocation();
auto hostTuple = builder.create<fir::AllocaOp>(loc, tupTy);
mlir::IntegerType offTy = builder.getIntegerType(32);
for (auto s : llvm::enumerate(tupleSymbols)) {
auto indexInTuple = s.index();
mlir::Value off = builder.createIntegerConstant(loc, offTy, indexInTuple);
mlir::Type varTy = tupTy.getType(indexInTuple);
mlir::Value eleOff = genTupleCoor(builder, loc, varTy, hostTuple, off);
InstantiateHostTuple instantiateHostTuple{
converter.getSymbolExtendedValue(*s.value(), &symMap), eleOff, loc};
walkCaptureCategories(instantiateHostTuple, converter, *s.value());
}
converter.bindHostAssocTuple(hostTuple);
}
void Fortran::lower::HostAssociations::internalProcedureBindings(
Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symMap) {
if (!globalSymbols.empty()) {
assert(hostScope && "host scope must have been set");
Fortran::lower::AggregateStoreMap storeMap;
for (auto &hostVariable : pft::getScopeVariableList(*hostScope))
if ((hostVariable.isAggregateStore() && hostVariable.isGlobal()) ||
(hostVariable.hasSymbol() &&
globalSymbols.contains(&hostVariable.getSymbol().GetUltimate()))) {
Fortran::lower::instantiateVariable(converter, hostVariable, symMap,
storeMap);
if (hostVariable.hasSymbol() &&
hostVariable.getSymbol().test(
Fortran::semantics::Symbol::Flag::OmpThreadprivate))
Fortran::lower::genThreadprivateOp(converter, hostVariable);
}
}
if (tupleSymbols.empty())
return;
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Type argTy = getArgumentType(converter);
mlir::TupleType tupTy = unwrapTupleTy(argTy);
mlir::Location loc = converter.getCurrentLocation();
mlir::func::FuncOp func = builder.getFunction();
mlir::Value tupleArg;
for (auto [ty, arg] : llvm::reverse(llvm::zip(
func.getFunctionType().getInputs(), func.front().getArguments())))
if (ty == argTy) {
tupleArg = arg;
break;
}
if (!tupleArg)
fir::emitFatalError(loc, "no host association argument found");
converter.bindHostAssocTuple(tupleArg);
mlir::IntegerType offTy = builder.getIntegerType(32);
for (auto s : llvm::enumerate(tupleSymbols)) {
mlir::Value off = builder.createIntegerConstant(loc, offTy, s.index());
mlir::Type varTy = tupTy.getType(s.index());
mlir::Value eleOff = genTupleCoor(builder, loc, varTy, tupleArg, off);
mlir::Value valueInTuple = builder.create<fir::LoadOp>(loc, eleOff);
GetFromTuple getFromTuple{symMap, valueInTuple, loc};
walkCaptureCategories(getFromTuple, converter, *s.value());
}
}
mlir::Type Fortran::lower::HostAssociations::getArgumentType(
Fortran::lower::AbstractConverter &converter) {
if (tupleSymbols.empty())
return {};
if (argType)
return argType;
mlir::MLIRContext *ctxt = &converter.getMLIRContext();
llvm::SmallVector<mlir::Type> tupleTys;
for (const Fortran::semantics::Symbol *sym : tupleSymbols)
tupleTys.emplace_back(
walkCaptureCategories(GetTypeInTuple{}, converter, *sym));
argType = fir::ReferenceType::get(mlir::TupleType::get(ctxt, tupleTys));
return argType;
}