1 //===---- ManagedMemoryRewrite.cpp - Rewrite global & malloc'd memory -----===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Take a module and rewrite: 11 // 1. `malloc` -> `polly_mallocManaged` 12 // 2. `free` -> `polly_freeManaged` 13 // 3. global arrays with initializers -> global arrays that are initialized 14 // with a constructor call to 15 // `polly_mallocManaged`. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "polly/CodeGen/CodeGeneration.h" 20 #include "polly/CodeGen/IslAst.h" 21 #include "polly/CodeGen/IslNodeBuilder.h" 22 #include "polly/CodeGen/PPCGCodeGeneration.h" 23 #include "polly/CodeGen/Utils.h" 24 #include "polly/DependenceInfo.h" 25 #include "polly/LinkAllPasses.h" 26 #include "polly/Options.h" 27 #include "polly/ScopDetection.h" 28 #include "polly/ScopInfo.h" 29 #include "polly/Support/SCEVValidator.h" 30 #include "llvm/Analysis/AliasAnalysis.h" 31 #include "llvm/Analysis/BasicAliasAnalysis.h" 32 #include "llvm/Analysis/CaptureTracking.h" 33 #include "llvm/Analysis/GlobalsModRef.h" 34 #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" 35 #include "llvm/Analysis/TargetLibraryInfo.h" 36 #include "llvm/Analysis/TargetTransformInfo.h" 37 #include "llvm/IR/LegacyPassManager.h" 38 #include "llvm/IR/Verifier.h" 39 #include "llvm/IRReader/IRReader.h" 40 #include "llvm/Linker/Linker.h" 41 #include "llvm/Support/TargetRegistry.h" 42 #include "llvm/Support/TargetSelect.h" 43 #include "llvm/Target/TargetMachine.h" 44 #include "llvm/Transforms/IPO/PassManagerBuilder.h" 45 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 46 #include "llvm/Transforms/Utils/ModuleUtils.h" 47 48 static cl::opt<bool> RewriteAllocas( 49 "polly-acc-rewrite-allocas", 50 cl::desc( 51 "Ask the managed memory rewriter to also rewrite alloca instructions"), 52 cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 53 54 static cl::opt<bool> IgnoreLinkageForGlobals( 55 "polly-acc-rewrite-ignore-linkage-for-globals", 56 cl::desc( 57 "By default, we only rewrite globals with internal linkage. This flag " 58 "enables rewriting of globals regardless of linkage"), 59 cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 60 61 #define DEBUG_TYPE "polly-acc-rewrite-managed-memory" 62 namespace { 63 64 static llvm::Function *getOrCreatePollyMallocManaged(Module &M) { 65 const char *Name = "polly_mallocManaged"; 66 Function *F = M.getFunction(Name); 67 68 // If F is not available, declare it. 69 if (!F) { 70 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 71 PollyIRBuilder Builder(M.getContext()); 72 // TODO: How do I get `size_t`? I assume from DataLayout? 73 FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(), 74 {Builder.getInt64Ty()}, false); 75 F = Function::Create(Ty, Linkage, Name, &M); 76 } 77 78 return F; 79 } 80 81 static llvm::Function *getOrCreatePollyFreeManaged(Module &M) { 82 const char *Name = "polly_freeManaged"; 83 Function *F = M.getFunction(Name); 84 85 // If F is not available, declare it. 86 if (!F) { 87 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 88 PollyIRBuilder Builder(M.getContext()); 89 // TODO: How do I get `size_t`? I assume from DataLayout? 90 FunctionType *Ty = 91 FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false); 92 F = Function::Create(Ty, Linkage, Name, &M); 93 } 94 95 return F; 96 } 97 98 // Expand a constant expression `Cur`, which is used at instruction `Parent` 99 // at index `index`. 100 // Since a constant expression can expand to multiple instructions, store all 101 // the expands into a set called `Expands`. 102 // Note that this goes inorder on the constant expression tree. 103 // A * ((B * D) + C) 104 // will be processed with first A, then B * D, then B, then D, and then C. 105 // Though ConstantExprs are not treated as "trees" but as DAGs, since you can 106 // have something like this: 107 // * 108 // / \ 109 // \ / 110 // (D) 111 // 112 // For the purposes of this expansion, we expand the two occurences of D 113 // separately. Therefore, we expand the DAG into the tree: 114 // * 115 // / \ 116 // D D 117 // TODO: We don't _have_to do this, but this is the simplest solution. 118 // We can write a solution that keeps track of which constants have been 119 // already expanded. 120 static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder, 121 Instruction *Parent, int index, 122 SmallPtrSet<Instruction *, 4> &Expands) { 123 assert(Cur && "invalid constant expression passed"); 124 Instruction *I = Cur->getAsInstruction(); 125 assert(I && "unable to convert ConstantExpr to Instruction"); 126 127 DEBUG(dbgs() << "Expanding ConstantExpression: " << *Cur 128 << " | in Instruction: " << *I << "\n";); 129 130 // Invalidate `Cur` so that no one after this point uses `Cur`. Rather, 131 // they should mutate `I`. 132 Cur = nullptr; 133 134 Expands.insert(I); 135 Parent->setOperand(index, I); 136 137 // The things that `Parent` uses (its operands) should be created 138 // before `Parent`. 139 Builder.SetInsertPoint(Parent); 140 Builder.Insert(I); 141 142 for (unsigned i = 0; i < I->getNumOperands(); i++) { 143 Value *Op = I->getOperand(i); 144 assert(isa<Constant>(Op) && "constant must have a constant operand"); 145 146 if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op)) 147 expandConstantExpr(CExprOp, Builder, I, i, Expands); 148 } 149 } 150 151 // Edit all uses of `OldVal` to NewVal` in `Inst`. This will rewrite 152 // `ConstantExpr`s that are used in the `Inst`. 153 // Note that `replaceAllUsesWith` is insufficient for this purpose because it 154 // does not rewrite values in `ConstantExpr`s. 155 static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal, 156 PollyIRBuilder &Builder) { 157 158 // This contains a set of instructions in which OldVal must be replaced. 159 // We start with `Inst`, and we fill it up with the expanded `ConstantExpr`s 160 // from `Inst`s arguments. 161 // We need to go through this process because `replaceAllUsesWith` does not 162 // actually edit `ConstantExpr`s. 163 SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst}; 164 165 // Expand all `ConstantExpr`s and place it in `InstsToVisit`. 166 for (unsigned i = 0; i < Inst->getNumOperands(); i++) { 167 Value *Operand = Inst->getOperand(i); 168 if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand)) 169 expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit); 170 } 171 172 // Now visit each instruction and use `replaceUsesOfWith`. We know that 173 // will work because `I` cannot have any `ConstantExpr` within it. 174 for (Instruction *I : InstsToVisit) 175 I->replaceUsesOfWith(OldVal, NewVal); 176 } 177 178 // Given a value `Current`, return all Instructions that may contain `Current` 179 // in an expression. 180 // We need this auxiliary function, because if we have a 181 // `Constant` that is a user of `V`, we need to recurse into the 182 // `Constant`s uses to gather the root instruciton. 183 static void getInstructionUsersOfValue(Value *V, 184 SmallVector<Instruction *, 4> &Owners) { 185 if (auto *I = dyn_cast<Instruction>(V)) { 186 Owners.push_back(I); 187 } else { 188 // Anything that is a `User` must be a constant or an instruction. 189 auto *C = cast<Constant>(V); 190 for (Use &CUse : C->uses()) 191 getInstructionUsersOfValue(CUse.getUser(), Owners); 192 } 193 } 194 195 static void 196 replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array, 197 SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) { 198 // We only want arrays. 199 ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getType()->getElementType()); 200 if (!ArrayTy) 201 return; 202 Type *ElemTy = ArrayTy->getElementType(); 203 PointerType *ElemPtrTy = ElemTy->getPointerTo(); 204 205 // We only wish to replace arrays that are visible in the module they 206 // inhabit. Otherwise, our type edit from [T] to T* would be illegal across 207 // modules. 208 const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() || 209 Array.hasInternalLinkage() || 210 IgnoreLinkageForGlobals; 211 if (!OnlyVisibleInsideModule) 212 return; 213 214 if (!Array.hasInitializer() || 215 !isa<ConstantAggregateZero>(Array.getInitializer())) 216 return; 217 218 // At this point, we have committed to replacing this array. 219 ReplacedGlobals.insert(&Array); 220 221 std::string NewName = (Array.getName() + Twine(".toptr")).str(); 222 GlobalVariable *ReplacementToArr = 223 cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy)); 224 ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy)); 225 226 Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M); 227 Twine FnName = Array.getName() + ".constructor"; 228 PollyIRBuilder Builder(M.getContext()); 229 FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false); 230 const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 231 Function *F = Function::Create(Ty, Linkage, FnName, &M); 232 BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F); 233 Builder.SetInsertPoint(Start); 234 235 int ArraySizeInt = DL.getTypeAllocSizeInBits(ArrayTy) / 8; 236 Value *ArraySize = Builder.getInt64(ArraySizeInt); 237 ArraySize->setName("array.size"); 238 239 Value *AllocatedMemRaw = 240 Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw"); 241 Value *AllocatedMemTyped = 242 Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed"); 243 Builder.CreateStore(AllocatedMemTyped, ReplacementToArr); 244 Builder.CreateRetVoid(); 245 246 const int Priority = 0; 247 appendToGlobalCtors(M, F, Priority, ReplacementToArr); 248 249 SmallVector<Instruction *, 4> ArrayUserInstructions; 250 // Get all instructions that use array. We need to do this weird thing 251 // because `Constant`s that contain this array neeed to be expanded into 252 // instructions so that we can replace their parameters. `Constant`s cannot 253 // be edited easily, so we choose to convert all `Constant`s to 254 // `Instruction`s and handle all of the uses of `Array` uniformly. 255 for (Use &ArrayUse : Array.uses()) 256 getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions); 257 258 for (Instruction *UserOfArrayInst : ArrayUserInstructions) { 259 260 Builder.SetInsertPoint(UserOfArrayInst); 261 // <ty>** -> <ty>* 262 Value *ArrPtrLoaded = Builder.CreateLoad(ReplacementToArr, "arrptr.load"); 263 // <ty>* -> [ty]* 264 Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast( 265 ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast"); 266 rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder); 267 } 268 } 269 270 // We return all `allocas` that may need to be converted to a call to 271 // cudaMallocManaged. 272 static void getAllocasToBeManaged(Function &F, 273 SmallSet<AllocaInst *, 4> &Allocas) { 274 for (BasicBlock &BB : F) { 275 for (Instruction &I : BB) { 276 auto *Alloca = dyn_cast<AllocaInst>(&I); 277 if (!Alloca) 278 continue; 279 dbgs() << "Checking if " << *Alloca << "may be captured: "; 280 281 if (PointerMayBeCaptured(Alloca, /* ReturnCaptures */ false, 282 /* StoreCaptures */ true)) { 283 Allocas.insert(Alloca); 284 DEBUG(dbgs() << "YES (captured)\n"); 285 } else { 286 DEBUG(dbgs() << "NO (not captured)\n"); 287 } 288 } 289 } 290 } 291 292 static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca, 293 const DataLayout &DL) { 294 DEBUG(dbgs() << "rewriting: " << *Alloca << " to managed mem.\n"); 295 Module *M = Alloca->getModule(); 296 assert(M && "Alloca does not have a module"); 297 298 PollyIRBuilder Builder(M->getContext()); 299 Builder.SetInsertPoint(Alloca); 300 301 Value *MallocManagedFn = getOrCreatePollyMallocManaged(*Alloca->getModule()); 302 const int Size = DL.getTypeAllocSize(Alloca->getType()->getElementType()); 303 Value *SizeVal = Builder.getInt64(Size); 304 Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal}); 305 Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType()); 306 307 Function *F = Alloca->getFunction(); 308 assert(F && "Alloca has invalid function"); 309 310 Bitcasted->takeName(Alloca); 311 Alloca->replaceAllUsesWith(Bitcasted); 312 Alloca->eraseFromParent(); 313 314 for (BasicBlock &BB : *F) { 315 ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator()); 316 if (!Return) 317 continue; 318 Builder.SetInsertPoint(Return); 319 320 Value *FreeManagedFn = getOrCreatePollyFreeManaged(*M); 321 Builder.CreateCall(FreeManagedFn, {RawManagedMem}); 322 } 323 } 324 325 // Replace all uses of `Old` with `New`, even inside `ConstantExpr`. 326 // 327 // `replaceAllUsesWith` does replace values in `ConstantExpr`. This function 328 // actually does replace it in `ConstantExpr`. The caveat is that if there is 329 // a use that is *outside* a function (say, at global declarations), we fail. 330 // So, this is meant to be used on values which we know will only be used 331 // within functions. 332 // 333 // This process works by looking through the uses of `Old`. If it finds a 334 // `ConstantExpr`, it recursively looks for the owning instruction. 335 // Then, it expands all the `ConstantExpr` to instructions and replaces 336 // `Old` with `New` in the expanded instructions. 337 static void replaceAllUsesAndConstantUses(Value *Old, Value *New, 338 PollyIRBuilder &Builder) { 339 SmallVector<Instruction *, 4> UserInstructions; 340 // Get all instructions that use array. We need to do this weird thing 341 // because `Constant`s that contain this array neeed to be expanded into 342 // instructions so that we can replace their parameters. `Constant`s cannot 343 // be edited easily, so we choose to convert all `Constant`s to 344 // `Instruction`s and handle all of the uses of `Array` uniformly. 345 for (Use &ArrayUse : Old->uses()) 346 getInstructionUsersOfValue(ArrayUse.getUser(), UserInstructions); 347 348 for (Instruction *I : UserInstructions) 349 rewriteOldValToNew(I, Old, New, Builder); 350 } 351 352 class ManagedMemoryRewritePass : public ModulePass { 353 public: 354 static char ID; 355 GPUArch Architecture; 356 GPURuntime Runtime; 357 358 ManagedMemoryRewritePass() : ModulePass(ID) {} 359 virtual bool runOnModule(Module &M) { 360 const DataLayout &DL = M.getDataLayout(); 361 362 Function *Malloc = M.getFunction("malloc"); 363 364 if (Malloc) { 365 PollyIRBuilder Builder(M.getContext()); 366 Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M); 367 assert(PollyMallocManaged && "unable to create polly_mallocManaged"); 368 369 replaceAllUsesAndConstantUses(Malloc, PollyMallocManaged, Builder); 370 Malloc->eraseFromParent(); 371 } 372 373 Function *Free = M.getFunction("free"); 374 375 if (Free) { 376 PollyIRBuilder Builder(M.getContext()); 377 Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M); 378 assert(PollyFreeManaged && "unable to create polly_freeManaged"); 379 380 replaceAllUsesAndConstantUses(Free, PollyFreeManaged, Builder); 381 Free->eraseFromParent(); 382 } 383 384 SmallPtrSet<GlobalVariable *, 4> GlobalsToErase; 385 for (GlobalVariable &Global : M.globals()) 386 replaceGlobalArray(M, DL, Global, GlobalsToErase); 387 for (GlobalVariable *G : GlobalsToErase) 388 G->eraseFromParent(); 389 390 // Rewrite allocas to cudaMallocs if we are asked to do so. 391 if (RewriteAllocas) { 392 SmallSet<AllocaInst *, 4> AllocasToBeManaged; 393 for (Function &F : M.functions()) 394 getAllocasToBeManaged(F, AllocasToBeManaged); 395 396 for (AllocaInst *Alloca : AllocasToBeManaged) 397 rewriteAllocaAsManagedMemory(Alloca, DL); 398 } 399 400 return true; 401 } 402 }; 403 404 } // namespace 405 char ManagedMemoryRewritePass::ID = 42; 406 407 Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch, 408 GPURuntime Runtime) { 409 ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass(); 410 pass->Runtime = Runtime; 411 pass->Architecture = Arch; 412 return pass; 413 } 414 415 INITIALIZE_PASS_BEGIN( 416 ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory", 417 "Polly - Rewrite all allocations in heap & data section to managed memory", 418 false, false) 419 INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration); 420 INITIALIZE_PASS_DEPENDENCY(DependenceInfo); 421 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass); 422 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass); 423 INITIALIZE_PASS_DEPENDENCY(RegionInfoPass); 424 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass); 425 INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass); 426 INITIALIZE_PASS_END( 427 ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory", 428 "Polly - Rewrite all allocations in heap & data section to managed memory", 429 false, false) 430