1 //===-- Coroutines.cpp ----------------------------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // This file implements the common infrastructure for Coroutine Passes.
10 //===----------------------------------------------------------------------===//
11 
12 #include "CoroInternal.h"
13 #include "llvm/Analysis/CallGraphSCCPass.h"
14 #include "llvm/IR/InstIterator.h"
15 #include "llvm/IR/LegacyPassManager.h"
16 #include "llvm/IR/Verifier.h"
17 #include "llvm/InitializePasses.h"
18 #include "llvm/Transforms/IPO.h"
19 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
20 #include "llvm/Transforms/Utils/Local.h"
21 
22 using namespace llvm;
23 
24 void llvm::initializeCoroutines(PassRegistry &Registry) {
25   initializeCoroEarlyPass(Registry);
26   initializeCoroSplitPass(Registry);
27   initializeCoroElidePass(Registry);
28   initializeCoroCleanupPass(Registry);
29 }
30 
31 static void addCoroutineOpt0Passes(const PassManagerBuilder &Builder,
32                                    legacy::PassManagerBase &PM) {
33   PM.add(createCoroSplitPass());
34   PM.add(createCoroElidePass());
35 
36   PM.add(createBarrierNoopPass());
37   PM.add(createCoroCleanupPass());
38 }
39 
40 static void addCoroutineEarlyPasses(const PassManagerBuilder &Builder,
41                                     legacy::PassManagerBase &PM) {
42   PM.add(createCoroEarlyPass());
43 }
44 
45 static void addCoroutineScalarOptimizerPasses(const PassManagerBuilder &Builder,
46                                               legacy::PassManagerBase &PM) {
47   PM.add(createCoroElidePass());
48 }
49 
50 static void addCoroutineSCCPasses(const PassManagerBuilder &Builder,
51                                   legacy::PassManagerBase &PM) {
52   PM.add(createCoroSplitPass());
53 }
54 
55 static void addCoroutineOptimizerLastPasses(const PassManagerBuilder &Builder,
56                                             legacy::PassManagerBase &PM) {
57   PM.add(createCoroCleanupPass());
58 }
59 
60 void llvm::addCoroutinePassesToExtensionPoints(PassManagerBuilder &Builder) {
61   Builder.addExtension(PassManagerBuilder::EP_EarlyAsPossible,
62                        addCoroutineEarlyPasses);
63   Builder.addExtension(PassManagerBuilder::EP_EnabledOnOptLevel0,
64                        addCoroutineOpt0Passes);
65   Builder.addExtension(PassManagerBuilder::EP_CGSCCOptimizerLate,
66                        addCoroutineSCCPasses);
67   Builder.addExtension(PassManagerBuilder::EP_ScalarOptimizerLate,
68                        addCoroutineScalarOptimizerPasses);
69   Builder.addExtension(PassManagerBuilder::EP_OptimizerLast,
70                        addCoroutineOptimizerLastPasses);
71 }
72 
73 // Construct the lowerer base class and initialize its members.
74 coro::LowererBase::LowererBase(Module &M)
75     : TheModule(M), Context(M.getContext()),
76       Int8Ptr(Type::getInt8PtrTy(Context)),
77       ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
78                                      /*isVarArg=*/false)),
79       NullPtr(ConstantPointerNull::get(Int8Ptr)) {}
80 
81 // Creates a sequence of instructions to obtain a resume function address using
82 // llvm.coro.subfn.addr. It generates the following sequence:
83 //
84 //    call i8* @llvm.coro.subfn.addr(i8* %Arg, i8 %index)
85 //    bitcast i8* %2 to void(i8*)*
86 
87 Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
88                                         Instruction *InsertPt) {
89   auto *IndexVal = ConstantInt::get(Type::getInt8Ty(Context), Index);
90   auto *Fn = Intrinsic::getDeclaration(&TheModule, Intrinsic::coro_subfn_addr);
91 
92   assert(Index >= CoroSubFnInst::IndexFirst &&
93          Index < CoroSubFnInst::IndexLast &&
94          "makeSubFnCall: Index value out of range");
95   auto *Call = CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt);
96 
97   auto *Bitcast =
98       new BitCastInst(Call, ResumeFnType->getPointerTo(), "", InsertPt);
99   return Bitcast;
100 }
101 
102 #ifndef NDEBUG
103 static bool isCoroutineIntrinsicName(StringRef Name) {
104   // NOTE: Must be sorted!
105   static const char *const CoroIntrinsics[] = {
106       "llvm.coro.alloc",   "llvm.coro.begin",   "llvm.coro.destroy",
107       "llvm.coro.done",    "llvm.coro.end",     "llvm.coro.frame",
108       "llvm.coro.free",    "llvm.coro.id",      "llvm.coro.param",
109       "llvm.coro.promise", "llvm.coro.resume",  "llvm.coro.save",
110       "llvm.coro.size",    "llvm.coro.subfn.addr", "llvm.coro.suspend",
111   };
112   return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name) != -1;
113 }
114 #endif
115 
116 // Verifies if a module has named values listed. Also, in debug mode verifies
117 // that names are intrinsic names.
118 bool coro::declaresIntrinsics(Module &M,
119                               std::initializer_list<StringRef> List) {
120 
121   for (StringRef Name : List) {
122     assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
123     if (M.getNamedValue(Name))
124       return true;
125   }
126 
127   return false;
128 }
129 
130 // Replace all coro.frees associated with the provided CoroId either with 'null'
131 // if Elide is true and with its frame parameter otherwise.
132 void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) {
133   SmallVector<CoroFreeInst *, 4> CoroFrees;
134   for (User *U : CoroId->users())
135     if (auto CF = dyn_cast<CoroFreeInst>(U))
136       CoroFrees.push_back(CF);
137 
138   if (CoroFrees.empty())
139     return;
140 
141   Value *Replacement =
142       Elide ? ConstantPointerNull::get(Type::getInt8PtrTy(CoroId->getContext()))
143             : CoroFrees.front()->getFrame();
144 
145   for (CoroFreeInst *CF : CoroFrees) {
146     CF->replaceAllUsesWith(Replacement);
147     CF->eraseFromParent();
148   }
149 }
150 
151 // FIXME: This code is stolen from CallGraph::addToCallGraph(Function *F), which
152 // happens to be private. It is better for this functionality exposed by the
153 // CallGraph.
154 static void buildCGN(CallGraph &CG, CallGraphNode *Node) {
155   Function *F = Node->getFunction();
156 
157   // Look for calls by this function.
158   for (Instruction &I : instructions(F))
159     if (CallSite CS = CallSite(cast<Value>(&I))) {
160       const Function *Callee = CS.getCalledFunction();
161       if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID()))
162         // Indirect calls of intrinsics are not allowed so no need to check.
163         // We can be more precise here by using TargetArg returned by
164         // Intrinsic::isLeaf.
165         Node->addCalledFunction(CS, CG.getCallsExternalNode());
166       else if (!Callee->isIntrinsic())
167         Node->addCalledFunction(CS, CG.getOrInsertFunction(Callee));
168     }
169 }
170 
171 // Rebuild CGN after we extracted parts of the code from ParentFunc into
172 // NewFuncs. Builds CGNs for the NewFuncs and adds them to the current SCC.
173 void coro::updateCallGraph(Function &ParentFunc, ArrayRef<Function *> NewFuncs,
174                            CallGraph &CG, CallGraphSCC &SCC) {
175   // Rebuild CGN from scratch for the ParentFunc
176   auto *ParentNode = CG[&ParentFunc];
177   ParentNode->removeAllCalledFunctions();
178   buildCGN(CG, ParentNode);
179 
180   SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
181 
182   for (Function *F : NewFuncs) {
183     CallGraphNode *Callee = CG.getOrInsertFunction(F);
184     Nodes.push_back(Callee);
185     buildCGN(CG, Callee);
186   }
187 
188   SCC.initialize(Nodes);
189 }
190 
191 static void clear(coro::Shape &Shape) {
192   Shape.CoroBegin = nullptr;
193   Shape.CoroEnds.clear();
194   Shape.CoroSizes.clear();
195   Shape.CoroSuspends.clear();
196 
197   Shape.FrameTy = nullptr;
198   Shape.FramePtr = nullptr;
199   Shape.AllocaSpillBlock = nullptr;
200   Shape.ResumeSwitch = nullptr;
201   Shape.PromiseAlloca = nullptr;
202   Shape.HasFinalSuspend = false;
203 }
204 
205 static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
206                                     CoroSuspendInst *SuspendInst) {
207   Module *M = SuspendInst->getModule();
208   auto *Fn = Intrinsic::getDeclaration(M, Intrinsic::coro_save);
209   auto *SaveInst =
210       cast<CoroSaveInst>(CallInst::Create(Fn, CoroBegin, "", SuspendInst));
211   assert(!SuspendInst->getCoroSave());
212   SuspendInst->setArgOperand(0, SaveInst);
213   return SaveInst;
214 }
215 
216 // Collect "interesting" coroutine intrinsics.
217 void coro::Shape::buildFrom(Function &F) {
218   size_t FinalSuspendIndex = 0;
219   clear(*this);
220   SmallVector<CoroFrameInst *, 8> CoroFrames;
221   for (Instruction &I : instructions(F)) {
222     if (auto II = dyn_cast<IntrinsicInst>(&I)) {
223       switch (II->getIntrinsicID()) {
224       default:
225         continue;
226       case Intrinsic::coro_size:
227         CoroSizes.push_back(cast<CoroSizeInst>(II));
228         break;
229       case Intrinsic::coro_frame:
230         CoroFrames.push_back(cast<CoroFrameInst>(II));
231         break;
232       case Intrinsic::coro_suspend:
233         CoroSuspends.push_back(cast<CoroSuspendInst>(II));
234         if (CoroSuspends.back()->isFinal()) {
235           if (HasFinalSuspend)
236             report_fatal_error(
237               "Only one suspend point can be marked as final");
238           HasFinalSuspend = true;
239           FinalSuspendIndex = CoroSuspends.size() - 1;
240         }
241         break;
242       case Intrinsic::coro_begin: {
243         auto CB = cast<CoroBeginInst>(II);
244         if (CB->getId()->getInfo().isPreSplit()) {
245           if (CoroBegin)
246             report_fatal_error(
247                 "coroutine should have exactly one defining @llvm.coro.begin");
248           CB->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull);
249           CB->addAttribute(AttributeSet::ReturnIndex, Attribute::NoAlias);
250           CB->removeAttribute(AttributeSet::FunctionIndex,
251                               Attribute::NoDuplicate);
252           CoroBegin = CB;
253         }
254         break;
255       }
256       case Intrinsic::coro_end:
257         CoroEnds.push_back(cast<CoroEndInst>(II));
258         if (CoroEnds.back()->isFallthrough()) {
259           // Make sure that the fallthrough coro.end is the first element in the
260           // CoroEnds vector.
261           if (CoroEnds.size() > 1) {
262             if (CoroEnds.front()->isFallthrough())
263               report_fatal_error(
264                   "Only one coro.end can be marked as fallthrough");
265             std::swap(CoroEnds.front(), CoroEnds.back());
266           }
267         }
268         break;
269       }
270     }
271   }
272 
273   // If for some reason, we were not able to find coro.begin, bailout.
274   if (!CoroBegin) {
275     // Replace coro.frame which are supposed to be lowered to the result of
276     // coro.begin with undef.
277     auto *Undef = UndefValue::get(Type::getInt8PtrTy(F.getContext()));
278     for (CoroFrameInst *CF : CoroFrames) {
279       CF->replaceAllUsesWith(Undef);
280       CF->eraseFromParent();
281     }
282 
283     // Replace all coro.suspend with undef and remove related coro.saves if
284     // present.
285     for (CoroSuspendInst *CS : CoroSuspends) {
286       CS->replaceAllUsesWith(UndefValue::get(CS->getType()));
287       CS->eraseFromParent();
288       if (auto *CoroSave = CS->getCoroSave())
289         CoroSave->eraseFromParent();
290     }
291 
292     // Replace all coro.ends with unreachable instruction.
293     for (CoroEndInst *CE : CoroEnds)
294       changeToUnreachable(CE, /*UseLLVMTrap=*/false);
295 
296     return;
297   }
298 
299   // The coro.free intrinsic is always lowered to the result of coro.begin.
300   for (CoroFrameInst *CF : CoroFrames) {
301     CF->replaceAllUsesWith(CoroBegin);
302     CF->eraseFromParent();
303   }
304 
305   // Canonicalize coro.suspend by inserting a coro.save if needed.
306   for (CoroSuspendInst *CS : CoroSuspends)
307     if (!CS->getCoroSave())
308       createCoroSave(CoroBegin, CS);
309 
310   // Move final suspend to be the last element in the CoroSuspends vector.
311   if (HasFinalSuspend &&
312       FinalSuspendIndex != CoroSuspends.size() - 1)
313     std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back());
314 }
315