1 //===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // This pass builds the coroutine frame and outlines resume and destroy parts 10 // of the coroutine into separate functions. 11 // 12 // We present a coroutine to an LLVM as an ordinary function with suspension 13 // points marked up with intrinsics. We let the optimizer party on the coroutine 14 // as a single function for as long as possible. Shortly before the coroutine is 15 // eligible to be inlined into its callers, we split up the coroutine into parts 16 // corresponding to an initial, resume and destroy invocations of the coroutine, 17 // add them to the current SCC and restart the IPO pipeline to optimize the 18 // coroutine subfunctions we extracted before proceeding to the caller of the 19 // coroutine. 20 //===----------------------------------------------------------------------===// 21 22 #include "CoroInstr.h" 23 #include "CoroInternal.h" 24 #include "llvm/ADT/DenseMap.h" 25 #include "llvm/ADT/SmallPtrSet.h" 26 #include "llvm/ADT/SmallVector.h" 27 #include "llvm/ADT/StringRef.h" 28 #include "llvm/ADT/Twine.h" 29 #include "llvm/Analysis/CallGraph.h" 30 #include "llvm/Analysis/CallGraphSCCPass.h" 31 #include "llvm/IR/Argument.h" 32 #include "llvm/IR/Attributes.h" 33 #include "llvm/IR/BasicBlock.h" 34 #include "llvm/IR/CFG.h" 35 #include "llvm/IR/CallSite.h" 36 #include "llvm/IR/CallingConv.h" 37 #include "llvm/IR/Constants.h" 38 #include "llvm/IR/DataLayout.h" 39 #include "llvm/IR/DerivedTypes.h" 40 #include "llvm/IR/Function.h" 41 #include "llvm/IR/GlobalValue.h" 42 #include "llvm/IR/GlobalVariable.h" 43 #include "llvm/IR/IRBuilder.h" 44 #include "llvm/IR/InstIterator.h" 45 #include "llvm/IR/InstrTypes.h" 46 #include "llvm/IR/Instruction.h" 47 #include "llvm/IR/Instructions.h" 48 #include "llvm/IR/IntrinsicInst.h" 49 #include "llvm/IR/LLVMContext.h" 50 #include "llvm/IR/LegacyPassManager.h" 51 #include "llvm/IR/Module.h" 52 #include "llvm/IR/Type.h" 53 #include "llvm/IR/Value.h" 54 #include "llvm/IR/Verifier.h" 55 #include "llvm/Pass.h" 56 #include "llvm/Support/Casting.h" 57 #include "llvm/Support/Debug.h" 58 #include "llvm/Support/raw_ostream.h" 59 #include "llvm/Transforms/Scalar.h" 60 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 61 #include "llvm/Transforms/Utils/Cloning.h" 62 #include "llvm/Transforms/Utils/Local.h" 63 #include "llvm/Transforms/Utils/ValueMapper.h" 64 #include <cassert> 65 #include <cstddef> 66 #include <cstdint> 67 #include <initializer_list> 68 #include <iterator> 69 70 using namespace llvm; 71 72 #define DEBUG_TYPE "coro-split" 73 74 // Create an entry block for a resume function with a switch that will jump to 75 // suspend points. 76 static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) { 77 LLVMContext &C = F.getContext(); 78 79 // resume.entry: 80 // %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0, 81 // i32 2 82 // % index = load i32, i32* %index.addr 83 // switch i32 %index, label %unreachable [ 84 // i32 0, label %resume.0 85 // i32 1, label %resume.1 86 // ... 87 // ] 88 89 auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F); 90 auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F); 91 92 IRBuilder<> Builder(NewEntry); 93 auto *FramePtr = Shape.FramePtr; 94 auto *FrameTy = Shape.FrameTy; 95 auto *GepIndex = Builder.CreateConstInBoundsGEP2_32( 96 FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr"); 97 auto *Index = Builder.CreateLoad(GepIndex, "index"); 98 auto *Switch = 99 Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size()); 100 Shape.ResumeSwitch = Switch; 101 102 size_t SuspendIndex = 0; 103 for (CoroSuspendInst *S : Shape.CoroSuspends) { 104 ConstantInt *IndexVal = Shape.getIndex(SuspendIndex); 105 106 // Replace CoroSave with a store to Index: 107 // %index.addr = getelementptr %f.frame... (index field number) 108 // store i32 0, i32* %index.addr1 109 auto *Save = S->getCoroSave(); 110 Builder.SetInsertPoint(Save); 111 if (S->isFinal()) { 112 // Final suspend point is represented by storing zero in ResumeFnAddr. 113 auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, 114 0, "ResumeFn.addr"); 115 auto *NullPtr = ConstantPointerNull::get(cast<PointerType>( 116 cast<PointerType>(GepIndex->getType())->getElementType())); 117 Builder.CreateStore(NullPtr, GepIndex); 118 } else { 119 auto *GepIndex = Builder.CreateConstInBoundsGEP2_32( 120 FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr"); 121 Builder.CreateStore(IndexVal, GepIndex); 122 } 123 Save->replaceAllUsesWith(ConstantTokenNone::get(C)); 124 Save->eraseFromParent(); 125 126 // Split block before and after coro.suspend and add a jump from an entry 127 // switch: 128 // 129 // whateverBB: 130 // whatever 131 // %0 = call i8 @llvm.coro.suspend(token none, i1 false) 132 // switch i8 %0, label %suspend[i8 0, label %resume 133 // i8 1, label %cleanup] 134 // becomes: 135 // 136 // whateverBB: 137 // whatever 138 // br label %resume.0.landing 139 // 140 // resume.0: ; <--- jump from the switch in the resume.entry 141 // %0 = tail call i8 @llvm.coro.suspend(token none, i1 false) 142 // br label %resume.0.landing 143 // 144 // resume.0.landing: 145 // %1 = phi i8[-1, %whateverBB], [%0, %resume.0] 146 // switch i8 % 1, label %suspend [i8 0, label %resume 147 // i8 1, label %cleanup] 148 149 auto *SuspendBB = S->getParent(); 150 auto *ResumeBB = 151 SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex)); 152 auto *LandingBB = ResumeBB->splitBasicBlock( 153 S->getNextNode(), ResumeBB->getName() + Twine(".landing")); 154 Switch->addCase(IndexVal, ResumeBB); 155 156 cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB); 157 auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front()); 158 S->replaceAllUsesWith(PN); 159 PN->addIncoming(Builder.getInt8(-1), SuspendBB); 160 PN->addIncoming(S, ResumeBB); 161 162 ++SuspendIndex; 163 } 164 165 Builder.SetInsertPoint(UnreachBB); 166 Builder.CreateUnreachable(); 167 168 return NewEntry; 169 } 170 171 // In Resumers, we replace fallthrough coro.end with ret void and delete the 172 // rest of the block. 173 static void replaceFallthroughCoroEnd(IntrinsicInst *End, 174 ValueToValueMapTy &VMap) { 175 auto *NewE = cast<IntrinsicInst>(VMap[End]); 176 ReturnInst::Create(NewE->getContext(), nullptr, NewE); 177 178 // Remove the rest of the block, by splitting it into an unreachable block. 179 auto *BB = NewE->getParent(); 180 BB->splitBasicBlock(NewE); 181 BB->getTerminator()->eraseFromParent(); 182 } 183 184 // In Resumers, we replace unwind coro.end with True to force the immediate 185 // unwind to caller. 186 static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) { 187 if (Shape.CoroEnds.empty()) 188 return; 189 190 LLVMContext &Context = Shape.CoroEnds.front()->getContext(); 191 auto *True = ConstantInt::getTrue(Context); 192 for (CoroEndInst *CE : Shape.CoroEnds) { 193 if (!CE->isUnwind()) 194 continue; 195 196 auto *NewCE = cast<IntrinsicInst>(VMap[CE]); 197 198 // If coro.end has an associated bundle, add cleanupret instruction. 199 if (auto Bundle = NewCE->getOperandBundle(LLVMContext::OB_funclet)) { 200 Value *FromPad = Bundle->Inputs[0]; 201 auto *CleanupRet = CleanupReturnInst::Create(FromPad, nullptr, NewCE); 202 NewCE->getParent()->splitBasicBlock(NewCE); 203 CleanupRet->getParent()->getTerminator()->eraseFromParent(); 204 } 205 206 NewCE->replaceAllUsesWith(True); 207 NewCE->eraseFromParent(); 208 } 209 } 210 211 // Rewrite final suspend point handling. We do not use suspend index to 212 // represent the final suspend point. Instead we zero-out ResumeFnAddr in the 213 // coroutine frame, since it is undefined behavior to resume a coroutine 214 // suspended at the final suspend point. Thus, in the resume function, we can 215 // simply remove the last case (when coro::Shape is built, the final suspend 216 // point (if present) is always the last element of CoroSuspends array). 217 // In the destroy function, we add a code sequence to check if ResumeFnAddress 218 // is Null, and if so, jump to the appropriate label to handle cleanup from the 219 // final suspend point. 220 static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr, 221 coro::Shape &Shape, SwitchInst *Switch, 222 bool IsDestroy) { 223 assert(Shape.HasFinalSuspend); 224 auto FinalCaseIt = std::prev(Switch->case_end()); 225 BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor(); 226 Switch->removeCase(FinalCaseIt); 227 if (IsDestroy) { 228 BasicBlock *OldSwitchBB = Switch->getParent(); 229 auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch"); 230 Builder.SetInsertPoint(OldSwitchBB->getTerminator()); 231 auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, FramePtr, 232 0, 0, "ResumeFn.addr"); 233 auto *Load = Builder.CreateLoad(GepIndex); 234 auto *NullPtr = 235 ConstantPointerNull::get(cast<PointerType>(Load->getType())); 236 auto *Cond = Builder.CreateICmpEQ(Load, NullPtr); 237 Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB); 238 OldSwitchBB->getTerminator()->eraseFromParent(); 239 } 240 } 241 242 // Create a resume clone by cloning the body of the original function, setting 243 // new entry block and replacing coro.suspend an appropriate value to force 244 // resume or cleanup pass for every suspend point. 245 static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape, 246 BasicBlock *ResumeEntry, int8_t FnIndex) { 247 Module *M = F.getParent(); 248 auto *FrameTy = Shape.FrameTy; 249 auto *FnPtrTy = cast<PointerType>(FrameTy->getElementType(0)); 250 auto *FnTy = cast<FunctionType>(FnPtrTy->getElementType()); 251 252 Function *NewF = 253 Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage, 254 F.getName() + Suffix, M); 255 NewF->addParamAttr(0, Attribute::NonNull); 256 NewF->addParamAttr(0, Attribute::NoAlias); 257 258 ValueToValueMapTy VMap; 259 // Replace all args with undefs. The buildCoroutineFrame algorithm already 260 // rewritten access to the args that occurs after suspend points with loads 261 // and stores to/from the coroutine frame. 262 for (Argument &A : F.args()) 263 VMap[&A] = UndefValue::get(A.getType()); 264 265 SmallVector<ReturnInst *, 4> Returns; 266 267 CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns); 268 269 // Remove old returns. 270 for (ReturnInst *Return : Returns) 271 changeToUnreachable(Return, /*UseLLVMTrap=*/false); 272 273 // Remove old return attributes. 274 NewF->removeAttributes( 275 AttributeList::ReturnIndex, 276 AttributeFuncs::typeIncompatible(NewF->getReturnType())); 277 278 // Make AllocaSpillBlock the new entry block. 279 auto *SwitchBB = cast<BasicBlock>(VMap[ResumeEntry]); 280 auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]); 281 Entry->moveBefore(&NewF->getEntryBlock()); 282 Entry->getTerminator()->eraseFromParent(); 283 BranchInst::Create(SwitchBB, Entry); 284 Entry->setName("entry" + Suffix); 285 286 // Clear all predecessors of the new entry block. 287 auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]); 288 Entry->replaceAllUsesWith(Switch->getDefaultDest()); 289 290 IRBuilder<> Builder(&NewF->getEntryBlock().front()); 291 292 // Remap frame pointer. 293 Argument *NewFramePtr = &*NewF->arg_begin(); 294 Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]); 295 NewFramePtr->takeName(OldFramePtr); 296 OldFramePtr->replaceAllUsesWith(NewFramePtr); 297 298 // Remap vFrame pointer. 299 auto *NewVFrame = Builder.CreateBitCast( 300 NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame"); 301 Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]); 302 OldVFrame->replaceAllUsesWith(NewVFrame); 303 304 // Rewrite final suspend handling as it is not done via switch (allows to 305 // remove final case from the switch, since it is undefined behavior to resume 306 // the coroutine suspended at the final suspend point. 307 if (Shape.HasFinalSuspend) { 308 auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]); 309 bool IsDestroy = FnIndex != 0; 310 handleFinalSuspend(Builder, NewFramePtr, Shape, Switch, IsDestroy); 311 } 312 313 // Replace coro suspend with the appropriate resume index. 314 // Replacing coro.suspend with (0) will result in control flow proceeding to 315 // a resume label associated with a suspend point, replacing it with (1) will 316 // result in control flow proceeding to a cleanup label associated with this 317 // suspend point. 318 auto *NewValue = Builder.getInt8(FnIndex ? 1 : 0); 319 for (CoroSuspendInst *CS : Shape.CoroSuspends) { 320 auto *MappedCS = cast<CoroSuspendInst>(VMap[CS]); 321 MappedCS->replaceAllUsesWith(NewValue); 322 MappedCS->eraseFromParent(); 323 } 324 325 // Remove coro.end intrinsics. 326 replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap); 327 replaceUnwindCoroEnds(Shape, VMap); 328 // Eliminate coro.free from the clones, replacing it with 'null' in cleanup, 329 // to suppress deallocation code. 330 coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]), 331 /*Elide=*/FnIndex == 2); 332 333 NewF->setCallingConv(CallingConv::Fast); 334 335 return NewF; 336 } 337 338 static void removeCoroEnds(coro::Shape &Shape) { 339 if (Shape.CoroEnds.empty()) 340 return; 341 342 LLVMContext &Context = Shape.CoroEnds.front()->getContext(); 343 auto *False = ConstantInt::getFalse(Context); 344 345 for (CoroEndInst *CE : Shape.CoroEnds) { 346 CE->replaceAllUsesWith(False); 347 CE->eraseFromParent(); 348 } 349 } 350 351 static void replaceFrameSize(coro::Shape &Shape) { 352 if (Shape.CoroSizes.empty()) 353 return; 354 355 // In the same function all coro.sizes should have the same result type. 356 auto *SizeIntrin = Shape.CoroSizes.back(); 357 Module *M = SizeIntrin->getModule(); 358 const DataLayout &DL = M->getDataLayout(); 359 auto Size = DL.getTypeAllocSize(Shape.FrameTy); 360 auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size); 361 362 for (CoroSizeInst *CS : Shape.CoroSizes) { 363 CS->replaceAllUsesWith(SizeConstant); 364 CS->eraseFromParent(); 365 } 366 } 367 368 // Create a global constant array containing pointers to functions provided and 369 // set Info parameter of CoroBegin to point at this constant. Example: 370 // 371 // @f.resumers = internal constant [2 x void(%f.frame*)*] 372 // [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy] 373 // define void @f() { 374 // ... 375 // call i8* @llvm.coro.begin(i8* null, i32 0, i8* null, 376 // i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*)) 377 // 378 // Assumes that all the functions have the same signature. 379 static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin, 380 std::initializer_list<Function *> Fns) { 381 SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end()); 382 assert(!Args.empty()); 383 Function *Part = *Fns.begin(); 384 Module *M = Part->getParent(); 385 auto *ArrTy = ArrayType::get(Part->getType(), Args.size()); 386 387 auto *ConstVal = ConstantArray::get(ArrTy, Args); 388 auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true, 389 GlobalVariable::PrivateLinkage, ConstVal, 390 F.getName() + Twine(".resumers")); 391 392 // Update coro.begin instruction to refer to this constant. 393 LLVMContext &C = F.getContext(); 394 auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C)); 395 CoroBegin->getId()->setInfo(BC); 396 } 397 398 // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame. 399 static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn, 400 Function *DestroyFn, Function *CleanupFn) { 401 IRBuilder<> Builder(Shape.FramePtr->getNextNode()); 402 auto *ResumeAddr = Builder.CreateConstInBoundsGEP2_32( 403 Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::ResumeField, 404 "resume.addr"); 405 Builder.CreateStore(ResumeFn, ResumeAddr); 406 407 Value *DestroyOrCleanupFn = DestroyFn; 408 409 CoroIdInst *CoroId = Shape.CoroBegin->getId(); 410 if (CoroAllocInst *CA = CoroId->getCoroAlloc()) { 411 // If there is a CoroAlloc and it returns false (meaning we elide the 412 // allocation, use CleanupFn instead of DestroyFn). 413 DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn); 414 } 415 416 auto *DestroyAddr = Builder.CreateConstInBoundsGEP2_32( 417 Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::DestroyField, 418 "destroy.addr"); 419 Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr); 420 } 421 422 static void postSplitCleanup(Function &F) { 423 removeUnreachableBlocks(F); 424 legacy::FunctionPassManager FPM(F.getParent()); 425 426 FPM.add(createVerifierPass()); 427 FPM.add(createSCCPPass()); 428 FPM.add(createCFGSimplificationPass()); 429 FPM.add(createEarlyCSEPass()); 430 FPM.add(createCFGSimplificationPass()); 431 432 FPM.doInitialization(); 433 FPM.run(F); 434 FPM.doFinalization(); 435 } 436 437 // Assuming we arrived at the block NewBlock from Prev instruction, store 438 // PHI's incoming values in the ResolvedValues map. 439 static void 440 scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock, 441 DenseMap<Value *, Value *> &ResolvedValues) { 442 auto *PrevBB = Prev->getParent(); 443 auto *I = &*NewBlock->begin(); 444 while (auto PN = dyn_cast<PHINode>(I)) { 445 auto V = PN->getIncomingValueForBlock(PrevBB); 446 // See if we already resolved it. 447 auto VI = ResolvedValues.find(V); 448 if (VI != ResolvedValues.end()) 449 V = VI->second; 450 // Remember the value. 451 ResolvedValues[PN] = V; 452 I = I->getNextNode(); 453 } 454 } 455 456 // Replace a sequence of branches leading to a ret, with a clone of a ret 457 // instruction. Suspend instruction represented by a switch, track the PHI 458 // values and select the correct case successor when possible. 459 static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { 460 DenseMap<Value *, Value *> ResolvedValues; 461 462 Instruction *I = InitialInst; 463 while (isa<TerminatorInst>(I)) { 464 if (isa<ReturnInst>(I)) { 465 if (I != InitialInst) 466 ReplaceInstWithInst(InitialInst, I->clone()); 467 return true; 468 } 469 if (auto *BR = dyn_cast<BranchInst>(I)) { 470 if (BR->isUnconditional()) { 471 BasicBlock *BB = BR->getSuccessor(0); 472 scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); 473 I = BB->getFirstNonPHIOrDbgOrLifetime(); 474 continue; 475 } 476 } else if (auto *SI = dyn_cast<SwitchInst>(I)) { 477 Value *V = SI->getCondition(); 478 auto it = ResolvedValues.find(V); 479 if (it != ResolvedValues.end()) 480 V = it->second; 481 if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) { 482 BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor(); 483 scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); 484 I = BB->getFirstNonPHIOrDbgOrLifetime(); 485 continue; 486 } 487 } 488 return false; 489 } 490 return false; 491 } 492 493 // Add musttail to any resume instructions that is immediately followed by a 494 // suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call 495 // for symmetrical coroutine control transfer (C++ Coroutines TS extension). 496 // This transformation is done only in the resume part of the coroutine that has 497 // identical signature and calling convention as the coro.resume call. 498 static void addMustTailToCoroResumes(Function &F) { 499 bool changed = false; 500 501 // Collect potential resume instructions. 502 SmallVector<CallInst *, 4> Resumes; 503 for (auto &I : instructions(F)) 504 if (auto *Call = dyn_cast<CallInst>(&I)) 505 if (auto *CalledValue = Call->getCalledValue()) 506 // CoroEarly pass replaced coro resumes with indirect calls to an 507 // address return by CoroSubFnInst intrinsic. See if it is one of those. 508 if (isa<CoroSubFnInst>(CalledValue->stripPointerCasts())) 509 Resumes.push_back(Call); 510 511 // Set musttail on those that are followed by a ret instruction. 512 for (CallInst *Call : Resumes) 513 if (simplifyTerminatorLeadingToRet(Call->getNextNode())) { 514 Call->setTailCallKind(CallInst::TCK_MustTail); 515 changed = true; 516 } 517 518 if (changed) 519 removeUnreachableBlocks(F); 520 } 521 522 // Coroutine has no suspend points. Remove heap allocation for the coroutine 523 // frame if possible. 524 static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) { 525 auto *CoroId = CoroBegin->getId(); 526 auto *AllocInst = CoroId->getCoroAlloc(); 527 coro::replaceCoroFree(CoroId, /*Elide=*/AllocInst != nullptr); 528 if (AllocInst) { 529 IRBuilder<> Builder(AllocInst); 530 // FIXME: Need to handle overaligned members. 531 auto *Frame = Builder.CreateAlloca(FrameTy); 532 auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy()); 533 AllocInst->replaceAllUsesWith(Builder.getFalse()); 534 AllocInst->eraseFromParent(); 535 CoroBegin->replaceAllUsesWith(VFrame); 536 } else { 537 CoroBegin->replaceAllUsesWith(CoroBegin->getMem()); 538 } 539 CoroBegin->eraseFromParent(); 540 } 541 542 // look for a very simple pattern 543 // coro.save 544 // no other calls 545 // resume or destroy call 546 // coro.suspend 547 // 548 // If there are other calls between coro.save and coro.suspend, they can 549 // potentially resume or destroy the coroutine, so it is unsafe to eliminate a 550 // suspend point. 551 static bool simplifySuspendPoint(CoroSuspendInst *Suspend, 552 CoroBeginInst *CoroBegin) { 553 auto *Save = Suspend->getCoroSave(); 554 auto *BB = Suspend->getParent(); 555 if (BB != Save->getParent()) 556 return false; 557 558 CallSite SingleCallSite; 559 560 // Check that we have only one CallSite. 561 for (Instruction *I = Save->getNextNode(); I != Suspend; 562 I = I->getNextNode()) { 563 if (isa<CoroFrameInst>(I)) 564 continue; 565 if (isa<CoroSubFnInst>(I)) 566 continue; 567 if (CallSite CS = CallSite(I)) { 568 if (SingleCallSite) 569 return false; 570 else 571 SingleCallSite = CS; 572 } 573 } 574 auto *CallInstr = SingleCallSite.getInstruction(); 575 if (!CallInstr) 576 return false; 577 578 auto *Callee = SingleCallSite.getCalledValue()->stripPointerCasts(); 579 580 // See if the callsite is for resumption or destruction of the coroutine. 581 auto *SubFn = dyn_cast<CoroSubFnInst>(Callee); 582 if (!SubFn) 583 return false; 584 585 // Does not refer to the current coroutine, we cannot do anything with it. 586 if (SubFn->getFrame() != CoroBegin) 587 return false; 588 589 // Replace llvm.coro.suspend with the value that results in resumption over 590 // the resume or cleanup path. 591 Suspend->replaceAllUsesWith(SubFn->getRawIndex()); 592 Suspend->eraseFromParent(); 593 Save->eraseFromParent(); 594 595 // No longer need a call to coro.resume or coro.destroy. 596 CallInstr->eraseFromParent(); 597 598 if (SubFn->user_empty()) 599 SubFn->eraseFromParent(); 600 601 return true; 602 } 603 604 // Remove suspend points that are simplified. 605 static void simplifySuspendPoints(coro::Shape &Shape) { 606 auto &S = Shape.CoroSuspends; 607 size_t I = 0, N = S.size(); 608 if (N == 0) 609 return; 610 while (true) { 611 if (simplifySuspendPoint(S[I], Shape.CoroBegin)) { 612 if (--N == I) 613 break; 614 std::swap(S[I], S[N]); 615 continue; 616 } 617 if (++I == N) 618 break; 619 } 620 S.resize(N); 621 } 622 623 static SmallPtrSet<BasicBlock *, 4> getCoroBeginPredBlocks(CoroBeginInst *CB) { 624 // Collect all blocks that we need to look for instructions to relocate. 625 SmallPtrSet<BasicBlock *, 4> RelocBlocks; 626 SmallVector<BasicBlock *, 4> Work; 627 Work.push_back(CB->getParent()); 628 629 do { 630 BasicBlock *Current = Work.pop_back_val(); 631 for (BasicBlock *BB : predecessors(Current)) 632 if (RelocBlocks.count(BB) == 0) { 633 RelocBlocks.insert(BB); 634 Work.push_back(BB); 635 } 636 } while (!Work.empty()); 637 return RelocBlocks; 638 } 639 640 static SmallPtrSet<Instruction *, 8> 641 getNotRelocatableInstructions(CoroBeginInst *CoroBegin, 642 SmallPtrSetImpl<BasicBlock *> &RelocBlocks) { 643 SmallPtrSet<Instruction *, 8> DoNotRelocate; 644 // Collect all instructions that we should not relocate 645 SmallVector<Instruction *, 8> Work; 646 647 // Start with CoroBegin and terminators of all preceding blocks. 648 Work.push_back(CoroBegin); 649 BasicBlock *CoroBeginBB = CoroBegin->getParent(); 650 for (BasicBlock *BB : RelocBlocks) 651 if (BB != CoroBeginBB) 652 Work.push_back(BB->getTerminator()); 653 654 // For every instruction in the Work list, place its operands in DoNotRelocate 655 // set. 656 do { 657 Instruction *Current = Work.pop_back_val(); 658 DoNotRelocate.insert(Current); 659 for (Value *U : Current->operands()) { 660 auto *I = dyn_cast<Instruction>(U); 661 if (!I) 662 continue; 663 if (isa<AllocaInst>(U)) 664 continue; 665 if (DoNotRelocate.count(I) == 0) { 666 Work.push_back(I); 667 DoNotRelocate.insert(I); 668 } 669 } 670 } while (!Work.empty()); 671 return DoNotRelocate; 672 } 673 674 static void relocateInstructionBefore(CoroBeginInst *CoroBegin, Function &F) { 675 // Analyze which non-alloca instructions are needed for allocation and 676 // relocate the rest to after coro.begin. We need to do it, since some of the 677 // targets of those instructions may be placed into coroutine frame memory 678 // for which becomes available after coro.begin intrinsic. 679 680 auto BlockSet = getCoroBeginPredBlocks(CoroBegin); 681 auto DoNotRelocateSet = getNotRelocatableInstructions(CoroBegin, BlockSet); 682 683 Instruction *InsertPt = CoroBegin->getNextNode(); 684 BasicBlock &BB = F.getEntryBlock(); // TODO: Look at other blocks as well. 685 for (auto B = BB.begin(), E = BB.end(); B != E;) { 686 Instruction &I = *B++; 687 if (isa<AllocaInst>(&I)) 688 continue; 689 if (&I == CoroBegin) 690 break; 691 if (DoNotRelocateSet.count(&I)) 692 continue; 693 I.moveBefore(InsertPt); 694 } 695 } 696 697 static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) { 698 coro::Shape Shape(F); 699 if (!Shape.CoroBegin) 700 return; 701 702 simplifySuspendPoints(Shape); 703 relocateInstructionBefore(Shape.CoroBegin, F); 704 buildCoroutineFrame(F, Shape); 705 replaceFrameSize(Shape); 706 707 // If there are no suspend points, no split required, just remove 708 // the allocation and deallocation blocks, they are not needed. 709 if (Shape.CoroSuspends.empty()) { 710 handleNoSuspendCoroutine(Shape.CoroBegin, Shape.FrameTy); 711 removeCoroEnds(Shape); 712 postSplitCleanup(F); 713 coro::updateCallGraph(F, {}, CG, SCC); 714 return; 715 } 716 717 auto *ResumeEntry = createResumeEntryBlock(F, Shape); 718 auto ResumeClone = createClone(F, ".resume", Shape, ResumeEntry, 0); 719 auto DestroyClone = createClone(F, ".destroy", Shape, ResumeEntry, 1); 720 auto CleanupClone = createClone(F, ".cleanup", Shape, ResumeEntry, 2); 721 722 // We no longer need coro.end in F. 723 removeCoroEnds(Shape); 724 725 postSplitCleanup(F); 726 postSplitCleanup(*ResumeClone); 727 postSplitCleanup(*DestroyClone); 728 postSplitCleanup(*CleanupClone); 729 730 addMustTailToCoroResumes(*ResumeClone); 731 732 // Store addresses resume/destroy/cleanup functions in the coroutine frame. 733 updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone); 734 735 // Create a constant array referring to resume/destroy/clone functions pointed 736 // by the last argument of @llvm.coro.info, so that CoroElide pass can 737 // determined correct function to call. 738 setCoroInfo(F, Shape.CoroBegin, {ResumeClone, DestroyClone, CleanupClone}); 739 740 // Update call graph and add the functions we created to the SCC. 741 coro::updateCallGraph(F, {ResumeClone, DestroyClone, CleanupClone}, CG, SCC); 742 } 743 744 // When we see the coroutine the first time, we insert an indirect call to a 745 // devirt trigger function and mark the coroutine that it is now ready for 746 // split. 747 static void prepareForSplit(Function &F, CallGraph &CG) { 748 Module &M = *F.getParent(); 749 #ifndef NDEBUG 750 Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN); 751 assert(DevirtFn && "coro.devirt.trigger function not found"); 752 #endif 753 754 F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT); 755 756 // Insert an indirect call sequence that will be devirtualized by CoroElide 757 // pass: 758 // %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1) 759 // %1 = bitcast i8* %0 to void(i8*)* 760 // call void %1(i8* null) 761 coro::LowererBase Lowerer(M); 762 Instruction *InsertPt = F.getEntryBlock().getTerminator(); 763 auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(F.getContext())); 764 auto *DevirtFnAddr = 765 Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt); 766 auto *IndirectCall = CallInst::Create(DevirtFnAddr, Null, "", InsertPt); 767 768 // Update CG graph with an indirect call we just added. 769 CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode()); 770 } 771 772 // Make sure that there is a devirtualization trigger function that CoroSplit 773 // pass uses the force restart CGSCC pipeline. If devirt trigger function is not 774 // found, we will create one and add it to the current SCC. 775 static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) { 776 Module &M = CG.getModule(); 777 if (M.getFunction(CORO_DEVIRT_TRIGGER_FN)) 778 return; 779 780 LLVMContext &C = M.getContext(); 781 auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C), 782 /*IsVarArgs=*/false); 783 Function *DevirtFn = 784 Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage, 785 CORO_DEVIRT_TRIGGER_FN, &M); 786 DevirtFn->addFnAttr(Attribute::AlwaysInline); 787 auto *Entry = BasicBlock::Create(C, "entry", DevirtFn); 788 ReturnInst::Create(C, Entry); 789 790 auto *Node = CG.getOrInsertFunction(DevirtFn); 791 792 SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end()); 793 Nodes.push_back(Node); 794 SCC.initialize(Nodes); 795 } 796 797 //===----------------------------------------------------------------------===// 798 // Top Level Driver 799 //===----------------------------------------------------------------------===// 800 801 namespace { 802 803 struct CoroSplit : public CallGraphSCCPass { 804 static char ID; // Pass identification, replacement for typeid 805 806 CoroSplit() : CallGraphSCCPass(ID) { 807 initializeCoroSplitPass(*PassRegistry::getPassRegistry()); 808 } 809 810 bool Run = false; 811 812 // A coroutine is identified by the presence of coro.begin intrinsic, if 813 // we don't have any, this pass has nothing to do. 814 bool doInitialization(CallGraph &CG) override { 815 Run = coro::declaresIntrinsics(CG.getModule(), {"llvm.coro.begin"}); 816 return CallGraphSCCPass::doInitialization(CG); 817 } 818 819 bool runOnSCC(CallGraphSCC &SCC) override { 820 if (!Run) 821 return false; 822 823 // Find coroutines for processing. 824 SmallVector<Function *, 4> Coroutines; 825 for (CallGraphNode *CGN : SCC) 826 if (auto *F = CGN->getFunction()) 827 if (F->hasFnAttribute(CORO_PRESPLIT_ATTR)) 828 Coroutines.push_back(F); 829 830 if (Coroutines.empty()) 831 return false; 832 833 CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); 834 createDevirtTriggerFunc(CG, SCC); 835 836 for (Function *F : Coroutines) { 837 Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR); 838 StringRef Value = Attr.getValueAsString(); 839 DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName() 840 << "' state: " << Value << "\n"); 841 if (Value == UNPREPARED_FOR_SPLIT) { 842 prepareForSplit(*F, CG); 843 continue; 844 } 845 F->removeFnAttr(CORO_PRESPLIT_ATTR); 846 splitCoroutine(*F, CG, SCC); 847 } 848 return true; 849 } 850 851 void getAnalysisUsage(AnalysisUsage &AU) const override { 852 CallGraphSCCPass::getAnalysisUsage(AU); 853 } 854 855 StringRef getPassName() const override { return "Coroutine Splitting"; } 856 }; 857 858 } // end anonymous namespace 859 860 char CoroSplit::ID = 0; 861 862 INITIALIZE_PASS( 863 CoroSplit, "coro-split", 864 "Split coroutine into a set of functions driving its state machine", false, 865 false) 866 867 Pass *llvm::createCoroSplitPass() { return new CoroSplit(); } 868