1 //===---- ManagedMemoryRewrite.cpp - Rewrite global & malloc'd memory -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Take a module and rewrite:
10 // 1. `malloc` -> `polly_mallocManaged`
11 // 2. `free` -> `polly_freeManaged`
12 // 3. global arrays with initializers -> global arrays that are initialized
13 //                                       with a constructor call to
14 //                                       `polly_mallocManaged`.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "polly/CodeGen/IRBuilder.h"
19 #include "polly/CodeGen/PPCGCodeGeneration.h"
20 #include "polly/DependenceInfo.h"
21 #include "polly/LinkAllPasses.h"
22 #include "polly/Options.h"
23 #include "polly/ScopDetection.h"
24 #include "llvm/ADT/SmallSet.h"
25 #include "llvm/Analysis/CaptureTracking.h"
26 #include "llvm/Transforms/Utils/ModuleUtils.h"
27 
28 using namespace polly;
29 
30 static cl::opt<bool> RewriteAllocas(
31     "polly-acc-rewrite-allocas",
32     cl::desc(
33         "Ask the managed memory rewriter to also rewrite alloca instructions"),
34     cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
35 
36 static cl::opt<bool> IgnoreLinkageForGlobals(
37     "polly-acc-rewrite-ignore-linkage-for-globals",
38     cl::desc(
39         "By default, we only rewrite globals with internal linkage. This flag "
40         "enables rewriting of globals regardless of linkage"),
41     cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
42 
43 #define DEBUG_TYPE "polly-acc-rewrite-managed-memory"
44 namespace {
45 
46 static llvm::Function *getOrCreatePollyMallocManaged(Module &M) {
47   const char *Name = "polly_mallocManaged";
48   Function *F = M.getFunction(Name);
49 
50   // If F is not available, declare it.
51   if (!F) {
52     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
53     PollyIRBuilder Builder(M.getContext());
54     // TODO: How do I get `size_t`? I assume from DataLayout?
55     FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(),
56                                          {Builder.getInt64Ty()}, false);
57     F = Function::Create(Ty, Linkage, Name, &M);
58   }
59 
60   return F;
61 }
62 
63 static llvm::Function *getOrCreatePollyFreeManaged(Module &M) {
64   const char *Name = "polly_freeManaged";
65   Function *F = M.getFunction(Name);
66 
67   // If F is not available, declare it.
68   if (!F) {
69     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
70     PollyIRBuilder Builder(M.getContext());
71     // TODO: How do I get `size_t`? I assume from DataLayout?
72     FunctionType *Ty =
73         FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false);
74     F = Function::Create(Ty, Linkage, Name, &M);
75   }
76 
77   return F;
78 }
79 
80 // Expand a constant expression `Cur`, which is used at instruction `Parent`
81 // at index `index`.
82 // Since a constant expression can expand to multiple instructions, store all
83 // the expands into a set called `Expands`.
84 // Note that this goes inorder on the constant expression tree.
85 // A * ((B * D) + C)
86 // will be processed with first A, then B * D, then B, then D, and then C.
87 // Though ConstantExprs are not treated as "trees" but as DAGs, since you can
88 // have something like this:
89 //    *
90 //   /  \
91 //   \  /
92 //    (D)
93 //
94 // For the purposes of this expansion, we expand the two occurences of D
95 // separately. Therefore, we expand the DAG into the tree:
96 //  *
97 // / \
98 // D  D
99 // TODO: We don't _have_to do this, but this is the simplest solution.
100 // We can write a solution that keeps track of which constants have been
101 // already expanded.
102 static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder,
103                                Instruction *Parent, int index,
104                                SmallPtrSet<Instruction *, 4> &Expands) {
105   assert(Cur && "invalid constant expression passed");
106   Instruction *I = Cur->getAsInstruction();
107   assert(I && "unable to convert ConstantExpr to Instruction");
108 
109   LLVM_DEBUG(dbgs() << "Expanding ConstantExpression: (" << *Cur
110                     << ") in Instruction: (" << *I << ")\n";);
111 
112   // Invalidate `Cur` so that no one after this point uses `Cur`. Rather,
113   // they should mutate `I`.
114   Cur = nullptr;
115 
116   Expands.insert(I);
117   Parent->setOperand(index, I);
118 
119   // The things that `Parent` uses (its operands) should be created
120   // before `Parent`.
121   Builder.SetInsertPoint(Parent);
122   Builder.Insert(I);
123 
124   for (unsigned i = 0; i < I->getNumOperands(); i++) {
125     Value *Op = I->getOperand(i);
126     assert(isa<Constant>(Op) && "constant must have a constant operand");
127 
128     if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op))
129       expandConstantExpr(CExprOp, Builder, I, i, Expands);
130   }
131 }
132 
133 // Edit all uses of `OldVal` to NewVal` in `Inst`. This will rewrite
134 // `ConstantExpr`s that are used in the `Inst`.
135 // Note that `replaceAllUsesWith` is insufficient for this purpose because it
136 // does not rewrite values in `ConstantExpr`s.
137 static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal,
138                                PollyIRBuilder &Builder) {
139 
140   // This contains a set of instructions in which OldVal must be replaced.
141   // We start with `Inst`, and we fill it up with the expanded `ConstantExpr`s
142   // from `Inst`s arguments.
143   // We need to go through this process because `replaceAllUsesWith` does not
144   // actually edit `ConstantExpr`s.
145   SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst};
146 
147   // Expand all `ConstantExpr`s and place it in `InstsToVisit`.
148   for (unsigned i = 0; i < Inst->getNumOperands(); i++) {
149     Value *Operand = Inst->getOperand(i);
150     if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand))
151       expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit);
152   }
153 
154   // Now visit each instruction and use `replaceUsesOfWith`. We know that
155   // will work because `I` cannot have any `ConstantExpr` within it.
156   for (Instruction *I : InstsToVisit)
157     I->replaceUsesOfWith(OldVal, NewVal);
158 }
159 
160 // Given a value `Current`, return all Instructions that may contain `Current`
161 // in an expression.
162 // We need this auxiliary function, because if we have a
163 // `Constant` that is a user of `V`, we need to recurse into the
164 // `Constant`s uses to gather the root instruciton.
165 static void getInstructionUsersOfValue(Value *V,
166                                        SmallVector<Instruction *, 4> &Owners) {
167   if (auto *I = dyn_cast<Instruction>(V)) {
168     Owners.push_back(I);
169   } else {
170     // Anything that is a `User` must be a constant or an instruction.
171     auto *C = cast<Constant>(V);
172     for (Use &CUse : C->uses())
173       getInstructionUsersOfValue(CUse.getUser(), Owners);
174   }
175 }
176 
177 static void
178 replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array,
179                    SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) {
180   // We only want arrays.
181   ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getType()->getElementType());
182   if (!ArrayTy)
183     return;
184   Type *ElemTy = ArrayTy->getElementType();
185   PointerType *ElemPtrTy = ElemTy->getPointerTo();
186 
187   // We only wish to replace arrays that are visible in the module they
188   // inhabit. Otherwise, our type edit from [T] to T* would be illegal across
189   // modules.
190   const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() ||
191                                        Array.hasInternalLinkage() ||
192                                        IgnoreLinkageForGlobals;
193   if (!OnlyVisibleInsideModule) {
194     LLVM_DEBUG(
195         dbgs() << "Not rewriting (" << Array
196                << ") to managed memory "
197                   "because it could be visible externally. To force rewrite, "
198                   "use -polly-acc-rewrite-ignore-linkage-for-globals.\n");
199     return;
200   }
201 
202   if (!Array.hasInitializer() ||
203       !isa<ConstantAggregateZero>(Array.getInitializer())) {
204     LLVM_DEBUG(dbgs() << "Not rewriting (" << Array
205                       << ") to managed memory "
206                          "because it has an initializer which is "
207                          "not a zeroinitializer.\n");
208     return;
209   }
210 
211   // At this point, we have committed to replacing this array.
212   ReplacedGlobals.insert(&Array);
213 
214   std::string NewName = Array.getName();
215   NewName += ".toptr";
216   GlobalVariable *ReplacementToArr =
217       cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy));
218   ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy));
219 
220   Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
221   std::string FnName = Array.getName();
222   FnName += ".constructor";
223   PollyIRBuilder Builder(M.getContext());
224   FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false);
225   const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
226   Function *F = Function::Create(Ty, Linkage, FnName, &M);
227   BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F);
228   Builder.SetInsertPoint(Start);
229 
230   const uint64_t ArraySizeInt = DL.getTypeAllocSize(ArrayTy);
231   Value *ArraySize = Builder.getInt64(ArraySizeInt);
232   ArraySize->setName("array.size");
233 
234   Value *AllocatedMemRaw =
235       Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw");
236   Value *AllocatedMemTyped =
237       Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed");
238   Builder.CreateStore(AllocatedMemTyped, ReplacementToArr);
239   Builder.CreateRetVoid();
240 
241   const int Priority = 0;
242   appendToGlobalCtors(M, F, Priority, ReplacementToArr);
243 
244   SmallVector<Instruction *, 4> ArrayUserInstructions;
245   // Get all instructions that use array. We need to do this weird thing
246   // because `Constant`s that contain this array neeed to be expanded into
247   // instructions so that we can replace their parameters. `Constant`s cannot
248   // be edited easily, so we choose to convert all `Constant`s to
249   // `Instruction`s and handle all of the uses of `Array` uniformly.
250   for (Use &ArrayUse : Array.uses())
251     getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions);
252 
253   for (Instruction *UserOfArrayInst : ArrayUserInstructions) {
254 
255     Builder.SetInsertPoint(UserOfArrayInst);
256     // <ty>** -> <ty>*
257     Value *ArrPtrLoaded = Builder.CreateLoad(ReplacementToArr, "arrptr.load");
258     // <ty>* -> [ty]*
259     Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast(
260         ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast");
261     rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder);
262   }
263 }
264 
265 // We return all `allocas` that may need to be converted to a call to
266 // cudaMallocManaged.
267 static void getAllocasToBeManaged(Function &F,
268                                   SmallSet<AllocaInst *, 4> &Allocas) {
269   for (BasicBlock &BB : F) {
270     for (Instruction &I : BB) {
271       auto *Alloca = dyn_cast<AllocaInst>(&I);
272       if (!Alloca)
273         continue;
274       LLVM_DEBUG(dbgs() << "Checking if (" << *Alloca << ") may be captured: ");
275 
276       if (PointerMayBeCaptured(Alloca, /* ReturnCaptures */ false,
277                                /* StoreCaptures */ true)) {
278         Allocas.insert(Alloca);
279         LLVM_DEBUG(dbgs() << "YES (captured).\n");
280       } else {
281         LLVM_DEBUG(dbgs() << "NO (not captured).\n");
282       }
283     }
284   }
285 }
286 
287 static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca,
288                                          const DataLayout &DL) {
289   LLVM_DEBUG(dbgs() << "rewriting: (" << *Alloca << ") to managed mem.\n");
290   Module *M = Alloca->getModule();
291   assert(M && "Alloca does not have a module");
292 
293   PollyIRBuilder Builder(M->getContext());
294   Builder.SetInsertPoint(Alloca);
295 
296   Function *MallocManagedFn =
297       getOrCreatePollyMallocManaged(*Alloca->getModule());
298   const uint64_t Size =
299       DL.getTypeAllocSize(Alloca->getType()->getElementType());
300   Value *SizeVal = Builder.getInt64(Size);
301   Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal});
302   Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType());
303 
304   Function *F = Alloca->getFunction();
305   assert(F && "Alloca has invalid function");
306 
307   Bitcasted->takeName(Alloca);
308   Alloca->replaceAllUsesWith(Bitcasted);
309   Alloca->eraseFromParent();
310 
311   for (BasicBlock &BB : *F) {
312     ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator());
313     if (!Return)
314       continue;
315     Builder.SetInsertPoint(Return);
316 
317     Function *FreeManagedFn = getOrCreatePollyFreeManaged(*M);
318     Builder.CreateCall(FreeManagedFn, {RawManagedMem});
319   }
320 }
321 
322 // Replace all uses of `Old` with `New`, even inside `ConstantExpr`.
323 //
324 // `replaceAllUsesWith` does replace values in `ConstantExpr`. This function
325 // actually does replace it in `ConstantExpr`. The caveat is that if there is
326 // a use that is *outside* a function (say, at global declarations), we fail.
327 // So, this is meant to be used on values which we know will only be used
328 // within functions.
329 //
330 // This process works by looking through the uses of `Old`. If it finds a
331 // `ConstantExpr`, it recursively looks for the owning instruction.
332 // Then, it expands all the `ConstantExpr` to instructions and replaces
333 // `Old` with `New` in the expanded instructions.
334 static void replaceAllUsesAndConstantUses(Value *Old, Value *New,
335                                           PollyIRBuilder &Builder) {
336   SmallVector<Instruction *, 4> UserInstructions;
337   // Get all instructions that use array. We need to do this weird thing
338   // because `Constant`s that contain this array neeed to be expanded into
339   // instructions so that we can replace their parameters. `Constant`s cannot
340   // be edited easily, so we choose to convert all `Constant`s to
341   // `Instruction`s and handle all of the uses of `Array` uniformly.
342   for (Use &ArrayUse : Old->uses())
343     getInstructionUsersOfValue(ArrayUse.getUser(), UserInstructions);
344 
345   for (Instruction *I : UserInstructions)
346     rewriteOldValToNew(I, Old, New, Builder);
347 }
348 
349 class ManagedMemoryRewritePass : public ModulePass {
350 public:
351   static char ID;
352   GPUArch Architecture;
353   GPURuntime Runtime;
354 
355   ManagedMemoryRewritePass() : ModulePass(ID) {}
356   virtual bool runOnModule(Module &M) {
357     const DataLayout &DL = M.getDataLayout();
358 
359     Function *Malloc = M.getFunction("malloc");
360 
361     if (Malloc) {
362       PollyIRBuilder Builder(M.getContext());
363       Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
364       assert(PollyMallocManaged && "unable to create polly_mallocManaged");
365 
366       replaceAllUsesAndConstantUses(Malloc, PollyMallocManaged, Builder);
367       Malloc->eraseFromParent();
368     }
369 
370     Function *Free = M.getFunction("free");
371 
372     if (Free) {
373       PollyIRBuilder Builder(M.getContext());
374       Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M);
375       assert(PollyFreeManaged && "unable to create polly_freeManaged");
376 
377       replaceAllUsesAndConstantUses(Free, PollyFreeManaged, Builder);
378       Free->eraseFromParent();
379     }
380 
381     SmallPtrSet<GlobalVariable *, 4> GlobalsToErase;
382     for (GlobalVariable &Global : M.globals())
383       replaceGlobalArray(M, DL, Global, GlobalsToErase);
384     for (GlobalVariable *G : GlobalsToErase)
385       G->eraseFromParent();
386 
387     // Rewrite allocas to cudaMallocs if we are asked to do so.
388     if (RewriteAllocas) {
389       SmallSet<AllocaInst *, 4> AllocasToBeManaged;
390       for (Function &F : M.functions())
391         getAllocasToBeManaged(F, AllocasToBeManaged);
392 
393       for (AllocaInst *Alloca : AllocasToBeManaged)
394         rewriteAllocaAsManagedMemory(Alloca, DL);
395     }
396 
397     return true;
398   }
399 };
400 } // namespace
401 char ManagedMemoryRewritePass::ID = 42;
402 
403 Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch,
404                                                 GPURuntime Runtime) {
405   ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass();
406   pass->Runtime = Runtime;
407   pass->Architecture = Arch;
408   return pass;
409 }
410 
411 INITIALIZE_PASS_BEGIN(
412     ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
413     "Polly - Rewrite all allocations in heap & data section to managed memory",
414     false, false)
415 INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration);
416 INITIALIZE_PASS_DEPENDENCY(DependenceInfo);
417 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
418 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass);
419 INITIALIZE_PASS_DEPENDENCY(RegionInfoPass);
420 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass);
421 INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass);
422 INITIALIZE_PASS_END(
423     ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
424     "Polly - Rewrite all allocations in heap & data section to managed memory",
425     false, false)
426