18a2c07f6SSiddharth Bhat //===---- ManagedMemoryRewrite.cpp - Rewrite global & malloc'd memory -----===//
2c4a4af47SSiddharth Bhat //
32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information.
52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c4a4af47SSiddharth Bhat //
7c4a4af47SSiddharth Bhat //===----------------------------------------------------------------------===//
8c4a4af47SSiddharth Bhat //
9c4a4af47SSiddharth Bhat // Take a module and rewrite:
10c4a4af47SSiddharth Bhat // 1. `malloc` -> `polly_mallocManaged`
11c4a4af47SSiddharth Bhat // 2. `free` -> `polly_freeManaged`
12c4a4af47SSiddharth Bhat // 3. global arrays with initializers -> global arrays that are initialized
13c4a4af47SSiddharth Bhat //                                       with a constructor call to
14c4a4af47SSiddharth Bhat //                                       `polly_mallocManaged`.
15c4a4af47SSiddharth Bhat //
16c4a4af47SSiddharth Bhat //===----------------------------------------------------------------------===//
17c4a4af47SSiddharth Bhat 
18031bb165SMichael Kruse #include "polly/CodeGen/IRBuilder.h"
19c4a4af47SSiddharth Bhat #include "polly/CodeGen/PPCGCodeGeneration.h"
20c4a4af47SSiddharth Bhat #include "polly/DependenceInfo.h"
21c4a4af47SSiddharth Bhat #include "polly/LinkAllPasses.h"
22c4a4af47SSiddharth Bhat #include "polly/Options.h"
23c4a4af47SSiddharth Bhat #include "polly/ScopDetection.h"
24031bb165SMichael Kruse #include "llvm/ADT/SmallSet.h"
258a2c07f6SSiddharth Bhat #include "llvm/Analysis/CaptureTracking.h"
262c831971SMichael Kruse #include "llvm/InitializePasses.h"
278a2c07f6SSiddharth Bhat #include "llvm/Transforms/Utils/ModuleUtils.h"
288a2c07f6SSiddharth Bhat 
2991ca9adcSMichael Kruse using namespace llvm;
30031bb165SMichael Kruse using namespace polly;
31031bb165SMichael Kruse 
328a2c07f6SSiddharth Bhat static cl::opt<bool> RewriteAllocas(
338a2c07f6SSiddharth Bhat     "polly-acc-rewrite-allocas",
348a2c07f6SSiddharth Bhat     cl::desc(
358a2c07f6SSiddharth Bhat         "Ask the managed memory rewriter to also rewrite alloca instructions"),
3636c7d79dSFangrui Song     cl::Hidden, cl::cat(PollyCategory));
378a2c07f6SSiddharth Bhat 
388a2c07f6SSiddharth Bhat static cl::opt<bool> IgnoreLinkageForGlobals(
398a2c07f6SSiddharth Bhat     "polly-acc-rewrite-ignore-linkage-for-globals",
408a2c07f6SSiddharth Bhat     cl::desc(
418a2c07f6SSiddharth Bhat         "By default, we only rewrite globals with internal linkage. This flag "
428a2c07f6SSiddharth Bhat         "enables rewriting of globals regardless of linkage"),
43*95a13425SFangrui Song     cl::Hidden, cl::cat(PollyCategory));
448a2c07f6SSiddharth Bhat 
458a2c07f6SSiddharth Bhat #define DEBUG_TYPE "polly-acc-rewrite-managed-memory"
46c4a4af47SSiddharth Bhat namespace {
47c4a4af47SSiddharth Bhat 
getOrCreatePollyMallocManaged(Module & M)488a2c07f6SSiddharth Bhat static llvm::Function *getOrCreatePollyMallocManaged(Module &M) {
49c4a4af47SSiddharth Bhat   const char *Name = "polly_mallocManaged";
50c4a4af47SSiddharth Bhat   Function *F = M.getFunction(Name);
51c4a4af47SSiddharth Bhat 
52c4a4af47SSiddharth Bhat   // If F is not available, declare it.
53c4a4af47SSiddharth Bhat   if (!F) {
54c4a4af47SSiddharth Bhat     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
55c4a4af47SSiddharth Bhat     PollyIRBuilder Builder(M.getContext());
56c4a4af47SSiddharth Bhat     // TODO: How do I get `size_t`? I assume from DataLayout?
57c4a4af47SSiddharth Bhat     FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(),
58c4a4af47SSiddharth Bhat                                          {Builder.getInt64Ty()}, false);
59c4a4af47SSiddharth Bhat     F = Function::Create(Ty, Linkage, Name, &M);
60c4a4af47SSiddharth Bhat   }
61c4a4af47SSiddharth Bhat 
62c4a4af47SSiddharth Bhat   return F;
63c4a4af47SSiddharth Bhat }
64c4a4af47SSiddharth Bhat 
getOrCreatePollyFreeManaged(Module & M)658a2c07f6SSiddharth Bhat static llvm::Function *getOrCreatePollyFreeManaged(Module &M) {
66c4a4af47SSiddharth Bhat   const char *Name = "polly_freeManaged";
67c4a4af47SSiddharth Bhat   Function *F = M.getFunction(Name);
68c4a4af47SSiddharth Bhat 
69c4a4af47SSiddharth Bhat   // If F is not available, declare it.
70c4a4af47SSiddharth Bhat   if (!F) {
71c4a4af47SSiddharth Bhat     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
72c4a4af47SSiddharth Bhat     PollyIRBuilder Builder(M.getContext());
73c4a4af47SSiddharth Bhat     // TODO: How do I get `size_t`? I assume from DataLayout?
74c4a4af47SSiddharth Bhat     FunctionType *Ty =
75c4a4af47SSiddharth Bhat         FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false);
76c4a4af47SSiddharth Bhat     F = Function::Create(Ty, Linkage, Name, &M);
77c4a4af47SSiddharth Bhat   }
78c4a4af47SSiddharth Bhat 
79c4a4af47SSiddharth Bhat   return F;
80c4a4af47SSiddharth Bhat }
81c4a4af47SSiddharth Bhat 
828a2c07f6SSiddharth Bhat // Expand a constant expression `Cur`, which is used at instruction `Parent`
838a2c07f6SSiddharth Bhat // at index `index`.
848a2c07f6SSiddharth Bhat // Since a constant expression can expand to multiple instructions, store all
858a2c07f6SSiddharth Bhat // the expands into a set called `Expands`.
868a2c07f6SSiddharth Bhat // Note that this goes inorder on the constant expression tree.
878a2c07f6SSiddharth Bhat // A * ((B * D) + C)
888a2c07f6SSiddharth Bhat // will be processed with first A, then B * D, then B, then D, and then C.
898a2c07f6SSiddharth Bhat // Though ConstantExprs are not treated as "trees" but as DAGs, since you can
908a2c07f6SSiddharth Bhat // have something like this:
918a2c07f6SSiddharth Bhat //    *
928a2c07f6SSiddharth Bhat //   /  \
938a2c07f6SSiddharth Bhat //   \  /
948a2c07f6SSiddharth Bhat //    (D)
958a2c07f6SSiddharth Bhat //
968a2c07f6SSiddharth Bhat // For the purposes of this expansion, we expand the two occurences of D
978a2c07f6SSiddharth Bhat // separately. Therefore, we expand the DAG into the tree:
988a2c07f6SSiddharth Bhat //  *
998a2c07f6SSiddharth Bhat // / \
1008a2c07f6SSiddharth Bhat // D  D
1018a2c07f6SSiddharth Bhat // TODO: We don't _have_to do this, but this is the simplest solution.
1028a2c07f6SSiddharth Bhat // We can write a solution that keeps track of which constants have been
1038a2c07f6SSiddharth Bhat // already expanded.
expandConstantExpr(ConstantExpr * Cur,PollyIRBuilder & Builder,Instruction * Parent,int index,SmallPtrSet<Instruction *,4> & Expands)1048a2c07f6SSiddharth Bhat static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder,
1058a2c07f6SSiddharth Bhat                                Instruction *Parent, int index,
1068a2c07f6SSiddharth Bhat                                SmallPtrSet<Instruction *, 4> &Expands) {
1078a2c07f6SSiddharth Bhat   assert(Cur && "invalid constant expression passed");
1088a2c07f6SSiddharth Bhat   Instruction *I = Cur->getAsInstruction();
109205a78a6SSiddharth Bhat   assert(I && "unable to convert ConstantExpr to Instruction");
110205a78a6SSiddharth Bhat 
111349506a9SNicola Zaghen   LLVM_DEBUG(dbgs() << "Expanding ConstantExpression: (" << *Cur
112a8c329b0SSiddharth Bhat                     << ") in Instruction: (" << *I << ")\n";);
113205a78a6SSiddharth Bhat 
114205a78a6SSiddharth Bhat   // Invalidate `Cur` so that no one after this point uses `Cur`. Rather,
115205a78a6SSiddharth Bhat   // they should mutate `I`.
116205a78a6SSiddharth Bhat   Cur = nullptr;
117205a78a6SSiddharth Bhat 
1188a2c07f6SSiddharth Bhat   Expands.insert(I);
1198a2c07f6SSiddharth Bhat   Parent->setOperand(index, I);
1208a2c07f6SSiddharth Bhat 
1218a2c07f6SSiddharth Bhat   // The things that `Parent` uses (its operands) should be created
1228a2c07f6SSiddharth Bhat   // before `Parent`.
1238a2c07f6SSiddharth Bhat   Builder.SetInsertPoint(Parent);
1248a2c07f6SSiddharth Bhat   Builder.Insert(I);
1258a2c07f6SSiddharth Bhat 
126205a78a6SSiddharth Bhat   for (unsigned i = 0; i < I->getNumOperands(); i++) {
127205a78a6SSiddharth Bhat     Value *Op = I->getOperand(i);
1288a2c07f6SSiddharth Bhat     assert(isa<Constant>(Op) && "constant must have a constant operand");
1298a2c07f6SSiddharth Bhat 
1308a2c07f6SSiddharth Bhat     if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op))
1318a2c07f6SSiddharth Bhat       expandConstantExpr(CExprOp, Builder, I, i, Expands);
1328a2c07f6SSiddharth Bhat   }
1338a2c07f6SSiddharth Bhat }
1348a2c07f6SSiddharth Bhat 
1358a2c07f6SSiddharth Bhat // Edit all uses of `OldVal` to NewVal` in `Inst`. This will rewrite
1368a2c07f6SSiddharth Bhat // `ConstantExpr`s that are used in the `Inst`.
1378a2c07f6SSiddharth Bhat // Note that `replaceAllUsesWith` is insufficient for this purpose because it
1388a2c07f6SSiddharth Bhat // does not rewrite values in `ConstantExpr`s.
rewriteOldValToNew(Instruction * Inst,Value * OldVal,Value * NewVal,PollyIRBuilder & Builder)1398a2c07f6SSiddharth Bhat static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal,
1408a2c07f6SSiddharth Bhat                                PollyIRBuilder &Builder) {
1418a2c07f6SSiddharth Bhat 
1428a2c07f6SSiddharth Bhat   // This contains a set of instructions in which OldVal must be replaced.
1438a2c07f6SSiddharth Bhat   // We start with `Inst`, and we fill it up with the expanded `ConstantExpr`s
1448a2c07f6SSiddharth Bhat   // from `Inst`s arguments.
1458a2c07f6SSiddharth Bhat   // We need to go through this process because `replaceAllUsesWith` does not
1468a2c07f6SSiddharth Bhat   // actually edit `ConstantExpr`s.
1478a2c07f6SSiddharth Bhat   SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst};
1488a2c07f6SSiddharth Bhat 
1498a2c07f6SSiddharth Bhat   // Expand all `ConstantExpr`s and place it in `InstsToVisit`.
1508a2c07f6SSiddharth Bhat   for (unsigned i = 0; i < Inst->getNumOperands(); i++) {
1518a2c07f6SSiddharth Bhat     Value *Operand = Inst->getOperand(i);
1528a2c07f6SSiddharth Bhat     if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand))
1538a2c07f6SSiddharth Bhat       expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit);
1548a2c07f6SSiddharth Bhat   }
1558a2c07f6SSiddharth Bhat 
1568a2c07f6SSiddharth Bhat   // Now visit each instruction and use `replaceUsesOfWith`. We know that
1578a2c07f6SSiddharth Bhat   // will work because `I` cannot have any `ConstantExpr` within it.
1588a2c07f6SSiddharth Bhat   for (Instruction *I : InstsToVisit)
1598a2c07f6SSiddharth Bhat     I->replaceUsesOfWith(OldVal, NewVal);
1608a2c07f6SSiddharth Bhat }
1618a2c07f6SSiddharth Bhat 
1628a2c07f6SSiddharth Bhat // Given a value `Current`, return all Instructions that may contain `Current`
1638a2c07f6SSiddharth Bhat // in an expression.
1648a2c07f6SSiddharth Bhat // We need this auxiliary function, because if we have a
1658a2c07f6SSiddharth Bhat // `Constant` that is a user of `V`, we need to recurse into the
1668a2c07f6SSiddharth Bhat // `Constant`s uses to gather the root instruciton.
getInstructionUsersOfValue(Value * V,SmallVector<Instruction *,4> & Owners)1678a2c07f6SSiddharth Bhat static void getInstructionUsersOfValue(Value *V,
1688a2c07f6SSiddharth Bhat                                        SmallVector<Instruction *, 4> &Owners) {
1698a2c07f6SSiddharth Bhat   if (auto *I = dyn_cast<Instruction>(V)) {
1708a2c07f6SSiddharth Bhat     Owners.push_back(I);
1718a2c07f6SSiddharth Bhat   } else {
1728a2c07f6SSiddharth Bhat     // Anything that is a `User` must be a constant or an instruction.
1738a2c07f6SSiddharth Bhat     auto *C = cast<Constant>(V);
1748a2c07f6SSiddharth Bhat     for (Use &CUse : C->uses())
1758a2c07f6SSiddharth Bhat       getInstructionUsersOfValue(CUse.getUser(), Owners);
1768a2c07f6SSiddharth Bhat   }
1778a2c07f6SSiddharth Bhat }
1788a2c07f6SSiddharth Bhat 
1798a2c07f6SSiddharth Bhat static void
replaceGlobalArray(Module & M,const DataLayout & DL,GlobalVariable & Array,SmallPtrSet<GlobalVariable *,4> & ReplacedGlobals)1808a2c07f6SSiddharth Bhat replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array,
1818a2c07f6SSiddharth Bhat                    SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) {
1828a2c07f6SSiddharth Bhat   // We only want arrays.
183ee423d93SNikita Popov   ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getValueType());
1848a2c07f6SSiddharth Bhat   if (!ArrayTy)
1858a2c07f6SSiddharth Bhat     return;
1868a2c07f6SSiddharth Bhat   Type *ElemTy = ArrayTy->getElementType();
1878a2c07f6SSiddharth Bhat   PointerType *ElemPtrTy = ElemTy->getPointerTo();
1888a2c07f6SSiddharth Bhat 
1898a2c07f6SSiddharth Bhat   // We only wish to replace arrays that are visible in the module they
1908a2c07f6SSiddharth Bhat   // inhabit. Otherwise, our type edit from [T] to T* would be illegal across
1918a2c07f6SSiddharth Bhat   // modules.
1928a2c07f6SSiddharth Bhat   const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() ||
1938a2c07f6SSiddharth Bhat                                        Array.hasInternalLinkage() ||
1948a2c07f6SSiddharth Bhat                                        IgnoreLinkageForGlobals;
195557ce3a8SSiddharth Bhat   if (!OnlyVisibleInsideModule) {
196349506a9SNicola Zaghen     LLVM_DEBUG(
197349506a9SNicola Zaghen         dbgs() << "Not rewriting (" << Array
198a8c329b0SSiddharth Bhat                << ") to managed memory "
199557ce3a8SSiddharth Bhat                   "because it could be visible externally. To force rewrite, "
200557ce3a8SSiddharth Bhat                   "use -polly-acc-rewrite-ignore-linkage-for-globals.\n");
2018a2c07f6SSiddharth Bhat     return;
202557ce3a8SSiddharth Bhat   }
2038a2c07f6SSiddharth Bhat 
2048a2c07f6SSiddharth Bhat   if (!Array.hasInitializer() ||
205557ce3a8SSiddharth Bhat       !isa<ConstantAggregateZero>(Array.getInitializer())) {
206349506a9SNicola Zaghen     LLVM_DEBUG(dbgs() << "Not rewriting (" << Array
207a8c329b0SSiddharth Bhat                       << ") to managed memory "
208557ce3a8SSiddharth Bhat                          "because it has an initializer which is "
209557ce3a8SSiddharth Bhat                          "not a zeroinitializer.\n");
2108a2c07f6SSiddharth Bhat     return;
211557ce3a8SSiddharth Bhat   }
2128a2c07f6SSiddharth Bhat 
2138a2c07f6SSiddharth Bhat   // At this point, we have committed to replacing this array.
2148a2c07f6SSiddharth Bhat   ReplacedGlobals.insert(&Array);
2158a2c07f6SSiddharth Bhat 
2161a53b732SMichael Kruse   std::string NewName = Array.getName().str();
21790411189STobias Grosser   NewName += ".toptr";
2188a2c07f6SSiddharth Bhat   GlobalVariable *ReplacementToArr =
2198a2c07f6SSiddharth Bhat       cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy));
2208a2c07f6SSiddharth Bhat   ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy));
2218a2c07f6SSiddharth Bhat 
2228a2c07f6SSiddharth Bhat   Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
2231a53b732SMichael Kruse   std::string FnName = Array.getName().str();
22490411189STobias Grosser   FnName += ".constructor";
2258a2c07f6SSiddharth Bhat   PollyIRBuilder Builder(M.getContext());
2268a2c07f6SSiddharth Bhat   FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false);
2278a2c07f6SSiddharth Bhat   const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
2288a2c07f6SSiddharth Bhat   Function *F = Function::Create(Ty, Linkage, FnName, &M);
2298a2c07f6SSiddharth Bhat   BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F);
2308a2c07f6SSiddharth Bhat   Builder.SetInsertPoint(Start);
2318a2c07f6SSiddharth Bhat 
23260354486SSiddharth Bhat   const uint64_t ArraySizeInt = DL.getTypeAllocSize(ArrayTy);
2338a2c07f6SSiddharth Bhat   Value *ArraySize = Builder.getInt64(ArraySizeInt);
2348a2c07f6SSiddharth Bhat   ArraySize->setName("array.size");
2358a2c07f6SSiddharth Bhat 
2368a2c07f6SSiddharth Bhat   Value *AllocatedMemRaw =
2378a2c07f6SSiddharth Bhat       Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw");
2388a2c07f6SSiddharth Bhat   Value *AllocatedMemTyped =
2398a2c07f6SSiddharth Bhat       Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed");
2408a2c07f6SSiddharth Bhat   Builder.CreateStore(AllocatedMemTyped, ReplacementToArr);
2418a2c07f6SSiddharth Bhat   Builder.CreateRetVoid();
2428a2c07f6SSiddharth Bhat 
2438a2c07f6SSiddharth Bhat   const int Priority = 0;
2448a2c07f6SSiddharth Bhat   appendToGlobalCtors(M, F, Priority, ReplacementToArr);
2458a2c07f6SSiddharth Bhat 
2468a2c07f6SSiddharth Bhat   SmallVector<Instruction *, 4> ArrayUserInstructions;
2478a2c07f6SSiddharth Bhat   // Get all instructions that use array. We need to do this weird thing
2488a2c07f6SSiddharth Bhat   // because `Constant`s that contain this array neeed to be expanded into
2498a2c07f6SSiddharth Bhat   // instructions so that we can replace their parameters. `Constant`s cannot
2508a2c07f6SSiddharth Bhat   // be edited easily, so we choose to convert all `Constant`s to
2518a2c07f6SSiddharth Bhat   // `Instruction`s and handle all of the uses of `Array` uniformly.
2528a2c07f6SSiddharth Bhat   for (Use &ArrayUse : Array.uses())
2538a2c07f6SSiddharth Bhat     getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions);
2548a2c07f6SSiddharth Bhat 
2558a2c07f6SSiddharth Bhat   for (Instruction *UserOfArrayInst : ArrayUserInstructions) {
2568a2c07f6SSiddharth Bhat 
2578a2c07f6SSiddharth Bhat     Builder.SetInsertPoint(UserOfArrayInst);
2588a2c07f6SSiddharth Bhat     // <ty>** -> <ty>*
2599c486eb3SMichael Kruse     Value *ArrPtrLoaded =
2609c486eb3SMichael Kruse         Builder.CreateLoad(ElemPtrTy, ReplacementToArr, "arrptr.load");
2618a2c07f6SSiddharth Bhat     // <ty>* -> [ty]*
2628a2c07f6SSiddharth Bhat     Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast(
2638a2c07f6SSiddharth Bhat         ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast");
2648a2c07f6SSiddharth Bhat     rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder);
2658a2c07f6SSiddharth Bhat   }
2668a2c07f6SSiddharth Bhat }
2678a2c07f6SSiddharth Bhat 
2688a2c07f6SSiddharth Bhat // We return all `allocas` that may need to be converted to a call to
2698a2c07f6SSiddharth Bhat // cudaMallocManaged.
getAllocasToBeManaged(Function & F,SmallSet<AllocaInst *,4> & Allocas)2708a2c07f6SSiddharth Bhat static void getAllocasToBeManaged(Function &F,
2718a2c07f6SSiddharth Bhat                                   SmallSet<AllocaInst *, 4> &Allocas) {
2728a2c07f6SSiddharth Bhat   for (BasicBlock &BB : F) {
2738a2c07f6SSiddharth Bhat     for (Instruction &I : BB) {
2748a2c07f6SSiddharth Bhat       auto *Alloca = dyn_cast<AllocaInst>(&I);
2758a2c07f6SSiddharth Bhat       if (!Alloca)
2768a2c07f6SSiddharth Bhat         continue;
277349506a9SNicola Zaghen       LLVM_DEBUG(dbgs() << "Checking if (" << *Alloca << ") may be captured: ");
2788a2c07f6SSiddharth Bhat 
2798a2c07f6SSiddharth Bhat       if (PointerMayBeCaptured(Alloca, /* ReturnCaptures */ false,
2808a2c07f6SSiddharth Bhat                                /* StoreCaptures */ true)) {
2818a2c07f6SSiddharth Bhat         Allocas.insert(Alloca);
282349506a9SNicola Zaghen         LLVM_DEBUG(dbgs() << "YES (captured).\n");
2838a2c07f6SSiddharth Bhat       } else {
284349506a9SNicola Zaghen         LLVM_DEBUG(dbgs() << "NO (not captured).\n");
2858a2c07f6SSiddharth Bhat       }
2868a2c07f6SSiddharth Bhat     }
2878a2c07f6SSiddharth Bhat   }
2888a2c07f6SSiddharth Bhat }
2898a2c07f6SSiddharth Bhat 
rewriteAllocaAsManagedMemory(AllocaInst * Alloca,const DataLayout & DL)2908a2c07f6SSiddharth Bhat static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca,
2918a2c07f6SSiddharth Bhat                                          const DataLayout &DL) {
292349506a9SNicola Zaghen   LLVM_DEBUG(dbgs() << "rewriting: (" << *Alloca << ") to managed mem.\n");
2938a2c07f6SSiddharth Bhat   Module *M = Alloca->getModule();
2948a2c07f6SSiddharth Bhat   assert(M && "Alloca does not have a module");
2958a2c07f6SSiddharth Bhat 
2968a2c07f6SSiddharth Bhat   PollyIRBuilder Builder(M->getContext());
2978a2c07f6SSiddharth Bhat   Builder.SetInsertPoint(Alloca);
2988a2c07f6SSiddharth Bhat 
299ae2f9512SJames Y Knight   Function *MallocManagedFn =
300ae2f9512SJames Y Knight       getOrCreatePollyMallocManaged(*Alloca->getModule());
301ee423d93SNikita Popov   const uint64_t Size = DL.getTypeAllocSize(Alloca->getAllocatedType());
3028a2c07f6SSiddharth Bhat   Value *SizeVal = Builder.getInt64(Size);
3038a2c07f6SSiddharth Bhat   Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal});
3048a2c07f6SSiddharth Bhat   Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType());
3058a2c07f6SSiddharth Bhat 
3068a2c07f6SSiddharth Bhat   Function *F = Alloca->getFunction();
3078a2c07f6SSiddharth Bhat   assert(F && "Alloca has invalid function");
3088a2c07f6SSiddharth Bhat 
3098a2c07f6SSiddharth Bhat   Bitcasted->takeName(Alloca);
3108a2c07f6SSiddharth Bhat   Alloca->replaceAllUsesWith(Bitcasted);
3118a2c07f6SSiddharth Bhat   Alloca->eraseFromParent();
3128a2c07f6SSiddharth Bhat 
3138a2c07f6SSiddharth Bhat   for (BasicBlock &BB : *F) {
3148a2c07f6SSiddharth Bhat     ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator());
3158a2c07f6SSiddharth Bhat     if (!Return)
3168a2c07f6SSiddharth Bhat       continue;
3178a2c07f6SSiddharth Bhat     Builder.SetInsertPoint(Return);
3188a2c07f6SSiddharth Bhat 
319ae2f9512SJames Y Knight     Function *FreeManagedFn = getOrCreatePollyFreeManaged(*M);
3208a2c07f6SSiddharth Bhat     Builder.CreateCall(FreeManagedFn, {RawManagedMem});
3218a2c07f6SSiddharth Bhat   }
3228a2c07f6SSiddharth Bhat }
3238a2c07f6SSiddharth Bhat 
324a2c41127SSiddharth Bhat // Replace all uses of `Old` with `New`, even inside `ConstantExpr`.
325a2c41127SSiddharth Bhat //
326a2c41127SSiddharth Bhat // `replaceAllUsesWith` does replace values in `ConstantExpr`. This function
327a2c41127SSiddharth Bhat // actually does replace it in `ConstantExpr`. The caveat is that if there is
328a2c41127SSiddharth Bhat // a use that is *outside* a function (say, at global declarations), we fail.
329a2c41127SSiddharth Bhat // So, this is meant to be used on values which we know will only be used
330a2c41127SSiddharth Bhat // within functions.
331a2c41127SSiddharth Bhat //
332a2c41127SSiddharth Bhat // This process works by looking through the uses of `Old`. If it finds a
333a2c41127SSiddharth Bhat // `ConstantExpr`, it recursively looks for the owning instruction.
334a2c41127SSiddharth Bhat // Then, it expands all the `ConstantExpr` to instructions and replaces
335a2c41127SSiddharth Bhat // `Old` with `New` in the expanded instructions.
replaceAllUsesAndConstantUses(Value * Old,Value * New,PollyIRBuilder & Builder)336a2c41127SSiddharth Bhat static void replaceAllUsesAndConstantUses(Value *Old, Value *New,
337a2c41127SSiddharth Bhat                                           PollyIRBuilder &Builder) {
338a2c41127SSiddharth Bhat   SmallVector<Instruction *, 4> UserInstructions;
339a2c41127SSiddharth Bhat   // Get all instructions that use array. We need to do this weird thing
340a2c41127SSiddharth Bhat   // because `Constant`s that contain this array neeed to be expanded into
341a2c41127SSiddharth Bhat   // instructions so that we can replace their parameters. `Constant`s cannot
342a2c41127SSiddharth Bhat   // be edited easily, so we choose to convert all `Constant`s to
343a2c41127SSiddharth Bhat   // `Instruction`s and handle all of the uses of `Array` uniformly.
344a2c41127SSiddharth Bhat   for (Use &ArrayUse : Old->uses())
345a2c41127SSiddharth Bhat     getInstructionUsersOfValue(ArrayUse.getUser(), UserInstructions);
346a2c41127SSiddharth Bhat 
347a2c41127SSiddharth Bhat   for (Instruction *I : UserInstructions)
348a2c41127SSiddharth Bhat     rewriteOldValToNew(I, Old, New, Builder);
349a2c41127SSiddharth Bhat }
350a2c41127SSiddharth Bhat 
351bd93df93SMichael Kruse class ManagedMemoryRewritePass final : public ModulePass {
352c4a4af47SSiddharth Bhat public:
353c4a4af47SSiddharth Bhat   static char ID;
354c4a4af47SSiddharth Bhat   GPUArch Architecture;
355c4a4af47SSiddharth Bhat   GPURuntime Runtime;
3568a2c07f6SSiddharth Bhat 
ManagedMemoryRewritePass()357c4a4af47SSiddharth Bhat   ManagedMemoryRewritePass() : ModulePass(ID) {}
runOnModule(Module & M)35833ca0b0eSMichael Kruse   bool runOnModule(Module &M) override {
3598a2c07f6SSiddharth Bhat     const DataLayout &DL = M.getDataLayout();
3608a2c07f6SSiddharth Bhat 
361c4a4af47SSiddharth Bhat     Function *Malloc = M.getFunction("malloc");
362c4a4af47SSiddharth Bhat 
363c4a4af47SSiddharth Bhat     if (Malloc) {
364a2c41127SSiddharth Bhat       PollyIRBuilder Builder(M.getContext());
3658a2c07f6SSiddharth Bhat       Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
366c4a4af47SSiddharth Bhat       assert(PollyMallocManaged && "unable to create polly_mallocManaged");
367a2c41127SSiddharth Bhat 
368a2c41127SSiddharth Bhat       replaceAllUsesAndConstantUses(Malloc, PollyMallocManaged, Builder);
3699298ff2dSSiddharth Bhat       Malloc->eraseFromParent();
370c4a4af47SSiddharth Bhat     }
371c4a4af47SSiddharth Bhat 
372c4a4af47SSiddharth Bhat     Function *Free = M.getFunction("free");
373c4a4af47SSiddharth Bhat 
374c4a4af47SSiddharth Bhat     if (Free) {
375a2c41127SSiddharth Bhat       PollyIRBuilder Builder(M.getContext());
3768a2c07f6SSiddharth Bhat       Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M);
377c4a4af47SSiddharth Bhat       assert(PollyFreeManaged && "unable to create polly_freeManaged");
378a2c41127SSiddharth Bhat 
379a2c41127SSiddharth Bhat       replaceAllUsesAndConstantUses(Free, PollyFreeManaged, Builder);
3809298ff2dSSiddharth Bhat       Free->eraseFromParent();
381c4a4af47SSiddharth Bhat     }
382c4a4af47SSiddharth Bhat 
3838a2c07f6SSiddharth Bhat     SmallPtrSet<GlobalVariable *, 4> GlobalsToErase;
3848a2c07f6SSiddharth Bhat     for (GlobalVariable &Global : M.globals())
3858a2c07f6SSiddharth Bhat       replaceGlobalArray(M, DL, Global, GlobalsToErase);
3868a2c07f6SSiddharth Bhat     for (GlobalVariable *G : GlobalsToErase)
3878a2c07f6SSiddharth Bhat       G->eraseFromParent();
3888a2c07f6SSiddharth Bhat 
3898a2c07f6SSiddharth Bhat     // Rewrite allocas to cudaMallocs if we are asked to do so.
3908a2c07f6SSiddharth Bhat     if (RewriteAllocas) {
3918a2c07f6SSiddharth Bhat       SmallSet<AllocaInst *, 4> AllocasToBeManaged;
3928a2c07f6SSiddharth Bhat       for (Function &F : M.functions())
3938a2c07f6SSiddharth Bhat         getAllocasToBeManaged(F, AllocasToBeManaged);
3948a2c07f6SSiddharth Bhat 
3958a2c07f6SSiddharth Bhat       for (AllocaInst *Alloca : AllocasToBeManaged)
3968a2c07f6SSiddharth Bhat         rewriteAllocaAsManagedMemory(Alloca, DL);
3978a2c07f6SSiddharth Bhat     }
3988a2c07f6SSiddharth Bhat 
399c4a4af47SSiddharth Bhat     return true;
400c4a4af47SSiddharth Bhat   }
401c4a4af47SSiddharth Bhat };
402c4a4af47SSiddharth Bhat } // namespace
403c4a4af47SSiddharth Bhat char ManagedMemoryRewritePass::ID = 42;
404c4a4af47SSiddharth Bhat 
createManagedMemoryRewritePassPass(GPUArch Arch,GPURuntime Runtime)405c4a4af47SSiddharth Bhat Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch,
406c4a4af47SSiddharth Bhat                                                 GPURuntime Runtime) {
407c4a4af47SSiddharth Bhat   ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass();
408c4a4af47SSiddharth Bhat   pass->Runtime = Runtime;
409c4a4af47SSiddharth Bhat   pass->Architecture = Arch;
410c4a4af47SSiddharth Bhat   return pass;
411c4a4af47SSiddharth Bhat }
412c4a4af47SSiddharth Bhat 
413c4a4af47SSiddharth Bhat INITIALIZE_PASS_BEGIN(
414c4a4af47SSiddharth Bhat     ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
415c4a4af47SSiddharth Bhat     "Polly - Rewrite all allocations in heap & data section to managed memory",
416c4a4af47SSiddharth Bhat     false, false)
417c4a4af47SSiddharth Bhat INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration);
418c4a4af47SSiddharth Bhat INITIALIZE_PASS_DEPENDENCY(DependenceInfo);
419c4a4af47SSiddharth Bhat INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
420c4a4af47SSiddharth Bhat INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass);
421c4a4af47SSiddharth Bhat INITIALIZE_PASS_DEPENDENCY(RegionInfoPass);
422c4a4af47SSiddharth Bhat INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass);
423c4a4af47SSiddharth Bhat INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass);
424c4a4af47SSiddharth Bhat INITIALIZE_PASS_END(
425c4a4af47SSiddharth Bhat     ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
426c4a4af47SSiddharth Bhat     "Polly - Rewrite all allocations in heap & data section to managed memory",
427c4a4af47SSiddharth Bhat     false, false)
428