1 //===---- ManagedMemoryRewrite.cpp - Rewrite global & malloc'd memory -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Take a module and rewrite: 10 // 1. `malloc` -> `polly_mallocManaged` 11 // 2. `free` -> `polly_freeManaged` 12 // 3. global arrays with initializers -> global arrays that are initialized 13 // with a constructor call to 14 // `polly_mallocManaged`. 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "polly/CodeGen/IRBuilder.h" 19 #include "polly/CodeGen/PPCGCodeGeneration.h" 20 #include "polly/DependenceInfo.h" 21 #include "polly/LinkAllPasses.h" 22 #include "polly/Options.h" 23 #include "polly/ScopDetection.h" 24 #include "llvm/ADT/SmallSet.h" 25 #include "llvm/Analysis/CaptureTracking.h" 26 #include "llvm/InitializePasses.h" 27 #include "llvm/Transforms/Utils/ModuleUtils.h" 28 29 using namespace llvm; 30 using namespace polly; 31 32 static cl::opt<bool> RewriteAllocas( 33 "polly-acc-rewrite-allocas", 34 cl::desc( 35 "Ask the managed memory rewriter to also rewrite alloca instructions"), 36 cl::Hidden, cl::cat(PollyCategory)); 37 38 static cl::opt<bool> IgnoreLinkageForGlobals( 39 "polly-acc-rewrite-ignore-linkage-for-globals", 40 cl::desc( 41 "By default, we only rewrite globals with internal linkage. This flag " 42 "enables rewriting of globals regardless of linkage"), 43 cl::Hidden, cl::cat(PollyCategory)); 44 45 #define DEBUG_TYPE "polly-acc-rewrite-managed-memory" 46 namespace { 47 48 static llvm::Function *getOrCreatePollyMallocManaged(Module &M) { 49 const char *Name = "polly_mallocManaged"; 50 Function *F = M.getFunction(Name); 51 52 // If F is not available, declare it. 53 if (!F) { 54 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 55 PollyIRBuilder Builder(M.getContext()); 56 // TODO: How do I get `size_t`? I assume from DataLayout? 57 FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(), 58 {Builder.getInt64Ty()}, false); 59 F = Function::Create(Ty, Linkage, Name, &M); 60 } 61 62 return F; 63 } 64 65 static llvm::Function *getOrCreatePollyFreeManaged(Module &M) { 66 const char *Name = "polly_freeManaged"; 67 Function *F = M.getFunction(Name); 68 69 // If F is not available, declare it. 70 if (!F) { 71 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 72 PollyIRBuilder Builder(M.getContext()); 73 // TODO: How do I get `size_t`? I assume from DataLayout? 74 FunctionType *Ty = 75 FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false); 76 F = Function::Create(Ty, Linkage, Name, &M); 77 } 78 79 return F; 80 } 81 82 // Expand a constant expression `Cur`, which is used at instruction `Parent` 83 // at index `index`. 84 // Since a constant expression can expand to multiple instructions, store all 85 // the expands into a set called `Expands`. 86 // Note that this goes inorder on the constant expression tree. 87 // A * ((B * D) + C) 88 // will be processed with first A, then B * D, then B, then D, and then C. 89 // Though ConstantExprs are not treated as "trees" but as DAGs, since you can 90 // have something like this: 91 // * 92 // / \ 93 // \ / 94 // (D) 95 // 96 // For the purposes of this expansion, we expand the two occurences of D 97 // separately. Therefore, we expand the DAG into the tree: 98 // * 99 // / \ 100 // D D 101 // TODO: We don't _have_to do this, but this is the simplest solution. 102 // We can write a solution that keeps track of which constants have been 103 // already expanded. 104 static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder, 105 Instruction *Parent, int index, 106 SmallPtrSet<Instruction *, 4> &Expands) { 107 assert(Cur && "invalid constant expression passed"); 108 Instruction *I = Cur->getAsInstruction(); 109 assert(I && "unable to convert ConstantExpr to Instruction"); 110 111 LLVM_DEBUG(dbgs() << "Expanding ConstantExpression: (" << *Cur 112 << ") in Instruction: (" << *I << ")\n";); 113 114 // Invalidate `Cur` so that no one after this point uses `Cur`. Rather, 115 // they should mutate `I`. 116 Cur = nullptr; 117 118 Expands.insert(I); 119 Parent->setOperand(index, I); 120 121 // The things that `Parent` uses (its operands) should be created 122 // before `Parent`. 123 Builder.SetInsertPoint(Parent); 124 Builder.Insert(I); 125 126 for (unsigned i = 0; i < I->getNumOperands(); i++) { 127 Value *Op = I->getOperand(i); 128 assert(isa<Constant>(Op) && "constant must have a constant operand"); 129 130 if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op)) 131 expandConstantExpr(CExprOp, Builder, I, i, Expands); 132 } 133 } 134 135 // Edit all uses of `OldVal` to NewVal` in `Inst`. This will rewrite 136 // `ConstantExpr`s that are used in the `Inst`. 137 // Note that `replaceAllUsesWith` is insufficient for this purpose because it 138 // does not rewrite values in `ConstantExpr`s. 139 static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal, 140 PollyIRBuilder &Builder) { 141 142 // This contains a set of instructions in which OldVal must be replaced. 143 // We start with `Inst`, and we fill it up with the expanded `ConstantExpr`s 144 // from `Inst`s arguments. 145 // We need to go through this process because `replaceAllUsesWith` does not 146 // actually edit `ConstantExpr`s. 147 SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst}; 148 149 // Expand all `ConstantExpr`s and place it in `InstsToVisit`. 150 for (unsigned i = 0; i < Inst->getNumOperands(); i++) { 151 Value *Operand = Inst->getOperand(i); 152 if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand)) 153 expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit); 154 } 155 156 // Now visit each instruction and use `replaceUsesOfWith`. We know that 157 // will work because `I` cannot have any `ConstantExpr` within it. 158 for (Instruction *I : InstsToVisit) 159 I->replaceUsesOfWith(OldVal, NewVal); 160 } 161 162 // Given a value `Current`, return all Instructions that may contain `Current` 163 // in an expression. 164 // We need this auxiliary function, because if we have a 165 // `Constant` that is a user of `V`, we need to recurse into the 166 // `Constant`s uses to gather the root instruciton. 167 static void getInstructionUsersOfValue(Value *V, 168 SmallVector<Instruction *, 4> &Owners) { 169 if (auto *I = dyn_cast<Instruction>(V)) { 170 Owners.push_back(I); 171 } else { 172 // Anything that is a `User` must be a constant or an instruction. 173 auto *C = cast<Constant>(V); 174 for (Use &CUse : C->uses()) 175 getInstructionUsersOfValue(CUse.getUser(), Owners); 176 } 177 } 178 179 static void 180 replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array, 181 SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) { 182 // We only want arrays. 183 ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getValueType()); 184 if (!ArrayTy) 185 return; 186 Type *ElemTy = ArrayTy->getElementType(); 187 PointerType *ElemPtrTy = ElemTy->getPointerTo(); 188 189 // We only wish to replace arrays that are visible in the module they 190 // inhabit. Otherwise, our type edit from [T] to T* would be illegal across 191 // modules. 192 const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() || 193 Array.hasInternalLinkage() || 194 IgnoreLinkageForGlobals; 195 if (!OnlyVisibleInsideModule) { 196 LLVM_DEBUG( 197 dbgs() << "Not rewriting (" << Array 198 << ") to managed memory " 199 "because it could be visible externally. To force rewrite, " 200 "use -polly-acc-rewrite-ignore-linkage-for-globals.\n"); 201 return; 202 } 203 204 if (!Array.hasInitializer() || 205 !isa<ConstantAggregateZero>(Array.getInitializer())) { 206 LLVM_DEBUG(dbgs() << "Not rewriting (" << Array 207 << ") to managed memory " 208 "because it has an initializer which is " 209 "not a zeroinitializer.\n"); 210 return; 211 } 212 213 // At this point, we have committed to replacing this array. 214 ReplacedGlobals.insert(&Array); 215 216 std::string NewName = Array.getName().str(); 217 NewName += ".toptr"; 218 GlobalVariable *ReplacementToArr = 219 cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy)); 220 ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy)); 221 222 Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M); 223 std::string FnName = Array.getName().str(); 224 FnName += ".constructor"; 225 PollyIRBuilder Builder(M.getContext()); 226 FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false); 227 const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage; 228 Function *F = Function::Create(Ty, Linkage, FnName, &M); 229 BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F); 230 Builder.SetInsertPoint(Start); 231 232 const uint64_t ArraySizeInt = DL.getTypeAllocSize(ArrayTy); 233 Value *ArraySize = Builder.getInt64(ArraySizeInt); 234 ArraySize->setName("array.size"); 235 236 Value *AllocatedMemRaw = 237 Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw"); 238 Value *AllocatedMemTyped = 239 Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed"); 240 Builder.CreateStore(AllocatedMemTyped, ReplacementToArr); 241 Builder.CreateRetVoid(); 242 243 const int Priority = 0; 244 appendToGlobalCtors(M, F, Priority, ReplacementToArr); 245 246 SmallVector<Instruction *, 4> ArrayUserInstructions; 247 // Get all instructions that use array. We need to do this weird thing 248 // because `Constant`s that contain this array neeed to be expanded into 249 // instructions so that we can replace their parameters. `Constant`s cannot 250 // be edited easily, so we choose to convert all `Constant`s to 251 // `Instruction`s and handle all of the uses of `Array` uniformly. 252 for (Use &ArrayUse : Array.uses()) 253 getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions); 254 255 for (Instruction *UserOfArrayInst : ArrayUserInstructions) { 256 257 Builder.SetInsertPoint(UserOfArrayInst); 258 // <ty>** -> <ty>* 259 Value *ArrPtrLoaded = 260 Builder.CreateLoad(ElemPtrTy, ReplacementToArr, "arrptr.load"); 261 // <ty>* -> [ty]* 262 Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast( 263 ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast"); 264 rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder); 265 } 266 } 267 268 // We return all `allocas` that may need to be converted to a call to 269 // cudaMallocManaged. 270 static void getAllocasToBeManaged(Function &F, 271 SmallSet<AllocaInst *, 4> &Allocas) { 272 for (BasicBlock &BB : F) { 273 for (Instruction &I : BB) { 274 auto *Alloca = dyn_cast<AllocaInst>(&I); 275 if (!Alloca) 276 continue; 277 LLVM_DEBUG(dbgs() << "Checking if (" << *Alloca << ") may be captured: "); 278 279 if (PointerMayBeCaptured(Alloca, /* ReturnCaptures */ false, 280 /* StoreCaptures */ true)) { 281 Allocas.insert(Alloca); 282 LLVM_DEBUG(dbgs() << "YES (captured).\n"); 283 } else { 284 LLVM_DEBUG(dbgs() << "NO (not captured).\n"); 285 } 286 } 287 } 288 } 289 290 static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca, 291 const DataLayout &DL) { 292 LLVM_DEBUG(dbgs() << "rewriting: (" << *Alloca << ") to managed mem.\n"); 293 Module *M = Alloca->getModule(); 294 assert(M && "Alloca does not have a module"); 295 296 PollyIRBuilder Builder(M->getContext()); 297 Builder.SetInsertPoint(Alloca); 298 299 Function *MallocManagedFn = 300 getOrCreatePollyMallocManaged(*Alloca->getModule()); 301 const uint64_t Size = DL.getTypeAllocSize(Alloca->getAllocatedType()); 302 Value *SizeVal = Builder.getInt64(Size); 303 Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal}); 304 Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType()); 305 306 Function *F = Alloca->getFunction(); 307 assert(F && "Alloca has invalid function"); 308 309 Bitcasted->takeName(Alloca); 310 Alloca->replaceAllUsesWith(Bitcasted); 311 Alloca->eraseFromParent(); 312 313 for (BasicBlock &BB : *F) { 314 ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator()); 315 if (!Return) 316 continue; 317 Builder.SetInsertPoint(Return); 318 319 Function *FreeManagedFn = getOrCreatePollyFreeManaged(*M); 320 Builder.CreateCall(FreeManagedFn, {RawManagedMem}); 321 } 322 } 323 324 // Replace all uses of `Old` with `New`, even inside `ConstantExpr`. 325 // 326 // `replaceAllUsesWith` does replace values in `ConstantExpr`. This function 327 // actually does replace it in `ConstantExpr`. The caveat is that if there is 328 // a use that is *outside* a function (say, at global declarations), we fail. 329 // So, this is meant to be used on values which we know will only be used 330 // within functions. 331 // 332 // This process works by looking through the uses of `Old`. If it finds a 333 // `ConstantExpr`, it recursively looks for the owning instruction. 334 // Then, it expands all the `ConstantExpr` to instructions and replaces 335 // `Old` with `New` in the expanded instructions. 336 static void replaceAllUsesAndConstantUses(Value *Old, Value *New, 337 PollyIRBuilder &Builder) { 338 SmallVector<Instruction *, 4> UserInstructions; 339 // Get all instructions that use array. We need to do this weird thing 340 // because `Constant`s that contain this array neeed to be expanded into 341 // instructions so that we can replace their parameters. `Constant`s cannot 342 // be edited easily, so we choose to convert all `Constant`s to 343 // `Instruction`s and handle all of the uses of `Array` uniformly. 344 for (Use &ArrayUse : Old->uses()) 345 getInstructionUsersOfValue(ArrayUse.getUser(), UserInstructions); 346 347 for (Instruction *I : UserInstructions) 348 rewriteOldValToNew(I, Old, New, Builder); 349 } 350 351 class ManagedMemoryRewritePass final : public ModulePass { 352 public: 353 static char ID; 354 GPUArch Architecture; 355 GPURuntime Runtime; 356 357 ManagedMemoryRewritePass() : ModulePass(ID) {} 358 bool runOnModule(Module &M) override { 359 const DataLayout &DL = M.getDataLayout(); 360 361 Function *Malloc = M.getFunction("malloc"); 362 363 if (Malloc) { 364 PollyIRBuilder Builder(M.getContext()); 365 Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M); 366 assert(PollyMallocManaged && "unable to create polly_mallocManaged"); 367 368 replaceAllUsesAndConstantUses(Malloc, PollyMallocManaged, Builder); 369 Malloc->eraseFromParent(); 370 } 371 372 Function *Free = M.getFunction("free"); 373 374 if (Free) { 375 PollyIRBuilder Builder(M.getContext()); 376 Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M); 377 assert(PollyFreeManaged && "unable to create polly_freeManaged"); 378 379 replaceAllUsesAndConstantUses(Free, PollyFreeManaged, Builder); 380 Free->eraseFromParent(); 381 } 382 383 SmallPtrSet<GlobalVariable *, 4> GlobalsToErase; 384 for (GlobalVariable &Global : M.globals()) 385 replaceGlobalArray(M, DL, Global, GlobalsToErase); 386 for (GlobalVariable *G : GlobalsToErase) 387 G->eraseFromParent(); 388 389 // Rewrite allocas to cudaMallocs if we are asked to do so. 390 if (RewriteAllocas) { 391 SmallSet<AllocaInst *, 4> AllocasToBeManaged; 392 for (Function &F : M.functions()) 393 getAllocasToBeManaged(F, AllocasToBeManaged); 394 395 for (AllocaInst *Alloca : AllocasToBeManaged) 396 rewriteAllocaAsManagedMemory(Alloca, DL); 397 } 398 399 return true; 400 } 401 }; 402 } // namespace 403 char ManagedMemoryRewritePass::ID = 42; 404 405 Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch, 406 GPURuntime Runtime) { 407 ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass(); 408 pass->Runtime = Runtime; 409 pass->Architecture = Arch; 410 return pass; 411 } 412 413 INITIALIZE_PASS_BEGIN( 414 ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory", 415 "Polly - Rewrite all allocations in heap & data section to managed memory", 416 false, false) 417 INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration); 418 INITIALIZE_PASS_DEPENDENCY(DependenceInfo); 419 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass); 420 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass); 421 INITIALIZE_PASS_DEPENDENCY(RegionInfoPass); 422 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass); 423 INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass); 424 INITIALIZE_PASS_END( 425 ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory", 426 "Polly - Rewrite all allocations in heap & data section to managed memory", 427 false, false) 428