1 //===---- ManagedMemoryRewrite.cpp - Rewrite global & malloc'd memory -----===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Take a module and rewrite: 11 // 1. `malloc` -> `polly_mallocManaged` 12 // 2. `free` -> `polly_freeManaged` 13 // 3. global arrays with initializers -> global arrays that are initialized 14 // with a constructor call to 15 // `polly_mallocManaged`. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "polly/CodeGen/CodeGeneration.h" 20 #include "polly/CodeGen/IslAst.h" 21 #include "polly/CodeGen/IslNodeBuilder.h" 22 #include "polly/CodeGen/PPCGCodeGeneration.h" 23 #include "polly/CodeGen/Utils.h" 24 #include "polly/DependenceInfo.h" 25 #include "polly/LinkAllPasses.h" 26 #include "polly/Options.h" 27 #include "polly/ScopDetection.h" 28 #include "polly/ScopInfo.h" 29 #include "polly/Support/SCEVValidator.h" 30 #include "llvm/Analysis/AliasAnalysis.h" 31 #include "llvm/Analysis/BasicAliasAnalysis.h" 32 #include "llvm/Analysis/CaptureTracking.h" 33 #include "llvm/Analysis/GlobalsModRef.h" 34 #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" 35 #include "llvm/Analysis/TargetLibraryInfo.h" 36 #include "llvm/Analysis/TargetTransformInfo.h" 37 #include "llvm/IR/LegacyPassManager.h" 38 #include "llvm/IR/Verifier.h" 39 #include "llvm/IRReader/IRReader.h" 40 #include "llvm/Linker/Linker.h" 41 #include "llvm/Support/TargetRegistry.h" 42 #include "llvm/Support/TargetSelect.h" 43 #include "llvm/Target/TargetMachine.h" 44 #include "llvm/Transforms/IPO/PassManagerBuilder.h" 45 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 46 #include "llvm/Transforms/Utils/ModuleUtils.h" 47 48 static cl::opt<bool> RewriteAllocas( 49 "polly-acc-rewrite-allocas", 50 cl::desc( 51 "Ask the managed memory rewriter to also rewrite alloca instructions"), 52 cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 53 54 static cl::opt<bool> IgnoreLinkageForGlobals( 55 "polly-acc-rewrite-ignore-linkage-for-globals", 56 cl::desc( 57 "By default, we only rewrite globals with internal linkage. This flag " 58 "enables rewriting of globals regardless of linkage"), 59 cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 60 61 #define DEBUG_TYPE "polly-acc-rewrite-managed-memory" 62 namespace { 63 64 static llvm::Function *getOrCreatePollyMallocManaged(Module &M) { 65 const char *Name = "polly_mallocManaged"; 66 Function *F = M.getFunction(Name); 67 68 // If F is not available, declare it. 69 if (!F) { 70 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 71 PollyIRBuilder Builder(M.getContext()); 72 // TODO: How do I get `size_t`? I assume from DataLayout? 73 FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(), 74 {Builder.getInt64Ty()}, false); 75 F = Function::Create(Ty, Linkage, Name, &M); 76 } 77 78 return F; 79 } 80 81 static llvm::Function *getOrCreatePollyFreeManaged(Module &M) { 82 const char *Name = "polly_freeManaged"; 83 Function *F = M.getFunction(Name); 84 85 // If F is not available, declare it. 86 if (!F) { 87 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 88 PollyIRBuilder Builder(M.getContext()); 89 // TODO: How do I get `size_t`? I assume from DataLayout? 90 FunctionType *Ty = 91 FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false); 92 F = Function::Create(Ty, Linkage, Name, &M); 93 } 94 95 return F; 96 } 97 98 // Expand a constant expression `Cur`, which is used at instruction `Parent` 99 // at index `index`. 100 // Since a constant expression can expand to multiple instructions, store all 101 // the expands into a set called `Expands`. 102 // Note that this goes inorder on the constant expression tree. 103 // A * ((B * D) + C) 104 // will be processed with first A, then B * D, then B, then D, and then C. 105 // Though ConstantExprs are not treated as "trees" but as DAGs, since you can 106 // have something like this: 107 // * 108 // / \ 109 // \ / 110 // (D) 111 // 112 // For the purposes of this expansion, we expand the two occurences of D 113 // separately. Therefore, we expand the DAG into the tree: 114 // * 115 // / \ 116 // D D 117 // TODO: We don't _have_to do this, but this is the simplest solution. 118 // We can write a solution that keeps track of which constants have been 119 // already expanded. 120 static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder, 121 Instruction *Parent, int index, 122 SmallPtrSet<Instruction *, 4> &Expands) { 123 assert(Cur && "invalid constant expression passed"); 124 125 Instruction *I = Cur->getAsInstruction(); 126 Expands.insert(I); 127 Parent->setOperand(index, I); 128 129 assert(I && "unable to convert ConstantExpr to Instruction"); 130 // The things that `Parent` uses (its operands) should be created 131 // before `Parent`. 132 Builder.SetInsertPoint(Parent); 133 Builder.Insert(I); 134 135 DEBUG(dbgs() << "Expanding ConstantExpression: " << *Cur 136 << " | in Instruction: " << *I << "\n";); 137 for (unsigned i = 0; i < Cur->getNumOperands(); i++) { 138 Value *Op = Cur->getOperand(i); 139 assert(isa<Constant>(Op) && "constant must have a constant operand"); 140 141 if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op)) 142 expandConstantExpr(CExprOp, Builder, I, i, Expands); 143 } 144 } 145 146 // Edit all uses of `OldVal` to NewVal` in `Inst`. This will rewrite 147 // `ConstantExpr`s that are used in the `Inst`. 148 // Note that `replaceAllUsesWith` is insufficient for this purpose because it 149 // does not rewrite values in `ConstantExpr`s. 150 static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal, 151 PollyIRBuilder &Builder) { 152 153 // This contains a set of instructions in which OldVal must be replaced. 154 // We start with `Inst`, and we fill it up with the expanded `ConstantExpr`s 155 // from `Inst`s arguments. 156 // We need to go through this process because `replaceAllUsesWith` does not 157 // actually edit `ConstantExpr`s. 158 SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst}; 159 160 // Expand all `ConstantExpr`s and place it in `InstsToVisit`. 161 for (unsigned i = 0; i < Inst->getNumOperands(); i++) { 162 Value *Operand = Inst->getOperand(i); 163 if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand)) 164 expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit); 165 } 166 167 // Now visit each instruction and use `replaceUsesOfWith`. We know that 168 // will work because `I` cannot have any `ConstantExpr` within it. 169 for (Instruction *I : InstsToVisit) 170 I->replaceUsesOfWith(OldVal, NewVal); 171 } 172 173 // Given a value `Current`, return all Instructions that may contain `Current` 174 // in an expression. 175 // We need this auxiliary function, because if we have a 176 // `Constant` that is a user of `V`, we need to recurse into the 177 // `Constant`s uses to gather the root instruciton. 178 static void getInstructionUsersOfValue(Value *V, 179 SmallVector<Instruction *, 4> &Owners) { 180 if (auto *I = dyn_cast<Instruction>(V)) { 181 Owners.push_back(I); 182 } else { 183 // Anything that is a `User` must be a constant or an instruction. 184 auto *C = cast<Constant>(V); 185 for (Use &CUse : C->uses()) 186 getInstructionUsersOfValue(CUse.getUser(), Owners); 187 } 188 } 189 190 static void 191 replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array, 192 SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) { 193 // We only want arrays. 194 ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getType()->getElementType()); 195 if (!ArrayTy) 196 return; 197 Type *ElemTy = ArrayTy->getElementType(); 198 PointerType *ElemPtrTy = ElemTy->getPointerTo(); 199 200 // We only wish to replace arrays that are visible in the module they 201 // inhabit. Otherwise, our type edit from [T] to T* would be illegal across 202 // modules. 203 const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() || 204 Array.hasInternalLinkage() || 205 IgnoreLinkageForGlobals; 206 if (!OnlyVisibleInsideModule) 207 return; 208 209 if (!Array.hasInitializer() || 210 !isa<ConstantAggregateZero>(Array.getInitializer())) 211 return; 212 213 // At this point, we have committed to replacing this array. 214 ReplacedGlobals.insert(&Array); 215 216 std::string NewName = (Array.getName() + Twine(".toptr")).str(); 217 GlobalVariable *ReplacementToArr = 218 cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy)); 219 ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy)); 220 221 Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M); 222 Twine FnName = Array.getName() + ".constructor"; 223 PollyIRBuilder Builder(M.getContext()); 224 FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false); 225 const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 226 Function *F = Function::Create(Ty, Linkage, FnName, &M); 227 BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F); 228 Builder.SetInsertPoint(Start); 229 230 int ArraySizeInt = DL.getTypeAllocSizeInBits(ArrayTy) / 8; 231 Value *ArraySize = Builder.getInt64(ArraySizeInt); 232 ArraySize->setName("array.size"); 233 234 Value *AllocatedMemRaw = 235 Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw"); 236 Value *AllocatedMemTyped = 237 Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed"); 238 Builder.CreateStore(AllocatedMemTyped, ReplacementToArr); 239 Builder.CreateRetVoid(); 240 241 const int Priority = 0; 242 appendToGlobalCtors(M, F, Priority, ReplacementToArr); 243 244 SmallVector<Instruction *, 4> ArrayUserInstructions; 245 // Get all instructions that use array. We need to do this weird thing 246 // because `Constant`s that contain this array neeed to be expanded into 247 // instructions so that we can replace their parameters. `Constant`s cannot 248 // be edited easily, so we choose to convert all `Constant`s to 249 // `Instruction`s and handle all of the uses of `Array` uniformly. 250 for (Use &ArrayUse : Array.uses()) 251 getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions); 252 253 for (Instruction *UserOfArrayInst : ArrayUserInstructions) { 254 255 Builder.SetInsertPoint(UserOfArrayInst); 256 // <ty>** -> <ty>* 257 Value *ArrPtrLoaded = Builder.CreateLoad(ReplacementToArr, "arrptr.load"); 258 // <ty>* -> [ty]* 259 Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast( 260 ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast"); 261 rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder); 262 } 263 } 264 265 // We return all `allocas` that may need to be converted to a call to 266 // cudaMallocManaged. 267 static void getAllocasToBeManaged(Function &F, 268 SmallSet<AllocaInst *, 4> &Allocas) { 269 for (BasicBlock &BB : F) { 270 for (Instruction &I : BB) { 271 auto *Alloca = dyn_cast<AllocaInst>(&I); 272 if (!Alloca) 273 continue; 274 dbgs() << "Checking if " << *Alloca << "may be captured: "; 275 276 if (PointerMayBeCaptured(Alloca, /* ReturnCaptures */ false, 277 /* StoreCaptures */ true)) { 278 Allocas.insert(Alloca); 279 DEBUG(dbgs() << "YES (captured)\n"); 280 } else { 281 DEBUG(dbgs() << "NO (not captured)\n"); 282 } 283 } 284 } 285 } 286 287 static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca, 288 const DataLayout &DL) { 289 DEBUG(dbgs() << "rewriting: " << *Alloca << " to managed mem.\n"); 290 Module *M = Alloca->getModule(); 291 assert(M && "Alloca does not have a module"); 292 293 PollyIRBuilder Builder(M->getContext()); 294 Builder.SetInsertPoint(Alloca); 295 296 Value *MallocManagedFn = getOrCreatePollyMallocManaged(*Alloca->getModule()); 297 const int Size = DL.getTypeAllocSize(Alloca->getType()->getElementType()); 298 Value *SizeVal = Builder.getInt64(Size); 299 Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal}); 300 Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType()); 301 302 Function *F = Alloca->getFunction(); 303 assert(F && "Alloca has invalid function"); 304 305 Bitcasted->takeName(Alloca); 306 Alloca->replaceAllUsesWith(Bitcasted); 307 Alloca->eraseFromParent(); 308 309 for (BasicBlock &BB : *F) { 310 ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator()); 311 if (!Return) 312 continue; 313 Builder.SetInsertPoint(Return); 314 315 Value *FreeManagedFn = getOrCreatePollyFreeManaged(*M); 316 Builder.CreateCall(FreeManagedFn, {RawManagedMem}); 317 } 318 } 319 320 class ManagedMemoryRewritePass : public ModulePass { 321 public: 322 static char ID; 323 GPUArch Architecture; 324 GPURuntime Runtime; 325 326 ManagedMemoryRewritePass() : ModulePass(ID) {} 327 virtual bool runOnModule(Module &M) { 328 const DataLayout &DL = M.getDataLayout(); 329 330 Function *Malloc = M.getFunction("malloc"); 331 332 if (Malloc) { 333 Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M); 334 assert(PollyMallocManaged && "unable to create polly_mallocManaged"); 335 Malloc->replaceAllUsesWith(PollyMallocManaged); 336 Malloc->eraseFromParent(); 337 } 338 339 Function *Free = M.getFunction("free"); 340 341 if (Free) { 342 Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M); 343 assert(PollyFreeManaged && "unable to create polly_freeManaged"); 344 Free->replaceAllUsesWith(PollyFreeManaged); 345 Free->eraseFromParent(); 346 } 347 348 SmallPtrSet<GlobalVariable *, 4> GlobalsToErase; 349 for (GlobalVariable &Global : M.globals()) 350 replaceGlobalArray(M, DL, Global, GlobalsToErase); 351 for (GlobalVariable *G : GlobalsToErase) 352 G->eraseFromParent(); 353 354 // Rewrite allocas to cudaMallocs if we are asked to do so. 355 if (RewriteAllocas) { 356 SmallSet<AllocaInst *, 4> AllocasToBeManaged; 357 for (Function &F : M.functions()) 358 getAllocasToBeManaged(F, AllocasToBeManaged); 359 360 for (AllocaInst *Alloca : AllocasToBeManaged) 361 rewriteAllocaAsManagedMemory(Alloca, DL); 362 } 363 364 return true; 365 } 366 }; 367 368 } // namespace 369 char ManagedMemoryRewritePass::ID = 42; 370 371 Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch, 372 GPURuntime Runtime) { 373 ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass(); 374 pass->Runtime = Runtime; 375 pass->Architecture = Arch; 376 return pass; 377 } 378 379 INITIALIZE_PASS_BEGIN( 380 ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory", 381 "Polly - Rewrite all allocations in heap & data section to managed memory", 382 false, false) 383 INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration); 384 INITIALIZE_PASS_DEPENDENCY(DependenceInfo); 385 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass); 386 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass); 387 INITIALIZE_PASS_DEPENDENCY(RegionInfoPass); 388 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass); 389 INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass); 390 INITIALIZE_PASS_END( 391 ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory", 392 "Polly - Rewrite all allocations in heap & data section to managed memory", 393 false, false) 394