1 //===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This pass builds the coroutine frame and outlines resume and destroy parts 9 // of the coroutine into separate functions. 10 // 11 // We present a coroutine to an LLVM as an ordinary function with suspension 12 // points marked up with intrinsics. We let the optimizer party on the coroutine 13 // as a single function for as long as possible. Shortly before the coroutine is 14 // eligible to be inlined into its callers, we split up the coroutine into parts 15 // corresponding to an initial, resume and destroy invocations of the coroutine, 16 // add them to the current SCC and restart the IPO pipeline to optimize the 17 // coroutine subfunctions we extracted before proceeding to the caller of the 18 // coroutine. 19 //===----------------------------------------------------------------------===// 20 21 #include "CoroInstr.h" 22 #include "CoroInternal.h" 23 #include "llvm/ADT/DenseMap.h" 24 #include "llvm/ADT/SmallPtrSet.h" 25 #include "llvm/ADT/SmallVector.h" 26 #include "llvm/ADT/StringRef.h" 27 #include "llvm/ADT/Twine.h" 28 #include "llvm/Analysis/CallGraph.h" 29 #include "llvm/Analysis/CallGraphSCCPass.h" 30 #include "llvm/Transforms/Utils/Local.h" 31 #include "llvm/IR/Argument.h" 32 #include "llvm/IR/Attributes.h" 33 #include "llvm/IR/BasicBlock.h" 34 #include "llvm/IR/CFG.h" 35 #include "llvm/IR/CallSite.h" 36 #include "llvm/IR/CallingConv.h" 37 #include "llvm/IR/Constants.h" 38 #include "llvm/IR/DataLayout.h" 39 #include "llvm/IR/DerivedTypes.h" 40 #include "llvm/IR/Function.h" 41 #include "llvm/IR/GlobalValue.h" 42 #include "llvm/IR/GlobalVariable.h" 43 #include "llvm/IR/IRBuilder.h" 44 #include "llvm/IR/InstIterator.h" 45 #include "llvm/IR/InstrTypes.h" 46 #include "llvm/IR/Instruction.h" 47 #include "llvm/IR/Instructions.h" 48 #include "llvm/IR/IntrinsicInst.h" 49 #include "llvm/IR/LLVMContext.h" 50 #include "llvm/IR/LegacyPassManager.h" 51 #include "llvm/IR/Module.h" 52 #include "llvm/IR/Type.h" 53 #include "llvm/IR/Value.h" 54 #include "llvm/IR/Verifier.h" 55 #include "llvm/Pass.h" 56 #include "llvm/Support/Casting.h" 57 #include "llvm/Support/Debug.h" 58 #include "llvm/Support/raw_ostream.h" 59 #include "llvm/Transforms/Scalar.h" 60 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 61 #include "llvm/Transforms/Utils/Cloning.h" 62 #include "llvm/Transforms/Utils/ValueMapper.h" 63 #include <cassert> 64 #include <cstddef> 65 #include <cstdint> 66 #include <initializer_list> 67 #include <iterator> 68 69 using namespace llvm; 70 71 #define DEBUG_TYPE "coro-split" 72 73 // Create an entry block for a resume function with a switch that will jump to 74 // suspend points. 75 static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) { 76 LLVMContext &C = F.getContext(); 77 78 // resume.entry: 79 // %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0, 80 // i32 2 81 // % index = load i32, i32* %index.addr 82 // switch i32 %index, label %unreachable [ 83 // i32 0, label %resume.0 84 // i32 1, label %resume.1 85 // ... 86 // ] 87 88 auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F); 89 auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F); 90 91 IRBuilder<> Builder(NewEntry); 92 auto *FramePtr = Shape.FramePtr; 93 auto *FrameTy = Shape.FrameTy; 94 auto *GepIndex = Builder.CreateConstInBoundsGEP2_32( 95 FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr"); 96 auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index"); 97 auto *Switch = 98 Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size()); 99 Shape.ResumeSwitch = Switch; 100 101 size_t SuspendIndex = 0; 102 for (CoroSuspendInst *S : Shape.CoroSuspends) { 103 ConstantInt *IndexVal = Shape.getIndex(SuspendIndex); 104 105 // Replace CoroSave with a store to Index: 106 // %index.addr = getelementptr %f.frame... (index field number) 107 // store i32 0, i32* %index.addr1 108 auto *Save = S->getCoroSave(); 109 Builder.SetInsertPoint(Save); 110 if (S->isFinal()) { 111 // Final suspend point is represented by storing zero in ResumeFnAddr. 112 auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, 113 0, "ResumeFn.addr"); 114 auto *NullPtr = ConstantPointerNull::get(cast<PointerType>( 115 cast<PointerType>(GepIndex->getType())->getElementType())); 116 Builder.CreateStore(NullPtr, GepIndex); 117 } else { 118 auto *GepIndex = Builder.CreateConstInBoundsGEP2_32( 119 FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr"); 120 Builder.CreateStore(IndexVal, GepIndex); 121 } 122 Save->replaceAllUsesWith(ConstantTokenNone::get(C)); 123 Save->eraseFromParent(); 124 125 // Split block before and after coro.suspend and add a jump from an entry 126 // switch: 127 // 128 // whateverBB: 129 // whatever 130 // %0 = call i8 @llvm.coro.suspend(token none, i1 false) 131 // switch i8 %0, label %suspend[i8 0, label %resume 132 // i8 1, label %cleanup] 133 // becomes: 134 // 135 // whateverBB: 136 // whatever 137 // br label %resume.0.landing 138 // 139 // resume.0: ; <--- jump from the switch in the resume.entry 140 // %0 = tail call i8 @llvm.coro.suspend(token none, i1 false) 141 // br label %resume.0.landing 142 // 143 // resume.0.landing: 144 // %1 = phi i8[-1, %whateverBB], [%0, %resume.0] 145 // switch i8 % 1, label %suspend [i8 0, label %resume 146 // i8 1, label %cleanup] 147 148 auto *SuspendBB = S->getParent(); 149 auto *ResumeBB = 150 SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex)); 151 auto *LandingBB = ResumeBB->splitBasicBlock( 152 S->getNextNode(), ResumeBB->getName() + Twine(".landing")); 153 Switch->addCase(IndexVal, ResumeBB); 154 155 cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB); 156 auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front()); 157 S->replaceAllUsesWith(PN); 158 PN->addIncoming(Builder.getInt8(-1), SuspendBB); 159 PN->addIncoming(S, ResumeBB); 160 161 ++SuspendIndex; 162 } 163 164 Builder.SetInsertPoint(UnreachBB); 165 Builder.CreateUnreachable(); 166 167 return NewEntry; 168 } 169 170 // In Resumers, we replace fallthrough coro.end with ret void and delete the 171 // rest of the block. 172 static void replaceFallthroughCoroEnd(IntrinsicInst *End, 173 ValueToValueMapTy &VMap) { 174 auto *NewE = cast<IntrinsicInst>(VMap[End]); 175 ReturnInst::Create(NewE->getContext(), nullptr, NewE); 176 177 // Remove the rest of the block, by splitting it into an unreachable block. 178 auto *BB = NewE->getParent(); 179 BB->splitBasicBlock(NewE); 180 BB->getTerminator()->eraseFromParent(); 181 } 182 183 // In Resumers, we replace unwind coro.end with True to force the immediate 184 // unwind to caller. 185 static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) { 186 if (Shape.CoroEnds.empty()) 187 return; 188 189 LLVMContext &Context = Shape.CoroEnds.front()->getContext(); 190 auto *True = ConstantInt::getTrue(Context); 191 for (CoroEndInst *CE : Shape.CoroEnds) { 192 if (!CE->isUnwind()) 193 continue; 194 195 auto *NewCE = cast<IntrinsicInst>(VMap[CE]); 196 197 // If coro.end has an associated bundle, add cleanupret instruction. 198 if (auto Bundle = NewCE->getOperandBundle(LLVMContext::OB_funclet)) { 199 Value *FromPad = Bundle->Inputs[0]; 200 auto *CleanupRet = CleanupReturnInst::Create(FromPad, nullptr, NewCE); 201 NewCE->getParent()->splitBasicBlock(NewCE); 202 CleanupRet->getParent()->getTerminator()->eraseFromParent(); 203 } 204 205 NewCE->replaceAllUsesWith(True); 206 NewCE->eraseFromParent(); 207 } 208 } 209 210 // Rewrite final suspend point handling. We do not use suspend index to 211 // represent the final suspend point. Instead we zero-out ResumeFnAddr in the 212 // coroutine frame, since it is undefined behavior to resume a coroutine 213 // suspended at the final suspend point. Thus, in the resume function, we can 214 // simply remove the last case (when coro::Shape is built, the final suspend 215 // point (if present) is always the last element of CoroSuspends array). 216 // In the destroy function, we add a code sequence to check if ResumeFnAddress 217 // is Null, and if so, jump to the appropriate label to handle cleanup from the 218 // final suspend point. 219 static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr, 220 coro::Shape &Shape, SwitchInst *Switch, 221 bool IsDestroy) { 222 assert(Shape.HasFinalSuspend); 223 auto FinalCaseIt = std::prev(Switch->case_end()); 224 BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor(); 225 Switch->removeCase(FinalCaseIt); 226 if (IsDestroy) { 227 BasicBlock *OldSwitchBB = Switch->getParent(); 228 auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch"); 229 Builder.SetInsertPoint(OldSwitchBB->getTerminator()); 230 auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, FramePtr, 231 0, 0, "ResumeFn.addr"); 232 auto *Load = Builder.CreateLoad( 233 Shape.FrameTy->getElementType(coro::Shape::ResumeField), GepIndex); 234 auto *NullPtr = 235 ConstantPointerNull::get(cast<PointerType>(Load->getType())); 236 auto *Cond = Builder.CreateICmpEQ(Load, NullPtr); 237 Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB); 238 OldSwitchBB->getTerminator()->eraseFromParent(); 239 } 240 } 241 242 // Create a resume clone by cloning the body of the original function, setting 243 // new entry block and replacing coro.suspend an appropriate value to force 244 // resume or cleanup pass for every suspend point. 245 static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape, 246 BasicBlock *ResumeEntry, int8_t FnIndex) { 247 Module *M = F.getParent(); 248 auto *FrameTy = Shape.FrameTy; 249 auto *FnPtrTy = cast<PointerType>(FrameTy->getElementType(0)); 250 auto *FnTy = cast<FunctionType>(FnPtrTy->getElementType()); 251 252 Function *NewF = 253 Function::Create(FnTy, GlobalValue::LinkageTypes::ExternalLinkage, 254 F.getName() + Suffix, M); 255 NewF->addParamAttr(0, Attribute::NonNull); 256 NewF->addParamAttr(0, Attribute::NoAlias); 257 258 ValueToValueMapTy VMap; 259 // Replace all args with undefs. The buildCoroutineFrame algorithm already 260 // rewritten access to the args that occurs after suspend points with loads 261 // and stores to/from the coroutine frame. 262 for (Argument &A : F.args()) 263 VMap[&A] = UndefValue::get(A.getType()); 264 265 SmallVector<ReturnInst *, 4> Returns; 266 267 CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns); 268 NewF->setLinkage(GlobalValue::LinkageTypes::InternalLinkage); 269 270 // Remove old returns. 271 for (ReturnInst *Return : Returns) 272 changeToUnreachable(Return, /*UseLLVMTrap=*/false); 273 274 // Remove old return attributes. 275 NewF->removeAttributes( 276 AttributeList::ReturnIndex, 277 AttributeFuncs::typeIncompatible(NewF->getReturnType())); 278 279 // Make AllocaSpillBlock the new entry block. 280 auto *SwitchBB = cast<BasicBlock>(VMap[ResumeEntry]); 281 auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]); 282 Entry->moveBefore(&NewF->getEntryBlock()); 283 Entry->getTerminator()->eraseFromParent(); 284 BranchInst::Create(SwitchBB, Entry); 285 Entry->setName("entry" + Suffix); 286 287 // Clear all predecessors of the new entry block. 288 auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]); 289 Entry->replaceAllUsesWith(Switch->getDefaultDest()); 290 291 IRBuilder<> Builder(&NewF->getEntryBlock().front()); 292 293 // Remap frame pointer. 294 Argument *NewFramePtr = &*NewF->arg_begin(); 295 Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]); 296 NewFramePtr->takeName(OldFramePtr); 297 OldFramePtr->replaceAllUsesWith(NewFramePtr); 298 299 // Remap vFrame pointer. 300 auto *NewVFrame = Builder.CreateBitCast( 301 NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame"); 302 Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]); 303 OldVFrame->replaceAllUsesWith(NewVFrame); 304 305 // Rewrite final suspend handling as it is not done via switch (allows to 306 // remove final case from the switch, since it is undefined behavior to resume 307 // the coroutine suspended at the final suspend point. 308 if (Shape.HasFinalSuspend) { 309 auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]); 310 bool IsDestroy = FnIndex != 0; 311 handleFinalSuspend(Builder, NewFramePtr, Shape, Switch, IsDestroy); 312 } 313 314 // Replace coro suspend with the appropriate resume index. 315 // Replacing coro.suspend with (0) will result in control flow proceeding to 316 // a resume label associated with a suspend point, replacing it with (1) will 317 // result in control flow proceeding to a cleanup label associated with this 318 // suspend point. 319 auto *NewValue = Builder.getInt8(FnIndex ? 1 : 0); 320 for (CoroSuspendInst *CS : Shape.CoroSuspends) { 321 auto *MappedCS = cast<CoroSuspendInst>(VMap[CS]); 322 MappedCS->replaceAllUsesWith(NewValue); 323 MappedCS->eraseFromParent(); 324 } 325 326 // Remove coro.end intrinsics. 327 replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap); 328 replaceUnwindCoroEnds(Shape, VMap); 329 // Eliminate coro.free from the clones, replacing it with 'null' in cleanup, 330 // to suppress deallocation code. 331 coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]), 332 /*Elide=*/FnIndex == 2); 333 334 NewF->setCallingConv(CallingConv::Fast); 335 336 return NewF; 337 } 338 339 static void removeCoroEnds(coro::Shape &Shape) { 340 if (Shape.CoroEnds.empty()) 341 return; 342 343 LLVMContext &Context = Shape.CoroEnds.front()->getContext(); 344 auto *False = ConstantInt::getFalse(Context); 345 346 for (CoroEndInst *CE : Shape.CoroEnds) { 347 CE->replaceAllUsesWith(False); 348 CE->eraseFromParent(); 349 } 350 } 351 352 static void replaceFrameSize(coro::Shape &Shape) { 353 if (Shape.CoroSizes.empty()) 354 return; 355 356 // In the same function all coro.sizes should have the same result type. 357 auto *SizeIntrin = Shape.CoroSizes.back(); 358 Module *M = SizeIntrin->getModule(); 359 const DataLayout &DL = M->getDataLayout(); 360 auto Size = DL.getTypeAllocSize(Shape.FrameTy); 361 auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size); 362 363 for (CoroSizeInst *CS : Shape.CoroSizes) { 364 CS->replaceAllUsesWith(SizeConstant); 365 CS->eraseFromParent(); 366 } 367 } 368 369 // Create a global constant array containing pointers to functions provided and 370 // set Info parameter of CoroBegin to point at this constant. Example: 371 // 372 // @f.resumers = internal constant [2 x void(%f.frame*)*] 373 // [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy] 374 // define void @f() { 375 // ... 376 // call i8* @llvm.coro.begin(i8* null, i32 0, i8* null, 377 // i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*)) 378 // 379 // Assumes that all the functions have the same signature. 380 static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin, 381 std::initializer_list<Function *> Fns) { 382 SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end()); 383 assert(!Args.empty()); 384 Function *Part = *Fns.begin(); 385 Module *M = Part->getParent(); 386 auto *ArrTy = ArrayType::get(Part->getType(), Args.size()); 387 388 auto *ConstVal = ConstantArray::get(ArrTy, Args); 389 auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true, 390 GlobalVariable::PrivateLinkage, ConstVal, 391 F.getName() + Twine(".resumers")); 392 393 // Update coro.begin instruction to refer to this constant. 394 LLVMContext &C = F.getContext(); 395 auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C)); 396 CoroBegin->getId()->setInfo(BC); 397 } 398 399 // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame. 400 static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn, 401 Function *DestroyFn, Function *CleanupFn) { 402 IRBuilder<> Builder(Shape.FramePtr->getNextNode()); 403 auto *ResumeAddr = Builder.CreateConstInBoundsGEP2_32( 404 Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::ResumeField, 405 "resume.addr"); 406 Builder.CreateStore(ResumeFn, ResumeAddr); 407 408 Value *DestroyOrCleanupFn = DestroyFn; 409 410 CoroIdInst *CoroId = Shape.CoroBegin->getId(); 411 if (CoroAllocInst *CA = CoroId->getCoroAlloc()) { 412 // If there is a CoroAlloc and it returns false (meaning we elide the 413 // allocation, use CleanupFn instead of DestroyFn). 414 DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn); 415 } 416 417 auto *DestroyAddr = Builder.CreateConstInBoundsGEP2_32( 418 Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::DestroyField, 419 "destroy.addr"); 420 Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr); 421 } 422 423 static void postSplitCleanup(Function &F) { 424 removeUnreachableBlocks(F); 425 legacy::FunctionPassManager FPM(F.getParent()); 426 427 FPM.add(createVerifierPass()); 428 FPM.add(createSCCPPass()); 429 FPM.add(createCFGSimplificationPass()); 430 FPM.add(createEarlyCSEPass()); 431 FPM.add(createCFGSimplificationPass()); 432 433 FPM.doInitialization(); 434 FPM.run(F); 435 FPM.doFinalization(); 436 } 437 438 // Assuming we arrived at the block NewBlock from Prev instruction, store 439 // PHI's incoming values in the ResolvedValues map. 440 static void 441 scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock, 442 DenseMap<Value *, Value *> &ResolvedValues) { 443 auto *PrevBB = Prev->getParent(); 444 for (PHINode &PN : NewBlock->phis()) { 445 auto V = PN.getIncomingValueForBlock(PrevBB); 446 // See if we already resolved it. 447 auto VI = ResolvedValues.find(V); 448 if (VI != ResolvedValues.end()) 449 V = VI->second; 450 // Remember the value. 451 ResolvedValues[&PN] = V; 452 } 453 } 454 455 // Replace a sequence of branches leading to a ret, with a clone of a ret 456 // instruction. Suspend instruction represented by a switch, track the PHI 457 // values and select the correct case successor when possible. 458 static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { 459 DenseMap<Value *, Value *> ResolvedValues; 460 461 Instruction *I = InitialInst; 462 while (I->isTerminator()) { 463 if (isa<ReturnInst>(I)) { 464 if (I != InitialInst) 465 ReplaceInstWithInst(InitialInst, I->clone()); 466 return true; 467 } 468 if (auto *BR = dyn_cast<BranchInst>(I)) { 469 if (BR->isUnconditional()) { 470 BasicBlock *BB = BR->getSuccessor(0); 471 scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); 472 I = BB->getFirstNonPHIOrDbgOrLifetime(); 473 continue; 474 } 475 } else if (auto *SI = dyn_cast<SwitchInst>(I)) { 476 Value *V = SI->getCondition(); 477 auto it = ResolvedValues.find(V); 478 if (it != ResolvedValues.end()) 479 V = it->second; 480 if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) { 481 BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor(); 482 scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); 483 I = BB->getFirstNonPHIOrDbgOrLifetime(); 484 continue; 485 } 486 } 487 return false; 488 } 489 return false; 490 } 491 492 // Add musttail to any resume instructions that is immediately followed by a 493 // suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call 494 // for symmetrical coroutine control transfer (C++ Coroutines TS extension). 495 // This transformation is done only in the resume part of the coroutine that has 496 // identical signature and calling convention as the coro.resume call. 497 static void addMustTailToCoroResumes(Function &F) { 498 bool changed = false; 499 500 // Collect potential resume instructions. 501 SmallVector<CallInst *, 4> Resumes; 502 for (auto &I : instructions(F)) 503 if (auto *Call = dyn_cast<CallInst>(&I)) 504 if (auto *CalledValue = Call->getCalledValue()) 505 // CoroEarly pass replaced coro resumes with indirect calls to an 506 // address return by CoroSubFnInst intrinsic. See if it is one of those. 507 if (isa<CoroSubFnInst>(CalledValue->stripPointerCasts())) 508 Resumes.push_back(Call); 509 510 // Set musttail on those that are followed by a ret instruction. 511 for (CallInst *Call : Resumes) 512 if (simplifyTerminatorLeadingToRet(Call->getNextNode())) { 513 Call->setTailCallKind(CallInst::TCK_MustTail); 514 changed = true; 515 } 516 517 if (changed) 518 removeUnreachableBlocks(F); 519 } 520 521 // Coroutine has no suspend points. Remove heap allocation for the coroutine 522 // frame if possible. 523 static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) { 524 auto *CoroId = CoroBegin->getId(); 525 auto *AllocInst = CoroId->getCoroAlloc(); 526 coro::replaceCoroFree(CoroId, /*Elide=*/AllocInst != nullptr); 527 if (AllocInst) { 528 IRBuilder<> Builder(AllocInst); 529 // FIXME: Need to handle overaligned members. 530 auto *Frame = Builder.CreateAlloca(FrameTy); 531 auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy()); 532 AllocInst->replaceAllUsesWith(Builder.getFalse()); 533 AllocInst->eraseFromParent(); 534 CoroBegin->replaceAllUsesWith(VFrame); 535 } else { 536 CoroBegin->replaceAllUsesWith(CoroBegin->getMem()); 537 } 538 CoroBegin->eraseFromParent(); 539 } 540 541 // SimplifySuspendPoint needs to check that there is no calls between 542 // coro_save and coro_suspend, since any of the calls may potentially resume 543 // the coroutine and if that is the case we cannot eliminate the suspend point. 544 static bool hasCallsInBlockBetween(Instruction *From, Instruction *To) { 545 for (Instruction *I = From; I != To; I = I->getNextNode()) { 546 // Assume that no intrinsic can resume the coroutine. 547 if (isa<IntrinsicInst>(I)) 548 continue; 549 550 if (CallSite(I)) 551 return true; 552 } 553 return false; 554 } 555 556 static bool hasCallsInBlocksBetween(BasicBlock *SaveBB, BasicBlock *ResDesBB) { 557 SmallPtrSet<BasicBlock *, 8> Set; 558 SmallVector<BasicBlock *, 8> Worklist; 559 560 Set.insert(SaveBB); 561 Worklist.push_back(ResDesBB); 562 563 // Accumulate all blocks between SaveBB and ResDesBB. Because CoroSaveIntr 564 // returns a token consumed by suspend instruction, all blocks in between 565 // will have to eventually hit SaveBB when going backwards from ResDesBB. 566 while (!Worklist.empty()) { 567 auto *BB = Worklist.pop_back_val(); 568 Set.insert(BB); 569 for (auto *Pred : predecessors(BB)) 570 if (Set.count(Pred) == 0) 571 Worklist.push_back(Pred); 572 } 573 574 // SaveBB and ResDesBB are checked separately in hasCallsBetween. 575 Set.erase(SaveBB); 576 Set.erase(ResDesBB); 577 578 for (auto *BB : Set) 579 if (hasCallsInBlockBetween(BB->getFirstNonPHI(), nullptr)) 580 return true; 581 582 return false; 583 } 584 585 static bool hasCallsBetween(Instruction *Save, Instruction *ResumeOrDestroy) { 586 auto *SaveBB = Save->getParent(); 587 auto *ResumeOrDestroyBB = ResumeOrDestroy->getParent(); 588 589 if (SaveBB == ResumeOrDestroyBB) 590 return hasCallsInBlockBetween(Save->getNextNode(), ResumeOrDestroy); 591 592 // Any calls from Save to the end of the block? 593 if (hasCallsInBlockBetween(Save->getNextNode(), nullptr)) 594 return true; 595 596 // Any calls from begging of the block up to ResumeOrDestroy? 597 if (hasCallsInBlockBetween(ResumeOrDestroyBB->getFirstNonPHI(), 598 ResumeOrDestroy)) 599 return true; 600 601 // Any calls in all of the blocks between SaveBB and ResumeOrDestroyBB? 602 if (hasCallsInBlocksBetween(SaveBB, ResumeOrDestroyBB)) 603 return true; 604 605 return false; 606 } 607 608 // If a SuspendIntrin is preceded by Resume or Destroy, we can eliminate the 609 // suspend point and replace it with nornal control flow. 610 static bool simplifySuspendPoint(CoroSuspendInst *Suspend, 611 CoroBeginInst *CoroBegin) { 612 Instruction *Prev = Suspend->getPrevNode(); 613 if (!Prev) { 614 auto *Pred = Suspend->getParent()->getSinglePredecessor(); 615 if (!Pred) 616 return false; 617 Prev = Pred->getTerminator(); 618 } 619 620 CallSite CS{Prev}; 621 if (!CS) 622 return false; 623 624 auto *CallInstr = CS.getInstruction(); 625 626 auto *Callee = CS.getCalledValue()->stripPointerCasts(); 627 628 // See if the callsite is for resumption or destruction of the coroutine. 629 auto *SubFn = dyn_cast<CoroSubFnInst>(Callee); 630 if (!SubFn) 631 return false; 632 633 // Does not refer to the current coroutine, we cannot do anything with it. 634 if (SubFn->getFrame() != CoroBegin) 635 return false; 636 637 // See if the transformation is safe. Specifically, see if there are any 638 // calls in between Save and CallInstr. They can potenitally resume the 639 // coroutine rendering this optimization unsafe. 640 auto *Save = Suspend->getCoroSave(); 641 if (hasCallsBetween(Save, CallInstr)) 642 return false; 643 644 // Replace llvm.coro.suspend with the value that results in resumption over 645 // the resume or cleanup path. 646 Suspend->replaceAllUsesWith(SubFn->getRawIndex()); 647 Suspend->eraseFromParent(); 648 Save->eraseFromParent(); 649 650 // No longer need a call to coro.resume or coro.destroy. 651 if (auto *Invoke = dyn_cast<InvokeInst>(CallInstr)) { 652 BranchInst::Create(Invoke->getNormalDest(), Invoke); 653 } 654 655 // Grab the CalledValue from CS before erasing the CallInstr. 656 auto *CalledValue = CS.getCalledValue(); 657 CallInstr->eraseFromParent(); 658 659 // If no more users remove it. Usually it is a bitcast of SubFn. 660 if (CalledValue != SubFn && CalledValue->user_empty()) 661 if (auto *I = dyn_cast<Instruction>(CalledValue)) 662 I->eraseFromParent(); 663 664 // Now we are good to remove SubFn. 665 if (SubFn->user_empty()) 666 SubFn->eraseFromParent(); 667 668 return true; 669 } 670 671 // Remove suspend points that are simplified. 672 static void simplifySuspendPoints(coro::Shape &Shape) { 673 auto &S = Shape.CoroSuspends; 674 size_t I = 0, N = S.size(); 675 if (N == 0) 676 return; 677 while (true) { 678 if (simplifySuspendPoint(S[I], Shape.CoroBegin)) { 679 if (--N == I) 680 break; 681 std::swap(S[I], S[N]); 682 continue; 683 } 684 if (++I == N) 685 break; 686 } 687 S.resize(N); 688 } 689 690 static SmallPtrSet<BasicBlock *, 4> getCoroBeginPredBlocks(CoroBeginInst *CB) { 691 // Collect all blocks that we need to look for instructions to relocate. 692 SmallPtrSet<BasicBlock *, 4> RelocBlocks; 693 SmallVector<BasicBlock *, 4> Work; 694 Work.push_back(CB->getParent()); 695 696 do { 697 BasicBlock *Current = Work.pop_back_val(); 698 for (BasicBlock *BB : predecessors(Current)) 699 if (RelocBlocks.count(BB) == 0) { 700 RelocBlocks.insert(BB); 701 Work.push_back(BB); 702 } 703 } while (!Work.empty()); 704 return RelocBlocks; 705 } 706 707 static SmallPtrSet<Instruction *, 8> 708 getNotRelocatableInstructions(CoroBeginInst *CoroBegin, 709 SmallPtrSetImpl<BasicBlock *> &RelocBlocks) { 710 SmallPtrSet<Instruction *, 8> DoNotRelocate; 711 // Collect all instructions that we should not relocate 712 SmallVector<Instruction *, 8> Work; 713 714 // Start with CoroBegin and terminators of all preceding blocks. 715 Work.push_back(CoroBegin); 716 BasicBlock *CoroBeginBB = CoroBegin->getParent(); 717 for (BasicBlock *BB : RelocBlocks) 718 if (BB != CoroBeginBB) 719 Work.push_back(BB->getTerminator()); 720 721 // For every instruction in the Work list, place its operands in DoNotRelocate 722 // set. 723 do { 724 Instruction *Current = Work.pop_back_val(); 725 LLVM_DEBUG(dbgs() << "CoroSplit: Will not relocate: " << *Current << "\n"); 726 DoNotRelocate.insert(Current); 727 for (Value *U : Current->operands()) { 728 auto *I = dyn_cast<Instruction>(U); 729 if (!I) 730 continue; 731 732 if (auto *A = dyn_cast<AllocaInst>(I)) { 733 // Stores to alloca instructions that occur before the coroutine frame 734 // is allocated should not be moved; the stored values may be used by 735 // the coroutine frame allocator. The operands to those stores must also 736 // remain in place. 737 for (const auto &User : A->users()) 738 if (auto *SI = dyn_cast<llvm::StoreInst>(User)) 739 if (RelocBlocks.count(SI->getParent()) != 0 && 740 DoNotRelocate.count(SI) == 0) { 741 Work.push_back(SI); 742 DoNotRelocate.insert(SI); 743 } 744 continue; 745 } 746 747 if (DoNotRelocate.count(I) == 0) { 748 Work.push_back(I); 749 DoNotRelocate.insert(I); 750 } 751 } 752 } while (!Work.empty()); 753 return DoNotRelocate; 754 } 755 756 static void relocateInstructionBefore(CoroBeginInst *CoroBegin, Function &F) { 757 // Analyze which non-alloca instructions are needed for allocation and 758 // relocate the rest to after coro.begin. We need to do it, since some of the 759 // targets of those instructions may be placed into coroutine frame memory 760 // for which becomes available after coro.begin intrinsic. 761 762 auto BlockSet = getCoroBeginPredBlocks(CoroBegin); 763 auto DoNotRelocateSet = getNotRelocatableInstructions(CoroBegin, BlockSet); 764 765 Instruction *InsertPt = CoroBegin->getNextNode(); 766 BasicBlock &BB = F.getEntryBlock(); // TODO: Look at other blocks as well. 767 for (auto B = BB.begin(), E = BB.end(); B != E;) { 768 Instruction &I = *B++; 769 if (isa<AllocaInst>(&I)) 770 continue; 771 if (&I == CoroBegin) 772 break; 773 if (DoNotRelocateSet.count(&I)) 774 continue; 775 I.moveBefore(InsertPt); 776 } 777 } 778 779 static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) { 780 EliminateUnreachableBlocks(F); 781 782 coro::Shape Shape(F); 783 if (!Shape.CoroBegin) 784 return; 785 786 simplifySuspendPoints(Shape); 787 relocateInstructionBefore(Shape.CoroBegin, F); 788 buildCoroutineFrame(F, Shape); 789 replaceFrameSize(Shape); 790 791 // If there are no suspend points, no split required, just remove 792 // the allocation and deallocation blocks, they are not needed. 793 if (Shape.CoroSuspends.empty()) { 794 handleNoSuspendCoroutine(Shape.CoroBegin, Shape.FrameTy); 795 removeCoroEnds(Shape); 796 postSplitCleanup(F); 797 coro::updateCallGraph(F, {}, CG, SCC); 798 return; 799 } 800 801 auto *ResumeEntry = createResumeEntryBlock(F, Shape); 802 auto ResumeClone = createClone(F, ".resume", Shape, ResumeEntry, 0); 803 auto DestroyClone = createClone(F, ".destroy", Shape, ResumeEntry, 1); 804 auto CleanupClone = createClone(F, ".cleanup", Shape, ResumeEntry, 2); 805 806 // We no longer need coro.end in F. 807 removeCoroEnds(Shape); 808 809 postSplitCleanup(F); 810 postSplitCleanup(*ResumeClone); 811 postSplitCleanup(*DestroyClone); 812 postSplitCleanup(*CleanupClone); 813 814 addMustTailToCoroResumes(*ResumeClone); 815 816 // Store addresses resume/destroy/cleanup functions in the coroutine frame. 817 updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone); 818 819 // Create a constant array referring to resume/destroy/clone functions pointed 820 // by the last argument of @llvm.coro.info, so that CoroElide pass can 821 // determined correct function to call. 822 setCoroInfo(F, Shape.CoroBegin, {ResumeClone, DestroyClone, CleanupClone}); 823 824 // Update call graph and add the functions we created to the SCC. 825 coro::updateCallGraph(F, {ResumeClone, DestroyClone, CleanupClone}, CG, SCC); 826 } 827 828 // When we see the coroutine the first time, we insert an indirect call to a 829 // devirt trigger function and mark the coroutine that it is now ready for 830 // split. 831 static void prepareForSplit(Function &F, CallGraph &CG) { 832 Module &M = *F.getParent(); 833 LLVMContext &Context = F.getContext(); 834 #ifndef NDEBUG 835 Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN); 836 assert(DevirtFn && "coro.devirt.trigger function not found"); 837 #endif 838 839 F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT); 840 841 // Insert an indirect call sequence that will be devirtualized by CoroElide 842 // pass: 843 // %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1) 844 // %1 = bitcast i8* %0 to void(i8*)* 845 // call void %1(i8* null) 846 coro::LowererBase Lowerer(M); 847 Instruction *InsertPt = F.getEntryBlock().getTerminator(); 848 auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(Context)); 849 auto *DevirtFnAddr = 850 Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt); 851 FunctionType *FnTy = FunctionType::get(Type::getVoidTy(Context), 852 {Type::getInt8PtrTy(Context)}, false); 853 auto *IndirectCall = CallInst::Create(FnTy, DevirtFnAddr, Null, "", InsertPt); 854 855 // Update CG graph with an indirect call we just added. 856 CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode()); 857 } 858 859 // Make sure that there is a devirtualization trigger function that CoroSplit 860 // pass uses the force restart CGSCC pipeline. If devirt trigger function is not 861 // found, we will create one and add it to the current SCC. 862 static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) { 863 Module &M = CG.getModule(); 864 if (M.getFunction(CORO_DEVIRT_TRIGGER_FN)) 865 return; 866 867 LLVMContext &C = M.getContext(); 868 auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C), 869 /*IsVarArgs=*/false); 870 Function *DevirtFn = 871 Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage, 872 CORO_DEVIRT_TRIGGER_FN, &M); 873 DevirtFn->addFnAttr(Attribute::AlwaysInline); 874 auto *Entry = BasicBlock::Create(C, "entry", DevirtFn); 875 ReturnInst::Create(C, Entry); 876 877 auto *Node = CG.getOrInsertFunction(DevirtFn); 878 879 SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end()); 880 Nodes.push_back(Node); 881 SCC.initialize(Nodes); 882 } 883 884 //===----------------------------------------------------------------------===// 885 // Top Level Driver 886 //===----------------------------------------------------------------------===// 887 888 namespace { 889 890 struct CoroSplit : public CallGraphSCCPass { 891 static char ID; // Pass identification, replacement for typeid 892 893 CoroSplit() : CallGraphSCCPass(ID) { 894 initializeCoroSplitPass(*PassRegistry::getPassRegistry()); 895 } 896 897 bool Run = false; 898 899 // A coroutine is identified by the presence of coro.begin intrinsic, if 900 // we don't have any, this pass has nothing to do. 901 bool doInitialization(CallGraph &CG) override { 902 Run = coro::declaresIntrinsics(CG.getModule(), {"llvm.coro.begin"}); 903 return CallGraphSCCPass::doInitialization(CG); 904 } 905 906 bool runOnSCC(CallGraphSCC &SCC) override { 907 if (!Run) 908 return false; 909 910 // Find coroutines for processing. 911 SmallVector<Function *, 4> Coroutines; 912 for (CallGraphNode *CGN : SCC) 913 if (auto *F = CGN->getFunction()) 914 if (F->hasFnAttribute(CORO_PRESPLIT_ATTR)) 915 Coroutines.push_back(F); 916 917 if (Coroutines.empty()) 918 return false; 919 920 CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); 921 createDevirtTriggerFunc(CG, SCC); 922 923 for (Function *F : Coroutines) { 924 Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR); 925 StringRef Value = Attr.getValueAsString(); 926 LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName() 927 << "' state: " << Value << "\n"); 928 if (Value == UNPREPARED_FOR_SPLIT) { 929 prepareForSplit(*F, CG); 930 continue; 931 } 932 F->removeFnAttr(CORO_PRESPLIT_ATTR); 933 splitCoroutine(*F, CG, SCC); 934 } 935 return true; 936 } 937 938 void getAnalysisUsage(AnalysisUsage &AU) const override { 939 CallGraphSCCPass::getAnalysisUsage(AU); 940 } 941 942 StringRef getPassName() const override { return "Coroutine Splitting"; } 943 }; 944 945 } // end anonymous namespace 946 947 char CoroSplit::ID = 0; 948 949 INITIALIZE_PASS( 950 CoroSplit, "coro-split", 951 "Split coroutine into a set of functions driving its state machine", false, 952 false) 953 954 Pass *llvm::createCoroSplitPass() { return new CoroSplit(); } 955