1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering from high level async operations to async.coro
10 // and async.runtime operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Async/IR/Async.h"
18 #include "mlir/Dialect/Async/Passes.h"
19 #include "mlir/Dialect/SCF/SCF.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/IR/BlockAndValueMapping.h"
22 #include "mlir/IR/ImplicitLocOpBuilder.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "mlir/Transforms/RegionUtils.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/Support/Debug.h"
28 
29 using namespace mlir;
30 using namespace mlir::async;
31 
32 #define DEBUG_TYPE "async-to-async-runtime"
33 // Prefix for functions outlined from `async.execute` op regions.
34 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
35 
36 namespace {
37 
38 class AsyncToAsyncRuntimePass
39     : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
40 public:
41   AsyncToAsyncRuntimePass() = default;
42   void runOnOperation() override;
43 };
44 
45 } // namespace
46 
47 //===----------------------------------------------------------------------===//
48 // async.execute op outlining to the coroutine functions.
49 //===----------------------------------------------------------------------===//
50 
51 /// Function targeted for coroutine transformation has two additional blocks at
52 /// the end: coroutine cleanup and coroutine suspension.
53 ///
54 /// async.await op lowering additionaly creates a resume block for each
55 /// operation to enable non-blocking waiting via coroutine suspension.
56 namespace {
57 struct CoroMachinery {
58   FuncOp func;
59 
60   // Async execute region returns a completion token, and an async value for
61   // each yielded value.
62   //
63   //   %token, %result = async.execute -> !async.value<T> {
64   //     %0 = arith.constant ... : T
65   //     async.yield %0 : T
66   //   }
67   Value asyncToken; // token representing completion of the async region
68   llvm::SmallVector<Value, 4> returnValues; // returned async values
69 
70   Value coroHandle; // coroutine handle (!async.coro.handle value)
71   Block *entry;     // coroutine entry block
72   Block *setError;  // switch completion token and all values to error state
73   Block *cleanup;   // coroutine cleanup block
74   Block *suspend;   // coroutine suspension block
75 };
76 } // namespace
77 
78 /// Utility to partially update the regular function CFG to the coroutine CFG
79 /// compatible with LLVM coroutines switched-resume lowering using
80 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
81 /// that branches into preexisting entry block. Also inserts trailing blocks.
82 ///
83 /// The result types of the passed `func` must start with an `async.token`
84 /// and be continued with some number of `async.value`s.
85 ///
86 /// The func given to this function needs to have been preprocessed to have
87 /// either branch or yield ops as terminators. Branches to the cleanup block are
88 /// inserted after each yield.
89 ///
90 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
91 ///
92 ///  - `entry` block sets up the coroutine.
93 ///  - `set_error` block sets completion token and async values state to error.
94 ///  - `cleanup` block cleans up the coroutine state.
95 ///  - `suspend block after the @llvm.coro.end() defines what value will be
96 ///    returned to the initial caller of a coroutine. Everything before the
97 ///    @llvm.coro.end() will be executed at every suspension point.
98 ///
99 /// Coroutine structure (only the important bits):
100 ///
101 ///   func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
102 ///   {
103 ///     ^entry(<function-arguments>):
104 ///       %token = <async token> : !async.token    // create async runtime token
105 ///       %value = <async value> : !async.value<T> // create async value
106 ///       %id = async.coro.id                      // create a coroutine id
107 ///       %hdl = async.coro.begin %id              // create a coroutine handle
108 ///       br ^preexisting_entry_block
109 ///
110 ///     /*  preexisting blocks modified to branch to the cleanup block */
111 ///
112 ///     ^set_error: // this block created lazily only if needed (see code below)
113 ///       async.runtime.set_error %token : !async.token
114 ///       async.runtime.set_error %value : !async.value<T>
115 ///       br ^cleanup
116 ///
117 ///     ^cleanup:
118 ///       async.coro.free %hdl // delete the coroutine state
119 ///       br ^suspend
120 ///
121 ///     ^suspend:
122 ///       async.coro.end %hdl // marks the end of a coroutine
123 ///       return %token, %value : !async.token, !async.value<T>
124 ///   }
125 ///
126 static CoroMachinery setupCoroMachinery(FuncOp func) {
127   assert(!func.getBlocks().empty() && "Function must have an entry block");
128 
129   MLIRContext *ctx = func.getContext();
130   Block *entryBlock = &func.getBlocks().front();
131   Block *originalEntryBlock =
132       entryBlock->splitBlock(entryBlock->getOperations().begin());
133   auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
134 
135   // ------------------------------------------------------------------------ //
136   // Allocate async token/values that we will return from a ramp function.
137   // ------------------------------------------------------------------------ //
138   auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
139 
140   llvm::SmallVector<Value, 4> retValues;
141   for (auto resType : func.getCallableResults().drop_front())
142     retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
143 
144   // ------------------------------------------------------------------------ //
145   // Initialize coroutine: get coroutine id and coroutine handle.
146   // ------------------------------------------------------------------------ //
147   auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
148   auto coroHdlOp =
149       builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
150   builder.create<BranchOp>(originalEntryBlock);
151 
152   Block *cleanupBlock = func.addBlock();
153   Block *suspendBlock = func.addBlock();
154 
155   // ------------------------------------------------------------------------ //
156   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
157   // ------------------------------------------------------------------------ //
158   builder.setInsertionPointToStart(cleanupBlock);
159   builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
160 
161   // Branch into the suspend block.
162   builder.create<BranchOp>(suspendBlock);
163 
164   // ------------------------------------------------------------------------ //
165   // Coroutine suspend block: mark the end of a coroutine and return allocated
166   // async token.
167   // ------------------------------------------------------------------------ //
168   builder.setInsertionPointToStart(suspendBlock);
169 
170   // Mark the end of a coroutine: async.coro.end
171   builder.create<CoroEndOp>(coroHdlOp.handle());
172 
173   // Return created `async.token` and `async.values` from the suspend block.
174   // This will be the return value of a coroutine ramp function.
175   SmallVector<Value, 4> ret{retToken};
176   ret.insert(ret.end(), retValues.begin(), retValues.end());
177   builder.create<ReturnOp>(ret);
178 
179   // `async.await` op lowering will create resume blocks for async
180   // continuations, and will conditionally branch to cleanup or suspend blocks.
181 
182   for (Block &block : func.body().getBlocks()) {
183     if (&block == entryBlock || &block == cleanupBlock ||
184         &block == suspendBlock)
185       continue;
186     Operation *terminator = block.getTerminator();
187     if (auto yield = dyn_cast<YieldOp>(terminator)) {
188       builder.setInsertionPointToEnd(&block);
189       builder.create<BranchOp>(cleanupBlock);
190     }
191   }
192 
193   // The switch-resumed API based coroutine should be marked with
194   // "coroutine.presplit" attribute with value "0" to mark the function as a
195   // coroutine.
196   func->setAttr("passthrough", builder.getArrayAttr(builder.getArrayAttr(
197                                    {builder.getStringAttr("coroutine.presplit"),
198                                     builder.getStringAttr("0")})));
199 
200   CoroMachinery machinery;
201   machinery.func = func;
202   machinery.asyncToken = retToken;
203   machinery.returnValues = retValues;
204   machinery.coroHandle = coroHdlOp.handle();
205   machinery.entry = entryBlock;
206   machinery.setError = nullptr; // created lazily only if needed
207   machinery.cleanup = cleanupBlock;
208   machinery.suspend = suspendBlock;
209   return machinery;
210 }
211 
212 // Lazily creates `set_error` block only if it is required for lowering to the
213 // runtime operations (see for example lowering of assert operation).
214 static Block *setupSetErrorBlock(CoroMachinery &coro) {
215   if (coro.setError)
216     return coro.setError;
217 
218   coro.setError = coro.func.addBlock();
219   coro.setError->moveBefore(coro.cleanup);
220 
221   auto builder =
222       ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError);
223 
224   // Coroutine set_error block: set error on token and all returned values.
225   builder.create<RuntimeSetErrorOp>(coro.asyncToken);
226   for (Value retValue : coro.returnValues)
227     builder.create<RuntimeSetErrorOp>(retValue);
228 
229   // Branch into the cleanup block.
230   builder.create<BranchOp>(coro.cleanup);
231 
232   return coro.setError;
233 }
234 
235 /// Outline the body region attached to the `async.execute` op into a standalone
236 /// function.
237 ///
238 /// Note that this is not reversible transformation.
239 static std::pair<FuncOp, CoroMachinery>
240 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
241   ModuleOp module = execute->getParentOfType<ModuleOp>();
242 
243   MLIRContext *ctx = module.getContext();
244   Location loc = execute.getLoc();
245 
246   // Make sure that all constants will be inside the outlined async function to
247   // reduce the number of function arguments.
248   cloneConstantsIntoTheRegion(execute.body());
249 
250   // Collect all outlined function inputs.
251   SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
252                                         execute.dependencies().end());
253   functionInputs.insert(execute.operands().begin(), execute.operands().end());
254   getUsedValuesDefinedAbove(execute.body(), functionInputs);
255 
256   // Collect types for the outlined function inputs and outputs.
257   auto typesRange = llvm::map_range(
258       functionInputs, [](Value value) { return value.getType(); });
259   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
260   auto outputTypes = execute.getResultTypes();
261 
262   auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
263   auto funcAttrs = ArrayRef<NamedAttribute>();
264 
265   // TODO: Derive outlined function name from the parent FuncOp (support
266   // multiple nested async.execute operations).
267   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
268   symbolTable.insert(func);
269 
270   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
271   auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
272 
273   // Prepare for coroutine conversion by creating the body of the function.
274   {
275     size_t numDependencies = execute.dependencies().size();
276     size_t numOperands = execute.operands().size();
277 
278     // Await on all dependencies before starting to execute the body region.
279     for (size_t i = 0; i < numDependencies; ++i)
280       builder.create<AwaitOp>(func.getArgument(i));
281 
282     // Await on all async value operands and unwrap the payload.
283     SmallVector<Value, 4> unwrappedOperands(numOperands);
284     for (size_t i = 0; i < numOperands; ++i) {
285       Value operand = func.getArgument(numDependencies + i);
286       unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
287     }
288 
289     // Map from function inputs defined above the execute op to the function
290     // arguments.
291     BlockAndValueMapping valueMapping;
292     valueMapping.map(functionInputs, func.getArguments());
293     valueMapping.map(execute.body().getArguments(), unwrappedOperands);
294 
295     // Clone all operations from the execute operation body into the outlined
296     // function body.
297     for (Operation &op : execute.body().getOps())
298       builder.clone(op, valueMapping);
299   }
300 
301   // Adding entry/cleanup/suspend blocks.
302   CoroMachinery coro = setupCoroMachinery(func);
303 
304   // Suspend async function at the end of an entry block, and resume it using
305   // Async resume operation (execution will be resumed in a thread managed by
306   // the async runtime).
307   {
308     BranchOp branch = cast<BranchOp>(coro.entry->getTerminator());
309     builder.setInsertionPointToEnd(coro.entry);
310 
311     // Save the coroutine state: async.coro.save
312     auto coroSaveOp =
313         builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
314 
315     // Pass coroutine to the runtime to be resumed on a runtime managed
316     // thread.
317     builder.create<RuntimeResumeOp>(coro.coroHandle);
318 
319     // Add async.coro.suspend as a suspended block terminator.
320     builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend,
321                                   branch.getDest(), coro.cleanup);
322 
323     branch.erase();
324   }
325 
326   // Replace the original `async.execute` with a call to outlined function.
327   {
328     ImplicitLocOpBuilder callBuilder(loc, execute);
329     auto callOutlinedFunc = callBuilder.create<CallOp>(
330         func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
331     execute.replaceAllUsesWith(callOutlinedFunc.getResults());
332     execute.erase();
333   }
334 
335   return {func, coro};
336 }
337 
338 //===----------------------------------------------------------------------===//
339 // Convert async.create_group operation to async.runtime.create_group
340 //===----------------------------------------------------------------------===//
341 
342 namespace {
343 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
344 public:
345   using OpConversionPattern::OpConversionPattern;
346 
347   LogicalResult
348   matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
349                   ConversionPatternRewriter &rewriter) const override {
350     rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
351         op, GroupType::get(op->getContext()), adaptor.getOperands());
352     return success();
353   }
354 };
355 } // namespace
356 
357 //===----------------------------------------------------------------------===//
358 // Convert async.add_to_group operation to async.runtime.add_to_group.
359 //===----------------------------------------------------------------------===//
360 
361 namespace {
362 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
363 public:
364   using OpConversionPattern::OpConversionPattern;
365 
366   LogicalResult
367   matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
368                   ConversionPatternRewriter &rewriter) const override {
369     rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
370         op, rewriter.getIndexType(), adaptor.getOperands());
371     return success();
372   }
373 };
374 } // namespace
375 
376 //===----------------------------------------------------------------------===//
377 // Convert async.await and async.await_all operations to the async.runtime.await
378 // or async.runtime.await_and_resume operations.
379 //===----------------------------------------------------------------------===//
380 
381 namespace {
382 template <typename AwaitType, typename AwaitableType>
383 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
384   using AwaitAdaptor = typename AwaitType::Adaptor;
385 
386 public:
387   AwaitOpLoweringBase(MLIRContext *ctx,
388                       llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
389       : OpConversionPattern<AwaitType>(ctx),
390         outlinedFunctions(outlinedFunctions) {}
391 
392   LogicalResult
393   matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
394                   ConversionPatternRewriter &rewriter) const override {
395     // We can only await on one the `AwaitableType` (for `await` it can be
396     // a `token` or a `value`, for `await_all` it must be a `group`).
397     if (!op.operand().getType().template isa<AwaitableType>())
398       return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
399 
400     // Check if await operation is inside the outlined coroutine function.
401     auto func = op->template getParentOfType<FuncOp>();
402     auto outlined = outlinedFunctions.find(func);
403     const bool isInCoroutine = outlined != outlinedFunctions.end();
404 
405     Location loc = op->getLoc();
406     Value operand = adaptor.operand();
407 
408     Type i1 = rewriter.getI1Type();
409 
410     // Inside regular functions we use the blocking wait operation to wait for
411     // the async object (token, value or group) to become available.
412     if (!isInCoroutine) {
413       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
414       builder.create<RuntimeAwaitOp>(loc, operand);
415 
416       // Assert that the awaited operands is not in the error state.
417       Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
418       Value notError = builder.create<arith::XOrIOp>(
419           isError, builder.create<arith::ConstantOp>(
420                        loc, i1, builder.getIntegerAttr(i1, 1)));
421 
422       builder.create<AssertOp>(notError,
423                                "Awaited async operand is in error state");
424     }
425 
426     // Inside the coroutine we convert await operation into coroutine suspension
427     // point, and resume execution asynchronously.
428     if (isInCoroutine) {
429       CoroMachinery &coro = outlined->getSecond();
430       Block *suspended = op->getBlock();
431 
432       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
433       MLIRContext *ctx = op->getContext();
434 
435       // Save the coroutine state and resume on a runtime managed thread when
436       // the operand becomes available.
437       auto coroSaveOp =
438           builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
439       builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
440 
441       // Split the entry block before the await operation.
442       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
443 
444       // Add async.coro.suspend as a suspended block terminator.
445       builder.setInsertionPointToEnd(suspended);
446       builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
447                                     coro.cleanup);
448 
449       // Split the resume block into error checking and continuation.
450       Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
451 
452       // Check if the awaited value is in the error state.
453       builder.setInsertionPointToStart(resume);
454       auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
455       builder.create<CondBranchOp>(isError,
456                                    /*trueDest=*/setupSetErrorBlock(coro),
457                                    /*trueArgs=*/ArrayRef<Value>(),
458                                    /*falseDest=*/continuation,
459                                    /*falseArgs=*/ArrayRef<Value>());
460 
461       // Make sure that replacement value will be constructed in the
462       // continuation block.
463       rewriter.setInsertionPointToStart(continuation);
464     }
465 
466     // Erase or replace the await operation with the new value.
467     if (Value replaceWith = getReplacementValue(op, operand, rewriter))
468       rewriter.replaceOp(op, replaceWith);
469     else
470       rewriter.eraseOp(op);
471 
472     return success();
473   }
474 
475   virtual Value getReplacementValue(AwaitType op, Value operand,
476                                     ConversionPatternRewriter &rewriter) const {
477     return Value();
478   }
479 
480 private:
481   llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
482 };
483 
484 /// Lowering for `async.await` with a token operand.
485 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
486   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
487 
488 public:
489   using Base::Base;
490 };
491 
492 /// Lowering for `async.await` with a value operand.
493 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
494   using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
495 
496 public:
497   using Base::Base;
498 
499   Value
500   getReplacementValue(AwaitOp op, Value operand,
501                       ConversionPatternRewriter &rewriter) const override {
502     // Load from the async value storage.
503     auto valueType = operand.getType().cast<ValueType>().getValueType();
504     return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
505   }
506 };
507 
508 /// Lowering for `async.await_all` operation.
509 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
510   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
511 
512 public:
513   using Base::Base;
514 };
515 
516 } // namespace
517 
518 //===----------------------------------------------------------------------===//
519 // Convert async.yield operation to async.runtime operations.
520 //===----------------------------------------------------------------------===//
521 
522 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
523 public:
524   YieldOpLowering(
525       MLIRContext *ctx,
526       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
527       : OpConversionPattern<async::YieldOp>(ctx),
528         outlinedFunctions(outlinedFunctions) {}
529 
530   LogicalResult
531   matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
532                   ConversionPatternRewriter &rewriter) const override {
533     // Check if yield operation is inside the async coroutine function.
534     auto func = op->template getParentOfType<FuncOp>();
535     auto outlined = outlinedFunctions.find(func);
536     if (outlined == outlinedFunctions.end())
537       return rewriter.notifyMatchFailure(
538           op, "operation is not inside the async coroutine function");
539 
540     Location loc = op->getLoc();
541     const CoroMachinery &coro = outlined->getSecond();
542 
543     // Store yielded values into the async values storage and switch async
544     // values state to available.
545     for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
546       Value yieldValue = std::get<0>(tuple);
547       Value asyncValue = std::get<1>(tuple);
548       rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
549       rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
550     }
551 
552     // Switch the coroutine completion token to available state.
553     rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
554 
555     return success();
556   }
557 
558 private:
559   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
560 };
561 
562 //===----------------------------------------------------------------------===//
563 // Convert std.assert operation to cond_br into `set_error` block.
564 //===----------------------------------------------------------------------===//
565 
566 class AssertOpLowering : public OpConversionPattern<AssertOp> {
567 public:
568   AssertOpLowering(MLIRContext *ctx,
569                    llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
570       : OpConversionPattern<AssertOp>(ctx),
571         outlinedFunctions(outlinedFunctions) {}
572 
573   LogicalResult
574   matchAndRewrite(AssertOp op, OpAdaptor adaptor,
575                   ConversionPatternRewriter &rewriter) const override {
576     // Check if assert operation is inside the async coroutine function.
577     auto func = op->template getParentOfType<FuncOp>();
578     auto outlined = outlinedFunctions.find(func);
579     if (outlined == outlinedFunctions.end())
580       return rewriter.notifyMatchFailure(
581           op, "operation is not inside the async coroutine function");
582 
583     Location loc = op->getLoc();
584     CoroMachinery &coro = outlined->getSecond();
585 
586     Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
587     rewriter.setInsertionPointToEnd(cont->getPrevNode());
588     rewriter.create<CondBranchOp>(loc, adaptor.getArg(),
589                                   /*trueDest=*/cont,
590                                   /*trueArgs=*/ArrayRef<Value>(),
591                                   /*falseDest=*/setupSetErrorBlock(coro),
592                                   /*falseArgs=*/ArrayRef<Value>());
593     rewriter.eraseOp(op);
594 
595     return success();
596   }
597 
598 private:
599   llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
600 };
601 
602 //===----------------------------------------------------------------------===//
603 
604 /// Rewrite a func as a coroutine by:
605 /// 1) Wrapping the results into `async.value`.
606 /// 2) Prepending the results with `async.token`.
607 /// 3) Setting up coroutine blocks.
608 /// 4) Rewriting return ops as yield op and branch op into the suspend block.
609 static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) {
610   auto *ctx = func->getContext();
611   auto loc = func.getLoc();
612   SmallVector<Type> resultTypes;
613   resultTypes.reserve(func.getCallableResults().size());
614   llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes),
615                   [](Type type) { return ValueType::get(type); });
616   func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes));
617   func.insertResult(0, TokenType::get(ctx), {});
618   for (Block &block : func.getBlocks()) {
619     Operation *terminator = block.getTerminator();
620     if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) {
621       ImplicitLocOpBuilder builder(loc, returnOp);
622       builder.create<YieldOp>(returnOp.getOperands());
623       returnOp.erase();
624     }
625   }
626   return setupCoroMachinery(func);
627 }
628 
629 /// Rewrites a call into a function that has been rewritten as a coroutine.
630 ///
631 /// The invocation of this function is safe only when call ops are traversed in
632 /// reverse order of how they appear in a single block. See `funcsToCoroutines`.
633 static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) {
634   auto loc = func.getLoc();
635   ImplicitLocOpBuilder callBuilder(loc, oldCall);
636   auto newCall = callBuilder.create<CallOp>(
637       func.getName(), func.getCallableResults(), oldCall.getArgOperands());
638 
639   // Await on the async token and all the value results and unwrap the latter.
640   callBuilder.create<AwaitOp>(loc, newCall.getResults().front());
641   SmallVector<Value> unwrappedResults;
642   unwrappedResults.reserve(newCall->getResults().size() - 1);
643   for (Value result : newCall.getResults().drop_front())
644     unwrappedResults.push_back(
645         callBuilder.create<AwaitOp>(loc, result).result());
646   // Careful, when result of a call is piped into another call this could lead
647   // to a dangling pointer.
648   oldCall.replaceAllUsesWith(unwrappedResults);
649   oldCall.erase();
650 }
651 
652 static bool isAllowedToBlock(FuncOp func) {
653   return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName);
654 }
655 
656 static LogicalResult
657 funcsToCoroutines(ModuleOp module,
658                   llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) {
659   // The following code supports the general case when 2 functions mutually
660   // recurse into each other. Because of this and that we are relying on
661   // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase
662   // a FuncOp while inserting an equivalent coroutine, because that could lead
663   // to dangling pointers.
664 
665   SmallVector<FuncOp> funcWorklist;
666 
667   // Careful, it's okay to add a func to the worklist multiple times if and only
668   // if the loop processing the worklist will skip the functions that have
669   // already been converted to coroutines.
670   auto addToWorklist = [&](FuncOp func) {
671     if (isAllowedToBlock(func))
672       return;
673     // N.B. To refactor this code into a separate pass the lookup in
674     // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary
675     // func and recognizing if it has a coroutine structure is messy. Passing
676     // this dict between the passes is ugly.
677     if (isAllowedToBlock(func) ||
678         outlinedFunctions.find(func) == outlinedFunctions.end()) {
679       for (Operation &op : func.body().getOps()) {
680         if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
681           funcWorklist.push_back(func);
682           break;
683         }
684       }
685     }
686   };
687 
688   // Traverse in post-order collecting for each func op the await ops it has.
689   for (FuncOp func : module.getOps<FuncOp>())
690     addToWorklist(func);
691 
692   SymbolTableCollection symbolTable;
693   SymbolUserMap symbolUserMap(symbolTable, module);
694 
695   // Rewrite funcs, while updating call sites and adding them to the worklist.
696   while (!funcWorklist.empty()) {
697     auto func = funcWorklist.pop_back_val();
698     auto insertion = outlinedFunctions.insert({func, CoroMachinery{}});
699     if (!insertion.second)
700       // This function has already been processed because this is either
701       // the corecursive case, or a caller with multiple calls to a newly
702       // created corouting. Either way, skip updating the call sites.
703       continue;
704     insertion.first->second = rewriteFuncAsCoroutine(func);
705     SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(),
706                                    symbolUserMap.getUsers(func).end());
707     // If there are multiple calls from the same block they need to be traversed
708     // in reverse order so that symbolUserMap references are not invalidated
709     // when updating the users of the call op which is earlier in the block.
710     llvm::sort(users, [](Operation *a, Operation *b) {
711       Block *blockA = a->getBlock();
712       Block *blockB = b->getBlock();
713       // Impose arbitrary order on blocks so that there is a well-defined order.
714       return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b));
715     });
716     // Rewrite the callsites to await on results of the newly created coroutine.
717     for (Operation *op : users) {
718       if (CallOp call = dyn_cast<mlir::CallOp>(*op)) {
719         FuncOp caller = call->getParentOfType<FuncOp>();
720         rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op.
721         addToWorklist(caller);
722       } else {
723         op->emitError("Unexpected reference to func referenced by symbol");
724         return failure();
725       }
726     }
727   }
728   return success();
729 }
730 
731 //===----------------------------------------------------------------------===//
732 void AsyncToAsyncRuntimePass::runOnOperation() {
733   ModuleOp module = getOperation();
734   SymbolTable symbolTable(module);
735 
736   // Outline all `async.execute` body regions into async functions (coroutines).
737   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
738 
739   module.walk([&](ExecuteOp execute) {
740     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
741   });
742 
743   LLVM_DEBUG({
744     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
745                  << " functions built from async.execute operations\n";
746   });
747 
748   // Returns true if operation is inside the coroutine.
749   auto isInCoroutine = [&](Operation *op) -> bool {
750     auto parentFunc = op->getParentOfType<FuncOp>();
751     return outlinedFunctions.find(parentFunc) != outlinedFunctions.end();
752   };
753 
754   if (eliminateBlockingAwaitOps &&
755       failed(funcsToCoroutines(module, outlinedFunctions))) {
756     signalPassFailure();
757     return;
758   }
759 
760   // Lower async operations to async.runtime operations.
761   MLIRContext *ctx = module->getContext();
762   RewritePatternSet asyncPatterns(ctx);
763 
764   // Conversion to async runtime augments original CFG with the coroutine CFG,
765   // and we have to make sure that structured control flow operations with async
766   // operations in nested regions will be converted to branch-based control flow
767   // before we add the coroutine basic blocks.
768   populateLoopToStdConversionPatterns(asyncPatterns);
769 
770   // Async lowering does not use type converter because it must preserve all
771   // types for async.runtime operations.
772   asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
773   asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
774                     AwaitAllOpLowering, YieldOpLowering>(ctx,
775                                                          outlinedFunctions);
776 
777   // Lower assertions to conditional branches into error blocks.
778   asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions);
779 
780   // All high level async operations must be lowered to the runtime operations.
781   ConversionTarget runtimeTarget(*ctx);
782   runtimeTarget.addLegalDialect<AsyncDialect>();
783   runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
784   runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
785 
786   // Decide if structured control flow has to be lowered to branch-based CFG.
787   runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
788     auto walkResult = op->walk([&](Operation *nested) {
789       bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
790       return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
791                                               : WalkResult::advance();
792     });
793     return !walkResult.wasInterrupted();
794   });
795   runtimeTarget.addLegalOp<AssertOp, arith::XOrIOp, arith::ConstantOp,
796                            ConstantOp, BranchOp, CondBranchOp>();
797 
798   // Assertions must be converted to runtime errors inside async functions.
799   runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool {
800     auto func = op->getParentOfType<FuncOp>();
801     return outlinedFunctions.find(func) == outlinedFunctions.end();
802   });
803 
804   if (eliminateBlockingAwaitOps)
805     runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(
806         [&](RuntimeAwaitOp op) -> bool {
807           return isAllowedToBlock(op->getParentOfType<FuncOp>());
808         });
809 
810   if (failed(applyPartialConversion(module, runtimeTarget,
811                                     std::move(asyncPatterns)))) {
812     signalPassFailure();
813     return;
814   }
815 }
816 
817 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
818   return std::make_unique<AsyncToAsyncRuntimePass>();
819 }
820