1 //===---- ManagedMemoryRewrite.cpp - Rewrite global & malloc'd memory -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Take a module and rewrite: 10 // 1. `malloc` -> `polly_mallocManaged` 11 // 2. `free` -> `polly_freeManaged` 12 // 3. global arrays with initializers -> global arrays that are initialized 13 // with a constructor call to 14 // `polly_mallocManaged`. 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "polly/CodeGen/CodeGeneration.h" 19 #include "polly/CodeGen/IslAst.h" 20 #include "polly/CodeGen/IslNodeBuilder.h" 21 #include "polly/CodeGen/PPCGCodeGeneration.h" 22 #include "polly/CodeGen/Utils.h" 23 #include "polly/DependenceInfo.h" 24 #include "polly/LinkAllPasses.h" 25 #include "polly/Options.h" 26 #include "polly/ScopDetection.h" 27 #include "polly/ScopInfo.h" 28 #include "polly/Support/SCEVValidator.h" 29 #include "llvm/Analysis/AliasAnalysis.h" 30 #include "llvm/Analysis/BasicAliasAnalysis.h" 31 #include "llvm/Analysis/CaptureTracking.h" 32 #include "llvm/Analysis/GlobalsModRef.h" 33 #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" 34 #include "llvm/Analysis/TargetLibraryInfo.h" 35 #include "llvm/Analysis/TargetTransformInfo.h" 36 #include "llvm/IR/LegacyPassManager.h" 37 #include "llvm/IR/Verifier.h" 38 #include "llvm/IRReader/IRReader.h" 39 #include "llvm/Linker/Linker.h" 40 #include "llvm/Support/TargetRegistry.h" 41 #include "llvm/Support/TargetSelect.h" 42 #include "llvm/Target/TargetMachine.h" 43 #include "llvm/Transforms/IPO/PassManagerBuilder.h" 44 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 45 #include "llvm/Transforms/Utils/ModuleUtils.h" 46 47 static cl::opt<bool> RewriteAllocas( 48 "polly-acc-rewrite-allocas", 49 cl::desc( 50 "Ask the managed memory rewriter to also rewrite alloca instructions"), 51 cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 52 53 static cl::opt<bool> IgnoreLinkageForGlobals( 54 "polly-acc-rewrite-ignore-linkage-for-globals", 55 cl::desc( 56 "By default, we only rewrite globals with internal linkage. This flag " 57 "enables rewriting of globals regardless of linkage"), 58 cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 59 60 #define DEBUG_TYPE "polly-acc-rewrite-managed-memory" 61 namespace { 62 63 static llvm::Function *getOrCreatePollyMallocManaged(Module &M) { 64 const char *Name = "polly_mallocManaged"; 65 Function *F = M.getFunction(Name); 66 67 // If F is not available, declare it. 68 if (!F) { 69 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 70 PollyIRBuilder Builder(M.getContext()); 71 // TODO: How do I get `size_t`? I assume from DataLayout? 72 FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(), 73 {Builder.getInt64Ty()}, false); 74 F = Function::Create(Ty, Linkage, Name, &M); 75 } 76 77 return F; 78 } 79 80 static llvm::Function *getOrCreatePollyFreeManaged(Module &M) { 81 const char *Name = "polly_freeManaged"; 82 Function *F = M.getFunction(Name); 83 84 // If F is not available, declare it. 85 if (!F) { 86 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 87 PollyIRBuilder Builder(M.getContext()); 88 // TODO: How do I get `size_t`? I assume from DataLayout? 89 FunctionType *Ty = 90 FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false); 91 F = Function::Create(Ty, Linkage, Name, &M); 92 } 93 94 return F; 95 } 96 97 // Expand a constant expression `Cur`, which is used at instruction `Parent` 98 // at index `index`. 99 // Since a constant expression can expand to multiple instructions, store all 100 // the expands into a set called `Expands`. 101 // Note that this goes inorder on the constant expression tree. 102 // A * ((B * D) + C) 103 // will be processed with first A, then B * D, then B, then D, and then C. 104 // Though ConstantExprs are not treated as "trees" but as DAGs, since you can 105 // have something like this: 106 // * 107 // / \ 108 // \ / 109 // (D) 110 // 111 // For the purposes of this expansion, we expand the two occurences of D 112 // separately. Therefore, we expand the DAG into the tree: 113 // * 114 // / \ 115 // D D 116 // TODO: We don't _have_to do this, but this is the simplest solution. 117 // We can write a solution that keeps track of which constants have been 118 // already expanded. 119 static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder, 120 Instruction *Parent, int index, 121 SmallPtrSet<Instruction *, 4> &Expands) { 122 assert(Cur && "invalid constant expression passed"); 123 Instruction *I = Cur->getAsInstruction(); 124 assert(I && "unable to convert ConstantExpr to Instruction"); 125 126 LLVM_DEBUG(dbgs() << "Expanding ConstantExpression: (" << *Cur 127 << ") in Instruction: (" << *I << ")\n";); 128 129 // Invalidate `Cur` so that no one after this point uses `Cur`. Rather, 130 // they should mutate `I`. 131 Cur = nullptr; 132 133 Expands.insert(I); 134 Parent->setOperand(index, I); 135 136 // The things that `Parent` uses (its operands) should be created 137 // before `Parent`. 138 Builder.SetInsertPoint(Parent); 139 Builder.Insert(I); 140 141 for (unsigned i = 0; i < I->getNumOperands(); i++) { 142 Value *Op = I->getOperand(i); 143 assert(isa<Constant>(Op) && "constant must have a constant operand"); 144 145 if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op)) 146 expandConstantExpr(CExprOp, Builder, I, i, Expands); 147 } 148 } 149 150 // Edit all uses of `OldVal` to NewVal` in `Inst`. This will rewrite 151 // `ConstantExpr`s that are used in the `Inst`. 152 // Note that `replaceAllUsesWith` is insufficient for this purpose because it 153 // does not rewrite values in `ConstantExpr`s. 154 static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal, 155 PollyIRBuilder &Builder) { 156 157 // This contains a set of instructions in which OldVal must be replaced. 158 // We start with `Inst`, and we fill it up with the expanded `ConstantExpr`s 159 // from `Inst`s arguments. 160 // We need to go through this process because `replaceAllUsesWith` does not 161 // actually edit `ConstantExpr`s. 162 SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst}; 163 164 // Expand all `ConstantExpr`s and place it in `InstsToVisit`. 165 for (unsigned i = 0; i < Inst->getNumOperands(); i++) { 166 Value *Operand = Inst->getOperand(i); 167 if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand)) 168 expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit); 169 } 170 171 // Now visit each instruction and use `replaceUsesOfWith`. We know that 172 // will work because `I` cannot have any `ConstantExpr` within it. 173 for (Instruction *I : InstsToVisit) 174 I->replaceUsesOfWith(OldVal, NewVal); 175 } 176 177 // Given a value `Current`, return all Instructions that may contain `Current` 178 // in an expression. 179 // We need this auxiliary function, because if we have a 180 // `Constant` that is a user of `V`, we need to recurse into the 181 // `Constant`s uses to gather the root instruciton. 182 static void getInstructionUsersOfValue(Value *V, 183 SmallVector<Instruction *, 4> &Owners) { 184 if (auto *I = dyn_cast<Instruction>(V)) { 185 Owners.push_back(I); 186 } else { 187 // Anything that is a `User` must be a constant or an instruction. 188 auto *C = cast<Constant>(V); 189 for (Use &CUse : C->uses()) 190 getInstructionUsersOfValue(CUse.getUser(), Owners); 191 } 192 } 193 194 static void 195 replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array, 196 SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) { 197 // We only want arrays. 198 ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getType()->getElementType()); 199 if (!ArrayTy) 200 return; 201 Type *ElemTy = ArrayTy->getElementType(); 202 PointerType *ElemPtrTy = ElemTy->getPointerTo(); 203 204 // We only wish to replace arrays that are visible in the module they 205 // inhabit. Otherwise, our type edit from [T] to T* would be illegal across 206 // modules. 207 const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() || 208 Array.hasInternalLinkage() || 209 IgnoreLinkageForGlobals; 210 if (!OnlyVisibleInsideModule) { 211 LLVM_DEBUG( 212 dbgs() << "Not rewriting (" << Array 213 << ") to managed memory " 214 "because it could be visible externally. To force rewrite, " 215 "use -polly-acc-rewrite-ignore-linkage-for-globals.\n"); 216 return; 217 } 218 219 if (!Array.hasInitializer() || 220 !isa<ConstantAggregateZero>(Array.getInitializer())) { 221 LLVM_DEBUG(dbgs() << "Not rewriting (" << Array 222 << ") to managed memory " 223 "because it has an initializer which is " 224 "not a zeroinitializer.\n"); 225 return; 226 } 227 228 // At this point, we have committed to replacing this array. 229 ReplacedGlobals.insert(&Array); 230 231 std::string NewName = Array.getName(); 232 NewName += ".toptr"; 233 GlobalVariable *ReplacementToArr = 234 cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy)); 235 ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy)); 236 237 Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M); 238 std::string FnName = Array.getName(); 239 FnName += ".constructor"; 240 PollyIRBuilder Builder(M.getContext()); 241 FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false); 242 const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 243 Function *F = Function::Create(Ty, Linkage, FnName, &M); 244 BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F); 245 Builder.SetInsertPoint(Start); 246 247 const uint64_t ArraySizeInt = DL.getTypeAllocSize(ArrayTy); 248 Value *ArraySize = Builder.getInt64(ArraySizeInt); 249 ArraySize->setName("array.size"); 250 251 Value *AllocatedMemRaw = 252 Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw"); 253 Value *AllocatedMemTyped = 254 Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed"); 255 Builder.CreateStore(AllocatedMemTyped, ReplacementToArr); 256 Builder.CreateRetVoid(); 257 258 const int Priority = 0; 259 appendToGlobalCtors(M, F, Priority, ReplacementToArr); 260 261 SmallVector<Instruction *, 4> ArrayUserInstructions; 262 // Get all instructions that use array. We need to do this weird thing 263 // because `Constant`s that contain this array neeed to be expanded into 264 // instructions so that we can replace their parameters. `Constant`s cannot 265 // be edited easily, so we choose to convert all `Constant`s to 266 // `Instruction`s and handle all of the uses of `Array` uniformly. 267 for (Use &ArrayUse : Array.uses()) 268 getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions); 269 270 for (Instruction *UserOfArrayInst : ArrayUserInstructions) { 271 272 Builder.SetInsertPoint(UserOfArrayInst); 273 // <ty>** -> <ty>* 274 Value *ArrPtrLoaded = Builder.CreateLoad(ReplacementToArr, "arrptr.load"); 275 // <ty>* -> [ty]* 276 Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast( 277 ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast"); 278 rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder); 279 } 280 } 281 282 // We return all `allocas` that may need to be converted to a call to 283 // cudaMallocManaged. 284 static void getAllocasToBeManaged(Function &F, 285 SmallSet<AllocaInst *, 4> &Allocas) { 286 for (BasicBlock &BB : F) { 287 for (Instruction &I : BB) { 288 auto *Alloca = dyn_cast<AllocaInst>(&I); 289 if (!Alloca) 290 continue; 291 LLVM_DEBUG(dbgs() << "Checking if (" << *Alloca << ") may be captured: "); 292 293 if (PointerMayBeCaptured(Alloca, /* ReturnCaptures */ false, 294 /* StoreCaptures */ true)) { 295 Allocas.insert(Alloca); 296 LLVM_DEBUG(dbgs() << "YES (captured).\n"); 297 } else { 298 LLVM_DEBUG(dbgs() << "NO (not captured).\n"); 299 } 300 } 301 } 302 } 303 304 static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca, 305 const DataLayout &DL) { 306 LLVM_DEBUG(dbgs() << "rewriting: (" << *Alloca << ") to managed mem.\n"); 307 Module *M = Alloca->getModule(); 308 assert(M && "Alloca does not have a module"); 309 310 PollyIRBuilder Builder(M->getContext()); 311 Builder.SetInsertPoint(Alloca); 312 313 Function *MallocManagedFn = 314 getOrCreatePollyMallocManaged(*Alloca->getModule()); 315 const uint64_t Size = 316 DL.getTypeAllocSize(Alloca->getType()->getElementType()); 317 Value *SizeVal = Builder.getInt64(Size); 318 Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal}); 319 Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType()); 320 321 Function *F = Alloca->getFunction(); 322 assert(F && "Alloca has invalid function"); 323 324 Bitcasted->takeName(Alloca); 325 Alloca->replaceAllUsesWith(Bitcasted); 326 Alloca->eraseFromParent(); 327 328 for (BasicBlock &BB : *F) { 329 ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator()); 330 if (!Return) 331 continue; 332 Builder.SetInsertPoint(Return); 333 334 Function *FreeManagedFn = getOrCreatePollyFreeManaged(*M); 335 Builder.CreateCall(FreeManagedFn, {RawManagedMem}); 336 } 337 } 338 339 // Replace all uses of `Old` with `New`, even inside `ConstantExpr`. 340 // 341 // `replaceAllUsesWith` does replace values in `ConstantExpr`. This function 342 // actually does replace it in `ConstantExpr`. The caveat is that if there is 343 // a use that is *outside* a function (say, at global declarations), we fail. 344 // So, this is meant to be used on values which we know will only be used 345 // within functions. 346 // 347 // This process works by looking through the uses of `Old`. If it finds a 348 // `ConstantExpr`, it recursively looks for the owning instruction. 349 // Then, it expands all the `ConstantExpr` to instructions and replaces 350 // `Old` with `New` in the expanded instructions. 351 static void replaceAllUsesAndConstantUses(Value *Old, Value *New, 352 PollyIRBuilder &Builder) { 353 SmallVector<Instruction *, 4> UserInstructions; 354 // Get all instructions that use array. We need to do this weird thing 355 // because `Constant`s that contain this array neeed to be expanded into 356 // instructions so that we can replace their parameters. `Constant`s cannot 357 // be edited easily, so we choose to convert all `Constant`s to 358 // `Instruction`s and handle all of the uses of `Array` uniformly. 359 for (Use &ArrayUse : Old->uses()) 360 getInstructionUsersOfValue(ArrayUse.getUser(), UserInstructions); 361 362 for (Instruction *I : UserInstructions) 363 rewriteOldValToNew(I, Old, New, Builder); 364 } 365 366 class ManagedMemoryRewritePass : public ModulePass { 367 public: 368 static char ID; 369 GPUArch Architecture; 370 GPURuntime Runtime; 371 372 ManagedMemoryRewritePass() : ModulePass(ID) {} 373 virtual bool runOnModule(Module &M) { 374 const DataLayout &DL = M.getDataLayout(); 375 376 Function *Malloc = M.getFunction("malloc"); 377 378 if (Malloc) { 379 PollyIRBuilder Builder(M.getContext()); 380 Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M); 381 assert(PollyMallocManaged && "unable to create polly_mallocManaged"); 382 383 replaceAllUsesAndConstantUses(Malloc, PollyMallocManaged, Builder); 384 Malloc->eraseFromParent(); 385 } 386 387 Function *Free = M.getFunction("free"); 388 389 if (Free) { 390 PollyIRBuilder Builder(M.getContext()); 391 Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M); 392 assert(PollyFreeManaged && "unable to create polly_freeManaged"); 393 394 replaceAllUsesAndConstantUses(Free, PollyFreeManaged, Builder); 395 Free->eraseFromParent(); 396 } 397 398 SmallPtrSet<GlobalVariable *, 4> GlobalsToErase; 399 for (GlobalVariable &Global : M.globals()) 400 replaceGlobalArray(M, DL, Global, GlobalsToErase); 401 for (GlobalVariable *G : GlobalsToErase) 402 G->eraseFromParent(); 403 404 // Rewrite allocas to cudaMallocs if we are asked to do so. 405 if (RewriteAllocas) { 406 SmallSet<AllocaInst *, 4> AllocasToBeManaged; 407 for (Function &F : M.functions()) 408 getAllocasToBeManaged(F, AllocasToBeManaged); 409 410 for (AllocaInst *Alloca : AllocasToBeManaged) 411 rewriteAllocaAsManagedMemory(Alloca, DL); 412 } 413 414 return true; 415 } 416 }; 417 } // namespace 418 char ManagedMemoryRewritePass::ID = 42; 419 420 Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch, 421 GPURuntime Runtime) { 422 ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass(); 423 pass->Runtime = Runtime; 424 pass->Architecture = Arch; 425 return pass; 426 } 427 428 INITIALIZE_PASS_BEGIN( 429 ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory", 430 "Polly - Rewrite all allocations in heap & data section to managed memory", 431 false, false) 432 INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration); 433 INITIALIZE_PASS_DEPENDENCY(DependenceInfo); 434 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass); 435 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass); 436 INITIALIZE_PASS_DEPENDENCY(RegionInfoPass); 437 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass); 438 INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass); 439 INITIALIZE_PASS_END( 440 ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory", 441 "Polly - Rewrite all allocations in heap & data section to managed memory", 442 false, false) 443