#include "polly/CodeGen/IRBuilder.h"
#include "polly/CodeGen/PPCGCodeGeneration.h"
#include "polly/DependenceInfo.h"
#include "polly/LinkAllPasses.h"
#include "polly/Options.h"
#include "polly/ScopDetection.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Analysis/CaptureTracking.h"
#include "llvm/InitializePasses.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
using namespace llvm;
using namespace polly;
static cl::opt<bool> RewriteAllocas(
"polly-acc-rewrite-allocas",
cl::desc(
"Ask the managed memory rewriter to also rewrite alloca instructions"),
cl::Hidden, cl::cat(PollyCategory));
static cl::opt<bool> IgnoreLinkageForGlobals(
"polly-acc-rewrite-ignore-linkage-for-globals",
cl::desc(
"By default, we only rewrite globals with internal linkage. This flag "
"enables rewriting of globals regardless of linkage"),
cl::Hidden, cl::cat(PollyCategory));
#define DEBUG_TYPE "polly-acc-rewrite-managed-memory"
namespace {
static llvm::Function *getOrCreatePollyMallocManaged(Module &M) {
const char *Name = "polly_mallocManaged";
Function *F = M.getFunction(Name);
if (!F) {
GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
PollyIRBuilder Builder(M.getContext());
FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(),
{Builder.getInt64Ty()}, false);
F = Function::Create(Ty, Linkage, Name, &M);
}
return F;
}
static llvm::Function *getOrCreatePollyFreeManaged(Module &M) {
const char *Name = "polly_freeManaged";
Function *F = M.getFunction(Name);
if (!F) {
GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
PollyIRBuilder Builder(M.getContext());
FunctionType *Ty =
FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false);
F = Function::Create(Ty, Linkage, Name, &M);
}
return F;
}
// \ /
// D D
static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder,
Instruction *Parent, int index,
SmallPtrSet<Instruction *, 4> &Expands) {
assert(Cur && "invalid constant expression passed");
Instruction *I = Cur->getAsInstruction();
assert(I && "unable to convert ConstantExpr to Instruction");
LLVM_DEBUG(dbgs() << "Expanding ConstantExpression: (" << *Cur
<< ") in Instruction: (" << *I << ")\n";);
Cur = nullptr;
Expands.insert(I);
Parent->setOperand(index, I);
Builder.SetInsertPoint(Parent);
Builder.Insert(I);
for (unsigned i = 0; i < I->getNumOperands(); i++) {
Value *Op = I->getOperand(i);
assert(isa<Constant>(Op) && "constant must have a constant operand");
if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op))
expandConstantExpr(CExprOp, Builder, I, i, Expands);
}
}
static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal,
PollyIRBuilder &Builder) {
SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst};
for (unsigned i = 0; i < Inst->getNumOperands(); i++) {
Value *Operand = Inst->getOperand(i);
if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand))
expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit);
}
for (Instruction *I : InstsToVisit)
I->replaceUsesOfWith(OldVal, NewVal);
}
static void getInstructionUsersOfValue(Value *V,
SmallVector<Instruction *, 4> &Owners) {
if (auto *I = dyn_cast<Instruction>(V)) {
Owners.push_back(I);
} else {
auto *C = cast<Constant>(V);
for (Use &CUse : C->uses())
getInstructionUsersOfValue(CUse.getUser(), Owners);
}
}
static void
replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array,
SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) {
ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getValueType());
if (!ArrayTy)
return;
Type *ElemTy = ArrayTy->getElementType();
PointerType *ElemPtrTy = ElemTy->getPointerTo();
const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() ||
Array.hasInternalLinkage() ||
IgnoreLinkageForGlobals;
if (!OnlyVisibleInsideModule) {
LLVM_DEBUG(
dbgs() << "Not rewriting (" << Array
<< ") to managed memory "
"because it could be visible externally. To force rewrite, "
"use -polly-acc-rewrite-ignore-linkage-for-globals.\n");
return;
}
if (!Array.hasInitializer() ||
!isa<ConstantAggregateZero>(Array.getInitializer())) {
LLVM_DEBUG(dbgs() << "Not rewriting (" << Array
<< ") to managed memory "
"because it has an initializer which is "
"not a zeroinitializer.\n");
return;
}
ReplacedGlobals.insert(&Array);
std::string NewName = Array.getName().str();
NewName += ".toptr";
GlobalVariable *ReplacementToArr =
cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy));
ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy));
Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
std::string FnName = Array.getName().str();
FnName += ".constructor";
PollyIRBuilder Builder(M.getContext());
FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false);
const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
Function *F = Function::Create(Ty, Linkage, FnName, &M);
BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F);
Builder.SetInsertPoint(Start);
const uint64_t ArraySizeInt = DL.getTypeAllocSize(ArrayTy);
Value *ArraySize = Builder.getInt64(ArraySizeInt);
ArraySize->setName("array.size");
Value *AllocatedMemRaw =
Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw");
Value *AllocatedMemTyped =
Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed");
Builder.CreateStore(AllocatedMemTyped, ReplacementToArr);
Builder.CreateRetVoid();
const int Priority = 0;
appendToGlobalCtors(M, F, Priority, ReplacementToArr);
SmallVector<Instruction *, 4> ArrayUserInstructions;
for (Use &ArrayUse : Array.uses())
getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions);
for (Instruction *UserOfArrayInst : ArrayUserInstructions) {
Builder.SetInsertPoint(UserOfArrayInst);
Value *ArrPtrLoaded =
Builder.CreateLoad(ElemPtrTy, ReplacementToArr, "arrptr.load");
Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast(
ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast");
rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder);
}
}
static void getAllocasToBeManaged(Function &F,
SmallSet<AllocaInst *, 4> &Allocas) {
for (BasicBlock &BB : F) {
for (Instruction &I : BB) {
auto *Alloca = dyn_cast<AllocaInst>(&I);
if (!Alloca)
continue;
LLVM_DEBUG(dbgs() << "Checking if (" << *Alloca << ") may be captured: ");
if (PointerMayBeCaptured(Alloca, false,
true)) {
Allocas.insert(Alloca);
LLVM_DEBUG(dbgs() << "YES (captured).\n");
} else {
LLVM_DEBUG(dbgs() << "NO (not captured).\n");
}
}
}
}
static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca,
const DataLayout &DL) {
LLVM_DEBUG(dbgs() << "rewriting: (" << *Alloca << ") to managed mem.\n");
Module *M = Alloca->getModule();
assert(M && "Alloca does not have a module");
PollyIRBuilder Builder(M->getContext());
Builder.SetInsertPoint(Alloca);
Function *MallocManagedFn =
getOrCreatePollyMallocManaged(*Alloca->getModule());
const uint64_t Size = DL.getTypeAllocSize(Alloca->getAllocatedType());
Value *SizeVal = Builder.getInt64(Size);
Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal});
Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType());
Function *F = Alloca->getFunction();
assert(F && "Alloca has invalid function");
Bitcasted->takeName(Alloca);
Alloca->replaceAllUsesWith(Bitcasted);
Alloca->eraseFromParent();
for (BasicBlock &BB : *F) {
ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator());
if (!Return)
continue;
Builder.SetInsertPoint(Return);
Function *FreeManagedFn = getOrCreatePollyFreeManaged(*M);
Builder.CreateCall(FreeManagedFn, {RawManagedMem});
}
}
static void replaceAllUsesAndConstantUses(Value *Old, Value *New,
PollyIRBuilder &Builder) {
SmallVector<Instruction *, 4> UserInstructions;
for (Use &ArrayUse : Old->uses())
getInstructionUsersOfValue(ArrayUse.getUser(), UserInstructions);
for (Instruction *I : UserInstructions)
rewriteOldValToNew(I, Old, New, Builder);
}
class ManagedMemoryRewritePass final : public ModulePass {
public:
static char ID;
GPUArch Architecture;
GPURuntime Runtime;
ManagedMemoryRewritePass() : ModulePass(ID) {}
bool runOnModule(Module &M) override {
const DataLayout &DL = M.getDataLayout();
Function *Malloc = M.getFunction("malloc");
if (Malloc) {
PollyIRBuilder Builder(M.getContext());
Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
assert(PollyMallocManaged && "unable to create polly_mallocManaged");
replaceAllUsesAndConstantUses(Malloc, PollyMallocManaged, Builder);
Malloc->eraseFromParent();
}
Function *Free = M.getFunction("free");
if (Free) {
PollyIRBuilder Builder(M.getContext());
Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M);
assert(PollyFreeManaged && "unable to create polly_freeManaged");
replaceAllUsesAndConstantUses(Free, PollyFreeManaged, Builder);
Free->eraseFromParent();
}
SmallPtrSet<GlobalVariable *, 4> GlobalsToErase;
for (GlobalVariable &Global : M.globals())
replaceGlobalArray(M, DL, Global, GlobalsToErase);
for (GlobalVariable *G : GlobalsToErase)
G->eraseFromParent();
if (RewriteAllocas) {
SmallSet<AllocaInst *, 4> AllocasToBeManaged;
for (Function &F : M.functions())
getAllocasToBeManaged(F, AllocasToBeManaged);
for (AllocaInst *Alloca : AllocasToBeManaged)
rewriteAllocaAsManagedMemory(Alloca, DL);
}
return true;
}
};
}
char ManagedMemoryRewritePass::ID = 42;
Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch,
GPURuntime Runtime) {
ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass();
pass->Runtime = Runtime;
pass->Architecture = Arch;
return pass;
}
INITIALIZE_PASS_BEGIN(
ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
"Polly - Rewrite all allocations in heap & data section to managed memory",
false, false)
INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration);
INITIALIZE_PASS_DEPENDENCY(DependenceInfo);
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass);
INITIALIZE_PASS_DEPENDENCY(RegionInfoPass);
INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass);
INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass);
INITIALIZE_PASS_END(
ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
"Polly - Rewrite all allocations in heap & data section to managed memory",
false, false)