#ifndef _OMPTARGET_OMPTINTERFACE_H
#define _OMPTARGET_OMPTINTERFACE_H
#ifdef OMPT_SUPPORT
#include <functional>
#include <tuple>
#include "Callback.h"
#include "omp-tools.h"
#include "llvm/Support/ErrorHandling.h"
#define OMPT_IF_BUILT(stmt) stmt
typedef ompt_data_t *(*ompt_get_task_data_t)();
typedef ompt_data_t *(*ompt_get_target_task_data_t)();
namespace llvm {
namespace omp {
namespace target {
namespace ompt {
static ompt_get_task_data_t ompt_get_task_data_fn;
static ompt_get_target_task_data_t ompt_get_target_task_data_fn;
class Interface {
public:
void beginTargetDataAlloc(int64_t DeviceId, void *HstPtrBegin,
void **TgtPtrBegin, size_t Size, void *Code);
void endTargetDataAlloc(int64_t DeviceId, void *HstPtrBegin,
void **TgtPtrBegin, size_t Size, void *Code);
void beginTargetDataSubmit(int64_t SrcDeviceId, void *SrcPtrBegin,
int64_t DstDeviceId, void *DstPtrBegin,
size_t Size, void *Code);
void endTargetDataSubmit(int64_t SrcDeviceId, void *SrcPtrBegin,
int64_t DstDeviceId, void *DstPtrBegin, size_t Size,
void *Code);
void beginTargetDataDelete(int64_t DeviceId, void *TgtPtrBegin, void *Code);
void endTargetDataDelete(int64_t DeviceId, void *TgtPtrBegin, void *Code);
void beginTargetDataRetrieve(int64_t SrcDeviceId, void *SrcPtrBegin,
int64_t DstDeviceId, void *DstPtrBegin,
size_t Size, void *Code);
void endTargetDataRetrieve(int64_t SrcDeviceId, void *SrcPtrBegin,
int64_t DstDeviceId, void *DstPtrBegin,
size_t Size, void *Code);
void beginTargetSubmit(unsigned int NumTeams = 1);
void endTargetSubmit(unsigned int NumTeams = 1);
void beginTargetDataEnter(int64_t DeviceId, void *Code);
void endTargetDataEnter(int64_t DeviceId, void *Code);
void beginTargetDataExit(int64_t DeviceId, void *Code);
void endTargetDataExit(int64_t DeviceId, void *Code);
void beginTargetUpdate(int64_t DeviceId, void *Code);
void endTargetUpdate(int64_t DeviceId, void *Code);
void beginTargetAssociatePointer(int64_t DeviceId, void *HstPtrBegin,
void *TgtPtrBegin, size_t Size, void *Code);
void endTargetAssociatePointer(int64_t DeviceId, void *HstPtrBegin,
void *TgtPtrBegin, size_t Size, void *Code);
void beginTargetDisassociatePointer(int64_t DeviceId, void *HstPtrBegin,
void *TgtPtrBegin, size_t Size,
void *Code);
void endTargetDisassociatePointer(int64_t DeviceId, void *HstPtrBegin,
void *TgtPtrBegin, size_t Size, void *Code);
void beginTarget(int64_t DeviceId, void *Code);
void endTarget(int64_t DeviceId, void *Code);
template <ompt_target_data_op_t OpType> auto getCallbacks() {
if constexpr (OpType == ompt_target_data_alloc ||
OpType == ompt_target_data_alloc_async)
return std::make_pair(std::mem_fn(&Interface::beginTargetDataAlloc),
std::mem_fn(&Interface::endTargetDataAlloc));
if constexpr (OpType == ompt_target_data_delete ||
OpType == ompt_target_data_delete_async)
return std::make_pair(std::mem_fn(&Interface::beginTargetDataDelete),
std::mem_fn(&Interface::endTargetDataDelete));
if constexpr (OpType == ompt_target_data_transfer_to_device ||
OpType == ompt_target_data_transfer_to_device_async)
return std::make_pair(std::mem_fn(&Interface::beginTargetDataSubmit),
std::mem_fn(&Interface::endTargetDataSubmit));
if constexpr (OpType == ompt_target_data_transfer_from_device ||
OpType == ompt_target_data_transfer_from_device_async)
return std::make_pair(std::mem_fn(&Interface::beginTargetDataRetrieve),
std::mem_fn(&Interface::endTargetDataRetrieve));
if constexpr (OpType == ompt_target_data_associate)
return std::make_pair(
std::mem_fn(&Interface::beginTargetAssociatePointer),
std::mem_fn(&Interface::endTargetAssociatePointer));
if constexpr (OpType == ompt_target_data_disassociate)
return std::make_pair(
std::mem_fn(&Interface::beginTargetDisassociatePointer),
std::mem_fn(&Interface::endTargetDisassociatePointer));
llvm_unreachable("Unhandled target data operation type!");
}
template <ompt_target_t OpType> auto getCallbacks() {
if constexpr (OpType == ompt_target_enter_data ||
OpType == ompt_target_enter_data_nowait)
return std::make_pair(std::mem_fn(&Interface::beginTargetDataEnter),
std::mem_fn(&Interface::endTargetDataEnter));
if constexpr (OpType == ompt_target_exit_data ||
OpType == ompt_target_exit_data_nowait)
return std::make_pair(std::mem_fn(&Interface::beginTargetDataExit),
std::mem_fn(&Interface::endTargetDataExit));
if constexpr (OpType == ompt_target_update ||
OpType == ompt_target_update_nowait)
return std::make_pair(std::mem_fn(&Interface::beginTargetUpdate),
std::mem_fn(&Interface::endTargetUpdate));
if constexpr (OpType == ompt_target || OpType == ompt_target_nowait)
return std::make_pair(std::mem_fn(&Interface::beginTarget),
std::mem_fn(&Interface::endTarget));
llvm_unreachable("Unknown target region operation type!");
}
template <ompt_callbacks_t OpType> auto getCallbacks() {
if constexpr (OpType == ompt_callback_target_submit)
return std::make_pair(std::mem_fn(&Interface::beginTargetSubmit),
std::mem_fn(&Interface::endTargetSubmit));
llvm_unreachable("Unhandled target operation!");
}
void setTargetDataValue(uint64_t DataValue) { TargetData.value = DataValue; }
void setTargetDataPtr(void *DataPtr) { TargetData.ptr = DataPtr; }
void setHostOpId(ompt_id_t OpId) { HostOpId = OpId; }
uint64_t getTargetDataValue() { return TargetData.value; }
void *getTargetDataPtr() { return TargetData.ptr; }
ompt_id_t getHostOpId() { return HostOpId; }
private:
ompt_id_t HostOpId = 0;
ompt_data_t TargetData = ompt_data_none;
ompt_data_t *TaskData = nullptr;
ompt_data_t *TargetTaskData = nullptr;
void beginTargetDataOperation();
void endTargetDataOperation();
void beginTargetRegion();
void endTargetRegion();
};
extern thread_local Interface RegionInterface;
extern thread_local void *ReturnAddress;
template <typename FuncTy, typename ArgsTy, size_t... IndexSeq>
void InvokeInterfaceFunction(FuncTy Func, ArgsTy Args,
std::index_sequence<IndexSeq...>) {
std::invoke(Func, RegionInterface, std::get<IndexSeq>(Args)...);
}
template <typename CallbackPairTy, typename... ArgsTy> class InterfaceRAII {
public:
InterfaceRAII(CallbackPairTy Callbacks, ArgsTy... Args)
: Arguments(Args...), beginFunction(std::get<0>(Callbacks)),
endFunction(std::get<1>(Callbacks)) {
performIfOmptInitialized(begin());
}
~InterfaceRAII() { performIfOmptInitialized(end()); }
private:
void begin() {
auto IndexSequence =
std::make_index_sequence<std::tuple_size_v<decltype(Arguments)>>{};
InvokeInterfaceFunction(beginFunction, Arguments, IndexSequence);
}
void end() {
auto IndexSequence =
std::make_index_sequence<std::tuple_size_v<decltype(Arguments)>>{};
InvokeInterfaceFunction(endFunction, Arguments, IndexSequence);
}
std::tuple<ArgsTy...> Arguments;
typename CallbackPairTy::first_type beginFunction;
typename CallbackPairTy::second_type endFunction;
};
template <typename CallbackPairTy, typename... ArgsTy>
InterfaceRAII(CallbackPairTy Callbacks, ArgsTy... Args)
-> InterfaceRAII<CallbackPairTy, ArgsTy...>;
class ReturnAddressSetterRAII {
public:
ReturnAddressSetterRAII(void *RA) : IsSetter(false) {
if (ReturnAddress == nullptr) {
ReturnAddress = RA;
IsSetter = true;
}
}
~ReturnAddressSetterRAII() {
if (IsSetter)
ReturnAddress = nullptr;
}
private:
bool IsSetter;
};
}
}
}
}
#define OMPT_GET_RETURN_ADDRESS llvm::omp::target::ompt::ReturnAddress
#else
#define OMPT_IF_BUILT(stmt)
#endif
#endif