1 //===---- ManagedMemoryRewrite.cpp - Rewrite global & malloc'd memory -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Take a module and rewrite:
10 // 1. `malloc` -> `polly_mallocManaged`
11 // 2. `free` -> `polly_freeManaged`
12 // 3. global arrays with initializers -> global arrays that are initialized
13 //                                       with a constructor call to
14 //                                       `polly_mallocManaged`.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "polly/CodeGen/CodeGeneration.h"
19 #include "polly/CodeGen/IslAst.h"
20 #include "polly/CodeGen/IslNodeBuilder.h"
21 #include "polly/CodeGen/PPCGCodeGeneration.h"
22 #include "polly/CodeGen/Utils.h"
23 #include "polly/DependenceInfo.h"
24 #include "polly/LinkAllPasses.h"
25 #include "polly/Options.h"
26 #include "polly/ScopDetection.h"
27 #include "polly/ScopInfo.h"
28 #include "polly/Support/SCEVValidator.h"
29 #include "llvm/Analysis/AliasAnalysis.h"
30 #include "llvm/Analysis/BasicAliasAnalysis.h"
31 #include "llvm/Analysis/CaptureTracking.h"
32 #include "llvm/Analysis/GlobalsModRef.h"
33 #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
34 #include "llvm/Analysis/TargetLibraryInfo.h"
35 #include "llvm/Analysis/TargetTransformInfo.h"
36 #include "llvm/IR/LegacyPassManager.h"
37 #include "llvm/IR/Verifier.h"
38 #include "llvm/IRReader/IRReader.h"
39 #include "llvm/Linker/Linker.h"
40 #include "llvm/Support/TargetRegistry.h"
41 #include "llvm/Support/TargetSelect.h"
42 #include "llvm/Target/TargetMachine.h"
43 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
44 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
45 #include "llvm/Transforms/Utils/ModuleUtils.h"
46 
47 static cl::opt<bool> RewriteAllocas(
48     "polly-acc-rewrite-allocas",
49     cl::desc(
50         "Ask the managed memory rewriter to also rewrite alloca instructions"),
51     cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
52 
53 static cl::opt<bool> IgnoreLinkageForGlobals(
54     "polly-acc-rewrite-ignore-linkage-for-globals",
55     cl::desc(
56         "By default, we only rewrite globals with internal linkage. This flag "
57         "enables rewriting of globals regardless of linkage"),
58     cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
59 
60 #define DEBUG_TYPE "polly-acc-rewrite-managed-memory"
61 namespace {
62 
63 static llvm::Function *getOrCreatePollyMallocManaged(Module &M) {
64   const char *Name = "polly_mallocManaged";
65   Function *F = M.getFunction(Name);
66 
67   // If F is not available, declare it.
68   if (!F) {
69     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
70     PollyIRBuilder Builder(M.getContext());
71     // TODO: How do I get `size_t`? I assume from DataLayout?
72     FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(),
73                                          {Builder.getInt64Ty()}, false);
74     F = Function::Create(Ty, Linkage, Name, &M);
75   }
76 
77   return F;
78 }
79 
80 static llvm::Function *getOrCreatePollyFreeManaged(Module &M) {
81   const char *Name = "polly_freeManaged";
82   Function *F = M.getFunction(Name);
83 
84   // If F is not available, declare it.
85   if (!F) {
86     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
87     PollyIRBuilder Builder(M.getContext());
88     // TODO: How do I get `size_t`? I assume from DataLayout?
89     FunctionType *Ty =
90         FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false);
91     F = Function::Create(Ty, Linkage, Name, &M);
92   }
93 
94   return F;
95 }
96 
97 // Expand a constant expression `Cur`, which is used at instruction `Parent`
98 // at index `index`.
99 // Since a constant expression can expand to multiple instructions, store all
100 // the expands into a set called `Expands`.
101 // Note that this goes inorder on the constant expression tree.
102 // A * ((B * D) + C)
103 // will be processed with first A, then B * D, then B, then D, and then C.
104 // Though ConstantExprs are not treated as "trees" but as DAGs, since you can
105 // have something like this:
106 //    *
107 //   /  \
108 //   \  /
109 //    (D)
110 //
111 // For the purposes of this expansion, we expand the two occurences of D
112 // separately. Therefore, we expand the DAG into the tree:
113 //  *
114 // / \
115 // D  D
116 // TODO: We don't _have_to do this, but this is the simplest solution.
117 // We can write a solution that keeps track of which constants have been
118 // already expanded.
119 static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder,
120                                Instruction *Parent, int index,
121                                SmallPtrSet<Instruction *, 4> &Expands) {
122   assert(Cur && "invalid constant expression passed");
123   Instruction *I = Cur->getAsInstruction();
124   assert(I && "unable to convert ConstantExpr to Instruction");
125 
126   LLVM_DEBUG(dbgs() << "Expanding ConstantExpression: (" << *Cur
127                     << ") in Instruction: (" << *I << ")\n";);
128 
129   // Invalidate `Cur` so that no one after this point uses `Cur`. Rather,
130   // they should mutate `I`.
131   Cur = nullptr;
132 
133   Expands.insert(I);
134   Parent->setOperand(index, I);
135 
136   // The things that `Parent` uses (its operands) should be created
137   // before `Parent`.
138   Builder.SetInsertPoint(Parent);
139   Builder.Insert(I);
140 
141   for (unsigned i = 0; i < I->getNumOperands(); i++) {
142     Value *Op = I->getOperand(i);
143     assert(isa<Constant>(Op) && "constant must have a constant operand");
144 
145     if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op))
146       expandConstantExpr(CExprOp, Builder, I, i, Expands);
147   }
148 }
149 
150 // Edit all uses of `OldVal` to NewVal` in `Inst`. This will rewrite
151 // `ConstantExpr`s that are used in the `Inst`.
152 // Note that `replaceAllUsesWith` is insufficient for this purpose because it
153 // does not rewrite values in `ConstantExpr`s.
154 static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal,
155                                PollyIRBuilder &Builder) {
156 
157   // This contains a set of instructions in which OldVal must be replaced.
158   // We start with `Inst`, and we fill it up with the expanded `ConstantExpr`s
159   // from `Inst`s arguments.
160   // We need to go through this process because `replaceAllUsesWith` does not
161   // actually edit `ConstantExpr`s.
162   SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst};
163 
164   // Expand all `ConstantExpr`s and place it in `InstsToVisit`.
165   for (unsigned i = 0; i < Inst->getNumOperands(); i++) {
166     Value *Operand = Inst->getOperand(i);
167     if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand))
168       expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit);
169   }
170 
171   // Now visit each instruction and use `replaceUsesOfWith`. We know that
172   // will work because `I` cannot have any `ConstantExpr` within it.
173   for (Instruction *I : InstsToVisit)
174     I->replaceUsesOfWith(OldVal, NewVal);
175 }
176 
177 // Given a value `Current`, return all Instructions that may contain `Current`
178 // in an expression.
179 // We need this auxiliary function, because if we have a
180 // `Constant` that is a user of `V`, we need to recurse into the
181 // `Constant`s uses to gather the root instruciton.
182 static void getInstructionUsersOfValue(Value *V,
183                                        SmallVector<Instruction *, 4> &Owners) {
184   if (auto *I = dyn_cast<Instruction>(V)) {
185     Owners.push_back(I);
186   } else {
187     // Anything that is a `User` must be a constant or an instruction.
188     auto *C = cast<Constant>(V);
189     for (Use &CUse : C->uses())
190       getInstructionUsersOfValue(CUse.getUser(), Owners);
191   }
192 }
193 
194 static void
195 replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array,
196                    SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) {
197   // We only want arrays.
198   ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getType()->getElementType());
199   if (!ArrayTy)
200     return;
201   Type *ElemTy = ArrayTy->getElementType();
202   PointerType *ElemPtrTy = ElemTy->getPointerTo();
203 
204   // We only wish to replace arrays that are visible in the module they
205   // inhabit. Otherwise, our type edit from [T] to T* would be illegal across
206   // modules.
207   const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() ||
208                                        Array.hasInternalLinkage() ||
209                                        IgnoreLinkageForGlobals;
210   if (!OnlyVisibleInsideModule) {
211     LLVM_DEBUG(
212         dbgs() << "Not rewriting (" << Array
213                << ") to managed memory "
214                   "because it could be visible externally. To force rewrite, "
215                   "use -polly-acc-rewrite-ignore-linkage-for-globals.\n");
216     return;
217   }
218 
219   if (!Array.hasInitializer() ||
220       !isa<ConstantAggregateZero>(Array.getInitializer())) {
221     LLVM_DEBUG(dbgs() << "Not rewriting (" << Array
222                       << ") to managed memory "
223                          "because it has an initializer which is "
224                          "not a zeroinitializer.\n");
225     return;
226   }
227 
228   // At this point, we have committed to replacing this array.
229   ReplacedGlobals.insert(&Array);
230 
231   std::string NewName = Array.getName();
232   NewName += ".toptr";
233   GlobalVariable *ReplacementToArr =
234       cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy));
235   ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy));
236 
237   Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
238   std::string FnName = Array.getName();
239   FnName += ".constructor";
240   PollyIRBuilder Builder(M.getContext());
241   FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false);
242   const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
243   Function *F = Function::Create(Ty, Linkage, FnName, &M);
244   BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F);
245   Builder.SetInsertPoint(Start);
246 
247   const uint64_t ArraySizeInt = DL.getTypeAllocSize(ArrayTy);
248   Value *ArraySize = Builder.getInt64(ArraySizeInt);
249   ArraySize->setName("array.size");
250 
251   Value *AllocatedMemRaw =
252       Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw");
253   Value *AllocatedMemTyped =
254       Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed");
255   Builder.CreateStore(AllocatedMemTyped, ReplacementToArr);
256   Builder.CreateRetVoid();
257 
258   const int Priority = 0;
259   appendToGlobalCtors(M, F, Priority, ReplacementToArr);
260 
261   SmallVector<Instruction *, 4> ArrayUserInstructions;
262   // Get all instructions that use array. We need to do this weird thing
263   // because `Constant`s that contain this array neeed to be expanded into
264   // instructions so that we can replace their parameters. `Constant`s cannot
265   // be edited easily, so we choose to convert all `Constant`s to
266   // `Instruction`s and handle all of the uses of `Array` uniformly.
267   for (Use &ArrayUse : Array.uses())
268     getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions);
269 
270   for (Instruction *UserOfArrayInst : ArrayUserInstructions) {
271 
272     Builder.SetInsertPoint(UserOfArrayInst);
273     // <ty>** -> <ty>*
274     Value *ArrPtrLoaded = Builder.CreateLoad(ReplacementToArr, "arrptr.load");
275     // <ty>* -> [ty]*
276     Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast(
277         ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast");
278     rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder);
279   }
280 }
281 
282 // We return all `allocas` that may need to be converted to a call to
283 // cudaMallocManaged.
284 static void getAllocasToBeManaged(Function &F,
285                                   SmallSet<AllocaInst *, 4> &Allocas) {
286   for (BasicBlock &BB : F) {
287     for (Instruction &I : BB) {
288       auto *Alloca = dyn_cast<AllocaInst>(&I);
289       if (!Alloca)
290         continue;
291       LLVM_DEBUG(dbgs() << "Checking if (" << *Alloca << ") may be captured: ");
292 
293       if (PointerMayBeCaptured(Alloca, /* ReturnCaptures */ false,
294                                /* StoreCaptures */ true)) {
295         Allocas.insert(Alloca);
296         LLVM_DEBUG(dbgs() << "YES (captured).\n");
297       } else {
298         LLVM_DEBUG(dbgs() << "NO (not captured).\n");
299       }
300     }
301   }
302 }
303 
304 static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca,
305                                          const DataLayout &DL) {
306   LLVM_DEBUG(dbgs() << "rewriting: (" << *Alloca << ") to managed mem.\n");
307   Module *M = Alloca->getModule();
308   assert(M && "Alloca does not have a module");
309 
310   PollyIRBuilder Builder(M->getContext());
311   Builder.SetInsertPoint(Alloca);
312 
313   Function *MallocManagedFn =
314       getOrCreatePollyMallocManaged(*Alloca->getModule());
315   const uint64_t Size =
316       DL.getTypeAllocSize(Alloca->getType()->getElementType());
317   Value *SizeVal = Builder.getInt64(Size);
318   Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal});
319   Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType());
320 
321   Function *F = Alloca->getFunction();
322   assert(F && "Alloca has invalid function");
323 
324   Bitcasted->takeName(Alloca);
325   Alloca->replaceAllUsesWith(Bitcasted);
326   Alloca->eraseFromParent();
327 
328   for (BasicBlock &BB : *F) {
329     ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator());
330     if (!Return)
331       continue;
332     Builder.SetInsertPoint(Return);
333 
334     Function *FreeManagedFn = getOrCreatePollyFreeManaged(*M);
335     Builder.CreateCall(FreeManagedFn, {RawManagedMem});
336   }
337 }
338 
339 // Replace all uses of `Old` with `New`, even inside `ConstantExpr`.
340 //
341 // `replaceAllUsesWith` does replace values in `ConstantExpr`. This function
342 // actually does replace it in `ConstantExpr`. The caveat is that if there is
343 // a use that is *outside* a function (say, at global declarations), we fail.
344 // So, this is meant to be used on values which we know will only be used
345 // within functions.
346 //
347 // This process works by looking through the uses of `Old`. If it finds a
348 // `ConstantExpr`, it recursively looks for the owning instruction.
349 // Then, it expands all the `ConstantExpr` to instructions and replaces
350 // `Old` with `New` in the expanded instructions.
351 static void replaceAllUsesAndConstantUses(Value *Old, Value *New,
352                                           PollyIRBuilder &Builder) {
353   SmallVector<Instruction *, 4> UserInstructions;
354   // Get all instructions that use array. We need to do this weird thing
355   // because `Constant`s that contain this array neeed to be expanded into
356   // instructions so that we can replace their parameters. `Constant`s cannot
357   // be edited easily, so we choose to convert all `Constant`s to
358   // `Instruction`s and handle all of the uses of `Array` uniformly.
359   for (Use &ArrayUse : Old->uses())
360     getInstructionUsersOfValue(ArrayUse.getUser(), UserInstructions);
361 
362   for (Instruction *I : UserInstructions)
363     rewriteOldValToNew(I, Old, New, Builder);
364 }
365 
366 class ManagedMemoryRewritePass : public ModulePass {
367 public:
368   static char ID;
369   GPUArch Architecture;
370   GPURuntime Runtime;
371 
372   ManagedMemoryRewritePass() : ModulePass(ID) {}
373   virtual bool runOnModule(Module &M) {
374     const DataLayout &DL = M.getDataLayout();
375 
376     Function *Malloc = M.getFunction("malloc");
377 
378     if (Malloc) {
379       PollyIRBuilder Builder(M.getContext());
380       Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
381       assert(PollyMallocManaged && "unable to create polly_mallocManaged");
382 
383       replaceAllUsesAndConstantUses(Malloc, PollyMallocManaged, Builder);
384       Malloc->eraseFromParent();
385     }
386 
387     Function *Free = M.getFunction("free");
388 
389     if (Free) {
390       PollyIRBuilder Builder(M.getContext());
391       Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M);
392       assert(PollyFreeManaged && "unable to create polly_freeManaged");
393 
394       replaceAllUsesAndConstantUses(Free, PollyFreeManaged, Builder);
395       Free->eraseFromParent();
396     }
397 
398     SmallPtrSet<GlobalVariable *, 4> GlobalsToErase;
399     for (GlobalVariable &Global : M.globals())
400       replaceGlobalArray(M, DL, Global, GlobalsToErase);
401     for (GlobalVariable *G : GlobalsToErase)
402       G->eraseFromParent();
403 
404     // Rewrite allocas to cudaMallocs if we are asked to do so.
405     if (RewriteAllocas) {
406       SmallSet<AllocaInst *, 4> AllocasToBeManaged;
407       for (Function &F : M.functions())
408         getAllocasToBeManaged(F, AllocasToBeManaged);
409 
410       for (AllocaInst *Alloca : AllocasToBeManaged)
411         rewriteAllocaAsManagedMemory(Alloca, DL);
412     }
413 
414     return true;
415   }
416 };
417 } // namespace
418 char ManagedMemoryRewritePass::ID = 42;
419 
420 Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch,
421                                                 GPURuntime Runtime) {
422   ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass();
423   pass->Runtime = Runtime;
424   pass->Architecture = Arch;
425   return pass;
426 }
427 
428 INITIALIZE_PASS_BEGIN(
429     ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
430     "Polly - Rewrite all allocations in heap & data section to managed memory",
431     false, false)
432 INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration);
433 INITIALIZE_PASS_DEPENDENCY(DependenceInfo);
434 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
435 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass);
436 INITIALIZE_PASS_DEPENDENCY(RegionInfoPass);
437 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass);
438 INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass);
439 INITIALIZE_PASS_END(
440     ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
441     "Polly - Rewrite all allocations in heap & data section to managed memory",
442     false, false)
443