1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering from high level async operations to async.coro
10 // and async.runtime operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Async/IR/Async.h"
18 #include "mlir/Dialect/Async/Passes.h"
19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "mlir/IR/ImplicitLocOpBuilder.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/RegionUtils.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/Support/Debug.h"
29 
30 using namespace mlir;
31 using namespace mlir::async;
32 
33 #define DEBUG_TYPE "async-to-async-runtime"
34 // Prefix for functions outlined from `async.execute` op regions.
35 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
36 
37 namespace {
38 
39 class AsyncToAsyncRuntimePass
40     : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
41 public:
42   AsyncToAsyncRuntimePass() = default;
43   void runOnOperation() override;
44 };
45 
46 } // namespace
47 
48 //===----------------------------------------------------------------------===//
49 // async.execute op outlining to the coroutine functions.
50 //===----------------------------------------------------------------------===//
51 
52 /// Function targeted for coroutine transformation has two additional blocks at
53 /// the end: coroutine cleanup and coroutine suspension.
54 ///
55 /// async.await op lowering additionaly creates a resume block for each
56 /// operation to enable non-blocking waiting via coroutine suspension.
57 namespace {
58 struct CoroMachinery {
59   func::FuncOp func;
60 
61   // Async execute region returns a completion token, and an async value for
62   // each yielded value.
63   //
64   //   %token, %result = async.execute -> !async.value<T> {
65   //     %0 = arith.constant ... : T
66   //     async.yield %0 : T
67   //   }
68   Value asyncToken; // token representing completion of the async region
69   llvm::SmallVector<Value, 4> returnValues; // returned async values
70 
71   Value coroHandle; // coroutine handle (!async.coro.handle value)
72   Block *entry;     // coroutine entry block
73   Block *setError;  // switch completion token and all values to error state
74   Block *cleanup;   // coroutine cleanup block
75   Block *suspend;   // coroutine suspension block
76 };
77 } // namespace
78 
79 /// Utility to partially update the regular function CFG to the coroutine CFG
80 /// compatible with LLVM coroutines switched-resume lowering using
81 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
82 /// that branches into preexisting entry block. Also inserts trailing blocks.
83 ///
84 /// The result types of the passed `func` must start with an `async.token`
85 /// and be continued with some number of `async.value`s.
86 ///
87 /// The func given to this function needs to have been preprocessed to have
88 /// either branch or yield ops as terminators. Branches to the cleanup block are
89 /// inserted after each yield.
90 ///
91 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
92 ///
93 ///  - `entry` block sets up the coroutine.
94 ///  - `set_error` block sets completion token and async values state to error.
95 ///  - `cleanup` block cleans up the coroutine state.
96 ///  - `suspend block after the @llvm.coro.end() defines what value will be
97 ///    returned to the initial caller of a coroutine. Everything before the
98 ///    @llvm.coro.end() will be executed at every suspension point.
99 ///
100 /// Coroutine structure (only the important bits):
101 ///
102 ///   func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
103 ///   {
104 ///     ^entry(<function-arguments>):
105 ///       %token = <async token> : !async.token    // create async runtime token
106 ///       %value = <async value> : !async.value<T> // create async value
107 ///       %id = async.coro.id                      // create a coroutine id
108 ///       %hdl = async.coro.begin %id              // create a coroutine handle
109 ///       cf.br ^preexisting_entry_block
110 ///
111 ///     /*  preexisting blocks modified to branch to the cleanup block */
112 ///
113 ///     ^set_error: // this block created lazily only if needed (see code below)
114 ///       async.runtime.set_error %token : !async.token
115 ///       async.runtime.set_error %value : !async.value<T>
116 ///       cf.br ^cleanup
117 ///
118 ///     ^cleanup:
119 ///       async.coro.free %hdl // delete the coroutine state
120 ///       cf.br ^suspend
121 ///
122 ///     ^suspend:
123 ///       async.coro.end %hdl // marks the end of a coroutine
124 ///       return %token, %value : !async.token, !async.value<T>
125 ///   }
126 ///
127 static CoroMachinery setupCoroMachinery(func::FuncOp func) {
128   assert(!func.getBlocks().empty() && "Function must have an entry block");
129 
130   MLIRContext *ctx = func.getContext();
131   Block *entryBlock = &func.getBlocks().front();
132   Block *originalEntryBlock =
133       entryBlock->splitBlock(entryBlock->getOperations().begin());
134   auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
135 
136   // ------------------------------------------------------------------------ //
137   // Allocate async token/values that we will return from a ramp function.
138   // ------------------------------------------------------------------------ //
139   auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
140 
141   llvm::SmallVector<Value, 4> retValues;
142   for (auto resType : func.getCallableResults().drop_front())
143     retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
144 
145   // ------------------------------------------------------------------------ //
146   // Initialize coroutine: get coroutine id and coroutine handle.
147   // ------------------------------------------------------------------------ //
148   auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
149   auto coroHdlOp =
150       builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
151   builder.create<cf::BranchOp>(originalEntryBlock);
152 
153   Block *cleanupBlock = func.addBlock();
154   Block *suspendBlock = func.addBlock();
155 
156   // ------------------------------------------------------------------------ //
157   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
158   // ------------------------------------------------------------------------ //
159   builder.setInsertionPointToStart(cleanupBlock);
160   builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
161 
162   // Branch into the suspend block.
163   builder.create<cf::BranchOp>(suspendBlock);
164 
165   // ------------------------------------------------------------------------ //
166   // Coroutine suspend block: mark the end of a coroutine and return allocated
167   // async token.
168   // ------------------------------------------------------------------------ //
169   builder.setInsertionPointToStart(suspendBlock);
170 
171   // Mark the end of a coroutine: async.coro.end
172   builder.create<CoroEndOp>(coroHdlOp.handle());
173 
174   // Return created `async.token` and `async.values` from the suspend block.
175   // This will be the return value of a coroutine ramp function.
176   SmallVector<Value, 4> ret{retToken};
177   ret.insert(ret.end(), retValues.begin(), retValues.end());
178   builder.create<func::ReturnOp>(ret);
179 
180   // `async.await` op lowering will create resume blocks for async
181   // continuations, and will conditionally branch to cleanup or suspend blocks.
182 
183   for (Block &block : func.getBody().getBlocks()) {
184     if (&block == entryBlock || &block == cleanupBlock ||
185         &block == suspendBlock)
186       continue;
187     Operation *terminator = block.getTerminator();
188     if (auto yield = dyn_cast<YieldOp>(terminator)) {
189       builder.setInsertionPointToEnd(&block);
190       builder.create<cf::BranchOp>(cleanupBlock);
191     }
192   }
193 
194   // The switch-resumed API based coroutine should be marked with
195   // "coroutine.presplit" attribute with value "0" to mark the function as a
196   // coroutine.
197   func->setAttr("passthrough", builder.getArrayAttr(builder.getArrayAttr(
198                                    {builder.getStringAttr("coroutine.presplit"),
199                                     builder.getStringAttr("0")})));
200 
201   CoroMachinery machinery;
202   machinery.func = func;
203   machinery.asyncToken = retToken;
204   machinery.returnValues = retValues;
205   machinery.coroHandle = coroHdlOp.handle();
206   machinery.entry = entryBlock;
207   machinery.setError = nullptr; // created lazily only if needed
208   machinery.cleanup = cleanupBlock;
209   machinery.suspend = suspendBlock;
210   return machinery;
211 }
212 
213 // Lazily creates `set_error` block only if it is required for lowering to the
214 // runtime operations (see for example lowering of assert operation).
215 static Block *setupSetErrorBlock(CoroMachinery &coro) {
216   if (coro.setError)
217     return coro.setError;
218 
219   coro.setError = coro.func.addBlock();
220   coro.setError->moveBefore(coro.cleanup);
221 
222   auto builder =
223       ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError);
224 
225   // Coroutine set_error block: set error on token and all returned values.
226   builder.create<RuntimeSetErrorOp>(coro.asyncToken);
227   for (Value retValue : coro.returnValues)
228     builder.create<RuntimeSetErrorOp>(retValue);
229 
230   // Branch into the cleanup block.
231   builder.create<cf::BranchOp>(coro.cleanup);
232 
233   return coro.setError;
234 }
235 
236 /// Outline the body region attached to the `async.execute` op into a standalone
237 /// function.
238 ///
239 /// Note that this is not reversible transformation.
240 static std::pair<func::FuncOp, CoroMachinery>
241 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
242   ModuleOp module = execute->getParentOfType<ModuleOp>();
243 
244   MLIRContext *ctx = module.getContext();
245   Location loc = execute.getLoc();
246 
247   // Make sure that all constants will be inside the outlined async function to
248   // reduce the number of function arguments.
249   cloneConstantsIntoTheRegion(execute.body());
250 
251   // Collect all outlined function inputs.
252   SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
253                                         execute.dependencies().end());
254   functionInputs.insert(execute.operands().begin(), execute.operands().end());
255   getUsedValuesDefinedAbove(execute.body(), functionInputs);
256 
257   // Collect types for the outlined function inputs and outputs.
258   auto typesRange = llvm::map_range(
259       functionInputs, [](Value value) { return value.getType(); });
260   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
261   auto outputTypes = execute.getResultTypes();
262 
263   auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
264   auto funcAttrs = ArrayRef<NamedAttribute>();
265 
266   // TODO: Derive outlined function name from the parent FuncOp (support
267   // multiple nested async.execute operations).
268   func::FuncOp func =
269       func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
270   symbolTable.insert(func);
271 
272   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
273   auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
274 
275   // Prepare for coroutine conversion by creating the body of the function.
276   {
277     size_t numDependencies = execute.dependencies().size();
278     size_t numOperands = execute.operands().size();
279 
280     // Await on all dependencies before starting to execute the body region.
281     for (size_t i = 0; i < numDependencies; ++i)
282       builder.create<AwaitOp>(func.getArgument(i));
283 
284     // Await on all async value operands and unwrap the payload.
285     SmallVector<Value, 4> unwrappedOperands(numOperands);
286     for (size_t i = 0; i < numOperands; ++i) {
287       Value operand = func.getArgument(numDependencies + i);
288       unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
289     }
290 
291     // Map from function inputs defined above the execute op to the function
292     // arguments.
293     BlockAndValueMapping valueMapping;
294     valueMapping.map(functionInputs, func.getArguments());
295     valueMapping.map(execute.body().getArguments(), unwrappedOperands);
296 
297     // Clone all operations from the execute operation body into the outlined
298     // function body.
299     for (Operation &op : execute.body().getOps())
300       builder.clone(op, valueMapping);
301   }
302 
303   // Adding entry/cleanup/suspend blocks.
304   CoroMachinery coro = setupCoroMachinery(func);
305 
306   // Suspend async function at the end of an entry block, and resume it using
307   // Async resume operation (execution will be resumed in a thread managed by
308   // the async runtime).
309   {
310     cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
311     builder.setInsertionPointToEnd(coro.entry);
312 
313     // Save the coroutine state: async.coro.save
314     auto coroSaveOp =
315         builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
316 
317     // Pass coroutine to the runtime to be resumed on a runtime managed
318     // thread.
319     builder.create<RuntimeResumeOp>(coro.coroHandle);
320 
321     // Add async.coro.suspend as a suspended block terminator.
322     builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend,
323                                   branch.getDest(), coro.cleanup);
324 
325     branch.erase();
326   }
327 
328   // Replace the original `async.execute` with a call to outlined function.
329   {
330     ImplicitLocOpBuilder callBuilder(loc, execute);
331     auto callOutlinedFunc = callBuilder.create<func::CallOp>(
332         func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
333     execute.replaceAllUsesWith(callOutlinedFunc.getResults());
334     execute.erase();
335   }
336 
337   return {func, coro};
338 }
339 
340 //===----------------------------------------------------------------------===//
341 // Convert async.create_group operation to async.runtime.create_group
342 //===----------------------------------------------------------------------===//
343 
344 namespace {
345 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
346 public:
347   using OpConversionPattern::OpConversionPattern;
348 
349   LogicalResult
350   matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
351                   ConversionPatternRewriter &rewriter) const override {
352     rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
353         op, GroupType::get(op->getContext()), adaptor.getOperands());
354     return success();
355   }
356 };
357 } // namespace
358 
359 //===----------------------------------------------------------------------===//
360 // Convert async.add_to_group operation to async.runtime.add_to_group.
361 //===----------------------------------------------------------------------===//
362 
363 namespace {
364 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
365 public:
366   using OpConversionPattern::OpConversionPattern;
367 
368   LogicalResult
369   matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
370                   ConversionPatternRewriter &rewriter) const override {
371     rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
372         op, rewriter.getIndexType(), adaptor.getOperands());
373     return success();
374   }
375 };
376 } // namespace
377 
378 //===----------------------------------------------------------------------===//
379 // Convert async.await and async.await_all operations to the async.runtime.await
380 // or async.runtime.await_and_resume operations.
381 //===----------------------------------------------------------------------===//
382 
383 namespace {
384 template <typename AwaitType, typename AwaitableType>
385 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
386   using AwaitAdaptor = typename AwaitType::Adaptor;
387 
388 public:
389   AwaitOpLoweringBase(
390       MLIRContext *ctx,
391       llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions)
392       : OpConversionPattern<AwaitType>(ctx),
393         outlinedFunctions(outlinedFunctions) {}
394 
395   LogicalResult
396   matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
397                   ConversionPatternRewriter &rewriter) const override {
398     // We can only await on one the `AwaitableType` (for `await` it can be
399     // a `token` or a `value`, for `await_all` it must be a `group`).
400     if (!op.operand().getType().template isa<AwaitableType>())
401       return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
402 
403     // Check if await operation is inside the outlined coroutine function.
404     auto func = op->template getParentOfType<func::FuncOp>();
405     auto outlined = outlinedFunctions.find(func);
406     const bool isInCoroutine = outlined != outlinedFunctions.end();
407 
408     Location loc = op->getLoc();
409     Value operand = adaptor.operand();
410 
411     Type i1 = rewriter.getI1Type();
412 
413     // Inside regular functions we use the blocking wait operation to wait for
414     // the async object (token, value or group) to become available.
415     if (!isInCoroutine) {
416       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
417       builder.create<RuntimeAwaitOp>(loc, operand);
418 
419       // Assert that the awaited operands is not in the error state.
420       Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
421       Value notError = builder.create<arith::XOrIOp>(
422           isError, builder.create<arith::ConstantOp>(
423                        loc, i1, builder.getIntegerAttr(i1, 1)));
424 
425       builder.create<cf::AssertOp>(notError,
426                                    "Awaited async operand is in error state");
427     }
428 
429     // Inside the coroutine we convert await operation into coroutine suspension
430     // point, and resume execution asynchronously.
431     if (isInCoroutine) {
432       CoroMachinery &coro = outlined->getSecond();
433       Block *suspended = op->getBlock();
434 
435       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
436       MLIRContext *ctx = op->getContext();
437 
438       // Save the coroutine state and resume on a runtime managed thread when
439       // the operand becomes available.
440       auto coroSaveOp =
441           builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
442       builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
443 
444       // Split the entry block before the await operation.
445       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
446 
447       // Add async.coro.suspend as a suspended block terminator.
448       builder.setInsertionPointToEnd(suspended);
449       builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
450                                     coro.cleanup);
451 
452       // Split the resume block into error checking and continuation.
453       Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
454 
455       // Check if the awaited value is in the error state.
456       builder.setInsertionPointToStart(resume);
457       auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
458       builder.create<cf::CondBranchOp>(isError,
459                                        /*trueDest=*/setupSetErrorBlock(coro),
460                                        /*trueArgs=*/ArrayRef<Value>(),
461                                        /*falseDest=*/continuation,
462                                        /*falseArgs=*/ArrayRef<Value>());
463 
464       // Make sure that replacement value will be constructed in the
465       // continuation block.
466       rewriter.setInsertionPointToStart(continuation);
467     }
468 
469     // Erase or replace the await operation with the new value.
470     if (Value replaceWith = getReplacementValue(op, operand, rewriter))
471       rewriter.replaceOp(op, replaceWith);
472     else
473       rewriter.eraseOp(op);
474 
475     return success();
476   }
477 
478   virtual Value getReplacementValue(AwaitType op, Value operand,
479                                     ConversionPatternRewriter &rewriter) const {
480     return Value();
481   }
482 
483 private:
484   llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions;
485 };
486 
487 /// Lowering for `async.await` with a token operand.
488 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
489   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
490 
491 public:
492   using Base::Base;
493 };
494 
495 /// Lowering for `async.await` with a value operand.
496 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
497   using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
498 
499 public:
500   using Base::Base;
501 
502   Value
503   getReplacementValue(AwaitOp op, Value operand,
504                       ConversionPatternRewriter &rewriter) const override {
505     // Load from the async value storage.
506     auto valueType = operand.getType().cast<ValueType>().getValueType();
507     return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
508   }
509 };
510 
511 /// Lowering for `async.await_all` operation.
512 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
513   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
514 
515 public:
516   using Base::Base;
517 };
518 
519 } // namespace
520 
521 //===----------------------------------------------------------------------===//
522 // Convert async.yield operation to async.runtime operations.
523 //===----------------------------------------------------------------------===//
524 
525 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
526 public:
527   YieldOpLowering(
528       MLIRContext *ctx,
529       const llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions)
530       : OpConversionPattern<async::YieldOp>(ctx),
531         outlinedFunctions(outlinedFunctions) {}
532 
533   LogicalResult
534   matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
535                   ConversionPatternRewriter &rewriter) const override {
536     // Check if yield operation is inside the async coroutine function.
537     auto func = op->template getParentOfType<func::FuncOp>();
538     auto outlined = outlinedFunctions.find(func);
539     if (outlined == outlinedFunctions.end())
540       return rewriter.notifyMatchFailure(
541           op, "operation is not inside the async coroutine function");
542 
543     Location loc = op->getLoc();
544     const CoroMachinery &coro = outlined->getSecond();
545 
546     // Store yielded values into the async values storage and switch async
547     // values state to available.
548     for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
549       Value yieldValue = std::get<0>(tuple);
550       Value asyncValue = std::get<1>(tuple);
551       rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
552       rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
553     }
554 
555     // Switch the coroutine completion token to available state.
556     rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
557 
558     return success();
559   }
560 
561 private:
562   const llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions;
563 };
564 
565 //===----------------------------------------------------------------------===//
566 // Convert cf.assert operation to cf.cond_br into `set_error` block.
567 //===----------------------------------------------------------------------===//
568 
569 class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
570 public:
571   AssertOpLowering(
572       MLIRContext *ctx,
573       llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions)
574       : OpConversionPattern<cf::AssertOp>(ctx),
575         outlinedFunctions(outlinedFunctions) {}
576 
577   LogicalResult
578   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
579                   ConversionPatternRewriter &rewriter) const override {
580     // Check if assert operation is inside the async coroutine function.
581     auto func = op->template getParentOfType<func::FuncOp>();
582     auto outlined = outlinedFunctions.find(func);
583     if (outlined == outlinedFunctions.end())
584       return rewriter.notifyMatchFailure(
585           op, "operation is not inside the async coroutine function");
586 
587     Location loc = op->getLoc();
588     CoroMachinery &coro = outlined->getSecond();
589 
590     Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
591     rewriter.setInsertionPointToEnd(cont->getPrevNode());
592     rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(),
593                                       /*trueDest=*/cont,
594                                       /*trueArgs=*/ArrayRef<Value>(),
595                                       /*falseDest=*/setupSetErrorBlock(coro),
596                                       /*falseArgs=*/ArrayRef<Value>());
597     rewriter.eraseOp(op);
598 
599     return success();
600   }
601 
602 private:
603   llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions;
604 };
605 
606 //===----------------------------------------------------------------------===//
607 
608 /// Rewrite a func as a coroutine by:
609 /// 1) Wrapping the results into `async.value`.
610 /// 2) Prepending the results with `async.token`.
611 /// 3) Setting up coroutine blocks.
612 /// 4) Rewriting return ops as yield op and branch op into the suspend block.
613 static CoroMachinery rewriteFuncAsCoroutine(func::FuncOp func) {
614   auto *ctx = func->getContext();
615   auto loc = func.getLoc();
616   SmallVector<Type> resultTypes;
617   resultTypes.reserve(func.getCallableResults().size());
618   llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes),
619                   [](Type type) { return ValueType::get(type); });
620   func.setType(
621       FunctionType::get(ctx, func.getFunctionType().getInputs(), resultTypes));
622   func.insertResult(0, TokenType::get(ctx), {});
623   for (Block &block : func.getBlocks()) {
624     Operation *terminator = block.getTerminator();
625     if (auto returnOp = dyn_cast<func::ReturnOp>(*terminator)) {
626       ImplicitLocOpBuilder builder(loc, returnOp);
627       builder.create<YieldOp>(returnOp.getOperands());
628       returnOp.erase();
629     }
630   }
631   return setupCoroMachinery(func);
632 }
633 
634 /// Rewrites a call into a function that has been rewritten as a coroutine.
635 ///
636 /// The invocation of this function is safe only when call ops are traversed in
637 /// reverse order of how they appear in a single block. See `funcsToCoroutines`.
638 static void rewriteCallsiteForCoroutine(func::CallOp oldCall,
639                                         func::FuncOp func) {
640   auto loc = func.getLoc();
641   ImplicitLocOpBuilder callBuilder(loc, oldCall);
642   auto newCall = callBuilder.create<func::CallOp>(
643       func.getName(), func.getCallableResults(), oldCall.getArgOperands());
644 
645   // Await on the async token and all the value results and unwrap the latter.
646   callBuilder.create<AwaitOp>(loc, newCall.getResults().front());
647   SmallVector<Value> unwrappedResults;
648   unwrappedResults.reserve(newCall->getResults().size() - 1);
649   for (Value result : newCall.getResults().drop_front())
650     unwrappedResults.push_back(
651         callBuilder.create<AwaitOp>(loc, result).result());
652   // Careful, when result of a call is piped into another call this could lead
653   // to a dangling pointer.
654   oldCall.replaceAllUsesWith(unwrappedResults);
655   oldCall.erase();
656 }
657 
658 static bool isAllowedToBlock(func::FuncOp func) {
659   return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName);
660 }
661 
662 static LogicalResult funcsToCoroutines(
663     ModuleOp module,
664     llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions) {
665   // The following code supports the general case when 2 functions mutually
666   // recurse into each other. Because of this and that we are relying on
667   // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase
668   // a FuncOp while inserting an equivalent coroutine, because that could lead
669   // to dangling pointers.
670 
671   SmallVector<func::FuncOp> funcWorklist;
672 
673   // Careful, it's okay to add a func to the worklist multiple times if and only
674   // if the loop processing the worklist will skip the functions that have
675   // already been converted to coroutines.
676   auto addToWorklist = [&](func::FuncOp func) {
677     if (isAllowedToBlock(func))
678       return;
679     // N.B. To refactor this code into a separate pass the lookup in
680     // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary
681     // func and recognizing if it has a coroutine structure is messy. Passing
682     // this dict between the passes is ugly.
683     if (isAllowedToBlock(func) ||
684         outlinedFunctions.find(func) == outlinedFunctions.end()) {
685       for (Operation &op : func.getBody().getOps()) {
686         if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
687           funcWorklist.push_back(func);
688           break;
689         }
690       }
691     }
692   };
693 
694   // Traverse in post-order collecting for each func op the await ops it has.
695   for (func::FuncOp func : module.getOps<func::FuncOp>())
696     addToWorklist(func);
697 
698   SymbolTableCollection symbolTable;
699   SymbolUserMap symbolUserMap(symbolTable, module);
700 
701   // Rewrite funcs, while updating call sites and adding them to the worklist.
702   while (!funcWorklist.empty()) {
703     auto func = funcWorklist.pop_back_val();
704     auto insertion = outlinedFunctions.insert({func, CoroMachinery{}});
705     if (!insertion.second)
706       // This function has already been processed because this is either
707       // the corecursive case, or a caller with multiple calls to a newly
708       // created corouting. Either way, skip updating the call sites.
709       continue;
710     insertion.first->second = rewriteFuncAsCoroutine(func);
711     SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(),
712                                    symbolUserMap.getUsers(func).end());
713     // If there are multiple calls from the same block they need to be traversed
714     // in reverse order so that symbolUserMap references are not invalidated
715     // when updating the users of the call op which is earlier in the block.
716     llvm::sort(users, [](Operation *a, Operation *b) {
717       Block *blockA = a->getBlock();
718       Block *blockB = b->getBlock();
719       // Impose arbitrary order on blocks so that there is a well-defined order.
720       return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b));
721     });
722     // Rewrite the callsites to await on results of the newly created coroutine.
723     for (Operation *op : users) {
724       if (func::CallOp call = dyn_cast<func::CallOp>(*op)) {
725         func::FuncOp caller = call->getParentOfType<func::FuncOp>();
726         rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op.
727         addToWorklist(caller);
728       } else {
729         op->emitError("Unexpected reference to func referenced by symbol");
730         return failure();
731       }
732     }
733   }
734   return success();
735 }
736 
737 //===----------------------------------------------------------------------===//
738 void AsyncToAsyncRuntimePass::runOnOperation() {
739   ModuleOp module = getOperation();
740   SymbolTable symbolTable(module);
741 
742   // Outline all `async.execute` body regions into async functions (coroutines).
743   llvm::DenseMap<func::FuncOp, CoroMachinery> outlinedFunctions;
744 
745   module.walk([&](ExecuteOp execute) {
746     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
747   });
748 
749   LLVM_DEBUG({
750     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
751                  << " functions built from async.execute operations\n";
752   });
753 
754   // Returns true if operation is inside the coroutine.
755   auto isInCoroutine = [&](Operation *op) -> bool {
756     auto parentFunc = op->getParentOfType<func::FuncOp>();
757     return outlinedFunctions.find(parentFunc) != outlinedFunctions.end();
758   };
759 
760   if (eliminateBlockingAwaitOps &&
761       failed(funcsToCoroutines(module, outlinedFunctions))) {
762     signalPassFailure();
763     return;
764   }
765 
766   // Lower async operations to async.runtime operations.
767   MLIRContext *ctx = module->getContext();
768   RewritePatternSet asyncPatterns(ctx);
769 
770   // Conversion to async runtime augments original CFG with the coroutine CFG,
771   // and we have to make sure that structured control flow operations with async
772   // operations in nested regions will be converted to branch-based control flow
773   // before we add the coroutine basic blocks.
774   populateSCFToControlFlowConversionPatterns(asyncPatterns);
775 
776   // Async lowering does not use type converter because it must preserve all
777   // types for async.runtime operations.
778   asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
779   asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
780                     AwaitAllOpLowering, YieldOpLowering>(ctx,
781                                                          outlinedFunctions);
782 
783   // Lower assertions to conditional branches into error blocks.
784   asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions);
785 
786   // All high level async operations must be lowered to the runtime operations.
787   ConversionTarget runtimeTarget(*ctx);
788   runtimeTarget.addLegalDialect<AsyncDialect>();
789   runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
790   runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
791 
792   // Decide if structured control flow has to be lowered to branch-based CFG.
793   runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
794     auto walkResult = op->walk([&](Operation *nested) {
795       bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
796       return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
797                                               : WalkResult::advance();
798     });
799     return !walkResult.wasInterrupted();
800   });
801   runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
802                            func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
803 
804   // Assertions must be converted to runtime errors inside async functions.
805   runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
806       [&](cf::AssertOp op) -> bool {
807         auto func = op->getParentOfType<func::FuncOp>();
808         return outlinedFunctions.find(func) == outlinedFunctions.end();
809       });
810 
811   if (eliminateBlockingAwaitOps)
812     runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(
813         [&](RuntimeAwaitOp op) -> bool {
814           return isAllowedToBlock(op->getParentOfType<func::FuncOp>());
815         });
816 
817   if (failed(applyPartialConversion(module, runtimeTarget,
818                                     std::move(asyncPatterns)))) {
819     signalPassFailure();
820     return;
821   }
822 }
823 
824 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
825   return std::make_unique<AsyncToAsyncRuntimePass>();
826 }
827