1 //===---- ManagedMemoryRewrite.cpp - Rewrite global & malloc'd memory -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Take a module and rewrite: 10 // 1. `malloc` -> `polly_mallocManaged` 11 // 2. `free` -> `polly_freeManaged` 12 // 3. global arrays with initializers -> global arrays that are initialized 13 // with a constructor call to 14 // `polly_mallocManaged`. 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "polly/CodeGen/IRBuilder.h" 19 #include "polly/CodeGen/PPCGCodeGeneration.h" 20 #include "polly/DependenceInfo.h" 21 #include "polly/LinkAllPasses.h" 22 #include "polly/Options.h" 23 #include "polly/ScopDetection.h" 24 #include "llvm/ADT/SmallSet.h" 25 #include "llvm/Analysis/CaptureTracking.h" 26 #include "llvm/InitializePasses.h" 27 #include "llvm/Transforms/Utils/ModuleUtils.h" 28 29 using namespace polly; 30 31 static cl::opt<bool> RewriteAllocas( 32 "polly-acc-rewrite-allocas", 33 cl::desc( 34 "Ask the managed memory rewriter to also rewrite alloca instructions"), 35 cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 36 37 static cl::opt<bool> IgnoreLinkageForGlobals( 38 "polly-acc-rewrite-ignore-linkage-for-globals", 39 cl::desc( 40 "By default, we only rewrite globals with internal linkage. This flag " 41 "enables rewriting of globals regardless of linkage"), 42 cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 43 44 #define DEBUG_TYPE "polly-acc-rewrite-managed-memory" 45 namespace { 46 47 static llvm::Function *getOrCreatePollyMallocManaged(Module &M) { 48 const char *Name = "polly_mallocManaged"; 49 Function *F = M.getFunction(Name); 50 51 // If F is not available, declare it. 52 if (!F) { 53 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 54 PollyIRBuilder Builder(M.getContext()); 55 // TODO: How do I get `size_t`? I assume from DataLayout? 56 FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(), 57 {Builder.getInt64Ty()}, false); 58 F = Function::Create(Ty, Linkage, Name, &M); 59 } 60 61 return F; 62 } 63 64 static llvm::Function *getOrCreatePollyFreeManaged(Module &M) { 65 const char *Name = "polly_freeManaged"; 66 Function *F = M.getFunction(Name); 67 68 // If F is not available, declare it. 69 if (!F) { 70 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 71 PollyIRBuilder Builder(M.getContext()); 72 // TODO: How do I get `size_t`? I assume from DataLayout? 73 FunctionType *Ty = 74 FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false); 75 F = Function::Create(Ty, Linkage, Name, &M); 76 } 77 78 return F; 79 } 80 81 // Expand a constant expression `Cur`, which is used at instruction `Parent` 82 // at index `index`. 83 // Since a constant expression can expand to multiple instructions, store all 84 // the expands into a set called `Expands`. 85 // Note that this goes inorder on the constant expression tree. 86 // A * ((B * D) + C) 87 // will be processed with first A, then B * D, then B, then D, and then C. 88 // Though ConstantExprs are not treated as "trees" but as DAGs, since you can 89 // have something like this: 90 // * 91 // / \ 92 // \ / 93 // (D) 94 // 95 // For the purposes of this expansion, we expand the two occurences of D 96 // separately. Therefore, we expand the DAG into the tree: 97 // * 98 // / \ 99 // D D 100 // TODO: We don't _have_to do this, but this is the simplest solution. 101 // We can write a solution that keeps track of which constants have been 102 // already expanded. 103 static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder, 104 Instruction *Parent, int index, 105 SmallPtrSet<Instruction *, 4> &Expands) { 106 assert(Cur && "invalid constant expression passed"); 107 Instruction *I = Cur->getAsInstruction(); 108 assert(I && "unable to convert ConstantExpr to Instruction"); 109 110 LLVM_DEBUG(dbgs() << "Expanding ConstantExpression: (" << *Cur 111 << ") in Instruction: (" << *I << ")\n";); 112 113 // Invalidate `Cur` so that no one after this point uses `Cur`. Rather, 114 // they should mutate `I`. 115 Cur = nullptr; 116 117 Expands.insert(I); 118 Parent->setOperand(index, I); 119 120 // The things that `Parent` uses (its operands) should be created 121 // before `Parent`. 122 Builder.SetInsertPoint(Parent); 123 Builder.Insert(I); 124 125 for (unsigned i = 0; i < I->getNumOperands(); i++) { 126 Value *Op = I->getOperand(i); 127 assert(isa<Constant>(Op) && "constant must have a constant operand"); 128 129 if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op)) 130 expandConstantExpr(CExprOp, Builder, I, i, Expands); 131 } 132 } 133 134 // Edit all uses of `OldVal` to NewVal` in `Inst`. This will rewrite 135 // `ConstantExpr`s that are used in the `Inst`. 136 // Note that `replaceAllUsesWith` is insufficient for this purpose because it 137 // does not rewrite values in `ConstantExpr`s. 138 static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal, 139 PollyIRBuilder &Builder) { 140 141 // This contains a set of instructions in which OldVal must be replaced. 142 // We start with `Inst`, and we fill it up with the expanded `ConstantExpr`s 143 // from `Inst`s arguments. 144 // We need to go through this process because `replaceAllUsesWith` does not 145 // actually edit `ConstantExpr`s. 146 SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst}; 147 148 // Expand all `ConstantExpr`s and place it in `InstsToVisit`. 149 for (unsigned i = 0; i < Inst->getNumOperands(); i++) { 150 Value *Operand = Inst->getOperand(i); 151 if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand)) 152 expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit); 153 } 154 155 // Now visit each instruction and use `replaceUsesOfWith`. We know that 156 // will work because `I` cannot have any `ConstantExpr` within it. 157 for (Instruction *I : InstsToVisit) 158 I->replaceUsesOfWith(OldVal, NewVal); 159 } 160 161 // Given a value `Current`, return all Instructions that may contain `Current` 162 // in an expression. 163 // We need this auxiliary function, because if we have a 164 // `Constant` that is a user of `V`, we need to recurse into the 165 // `Constant`s uses to gather the root instruciton. 166 static void getInstructionUsersOfValue(Value *V, 167 SmallVector<Instruction *, 4> &Owners) { 168 if (auto *I = dyn_cast<Instruction>(V)) { 169 Owners.push_back(I); 170 } else { 171 // Anything that is a `User` must be a constant or an instruction. 172 auto *C = cast<Constant>(V); 173 for (Use &CUse : C->uses()) 174 getInstructionUsersOfValue(CUse.getUser(), Owners); 175 } 176 } 177 178 static void 179 replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array, 180 SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) { 181 // We only want arrays. 182 ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getType()->getElementType()); 183 if (!ArrayTy) 184 return; 185 Type *ElemTy = ArrayTy->getElementType(); 186 PointerType *ElemPtrTy = ElemTy->getPointerTo(); 187 188 // We only wish to replace arrays that are visible in the module they 189 // inhabit. Otherwise, our type edit from [T] to T* would be illegal across 190 // modules. 191 const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() || 192 Array.hasInternalLinkage() || 193 IgnoreLinkageForGlobals; 194 if (!OnlyVisibleInsideModule) { 195 LLVM_DEBUG( 196 dbgs() << "Not rewriting (" << Array 197 << ") to managed memory " 198 "because it could be visible externally. To force rewrite, " 199 "use -polly-acc-rewrite-ignore-linkage-for-globals.\n"); 200 return; 201 } 202 203 if (!Array.hasInitializer() || 204 !isa<ConstantAggregateZero>(Array.getInitializer())) { 205 LLVM_DEBUG(dbgs() << "Not rewriting (" << Array 206 << ") to managed memory " 207 "because it has an initializer which is " 208 "not a zeroinitializer.\n"); 209 return; 210 } 211 212 // At this point, we have committed to replacing this array. 213 ReplacedGlobals.insert(&Array); 214 215 std::string NewName = Array.getName().str(); 216 NewName += ".toptr"; 217 GlobalVariable *ReplacementToArr = 218 cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy)); 219 ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy)); 220 221 Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M); 222 std::string FnName = Array.getName().str(); 223 FnName += ".constructor"; 224 PollyIRBuilder Builder(M.getContext()); 225 FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false); 226 const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 227 Function *F = Function::Create(Ty, Linkage, FnName, &M); 228 BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F); 229 Builder.SetInsertPoint(Start); 230 231 const uint64_t ArraySizeInt = DL.getTypeAllocSize(ArrayTy); 232 Value *ArraySize = Builder.getInt64(ArraySizeInt); 233 ArraySize->setName("array.size"); 234 235 Value *AllocatedMemRaw = 236 Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw"); 237 Value *AllocatedMemTyped = 238 Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed"); 239 Builder.CreateStore(AllocatedMemTyped, ReplacementToArr); 240 Builder.CreateRetVoid(); 241 242 const int Priority = 0; 243 appendToGlobalCtors(M, F, Priority, ReplacementToArr); 244 245 SmallVector<Instruction *, 4> ArrayUserInstructions; 246 // Get all instructions that use array. We need to do this weird thing 247 // because `Constant`s that contain this array neeed to be expanded into 248 // instructions so that we can replace their parameters. `Constant`s cannot 249 // be edited easily, so we choose to convert all `Constant`s to 250 // `Instruction`s and handle all of the uses of `Array` uniformly. 251 for (Use &ArrayUse : Array.uses()) 252 getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions); 253 254 for (Instruction *UserOfArrayInst : ArrayUserInstructions) { 255 256 Builder.SetInsertPoint(UserOfArrayInst); 257 // <ty>** -> <ty>* 258 Value *ArrPtrLoaded = Builder.CreateLoad(ReplacementToArr, "arrptr.load"); 259 // <ty>* -> [ty]* 260 Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast( 261 ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast"); 262 rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder); 263 } 264 } 265 266 // We return all `allocas` that may need to be converted to a call to 267 // cudaMallocManaged. 268 static void getAllocasToBeManaged(Function &F, 269 SmallSet<AllocaInst *, 4> &Allocas) { 270 for (BasicBlock &BB : F) { 271 for (Instruction &I : BB) { 272 auto *Alloca = dyn_cast<AllocaInst>(&I); 273 if (!Alloca) 274 continue; 275 LLVM_DEBUG(dbgs() << "Checking if (" << *Alloca << ") may be captured: "); 276 277 if (PointerMayBeCaptured(Alloca, /* ReturnCaptures */ false, 278 /* StoreCaptures */ true)) { 279 Allocas.insert(Alloca); 280 LLVM_DEBUG(dbgs() << "YES (captured).\n"); 281 } else { 282 LLVM_DEBUG(dbgs() << "NO (not captured).\n"); 283 } 284 } 285 } 286 } 287 288 static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca, 289 const DataLayout &DL) { 290 LLVM_DEBUG(dbgs() << "rewriting: (" << *Alloca << ") to managed mem.\n"); 291 Module *M = Alloca->getModule(); 292 assert(M && "Alloca does not have a module"); 293 294 PollyIRBuilder Builder(M->getContext()); 295 Builder.SetInsertPoint(Alloca); 296 297 Function *MallocManagedFn = 298 getOrCreatePollyMallocManaged(*Alloca->getModule()); 299 const uint64_t Size = 300 DL.getTypeAllocSize(Alloca->getType()->getElementType()); 301 Value *SizeVal = Builder.getInt64(Size); 302 Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal}); 303 Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType()); 304 305 Function *F = Alloca->getFunction(); 306 assert(F && "Alloca has invalid function"); 307 308 Bitcasted->takeName(Alloca); 309 Alloca->replaceAllUsesWith(Bitcasted); 310 Alloca->eraseFromParent(); 311 312 for (BasicBlock &BB : *F) { 313 ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator()); 314 if (!Return) 315 continue; 316 Builder.SetInsertPoint(Return); 317 318 Function *FreeManagedFn = getOrCreatePollyFreeManaged(*M); 319 Builder.CreateCall(FreeManagedFn, {RawManagedMem}); 320 } 321 } 322 323 // Replace all uses of `Old` with `New`, even inside `ConstantExpr`. 324 // 325 // `replaceAllUsesWith` does replace values in `ConstantExpr`. This function 326 // actually does replace it in `ConstantExpr`. The caveat is that if there is 327 // a use that is *outside* a function (say, at global declarations), we fail. 328 // So, this is meant to be used on values which we know will only be used 329 // within functions. 330 // 331 // This process works by looking through the uses of `Old`. If it finds a 332 // `ConstantExpr`, it recursively looks for the owning instruction. 333 // Then, it expands all the `ConstantExpr` to instructions and replaces 334 // `Old` with `New` in the expanded instructions. 335 static void replaceAllUsesAndConstantUses(Value *Old, Value *New, 336 PollyIRBuilder &Builder) { 337 SmallVector<Instruction *, 4> UserInstructions; 338 // Get all instructions that use array. We need to do this weird thing 339 // because `Constant`s that contain this array neeed to be expanded into 340 // instructions so that we can replace their parameters. `Constant`s cannot 341 // be edited easily, so we choose to convert all `Constant`s to 342 // `Instruction`s and handle all of the uses of `Array` uniformly. 343 for (Use &ArrayUse : Old->uses()) 344 getInstructionUsersOfValue(ArrayUse.getUser(), UserInstructions); 345 346 for (Instruction *I : UserInstructions) 347 rewriteOldValToNew(I, Old, New, Builder); 348 } 349 350 class ManagedMemoryRewritePass : public ModulePass { 351 public: 352 static char ID; 353 GPUArch Architecture; 354 GPURuntime Runtime; 355 356 ManagedMemoryRewritePass() : ModulePass(ID) {} 357 virtual bool runOnModule(Module &M) { 358 const DataLayout &DL = M.getDataLayout(); 359 360 Function *Malloc = M.getFunction("malloc"); 361 362 if (Malloc) { 363 PollyIRBuilder Builder(M.getContext()); 364 Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M); 365 assert(PollyMallocManaged && "unable to create polly_mallocManaged"); 366 367 replaceAllUsesAndConstantUses(Malloc, PollyMallocManaged, Builder); 368 Malloc->eraseFromParent(); 369 } 370 371 Function *Free = M.getFunction("free"); 372 373 if (Free) { 374 PollyIRBuilder Builder(M.getContext()); 375 Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M); 376 assert(PollyFreeManaged && "unable to create polly_freeManaged"); 377 378 replaceAllUsesAndConstantUses(Free, PollyFreeManaged, Builder); 379 Free->eraseFromParent(); 380 } 381 382 SmallPtrSet<GlobalVariable *, 4> GlobalsToErase; 383 for (GlobalVariable &Global : M.globals()) 384 replaceGlobalArray(M, DL, Global, GlobalsToErase); 385 for (GlobalVariable *G : GlobalsToErase) 386 G->eraseFromParent(); 387 388 // Rewrite allocas to cudaMallocs if we are asked to do so. 389 if (RewriteAllocas) { 390 SmallSet<AllocaInst *, 4> AllocasToBeManaged; 391 for (Function &F : M.functions()) 392 getAllocasToBeManaged(F, AllocasToBeManaged); 393 394 for (AllocaInst *Alloca : AllocasToBeManaged) 395 rewriteAllocaAsManagedMemory(Alloca, DL); 396 } 397 398 return true; 399 } 400 }; 401 } // namespace 402 char ManagedMemoryRewritePass::ID = 42; 403 404 Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch, 405 GPURuntime Runtime) { 406 ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass(); 407 pass->Runtime = Runtime; 408 pass->Architecture = Arch; 409 return pass; 410 } 411 412 INITIALIZE_PASS_BEGIN( 413 ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory", 414 "Polly - Rewrite all allocations in heap & data section to managed memory", 415 false, false) 416 INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration); 417 INITIALIZE_PASS_DEPENDENCY(DependenceInfo); 418 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass); 419 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass); 420 INITIALIZE_PASS_DEPENDENCY(RegionInfoPass); 421 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass); 422 INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass); 423 INITIALIZE_PASS_END( 424 ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory", 425 "Polly - Rewrite all allocations in heap & data section to managed memory", 426 false, false) 427