1 //===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // This pass builds the coroutine frame and outlines resume and destroy parts 10 // of the coroutine into separate functions. 11 // 12 // We present a coroutine to an LLVM as an ordinary function with suspension 13 // points marked up with intrinsics. We let the optimizer party on the coroutine 14 // as a single function for as long as possible. Shortly before the coroutine is 15 // eligible to be inlined into its callers, we split up the coroutine into parts 16 // corresponding to an initial, resume and destroy invocations of the coroutine, 17 // add them to the current SCC and restart the IPO pipeline to optimize the 18 // coroutine subfunctions we extracted before proceeding to the caller of the 19 // coroutine. 20 //===----------------------------------------------------------------------===// 21 22 #include "CoroInstr.h" 23 #include "CoroInternal.h" 24 #include "llvm/ADT/DenseMap.h" 25 #include "llvm/ADT/SmallPtrSet.h" 26 #include "llvm/ADT/SmallVector.h" 27 #include "llvm/ADT/StringRef.h" 28 #include "llvm/ADT/Twine.h" 29 #include "llvm/Analysis/CallGraph.h" 30 #include "llvm/Analysis/CallGraphSCCPass.h" 31 #include "llvm/Analysis/Utils/Local.h" 32 #include "llvm/IR/Argument.h" 33 #include "llvm/IR/Attributes.h" 34 #include "llvm/IR/BasicBlock.h" 35 #include "llvm/IR/CFG.h" 36 #include "llvm/IR/CallSite.h" 37 #include "llvm/IR/CallingConv.h" 38 #include "llvm/IR/Constants.h" 39 #include "llvm/IR/DataLayout.h" 40 #include "llvm/IR/DerivedTypes.h" 41 #include "llvm/IR/Function.h" 42 #include "llvm/IR/GlobalValue.h" 43 #include "llvm/IR/GlobalVariable.h" 44 #include "llvm/IR/IRBuilder.h" 45 #include "llvm/IR/InstIterator.h" 46 #include "llvm/IR/InstrTypes.h" 47 #include "llvm/IR/Instruction.h" 48 #include "llvm/IR/Instructions.h" 49 #include "llvm/IR/IntrinsicInst.h" 50 #include "llvm/IR/LLVMContext.h" 51 #include "llvm/IR/LegacyPassManager.h" 52 #include "llvm/IR/Module.h" 53 #include "llvm/IR/Type.h" 54 #include "llvm/IR/Value.h" 55 #include "llvm/IR/Verifier.h" 56 #include "llvm/Pass.h" 57 #include "llvm/Support/Casting.h" 58 #include "llvm/Support/Debug.h" 59 #include "llvm/Support/raw_ostream.h" 60 #include "llvm/Transforms/Scalar.h" 61 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 62 #include "llvm/Transforms/Utils/Cloning.h" 63 #include "llvm/Transforms/Utils/ValueMapper.h" 64 #include <cassert> 65 #include <cstddef> 66 #include <cstdint> 67 #include <initializer_list> 68 #include <iterator> 69 70 using namespace llvm; 71 72 #define DEBUG_TYPE "coro-split" 73 74 // Create an entry block for a resume function with a switch that will jump to 75 // suspend points. 76 static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) { 77 LLVMContext &C = F.getContext(); 78 79 // resume.entry: 80 // %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0, 81 // i32 2 82 // % index = load i32, i32* %index.addr 83 // switch i32 %index, label %unreachable [ 84 // i32 0, label %resume.0 85 // i32 1, label %resume.1 86 // ... 87 // ] 88 89 auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F); 90 auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F); 91 92 IRBuilder<> Builder(NewEntry); 93 auto *FramePtr = Shape.FramePtr; 94 auto *FrameTy = Shape.FrameTy; 95 auto *GepIndex = Builder.CreateConstInBoundsGEP2_32( 96 FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr"); 97 auto *Index = Builder.CreateLoad(GepIndex, "index"); 98 auto *Switch = 99 Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size()); 100 Shape.ResumeSwitch = Switch; 101 102 size_t SuspendIndex = 0; 103 for (CoroSuspendInst *S : Shape.CoroSuspends) { 104 ConstantInt *IndexVal = Shape.getIndex(SuspendIndex); 105 106 // Replace CoroSave with a store to Index: 107 // %index.addr = getelementptr %f.frame... (index field number) 108 // store i32 0, i32* %index.addr1 109 auto *Save = S->getCoroSave(); 110 Builder.SetInsertPoint(Save); 111 if (S->isFinal()) { 112 // Final suspend point is represented by storing zero in ResumeFnAddr. 113 auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, 114 0, "ResumeFn.addr"); 115 auto *NullPtr = ConstantPointerNull::get(cast<PointerType>( 116 cast<PointerType>(GepIndex->getType())->getElementType())); 117 Builder.CreateStore(NullPtr, GepIndex); 118 } else { 119 auto *GepIndex = Builder.CreateConstInBoundsGEP2_32( 120 FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr"); 121 Builder.CreateStore(IndexVal, GepIndex); 122 } 123 Save->replaceAllUsesWith(ConstantTokenNone::get(C)); 124 Save->eraseFromParent(); 125 126 // Split block before and after coro.suspend and add a jump from an entry 127 // switch: 128 // 129 // whateverBB: 130 // whatever 131 // %0 = call i8 @llvm.coro.suspend(token none, i1 false) 132 // switch i8 %0, label %suspend[i8 0, label %resume 133 // i8 1, label %cleanup] 134 // becomes: 135 // 136 // whateverBB: 137 // whatever 138 // br label %resume.0.landing 139 // 140 // resume.0: ; <--- jump from the switch in the resume.entry 141 // %0 = tail call i8 @llvm.coro.suspend(token none, i1 false) 142 // br label %resume.0.landing 143 // 144 // resume.0.landing: 145 // %1 = phi i8[-1, %whateverBB], [%0, %resume.0] 146 // switch i8 % 1, label %suspend [i8 0, label %resume 147 // i8 1, label %cleanup] 148 149 auto *SuspendBB = S->getParent(); 150 auto *ResumeBB = 151 SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex)); 152 auto *LandingBB = ResumeBB->splitBasicBlock( 153 S->getNextNode(), ResumeBB->getName() + Twine(".landing")); 154 Switch->addCase(IndexVal, ResumeBB); 155 156 cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB); 157 auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front()); 158 S->replaceAllUsesWith(PN); 159 PN->addIncoming(Builder.getInt8(-1), SuspendBB); 160 PN->addIncoming(S, ResumeBB); 161 162 ++SuspendIndex; 163 } 164 165 Builder.SetInsertPoint(UnreachBB); 166 Builder.CreateUnreachable(); 167 168 return NewEntry; 169 } 170 171 // In Resumers, we replace fallthrough coro.end with ret void and delete the 172 // rest of the block. 173 static void replaceFallthroughCoroEnd(IntrinsicInst *End, 174 ValueToValueMapTy &VMap) { 175 auto *NewE = cast<IntrinsicInst>(VMap[End]); 176 ReturnInst::Create(NewE->getContext(), nullptr, NewE); 177 178 // Remove the rest of the block, by splitting it into an unreachable block. 179 auto *BB = NewE->getParent(); 180 BB->splitBasicBlock(NewE); 181 BB->getTerminator()->eraseFromParent(); 182 } 183 184 // In Resumers, we replace unwind coro.end with True to force the immediate 185 // unwind to caller. 186 static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) { 187 if (Shape.CoroEnds.empty()) 188 return; 189 190 LLVMContext &Context = Shape.CoroEnds.front()->getContext(); 191 auto *True = ConstantInt::getTrue(Context); 192 for (CoroEndInst *CE : Shape.CoroEnds) { 193 if (!CE->isUnwind()) 194 continue; 195 196 auto *NewCE = cast<IntrinsicInst>(VMap[CE]); 197 198 // If coro.end has an associated bundle, add cleanupret instruction. 199 if (auto Bundle = NewCE->getOperandBundle(LLVMContext::OB_funclet)) { 200 Value *FromPad = Bundle->Inputs[0]; 201 auto *CleanupRet = CleanupReturnInst::Create(FromPad, nullptr, NewCE); 202 NewCE->getParent()->splitBasicBlock(NewCE); 203 CleanupRet->getParent()->getTerminator()->eraseFromParent(); 204 } 205 206 NewCE->replaceAllUsesWith(True); 207 NewCE->eraseFromParent(); 208 } 209 } 210 211 // Rewrite final suspend point handling. We do not use suspend index to 212 // represent the final suspend point. Instead we zero-out ResumeFnAddr in the 213 // coroutine frame, since it is undefined behavior to resume a coroutine 214 // suspended at the final suspend point. Thus, in the resume function, we can 215 // simply remove the last case (when coro::Shape is built, the final suspend 216 // point (if present) is always the last element of CoroSuspends array). 217 // In the destroy function, we add a code sequence to check if ResumeFnAddress 218 // is Null, and if so, jump to the appropriate label to handle cleanup from the 219 // final suspend point. 220 static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr, 221 coro::Shape &Shape, SwitchInst *Switch, 222 bool IsDestroy) { 223 assert(Shape.HasFinalSuspend); 224 auto FinalCaseIt = std::prev(Switch->case_end()); 225 BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor(); 226 Switch->removeCase(FinalCaseIt); 227 if (IsDestroy) { 228 BasicBlock *OldSwitchBB = Switch->getParent(); 229 auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch"); 230 Builder.SetInsertPoint(OldSwitchBB->getTerminator()); 231 auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, FramePtr, 232 0, 0, "ResumeFn.addr"); 233 auto *Load = Builder.CreateLoad(GepIndex); 234 auto *NullPtr = 235 ConstantPointerNull::get(cast<PointerType>(Load->getType())); 236 auto *Cond = Builder.CreateICmpEQ(Load, NullPtr); 237 Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB); 238 OldSwitchBB->getTerminator()->eraseFromParent(); 239 } 240 } 241 242 // Create a resume clone by cloning the body of the original function, setting 243 // new entry block and replacing coro.suspend an appropriate value to force 244 // resume or cleanup pass for every suspend point. 245 static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape, 246 BasicBlock *ResumeEntry, int8_t FnIndex) { 247 Module *M = F.getParent(); 248 auto *FrameTy = Shape.FrameTy; 249 auto *FnPtrTy = cast<PointerType>(FrameTy->getElementType(0)); 250 auto *FnTy = cast<FunctionType>(FnPtrTy->getElementType()); 251 252 Function *NewF = 253 Function::Create(FnTy, GlobalValue::LinkageTypes::ExternalLinkage, 254 F.getName() + Suffix, M); 255 NewF->addParamAttr(0, Attribute::NonNull); 256 NewF->addParamAttr(0, Attribute::NoAlias); 257 258 ValueToValueMapTy VMap; 259 // Replace all args with undefs. The buildCoroutineFrame algorithm already 260 // rewritten access to the args that occurs after suspend points with loads 261 // and stores to/from the coroutine frame. 262 for (Argument &A : F.args()) 263 VMap[&A] = UndefValue::get(A.getType()); 264 265 SmallVector<ReturnInst *, 4> Returns; 266 267 CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns); 268 NewF->setLinkage(GlobalValue::LinkageTypes::InternalLinkage); 269 270 // Remove old returns. 271 for (ReturnInst *Return : Returns) 272 changeToUnreachable(Return, /*UseLLVMTrap=*/false); 273 274 // Remove old return attributes. 275 NewF->removeAttributes( 276 AttributeList::ReturnIndex, 277 AttributeFuncs::typeIncompatible(NewF->getReturnType())); 278 279 // Make AllocaSpillBlock the new entry block. 280 auto *SwitchBB = cast<BasicBlock>(VMap[ResumeEntry]); 281 auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]); 282 Entry->moveBefore(&NewF->getEntryBlock()); 283 Entry->getTerminator()->eraseFromParent(); 284 BranchInst::Create(SwitchBB, Entry); 285 Entry->setName("entry" + Suffix); 286 287 // Clear all predecessors of the new entry block. 288 auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]); 289 Entry->replaceAllUsesWith(Switch->getDefaultDest()); 290 291 IRBuilder<> Builder(&NewF->getEntryBlock().front()); 292 293 // Remap frame pointer. 294 Argument *NewFramePtr = &*NewF->arg_begin(); 295 Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]); 296 NewFramePtr->takeName(OldFramePtr); 297 OldFramePtr->replaceAllUsesWith(NewFramePtr); 298 299 // Remap vFrame pointer. 300 auto *NewVFrame = Builder.CreateBitCast( 301 NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame"); 302 Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]); 303 OldVFrame->replaceAllUsesWith(NewVFrame); 304 305 // Rewrite final suspend handling as it is not done via switch (allows to 306 // remove final case from the switch, since it is undefined behavior to resume 307 // the coroutine suspended at the final suspend point. 308 if (Shape.HasFinalSuspend) { 309 auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]); 310 bool IsDestroy = FnIndex != 0; 311 handleFinalSuspend(Builder, NewFramePtr, Shape, Switch, IsDestroy); 312 } 313 314 // Replace coro suspend with the appropriate resume index. 315 // Replacing coro.suspend with (0) will result in control flow proceeding to 316 // a resume label associated with a suspend point, replacing it with (1) will 317 // result in control flow proceeding to a cleanup label associated with this 318 // suspend point. 319 auto *NewValue = Builder.getInt8(FnIndex ? 1 : 0); 320 for (CoroSuspendInst *CS : Shape.CoroSuspends) { 321 auto *MappedCS = cast<CoroSuspendInst>(VMap[CS]); 322 MappedCS->replaceAllUsesWith(NewValue); 323 MappedCS->eraseFromParent(); 324 } 325 326 // Remove coro.end intrinsics. 327 replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap); 328 replaceUnwindCoroEnds(Shape, VMap); 329 // Eliminate coro.free from the clones, replacing it with 'null' in cleanup, 330 // to suppress deallocation code. 331 coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]), 332 /*Elide=*/FnIndex == 2); 333 334 NewF->setCallingConv(CallingConv::Fast); 335 336 return NewF; 337 } 338 339 static void removeCoroEnds(coro::Shape &Shape) { 340 if (Shape.CoroEnds.empty()) 341 return; 342 343 LLVMContext &Context = Shape.CoroEnds.front()->getContext(); 344 auto *False = ConstantInt::getFalse(Context); 345 346 for (CoroEndInst *CE : Shape.CoroEnds) { 347 CE->replaceAllUsesWith(False); 348 CE->eraseFromParent(); 349 } 350 } 351 352 static void replaceFrameSize(coro::Shape &Shape) { 353 if (Shape.CoroSizes.empty()) 354 return; 355 356 // In the same function all coro.sizes should have the same result type. 357 auto *SizeIntrin = Shape.CoroSizes.back(); 358 Module *M = SizeIntrin->getModule(); 359 const DataLayout &DL = M->getDataLayout(); 360 auto Size = DL.getTypeAllocSize(Shape.FrameTy); 361 auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size); 362 363 for (CoroSizeInst *CS : Shape.CoroSizes) { 364 CS->replaceAllUsesWith(SizeConstant); 365 CS->eraseFromParent(); 366 } 367 } 368 369 // Create a global constant array containing pointers to functions provided and 370 // set Info parameter of CoroBegin to point at this constant. Example: 371 // 372 // @f.resumers = internal constant [2 x void(%f.frame*)*] 373 // [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy] 374 // define void @f() { 375 // ... 376 // call i8* @llvm.coro.begin(i8* null, i32 0, i8* null, 377 // i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*)) 378 // 379 // Assumes that all the functions have the same signature. 380 static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin, 381 std::initializer_list<Function *> Fns) { 382 SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end()); 383 assert(!Args.empty()); 384 Function *Part = *Fns.begin(); 385 Module *M = Part->getParent(); 386 auto *ArrTy = ArrayType::get(Part->getType(), Args.size()); 387 388 auto *ConstVal = ConstantArray::get(ArrTy, Args); 389 auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true, 390 GlobalVariable::PrivateLinkage, ConstVal, 391 F.getName() + Twine(".resumers")); 392 393 // Update coro.begin instruction to refer to this constant. 394 LLVMContext &C = F.getContext(); 395 auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C)); 396 CoroBegin->getId()->setInfo(BC); 397 } 398 399 // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame. 400 static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn, 401 Function *DestroyFn, Function *CleanupFn) { 402 IRBuilder<> Builder(Shape.FramePtr->getNextNode()); 403 auto *ResumeAddr = Builder.CreateConstInBoundsGEP2_32( 404 Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::ResumeField, 405 "resume.addr"); 406 Builder.CreateStore(ResumeFn, ResumeAddr); 407 408 Value *DestroyOrCleanupFn = DestroyFn; 409 410 CoroIdInst *CoroId = Shape.CoroBegin->getId(); 411 if (CoroAllocInst *CA = CoroId->getCoroAlloc()) { 412 // If there is a CoroAlloc and it returns false (meaning we elide the 413 // allocation, use CleanupFn instead of DestroyFn). 414 DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn); 415 } 416 417 auto *DestroyAddr = Builder.CreateConstInBoundsGEP2_32( 418 Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::DestroyField, 419 "destroy.addr"); 420 Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr); 421 } 422 423 static void postSplitCleanup(Function &F) { 424 removeUnreachableBlocks(F); 425 legacy::FunctionPassManager FPM(F.getParent()); 426 427 FPM.add(createVerifierPass()); 428 FPM.add(createSCCPPass()); 429 FPM.add(createCFGSimplificationPass()); 430 FPM.add(createEarlyCSEPass()); 431 FPM.add(createCFGSimplificationPass()); 432 433 FPM.doInitialization(); 434 FPM.run(F); 435 FPM.doFinalization(); 436 } 437 438 // Assuming we arrived at the block NewBlock from Prev instruction, store 439 // PHI's incoming values in the ResolvedValues map. 440 static void 441 scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock, 442 DenseMap<Value *, Value *> &ResolvedValues) { 443 auto *PrevBB = Prev->getParent(); 444 for (PHINode &PN : NewBlock->phis()) { 445 auto V = PN.getIncomingValueForBlock(PrevBB); 446 // See if we already resolved it. 447 auto VI = ResolvedValues.find(V); 448 if (VI != ResolvedValues.end()) 449 V = VI->second; 450 // Remember the value. 451 ResolvedValues[&PN] = V; 452 } 453 } 454 455 // Replace a sequence of branches leading to a ret, with a clone of a ret 456 // instruction. Suspend instruction represented by a switch, track the PHI 457 // values and select the correct case successor when possible. 458 static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { 459 DenseMap<Value *, Value *> ResolvedValues; 460 461 Instruction *I = InitialInst; 462 while (isa<TerminatorInst>(I)) { 463 if (isa<ReturnInst>(I)) { 464 if (I != InitialInst) 465 ReplaceInstWithInst(InitialInst, I->clone()); 466 return true; 467 } 468 if (auto *BR = dyn_cast<BranchInst>(I)) { 469 if (BR->isUnconditional()) { 470 BasicBlock *BB = BR->getSuccessor(0); 471 scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); 472 I = BB->getFirstNonPHIOrDbgOrLifetime(); 473 continue; 474 } 475 } else if (auto *SI = dyn_cast<SwitchInst>(I)) { 476 Value *V = SI->getCondition(); 477 auto it = ResolvedValues.find(V); 478 if (it != ResolvedValues.end()) 479 V = it->second; 480 if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) { 481 BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor(); 482 scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); 483 I = BB->getFirstNonPHIOrDbgOrLifetime(); 484 continue; 485 } 486 } 487 return false; 488 } 489 return false; 490 } 491 492 // Add musttail to any resume instructions that is immediately followed by a 493 // suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call 494 // for symmetrical coroutine control transfer (C++ Coroutines TS extension). 495 // This transformation is done only in the resume part of the coroutine that has 496 // identical signature and calling convention as the coro.resume call. 497 static void addMustTailToCoroResumes(Function &F) { 498 bool changed = false; 499 500 // Collect potential resume instructions. 501 SmallVector<CallInst *, 4> Resumes; 502 for (auto &I : instructions(F)) 503 if (auto *Call = dyn_cast<CallInst>(&I)) 504 if (auto *CalledValue = Call->getCalledValue()) 505 // CoroEarly pass replaced coro resumes with indirect calls to an 506 // address return by CoroSubFnInst intrinsic. See if it is one of those. 507 if (isa<CoroSubFnInst>(CalledValue->stripPointerCasts())) 508 Resumes.push_back(Call); 509 510 // Set musttail on those that are followed by a ret instruction. 511 for (CallInst *Call : Resumes) 512 if (simplifyTerminatorLeadingToRet(Call->getNextNode())) { 513 Call->setTailCallKind(CallInst::TCK_MustTail); 514 changed = true; 515 } 516 517 if (changed) 518 removeUnreachableBlocks(F); 519 } 520 521 // Coroutine has no suspend points. Remove heap allocation for the coroutine 522 // frame if possible. 523 static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) { 524 auto *CoroId = CoroBegin->getId(); 525 auto *AllocInst = CoroId->getCoroAlloc(); 526 coro::replaceCoroFree(CoroId, /*Elide=*/AllocInst != nullptr); 527 if (AllocInst) { 528 IRBuilder<> Builder(AllocInst); 529 // FIXME: Need to handle overaligned members. 530 auto *Frame = Builder.CreateAlloca(FrameTy); 531 auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy()); 532 AllocInst->replaceAllUsesWith(Builder.getFalse()); 533 AllocInst->eraseFromParent(); 534 CoroBegin->replaceAllUsesWith(VFrame); 535 } else { 536 CoroBegin->replaceAllUsesWith(CoroBegin->getMem()); 537 } 538 CoroBegin->eraseFromParent(); 539 } 540 541 // look for a very simple pattern 542 // coro.save 543 // no other calls 544 // resume or destroy call 545 // coro.suspend 546 // 547 // If there are other calls between coro.save and coro.suspend, they can 548 // potentially resume or destroy the coroutine, so it is unsafe to eliminate a 549 // suspend point. 550 static bool simplifySuspendPoint(CoroSuspendInst *Suspend, 551 CoroBeginInst *CoroBegin) { 552 auto *Save = Suspend->getCoroSave(); 553 auto *BB = Suspend->getParent(); 554 if (BB != Save->getParent()) 555 return false; 556 557 CallSite SingleCallSite; 558 559 // Check that we have only one CallSite. 560 for (Instruction *I = Save->getNextNode(); I != Suspend; 561 I = I->getNextNode()) { 562 if (isa<CoroFrameInst>(I)) 563 continue; 564 if (isa<CoroSubFnInst>(I)) 565 continue; 566 if (CallSite CS = CallSite(I)) { 567 if (SingleCallSite) 568 return false; 569 else 570 SingleCallSite = CS; 571 } 572 } 573 auto *CallInstr = SingleCallSite.getInstruction(); 574 if (!CallInstr) 575 return false; 576 577 auto *Callee = SingleCallSite.getCalledValue()->stripPointerCasts(); 578 579 // See if the callsite is for resumption or destruction of the coroutine. 580 auto *SubFn = dyn_cast<CoroSubFnInst>(Callee); 581 if (!SubFn) 582 return false; 583 584 // Does not refer to the current coroutine, we cannot do anything with it. 585 if (SubFn->getFrame() != CoroBegin) 586 return false; 587 588 // Replace llvm.coro.suspend with the value that results in resumption over 589 // the resume or cleanup path. 590 Suspend->replaceAllUsesWith(SubFn->getRawIndex()); 591 Suspend->eraseFromParent(); 592 Save->eraseFromParent(); 593 594 // No longer need a call to coro.resume or coro.destroy. 595 CallInstr->eraseFromParent(); 596 597 if (SubFn->user_empty()) 598 SubFn->eraseFromParent(); 599 600 return true; 601 } 602 603 // Remove suspend points that are simplified. 604 static void simplifySuspendPoints(coro::Shape &Shape) { 605 auto &S = Shape.CoroSuspends; 606 size_t I = 0, N = S.size(); 607 if (N == 0) 608 return; 609 while (true) { 610 if (simplifySuspendPoint(S[I], Shape.CoroBegin)) { 611 if (--N == I) 612 break; 613 std::swap(S[I], S[N]); 614 continue; 615 } 616 if (++I == N) 617 break; 618 } 619 S.resize(N); 620 } 621 622 static SmallPtrSet<BasicBlock *, 4> getCoroBeginPredBlocks(CoroBeginInst *CB) { 623 // Collect all blocks that we need to look for instructions to relocate. 624 SmallPtrSet<BasicBlock *, 4> RelocBlocks; 625 SmallVector<BasicBlock *, 4> Work; 626 Work.push_back(CB->getParent()); 627 628 do { 629 BasicBlock *Current = Work.pop_back_val(); 630 for (BasicBlock *BB : predecessors(Current)) 631 if (RelocBlocks.count(BB) == 0) { 632 RelocBlocks.insert(BB); 633 Work.push_back(BB); 634 } 635 } while (!Work.empty()); 636 return RelocBlocks; 637 } 638 639 static SmallPtrSet<Instruction *, 8> 640 getNotRelocatableInstructions(CoroBeginInst *CoroBegin, 641 SmallPtrSetImpl<BasicBlock *> &RelocBlocks) { 642 SmallPtrSet<Instruction *, 8> DoNotRelocate; 643 // Collect all instructions that we should not relocate 644 SmallVector<Instruction *, 8> Work; 645 646 // Start with CoroBegin and terminators of all preceding blocks. 647 Work.push_back(CoroBegin); 648 BasicBlock *CoroBeginBB = CoroBegin->getParent(); 649 for (BasicBlock *BB : RelocBlocks) 650 if (BB != CoroBeginBB) 651 Work.push_back(BB->getTerminator()); 652 653 // For every instruction in the Work list, place its operands in DoNotRelocate 654 // set. 655 do { 656 Instruction *Current = Work.pop_back_val(); 657 DEBUG(dbgs() << "CoroSplit: Will not relocate: " << *Current << "\n"); 658 DoNotRelocate.insert(Current); 659 for (Value *U : Current->operands()) { 660 auto *I = dyn_cast<Instruction>(U); 661 if (!I) 662 continue; 663 664 if (auto *A = dyn_cast<AllocaInst>(I)) { 665 // Stores to alloca instructions that occur before the coroutine frame 666 // is allocated should not be moved; the stored values may be used by 667 // the coroutine frame allocator. The operands to those stores must also 668 // remain in place. 669 for (const auto &User : A->users()) 670 if (auto *SI = dyn_cast<llvm::StoreInst>(User)) 671 if (RelocBlocks.count(SI->getParent()) != 0 && 672 DoNotRelocate.count(SI) == 0) { 673 Work.push_back(SI); 674 DoNotRelocate.insert(SI); 675 } 676 continue; 677 } 678 679 if (DoNotRelocate.count(I) == 0) { 680 Work.push_back(I); 681 DoNotRelocate.insert(I); 682 } 683 } 684 } while (!Work.empty()); 685 return DoNotRelocate; 686 } 687 688 static void relocateInstructionBefore(CoroBeginInst *CoroBegin, Function &F) { 689 // Analyze which non-alloca instructions are needed for allocation and 690 // relocate the rest to after coro.begin. We need to do it, since some of the 691 // targets of those instructions may be placed into coroutine frame memory 692 // for which becomes available after coro.begin intrinsic. 693 694 auto BlockSet = getCoroBeginPredBlocks(CoroBegin); 695 auto DoNotRelocateSet = getNotRelocatableInstructions(CoroBegin, BlockSet); 696 697 Instruction *InsertPt = CoroBegin->getNextNode(); 698 BasicBlock &BB = F.getEntryBlock(); // TODO: Look at other blocks as well. 699 for (auto B = BB.begin(), E = BB.end(); B != E;) { 700 Instruction &I = *B++; 701 if (isa<AllocaInst>(&I)) 702 continue; 703 if (&I == CoroBegin) 704 break; 705 if (DoNotRelocateSet.count(&I)) 706 continue; 707 I.moveBefore(InsertPt); 708 } 709 } 710 711 static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) { 712 coro::Shape Shape(F); 713 if (!Shape.CoroBegin) 714 return; 715 716 simplifySuspendPoints(Shape); 717 relocateInstructionBefore(Shape.CoroBegin, F); 718 buildCoroutineFrame(F, Shape); 719 replaceFrameSize(Shape); 720 721 // If there are no suspend points, no split required, just remove 722 // the allocation and deallocation blocks, they are not needed. 723 if (Shape.CoroSuspends.empty()) { 724 handleNoSuspendCoroutine(Shape.CoroBegin, Shape.FrameTy); 725 removeCoroEnds(Shape); 726 postSplitCleanup(F); 727 coro::updateCallGraph(F, {}, CG, SCC); 728 return; 729 } 730 731 auto *ResumeEntry = createResumeEntryBlock(F, Shape); 732 auto ResumeClone = createClone(F, ".resume", Shape, ResumeEntry, 0); 733 auto DestroyClone = createClone(F, ".destroy", Shape, ResumeEntry, 1); 734 auto CleanupClone = createClone(F, ".cleanup", Shape, ResumeEntry, 2); 735 736 // We no longer need coro.end in F. 737 removeCoroEnds(Shape); 738 739 postSplitCleanup(F); 740 postSplitCleanup(*ResumeClone); 741 postSplitCleanup(*DestroyClone); 742 postSplitCleanup(*CleanupClone); 743 744 addMustTailToCoroResumes(*ResumeClone); 745 746 // Store addresses resume/destroy/cleanup functions in the coroutine frame. 747 updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone); 748 749 // Create a constant array referring to resume/destroy/clone functions pointed 750 // by the last argument of @llvm.coro.info, so that CoroElide pass can 751 // determined correct function to call. 752 setCoroInfo(F, Shape.CoroBegin, {ResumeClone, DestroyClone, CleanupClone}); 753 754 // Update call graph and add the functions we created to the SCC. 755 coro::updateCallGraph(F, {ResumeClone, DestroyClone, CleanupClone}, CG, SCC); 756 } 757 758 // When we see the coroutine the first time, we insert an indirect call to a 759 // devirt trigger function and mark the coroutine that it is now ready for 760 // split. 761 static void prepareForSplit(Function &F, CallGraph &CG) { 762 Module &M = *F.getParent(); 763 #ifndef NDEBUG 764 Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN); 765 assert(DevirtFn && "coro.devirt.trigger function not found"); 766 #endif 767 768 F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT); 769 770 // Insert an indirect call sequence that will be devirtualized by CoroElide 771 // pass: 772 // %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1) 773 // %1 = bitcast i8* %0 to void(i8*)* 774 // call void %1(i8* null) 775 coro::LowererBase Lowerer(M); 776 Instruction *InsertPt = F.getEntryBlock().getTerminator(); 777 auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(F.getContext())); 778 auto *DevirtFnAddr = 779 Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt); 780 auto *IndirectCall = CallInst::Create(DevirtFnAddr, Null, "", InsertPt); 781 782 // Update CG graph with an indirect call we just added. 783 CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode()); 784 } 785 786 // Make sure that there is a devirtualization trigger function that CoroSplit 787 // pass uses the force restart CGSCC pipeline. If devirt trigger function is not 788 // found, we will create one and add it to the current SCC. 789 static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) { 790 Module &M = CG.getModule(); 791 if (M.getFunction(CORO_DEVIRT_TRIGGER_FN)) 792 return; 793 794 LLVMContext &C = M.getContext(); 795 auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C), 796 /*IsVarArgs=*/false); 797 Function *DevirtFn = 798 Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage, 799 CORO_DEVIRT_TRIGGER_FN, &M); 800 DevirtFn->addFnAttr(Attribute::AlwaysInline); 801 auto *Entry = BasicBlock::Create(C, "entry", DevirtFn); 802 ReturnInst::Create(C, Entry); 803 804 auto *Node = CG.getOrInsertFunction(DevirtFn); 805 806 SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end()); 807 Nodes.push_back(Node); 808 SCC.initialize(Nodes); 809 } 810 811 //===----------------------------------------------------------------------===// 812 // Top Level Driver 813 //===----------------------------------------------------------------------===// 814 815 namespace { 816 817 struct CoroSplit : public CallGraphSCCPass { 818 static char ID; // Pass identification, replacement for typeid 819 820 CoroSplit() : CallGraphSCCPass(ID) { 821 initializeCoroSplitPass(*PassRegistry::getPassRegistry()); 822 } 823 824 bool Run = false; 825 826 // A coroutine is identified by the presence of coro.begin intrinsic, if 827 // we don't have any, this pass has nothing to do. 828 bool doInitialization(CallGraph &CG) override { 829 Run = coro::declaresIntrinsics(CG.getModule(), {"llvm.coro.begin"}); 830 return CallGraphSCCPass::doInitialization(CG); 831 } 832 833 bool runOnSCC(CallGraphSCC &SCC) override { 834 if (!Run) 835 return false; 836 837 // Find coroutines for processing. 838 SmallVector<Function *, 4> Coroutines; 839 for (CallGraphNode *CGN : SCC) 840 if (auto *F = CGN->getFunction()) 841 if (F->hasFnAttribute(CORO_PRESPLIT_ATTR)) 842 Coroutines.push_back(F); 843 844 if (Coroutines.empty()) 845 return false; 846 847 CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); 848 createDevirtTriggerFunc(CG, SCC); 849 850 for (Function *F : Coroutines) { 851 Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR); 852 StringRef Value = Attr.getValueAsString(); 853 DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName() 854 << "' state: " << Value << "\n"); 855 if (Value == UNPREPARED_FOR_SPLIT) { 856 prepareForSplit(*F, CG); 857 continue; 858 } 859 F->removeFnAttr(CORO_PRESPLIT_ATTR); 860 splitCoroutine(*F, CG, SCC); 861 } 862 return true; 863 } 864 865 void getAnalysisUsage(AnalysisUsage &AU) const override { 866 CallGraphSCCPass::getAnalysisUsage(AU); 867 } 868 869 StringRef getPassName() const override { return "Coroutine Splitting"; } 870 }; 871 872 } // end anonymous namespace 873 874 char CoroSplit::ID = 0; 875 876 INITIALIZE_PASS( 877 CoroSplit, "coro-split", 878 "Split coroutine into a set of functions driving its state machine", false, 879 false) 880 881 Pass *llvm::createCoroSplitPass() { return new CoroSplit(); } 882