1 //===---- ManagedMemoryRewrite.cpp - Rewrite global & malloc'd memory -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Take a module and rewrite: 10 // 1. `malloc` -> `polly_mallocManaged` 11 // 2. `free` -> `polly_freeManaged` 12 // 3. global arrays with initializers -> global arrays that are initialized 13 // with a constructor call to 14 // `polly_mallocManaged`. 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "polly/CodeGen/IRBuilder.h" 19 #include "polly/CodeGen/PPCGCodeGeneration.h" 20 #include "polly/DependenceInfo.h" 21 #include "polly/LinkAllPasses.h" 22 #include "polly/Options.h" 23 #include "polly/ScopDetection.h" 24 #include "llvm/ADT/SmallSet.h" 25 #include "llvm/Analysis/CaptureTracking.h" 26 #include "llvm/Transforms/Utils/ModuleUtils.h" 27 28 using namespace polly; 29 30 static cl::opt<bool> RewriteAllocas( 31 "polly-acc-rewrite-allocas", 32 cl::desc( 33 "Ask the managed memory rewriter to also rewrite alloca instructions"), 34 cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 35 36 static cl::opt<bool> IgnoreLinkageForGlobals( 37 "polly-acc-rewrite-ignore-linkage-for-globals", 38 cl::desc( 39 "By default, we only rewrite globals with internal linkage. This flag " 40 "enables rewriting of globals regardless of linkage"), 41 cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 42 43 #define DEBUG_TYPE "polly-acc-rewrite-managed-memory" 44 namespace { 45 46 static llvm::Function *getOrCreatePollyMallocManaged(Module &M) { 47 const char *Name = "polly_mallocManaged"; 48 Function *F = M.getFunction(Name); 49 50 // If F is not available, declare it. 51 if (!F) { 52 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 53 PollyIRBuilder Builder(M.getContext()); 54 // TODO: How do I get `size_t`? I assume from DataLayout? 55 FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(), 56 {Builder.getInt64Ty()}, false); 57 F = Function::Create(Ty, Linkage, Name, &M); 58 } 59 60 return F; 61 } 62 63 static llvm::Function *getOrCreatePollyFreeManaged(Module &M) { 64 const char *Name = "polly_freeManaged"; 65 Function *F = M.getFunction(Name); 66 67 // If F is not available, declare it. 68 if (!F) { 69 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 70 PollyIRBuilder Builder(M.getContext()); 71 // TODO: How do I get `size_t`? I assume from DataLayout? 72 FunctionType *Ty = 73 FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false); 74 F = Function::Create(Ty, Linkage, Name, &M); 75 } 76 77 return F; 78 } 79 80 // Expand a constant expression `Cur`, which is used at instruction `Parent` 81 // at index `index`. 82 // Since a constant expression can expand to multiple instructions, store all 83 // the expands into a set called `Expands`. 84 // Note that this goes inorder on the constant expression tree. 85 // A * ((B * D) + C) 86 // will be processed with first A, then B * D, then B, then D, and then C. 87 // Though ConstantExprs are not treated as "trees" but as DAGs, since you can 88 // have something like this: 89 // * 90 // / \ 91 // \ / 92 // (D) 93 // 94 // For the purposes of this expansion, we expand the two occurences of D 95 // separately. Therefore, we expand the DAG into the tree: 96 // * 97 // / \ 98 // D D 99 // TODO: We don't _have_to do this, but this is the simplest solution. 100 // We can write a solution that keeps track of which constants have been 101 // already expanded. 102 static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder, 103 Instruction *Parent, int index, 104 SmallPtrSet<Instruction *, 4> &Expands) { 105 assert(Cur && "invalid constant expression passed"); 106 Instruction *I = Cur->getAsInstruction(); 107 assert(I && "unable to convert ConstantExpr to Instruction"); 108 109 LLVM_DEBUG(dbgs() << "Expanding ConstantExpression: (" << *Cur 110 << ") in Instruction: (" << *I << ")\n";); 111 112 // Invalidate `Cur` so that no one after this point uses `Cur`. Rather, 113 // they should mutate `I`. 114 Cur = nullptr; 115 116 Expands.insert(I); 117 Parent->setOperand(index, I); 118 119 // The things that `Parent` uses (its operands) should be created 120 // before `Parent`. 121 Builder.SetInsertPoint(Parent); 122 Builder.Insert(I); 123 124 for (unsigned i = 0; i < I->getNumOperands(); i++) { 125 Value *Op = I->getOperand(i); 126 assert(isa<Constant>(Op) && "constant must have a constant operand"); 127 128 if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op)) 129 expandConstantExpr(CExprOp, Builder, I, i, Expands); 130 } 131 } 132 133 // Edit all uses of `OldVal` to NewVal` in `Inst`. This will rewrite 134 // `ConstantExpr`s that are used in the `Inst`. 135 // Note that `replaceAllUsesWith` is insufficient for this purpose because it 136 // does not rewrite values in `ConstantExpr`s. 137 static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal, 138 PollyIRBuilder &Builder) { 139 140 // This contains a set of instructions in which OldVal must be replaced. 141 // We start with `Inst`, and we fill it up with the expanded `ConstantExpr`s 142 // from `Inst`s arguments. 143 // We need to go through this process because `replaceAllUsesWith` does not 144 // actually edit `ConstantExpr`s. 145 SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst}; 146 147 // Expand all `ConstantExpr`s and place it in `InstsToVisit`. 148 for (unsigned i = 0; i < Inst->getNumOperands(); i++) { 149 Value *Operand = Inst->getOperand(i); 150 if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand)) 151 expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit); 152 } 153 154 // Now visit each instruction and use `replaceUsesOfWith`. We know that 155 // will work because `I` cannot have any `ConstantExpr` within it. 156 for (Instruction *I : InstsToVisit) 157 I->replaceUsesOfWith(OldVal, NewVal); 158 } 159 160 // Given a value `Current`, return all Instructions that may contain `Current` 161 // in an expression. 162 // We need this auxiliary function, because if we have a 163 // `Constant` that is a user of `V`, we need to recurse into the 164 // `Constant`s uses to gather the root instruciton. 165 static void getInstructionUsersOfValue(Value *V, 166 SmallVector<Instruction *, 4> &Owners) { 167 if (auto *I = dyn_cast<Instruction>(V)) { 168 Owners.push_back(I); 169 } else { 170 // Anything that is a `User` must be a constant or an instruction. 171 auto *C = cast<Constant>(V); 172 for (Use &CUse : C->uses()) 173 getInstructionUsersOfValue(CUse.getUser(), Owners); 174 } 175 } 176 177 static void 178 replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array, 179 SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) { 180 // We only want arrays. 181 ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getType()->getElementType()); 182 if (!ArrayTy) 183 return; 184 Type *ElemTy = ArrayTy->getElementType(); 185 PointerType *ElemPtrTy = ElemTy->getPointerTo(); 186 187 // We only wish to replace arrays that are visible in the module they 188 // inhabit. Otherwise, our type edit from [T] to T* would be illegal across 189 // modules. 190 const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() || 191 Array.hasInternalLinkage() || 192 IgnoreLinkageForGlobals; 193 if (!OnlyVisibleInsideModule) { 194 LLVM_DEBUG( 195 dbgs() << "Not rewriting (" << Array 196 << ") to managed memory " 197 "because it could be visible externally. To force rewrite, " 198 "use -polly-acc-rewrite-ignore-linkage-for-globals.\n"); 199 return; 200 } 201 202 if (!Array.hasInitializer() || 203 !isa<ConstantAggregateZero>(Array.getInitializer())) { 204 LLVM_DEBUG(dbgs() << "Not rewriting (" << Array 205 << ") to managed memory " 206 "because it has an initializer which is " 207 "not a zeroinitializer.\n"); 208 return; 209 } 210 211 // At this point, we have committed to replacing this array. 212 ReplacedGlobals.insert(&Array); 213 214 std::string NewName = Array.getName(); 215 NewName += ".toptr"; 216 GlobalVariable *ReplacementToArr = 217 cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy)); 218 ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy)); 219 220 Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M); 221 std::string FnName = Array.getName(); 222 FnName += ".constructor"; 223 PollyIRBuilder Builder(M.getContext()); 224 FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false); 225 const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 226 Function *F = Function::Create(Ty, Linkage, FnName, &M); 227 BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F); 228 Builder.SetInsertPoint(Start); 229 230 const uint64_t ArraySizeInt = DL.getTypeAllocSize(ArrayTy); 231 Value *ArraySize = Builder.getInt64(ArraySizeInt); 232 ArraySize->setName("array.size"); 233 234 Value *AllocatedMemRaw = 235 Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw"); 236 Value *AllocatedMemTyped = 237 Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed"); 238 Builder.CreateStore(AllocatedMemTyped, ReplacementToArr); 239 Builder.CreateRetVoid(); 240 241 const int Priority = 0; 242 appendToGlobalCtors(M, F, Priority, ReplacementToArr); 243 244 SmallVector<Instruction *, 4> ArrayUserInstructions; 245 // Get all instructions that use array. We need to do this weird thing 246 // because `Constant`s that contain this array neeed to be expanded into 247 // instructions so that we can replace their parameters. `Constant`s cannot 248 // be edited easily, so we choose to convert all `Constant`s to 249 // `Instruction`s and handle all of the uses of `Array` uniformly. 250 for (Use &ArrayUse : Array.uses()) 251 getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions); 252 253 for (Instruction *UserOfArrayInst : ArrayUserInstructions) { 254 255 Builder.SetInsertPoint(UserOfArrayInst); 256 // <ty>** -> <ty>* 257 Value *ArrPtrLoaded = Builder.CreateLoad(ReplacementToArr, "arrptr.load"); 258 // <ty>* -> [ty]* 259 Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast( 260 ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast"); 261 rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder); 262 } 263 } 264 265 // We return all `allocas` that may need to be converted to a call to 266 // cudaMallocManaged. 267 static void getAllocasToBeManaged(Function &F, 268 SmallSet<AllocaInst *, 4> &Allocas) { 269 for (BasicBlock &BB : F) { 270 for (Instruction &I : BB) { 271 auto *Alloca = dyn_cast<AllocaInst>(&I); 272 if (!Alloca) 273 continue; 274 LLVM_DEBUG(dbgs() << "Checking if (" << *Alloca << ") may be captured: "); 275 276 if (PointerMayBeCaptured(Alloca, /* ReturnCaptures */ false, 277 /* StoreCaptures */ true)) { 278 Allocas.insert(Alloca); 279 LLVM_DEBUG(dbgs() << "YES (captured).\n"); 280 } else { 281 LLVM_DEBUG(dbgs() << "NO (not captured).\n"); 282 } 283 } 284 } 285 } 286 287 static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca, 288 const DataLayout &DL) { 289 LLVM_DEBUG(dbgs() << "rewriting: (" << *Alloca << ") to managed mem.\n"); 290 Module *M = Alloca->getModule(); 291 assert(M && "Alloca does not have a module"); 292 293 PollyIRBuilder Builder(M->getContext()); 294 Builder.SetInsertPoint(Alloca); 295 296 Function *MallocManagedFn = 297 getOrCreatePollyMallocManaged(*Alloca->getModule()); 298 const uint64_t Size = 299 DL.getTypeAllocSize(Alloca->getType()->getElementType()); 300 Value *SizeVal = Builder.getInt64(Size); 301 Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal}); 302 Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType()); 303 304 Function *F = Alloca->getFunction(); 305 assert(F && "Alloca has invalid function"); 306 307 Bitcasted->takeName(Alloca); 308 Alloca->replaceAllUsesWith(Bitcasted); 309 Alloca->eraseFromParent(); 310 311 for (BasicBlock &BB : *F) { 312 ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator()); 313 if (!Return) 314 continue; 315 Builder.SetInsertPoint(Return); 316 317 Function *FreeManagedFn = getOrCreatePollyFreeManaged(*M); 318 Builder.CreateCall(FreeManagedFn, {RawManagedMem}); 319 } 320 } 321 322 // Replace all uses of `Old` with `New`, even inside `ConstantExpr`. 323 // 324 // `replaceAllUsesWith` does replace values in `ConstantExpr`. This function 325 // actually does replace it in `ConstantExpr`. The caveat is that if there is 326 // a use that is *outside* a function (say, at global declarations), we fail. 327 // So, this is meant to be used on values which we know will only be used 328 // within functions. 329 // 330 // This process works by looking through the uses of `Old`. If it finds a 331 // `ConstantExpr`, it recursively looks for the owning instruction. 332 // Then, it expands all the `ConstantExpr` to instructions and replaces 333 // `Old` with `New` in the expanded instructions. 334 static void replaceAllUsesAndConstantUses(Value *Old, Value *New, 335 PollyIRBuilder &Builder) { 336 SmallVector<Instruction *, 4> UserInstructions; 337 // Get all instructions that use array. We need to do this weird thing 338 // because `Constant`s that contain this array neeed to be expanded into 339 // instructions so that we can replace their parameters. `Constant`s cannot 340 // be edited easily, so we choose to convert all `Constant`s to 341 // `Instruction`s and handle all of the uses of `Array` uniformly. 342 for (Use &ArrayUse : Old->uses()) 343 getInstructionUsersOfValue(ArrayUse.getUser(), UserInstructions); 344 345 for (Instruction *I : UserInstructions) 346 rewriteOldValToNew(I, Old, New, Builder); 347 } 348 349 class ManagedMemoryRewritePass : public ModulePass { 350 public: 351 static char ID; 352 GPUArch Architecture; 353 GPURuntime Runtime; 354 355 ManagedMemoryRewritePass() : ModulePass(ID) {} 356 virtual bool runOnModule(Module &M) { 357 const DataLayout &DL = M.getDataLayout(); 358 359 Function *Malloc = M.getFunction("malloc"); 360 361 if (Malloc) { 362 PollyIRBuilder Builder(M.getContext()); 363 Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M); 364 assert(PollyMallocManaged && "unable to create polly_mallocManaged"); 365 366 replaceAllUsesAndConstantUses(Malloc, PollyMallocManaged, Builder); 367 Malloc->eraseFromParent(); 368 } 369 370 Function *Free = M.getFunction("free"); 371 372 if (Free) { 373 PollyIRBuilder Builder(M.getContext()); 374 Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M); 375 assert(PollyFreeManaged && "unable to create polly_freeManaged"); 376 377 replaceAllUsesAndConstantUses(Free, PollyFreeManaged, Builder); 378 Free->eraseFromParent(); 379 } 380 381 SmallPtrSet<GlobalVariable *, 4> GlobalsToErase; 382 for (GlobalVariable &Global : M.globals()) 383 replaceGlobalArray(M, DL, Global, GlobalsToErase); 384 for (GlobalVariable *G : GlobalsToErase) 385 G->eraseFromParent(); 386 387 // Rewrite allocas to cudaMallocs if we are asked to do so. 388 if (RewriteAllocas) { 389 SmallSet<AllocaInst *, 4> AllocasToBeManaged; 390 for (Function &F : M.functions()) 391 getAllocasToBeManaged(F, AllocasToBeManaged); 392 393 for (AllocaInst *Alloca : AllocasToBeManaged) 394 rewriteAllocaAsManagedMemory(Alloca, DL); 395 } 396 397 return true; 398 } 399 }; 400 } // namespace 401 char ManagedMemoryRewritePass::ID = 42; 402 403 Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch, 404 GPURuntime Runtime) { 405 ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass(); 406 pass->Runtime = Runtime; 407 pass->Architecture = Arch; 408 return pass; 409 } 410 411 INITIALIZE_PASS_BEGIN( 412 ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory", 413 "Polly - Rewrite all allocations in heap & data section to managed memory", 414 false, false) 415 INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration); 416 INITIALIZE_PASS_DEPENDENCY(DependenceInfo); 417 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass); 418 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass); 419 INITIALIZE_PASS_DEPENDENCY(RegionInfoPass); 420 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass); 421 INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass); 422 INITIALIZE_PASS_END( 423 ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory", 424 "Polly - Rewrite all allocations in heap & data section to managed memory", 425 false, false) 426