1 //===- Coroutines.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the common infrastructure for Coroutine Passes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Transforms/Coroutines.h"
14 #include "CoroInstr.h"
15 #include "CoroInternal.h"
16 #include "llvm-c/Transforms/Coroutines.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/Analysis/CallGraph.h"
20 #include "llvm/Analysis/CallGraphSCCPass.h"
21 #include "llvm/IR/Attributes.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DerivedTypes.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/InstIterator.h"
26 #include "llvm/IR/Instructions.h"
27 #include "llvm/IR/IntrinsicInst.h"
28 #include "llvm/IR/Intrinsics.h"
29 #include "llvm/IR/LegacyPassManager.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Support/Casting.h"
34 #include "llvm/Support/ErrorHandling.h"
35 #include "llvm/Transforms/IPO.h"
36 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
37 #include "llvm/Transforms/Utils/Local.h"
38 #include <cassert>
39 #include <cstddef>
40 #include <utility>
41 
42 using namespace llvm;
43 
44 void llvm::initializeCoroutines(PassRegistry &Registry) {
45   initializeCoroEarlyLegacyPass(Registry);
46   initializeCoroSplitLegacyPass(Registry);
47   initializeCoroElideLegacyPass(Registry);
48   initializeCoroCleanupLegacyPass(Registry);
49 }
50 
51 static void addCoroutineOpt0Passes(const PassManagerBuilder &Builder,
52                                    legacy::PassManagerBase &PM) {
53   PM.add(createCoroSplitLegacyPass());
54   PM.add(createCoroElideLegacyPass());
55 
56   PM.add(createBarrierNoopPass());
57   PM.add(createCoroCleanupLegacyPass());
58 }
59 
60 static void addCoroutineEarlyPasses(const PassManagerBuilder &Builder,
61                                     legacy::PassManagerBase &PM) {
62   PM.add(createCoroEarlyLegacyPass());
63 }
64 
65 static void addCoroutineScalarOptimizerPasses(const PassManagerBuilder &Builder,
66                                               legacy::PassManagerBase &PM) {
67   PM.add(createCoroElideLegacyPass());
68 }
69 
70 static void addCoroutineSCCPasses(const PassManagerBuilder &Builder,
71                                   legacy::PassManagerBase &PM) {
72   PM.add(createCoroSplitLegacyPass(Builder.OptLevel != 0));
73 }
74 
75 static void addCoroutineOptimizerLastPasses(const PassManagerBuilder &Builder,
76                                             legacy::PassManagerBase &PM) {
77   PM.add(createCoroCleanupLegacyPass());
78 }
79 
80 void llvm::addCoroutinePassesToExtensionPoints(PassManagerBuilder &Builder) {
81   Builder.addExtension(PassManagerBuilder::EP_EarlyAsPossible,
82                        addCoroutineEarlyPasses);
83   Builder.addExtension(PassManagerBuilder::EP_EnabledOnOptLevel0,
84                        addCoroutineOpt0Passes);
85   Builder.addExtension(PassManagerBuilder::EP_CGSCCOptimizerLate,
86                        addCoroutineSCCPasses);
87   Builder.addExtension(PassManagerBuilder::EP_ScalarOptimizerLate,
88                        addCoroutineScalarOptimizerPasses);
89   Builder.addExtension(PassManagerBuilder::EP_OptimizerLast,
90                        addCoroutineOptimizerLastPasses);
91 }
92 
93 // Construct the lowerer base class and initialize its members.
94 coro::LowererBase::LowererBase(Module &M)
95     : TheModule(M), Context(M.getContext()),
96       Int8Ptr(Type::getInt8PtrTy(Context)),
97       ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
98                                      /*isVarArg=*/false)),
99       NullPtr(ConstantPointerNull::get(Int8Ptr)) {}
100 
101 // Creates a sequence of instructions to obtain a resume function address using
102 // llvm.coro.subfn.addr. It generates the following sequence:
103 //
104 //    call i8* @llvm.coro.subfn.addr(i8* %Arg, i8 %index)
105 //    bitcast i8* %2 to void(i8*)*
106 
107 Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
108                                         Instruction *InsertPt) {
109   auto *IndexVal = ConstantInt::get(Type::getInt8Ty(Context), Index);
110   auto *Fn = Intrinsic::getDeclaration(&TheModule, Intrinsic::coro_subfn_addr);
111 
112   assert(Index >= CoroSubFnInst::IndexFirst &&
113          Index < CoroSubFnInst::IndexLast &&
114          "makeSubFnCall: Index value out of range");
115   auto *Call = CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt);
116 
117   auto *Bitcast =
118       new BitCastInst(Call, ResumeFnType->getPointerTo(), "", InsertPt);
119   return Bitcast;
120 }
121 
122 #ifndef NDEBUG
123 static bool isCoroutineIntrinsicName(StringRef Name) {
124   // NOTE: Must be sorted!
125   static const char *const CoroIntrinsics[] = {
126       "llvm.coro.alloc",
127       "llvm.coro.async.context.alloc",
128       "llvm.coro.async.context.dealloc",
129       "llvm.coro.async.resume",
130       "llvm.coro.async.size.replace",
131       "llvm.coro.async.store_resume",
132       "llvm.coro.begin",
133       "llvm.coro.destroy",
134       "llvm.coro.done",
135       "llvm.coro.end",
136       "llvm.coro.end.async",
137       "llvm.coro.frame",
138       "llvm.coro.free",
139       "llvm.coro.id",
140       "llvm.coro.id.async",
141       "llvm.coro.id.retcon",
142       "llvm.coro.id.retcon.once",
143       "llvm.coro.noop",
144       "llvm.coro.prepare.async",
145       "llvm.coro.prepare.retcon",
146       "llvm.coro.promise",
147       "llvm.coro.resume",
148       "llvm.coro.save",
149       "llvm.coro.size",
150       "llvm.coro.subfn.addr",
151       "llvm.coro.suspend",
152       "llvm.coro.suspend.async",
153       "llvm.coro.suspend.retcon",
154   };
155   return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name) != -1;
156 }
157 #endif
158 
159 // Verifies if a module has named values listed. Also, in debug mode verifies
160 // that names are intrinsic names.
161 bool coro::declaresIntrinsics(const Module &M,
162                               const std::initializer_list<StringRef> List) {
163   for (StringRef Name : List) {
164     assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
165     if (M.getNamedValue(Name))
166       return true;
167   }
168 
169   return false;
170 }
171 
172 // Replace all coro.frees associated with the provided CoroId either with 'null'
173 // if Elide is true and with its frame parameter otherwise.
174 void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) {
175   SmallVector<CoroFreeInst *, 4> CoroFrees;
176   for (User *U : CoroId->users())
177     if (auto CF = dyn_cast<CoroFreeInst>(U))
178       CoroFrees.push_back(CF);
179 
180   if (CoroFrees.empty())
181     return;
182 
183   Value *Replacement =
184       Elide ? ConstantPointerNull::get(Type::getInt8PtrTy(CoroId->getContext()))
185             : CoroFrees.front()->getFrame();
186 
187   for (CoroFreeInst *CF : CoroFrees) {
188     CF->replaceAllUsesWith(Replacement);
189     CF->eraseFromParent();
190   }
191 }
192 
193 // FIXME: This code is stolen from CallGraph::addToCallGraph(Function *F), which
194 // happens to be private. It is better for this functionality exposed by the
195 // CallGraph.
196 static void buildCGN(CallGraph &CG, CallGraphNode *Node) {
197   Function *F = Node->getFunction();
198 
199   // Look for calls by this function.
200   for (Instruction &I : instructions(F))
201     if (auto *Call = dyn_cast<CallBase>(&I)) {
202       const Function *Callee = Call->getCalledFunction();
203       if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID()))
204         // Indirect calls of intrinsics are not allowed so no need to check.
205         // We can be more precise here by using TargetArg returned by
206         // Intrinsic::isLeaf.
207         Node->addCalledFunction(Call, CG.getCallsExternalNode());
208       else if (!Callee->isIntrinsic())
209         Node->addCalledFunction(Call, CG.getOrInsertFunction(Callee));
210     }
211 }
212 
213 // Rebuild CGN after we extracted parts of the code from ParentFunc into
214 // NewFuncs. Builds CGNs for the NewFuncs and adds them to the current SCC.
215 void coro::updateCallGraph(Function &ParentFunc, ArrayRef<Function *> NewFuncs,
216                            CallGraph &CG, CallGraphSCC &SCC) {
217   // Rebuild CGN from scratch for the ParentFunc
218   auto *ParentNode = CG[&ParentFunc];
219   ParentNode->removeAllCalledFunctions();
220   buildCGN(CG, ParentNode);
221 
222   SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
223 
224   for (Function *F : NewFuncs) {
225     CallGraphNode *Callee = CG.getOrInsertFunction(F);
226     Nodes.push_back(Callee);
227     buildCGN(CG, Callee);
228   }
229 
230   SCC.initialize(Nodes);
231 }
232 
233 static void clear(coro::Shape &Shape) {
234   Shape.CoroBegin = nullptr;
235   Shape.CoroEnds.clear();
236   Shape.CoroSizes.clear();
237   Shape.CoroSuspends.clear();
238 
239   Shape.FrameTy = nullptr;
240   Shape.FramePtr = nullptr;
241   Shape.AllocaSpillBlock = nullptr;
242 }
243 
244 static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
245                                     CoroSuspendInst *SuspendInst) {
246   Module *M = SuspendInst->getModule();
247   auto *Fn = Intrinsic::getDeclaration(M, Intrinsic::coro_save);
248   auto *SaveInst =
249       cast<CoroSaveInst>(CallInst::Create(Fn, CoroBegin, "", SuspendInst));
250   assert(!SuspendInst->getCoroSave());
251   SuspendInst->setArgOperand(0, SaveInst);
252   return SaveInst;
253 }
254 
255 // Collect "interesting" coroutine intrinsics.
256 void coro::Shape::buildFrom(Function &F) {
257   bool HasFinalSuspend = false;
258   size_t FinalSuspendIndex = 0;
259   clear(*this);
260   SmallVector<CoroFrameInst *, 8> CoroFrames;
261   SmallVector<CoroSaveInst *, 2> UnusedCoroSaves;
262 
263   for (Instruction &I : instructions(F)) {
264     if (auto II = dyn_cast<IntrinsicInst>(&I)) {
265       switch (II->getIntrinsicID()) {
266       default:
267         continue;
268       case Intrinsic::coro_size:
269         CoroSizes.push_back(cast<CoroSizeInst>(II));
270         break;
271       case Intrinsic::coro_frame:
272         CoroFrames.push_back(cast<CoroFrameInst>(II));
273         break;
274       case Intrinsic::coro_save:
275         // After optimizations, coro_suspends using this coro_save might have
276         // been removed, remember orphaned coro_saves to remove them later.
277         if (II->use_empty())
278           UnusedCoroSaves.push_back(cast<CoroSaveInst>(II));
279         break;
280       case Intrinsic::coro_suspend_async: {
281         auto *Suspend = cast<CoroSuspendAsyncInst>(II);
282         Suspend->checkWellFormed();
283         CoroSuspends.push_back(Suspend);
284         break;
285       }
286       case Intrinsic::coro_suspend_retcon: {
287         auto Suspend = cast<CoroSuspendRetconInst>(II);
288         CoroSuspends.push_back(Suspend);
289         break;
290       }
291       case Intrinsic::coro_suspend: {
292         auto Suspend = cast<CoroSuspendInst>(II);
293         CoroSuspends.push_back(Suspend);
294         if (Suspend->isFinal()) {
295           if (HasFinalSuspend)
296             report_fatal_error(
297               "Only one suspend point can be marked as final");
298           HasFinalSuspend = true;
299           FinalSuspendIndex = CoroSuspends.size() - 1;
300         }
301         break;
302       }
303       case Intrinsic::coro_begin: {
304         auto CB = cast<CoroBeginInst>(II);
305 
306         // Ignore coro id's that aren't pre-split.
307         auto Id = dyn_cast<CoroIdInst>(CB->getId());
308         if (Id && !Id->getInfo().isPreSplit())
309           break;
310 
311         if (CoroBegin)
312           report_fatal_error(
313                 "coroutine should have exactly one defining @llvm.coro.begin");
314         CB->addRetAttr(Attribute::NonNull);
315         CB->addRetAttr(Attribute::NoAlias);
316         CB->removeFnAttr(Attribute::NoDuplicate);
317         CoroBegin = CB;
318         break;
319       }
320       case Intrinsic::coro_end_async:
321       case Intrinsic::coro_end:
322         CoroEnds.push_back(cast<AnyCoroEndInst>(II));
323         if (auto *AsyncEnd = dyn_cast<CoroAsyncEndInst>(II)) {
324           AsyncEnd->checkWellFormed();
325         }
326         if (CoroEnds.back()->isFallthrough() && isa<CoroEndInst>(II)) {
327           // Make sure that the fallthrough coro.end is the first element in the
328           // CoroEnds vector.
329           // Note: I don't think this is neccessary anymore.
330           if (CoroEnds.size() > 1) {
331             if (CoroEnds.front()->isFallthrough())
332               report_fatal_error(
333                   "Only one coro.end can be marked as fallthrough");
334             std::swap(CoroEnds.front(), CoroEnds.back());
335           }
336         }
337         break;
338       }
339     }
340   }
341 
342   // If for some reason, we were not able to find coro.begin, bailout.
343   if (!CoroBegin) {
344     // Replace coro.frame which are supposed to be lowered to the result of
345     // coro.begin with undef.
346     auto *Undef = UndefValue::get(Type::getInt8PtrTy(F.getContext()));
347     for (CoroFrameInst *CF : CoroFrames) {
348       CF->replaceAllUsesWith(Undef);
349       CF->eraseFromParent();
350     }
351 
352     // Replace all coro.suspend with undef and remove related coro.saves if
353     // present.
354     for (AnyCoroSuspendInst *CS : CoroSuspends) {
355       CS->replaceAllUsesWith(UndefValue::get(CS->getType()));
356       CS->eraseFromParent();
357       if (auto *CoroSave = CS->getCoroSave())
358         CoroSave->eraseFromParent();
359     }
360 
361     // Replace all coro.ends with unreachable instruction.
362     for (AnyCoroEndInst *CE : CoroEnds)
363       changeToUnreachable(CE);
364 
365     return;
366   }
367 
368   auto Id = CoroBegin->getId();
369   switch (auto IdIntrinsic = Id->getIntrinsicID()) {
370   case Intrinsic::coro_id: {
371     auto SwitchId = cast<CoroIdInst>(Id);
372     this->ABI = coro::ABI::Switch;
373     this->SwitchLowering.HasFinalSuspend = HasFinalSuspend;
374     this->SwitchLowering.ResumeSwitch = nullptr;
375     this->SwitchLowering.PromiseAlloca = SwitchId->getPromise();
376     this->SwitchLowering.ResumeEntryBlock = nullptr;
377 
378     for (auto AnySuspend : CoroSuspends) {
379       auto Suspend = dyn_cast<CoroSuspendInst>(AnySuspend);
380       if (!Suspend) {
381 #ifndef NDEBUG
382         AnySuspend->dump();
383 #endif
384         report_fatal_error("coro.id must be paired with coro.suspend");
385       }
386 
387       if (!Suspend->getCoroSave())
388         createCoroSave(CoroBegin, Suspend);
389     }
390     break;
391   }
392   case Intrinsic::coro_id_async: {
393     auto *AsyncId = cast<CoroIdAsyncInst>(Id);
394     AsyncId->checkWellFormed();
395     this->ABI = coro::ABI::Async;
396     this->AsyncLowering.Context = AsyncId->getStorage();
397     this->AsyncLowering.ContextArgNo = AsyncId->getStorageArgumentIndex();
398     this->AsyncLowering.ContextHeaderSize = AsyncId->getStorageSize();
399     this->AsyncLowering.ContextAlignment =
400         AsyncId->getStorageAlignment().value();
401     this->AsyncLowering.AsyncFuncPointer = AsyncId->getAsyncFunctionPointer();
402     this->AsyncLowering.AsyncCC = F.getCallingConv();
403     break;
404   };
405   case Intrinsic::coro_id_retcon:
406   case Intrinsic::coro_id_retcon_once: {
407     auto ContinuationId = cast<AnyCoroIdRetconInst>(Id);
408     ContinuationId->checkWellFormed();
409     this->ABI = (IdIntrinsic == Intrinsic::coro_id_retcon
410                   ? coro::ABI::Retcon
411                   : coro::ABI::RetconOnce);
412     auto Prototype = ContinuationId->getPrototype();
413     this->RetconLowering.ResumePrototype = Prototype;
414     this->RetconLowering.Alloc = ContinuationId->getAllocFunction();
415     this->RetconLowering.Dealloc = ContinuationId->getDeallocFunction();
416     this->RetconLowering.ReturnBlock = nullptr;
417     this->RetconLowering.IsFrameInlineInStorage = false;
418 
419     // Determine the result value types, and make sure they match up with
420     // the values passed to the suspends.
421     auto ResultTys = getRetconResultTypes();
422     auto ResumeTys = getRetconResumeTypes();
423 
424     for (auto AnySuspend : CoroSuspends) {
425       auto Suspend = dyn_cast<CoroSuspendRetconInst>(AnySuspend);
426       if (!Suspend) {
427 #ifndef NDEBUG
428         AnySuspend->dump();
429 #endif
430         report_fatal_error("coro.id.retcon.* must be paired with "
431                            "coro.suspend.retcon");
432       }
433 
434       // Check that the argument types of the suspend match the results.
435       auto SI = Suspend->value_begin(), SE = Suspend->value_end();
436       auto RI = ResultTys.begin(), RE = ResultTys.end();
437       for (; SI != SE && RI != RE; ++SI, ++RI) {
438         auto SrcTy = (*SI)->getType();
439         if (SrcTy != *RI) {
440           // The optimizer likes to eliminate bitcasts leading into variadic
441           // calls, but that messes with our invariants.  Re-insert the
442           // bitcast and ignore this type mismatch.
443           if (CastInst::isBitCastable(SrcTy, *RI)) {
444             auto BCI = new BitCastInst(*SI, *RI, "", Suspend);
445             SI->set(BCI);
446             continue;
447           }
448 
449 #ifndef NDEBUG
450           Suspend->dump();
451           Prototype->getFunctionType()->dump();
452 #endif
453           report_fatal_error("argument to coro.suspend.retcon does not "
454                              "match corresponding prototype function result");
455         }
456       }
457       if (SI != SE || RI != RE) {
458 #ifndef NDEBUG
459         Suspend->dump();
460         Prototype->getFunctionType()->dump();
461 #endif
462         report_fatal_error("wrong number of arguments to coro.suspend.retcon");
463       }
464 
465       // Check that the result type of the suspend matches the resume types.
466       Type *SResultTy = Suspend->getType();
467       ArrayRef<Type*> SuspendResultTys;
468       if (SResultTy->isVoidTy()) {
469         // leave as empty array
470       } else if (auto SResultStructTy = dyn_cast<StructType>(SResultTy)) {
471         SuspendResultTys = SResultStructTy->elements();
472       } else {
473         // forms an ArrayRef using SResultTy, be careful
474         SuspendResultTys = SResultTy;
475       }
476       if (SuspendResultTys.size() != ResumeTys.size()) {
477 #ifndef NDEBUG
478         Suspend->dump();
479         Prototype->getFunctionType()->dump();
480 #endif
481         report_fatal_error("wrong number of results from coro.suspend.retcon");
482       }
483       for (size_t I = 0, E = ResumeTys.size(); I != E; ++I) {
484         if (SuspendResultTys[I] != ResumeTys[I]) {
485 #ifndef NDEBUG
486           Suspend->dump();
487           Prototype->getFunctionType()->dump();
488 #endif
489           report_fatal_error("result from coro.suspend.retcon does not "
490                              "match corresponding prototype function param");
491         }
492       }
493     }
494     break;
495   }
496 
497   default:
498     llvm_unreachable("coro.begin is not dependent on a coro.id call");
499   }
500 
501   // The coro.free intrinsic is always lowered to the result of coro.begin.
502   for (CoroFrameInst *CF : CoroFrames) {
503     CF->replaceAllUsesWith(CoroBegin);
504     CF->eraseFromParent();
505   }
506 
507   // Move final suspend to be the last element in the CoroSuspends vector.
508   if (ABI == coro::ABI::Switch &&
509       SwitchLowering.HasFinalSuspend &&
510       FinalSuspendIndex != CoroSuspends.size() - 1)
511     std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back());
512 
513   // Remove orphaned coro.saves.
514   for (CoroSaveInst *CoroSave : UnusedCoroSaves)
515     CoroSave->eraseFromParent();
516 }
517 
518 static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) {
519   Call->setCallingConv(Callee->getCallingConv());
520   // TODO: attributes?
521 }
522 
523 static void addCallToCallGraph(CallGraph *CG, CallInst *Call, Function *Callee){
524   if (CG)
525     (*CG)[Call->getFunction()]->addCalledFunction(Call, (*CG)[Callee]);
526 }
527 
528 Value *coro::Shape::emitAlloc(IRBuilder<> &Builder, Value *Size,
529                               CallGraph *CG) const {
530   switch (ABI) {
531   case coro::ABI::Switch:
532     llvm_unreachable("can't allocate memory in coro switch-lowering");
533 
534   case coro::ABI::Retcon:
535   case coro::ABI::RetconOnce: {
536     auto Alloc = RetconLowering.Alloc;
537     Size = Builder.CreateIntCast(Size,
538                                  Alloc->getFunctionType()->getParamType(0),
539                                  /*is signed*/ false);
540     auto *Call = Builder.CreateCall(Alloc, Size);
541     propagateCallAttrsFromCallee(Call, Alloc);
542     addCallToCallGraph(CG, Call, Alloc);
543     return Call;
544   }
545   case coro::ABI::Async:
546     llvm_unreachable("can't allocate memory in coro async-lowering");
547   }
548   llvm_unreachable("Unknown coro::ABI enum");
549 }
550 
551 void coro::Shape::emitDealloc(IRBuilder<> &Builder, Value *Ptr,
552                               CallGraph *CG) const {
553   switch (ABI) {
554   case coro::ABI::Switch:
555     llvm_unreachable("can't allocate memory in coro switch-lowering");
556 
557   case coro::ABI::Retcon:
558   case coro::ABI::RetconOnce: {
559     auto Dealloc = RetconLowering.Dealloc;
560     Ptr = Builder.CreateBitCast(Ptr,
561                                 Dealloc->getFunctionType()->getParamType(0));
562     auto *Call = Builder.CreateCall(Dealloc, Ptr);
563     propagateCallAttrsFromCallee(Call, Dealloc);
564     addCallToCallGraph(CG, Call, Dealloc);
565     return;
566   }
567   case coro::ABI::Async:
568     llvm_unreachable("can't allocate memory in coro async-lowering");
569   }
570   llvm_unreachable("Unknown coro::ABI enum");
571 }
572 
573 [[noreturn]] static void fail(const Instruction *I, const char *Reason,
574                               Value *V) {
575 #ifndef NDEBUG
576   I->dump();
577   if (V) {
578     errs() << "  Value: ";
579     V->printAsOperand(llvm::errs());
580     errs() << '\n';
581   }
582 #endif
583   report_fatal_error(Reason);
584 }
585 
586 /// Check that the given value is a well-formed prototype for the
587 /// llvm.coro.id.retcon.* intrinsics.
588 static void checkWFRetconPrototype(const AnyCoroIdRetconInst *I, Value *V) {
589   auto F = dyn_cast<Function>(V->stripPointerCasts());
590   if (!F)
591     fail(I, "llvm.coro.id.retcon.* prototype not a Function", V);
592 
593   auto FT = F->getFunctionType();
594 
595   if (isa<CoroIdRetconInst>(I)) {
596     bool ResultOkay;
597     if (FT->getReturnType()->isPointerTy()) {
598       ResultOkay = true;
599     } else if (auto SRetTy = dyn_cast<StructType>(FT->getReturnType())) {
600       ResultOkay = (!SRetTy->isOpaque() &&
601                     SRetTy->getNumElements() > 0 &&
602                     SRetTy->getElementType(0)->isPointerTy());
603     } else {
604       ResultOkay = false;
605     }
606     if (!ResultOkay)
607       fail(I, "llvm.coro.id.retcon prototype must return pointer as first "
608               "result", F);
609 
610     if (FT->getReturnType() !=
611           I->getFunction()->getFunctionType()->getReturnType())
612       fail(I, "llvm.coro.id.retcon prototype return type must be same as"
613               "current function return type", F);
614   } else {
615     // No meaningful validation to do here for llvm.coro.id.unique.once.
616   }
617 
618   if (FT->getNumParams() == 0 || !FT->getParamType(0)->isPointerTy())
619     fail(I, "llvm.coro.id.retcon.* prototype must take pointer as "
620             "its first parameter", F);
621 }
622 
623 /// Check that the given value is a well-formed allocator.
624 static void checkWFAlloc(const Instruction *I, Value *V) {
625   auto F = dyn_cast<Function>(V->stripPointerCasts());
626   if (!F)
627     fail(I, "llvm.coro.* allocator not a Function", V);
628 
629   auto FT = F->getFunctionType();
630   if (!FT->getReturnType()->isPointerTy())
631     fail(I, "llvm.coro.* allocator must return a pointer", F);
632 
633   if (FT->getNumParams() != 1 ||
634       !FT->getParamType(0)->isIntegerTy())
635     fail(I, "llvm.coro.* allocator must take integer as only param", F);
636 }
637 
638 /// Check that the given value is a well-formed deallocator.
639 static void checkWFDealloc(const Instruction *I, Value *V) {
640   auto F = dyn_cast<Function>(V->stripPointerCasts());
641   if (!F)
642     fail(I, "llvm.coro.* deallocator not a Function", V);
643 
644   auto FT = F->getFunctionType();
645   if (!FT->getReturnType()->isVoidTy())
646     fail(I, "llvm.coro.* deallocator must return void", F);
647 
648   if (FT->getNumParams() != 1 ||
649       !FT->getParamType(0)->isPointerTy())
650     fail(I, "llvm.coro.* deallocator must take pointer as only param", F);
651 }
652 
653 static void checkConstantInt(const Instruction *I, Value *V,
654                              const char *Reason) {
655   if (!isa<ConstantInt>(V)) {
656     fail(I, Reason, V);
657   }
658 }
659 
660 void AnyCoroIdRetconInst::checkWellFormed() const {
661   checkConstantInt(this, getArgOperand(SizeArg),
662                    "size argument to coro.id.retcon.* must be constant");
663   checkConstantInt(this, getArgOperand(AlignArg),
664                    "alignment argument to coro.id.retcon.* must be constant");
665   checkWFRetconPrototype(this, getArgOperand(PrototypeArg));
666   checkWFAlloc(this, getArgOperand(AllocArg));
667   checkWFDealloc(this, getArgOperand(DeallocArg));
668 }
669 
670 static void checkAsyncFuncPointer(const Instruction *I, Value *V) {
671   auto *AsyncFuncPtrAddr = dyn_cast<GlobalVariable>(V->stripPointerCasts());
672   if (!AsyncFuncPtrAddr)
673     fail(I, "llvm.coro.id.async async function pointer not a global", V);
674 
675   auto *StructTy =
676       cast<StructType>(AsyncFuncPtrAddr->getType()->getPointerElementType());
677   if (StructTy->isOpaque() || !StructTy->isPacked() ||
678       StructTy->getNumElements() != 2 ||
679       !StructTy->getElementType(0)->isIntegerTy(32) ||
680       !StructTy->getElementType(1)->isIntegerTy(32))
681     fail(I,
682          "llvm.coro.id.async async function pointer argument's type is not "
683          "<{i32, i32}>",
684          V);
685 }
686 
687 void CoroIdAsyncInst::checkWellFormed() const {
688   checkConstantInt(this, getArgOperand(SizeArg),
689                    "size argument to coro.id.async must be constant");
690   checkConstantInt(this, getArgOperand(AlignArg),
691                    "alignment argument to coro.id.async must be constant");
692   checkConstantInt(this, getArgOperand(StorageArg),
693                    "storage argument offset to coro.id.async must be constant");
694   checkAsyncFuncPointer(this, getArgOperand(AsyncFuncPtrArg));
695 }
696 
697 static void checkAsyncContextProjectFunction(const Instruction *I,
698                                              Function *F) {
699   auto *FunTy = cast<FunctionType>(F->getValueType());
700   if (!FunTy->getReturnType()->isPointerTy() ||
701       !FunTy->getReturnType()->getPointerElementType()->isIntegerTy(8))
702     fail(I,
703          "llvm.coro.suspend.async resume function projection function must "
704          "return an i8* type",
705          F);
706   if (FunTy->getNumParams() != 1 || !FunTy->getParamType(0)->isPointerTy() ||
707       !FunTy->getParamType(0)->getPointerElementType()->isIntegerTy(8))
708     fail(I,
709          "llvm.coro.suspend.async resume function projection function must "
710          "take one i8* type as parameter",
711          F);
712 }
713 
714 void CoroSuspendAsyncInst::checkWellFormed() const {
715   checkAsyncContextProjectFunction(this, getAsyncContextProjectionFunction());
716 }
717 
718 void CoroAsyncEndInst::checkWellFormed() const {
719   auto *MustTailCallFunc = getMustTailCallFunction();
720   if (!MustTailCallFunc)
721     return;
722   auto *FnTy =
723       cast<FunctionType>(MustTailCallFunc->getType()->getPointerElementType());
724   if (FnTy->getNumParams() != (arg_size() - 3))
725     fail(this,
726          "llvm.coro.end.async must tail call function argument type must "
727          "match the tail arguments",
728          MustTailCallFunc);
729 }
730 
731 void LLVMAddCoroEarlyPass(LLVMPassManagerRef PM) {
732   unwrap(PM)->add(createCoroEarlyLegacyPass());
733 }
734 
735 void LLVMAddCoroSplitPass(LLVMPassManagerRef PM) {
736   unwrap(PM)->add(createCoroSplitLegacyPass());
737 }
738 
739 void LLVMAddCoroElidePass(LLVMPassManagerRef PM) {
740   unwrap(PM)->add(createCoroElideLegacyPass());
741 }
742 
743 void LLVMAddCoroCleanupPass(LLVMPassManagerRef PM) {
744   unwrap(PM)->add(createCoroCleanupLegacyPass());
745 }
746 
747 void
748 LLVMPassManagerBuilderAddCoroutinePassesToExtensionPoints(LLVMPassManagerBuilderRef PMB) {
749   PassManagerBuilder *Builder = unwrap(PMB);
750   addCoroutinePassesToExtensionPoints(*Builder);
751 }
752