18a2c07f6SSiddharth Bhat //===---- ManagedMemoryRewrite.cpp - Rewrite global & malloc'd memory -----===//
2c4a4af47SSiddharth Bhat //
32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information.
52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c4a4af47SSiddharth Bhat //
7c4a4af47SSiddharth Bhat //===----------------------------------------------------------------------===//
8c4a4af47SSiddharth Bhat //
9c4a4af47SSiddharth Bhat // Take a module and rewrite:
10c4a4af47SSiddharth Bhat // 1. `malloc` -> `polly_mallocManaged`
11c4a4af47SSiddharth Bhat // 2. `free` -> `polly_freeManaged`
12c4a4af47SSiddharth Bhat // 3. global arrays with initializers -> global arrays that are initialized
13c4a4af47SSiddharth Bhat // with a constructor call to
14c4a4af47SSiddharth Bhat // `polly_mallocManaged`.
15c4a4af47SSiddharth Bhat //
16c4a4af47SSiddharth Bhat //===----------------------------------------------------------------------===//
17c4a4af47SSiddharth Bhat
18031bb165SMichael Kruse #include "polly/CodeGen/IRBuilder.h"
19c4a4af47SSiddharth Bhat #include "polly/CodeGen/PPCGCodeGeneration.h"
20c4a4af47SSiddharth Bhat #include "polly/DependenceInfo.h"
21c4a4af47SSiddharth Bhat #include "polly/LinkAllPasses.h"
22c4a4af47SSiddharth Bhat #include "polly/Options.h"
23c4a4af47SSiddharth Bhat #include "polly/ScopDetection.h"
24031bb165SMichael Kruse #include "llvm/ADT/SmallSet.h"
258a2c07f6SSiddharth Bhat #include "llvm/Analysis/CaptureTracking.h"
262c831971SMichael Kruse #include "llvm/InitializePasses.h"
278a2c07f6SSiddharth Bhat #include "llvm/Transforms/Utils/ModuleUtils.h"
288a2c07f6SSiddharth Bhat
2991ca9adcSMichael Kruse using namespace llvm;
30031bb165SMichael Kruse using namespace polly;
31031bb165SMichael Kruse
328a2c07f6SSiddharth Bhat static cl::opt<bool> RewriteAllocas(
338a2c07f6SSiddharth Bhat "polly-acc-rewrite-allocas",
348a2c07f6SSiddharth Bhat cl::desc(
358a2c07f6SSiddharth Bhat "Ask the managed memory rewriter to also rewrite alloca instructions"),
3636c7d79dSFangrui Song cl::Hidden, cl::cat(PollyCategory));
378a2c07f6SSiddharth Bhat
388a2c07f6SSiddharth Bhat static cl::opt<bool> IgnoreLinkageForGlobals(
398a2c07f6SSiddharth Bhat "polly-acc-rewrite-ignore-linkage-for-globals",
408a2c07f6SSiddharth Bhat cl::desc(
418a2c07f6SSiddharth Bhat "By default, we only rewrite globals with internal linkage. This flag "
428a2c07f6SSiddharth Bhat "enables rewriting of globals regardless of linkage"),
43*95a13425SFangrui Song cl::Hidden, cl::cat(PollyCategory));
448a2c07f6SSiddharth Bhat
458a2c07f6SSiddharth Bhat #define DEBUG_TYPE "polly-acc-rewrite-managed-memory"
46c4a4af47SSiddharth Bhat namespace {
47c4a4af47SSiddharth Bhat
getOrCreatePollyMallocManaged(Module & M)488a2c07f6SSiddharth Bhat static llvm::Function *getOrCreatePollyMallocManaged(Module &M) {
49c4a4af47SSiddharth Bhat const char *Name = "polly_mallocManaged";
50c4a4af47SSiddharth Bhat Function *F = M.getFunction(Name);
51c4a4af47SSiddharth Bhat
52c4a4af47SSiddharth Bhat // If F is not available, declare it.
53c4a4af47SSiddharth Bhat if (!F) {
54c4a4af47SSiddharth Bhat GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
55c4a4af47SSiddharth Bhat PollyIRBuilder Builder(M.getContext());
56c4a4af47SSiddharth Bhat // TODO: How do I get `size_t`? I assume from DataLayout?
57c4a4af47SSiddharth Bhat FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(),
58c4a4af47SSiddharth Bhat {Builder.getInt64Ty()}, false);
59c4a4af47SSiddharth Bhat F = Function::Create(Ty, Linkage, Name, &M);
60c4a4af47SSiddharth Bhat }
61c4a4af47SSiddharth Bhat
62c4a4af47SSiddharth Bhat return F;
63c4a4af47SSiddharth Bhat }
64c4a4af47SSiddharth Bhat
getOrCreatePollyFreeManaged(Module & M)658a2c07f6SSiddharth Bhat static llvm::Function *getOrCreatePollyFreeManaged(Module &M) {
66c4a4af47SSiddharth Bhat const char *Name = "polly_freeManaged";
67c4a4af47SSiddharth Bhat Function *F = M.getFunction(Name);
68c4a4af47SSiddharth Bhat
69c4a4af47SSiddharth Bhat // If F is not available, declare it.
70c4a4af47SSiddharth Bhat if (!F) {
71c4a4af47SSiddharth Bhat GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
72c4a4af47SSiddharth Bhat PollyIRBuilder Builder(M.getContext());
73c4a4af47SSiddharth Bhat // TODO: How do I get `size_t`? I assume from DataLayout?
74c4a4af47SSiddharth Bhat FunctionType *Ty =
75c4a4af47SSiddharth Bhat FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false);
76c4a4af47SSiddharth Bhat F = Function::Create(Ty, Linkage, Name, &M);
77c4a4af47SSiddharth Bhat }
78c4a4af47SSiddharth Bhat
79c4a4af47SSiddharth Bhat return F;
80c4a4af47SSiddharth Bhat }
81c4a4af47SSiddharth Bhat
828a2c07f6SSiddharth Bhat // Expand a constant expression `Cur`, which is used at instruction `Parent`
838a2c07f6SSiddharth Bhat // at index `index`.
848a2c07f6SSiddharth Bhat // Since a constant expression can expand to multiple instructions, store all
858a2c07f6SSiddharth Bhat // the expands into a set called `Expands`.
868a2c07f6SSiddharth Bhat // Note that this goes inorder on the constant expression tree.
878a2c07f6SSiddharth Bhat // A * ((B * D) + C)
888a2c07f6SSiddharth Bhat // will be processed with first A, then B * D, then B, then D, and then C.
898a2c07f6SSiddharth Bhat // Though ConstantExprs are not treated as "trees" but as DAGs, since you can
908a2c07f6SSiddharth Bhat // have something like this:
918a2c07f6SSiddharth Bhat // *
928a2c07f6SSiddharth Bhat // / \
938a2c07f6SSiddharth Bhat // \ /
948a2c07f6SSiddharth Bhat // (D)
958a2c07f6SSiddharth Bhat //
968a2c07f6SSiddharth Bhat // For the purposes of this expansion, we expand the two occurences of D
978a2c07f6SSiddharth Bhat // separately. Therefore, we expand the DAG into the tree:
988a2c07f6SSiddharth Bhat // *
998a2c07f6SSiddharth Bhat // / \
1008a2c07f6SSiddharth Bhat // D D
1018a2c07f6SSiddharth Bhat // TODO: We don't _have_to do this, but this is the simplest solution.
1028a2c07f6SSiddharth Bhat // We can write a solution that keeps track of which constants have been
1038a2c07f6SSiddharth Bhat // already expanded.
expandConstantExpr(ConstantExpr * Cur,PollyIRBuilder & Builder,Instruction * Parent,int index,SmallPtrSet<Instruction *,4> & Expands)1048a2c07f6SSiddharth Bhat static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder,
1058a2c07f6SSiddharth Bhat Instruction *Parent, int index,
1068a2c07f6SSiddharth Bhat SmallPtrSet<Instruction *, 4> &Expands) {
1078a2c07f6SSiddharth Bhat assert(Cur && "invalid constant expression passed");
1088a2c07f6SSiddharth Bhat Instruction *I = Cur->getAsInstruction();
109205a78a6SSiddharth Bhat assert(I && "unable to convert ConstantExpr to Instruction");
110205a78a6SSiddharth Bhat
111349506a9SNicola Zaghen LLVM_DEBUG(dbgs() << "Expanding ConstantExpression: (" << *Cur
112a8c329b0SSiddharth Bhat << ") in Instruction: (" << *I << ")\n";);
113205a78a6SSiddharth Bhat
114205a78a6SSiddharth Bhat // Invalidate `Cur` so that no one after this point uses `Cur`. Rather,
115205a78a6SSiddharth Bhat // they should mutate `I`.
116205a78a6SSiddharth Bhat Cur = nullptr;
117205a78a6SSiddharth Bhat
1188a2c07f6SSiddharth Bhat Expands.insert(I);
1198a2c07f6SSiddharth Bhat Parent->setOperand(index, I);
1208a2c07f6SSiddharth Bhat
1218a2c07f6SSiddharth Bhat // The things that `Parent` uses (its operands) should be created
1228a2c07f6SSiddharth Bhat // before `Parent`.
1238a2c07f6SSiddharth Bhat Builder.SetInsertPoint(Parent);
1248a2c07f6SSiddharth Bhat Builder.Insert(I);
1258a2c07f6SSiddharth Bhat
126205a78a6SSiddharth Bhat for (unsigned i = 0; i < I->getNumOperands(); i++) {
127205a78a6SSiddharth Bhat Value *Op = I->getOperand(i);
1288a2c07f6SSiddharth Bhat assert(isa<Constant>(Op) && "constant must have a constant operand");
1298a2c07f6SSiddharth Bhat
1308a2c07f6SSiddharth Bhat if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op))
1318a2c07f6SSiddharth Bhat expandConstantExpr(CExprOp, Builder, I, i, Expands);
1328a2c07f6SSiddharth Bhat }
1338a2c07f6SSiddharth Bhat }
1348a2c07f6SSiddharth Bhat
1358a2c07f6SSiddharth Bhat // Edit all uses of `OldVal` to NewVal` in `Inst`. This will rewrite
1368a2c07f6SSiddharth Bhat // `ConstantExpr`s that are used in the `Inst`.
1378a2c07f6SSiddharth Bhat // Note that `replaceAllUsesWith` is insufficient for this purpose because it
1388a2c07f6SSiddharth Bhat // does not rewrite values in `ConstantExpr`s.
rewriteOldValToNew(Instruction * Inst,Value * OldVal,Value * NewVal,PollyIRBuilder & Builder)1398a2c07f6SSiddharth Bhat static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal,
1408a2c07f6SSiddharth Bhat PollyIRBuilder &Builder) {
1418a2c07f6SSiddharth Bhat
1428a2c07f6SSiddharth Bhat // This contains a set of instructions in which OldVal must be replaced.
1438a2c07f6SSiddharth Bhat // We start with `Inst`, and we fill it up with the expanded `ConstantExpr`s
1448a2c07f6SSiddharth Bhat // from `Inst`s arguments.
1458a2c07f6SSiddharth Bhat // We need to go through this process because `replaceAllUsesWith` does not
1468a2c07f6SSiddharth Bhat // actually edit `ConstantExpr`s.
1478a2c07f6SSiddharth Bhat SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst};
1488a2c07f6SSiddharth Bhat
1498a2c07f6SSiddharth Bhat // Expand all `ConstantExpr`s and place it in `InstsToVisit`.
1508a2c07f6SSiddharth Bhat for (unsigned i = 0; i < Inst->getNumOperands(); i++) {
1518a2c07f6SSiddharth Bhat Value *Operand = Inst->getOperand(i);
1528a2c07f6SSiddharth Bhat if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand))
1538a2c07f6SSiddharth Bhat expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit);
1548a2c07f6SSiddharth Bhat }
1558a2c07f6SSiddharth Bhat
1568a2c07f6SSiddharth Bhat // Now visit each instruction and use `replaceUsesOfWith`. We know that
1578a2c07f6SSiddharth Bhat // will work because `I` cannot have any `ConstantExpr` within it.
1588a2c07f6SSiddharth Bhat for (Instruction *I : InstsToVisit)
1598a2c07f6SSiddharth Bhat I->replaceUsesOfWith(OldVal, NewVal);
1608a2c07f6SSiddharth Bhat }
1618a2c07f6SSiddharth Bhat
1628a2c07f6SSiddharth Bhat // Given a value `Current`, return all Instructions that may contain `Current`
1638a2c07f6SSiddharth Bhat // in an expression.
1648a2c07f6SSiddharth Bhat // We need this auxiliary function, because if we have a
1658a2c07f6SSiddharth Bhat // `Constant` that is a user of `V`, we need to recurse into the
1668a2c07f6SSiddharth Bhat // `Constant`s uses to gather the root instruciton.
getInstructionUsersOfValue(Value * V,SmallVector<Instruction *,4> & Owners)1678a2c07f6SSiddharth Bhat static void getInstructionUsersOfValue(Value *V,
1688a2c07f6SSiddharth Bhat SmallVector<Instruction *, 4> &Owners) {
1698a2c07f6SSiddharth Bhat if (auto *I = dyn_cast<Instruction>(V)) {
1708a2c07f6SSiddharth Bhat Owners.push_back(I);
1718a2c07f6SSiddharth Bhat } else {
1728a2c07f6SSiddharth Bhat // Anything that is a `User` must be a constant or an instruction.
1738a2c07f6SSiddharth Bhat auto *C = cast<Constant>(V);
1748a2c07f6SSiddharth Bhat for (Use &CUse : C->uses())
1758a2c07f6SSiddharth Bhat getInstructionUsersOfValue(CUse.getUser(), Owners);
1768a2c07f6SSiddharth Bhat }
1778a2c07f6SSiddharth Bhat }
1788a2c07f6SSiddharth Bhat
1798a2c07f6SSiddharth Bhat static void
replaceGlobalArray(Module & M,const DataLayout & DL,GlobalVariable & Array,SmallPtrSet<GlobalVariable *,4> & ReplacedGlobals)1808a2c07f6SSiddharth Bhat replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array,
1818a2c07f6SSiddharth Bhat SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) {
1828a2c07f6SSiddharth Bhat // We only want arrays.
183ee423d93SNikita Popov ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getValueType());
1848a2c07f6SSiddharth Bhat if (!ArrayTy)
1858a2c07f6SSiddharth Bhat return;
1868a2c07f6SSiddharth Bhat Type *ElemTy = ArrayTy->getElementType();
1878a2c07f6SSiddharth Bhat PointerType *ElemPtrTy = ElemTy->getPointerTo();
1888a2c07f6SSiddharth Bhat
1898a2c07f6SSiddharth Bhat // We only wish to replace arrays that are visible in the module they
1908a2c07f6SSiddharth Bhat // inhabit. Otherwise, our type edit from [T] to T* would be illegal across
1918a2c07f6SSiddharth Bhat // modules.
1928a2c07f6SSiddharth Bhat const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() ||
1938a2c07f6SSiddharth Bhat Array.hasInternalLinkage() ||
1948a2c07f6SSiddharth Bhat IgnoreLinkageForGlobals;
195557ce3a8SSiddharth Bhat if (!OnlyVisibleInsideModule) {
196349506a9SNicola Zaghen LLVM_DEBUG(
197349506a9SNicola Zaghen dbgs() << "Not rewriting (" << Array
198a8c329b0SSiddharth Bhat << ") to managed memory "
199557ce3a8SSiddharth Bhat "because it could be visible externally. To force rewrite, "
200557ce3a8SSiddharth Bhat "use -polly-acc-rewrite-ignore-linkage-for-globals.\n");
2018a2c07f6SSiddharth Bhat return;
202557ce3a8SSiddharth Bhat }
2038a2c07f6SSiddharth Bhat
2048a2c07f6SSiddharth Bhat if (!Array.hasInitializer() ||
205557ce3a8SSiddharth Bhat !isa<ConstantAggregateZero>(Array.getInitializer())) {
206349506a9SNicola Zaghen LLVM_DEBUG(dbgs() << "Not rewriting (" << Array
207a8c329b0SSiddharth Bhat << ") to managed memory "
208557ce3a8SSiddharth Bhat "because it has an initializer which is "
209557ce3a8SSiddharth Bhat "not a zeroinitializer.\n");
2108a2c07f6SSiddharth Bhat return;
211557ce3a8SSiddharth Bhat }
2128a2c07f6SSiddharth Bhat
2138a2c07f6SSiddharth Bhat // At this point, we have committed to replacing this array.
2148a2c07f6SSiddharth Bhat ReplacedGlobals.insert(&Array);
2158a2c07f6SSiddharth Bhat
2161a53b732SMichael Kruse std::string NewName = Array.getName().str();
21790411189STobias Grosser NewName += ".toptr";
2188a2c07f6SSiddharth Bhat GlobalVariable *ReplacementToArr =
2198a2c07f6SSiddharth Bhat cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy));
2208a2c07f6SSiddharth Bhat ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy));
2218a2c07f6SSiddharth Bhat
2228a2c07f6SSiddharth Bhat Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
2231a53b732SMichael Kruse std::string FnName = Array.getName().str();
22490411189STobias Grosser FnName += ".constructor";
2258a2c07f6SSiddharth Bhat PollyIRBuilder Builder(M.getContext());
2268a2c07f6SSiddharth Bhat FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false);
2278a2c07f6SSiddharth Bhat const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
2288a2c07f6SSiddharth Bhat Function *F = Function::Create(Ty, Linkage, FnName, &M);
2298a2c07f6SSiddharth Bhat BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F);
2308a2c07f6SSiddharth Bhat Builder.SetInsertPoint(Start);
2318a2c07f6SSiddharth Bhat
23260354486SSiddharth Bhat const uint64_t ArraySizeInt = DL.getTypeAllocSize(ArrayTy);
2338a2c07f6SSiddharth Bhat Value *ArraySize = Builder.getInt64(ArraySizeInt);
2348a2c07f6SSiddharth Bhat ArraySize->setName("array.size");
2358a2c07f6SSiddharth Bhat
2368a2c07f6SSiddharth Bhat Value *AllocatedMemRaw =
2378a2c07f6SSiddharth Bhat Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw");
2388a2c07f6SSiddharth Bhat Value *AllocatedMemTyped =
2398a2c07f6SSiddharth Bhat Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed");
2408a2c07f6SSiddharth Bhat Builder.CreateStore(AllocatedMemTyped, ReplacementToArr);
2418a2c07f6SSiddharth Bhat Builder.CreateRetVoid();
2428a2c07f6SSiddharth Bhat
2438a2c07f6SSiddharth Bhat const int Priority = 0;
2448a2c07f6SSiddharth Bhat appendToGlobalCtors(M, F, Priority, ReplacementToArr);
2458a2c07f6SSiddharth Bhat
2468a2c07f6SSiddharth Bhat SmallVector<Instruction *, 4> ArrayUserInstructions;
2478a2c07f6SSiddharth Bhat // Get all instructions that use array. We need to do this weird thing
2488a2c07f6SSiddharth Bhat // because `Constant`s that contain this array neeed to be expanded into
2498a2c07f6SSiddharth Bhat // instructions so that we can replace their parameters. `Constant`s cannot
2508a2c07f6SSiddharth Bhat // be edited easily, so we choose to convert all `Constant`s to
2518a2c07f6SSiddharth Bhat // `Instruction`s and handle all of the uses of `Array` uniformly.
2528a2c07f6SSiddharth Bhat for (Use &ArrayUse : Array.uses())
2538a2c07f6SSiddharth Bhat getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions);
2548a2c07f6SSiddharth Bhat
2558a2c07f6SSiddharth Bhat for (Instruction *UserOfArrayInst : ArrayUserInstructions) {
2568a2c07f6SSiddharth Bhat
2578a2c07f6SSiddharth Bhat Builder.SetInsertPoint(UserOfArrayInst);
2588a2c07f6SSiddharth Bhat // <ty>** -> <ty>*
2599c486eb3SMichael Kruse Value *ArrPtrLoaded =
2609c486eb3SMichael Kruse Builder.CreateLoad(ElemPtrTy, ReplacementToArr, "arrptr.load");
2618a2c07f6SSiddharth Bhat // <ty>* -> [ty]*
2628a2c07f6SSiddharth Bhat Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast(
2638a2c07f6SSiddharth Bhat ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast");
2648a2c07f6SSiddharth Bhat rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder);
2658a2c07f6SSiddharth Bhat }
2668a2c07f6SSiddharth Bhat }
2678a2c07f6SSiddharth Bhat
2688a2c07f6SSiddharth Bhat // We return all `allocas` that may need to be converted to a call to
2698a2c07f6SSiddharth Bhat // cudaMallocManaged.
getAllocasToBeManaged(Function & F,SmallSet<AllocaInst *,4> & Allocas)2708a2c07f6SSiddharth Bhat static void getAllocasToBeManaged(Function &F,
2718a2c07f6SSiddharth Bhat SmallSet<AllocaInst *, 4> &Allocas) {
2728a2c07f6SSiddharth Bhat for (BasicBlock &BB : F) {
2738a2c07f6SSiddharth Bhat for (Instruction &I : BB) {
2748a2c07f6SSiddharth Bhat auto *Alloca = dyn_cast<AllocaInst>(&I);
2758a2c07f6SSiddharth Bhat if (!Alloca)
2768a2c07f6SSiddharth Bhat continue;
277349506a9SNicola Zaghen LLVM_DEBUG(dbgs() << "Checking if (" << *Alloca << ") may be captured: ");
2788a2c07f6SSiddharth Bhat
2798a2c07f6SSiddharth Bhat if (PointerMayBeCaptured(Alloca, /* ReturnCaptures */ false,
2808a2c07f6SSiddharth Bhat /* StoreCaptures */ true)) {
2818a2c07f6SSiddharth Bhat Allocas.insert(Alloca);
282349506a9SNicola Zaghen LLVM_DEBUG(dbgs() << "YES (captured).\n");
2838a2c07f6SSiddharth Bhat } else {
284349506a9SNicola Zaghen LLVM_DEBUG(dbgs() << "NO (not captured).\n");
2858a2c07f6SSiddharth Bhat }
2868a2c07f6SSiddharth Bhat }
2878a2c07f6SSiddharth Bhat }
2888a2c07f6SSiddharth Bhat }
2898a2c07f6SSiddharth Bhat
rewriteAllocaAsManagedMemory(AllocaInst * Alloca,const DataLayout & DL)2908a2c07f6SSiddharth Bhat static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca,
2918a2c07f6SSiddharth Bhat const DataLayout &DL) {
292349506a9SNicola Zaghen LLVM_DEBUG(dbgs() << "rewriting: (" << *Alloca << ") to managed mem.\n");
2938a2c07f6SSiddharth Bhat Module *M = Alloca->getModule();
2948a2c07f6SSiddharth Bhat assert(M && "Alloca does not have a module");
2958a2c07f6SSiddharth Bhat
2968a2c07f6SSiddharth Bhat PollyIRBuilder Builder(M->getContext());
2978a2c07f6SSiddharth Bhat Builder.SetInsertPoint(Alloca);
2988a2c07f6SSiddharth Bhat
299ae2f9512SJames Y Knight Function *MallocManagedFn =
300ae2f9512SJames Y Knight getOrCreatePollyMallocManaged(*Alloca->getModule());
301ee423d93SNikita Popov const uint64_t Size = DL.getTypeAllocSize(Alloca->getAllocatedType());
3028a2c07f6SSiddharth Bhat Value *SizeVal = Builder.getInt64(Size);
3038a2c07f6SSiddharth Bhat Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal});
3048a2c07f6SSiddharth Bhat Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType());
3058a2c07f6SSiddharth Bhat
3068a2c07f6SSiddharth Bhat Function *F = Alloca->getFunction();
3078a2c07f6SSiddharth Bhat assert(F && "Alloca has invalid function");
3088a2c07f6SSiddharth Bhat
3098a2c07f6SSiddharth Bhat Bitcasted->takeName(Alloca);
3108a2c07f6SSiddharth Bhat Alloca->replaceAllUsesWith(Bitcasted);
3118a2c07f6SSiddharth Bhat Alloca->eraseFromParent();
3128a2c07f6SSiddharth Bhat
3138a2c07f6SSiddharth Bhat for (BasicBlock &BB : *F) {
3148a2c07f6SSiddharth Bhat ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator());
3158a2c07f6SSiddharth Bhat if (!Return)
3168a2c07f6SSiddharth Bhat continue;
3178a2c07f6SSiddharth Bhat Builder.SetInsertPoint(Return);
3188a2c07f6SSiddharth Bhat
319ae2f9512SJames Y Knight Function *FreeManagedFn = getOrCreatePollyFreeManaged(*M);
3208a2c07f6SSiddharth Bhat Builder.CreateCall(FreeManagedFn, {RawManagedMem});
3218a2c07f6SSiddharth Bhat }
3228a2c07f6SSiddharth Bhat }
3238a2c07f6SSiddharth Bhat
324a2c41127SSiddharth Bhat // Replace all uses of `Old` with `New`, even inside `ConstantExpr`.
325a2c41127SSiddharth Bhat //
326a2c41127SSiddharth Bhat // `replaceAllUsesWith` does replace values in `ConstantExpr`. This function
327a2c41127SSiddharth Bhat // actually does replace it in `ConstantExpr`. The caveat is that if there is
328a2c41127SSiddharth Bhat // a use that is *outside* a function (say, at global declarations), we fail.
329a2c41127SSiddharth Bhat // So, this is meant to be used on values which we know will only be used
330a2c41127SSiddharth Bhat // within functions.
331a2c41127SSiddharth Bhat //
332a2c41127SSiddharth Bhat // This process works by looking through the uses of `Old`. If it finds a
333a2c41127SSiddharth Bhat // `ConstantExpr`, it recursively looks for the owning instruction.
334a2c41127SSiddharth Bhat // Then, it expands all the `ConstantExpr` to instructions and replaces
335a2c41127SSiddharth Bhat // `Old` with `New` in the expanded instructions.
replaceAllUsesAndConstantUses(Value * Old,Value * New,PollyIRBuilder & Builder)336a2c41127SSiddharth Bhat static void replaceAllUsesAndConstantUses(Value *Old, Value *New,
337a2c41127SSiddharth Bhat PollyIRBuilder &Builder) {
338a2c41127SSiddharth Bhat SmallVector<Instruction *, 4> UserInstructions;
339a2c41127SSiddharth Bhat // Get all instructions that use array. We need to do this weird thing
340a2c41127SSiddharth Bhat // because `Constant`s that contain this array neeed to be expanded into
341a2c41127SSiddharth Bhat // instructions so that we can replace their parameters. `Constant`s cannot
342a2c41127SSiddharth Bhat // be edited easily, so we choose to convert all `Constant`s to
343a2c41127SSiddharth Bhat // `Instruction`s and handle all of the uses of `Array` uniformly.
344a2c41127SSiddharth Bhat for (Use &ArrayUse : Old->uses())
345a2c41127SSiddharth Bhat getInstructionUsersOfValue(ArrayUse.getUser(), UserInstructions);
346a2c41127SSiddharth Bhat
347a2c41127SSiddharth Bhat for (Instruction *I : UserInstructions)
348a2c41127SSiddharth Bhat rewriteOldValToNew(I, Old, New, Builder);
349a2c41127SSiddharth Bhat }
350a2c41127SSiddharth Bhat
351bd93df93SMichael Kruse class ManagedMemoryRewritePass final : public ModulePass {
352c4a4af47SSiddharth Bhat public:
353c4a4af47SSiddharth Bhat static char ID;
354c4a4af47SSiddharth Bhat GPUArch Architecture;
355c4a4af47SSiddharth Bhat GPURuntime Runtime;
3568a2c07f6SSiddharth Bhat
ManagedMemoryRewritePass()357c4a4af47SSiddharth Bhat ManagedMemoryRewritePass() : ModulePass(ID) {}
runOnModule(Module & M)35833ca0b0eSMichael Kruse bool runOnModule(Module &M) override {
3598a2c07f6SSiddharth Bhat const DataLayout &DL = M.getDataLayout();
3608a2c07f6SSiddharth Bhat
361c4a4af47SSiddharth Bhat Function *Malloc = M.getFunction("malloc");
362c4a4af47SSiddharth Bhat
363c4a4af47SSiddharth Bhat if (Malloc) {
364a2c41127SSiddharth Bhat PollyIRBuilder Builder(M.getContext());
3658a2c07f6SSiddharth Bhat Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
366c4a4af47SSiddharth Bhat assert(PollyMallocManaged && "unable to create polly_mallocManaged");
367a2c41127SSiddharth Bhat
368a2c41127SSiddharth Bhat replaceAllUsesAndConstantUses(Malloc, PollyMallocManaged, Builder);
3699298ff2dSSiddharth Bhat Malloc->eraseFromParent();
370c4a4af47SSiddharth Bhat }
371c4a4af47SSiddharth Bhat
372c4a4af47SSiddharth Bhat Function *Free = M.getFunction("free");
373c4a4af47SSiddharth Bhat
374c4a4af47SSiddharth Bhat if (Free) {
375a2c41127SSiddharth Bhat PollyIRBuilder Builder(M.getContext());
3768a2c07f6SSiddharth Bhat Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M);
377c4a4af47SSiddharth Bhat assert(PollyFreeManaged && "unable to create polly_freeManaged");
378a2c41127SSiddharth Bhat
379a2c41127SSiddharth Bhat replaceAllUsesAndConstantUses(Free, PollyFreeManaged, Builder);
3809298ff2dSSiddharth Bhat Free->eraseFromParent();
381c4a4af47SSiddharth Bhat }
382c4a4af47SSiddharth Bhat
3838a2c07f6SSiddharth Bhat SmallPtrSet<GlobalVariable *, 4> GlobalsToErase;
3848a2c07f6SSiddharth Bhat for (GlobalVariable &Global : M.globals())
3858a2c07f6SSiddharth Bhat replaceGlobalArray(M, DL, Global, GlobalsToErase);
3868a2c07f6SSiddharth Bhat for (GlobalVariable *G : GlobalsToErase)
3878a2c07f6SSiddharth Bhat G->eraseFromParent();
3888a2c07f6SSiddharth Bhat
3898a2c07f6SSiddharth Bhat // Rewrite allocas to cudaMallocs if we are asked to do so.
3908a2c07f6SSiddharth Bhat if (RewriteAllocas) {
3918a2c07f6SSiddharth Bhat SmallSet<AllocaInst *, 4> AllocasToBeManaged;
3928a2c07f6SSiddharth Bhat for (Function &F : M.functions())
3938a2c07f6SSiddharth Bhat getAllocasToBeManaged(F, AllocasToBeManaged);
3948a2c07f6SSiddharth Bhat
3958a2c07f6SSiddharth Bhat for (AllocaInst *Alloca : AllocasToBeManaged)
3968a2c07f6SSiddharth Bhat rewriteAllocaAsManagedMemory(Alloca, DL);
3978a2c07f6SSiddharth Bhat }
3988a2c07f6SSiddharth Bhat
399c4a4af47SSiddharth Bhat return true;
400c4a4af47SSiddharth Bhat }
401c4a4af47SSiddharth Bhat };
402c4a4af47SSiddharth Bhat } // namespace
403c4a4af47SSiddharth Bhat char ManagedMemoryRewritePass::ID = 42;
404c4a4af47SSiddharth Bhat
createManagedMemoryRewritePassPass(GPUArch Arch,GPURuntime Runtime)405c4a4af47SSiddharth Bhat Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch,
406c4a4af47SSiddharth Bhat GPURuntime Runtime) {
407c4a4af47SSiddharth Bhat ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass();
408c4a4af47SSiddharth Bhat pass->Runtime = Runtime;
409c4a4af47SSiddharth Bhat pass->Architecture = Arch;
410c4a4af47SSiddharth Bhat return pass;
411c4a4af47SSiddharth Bhat }
412c4a4af47SSiddharth Bhat
413c4a4af47SSiddharth Bhat INITIALIZE_PASS_BEGIN(
414c4a4af47SSiddharth Bhat ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
415c4a4af47SSiddharth Bhat "Polly - Rewrite all allocations in heap & data section to managed memory",
416c4a4af47SSiddharth Bhat false, false)
417c4a4af47SSiddharth Bhat INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration);
418c4a4af47SSiddharth Bhat INITIALIZE_PASS_DEPENDENCY(DependenceInfo);
419c4a4af47SSiddharth Bhat INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
420c4a4af47SSiddharth Bhat INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass);
421c4a4af47SSiddharth Bhat INITIALIZE_PASS_DEPENDENCY(RegionInfoPass);
422c4a4af47SSiddharth Bhat INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass);
423c4a4af47SSiddharth Bhat INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass);
424c4a4af47SSiddharth Bhat INITIALIZE_PASS_END(
425c4a4af47SSiddharth Bhat ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
426c4a4af47SSiddharth Bhat "Polly - Rewrite all allocations in heap & data section to managed memory",
427c4a4af47SSiddharth Bhat false, false)
428