1 //===---- ManagedMemoryRewrite.cpp - Rewrite global & malloc'd memory -----===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Take a module and rewrite:
11 // 1. `malloc` -> `polly_mallocManaged`
12 // 2. `free` -> `polly_freeManaged`
13 // 3. global arrays with initializers -> global arrays that are initialized
14 //                                       with a constructor call to
15 //                                       `polly_mallocManaged`.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "polly/CodeGen/CodeGeneration.h"
20 #include "polly/CodeGen/IslAst.h"
21 #include "polly/CodeGen/IslNodeBuilder.h"
22 #include "polly/CodeGen/PPCGCodeGeneration.h"
23 #include "polly/CodeGen/Utils.h"
24 #include "polly/DependenceInfo.h"
25 #include "polly/LinkAllPasses.h"
26 #include "polly/Options.h"
27 #include "polly/ScopDetection.h"
28 #include "polly/ScopInfo.h"
29 #include "polly/Support/SCEVValidator.h"
30 #include "llvm/Analysis/AliasAnalysis.h"
31 #include "llvm/Analysis/BasicAliasAnalysis.h"
32 #include "llvm/Analysis/CaptureTracking.h"
33 #include "llvm/Analysis/GlobalsModRef.h"
34 #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
35 #include "llvm/Analysis/TargetLibraryInfo.h"
36 #include "llvm/Analysis/TargetTransformInfo.h"
37 #include "llvm/IR/LegacyPassManager.h"
38 #include "llvm/IR/Verifier.h"
39 #include "llvm/IRReader/IRReader.h"
40 #include "llvm/Linker/Linker.h"
41 #include "llvm/Support/TargetRegistry.h"
42 #include "llvm/Support/TargetSelect.h"
43 #include "llvm/Target/TargetMachine.h"
44 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
45 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
46 #include "llvm/Transforms/Utils/ModuleUtils.h"
47 
48 static cl::opt<bool> RewriteAllocas(
49     "polly-acc-rewrite-allocas",
50     cl::desc(
51         "Ask the managed memory rewriter to also rewrite alloca instructions"),
52     cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
53 
54 static cl::opt<bool> IgnoreLinkageForGlobals(
55     "polly-acc-rewrite-ignore-linkage-for-globals",
56     cl::desc(
57         "By default, we only rewrite globals with internal linkage. This flag "
58         "enables rewriting of globals regardless of linkage"),
59     cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
60 
61 #define DEBUG_TYPE "polly-acc-rewrite-managed-memory"
62 namespace {
63 
64 static llvm::Function *getOrCreatePollyMallocManaged(Module &M) {
65   const char *Name = "polly_mallocManaged";
66   Function *F = M.getFunction(Name);
67 
68   // If F is not available, declare it.
69   if (!F) {
70     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
71     PollyIRBuilder Builder(M.getContext());
72     // TODO: How do I get `size_t`? I assume from DataLayout?
73     FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(),
74                                          {Builder.getInt64Ty()}, false);
75     F = Function::Create(Ty, Linkage, Name, &M);
76   }
77 
78   return F;
79 }
80 
81 static llvm::Function *getOrCreatePollyFreeManaged(Module &M) {
82   const char *Name = "polly_freeManaged";
83   Function *F = M.getFunction(Name);
84 
85   // If F is not available, declare it.
86   if (!F) {
87     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
88     PollyIRBuilder Builder(M.getContext());
89     // TODO: How do I get `size_t`? I assume from DataLayout?
90     FunctionType *Ty =
91         FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false);
92     F = Function::Create(Ty, Linkage, Name, &M);
93   }
94 
95   return F;
96 }
97 
98 // Expand a constant expression `Cur`, which is used at instruction `Parent`
99 // at index `index`.
100 // Since a constant expression can expand to multiple instructions, store all
101 // the expands into a set called `Expands`.
102 // Note that this goes inorder on the constant expression tree.
103 // A * ((B * D) + C)
104 // will be processed with first A, then B * D, then B, then D, and then C.
105 // Though ConstantExprs are not treated as "trees" but as DAGs, since you can
106 // have something like this:
107 //    *
108 //   /  \
109 //   \  /
110 //    (D)
111 //
112 // For the purposes of this expansion, we expand the two occurences of D
113 // separately. Therefore, we expand the DAG into the tree:
114 //  *
115 // / \
116 // D  D
117 // TODO: We don't _have_to do this, but this is the simplest solution.
118 // We can write a solution that keeps track of which constants have been
119 // already expanded.
120 static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder,
121                                Instruction *Parent, int index,
122                                SmallPtrSet<Instruction *, 4> &Expands) {
123   assert(Cur && "invalid constant expression passed");
124 
125   Instruction *I = Cur->getAsInstruction();
126   Expands.insert(I);
127   Parent->setOperand(index, I);
128 
129   assert(I && "unable to convert ConstantExpr to Instruction");
130   // The things that `Parent` uses (its operands) should be created
131   // before `Parent`.
132   Builder.SetInsertPoint(Parent);
133   Builder.Insert(I);
134 
135   DEBUG(dbgs() << "Expanding ConstantExpression: " << *Cur
136                << " | in Instruction: " << *I << "\n";);
137   for (unsigned i = 0; i < Cur->getNumOperands(); i++) {
138     Value *Op = Cur->getOperand(i);
139     assert(isa<Constant>(Op) && "constant must have a constant operand");
140 
141     if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op))
142       expandConstantExpr(CExprOp, Builder, I, i, Expands);
143   }
144 }
145 
146 // Edit all uses of `OldVal` to NewVal` in `Inst`. This will rewrite
147 // `ConstantExpr`s that are used in the `Inst`.
148 // Note that `replaceAllUsesWith` is insufficient for this purpose because it
149 // does not rewrite values in `ConstantExpr`s.
150 static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal,
151                                PollyIRBuilder &Builder) {
152 
153   // This contains a set of instructions in which OldVal must be replaced.
154   // We start with `Inst`, and we fill it up with the expanded `ConstantExpr`s
155   // from `Inst`s arguments.
156   // We need to go through this process because `replaceAllUsesWith` does not
157   // actually edit `ConstantExpr`s.
158   SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst};
159 
160   // Expand all `ConstantExpr`s and place it in `InstsToVisit`.
161   for (unsigned i = 0; i < Inst->getNumOperands(); i++) {
162     Value *Operand = Inst->getOperand(i);
163     if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand))
164       expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit);
165   }
166 
167   // Now visit each instruction and use `replaceUsesOfWith`. We know that
168   // will work because `I` cannot have any `ConstantExpr` within it.
169   for (Instruction *I : InstsToVisit)
170     I->replaceUsesOfWith(OldVal, NewVal);
171 }
172 
173 // Given a value `Current`, return all Instructions that may contain `Current`
174 // in an expression.
175 // We need this auxiliary function, because if we have a
176 // `Constant` that is a user of `V`, we need to recurse into the
177 // `Constant`s uses to gather the root instruciton.
178 static void getInstructionUsersOfValue(Value *V,
179                                        SmallVector<Instruction *, 4> &Owners) {
180   if (auto *I = dyn_cast<Instruction>(V)) {
181     Owners.push_back(I);
182   } else {
183     // Anything that is a `User` must be a constant or an instruction.
184     auto *C = cast<Constant>(V);
185     for (Use &CUse : C->uses())
186       getInstructionUsersOfValue(CUse.getUser(), Owners);
187   }
188 }
189 
190 static void
191 replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array,
192                    SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) {
193   // We only want arrays.
194   ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getType()->getElementType());
195   if (!ArrayTy)
196     return;
197   Type *ElemTy = ArrayTy->getElementType();
198   PointerType *ElemPtrTy = ElemTy->getPointerTo();
199 
200   // We only wish to replace arrays that are visible in the module they
201   // inhabit. Otherwise, our type edit from [T] to T* would be illegal across
202   // modules.
203   const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() ||
204                                        Array.hasInternalLinkage() ||
205                                        IgnoreLinkageForGlobals;
206   if (!OnlyVisibleInsideModule)
207     return;
208 
209   if (!Array.hasInitializer() ||
210       !isa<ConstantAggregateZero>(Array.getInitializer()))
211     return;
212 
213   // At this point, we have committed to replacing this array.
214   ReplacedGlobals.insert(&Array);
215 
216   std::string NewName = (Array.getName() + Twine(".toptr")).str();
217   GlobalVariable *ReplacementToArr =
218       cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy));
219   ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy));
220 
221   Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
222   Twine FnName = Array.getName() + ".constructor";
223   PollyIRBuilder Builder(M.getContext());
224   FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false);
225   const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
226   Function *F = Function::Create(Ty, Linkage, FnName, &M);
227   BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F);
228   Builder.SetInsertPoint(Start);
229 
230   int ArraySizeInt = DL.getTypeAllocSizeInBits(ArrayTy) / 8;
231   Value *ArraySize = Builder.getInt64(ArraySizeInt);
232   ArraySize->setName("array.size");
233 
234   Value *AllocatedMemRaw =
235       Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw");
236   Value *AllocatedMemTyped =
237       Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed");
238   Builder.CreateStore(AllocatedMemTyped, ReplacementToArr);
239   Builder.CreateRetVoid();
240 
241   const int Priority = 0;
242   appendToGlobalCtors(M, F, Priority, ReplacementToArr);
243 
244   SmallVector<Instruction *, 4> ArrayUserInstructions;
245   // Get all instructions that use array. We need to do this weird thing
246   // because `Constant`s that contain this array neeed to be expanded into
247   // instructions so that we can replace their parameters. `Constant`s cannot
248   // be edited easily, so we choose to convert all `Constant`s to
249   // `Instruction`s and handle all of the uses of `Array` uniformly.
250   for (Use &ArrayUse : Array.uses())
251     getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions);
252 
253   for (Instruction *UserOfArrayInst : ArrayUserInstructions) {
254 
255     Builder.SetInsertPoint(UserOfArrayInst);
256     // <ty>** -> <ty>*
257     Value *ArrPtrLoaded = Builder.CreateLoad(ReplacementToArr, "arrptr.load");
258     // <ty>* -> [ty]*
259     Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast(
260         ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast");
261     rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder);
262   }
263 }
264 
265 // We return all `allocas` that may need to be converted to a call to
266 // cudaMallocManaged.
267 static void getAllocasToBeManaged(Function &F,
268                                   SmallSet<AllocaInst *, 4> &Allocas) {
269   for (BasicBlock &BB : F) {
270     for (Instruction &I : BB) {
271       auto *Alloca = dyn_cast<AllocaInst>(&I);
272       if (!Alloca)
273         continue;
274       dbgs() << "Checking if " << *Alloca << "may be captured: ";
275 
276       if (PointerMayBeCaptured(Alloca, /* ReturnCaptures */ false,
277                                /* StoreCaptures */ true)) {
278         Allocas.insert(Alloca);
279         DEBUG(dbgs() << "YES (captured)\n");
280       } else {
281         DEBUG(dbgs() << "NO (not captured)\n");
282       }
283     }
284   }
285 }
286 
287 static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca,
288                                          const DataLayout &DL) {
289   DEBUG(dbgs() << "rewriting: " << *Alloca << " to managed mem.\n");
290   Module *M = Alloca->getModule();
291   assert(M && "Alloca does not have a module");
292 
293   PollyIRBuilder Builder(M->getContext());
294   Builder.SetInsertPoint(Alloca);
295 
296   Value *MallocManagedFn = getOrCreatePollyMallocManaged(*Alloca->getModule());
297   const int Size = DL.getTypeAllocSize(Alloca->getType()->getElementType());
298   Value *SizeVal = Builder.getInt64(Size);
299   Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal});
300   Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType());
301 
302   Function *F = Alloca->getFunction();
303   assert(F && "Alloca has invalid function");
304 
305   Bitcasted->takeName(Alloca);
306   Alloca->replaceAllUsesWith(Bitcasted);
307   Alloca->eraseFromParent();
308 
309   for (BasicBlock &BB : *F) {
310     ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator());
311     if (!Return)
312       continue;
313     Builder.SetInsertPoint(Return);
314 
315     Value *FreeManagedFn = getOrCreatePollyFreeManaged(*M);
316     Builder.CreateCall(FreeManagedFn, {RawManagedMem});
317   }
318 }
319 
320 class ManagedMemoryRewritePass : public ModulePass {
321 public:
322   static char ID;
323   GPUArch Architecture;
324   GPURuntime Runtime;
325 
326   ManagedMemoryRewritePass() : ModulePass(ID) {}
327   virtual bool runOnModule(Module &M) {
328     const DataLayout &DL = M.getDataLayout();
329 
330     Function *Malloc = M.getFunction("malloc");
331 
332     if (Malloc) {
333       Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
334       assert(PollyMallocManaged && "unable to create polly_mallocManaged");
335       Malloc->replaceAllUsesWith(PollyMallocManaged);
336       Malloc->eraseFromParent();
337     }
338 
339     Function *Free = M.getFunction("free");
340 
341     if (Free) {
342       Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M);
343       assert(PollyFreeManaged && "unable to create polly_freeManaged");
344       Free->replaceAllUsesWith(PollyFreeManaged);
345       Free->eraseFromParent();
346     }
347 
348     SmallPtrSet<GlobalVariable *, 4> GlobalsToErase;
349     for (GlobalVariable &Global : M.globals())
350       replaceGlobalArray(M, DL, Global, GlobalsToErase);
351     for (GlobalVariable *G : GlobalsToErase)
352       G->eraseFromParent();
353 
354     // Rewrite allocas to cudaMallocs if we are asked to do so.
355     if (RewriteAllocas) {
356       SmallSet<AllocaInst *, 4> AllocasToBeManaged;
357       for (Function &F : M.functions())
358         getAllocasToBeManaged(F, AllocasToBeManaged);
359 
360       for (AllocaInst *Alloca : AllocasToBeManaged)
361         rewriteAllocaAsManagedMemory(Alloca, DL);
362     }
363 
364     return true;
365   }
366 };
367 
368 } // namespace
369 char ManagedMemoryRewritePass::ID = 42;
370 
371 Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch,
372                                                 GPURuntime Runtime) {
373   ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass();
374   pass->Runtime = Runtime;
375   pass->Architecture = Arch;
376   return pass;
377 }
378 
379 INITIALIZE_PASS_BEGIN(
380     ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
381     "Polly - Rewrite all allocations in heap & data section to managed memory",
382     false, false)
383 INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration);
384 INITIALIZE_PASS_DEPENDENCY(DependenceInfo);
385 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
386 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass);
387 INITIALIZE_PASS_DEPENDENCY(RegionInfoPass);
388 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass);
389 INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass);
390 INITIALIZE_PASS_END(
391     ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
392     "Polly - Rewrite all allocations in heap & data section to managed memory",
393     false, false)
394