#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/PassPlugin.h"
#include <map>
using namespace llvm;
namespace {
struct LoadStoreMemSpace : public PassInfoMixin<LoadStoreMemSpace> {
PreservedAnalyses run(llvm::Module &module, ModuleAnalysisManager &) {
bool modifiedCodeGen = runOnModule(module);
return (modifiedCodeGen ? llvm::PreservedAnalyses::none()
: llvm::PreservedAnalyses::all());
}
bool runOnModule(llvm::Module &module);
static bool isRequired() { return true; }
};
}
static std::map<int, std::string> AddrSpaceMap = {
{0, "FLAT"}, {1, "GLOBAL"}, {3, "SHARED"}, {4, "CONSTANT"}, {5, "SCRATCH"}};
static std::map<std::string, uint32_t> LocationCounterSourceMap;
static std::string LoadOrStoreMap(const BasicBlock::iterator &I) {
if (LoadInst *LI = dyn_cast<LoadInst>(I))
return "LOAD";
else if (StoreInst *SI = dyn_cast<StoreInst>(I))
return "STORE";
else
throw std::runtime_error("Error: unknown operation type");
}
template <typename LoadOrStoreInst>
static void InstrumentationFunction(const BasicBlock::iterator &I,
const Function &F, const llvm::Module &M,
uint32_t &LocationCounter) {
auto LSI = dyn_cast<LoadOrStoreInst>(I);
if (not LSI)
return;
Value *Op = LSI->getPointerOperand()->stripPointerCasts();
uint32_t AddrSpace = cast<PointerType>(Op->getType())->getAddressSpace();
DILocation *DL = dyn_cast<Instruction>(I)->getDebugLoc();
std::string SourceAndAddrSpaceInfo =
(F.getName() + " " + DL->getFilename() + ":" + Twine(DL->getLine()) +
":" + Twine(DL->getColumn()))
.str() +
" " + AddrSpaceMap[AddrSpace] + " " + LoadOrStoreMap(I);
if (LocationCounterSourceMap.find(SourceAndAddrSpaceInfo) ==
LocationCounterSourceMap.end()) {
errs() << LocationCounter << " " << SourceAndAddrSpaceInfo << "\n";
LocationCounterSourceMap[SourceAndAddrSpaceInfo] = LocationCounter;
LocationCounter++;
}
}
bool LoadStoreMemSpace::runOnModule(Module &M) {
bool ModifiedCodeGen = false;
uint32_t LocationCounter = 0;
for (auto &F : M) {
if (F.isIntrinsic())
continue;
StringRef functionName = F.getName();
if (F.getCallingConv() == CallingConv::AMDGPU_KERNEL ||
F.getCallingConv() == CallingConv::PTX_Kernel ||
functionName.contains("kernel")) {
for (Function::iterator BB = F.begin(); BB != F.end(); BB++) {
for (BasicBlock::iterator I = BB->begin(); I != BB->end(); I++) {
if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
InstrumentationFunction<LoadInst>(I, F, M, LocationCounter);
} else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
InstrumentationFunction<StoreInst>(I, F, M, LocationCounter);
}
}
}
}
}
return ModifiedCodeGen;
}
static PassPluginLibraryInfo getPassPluginInfo() {
const auto callback = [](PassBuilder &PB) {
PB.registerOptimizerLastEPCallback([&](ModulePassManager &MPM, auto, auto) {
MPM.addPass(LoadStoreMemSpace());
return true;
});
};
return {LLVM_PLUGIN_API_VERSION, "print-mem-space", LLVM_VERSION_STRING,
callback};
};
extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo llvmGetPassPluginInfo() {
return getPassPluginInfo();
}