1 //===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // This pass builds the coroutine frame and outlines resume and destroy parts 10 // of the coroutine into separate functions. 11 // 12 // We present a coroutine to an LLVM as an ordinary function with suspension 13 // points marked up with intrinsics. We let the optimizer party on the coroutine 14 // as a single function for as long as possible. Shortly before the coroutine is 15 // eligible to be inlined into its callers, we split up the coroutine into parts 16 // corresponding to an initial, resume and destroy invocations of the coroutine, 17 // add them to the current SCC and restart the IPO pipeline to optimize the 18 // coroutine subfunctions we extracted before proceeding to the caller of the 19 // coroutine. 20 //===----------------------------------------------------------------------===// 21 22 #include "CoroInternal.h" 23 #include "llvm/Analysis/CallGraphSCCPass.h" 24 #include "llvm/IR/DebugInfoMetadata.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/InstIterator.h" 27 #include "llvm/IR/LegacyPassManager.h" 28 #include "llvm/IR/Verifier.h" 29 #include "llvm/Transforms/Scalar.h" 30 #include "llvm/Transforms/Utils/Cloning.h" 31 #include "llvm/Transforms/Utils/Local.h" 32 #include "llvm/Transforms/Utils/ValueMapper.h" 33 34 using namespace llvm; 35 36 #define DEBUG_TYPE "coro-split" 37 38 // Create an entry block for a resume function with a switch that will jump to 39 // suspend points. 40 static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) { 41 LLVMContext &C = F.getContext(); 42 43 // resume.entry: 44 // %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0, 45 // i32 2 46 // % index = load i32, i32* %index.addr 47 // switch i32 %index, label %unreachable [ 48 // i32 0, label %resume.0 49 // i32 1, label %resume.1 50 // ... 51 // ] 52 53 auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F); 54 auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F); 55 56 IRBuilder<> Builder(NewEntry); 57 auto *FramePtr = Shape.FramePtr; 58 auto *FrameTy = Shape.FrameTy; 59 auto *GepIndex = Builder.CreateConstInBoundsGEP2_32( 60 FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr"); 61 auto *Index = Builder.CreateLoad(GepIndex, "index"); 62 auto *Switch = 63 Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size()); 64 Shape.ResumeSwitch = Switch; 65 66 size_t SuspendIndex = 0; 67 for (CoroSuspendInst *S : Shape.CoroSuspends) { 68 ConstantInt *IndexVal = Shape.getIndex(SuspendIndex); 69 70 // Replace CoroSave with a store to Index: 71 // %index.addr = getelementptr %f.frame... (index field number) 72 // store i32 0, i32* %index.addr1 73 auto *Save = S->getCoroSave(); 74 Builder.SetInsertPoint(Save); 75 if (S->isFinal()) { 76 // Final suspend point is represented by storing zero in ResumeFnAddr. 77 auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, 78 0, "ResumeFn.addr"); 79 auto *NullPtr = ConstantPointerNull::get(cast<PointerType>( 80 cast<PointerType>(GepIndex->getType())->getElementType())); 81 Builder.CreateStore(NullPtr, GepIndex); 82 } else { 83 auto *GepIndex = Builder.CreateConstInBoundsGEP2_32( 84 FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr"); 85 Builder.CreateStore(IndexVal, GepIndex); 86 } 87 Save->replaceAllUsesWith(ConstantTokenNone::get(C)); 88 Save->eraseFromParent(); 89 90 // Split block before and after coro.suspend and add a jump from an entry 91 // switch: 92 // 93 // whateverBB: 94 // whatever 95 // %0 = call i8 @llvm.coro.suspend(token none, i1 false) 96 // switch i8 %0, label %suspend[i8 0, label %resume 97 // i8 1, label %cleanup] 98 // becomes: 99 // 100 // whateverBB: 101 // whatever 102 // br label %resume.0.landing 103 // 104 // resume.0: ; <--- jump from the switch in the resume.entry 105 // %0 = tail call i8 @llvm.coro.suspend(token none, i1 false) 106 // br label %resume.0.landing 107 // 108 // resume.0.landing: 109 // %1 = phi i8[-1, %whateverBB], [%0, %resume.0] 110 // switch i8 % 1, label %suspend [i8 0, label %resume 111 // i8 1, label %cleanup] 112 113 auto *SuspendBB = S->getParent(); 114 auto *ResumeBB = 115 SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex)); 116 auto *LandingBB = ResumeBB->splitBasicBlock( 117 S->getNextNode(), ResumeBB->getName() + Twine(".landing")); 118 Switch->addCase(IndexVal, ResumeBB); 119 120 cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB); 121 auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front()); 122 S->replaceAllUsesWith(PN); 123 PN->addIncoming(Builder.getInt8(-1), SuspendBB); 124 PN->addIncoming(S, ResumeBB); 125 126 ++SuspendIndex; 127 } 128 129 Builder.SetInsertPoint(UnreachBB); 130 Builder.CreateUnreachable(); 131 132 return NewEntry; 133 } 134 135 // In Resumers, we replace fallthrough coro.end with ret void and delete the 136 // rest of the block. 137 static void replaceFallthroughCoroEnd(IntrinsicInst *End, 138 ValueToValueMapTy &VMap) { 139 auto *NewE = cast<IntrinsicInst>(VMap[End]); 140 ReturnInst::Create(NewE->getContext(), nullptr, NewE); 141 142 // Remove the rest of the block, by splitting it into an unreachable block. 143 auto *BB = NewE->getParent(); 144 BB->splitBasicBlock(NewE); 145 BB->getTerminator()->eraseFromParent(); 146 } 147 148 // In Resumers, we replace unwind coro.end with True to force the immediate 149 // unwind to caller. 150 static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) { 151 if (Shape.CoroEnds.empty()) 152 return; 153 154 LLVMContext &Context = Shape.CoroEnds.front()->getContext(); 155 auto *True = ConstantInt::getTrue(Context); 156 for (CoroEndInst *CE : Shape.CoroEnds) { 157 if (!CE->isUnwind()) 158 continue; 159 160 auto *NewCE = cast<IntrinsicInst>(VMap[CE]); 161 162 // If coro.end has an associated bundle, add cleanupret instruction. 163 if (auto Bundle = NewCE->getOperandBundle(LLVMContext::OB_funclet)) { 164 Value *FromPad = Bundle->Inputs[0]; 165 auto *CleanupRet = CleanupReturnInst::Create(FromPad, nullptr, NewCE); 166 NewCE->getParent()->splitBasicBlock(NewCE); 167 CleanupRet->getParent()->getTerminator()->eraseFromParent(); 168 } 169 170 NewCE->replaceAllUsesWith(True); 171 NewCE->eraseFromParent(); 172 } 173 } 174 175 // Rewrite final suspend point handling. We do not use suspend index to 176 // represent the final suspend point. Instead we zero-out ResumeFnAddr in the 177 // coroutine frame, since it is undefined behavior to resume a coroutine 178 // suspended at the final suspend point. Thus, in the resume function, we can 179 // simply remove the last case (when coro::Shape is built, the final suspend 180 // point (if present) is always the last element of CoroSuspends array). 181 // In the destroy function, we add a code sequence to check if ResumeFnAddress 182 // is Null, and if so, jump to the appropriate label to handle cleanup from the 183 // final suspend point. 184 static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr, 185 coro::Shape &Shape, SwitchInst *Switch, 186 bool IsDestroy) { 187 assert(Shape.HasFinalSuspend); 188 auto FinalCaseIt = std::prev(Switch->case_end()); 189 BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor(); 190 Switch->removeCase(FinalCaseIt); 191 if (IsDestroy) { 192 BasicBlock *OldSwitchBB = Switch->getParent(); 193 auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch"); 194 Builder.SetInsertPoint(OldSwitchBB->getTerminator()); 195 auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, FramePtr, 196 0, 0, "ResumeFn.addr"); 197 auto *Load = Builder.CreateLoad(GepIndex); 198 auto *NullPtr = 199 ConstantPointerNull::get(cast<PointerType>(Load->getType())); 200 auto *Cond = Builder.CreateICmpEQ(Load, NullPtr); 201 Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB); 202 OldSwitchBB->getTerminator()->eraseFromParent(); 203 } 204 } 205 206 // Create a resume clone by cloning the body of the original function, setting 207 // new entry block and replacing coro.suspend an appropriate value to force 208 // resume or cleanup pass for every suspend point. 209 static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape, 210 BasicBlock *ResumeEntry, int8_t FnIndex) { 211 Module *M = F.getParent(); 212 auto *FrameTy = Shape.FrameTy; 213 auto *FnPtrTy = cast<PointerType>(FrameTy->getElementType(0)); 214 auto *FnTy = cast<FunctionType>(FnPtrTy->getElementType()); 215 216 Function *NewF = 217 Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage, 218 F.getName() + Suffix, M); 219 NewF->addParamAttr(0, Attribute::NonNull); 220 NewF->addParamAttr(0, Attribute::NoAlias); 221 222 ValueToValueMapTy VMap; 223 // Replace all args with undefs. The buildCoroutineFrame algorithm already 224 // rewritten access to the args that occurs after suspend points with loads 225 // and stores to/from the coroutine frame. 226 for (Argument &A : F.args()) 227 VMap[&A] = UndefValue::get(A.getType()); 228 229 SmallVector<ReturnInst *, 4> Returns; 230 231 CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns); 232 233 // Remove old returns. 234 for (ReturnInst *Return : Returns) 235 changeToUnreachable(Return, /*UseLLVMTrap=*/false); 236 237 // Remove old return attributes. 238 NewF->removeAttributes( 239 AttributeList::ReturnIndex, 240 AttributeFuncs::typeIncompatible(NewF->getReturnType())); 241 242 // Make AllocaSpillBlock the new entry block. 243 auto *SwitchBB = cast<BasicBlock>(VMap[ResumeEntry]); 244 auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]); 245 Entry->moveBefore(&NewF->getEntryBlock()); 246 Entry->getTerminator()->eraseFromParent(); 247 BranchInst::Create(SwitchBB, Entry); 248 Entry->setName("entry" + Suffix); 249 250 // Clear all predecessors of the new entry block. 251 auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]); 252 Entry->replaceAllUsesWith(Switch->getDefaultDest()); 253 254 IRBuilder<> Builder(&NewF->getEntryBlock().front()); 255 256 // Remap frame pointer. 257 Argument *NewFramePtr = &*NewF->arg_begin(); 258 Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]); 259 NewFramePtr->takeName(OldFramePtr); 260 OldFramePtr->replaceAllUsesWith(NewFramePtr); 261 262 // Remap vFrame pointer. 263 auto *NewVFrame = Builder.CreateBitCast( 264 NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame"); 265 Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]); 266 OldVFrame->replaceAllUsesWith(NewVFrame); 267 268 // Rewrite final suspend handling as it is not done via switch (allows to 269 // remove final case from the switch, since it is undefined behavior to resume 270 // the coroutine suspended at the final suspend point. 271 if (Shape.HasFinalSuspend) { 272 auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]); 273 bool IsDestroy = FnIndex != 0; 274 handleFinalSuspend(Builder, NewFramePtr, Shape, Switch, IsDestroy); 275 } 276 277 // Replace coro suspend with the appropriate resume index. 278 // Replacing coro.suspend with (0) will result in control flow proceeding to 279 // a resume label associated with a suspend point, replacing it with (1) will 280 // result in control flow proceeding to a cleanup label associated with this 281 // suspend point. 282 auto *NewValue = Builder.getInt8(FnIndex ? 1 : 0); 283 for (CoroSuspendInst *CS : Shape.CoroSuspends) { 284 auto *MappedCS = cast<CoroSuspendInst>(VMap[CS]); 285 MappedCS->replaceAllUsesWith(NewValue); 286 MappedCS->eraseFromParent(); 287 } 288 289 // Remove coro.end intrinsics. 290 replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap); 291 replaceUnwindCoroEnds(Shape, VMap); 292 // Eliminate coro.free from the clones, replacing it with 'null' in cleanup, 293 // to suppress deallocation code. 294 coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]), 295 /*Elide=*/FnIndex == 2); 296 297 NewF->setCallingConv(CallingConv::Fast); 298 299 return NewF; 300 } 301 302 static void removeCoroEnds(coro::Shape &Shape) { 303 if (Shape.CoroEnds.empty()) 304 return; 305 306 LLVMContext &Context = Shape.CoroEnds.front()->getContext(); 307 auto *False = ConstantInt::getFalse(Context); 308 309 for (CoroEndInst *CE : Shape.CoroEnds) { 310 CE->replaceAllUsesWith(False); 311 CE->eraseFromParent(); 312 } 313 } 314 315 static void replaceFrameSize(coro::Shape &Shape) { 316 if (Shape.CoroSizes.empty()) 317 return; 318 319 // In the same function all coro.sizes should have the same result type. 320 auto *SizeIntrin = Shape.CoroSizes.back(); 321 Module *M = SizeIntrin->getModule(); 322 const DataLayout &DL = M->getDataLayout(); 323 auto Size = DL.getTypeAllocSize(Shape.FrameTy); 324 auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size); 325 326 for (CoroSizeInst *CS : Shape.CoroSizes) { 327 CS->replaceAllUsesWith(SizeConstant); 328 CS->eraseFromParent(); 329 } 330 } 331 332 // Create a global constant array containing pointers to functions provided and 333 // set Info parameter of CoroBegin to point at this constant. Example: 334 // 335 // @f.resumers = internal constant [2 x void(%f.frame*)*] 336 // [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy] 337 // define void @f() { 338 // ... 339 // call i8* @llvm.coro.begin(i8* null, i32 0, i8* null, 340 // i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*)) 341 // 342 // Assumes that all the functions have the same signature. 343 static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin, 344 std::initializer_list<Function *> Fns) { 345 346 SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end()); 347 assert(!Args.empty()); 348 Function *Part = *Fns.begin(); 349 Module *M = Part->getParent(); 350 auto *ArrTy = ArrayType::get(Part->getType(), Args.size()); 351 352 auto *ConstVal = ConstantArray::get(ArrTy, Args); 353 auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true, 354 GlobalVariable::PrivateLinkage, ConstVal, 355 F.getName() + Twine(".resumers")); 356 357 // Update coro.begin instruction to refer to this constant. 358 LLVMContext &C = F.getContext(); 359 auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C)); 360 CoroBegin->getId()->setInfo(BC); 361 } 362 363 // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame. 364 static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn, 365 Function *DestroyFn, Function *CleanupFn) { 366 367 IRBuilder<> Builder(Shape.FramePtr->getNextNode()); 368 auto *ResumeAddr = Builder.CreateConstInBoundsGEP2_32( 369 Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::ResumeField, 370 "resume.addr"); 371 Builder.CreateStore(ResumeFn, ResumeAddr); 372 373 Value *DestroyOrCleanupFn = DestroyFn; 374 375 CoroIdInst *CoroId = Shape.CoroBegin->getId(); 376 if (CoroAllocInst *CA = CoroId->getCoroAlloc()) { 377 // If there is a CoroAlloc and it returns false (meaning we elide the 378 // allocation, use CleanupFn instead of DestroyFn). 379 DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn); 380 } 381 382 auto *DestroyAddr = Builder.CreateConstInBoundsGEP2_32( 383 Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::DestroyField, 384 "destroy.addr"); 385 Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr); 386 } 387 388 static void postSplitCleanup(Function &F) { 389 removeUnreachableBlocks(F); 390 llvm::legacy::FunctionPassManager FPM(F.getParent()); 391 392 FPM.add(createVerifierPass()); 393 FPM.add(createSCCPPass()); 394 FPM.add(createCFGSimplificationPass()); 395 FPM.add(createEarlyCSEPass()); 396 FPM.add(createCFGSimplificationPass()); 397 398 FPM.doInitialization(); 399 FPM.run(F); 400 FPM.doFinalization(); 401 } 402 403 // Coroutine has no suspend points. Remove heap allocation for the coroutine 404 // frame if possible. 405 static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) { 406 auto *CoroId = CoroBegin->getId(); 407 auto *AllocInst = CoroId->getCoroAlloc(); 408 coro::replaceCoroFree(CoroId, /*Elide=*/AllocInst != nullptr); 409 if (AllocInst) { 410 IRBuilder<> Builder(AllocInst); 411 // FIXME: Need to handle overaligned members. 412 auto *Frame = Builder.CreateAlloca(FrameTy); 413 auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy()); 414 AllocInst->replaceAllUsesWith(Builder.getFalse()); 415 AllocInst->eraseFromParent(); 416 CoroBegin->replaceAllUsesWith(VFrame); 417 } else { 418 CoroBegin->replaceAllUsesWith(CoroBegin->getMem()); 419 } 420 CoroBegin->eraseFromParent(); 421 } 422 423 // look for a very simple pattern 424 // coro.save 425 // no other calls 426 // resume or destroy call 427 // coro.suspend 428 // 429 // If there are other calls between coro.save and coro.suspend, they can 430 // potentially resume or destroy the coroutine, so it is unsafe to eliminate a 431 // suspend point. 432 static bool simplifySuspendPoint(CoroSuspendInst *Suspend, 433 CoroBeginInst *CoroBegin) { 434 auto *Save = Suspend->getCoroSave(); 435 auto *BB = Suspend->getParent(); 436 if (BB != Save->getParent()) 437 return false; 438 439 CallSite SingleCallSite; 440 441 // Check that we have only one CallSite. 442 for (Instruction *I = Save->getNextNode(); I != Suspend; 443 I = I->getNextNode()) { 444 if (isa<CoroFrameInst>(I)) 445 continue; 446 if (isa<CoroSubFnInst>(I)) 447 continue; 448 if (CallSite CS = CallSite(I)) { 449 if (SingleCallSite) 450 return false; 451 else 452 SingleCallSite = CS; 453 } 454 } 455 auto *CallInstr = SingleCallSite.getInstruction(); 456 if (!CallInstr) 457 return false; 458 459 auto *Callee = SingleCallSite.getCalledValue()->stripPointerCasts(); 460 461 // See if the callsite is for resumption or destruction of the coroutine. 462 auto *SubFn = dyn_cast<CoroSubFnInst>(Callee); 463 if (!SubFn) 464 return false; 465 466 // Does not refer to the current coroutine, we cannot do anything with it. 467 if (SubFn->getFrame() != CoroBegin) 468 return false; 469 470 // Replace llvm.coro.suspend with the value that results in resumption over 471 // the resume or cleanup path. 472 Suspend->replaceAllUsesWith(SubFn->getRawIndex()); 473 Suspend->eraseFromParent(); 474 Save->eraseFromParent(); 475 476 // No longer need a call to coro.resume or coro.destroy. 477 CallInstr->eraseFromParent(); 478 479 if (SubFn->user_empty()) 480 SubFn->eraseFromParent(); 481 482 return true; 483 } 484 485 // Remove suspend points that are simplified. 486 static void simplifySuspendPoints(coro::Shape &Shape) { 487 auto &S = Shape.CoroSuspends; 488 size_t I = 0, N = S.size(); 489 if (N == 0) 490 return; 491 for (;;) { 492 if (simplifySuspendPoint(S[I], Shape.CoroBegin)) { 493 if (--N == I) 494 break; 495 std::swap(S[I], S[N]); 496 continue; 497 } 498 if (++I == N) 499 break; 500 } 501 S.resize(N); 502 } 503 504 static SmallPtrSet<BasicBlock *, 4> getCoroBeginPredBlocks(CoroBeginInst *CB) { 505 // Collect all blocks that we need to look for instructions to relocate. 506 SmallPtrSet<BasicBlock *, 4> RelocBlocks; 507 SmallVector<BasicBlock *, 4> Work; 508 Work.push_back(CB->getParent()); 509 510 do { 511 BasicBlock *Current = Work.pop_back_val(); 512 for (BasicBlock *BB : predecessors(Current)) 513 if (RelocBlocks.count(BB) == 0) { 514 RelocBlocks.insert(BB); 515 Work.push_back(BB); 516 } 517 } while (!Work.empty()); 518 return RelocBlocks; 519 } 520 521 static SmallPtrSet<Instruction *, 8> 522 getNotRelocatableInstructions(CoroBeginInst *CoroBegin, 523 SmallPtrSetImpl<BasicBlock *> &RelocBlocks) { 524 SmallPtrSet<Instruction *, 8> DoNotRelocate; 525 // Collect all instructions that we should not relocate 526 SmallVector<Instruction *, 8> Work; 527 528 // Start with CoroBegin and terminators of all preceding blocks. 529 Work.push_back(CoroBegin); 530 BasicBlock *CoroBeginBB = CoroBegin->getParent(); 531 for (BasicBlock *BB : RelocBlocks) 532 if (BB != CoroBeginBB) 533 Work.push_back(BB->getTerminator()); 534 535 // For every instruction in the Work list, place its operands in DoNotRelocate 536 // set. 537 do { 538 Instruction *Current = Work.pop_back_val(); 539 DoNotRelocate.insert(Current); 540 for (Value *U : Current->operands()) { 541 auto *I = dyn_cast<Instruction>(U); 542 if (!I) 543 continue; 544 if (isa<AllocaInst>(U)) 545 continue; 546 if (DoNotRelocate.count(I) == 0) { 547 Work.push_back(I); 548 DoNotRelocate.insert(I); 549 } 550 } 551 } while (!Work.empty()); 552 return DoNotRelocate; 553 } 554 555 static void relocateInstructionBefore(CoroBeginInst *CoroBegin, Function &F) { 556 // Analyze which non-alloca instructions are needed for allocation and 557 // relocate the rest to after coro.begin. We need to do it, since some of the 558 // targets of those instructions may be placed into coroutine frame memory 559 // for which becomes available after coro.begin intrinsic. 560 561 auto BlockSet = getCoroBeginPredBlocks(CoroBegin); 562 auto DoNotRelocateSet = getNotRelocatableInstructions(CoroBegin, BlockSet); 563 564 Instruction *InsertPt = CoroBegin->getNextNode(); 565 BasicBlock &BB = F.getEntryBlock(); // TODO: Look at other blocks as well. 566 for (auto B = BB.begin(), E = BB.end(); B != E;) { 567 Instruction &I = *B++; 568 if (isa<AllocaInst>(&I)) 569 continue; 570 if (&I == CoroBegin) 571 break; 572 if (DoNotRelocateSet.count(&I)) 573 continue; 574 I.moveBefore(InsertPt); 575 } 576 } 577 578 static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) { 579 coro::Shape Shape(F); 580 if (!Shape.CoroBegin) 581 return; 582 583 simplifySuspendPoints(Shape); 584 relocateInstructionBefore(Shape.CoroBegin, F); 585 buildCoroutineFrame(F, Shape); 586 replaceFrameSize(Shape); 587 588 // If there are no suspend points, no split required, just remove 589 // the allocation and deallocation blocks, they are not needed. 590 if (Shape.CoroSuspends.empty()) { 591 handleNoSuspendCoroutine(Shape.CoroBegin, Shape.FrameTy); 592 removeCoroEnds(Shape); 593 postSplitCleanup(F); 594 coro::updateCallGraph(F, {}, CG, SCC); 595 return; 596 } 597 598 auto *ResumeEntry = createResumeEntryBlock(F, Shape); 599 auto ResumeClone = createClone(F, ".resume", Shape, ResumeEntry, 0); 600 auto DestroyClone = createClone(F, ".destroy", Shape, ResumeEntry, 1); 601 auto CleanupClone = createClone(F, ".cleanup", Shape, ResumeEntry, 2); 602 603 // We no longer need coro.end in F. 604 removeCoroEnds(Shape); 605 606 postSplitCleanup(F); 607 postSplitCleanup(*ResumeClone); 608 postSplitCleanup(*DestroyClone); 609 postSplitCleanup(*CleanupClone); 610 611 // Store addresses resume/destroy/cleanup functions in the coroutine frame. 612 updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone); 613 614 // Create a constant array referring to resume/destroy/clone functions pointed 615 // by the last argument of @llvm.coro.info, so that CoroElide pass can 616 // determined correct function to call. 617 setCoroInfo(F, Shape.CoroBegin, {ResumeClone, DestroyClone, CleanupClone}); 618 619 // Update call graph and add the functions we created to the SCC. 620 coro::updateCallGraph(F, {ResumeClone, DestroyClone, CleanupClone}, CG, SCC); 621 } 622 623 // When we see the coroutine the first time, we insert an indirect call to a 624 // devirt trigger function and mark the coroutine that it is now ready for 625 // split. 626 static void prepareForSplit(Function &F, CallGraph &CG) { 627 Module &M = *F.getParent(); 628 #ifndef NDEBUG 629 Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN); 630 assert(DevirtFn && "coro.devirt.trigger function not found"); 631 #endif 632 633 F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT); 634 635 // Insert an indirect call sequence that will be devirtualized by CoroElide 636 // pass: 637 // %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1) 638 // %1 = bitcast i8* %0 to void(i8*)* 639 // call void %1(i8* null) 640 coro::LowererBase Lowerer(M); 641 Instruction *InsertPt = F.getEntryBlock().getTerminator(); 642 auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(F.getContext())); 643 auto *DevirtFnAddr = 644 Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt); 645 auto *IndirectCall = CallInst::Create(DevirtFnAddr, Null, "", InsertPt); 646 647 // Update CG graph with an indirect call we just added. 648 CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode()); 649 } 650 651 // Make sure that there is a devirtualization trigger function that CoroSplit 652 // pass uses the force restart CGSCC pipeline. If devirt trigger function is not 653 // found, we will create one and add it to the current SCC. 654 static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) { 655 Module &M = CG.getModule(); 656 if (M.getFunction(CORO_DEVIRT_TRIGGER_FN)) 657 return; 658 659 LLVMContext &C = M.getContext(); 660 auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C), 661 /*IsVarArgs=*/false); 662 Function *DevirtFn = 663 Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage, 664 CORO_DEVIRT_TRIGGER_FN, &M); 665 DevirtFn->addFnAttr(Attribute::AlwaysInline); 666 auto *Entry = BasicBlock::Create(C, "entry", DevirtFn); 667 ReturnInst::Create(C, Entry); 668 669 auto *Node = CG.getOrInsertFunction(DevirtFn); 670 671 SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end()); 672 Nodes.push_back(Node); 673 SCC.initialize(Nodes); 674 } 675 676 //===----------------------------------------------------------------------===// 677 // Top Level Driver 678 //===----------------------------------------------------------------------===// 679 680 namespace { 681 682 struct CoroSplit : public CallGraphSCCPass { 683 static char ID; // Pass identification, replacement for typeid 684 CoroSplit() : CallGraphSCCPass(ID) { 685 initializeCoroSplitPass(*PassRegistry::getPassRegistry()); 686 } 687 688 bool Run = false; 689 690 // A coroutine is identified by the presence of coro.begin intrinsic, if 691 // we don't have any, this pass has nothing to do. 692 bool doInitialization(CallGraph &CG) override { 693 Run = coro::declaresIntrinsics(CG.getModule(), {"llvm.coro.begin"}); 694 return CallGraphSCCPass::doInitialization(CG); 695 } 696 697 bool runOnSCC(CallGraphSCC &SCC) override { 698 if (!Run) 699 return false; 700 701 // Find coroutines for processing. 702 SmallVector<Function *, 4> Coroutines; 703 for (CallGraphNode *CGN : SCC) 704 if (auto *F = CGN->getFunction()) 705 if (F->hasFnAttribute(CORO_PRESPLIT_ATTR)) 706 Coroutines.push_back(F); 707 708 if (Coroutines.empty()) 709 return false; 710 711 CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); 712 createDevirtTriggerFunc(CG, SCC); 713 714 for (Function *F : Coroutines) { 715 Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR); 716 StringRef Value = Attr.getValueAsString(); 717 DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName() 718 << "' state: " << Value << "\n"); 719 if (Value == UNPREPARED_FOR_SPLIT) { 720 prepareForSplit(*F, CG); 721 continue; 722 } 723 F->removeFnAttr(CORO_PRESPLIT_ATTR); 724 splitCoroutine(*F, CG, SCC); 725 } 726 return true; 727 } 728 729 void getAnalysisUsage(AnalysisUsage &AU) const override { 730 CallGraphSCCPass::getAnalysisUsage(AU); 731 } 732 StringRef getPassName() const override { return "Coroutine Splitting"; } 733 }; 734 } 735 736 char CoroSplit::ID = 0; 737 INITIALIZE_PASS( 738 CoroSplit, "coro-split", 739 "Split coroutine into a set of functions driving its state machine", false, 740 false) 741 742 Pass *llvm::createCoroSplitPass() { return new CoroSplit(); } 743