1 //===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This pass builds the coroutine frame and outlines resume and destroy parts
9 // of the coroutine into separate functions.
10 //
11 // We present a coroutine to an LLVM as an ordinary function with suspension
12 // points marked up with intrinsics. We let the optimizer party on the coroutine
13 // as a single function for as long as possible. Shortly before the coroutine is
14 // eligible to be inlined into its callers, we split up the coroutine into parts
15 // corresponding to an initial, resume and destroy invocations of the coroutine,
16 // add them to the current SCC and restart the IPO pipeline to optimize the
17 // coroutine subfunctions we extracted before proceeding to the caller of the
18 // coroutine.
19 //===----------------------------------------------------------------------===//
20 
21 #include "CoroInstr.h"
22 #include "CoroInternal.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/SmallPtrSet.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/ADT/Twine.h"
28 #include "llvm/Analysis/CallGraph.h"
29 #include "llvm/Analysis/CallGraphSCCPass.h"
30 #include "llvm/Transforms/Utils/Local.h"
31 #include "llvm/IR/Argument.h"
32 #include "llvm/IR/Attributes.h"
33 #include "llvm/IR/BasicBlock.h"
34 #include "llvm/IR/CFG.h"
35 #include "llvm/IR/CallSite.h"
36 #include "llvm/IR/CallingConv.h"
37 #include "llvm/IR/Constants.h"
38 #include "llvm/IR/DataLayout.h"
39 #include "llvm/IR/DerivedTypes.h"
40 #include "llvm/IR/Function.h"
41 #include "llvm/IR/GlobalValue.h"
42 #include "llvm/IR/GlobalVariable.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/InstIterator.h"
45 #include "llvm/IR/InstrTypes.h"
46 #include "llvm/IR/Instruction.h"
47 #include "llvm/IR/Instructions.h"
48 #include "llvm/IR/IntrinsicInst.h"
49 #include "llvm/IR/LLVMContext.h"
50 #include "llvm/IR/LegacyPassManager.h"
51 #include "llvm/IR/Module.h"
52 #include "llvm/IR/Type.h"
53 #include "llvm/IR/Value.h"
54 #include "llvm/IR/Verifier.h"
55 #include "llvm/Pass.h"
56 #include "llvm/Support/Casting.h"
57 #include "llvm/Support/Debug.h"
58 #include "llvm/Support/PrettyStackTrace.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/Transforms/Scalar.h"
61 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
62 #include "llvm/Transforms/Utils/Cloning.h"
63 #include "llvm/Transforms/Utils/ValueMapper.h"
64 #include <cassert>
65 #include <cstddef>
66 #include <cstdint>
67 #include <initializer_list>
68 #include <iterator>
69 
70 using namespace llvm;
71 
72 #define DEBUG_TYPE "coro-split"
73 
74 namespace {
75 
76 /// A little helper class for building
77 class CoroCloner {
78 public:
79   enum class Kind {
80     /// The shared resume function for a switch lowering.
81     SwitchResume,
82 
83     /// The shared unwind function for a switch lowering.
84     SwitchUnwind,
85 
86     /// The shared cleanup function for a switch lowering.
87     SwitchCleanup,
88 
89     /// An individual continuation function.
90     Continuation,
91   };
92 private:
93   Function &OrigF;
94   Function *NewF;
95   const Twine &Suffix;
96   coro::Shape &Shape;
97   Kind FKind;
98   ValueToValueMapTy VMap;
99   IRBuilder<> Builder;
100   Value *NewFramePtr = nullptr;
101   Value *SwiftErrorSlot = nullptr;
102 
103   /// The active suspend instruction; meaningful only for continuation ABIs.
104   AnyCoroSuspendInst *ActiveSuspend = nullptr;
105 
106 public:
107   /// Create a cloner for a switch lowering.
108   CoroCloner(Function &OrigF, const Twine &Suffix, coro::Shape &Shape,
109              Kind FKind)
110     : OrigF(OrigF), NewF(nullptr), Suffix(Suffix), Shape(Shape),
111       FKind(FKind), Builder(OrigF.getContext()) {
112     assert(Shape.ABI == coro::ABI::Switch);
113   }
114 
115   /// Create a cloner for a continuation lowering.
116   CoroCloner(Function &OrigF, const Twine &Suffix, coro::Shape &Shape,
117              Function *NewF, AnyCoroSuspendInst *ActiveSuspend)
118     : OrigF(OrigF), NewF(NewF), Suffix(Suffix), Shape(Shape),
119       FKind(Kind::Continuation), Builder(OrigF.getContext()),
120       ActiveSuspend(ActiveSuspend) {
121     assert(Shape.ABI == coro::ABI::Retcon ||
122            Shape.ABI == coro::ABI::RetconOnce);
123     assert(NewF && "need existing function for continuation");
124     assert(ActiveSuspend && "need active suspend point for continuation");
125   }
126 
127   Function *getFunction() const {
128     assert(NewF != nullptr && "declaration not yet set");
129     return NewF;
130   }
131 
132   void create();
133 
134 private:
135   bool isSwitchDestroyFunction() {
136     switch (FKind) {
137     case Kind::Continuation:
138     case Kind::SwitchResume:
139       return false;
140     case Kind::SwitchUnwind:
141     case Kind::SwitchCleanup:
142       return true;
143     }
144     llvm_unreachable("Unknown CoroCloner::Kind enum");
145   }
146 
147   void createDeclaration();
148   void replaceEntryBlock();
149   Value *deriveNewFramePointer();
150   void replaceRetconSuspendUses();
151   void replaceCoroSuspends();
152   void replaceCoroEnds();
153   void replaceSwiftErrorOps();
154   void handleFinalSuspend();
155   void maybeFreeContinuationStorage();
156 };
157 
158 } // end anonymous namespace
159 
160 static void maybeFreeRetconStorage(IRBuilder<> &Builder, coro::Shape &Shape,
161                                    Value *FramePtr, CallGraph *CG) {
162   assert(Shape.ABI == coro::ABI::Retcon ||
163          Shape.ABI == coro::ABI::RetconOnce);
164   if (Shape.RetconLowering.IsFrameInlineInStorage)
165     return;
166 
167   Shape.emitDealloc(Builder, FramePtr, CG);
168 }
169 
170 /// Replace a non-unwind call to llvm.coro.end.
171 static void replaceFallthroughCoroEnd(CoroEndInst *End, coro::Shape &Shape,
172                                       Value *FramePtr, bool InResume,
173                                       CallGraph *CG) {
174   // Start inserting right before the coro.end.
175   IRBuilder<> Builder(End);
176 
177   // Create the return instruction.
178   switch (Shape.ABI) {
179   // The cloned functions in switch-lowering always return void.
180   case coro::ABI::Switch:
181     // coro.end doesn't immediately end the coroutine in the main function
182     // in this lowering, because we need to deallocate the coroutine.
183     if (!InResume)
184       return;
185     Builder.CreateRetVoid();
186     break;
187 
188   // In unique continuation lowering, the continuations always return void.
189   // But we may have implicitly allocated storage.
190   case coro::ABI::RetconOnce:
191     maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
192     Builder.CreateRetVoid();
193     break;
194 
195   // In non-unique continuation lowering, we signal completion by returning
196   // a null continuation.
197   case coro::ABI::Retcon: {
198     maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
199     auto RetTy = Shape.getResumeFunctionType()->getReturnType();
200     auto RetStructTy = dyn_cast<StructType>(RetTy);
201     PointerType *ContinuationTy =
202       cast<PointerType>(RetStructTy ? RetStructTy->getElementType(0) : RetTy);
203 
204     Value *ReturnValue = ConstantPointerNull::get(ContinuationTy);
205     if (RetStructTy) {
206       ReturnValue = Builder.CreateInsertValue(UndefValue::get(RetStructTy),
207                                               ReturnValue, 0);
208     }
209     Builder.CreateRet(ReturnValue);
210     break;
211   }
212   }
213 
214   // Remove the rest of the block, by splitting it into an unreachable block.
215   auto *BB = End->getParent();
216   BB->splitBasicBlock(End);
217   BB->getTerminator()->eraseFromParent();
218 }
219 
220 /// Replace an unwind call to llvm.coro.end.
221 static void replaceUnwindCoroEnd(CoroEndInst *End, coro::Shape &Shape,
222                                  Value *FramePtr, bool InResume, CallGraph *CG){
223   IRBuilder<> Builder(End);
224 
225   switch (Shape.ABI) {
226   // In switch-lowering, this does nothing in the main function.
227   case coro::ABI::Switch:
228     if (!InResume)
229       return;
230     break;
231 
232   // In continuation-lowering, this frees the continuation storage.
233   case coro::ABI::Retcon:
234   case coro::ABI::RetconOnce:
235     maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
236     break;
237   }
238 
239   // If coro.end has an associated bundle, add cleanupret instruction.
240   if (auto Bundle = End->getOperandBundle(LLVMContext::OB_funclet)) {
241     auto *FromPad = cast<CleanupPadInst>(Bundle->Inputs[0]);
242     auto *CleanupRet = Builder.CreateCleanupRet(FromPad, nullptr);
243     End->getParent()->splitBasicBlock(End);
244     CleanupRet->getParent()->getTerminator()->eraseFromParent();
245   }
246 }
247 
248 static void replaceCoroEnd(CoroEndInst *End, coro::Shape &Shape,
249                            Value *FramePtr, bool InResume, CallGraph *CG) {
250   if (End->isUnwind())
251     replaceUnwindCoroEnd(End, Shape, FramePtr, InResume, CG);
252   else
253     replaceFallthroughCoroEnd(End, Shape, FramePtr, InResume, CG);
254 
255   auto &Context = End->getContext();
256   End->replaceAllUsesWith(InResume ? ConstantInt::getTrue(Context)
257                                    : ConstantInt::getFalse(Context));
258   End->eraseFromParent();
259 }
260 
261 // Create an entry block for a resume function with a switch that will jump to
262 // suspend points.
263 static void createResumeEntryBlock(Function &F, coro::Shape &Shape) {
264   assert(Shape.ABI == coro::ABI::Switch);
265   LLVMContext &C = F.getContext();
266 
267   // resume.entry:
268   //  %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0,
269   //  i32 2
270   //  % index = load i32, i32* %index.addr
271   //  switch i32 %index, label %unreachable [
272   //    i32 0, label %resume.0
273   //    i32 1, label %resume.1
274   //    ...
275   //  ]
276 
277   auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F);
278   auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F);
279 
280   IRBuilder<> Builder(NewEntry);
281   auto *FramePtr = Shape.FramePtr;
282   auto *FrameTy = Shape.FrameTy;
283   auto *GepIndex = Builder.CreateStructGEP(
284       FrameTy, FramePtr, coro::Shape::SwitchFieldIndex::Index, "index.addr");
285   auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index");
286   auto *Switch =
287       Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
288   Shape.SwitchLowering.ResumeSwitch = Switch;
289 
290   size_t SuspendIndex = 0;
291   for (auto *AnyS : Shape.CoroSuspends) {
292     auto *S = cast<CoroSuspendInst>(AnyS);
293     ConstantInt *IndexVal = Shape.getIndex(SuspendIndex);
294 
295     // Replace CoroSave with a store to Index:
296     //    %index.addr = getelementptr %f.frame... (index field number)
297     //    store i32 0, i32* %index.addr1
298     auto *Save = S->getCoroSave();
299     Builder.SetInsertPoint(Save);
300     if (S->isFinal()) {
301       // Final suspend point is represented by storing zero in ResumeFnAddr.
302       auto *GepIndex = Builder.CreateStructGEP(FrameTy, FramePtr,
303                                  coro::Shape::SwitchFieldIndex::Resume,
304                                   "ResumeFn.addr");
305       auto *NullPtr = ConstantPointerNull::get(cast<PointerType>(
306           cast<PointerType>(GepIndex->getType())->getElementType()));
307       Builder.CreateStore(NullPtr, GepIndex);
308     } else {
309       auto *GepIndex = Builder.CreateStructGEP(
310           FrameTy, FramePtr, coro::Shape::SwitchFieldIndex::Index, "index.addr");
311       Builder.CreateStore(IndexVal, GepIndex);
312     }
313     Save->replaceAllUsesWith(ConstantTokenNone::get(C));
314     Save->eraseFromParent();
315 
316     // Split block before and after coro.suspend and add a jump from an entry
317     // switch:
318     //
319     //  whateverBB:
320     //    whatever
321     //    %0 = call i8 @llvm.coro.suspend(token none, i1 false)
322     //    switch i8 %0, label %suspend[i8 0, label %resume
323     //                                 i8 1, label %cleanup]
324     // becomes:
325     //
326     //  whateverBB:
327     //     whatever
328     //     br label %resume.0.landing
329     //
330     //  resume.0: ; <--- jump from the switch in the resume.entry
331     //     %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
332     //     br label %resume.0.landing
333     //
334     //  resume.0.landing:
335     //     %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
336     //     switch i8 % 1, label %suspend [i8 0, label %resume
337     //                                    i8 1, label %cleanup]
338 
339     auto *SuspendBB = S->getParent();
340     auto *ResumeBB =
341         SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex));
342     auto *LandingBB = ResumeBB->splitBasicBlock(
343         S->getNextNode(), ResumeBB->getName() + Twine(".landing"));
344     Switch->addCase(IndexVal, ResumeBB);
345 
346     cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
347     auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front());
348     S->replaceAllUsesWith(PN);
349     PN->addIncoming(Builder.getInt8(-1), SuspendBB);
350     PN->addIncoming(S, ResumeBB);
351 
352     ++SuspendIndex;
353   }
354 
355   Builder.SetInsertPoint(UnreachBB);
356   Builder.CreateUnreachable();
357 
358   Shape.SwitchLowering.ResumeEntryBlock = NewEntry;
359 }
360 
361 
362 // Rewrite final suspend point handling. We do not use suspend index to
363 // represent the final suspend point. Instead we zero-out ResumeFnAddr in the
364 // coroutine frame, since it is undefined behavior to resume a coroutine
365 // suspended at the final suspend point. Thus, in the resume function, we can
366 // simply remove the last case (when coro::Shape is built, the final suspend
367 // point (if present) is always the last element of CoroSuspends array).
368 // In the destroy function, we add a code sequence to check if ResumeFnAddress
369 // is Null, and if so, jump to the appropriate label to handle cleanup from the
370 // final suspend point.
371 void CoroCloner::handleFinalSuspend() {
372   assert(Shape.ABI == coro::ABI::Switch &&
373          Shape.SwitchLowering.HasFinalSuspend);
374   auto *Switch = cast<SwitchInst>(VMap[Shape.SwitchLowering.ResumeSwitch]);
375   auto FinalCaseIt = std::prev(Switch->case_end());
376   BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor();
377   Switch->removeCase(FinalCaseIt);
378   if (isSwitchDestroyFunction()) {
379     BasicBlock *OldSwitchBB = Switch->getParent();
380     auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch");
381     Builder.SetInsertPoint(OldSwitchBB->getTerminator());
382     auto *GepIndex = Builder.CreateStructGEP(Shape.FrameTy, NewFramePtr,
383                                        coro::Shape::SwitchFieldIndex::Resume,
384                                              "ResumeFn.addr");
385     auto *Load = Builder.CreateLoad(Shape.getSwitchResumePointerType(),
386                                     GepIndex);
387     auto *Cond = Builder.CreateIsNull(Load);
388     Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB);
389     OldSwitchBB->getTerminator()->eraseFromParent();
390   }
391 }
392 
393 static Function *createCloneDeclaration(Function &OrigF, coro::Shape &Shape,
394                                         const Twine &Suffix,
395                                         Module::iterator InsertBefore) {
396   Module *M = OrigF.getParent();
397   auto *FnTy = Shape.getResumeFunctionType();
398 
399   Function *NewF =
400       Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage,
401                        OrigF.getName() + Suffix);
402   NewF->addParamAttr(0, Attribute::NonNull);
403   NewF->addParamAttr(0, Attribute::NoAlias);
404 
405   M->getFunctionList().insert(InsertBefore, NewF);
406 
407   return NewF;
408 }
409 
410 /// Replace uses of the active llvm.coro.suspend.retcon call with the
411 /// arguments to the continuation function.
412 ///
413 /// This assumes that the builder has a meaningful insertion point.
414 void CoroCloner::replaceRetconSuspendUses() {
415   assert(Shape.ABI == coro::ABI::Retcon ||
416          Shape.ABI == coro::ABI::RetconOnce);
417 
418   auto NewS = VMap[ActiveSuspend];
419   if (NewS->use_empty()) return;
420 
421   // Copy out all the continuation arguments after the buffer pointer into
422   // an easily-indexed data structure for convenience.
423   SmallVector<Value*, 8> Args;
424   for (auto I = std::next(NewF->arg_begin()), E = NewF->arg_end(); I != E; ++I)
425     Args.push_back(&*I);
426 
427   // If the suspend returns a single scalar value, we can just do a simple
428   // replacement.
429   if (!isa<StructType>(NewS->getType())) {
430     assert(Args.size() == 1);
431     NewS->replaceAllUsesWith(Args.front());
432     return;
433   }
434 
435   // Try to peephole extracts of an aggregate return.
436   for (auto UI = NewS->use_begin(), UE = NewS->use_end(); UI != UE; ) {
437     auto EVI = dyn_cast<ExtractValueInst>((UI++)->getUser());
438     if (!EVI || EVI->getNumIndices() != 1)
439       continue;
440 
441     EVI->replaceAllUsesWith(Args[EVI->getIndices().front()]);
442     EVI->eraseFromParent();
443   }
444 
445   // If we have no remaining uses, we're done.
446   if (NewS->use_empty()) return;
447 
448   // Otherwise, we need to create an aggregate.
449   Value *Agg = UndefValue::get(NewS->getType());
450   for (size_t I = 0, E = Args.size(); I != E; ++I)
451     Agg = Builder.CreateInsertValue(Agg, Args[I], I);
452 
453   NewS->replaceAllUsesWith(Agg);
454 }
455 
456 void CoroCloner::replaceCoroSuspends() {
457   Value *SuspendResult;
458 
459   switch (Shape.ABI) {
460   // In switch lowering, replace coro.suspend with the appropriate value
461   // for the type of function we're extracting.
462   // Replacing coro.suspend with (0) will result in control flow proceeding to
463   // a resume label associated with a suspend point, replacing it with (1) will
464   // result in control flow proceeding to a cleanup label associated with this
465   // suspend point.
466   case coro::ABI::Switch:
467     SuspendResult = Builder.getInt8(isSwitchDestroyFunction() ? 1 : 0);
468     break;
469 
470   // In returned-continuation lowering, the arguments from earlier
471   // continuations are theoretically arbitrary, and they should have been
472   // spilled.
473   case coro::ABI::RetconOnce:
474   case coro::ABI::Retcon:
475     return;
476   }
477 
478   for (AnyCoroSuspendInst *CS : Shape.CoroSuspends) {
479     // The active suspend was handled earlier.
480     if (CS == ActiveSuspend) continue;
481 
482     auto *MappedCS = cast<AnyCoroSuspendInst>(VMap[CS]);
483     MappedCS->replaceAllUsesWith(SuspendResult);
484     MappedCS->eraseFromParent();
485   }
486 }
487 
488 void CoroCloner::replaceCoroEnds() {
489   for (CoroEndInst *CE : Shape.CoroEnds) {
490     // We use a null call graph because there's no call graph node for
491     // the cloned function yet.  We'll just be rebuilding that later.
492     auto NewCE = cast<CoroEndInst>(VMap[CE]);
493     replaceCoroEnd(NewCE, Shape, NewFramePtr, /*in resume*/ true, nullptr);
494   }
495 }
496 
497 static void replaceSwiftErrorOps(Function &F, coro::Shape &Shape,
498                                  ValueToValueMapTy *VMap) {
499   Value *CachedSlot = nullptr;
500   auto getSwiftErrorSlot = [&](Type *ValueTy) -> Value * {
501     if (CachedSlot) {
502       assert(CachedSlot->getType()->getPointerElementType() == ValueTy &&
503              "multiple swifterror slots in function with different types");
504       return CachedSlot;
505     }
506 
507     // Check if the function has a swifterror argument.
508     for (auto &Arg : F.args()) {
509       if (Arg.isSwiftError()) {
510         CachedSlot = &Arg;
511         assert(Arg.getType()->getPointerElementType() == ValueTy &&
512                "swifterror argument does not have expected type");
513         return &Arg;
514       }
515     }
516 
517     // Create a swifterror alloca.
518     IRBuilder<> Builder(F.getEntryBlock().getFirstNonPHIOrDbg());
519     auto Alloca = Builder.CreateAlloca(ValueTy);
520     Alloca->setSwiftError(true);
521 
522     CachedSlot = Alloca;
523     return Alloca;
524   };
525 
526   for (CallInst *Op : Shape.SwiftErrorOps) {
527     auto MappedOp = VMap ? cast<CallInst>((*VMap)[Op]) : Op;
528     IRBuilder<> Builder(MappedOp);
529 
530     // If there are no arguments, this is a 'get' operation.
531     Value *MappedResult;
532     if (Op->getNumArgOperands() == 0) {
533       auto ValueTy = Op->getType();
534       auto Slot = getSwiftErrorSlot(ValueTy);
535       MappedResult = Builder.CreateLoad(ValueTy, Slot);
536     } else {
537       assert(Op->getNumArgOperands() == 1);
538       auto Value = MappedOp->getArgOperand(0);
539       auto ValueTy = Value->getType();
540       auto Slot = getSwiftErrorSlot(ValueTy);
541       Builder.CreateStore(Value, Slot);
542       MappedResult = Slot;
543     }
544 
545     MappedOp->replaceAllUsesWith(MappedResult);
546     MappedOp->eraseFromParent();
547   }
548 
549   // If we're updating the original function, we've invalidated SwiftErrorOps.
550   if (VMap == nullptr) {
551     Shape.SwiftErrorOps.clear();
552   }
553 }
554 
555 void CoroCloner::replaceSwiftErrorOps() {
556   ::replaceSwiftErrorOps(*NewF, Shape, &VMap);
557 }
558 
559 void CoroCloner::replaceEntryBlock() {
560   // In the original function, the AllocaSpillBlock is a block immediately
561   // following the allocation of the frame object which defines GEPs for
562   // all the allocas that have been moved into the frame, and it ends by
563   // branching to the original beginning of the coroutine.  Make this
564   // the entry block of the cloned function.
565   auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]);
566   Entry->setName("entry" + Suffix);
567   Entry->moveBefore(&NewF->getEntryBlock());
568   Entry->getTerminator()->eraseFromParent();
569 
570   // Clear all predecessors of the new entry block.  There should be
571   // exactly one predecessor, which we created when splitting out
572   // AllocaSpillBlock to begin with.
573   assert(Entry->hasOneUse());
574   auto BranchToEntry = cast<BranchInst>(Entry->user_back());
575   assert(BranchToEntry->isUnconditional());
576   Builder.SetInsertPoint(BranchToEntry);
577   Builder.CreateUnreachable();
578   BranchToEntry->eraseFromParent();
579 
580   // TODO: move any allocas into Entry that weren't moved into the frame.
581   // (Currently we move all allocas into the frame.)
582 
583   // Branch from the entry to the appropriate place.
584   Builder.SetInsertPoint(Entry);
585   switch (Shape.ABI) {
586   case coro::ABI::Switch: {
587     // In switch-lowering, we built a resume-entry block in the original
588     // function.  Make the entry block branch to this.
589     auto *SwitchBB =
590       cast<BasicBlock>(VMap[Shape.SwitchLowering.ResumeEntryBlock]);
591     Builder.CreateBr(SwitchBB);
592     break;
593   }
594 
595   case coro::ABI::Retcon:
596   case coro::ABI::RetconOnce: {
597     // In continuation ABIs, we want to branch to immediately after the
598     // active suspend point.  Earlier phases will have put the suspend in its
599     // own basic block, so just thread our jump directly to its successor.
600     auto MappedCS = cast<CoroSuspendRetconInst>(VMap[ActiveSuspend]);
601     auto Branch = cast<BranchInst>(MappedCS->getNextNode());
602     assert(Branch->isUnconditional());
603     Builder.CreateBr(Branch->getSuccessor(0));
604     break;
605   }
606   }
607 }
608 
609 /// Derive the value of the new frame pointer.
610 Value *CoroCloner::deriveNewFramePointer() {
611   // Builder should be inserting to the front of the new entry block.
612 
613   switch (Shape.ABI) {
614   // In switch-lowering, the argument is the frame pointer.
615   case coro::ABI::Switch:
616     return &*NewF->arg_begin();
617 
618   // In continuation-lowering, the argument is the opaque storage.
619   case coro::ABI::Retcon:
620   case coro::ABI::RetconOnce: {
621     Argument *NewStorage = &*NewF->arg_begin();
622     auto FramePtrTy = Shape.FrameTy->getPointerTo();
623 
624     // If the storage is inline, just bitcast to the storage to the frame type.
625     if (Shape.RetconLowering.IsFrameInlineInStorage)
626       return Builder.CreateBitCast(NewStorage, FramePtrTy);
627 
628     // Otherwise, load the real frame from the opaque storage.
629     auto FramePtrPtr =
630       Builder.CreateBitCast(NewStorage, FramePtrTy->getPointerTo());
631     return Builder.CreateLoad(FramePtrPtr);
632   }
633   }
634   llvm_unreachable("bad ABI");
635 }
636 
637 /// Clone the body of the original function into a resume function of
638 /// some sort.
639 void CoroCloner::create() {
640   // Create the new function if we don't already have one.
641   if (!NewF) {
642     NewF = createCloneDeclaration(OrigF, Shape, Suffix,
643                                   OrigF.getParent()->end());
644   }
645 
646   // Replace all args with undefs. The buildCoroutineFrame algorithm already
647   // rewritten access to the args that occurs after suspend points with loads
648   // and stores to/from the coroutine frame.
649   for (Argument &A : OrigF.args())
650     VMap[&A] = UndefValue::get(A.getType());
651 
652   SmallVector<ReturnInst *, 4> Returns;
653 
654   // Ignore attempts to change certain attributes of the function.
655   // TODO: maybe there should be a way to suppress this during cloning?
656   auto savedVisibility = NewF->getVisibility();
657   auto savedUnnamedAddr = NewF->getUnnamedAddr();
658   auto savedDLLStorageClass = NewF->getDLLStorageClass();
659 
660   // NewF's linkage (which CloneFunctionInto does *not* change) might not
661   // be compatible with the visibility of OrigF (which it *does* change),
662   // so protect against that.
663   auto savedLinkage = NewF->getLinkage();
664   NewF->setLinkage(llvm::GlobalValue::ExternalLinkage);
665 
666   CloneFunctionInto(NewF, &OrigF, VMap, /*ModuleLevelChanges=*/true, Returns);
667 
668   NewF->setLinkage(savedLinkage);
669   NewF->setVisibility(savedVisibility);
670   NewF->setUnnamedAddr(savedUnnamedAddr);
671   NewF->setDLLStorageClass(savedDLLStorageClass);
672 
673   auto &Context = NewF->getContext();
674 
675   // Replace the attributes of the new function:
676   auto OrigAttrs = NewF->getAttributes();
677   auto NewAttrs = AttributeList();
678 
679   switch (Shape.ABI) {
680   case coro::ABI::Switch:
681     // Bootstrap attributes by copying function attributes from the
682     // original function.  This should include optimization settings and so on.
683     NewAttrs = NewAttrs.addAttributes(Context, AttributeList::FunctionIndex,
684                                       OrigAttrs.getFnAttributes());
685     break;
686 
687   case coro::ABI::Retcon:
688   case coro::ABI::RetconOnce:
689     // If we have a continuation prototype, just use its attributes,
690     // full-stop.
691     NewAttrs = Shape.RetconLowering.ResumePrototype->getAttributes();
692     break;
693   }
694 
695   // Make the frame parameter nonnull and noalias.
696   NewAttrs = NewAttrs.addParamAttribute(Context, 0, Attribute::NonNull);
697   NewAttrs = NewAttrs.addParamAttribute(Context, 0, Attribute::NoAlias);
698 
699   switch (Shape.ABI) {
700   // In these ABIs, the cloned functions always return 'void', and the
701   // existing return sites are meaningless.  Note that for unique
702   // continuations, this includes the returns associated with suspends;
703   // this is fine because we can't suspend twice.
704   case coro::ABI::Switch:
705   case coro::ABI::RetconOnce:
706     // Remove old returns.
707     for (ReturnInst *Return : Returns)
708       changeToUnreachable(Return, /*UseLLVMTrap=*/false);
709     break;
710 
711   // With multi-suspend continuations, we'll already have eliminated the
712   // original returns and inserted returns before all the suspend points,
713   // so we want to leave any returns in place.
714   case coro::ABI::Retcon:
715     break;
716   }
717 
718   NewF->setAttributes(NewAttrs);
719   NewF->setCallingConv(Shape.getResumeFunctionCC());
720 
721   // Set up the new entry block.
722   replaceEntryBlock();
723 
724   Builder.SetInsertPoint(&NewF->getEntryBlock().front());
725   NewFramePtr = deriveNewFramePointer();
726 
727   // Remap frame pointer.
728   Value *OldFramePtr = VMap[Shape.FramePtr];
729   NewFramePtr->takeName(OldFramePtr);
730   OldFramePtr->replaceAllUsesWith(NewFramePtr);
731 
732   // Remap vFrame pointer.
733   auto *NewVFrame = Builder.CreateBitCast(
734       NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame");
735   Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]);
736   OldVFrame->replaceAllUsesWith(NewVFrame);
737 
738   switch (Shape.ABI) {
739   case coro::ABI::Switch:
740     // Rewrite final suspend handling as it is not done via switch (allows to
741     // remove final case from the switch, since it is undefined behavior to
742     // resume the coroutine suspended at the final suspend point.
743     if (Shape.SwitchLowering.HasFinalSuspend)
744       handleFinalSuspend();
745     break;
746 
747   case coro::ABI::Retcon:
748   case coro::ABI::RetconOnce:
749     // Replace uses of the active suspend with the corresponding
750     // continuation-function arguments.
751     assert(ActiveSuspend != nullptr &&
752            "no active suspend when lowering a continuation-style coroutine");
753     replaceRetconSuspendUses();
754     break;
755   }
756 
757   // Handle suspends.
758   replaceCoroSuspends();
759 
760   // Handle swifterror.
761   replaceSwiftErrorOps();
762 
763   // Remove coro.end intrinsics.
764   replaceCoroEnds();
765 
766   // Eliminate coro.free from the clones, replacing it with 'null' in cleanup,
767   // to suppress deallocation code.
768   if (Shape.ABI == coro::ABI::Switch)
769     coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]),
770                           /*Elide=*/ FKind == CoroCloner::Kind::SwitchCleanup);
771 }
772 
773 // Create a resume clone by cloning the body of the original function, setting
774 // new entry block and replacing coro.suspend an appropriate value to force
775 // resume or cleanup pass for every suspend point.
776 static Function *createClone(Function &F, const Twine &Suffix,
777                              coro::Shape &Shape, CoroCloner::Kind FKind) {
778   CoroCloner Cloner(F, Suffix, Shape, FKind);
779   Cloner.create();
780   return Cloner.getFunction();
781 }
782 
783 /// Remove calls to llvm.coro.end in the original function.
784 static void removeCoroEnds(coro::Shape &Shape, CallGraph *CG) {
785   for (auto End : Shape.CoroEnds) {
786     replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, CG);
787   }
788 }
789 
790 static void replaceFrameSize(coro::Shape &Shape) {
791   if (Shape.CoroSizes.empty())
792     return;
793 
794   // In the same function all coro.sizes should have the same result type.
795   auto *SizeIntrin = Shape.CoroSizes.back();
796   Module *M = SizeIntrin->getModule();
797   const DataLayout &DL = M->getDataLayout();
798   auto Size = DL.getTypeAllocSize(Shape.FrameTy);
799   auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size);
800 
801   for (CoroSizeInst *CS : Shape.CoroSizes) {
802     CS->replaceAllUsesWith(SizeConstant);
803     CS->eraseFromParent();
804   }
805 }
806 
807 // Create a global constant array containing pointers to functions provided and
808 // set Info parameter of CoroBegin to point at this constant. Example:
809 //
810 //   @f.resumers = internal constant [2 x void(%f.frame*)*]
811 //                    [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy]
812 //   define void @f() {
813 //     ...
814 //     call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
815 //                    i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*))
816 //
817 // Assumes that all the functions have the same signature.
818 static void setCoroInfo(Function &F, coro::Shape &Shape,
819                         ArrayRef<Function *> Fns) {
820   // This only works under the switch-lowering ABI because coro elision
821   // only works on the switch-lowering ABI.
822   assert(Shape.ABI == coro::ABI::Switch);
823 
824   SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end());
825   assert(!Args.empty());
826   Function *Part = *Fns.begin();
827   Module *M = Part->getParent();
828   auto *ArrTy = ArrayType::get(Part->getType(), Args.size());
829 
830   auto *ConstVal = ConstantArray::get(ArrTy, Args);
831   auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true,
832                                 GlobalVariable::PrivateLinkage, ConstVal,
833                                 F.getName() + Twine(".resumers"));
834 
835   // Update coro.begin instruction to refer to this constant.
836   LLVMContext &C = F.getContext();
837   auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C));
838   Shape.getSwitchCoroId()->setInfo(BC);
839 }
840 
841 // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
842 static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn,
843                             Function *DestroyFn, Function *CleanupFn) {
844   assert(Shape.ABI == coro::ABI::Switch);
845 
846   IRBuilder<> Builder(Shape.FramePtr->getNextNode());
847   auto *ResumeAddr = Builder.CreateStructGEP(
848       Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Resume,
849       "resume.addr");
850   Builder.CreateStore(ResumeFn, ResumeAddr);
851 
852   Value *DestroyOrCleanupFn = DestroyFn;
853 
854   CoroIdInst *CoroId = Shape.getSwitchCoroId();
855   if (CoroAllocInst *CA = CoroId->getCoroAlloc()) {
856     // If there is a CoroAlloc and it returns false (meaning we elide the
857     // allocation, use CleanupFn instead of DestroyFn).
858     DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn);
859   }
860 
861   auto *DestroyAddr = Builder.CreateStructGEP(
862       Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Destroy,
863       "destroy.addr");
864   Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr);
865 }
866 
867 static void postSplitCleanup(Function &F) {
868   removeUnreachableBlocks(F);
869 
870   // For now, we do a mandatory verification step because we don't
871   // entirely trust this pass.  Note that we don't want to add a verifier
872   // pass to FPM below because it will also verify all the global data.
873   verifyFunction(F);
874 
875   legacy::FunctionPassManager FPM(F.getParent());
876 
877   FPM.add(createSCCPPass());
878   FPM.add(createCFGSimplificationPass());
879   FPM.add(createEarlyCSEPass());
880   FPM.add(createCFGSimplificationPass());
881 
882   FPM.doInitialization();
883   FPM.run(F);
884   FPM.doFinalization();
885 }
886 
887 // Assuming we arrived at the block NewBlock from Prev instruction, store
888 // PHI's incoming values in the ResolvedValues map.
889 static void
890 scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock,
891                           DenseMap<Value *, Value *> &ResolvedValues) {
892   auto *PrevBB = Prev->getParent();
893   for (PHINode &PN : NewBlock->phis()) {
894     auto V = PN.getIncomingValueForBlock(PrevBB);
895     // See if we already resolved it.
896     auto VI = ResolvedValues.find(V);
897     if (VI != ResolvedValues.end())
898       V = VI->second;
899     // Remember the value.
900     ResolvedValues[&PN] = V;
901   }
902 }
903 
904 // Replace a sequence of branches leading to a ret, with a clone of a ret
905 // instruction. Suspend instruction represented by a switch, track the PHI
906 // values and select the correct case successor when possible.
907 static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
908   DenseMap<Value *, Value *> ResolvedValues;
909 
910   Instruction *I = InitialInst;
911   while (I->isTerminator()) {
912     if (isa<ReturnInst>(I)) {
913       if (I != InitialInst)
914         ReplaceInstWithInst(InitialInst, I->clone());
915       return true;
916     }
917     if (auto *BR = dyn_cast<BranchInst>(I)) {
918       if (BR->isUnconditional()) {
919         BasicBlock *BB = BR->getSuccessor(0);
920         scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
921         I = BB->getFirstNonPHIOrDbgOrLifetime();
922         continue;
923       }
924     } else if (auto *SI = dyn_cast<SwitchInst>(I)) {
925       Value *V = SI->getCondition();
926       auto it = ResolvedValues.find(V);
927       if (it != ResolvedValues.end())
928         V = it->second;
929       if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) {
930         BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
931         scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
932         I = BB->getFirstNonPHIOrDbgOrLifetime();
933         continue;
934       }
935     }
936     return false;
937   }
938   return false;
939 }
940 
941 // Add musttail to any resume instructions that is immediately followed by a
942 // suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call
943 // for symmetrical coroutine control transfer (C++ Coroutines TS extension).
944 // This transformation is done only in the resume part of the coroutine that has
945 // identical signature and calling convention as the coro.resume call.
946 static void addMustTailToCoroResumes(Function &F) {
947   bool changed = false;
948 
949   // Collect potential resume instructions.
950   SmallVector<CallInst *, 4> Resumes;
951   for (auto &I : instructions(F))
952     if (auto *Call = dyn_cast<CallInst>(&I))
953       if (auto *CalledValue = Call->getCalledValue())
954         // CoroEarly pass replaced coro resumes with indirect calls to an
955         // address return by CoroSubFnInst intrinsic. See if it is one of those.
956         if (isa<CoroSubFnInst>(CalledValue->stripPointerCasts()))
957           Resumes.push_back(Call);
958 
959   // Set musttail on those that are followed by a ret instruction.
960   for (CallInst *Call : Resumes)
961     if (simplifyTerminatorLeadingToRet(Call->getNextNode())) {
962       Call->setTailCallKind(CallInst::TCK_MustTail);
963       changed = true;
964     }
965 
966   if (changed)
967     removeUnreachableBlocks(F);
968 }
969 
970 // Coroutine has no suspend points. Remove heap allocation for the coroutine
971 // frame if possible.
972 static void handleNoSuspendCoroutine(coro::Shape &Shape) {
973   auto *CoroBegin = Shape.CoroBegin;
974   auto *CoroId = CoroBegin->getId();
975   auto *AllocInst = CoroId->getCoroAlloc();
976   switch (Shape.ABI) {
977   case coro::ABI::Switch: {
978     auto SwitchId = cast<CoroIdInst>(CoroId);
979     coro::replaceCoroFree(SwitchId, /*Elide=*/AllocInst != nullptr);
980     if (AllocInst) {
981       IRBuilder<> Builder(AllocInst);
982       // FIXME: Need to handle overaligned members.
983       auto *Frame = Builder.CreateAlloca(Shape.FrameTy);
984       auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy());
985       AllocInst->replaceAllUsesWith(Builder.getFalse());
986       AllocInst->eraseFromParent();
987       CoroBegin->replaceAllUsesWith(VFrame);
988     } else {
989       CoroBegin->replaceAllUsesWith(CoroBegin->getMem());
990     }
991     break;
992   }
993 
994   case coro::ABI::Retcon:
995   case coro::ABI::RetconOnce:
996     CoroBegin->replaceAllUsesWith(UndefValue::get(CoroBegin->getType()));
997     break;
998   }
999 
1000   CoroBegin->eraseFromParent();
1001 }
1002 
1003 // SimplifySuspendPoint needs to check that there is no calls between
1004 // coro_save and coro_suspend, since any of the calls may potentially resume
1005 // the coroutine and if that is the case we cannot eliminate the suspend point.
1006 static bool hasCallsInBlockBetween(Instruction *From, Instruction *To) {
1007   for (Instruction *I = From; I != To; I = I->getNextNode()) {
1008     // Assume that no intrinsic can resume the coroutine.
1009     if (isa<IntrinsicInst>(I))
1010       continue;
1011 
1012     if (CallSite(I))
1013       return true;
1014   }
1015   return false;
1016 }
1017 
1018 static bool hasCallsInBlocksBetween(BasicBlock *SaveBB, BasicBlock *ResDesBB) {
1019   SmallPtrSet<BasicBlock *, 8> Set;
1020   SmallVector<BasicBlock *, 8> Worklist;
1021 
1022   Set.insert(SaveBB);
1023   Worklist.push_back(ResDesBB);
1024 
1025   // Accumulate all blocks between SaveBB and ResDesBB. Because CoroSaveIntr
1026   // returns a token consumed by suspend instruction, all blocks in between
1027   // will have to eventually hit SaveBB when going backwards from ResDesBB.
1028   while (!Worklist.empty()) {
1029     auto *BB = Worklist.pop_back_val();
1030     Set.insert(BB);
1031     for (auto *Pred : predecessors(BB))
1032       if (Set.count(Pred) == 0)
1033         Worklist.push_back(Pred);
1034   }
1035 
1036   // SaveBB and ResDesBB are checked separately in hasCallsBetween.
1037   Set.erase(SaveBB);
1038   Set.erase(ResDesBB);
1039 
1040   for (auto *BB : Set)
1041     if (hasCallsInBlockBetween(BB->getFirstNonPHI(), nullptr))
1042       return true;
1043 
1044   return false;
1045 }
1046 
1047 static bool hasCallsBetween(Instruction *Save, Instruction *ResumeOrDestroy) {
1048   auto *SaveBB = Save->getParent();
1049   auto *ResumeOrDestroyBB = ResumeOrDestroy->getParent();
1050 
1051   if (SaveBB == ResumeOrDestroyBB)
1052     return hasCallsInBlockBetween(Save->getNextNode(), ResumeOrDestroy);
1053 
1054   // Any calls from Save to the end of the block?
1055   if (hasCallsInBlockBetween(Save->getNextNode(), nullptr))
1056     return true;
1057 
1058   // Any calls from begging of the block up to ResumeOrDestroy?
1059   if (hasCallsInBlockBetween(ResumeOrDestroyBB->getFirstNonPHI(),
1060                              ResumeOrDestroy))
1061     return true;
1062 
1063   // Any calls in all of the blocks between SaveBB and ResumeOrDestroyBB?
1064   if (hasCallsInBlocksBetween(SaveBB, ResumeOrDestroyBB))
1065     return true;
1066 
1067   return false;
1068 }
1069 
1070 // If a SuspendIntrin is preceded by Resume or Destroy, we can eliminate the
1071 // suspend point and replace it with nornal control flow.
1072 static bool simplifySuspendPoint(CoroSuspendInst *Suspend,
1073                                  CoroBeginInst *CoroBegin) {
1074   Instruction *Prev = Suspend->getPrevNode();
1075   if (!Prev) {
1076     auto *Pred = Suspend->getParent()->getSinglePredecessor();
1077     if (!Pred)
1078       return false;
1079     Prev = Pred->getTerminator();
1080   }
1081 
1082   CallSite CS{Prev};
1083   if (!CS)
1084     return false;
1085 
1086   auto *CallInstr = CS.getInstruction();
1087 
1088   auto *Callee = CS.getCalledValue()->stripPointerCasts();
1089 
1090   // See if the callsite is for resumption or destruction of the coroutine.
1091   auto *SubFn = dyn_cast<CoroSubFnInst>(Callee);
1092   if (!SubFn)
1093     return false;
1094 
1095   // Does not refer to the current coroutine, we cannot do anything with it.
1096   if (SubFn->getFrame() != CoroBegin)
1097     return false;
1098 
1099   // See if the transformation is safe. Specifically, see if there are any
1100   // calls in between Save and CallInstr. They can potenitally resume the
1101   // coroutine rendering this optimization unsafe.
1102   auto *Save = Suspend->getCoroSave();
1103   if (hasCallsBetween(Save, CallInstr))
1104     return false;
1105 
1106   // Replace llvm.coro.suspend with the value that results in resumption over
1107   // the resume or cleanup path.
1108   Suspend->replaceAllUsesWith(SubFn->getRawIndex());
1109   Suspend->eraseFromParent();
1110   Save->eraseFromParent();
1111 
1112   // No longer need a call to coro.resume or coro.destroy.
1113   if (auto *Invoke = dyn_cast<InvokeInst>(CallInstr)) {
1114     BranchInst::Create(Invoke->getNormalDest(), Invoke);
1115   }
1116 
1117   // Grab the CalledValue from CS before erasing the CallInstr.
1118   auto *CalledValue = CS.getCalledValue();
1119   CallInstr->eraseFromParent();
1120 
1121   // If no more users remove it. Usually it is a bitcast of SubFn.
1122   if (CalledValue != SubFn && CalledValue->user_empty())
1123     if (auto *I = dyn_cast<Instruction>(CalledValue))
1124       I->eraseFromParent();
1125 
1126   // Now we are good to remove SubFn.
1127   if (SubFn->user_empty())
1128     SubFn->eraseFromParent();
1129 
1130   return true;
1131 }
1132 
1133 // Remove suspend points that are simplified.
1134 static void simplifySuspendPoints(coro::Shape &Shape) {
1135   // Currently, the only simplification we do is switch-lowering-specific.
1136   if (Shape.ABI != coro::ABI::Switch)
1137     return;
1138 
1139   auto &S = Shape.CoroSuspends;
1140   size_t I = 0, N = S.size();
1141   if (N == 0)
1142     return;
1143   while (true) {
1144     if (simplifySuspendPoint(cast<CoroSuspendInst>(S[I]), Shape.CoroBegin)) {
1145       if (--N == I)
1146         break;
1147       std::swap(S[I], S[N]);
1148       continue;
1149     }
1150     if (++I == N)
1151       break;
1152   }
1153   S.resize(N);
1154 }
1155 
1156 static void splitSwitchCoroutine(Function &F, coro::Shape &Shape,
1157                                  SmallVectorImpl<Function *> &Clones) {
1158   assert(Shape.ABI == coro::ABI::Switch);
1159 
1160   createResumeEntryBlock(F, Shape);
1161   auto ResumeClone = createClone(F, ".resume", Shape,
1162                                  CoroCloner::Kind::SwitchResume);
1163   auto DestroyClone = createClone(F, ".destroy", Shape,
1164                                   CoroCloner::Kind::SwitchUnwind);
1165   auto CleanupClone = createClone(F, ".cleanup", Shape,
1166                                   CoroCloner::Kind::SwitchCleanup);
1167 
1168   postSplitCleanup(*ResumeClone);
1169   postSplitCleanup(*DestroyClone);
1170   postSplitCleanup(*CleanupClone);
1171 
1172   addMustTailToCoroResumes(*ResumeClone);
1173 
1174   // Store addresses resume/destroy/cleanup functions in the coroutine frame.
1175   updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
1176 
1177   assert(Clones.empty());
1178   Clones.push_back(ResumeClone);
1179   Clones.push_back(DestroyClone);
1180   Clones.push_back(CleanupClone);
1181 
1182   // Create a constant array referring to resume/destroy/clone functions pointed
1183   // by the last argument of @llvm.coro.info, so that CoroElide pass can
1184   // determined correct function to call.
1185   setCoroInfo(F, Shape, Clones);
1186 }
1187 
1188 static void splitRetconCoroutine(Function &F, coro::Shape &Shape,
1189                                  SmallVectorImpl<Function *> &Clones) {
1190   assert(Shape.ABI == coro::ABI::Retcon ||
1191          Shape.ABI == coro::ABI::RetconOnce);
1192   assert(Clones.empty());
1193 
1194   // Reset various things that the optimizer might have decided it
1195   // "knows" about the coroutine function due to not seeing a return.
1196   F.removeFnAttr(Attribute::NoReturn);
1197   F.removeAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
1198   F.removeAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
1199 
1200   // Allocate the frame.
1201   auto *Id = cast<AnyCoroIdRetconInst>(Shape.CoroBegin->getId());
1202   Value *RawFramePtr;
1203   if (Shape.RetconLowering.IsFrameInlineInStorage) {
1204     RawFramePtr = Id->getStorage();
1205   } else {
1206     IRBuilder<> Builder(Id);
1207 
1208     // Determine the size of the frame.
1209     const DataLayout &DL = F.getParent()->getDataLayout();
1210     auto Size = DL.getTypeAllocSize(Shape.FrameTy);
1211 
1212     // Allocate.  We don't need to update the call graph node because we're
1213     // going to recompute it from scratch after splitting.
1214     RawFramePtr = Shape.emitAlloc(Builder, Builder.getInt64(Size), nullptr);
1215     RawFramePtr =
1216       Builder.CreateBitCast(RawFramePtr, Shape.CoroBegin->getType());
1217 
1218     // Stash the allocated frame pointer in the continuation storage.
1219     auto Dest = Builder.CreateBitCast(Id->getStorage(),
1220                                       RawFramePtr->getType()->getPointerTo());
1221     Builder.CreateStore(RawFramePtr, Dest);
1222   }
1223 
1224   // Map all uses of llvm.coro.begin to the allocated frame pointer.
1225   {
1226     // Make sure we don't invalidate Shape.FramePtr.
1227     TrackingVH<Instruction> Handle(Shape.FramePtr);
1228     Shape.CoroBegin->replaceAllUsesWith(RawFramePtr);
1229     Shape.FramePtr = Handle.getValPtr();
1230   }
1231 
1232   // Create a unique return block.
1233   BasicBlock *ReturnBB = nullptr;
1234   SmallVector<PHINode *, 4> ReturnPHIs;
1235 
1236   // Create all the functions in order after the main function.
1237   auto NextF = std::next(F.getIterator());
1238 
1239   // Create a continuation function for each of the suspend points.
1240   Clones.reserve(Shape.CoroSuspends.size());
1241   for (size_t i = 0, e = Shape.CoroSuspends.size(); i != e; ++i) {
1242     auto Suspend = cast<CoroSuspendRetconInst>(Shape.CoroSuspends[i]);
1243 
1244     // Create the clone declaration.
1245     auto Continuation =
1246       createCloneDeclaration(F, Shape, ".resume." + Twine(i), NextF);
1247     Clones.push_back(Continuation);
1248 
1249     // Insert a branch to the unified return block immediately before
1250     // the suspend point.
1251     auto SuspendBB = Suspend->getParent();
1252     auto NewSuspendBB = SuspendBB->splitBasicBlock(Suspend);
1253     auto Branch = cast<BranchInst>(SuspendBB->getTerminator());
1254 
1255     // Create the unified return block.
1256     if (!ReturnBB) {
1257       // Place it before the first suspend.
1258       ReturnBB = BasicBlock::Create(F.getContext(), "coro.return", &F,
1259                                     NewSuspendBB);
1260       Shape.RetconLowering.ReturnBlock = ReturnBB;
1261 
1262       IRBuilder<> Builder(ReturnBB);
1263 
1264       // Create PHIs for all the return values.
1265       assert(ReturnPHIs.empty());
1266 
1267       // First, the continuation.
1268       ReturnPHIs.push_back(Builder.CreatePHI(Continuation->getType(),
1269                                              Shape.CoroSuspends.size()));
1270 
1271       // Next, all the directly-yielded values.
1272       for (auto ResultTy : Shape.getRetconResultTypes())
1273         ReturnPHIs.push_back(Builder.CreatePHI(ResultTy,
1274                                                Shape.CoroSuspends.size()));
1275 
1276       // Build the return value.
1277       auto RetTy = F.getReturnType();
1278 
1279       // Cast the continuation value if necessary.
1280       // We can't rely on the types matching up because that type would
1281       // have to be infinite.
1282       auto CastedContinuationTy =
1283         (ReturnPHIs.size() == 1 ? RetTy : RetTy->getStructElementType(0));
1284       auto *CastedContinuation =
1285         Builder.CreateBitCast(ReturnPHIs[0], CastedContinuationTy);
1286 
1287       Value *RetV;
1288       if (ReturnPHIs.size() == 1) {
1289         RetV = CastedContinuation;
1290       } else {
1291         RetV = UndefValue::get(RetTy);
1292         RetV = Builder.CreateInsertValue(RetV, CastedContinuation, 0);
1293         for (size_t I = 1, E = ReturnPHIs.size(); I != E; ++I)
1294           RetV = Builder.CreateInsertValue(RetV, ReturnPHIs[I], I);
1295       }
1296 
1297       Builder.CreateRet(RetV);
1298     }
1299 
1300     // Branch to the return block.
1301     Branch->setSuccessor(0, ReturnBB);
1302     ReturnPHIs[0]->addIncoming(Continuation, SuspendBB);
1303     size_t NextPHIIndex = 1;
1304     for (auto &VUse : Suspend->value_operands())
1305       ReturnPHIs[NextPHIIndex++]->addIncoming(&*VUse, SuspendBB);
1306     assert(NextPHIIndex == ReturnPHIs.size());
1307   }
1308 
1309   assert(Clones.size() == Shape.CoroSuspends.size());
1310   for (size_t i = 0, e = Shape.CoroSuspends.size(); i != e; ++i) {
1311     auto Suspend = Shape.CoroSuspends[i];
1312     auto Clone = Clones[i];
1313 
1314     CoroCloner(F, "resume." + Twine(i), Shape, Clone, Suspend).create();
1315   }
1316 }
1317 
1318 namespace {
1319   class PrettyStackTraceFunction : public PrettyStackTraceEntry {
1320     Function &F;
1321   public:
1322     PrettyStackTraceFunction(Function &F) : F(F) {}
1323     void print(raw_ostream &OS) const override {
1324       OS << "While splitting coroutine ";
1325       F.printAsOperand(OS, /*print type*/ false, F.getParent());
1326       OS << "\n";
1327     }
1328   };
1329 }
1330 
1331 static void splitCoroutine(Function &F, coro::Shape &Shape,
1332                            SmallVectorImpl<Function *> &Clones) {
1333   switch (Shape.ABI) {
1334   case coro::ABI::Switch:
1335     return splitSwitchCoroutine(F, Shape, Clones);
1336   case coro::ABI::Retcon:
1337   case coro::ABI::RetconOnce:
1338     return splitRetconCoroutine(F, Shape, Clones);
1339   }
1340   llvm_unreachable("bad ABI kind");
1341 }
1342 
1343 static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) {
1344   PrettyStackTraceFunction prettyStackTrace(F);
1345 
1346   // The suspend-crossing algorithm in buildCoroutineFrame get tripped
1347   // up by uses in unreachable blocks, so remove them as a first pass.
1348   removeUnreachableBlocks(F);
1349 
1350   coro::Shape Shape(F);
1351   if (!Shape.CoroBegin)
1352     return;
1353 
1354   simplifySuspendPoints(Shape);
1355   buildCoroutineFrame(F, Shape);
1356   replaceFrameSize(Shape);
1357 
1358   SmallVector<Function*, 4> Clones;
1359 
1360   // If there are no suspend points, no split required, just remove
1361   // the allocation and deallocation blocks, they are not needed.
1362   if (Shape.CoroSuspends.empty()) {
1363     handleNoSuspendCoroutine(Shape);
1364   } else {
1365     splitCoroutine(F, Shape, Clones);
1366   }
1367 
1368   // Replace all the swifterror operations in the original function.
1369   // This invalidates SwiftErrorOps in the Shape.
1370   replaceSwiftErrorOps(F, Shape, nullptr);
1371 
1372   removeCoroEnds(Shape, &CG);
1373   postSplitCleanup(F);
1374 
1375   // Update call graph and add the functions we created to the SCC.
1376   coro::updateCallGraph(F, Clones, CG, SCC);
1377 }
1378 
1379 // When we see the coroutine the first time, we insert an indirect call to a
1380 // devirt trigger function and mark the coroutine that it is now ready for
1381 // split.
1382 static void prepareForSplit(Function &F, CallGraph &CG) {
1383   Module &M = *F.getParent();
1384   LLVMContext &Context = F.getContext();
1385 #ifndef NDEBUG
1386   Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN);
1387   assert(DevirtFn && "coro.devirt.trigger function not found");
1388 #endif
1389 
1390   F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
1391 
1392   // Insert an indirect call sequence that will be devirtualized by CoroElide
1393   // pass:
1394   //    %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1)
1395   //    %1 = bitcast i8* %0 to void(i8*)*
1396   //    call void %1(i8* null)
1397   coro::LowererBase Lowerer(M);
1398   Instruction *InsertPt = F.getEntryBlock().getTerminator();
1399   auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(Context));
1400   auto *DevirtFnAddr =
1401       Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt);
1402   FunctionType *FnTy = FunctionType::get(Type::getVoidTy(Context),
1403                                          {Type::getInt8PtrTy(Context)}, false);
1404   auto *IndirectCall = CallInst::Create(FnTy, DevirtFnAddr, Null, "", InsertPt);
1405 
1406   // Update CG graph with an indirect call we just added.
1407   CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode());
1408 }
1409 
1410 // Make sure that there is a devirtualization trigger function that CoroSplit
1411 // pass uses the force restart CGSCC pipeline. If devirt trigger function is not
1412 // found, we will create one and add it to the current SCC.
1413 static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) {
1414   Module &M = CG.getModule();
1415   if (M.getFunction(CORO_DEVIRT_TRIGGER_FN))
1416     return;
1417 
1418   LLVMContext &C = M.getContext();
1419   auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C),
1420                                  /*isVarArg=*/false);
1421   Function *DevirtFn =
1422       Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
1423                        CORO_DEVIRT_TRIGGER_FN, &M);
1424   DevirtFn->addFnAttr(Attribute::AlwaysInline);
1425   auto *Entry = BasicBlock::Create(C, "entry", DevirtFn);
1426   ReturnInst::Create(C, Entry);
1427 
1428   auto *Node = CG.getOrInsertFunction(DevirtFn);
1429 
1430   SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
1431   Nodes.push_back(Node);
1432   SCC.initialize(Nodes);
1433 }
1434 
1435 /// Replace a call to llvm.coro.prepare.retcon.
1436 static void replacePrepare(CallInst *Prepare, CallGraph &CG) {
1437   auto CastFn = Prepare->getArgOperand(0); // as an i8*
1438   auto Fn = CastFn->stripPointerCasts(); // as its original type
1439 
1440   // Find call graph nodes for the preparation.
1441   CallGraphNode *PrepareUserNode = nullptr, *FnNode = nullptr;
1442   if (auto ConcreteFn = dyn_cast<Function>(Fn)) {
1443     PrepareUserNode = CG[Prepare->getFunction()];
1444     FnNode = CG[ConcreteFn];
1445   }
1446 
1447   // Attempt to peephole this pattern:
1448   //    %0 = bitcast [[TYPE]] @some_function to i8*
1449   //    %1 = call @llvm.coro.prepare.retcon(i8* %0)
1450   //    %2 = bitcast %1 to [[TYPE]]
1451   // ==>
1452   //    %2 = @some_function
1453   for (auto UI = Prepare->use_begin(), UE = Prepare->use_end();
1454          UI != UE; ) {
1455     // Look for bitcasts back to the original function type.
1456     auto *Cast = dyn_cast<BitCastInst>((UI++)->getUser());
1457     if (!Cast || Cast->getType() != Fn->getType()) continue;
1458 
1459     // Check whether the replacement will introduce new direct calls.
1460     // If so, we'll need to update the call graph.
1461     if (PrepareUserNode) {
1462       for (auto &Use : Cast->uses()) {
1463         if (auto *CB = dyn_cast<CallBase>(Use.getUser())) {
1464           if (!CB->isCallee(&Use))
1465             continue;
1466           PrepareUserNode->removeCallEdgeFor(*CB);
1467           PrepareUserNode->addCalledFunction(CB, FnNode);
1468         }
1469       }
1470     }
1471 
1472     // Replace and remove the cast.
1473     Cast->replaceAllUsesWith(Fn);
1474     Cast->eraseFromParent();
1475   }
1476 
1477   // Replace any remaining uses with the function as an i8*.
1478   // This can never directly be a callee, so we don't need to update CG.
1479   Prepare->replaceAllUsesWith(CastFn);
1480   Prepare->eraseFromParent();
1481 
1482   // Kill dead bitcasts.
1483   while (auto *Cast = dyn_cast<BitCastInst>(CastFn)) {
1484     if (!Cast->use_empty()) break;
1485     CastFn = Cast->getOperand(0);
1486     Cast->eraseFromParent();
1487   }
1488 }
1489 
1490 /// Remove calls to llvm.coro.prepare.retcon, a barrier meant to prevent
1491 /// IPO from operating on calls to a retcon coroutine before it's been
1492 /// split.  This is only safe to do after we've split all retcon
1493 /// coroutines in the module.  We can do that this in this pass because
1494 /// this pass does promise to split all retcon coroutines (as opposed to
1495 /// switch coroutines, which are lowered in multiple stages).
1496 static bool replaceAllPrepares(Function *PrepareFn, CallGraph &CG) {
1497   bool Changed = false;
1498   for (auto PI = PrepareFn->use_begin(), PE = PrepareFn->use_end();
1499          PI != PE; ) {
1500     // Intrinsics can only be used in calls.
1501     auto *Prepare = cast<CallInst>((PI++)->getUser());
1502     replacePrepare(Prepare, CG);
1503     Changed = true;
1504   }
1505 
1506   return Changed;
1507 }
1508 
1509 //===----------------------------------------------------------------------===//
1510 //                              Top Level Driver
1511 //===----------------------------------------------------------------------===//
1512 
1513 namespace {
1514 
1515 struct CoroSplit : public CallGraphSCCPass {
1516   static char ID; // Pass identification, replacement for typeid
1517 
1518   CoroSplit() : CallGraphSCCPass(ID) {
1519     initializeCoroSplitPass(*PassRegistry::getPassRegistry());
1520   }
1521 
1522   bool Run = false;
1523 
1524   // A coroutine is identified by the presence of coro.begin intrinsic, if
1525   // we don't have any, this pass has nothing to do.
1526   bool doInitialization(CallGraph &CG) override {
1527     Run = coro::declaresIntrinsics(CG.getModule(),
1528                                    {"llvm.coro.begin",
1529                                     "llvm.coro.prepare.retcon"});
1530     return CallGraphSCCPass::doInitialization(CG);
1531   }
1532 
1533   bool runOnSCC(CallGraphSCC &SCC) override {
1534     if (!Run)
1535       return false;
1536 
1537     // Check for uses of llvm.coro.prepare.retcon.
1538     auto PrepareFn =
1539       SCC.getCallGraph().getModule().getFunction("llvm.coro.prepare.retcon");
1540     if (PrepareFn && PrepareFn->use_empty())
1541       PrepareFn = nullptr;
1542 
1543     // Find coroutines for processing.
1544     SmallVector<Function *, 4> Coroutines;
1545     for (CallGraphNode *CGN : SCC)
1546       if (auto *F = CGN->getFunction())
1547         if (F->hasFnAttribute(CORO_PRESPLIT_ATTR))
1548           Coroutines.push_back(F);
1549 
1550     if (Coroutines.empty() && !PrepareFn)
1551       return false;
1552 
1553     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
1554 
1555     if (Coroutines.empty())
1556       return replaceAllPrepares(PrepareFn, CG);
1557 
1558     createDevirtTriggerFunc(CG, SCC);
1559 
1560     // Split all the coroutines.
1561     for (Function *F : Coroutines) {
1562       Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR);
1563       StringRef Value = Attr.getValueAsString();
1564       LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName()
1565                         << "' state: " << Value << "\n");
1566       if (Value == UNPREPARED_FOR_SPLIT) {
1567         prepareForSplit(*F, CG);
1568         continue;
1569       }
1570       F->removeFnAttr(CORO_PRESPLIT_ATTR);
1571       splitCoroutine(*F, CG, SCC);
1572     }
1573 
1574     if (PrepareFn)
1575       replaceAllPrepares(PrepareFn, CG);
1576 
1577     return true;
1578   }
1579 
1580   void getAnalysisUsage(AnalysisUsage &AU) const override {
1581     CallGraphSCCPass::getAnalysisUsage(AU);
1582   }
1583 
1584   StringRef getPassName() const override { return "Coroutine Splitting"; }
1585 };
1586 
1587 } // end anonymous namespace
1588 
1589 char CoroSplit::ID = 0;
1590 
1591 INITIALIZE_PASS_BEGIN(
1592     CoroSplit, "coro-split",
1593     "Split coroutine into a set of functions driving its state machine", false,
1594     false)
1595 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
1596 INITIALIZE_PASS_END(
1597     CoroSplit, "coro-split",
1598     "Split coroutine into a set of functions driving its state machine", false,
1599     false)
1600 
1601 Pass *llvm::createCoroSplitPass() { return new CoroSplit(); }
1602