1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering from high level async operations to async.coro
10 // and async.runtime operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
16 #include "mlir/Dialect/Async/IR/Async.h"
17 #include "mlir/Dialect/Async/Passes.h"
18 #include "mlir/Dialect/SCF/SCF.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "mlir/Transforms/RegionUtils.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/Support/Debug.h"
27 
28 using namespace mlir;
29 using namespace mlir::async;
30 
31 #define DEBUG_TYPE "async-to-async-runtime"
32 // Prefix for functions outlined from `async.execute` op regions.
33 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
34 
35 namespace {
36 
37 class AsyncToAsyncRuntimePass
38     : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
39 public:
40   AsyncToAsyncRuntimePass() = default;
41   void runOnOperation() override;
42 };
43 
44 } // namespace
45 
46 //===----------------------------------------------------------------------===//
47 // async.execute op outlining to the coroutine functions.
48 //===----------------------------------------------------------------------===//
49 
50 /// Function targeted for coroutine transformation has two additional blocks at
51 /// the end: coroutine cleanup and coroutine suspension.
52 ///
53 /// async.await op lowering additionaly creates a resume block for each
54 /// operation to enable non-blocking waiting via coroutine suspension.
55 namespace {
56 struct CoroMachinery {
57   FuncOp func;
58 
59   // Async execute region returns a completion token, and an async value for
60   // each yielded value.
61   //
62   //   %token, %result = async.execute -> !async.value<T> {
63   //     %0 = constant ... : T
64   //     async.yield %0 : T
65   //   }
66   Value asyncToken; // token representing completion of the async region
67   llvm::SmallVector<Value, 4> returnValues; // returned async values
68 
69   Value coroHandle; // coroutine handle (!async.coro.handle value)
70   Block *setError;  // switch completion token and all values to error state
71   Block *cleanup;   // coroutine cleanup block
72   Block *suspend;   // coroutine suspension block
73 };
74 } // namespace
75 
76 /// Builds an coroutine template compatible with LLVM coroutines switched-resume
77 /// lowering using `async.runtime.*` and `async.coro.*` operations.
78 ///
79 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
80 ///
81 ///  - `entry` block sets up the coroutine.
82 ///  - `set_error` block sets completion token and async values state to error.
83 ///  - `cleanup` block cleans up the coroutine state.
84 ///  - `suspend block after the @llvm.coro.end() defines what value will be
85 ///    returned to the initial caller of a coroutine. Everything before the
86 ///    @llvm.coro.end() will be executed at every suspension point.
87 ///
88 /// Coroutine structure (only the important bits):
89 ///
90 ///   func @async_execute_fn(<function-arguments>)
91 ///        -> (!async.token, !async.value<T>)
92 ///   {
93 ///     ^entry(<function-arguments>):
94 ///       %token = <async token> : !async.token    // create async runtime token
95 ///       %value = <async value> : !async.value<T> // create async value
96 ///       %id = async.coro.id                      // create a coroutine id
97 ///       %hdl = async.coro.begin %id              // create a coroutine handle
98 ///       br ^cleanup
99 ///
100 ///     ^set_error: // this block created lazily only if needed (see code below)
101 ///       async.runtime.set_error %token : !async.token
102 ///       async.runtime.set_error %value : !async.value<T>
103 ///       br ^cleanup
104 ///
105 ///     ^cleanup:
106 ///       async.coro.free %hdl // delete the coroutine state
107 ///       br ^suspend
108 ///
109 ///     ^suspend:
110 ///       async.coro.end %hdl // marks the end of a coroutine
111 ///       return %token, %value : !async.token, !async.value<T>
112 ///   }
113 ///
114 /// The actual code for the async.execute operation body region will be inserted
115 /// before the entry block terminator.
116 ///
117 ///
118 static CoroMachinery setupCoroMachinery(FuncOp func) {
119   assert(func.getBody().empty() && "Function must have empty body");
120 
121   MLIRContext *ctx = func.getContext();
122   Block *entryBlock = func.addEntryBlock();
123 
124   auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
125 
126   // ------------------------------------------------------------------------ //
127   // Allocate async token/values that we will return from a ramp function.
128   // ------------------------------------------------------------------------ //
129   auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
130 
131   llvm::SmallVector<Value, 4> retValues;
132   for (auto resType : func.getCallableResults().drop_front())
133     retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
134 
135   // ------------------------------------------------------------------------ //
136   // Initialize coroutine: get coroutine id and coroutine handle.
137   // ------------------------------------------------------------------------ //
138   auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
139   auto coroHdlOp =
140       builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
141 
142   Block *cleanupBlock = func.addBlock();
143   Block *suspendBlock = func.addBlock();
144 
145   // ------------------------------------------------------------------------ //
146   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
147   // ------------------------------------------------------------------------ //
148   builder.setInsertionPointToStart(cleanupBlock);
149   builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
150 
151   // Branch into the suspend block.
152   builder.create<BranchOp>(suspendBlock);
153 
154   // ------------------------------------------------------------------------ //
155   // Coroutine suspend block: mark the end of a coroutine and return allocated
156   // async token.
157   // ------------------------------------------------------------------------ //
158   builder.setInsertionPointToStart(suspendBlock);
159 
160   // Mark the end of a coroutine: async.coro.end
161   builder.create<CoroEndOp>(coroHdlOp.handle());
162 
163   // Return created `async.token` and `async.values` from the suspend block.
164   // This will be the return value of a coroutine ramp function.
165   SmallVector<Value, 4> ret{retToken};
166   ret.insert(ret.end(), retValues.begin(), retValues.end());
167   builder.create<ReturnOp>(ret);
168 
169   // Branch from the entry block to the cleanup block to create a valid CFG.
170   builder.setInsertionPointToEnd(entryBlock);
171   builder.create<BranchOp>(cleanupBlock);
172 
173   // `async.await` op lowering will create resume blocks for async
174   // continuations, and will conditionally branch to cleanup or suspend blocks.
175 
176   CoroMachinery machinery;
177   machinery.func = func;
178   machinery.asyncToken = retToken;
179   machinery.returnValues = retValues;
180   machinery.coroHandle = coroHdlOp.handle();
181   machinery.setError = nullptr; // created lazily only if needed
182   machinery.cleanup = cleanupBlock;
183   machinery.suspend = suspendBlock;
184   return machinery;
185 }
186 
187 // Lazily creates `set_error` block only if it is required for lowering to the
188 // runtime operations (see for example lowering of assert operation).
189 static Block *setupSetErrorBlock(CoroMachinery &coro) {
190   if (coro.setError)
191     return coro.setError;
192 
193   coro.setError = coro.func.addBlock();
194   coro.setError->moveBefore(coro.cleanup);
195 
196   auto builder =
197       ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError);
198 
199   // Coroutine set_error block: set error on token and all returned values.
200   builder.create<RuntimeSetErrorOp>(coro.asyncToken);
201   for (Value retValue : coro.returnValues)
202     builder.create<RuntimeSetErrorOp>(retValue);
203 
204   // Branch into the cleanup block.
205   builder.create<BranchOp>(coro.cleanup);
206 
207   return coro.setError;
208 }
209 
210 /// Outline the body region attached to the `async.execute` op into a standalone
211 /// function.
212 ///
213 /// Note that this is not reversible transformation.
214 static std::pair<FuncOp, CoroMachinery>
215 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
216   ModuleOp module = execute->getParentOfType<ModuleOp>();
217 
218   MLIRContext *ctx = module.getContext();
219   Location loc = execute.getLoc();
220 
221   // Collect all outlined function inputs.
222   SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
223                                         execute.dependencies().end());
224   functionInputs.insert(execute.operands().begin(), execute.operands().end());
225   getUsedValuesDefinedAbove(execute.body(), functionInputs);
226 
227   // Collect types for the outlined function inputs and outputs.
228   auto typesRange = llvm::map_range(
229       functionInputs, [](Value value) { return value.getType(); });
230   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
231   auto outputTypes = execute.getResultTypes();
232 
233   auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
234   auto funcAttrs = ArrayRef<NamedAttribute>();
235 
236   // TODO: Derive outlined function name from the parent FuncOp (support
237   // multiple nested async.execute operations).
238   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
239   symbolTable.insert(func);
240 
241   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
242 
243   // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
244   // blocks, adding async.coro operations and setting up control flow.
245   CoroMachinery coro = setupCoroMachinery(func);
246 
247   // Suspend async function at the end of an entry block, and resume it using
248   // Async resume operation (execution will be resumed in a thread managed by
249   // the async runtime).
250   Block *entryBlock = &func.getBlocks().front();
251   auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock);
252 
253   // Save the coroutine state: async.coro.save
254   auto coroSaveOp =
255       builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
256 
257   // Pass coroutine to the runtime to be resumed on a runtime managed thread.
258   builder.create<RuntimeResumeOp>(coro.coroHandle);
259 
260   // Split the entry block before the terminator (branch to suspend block).
261   auto *terminatorOp = entryBlock->getTerminator();
262   Block *suspended = terminatorOp->getBlock();
263   Block *resume = suspended->splitBlock(terminatorOp);
264 
265   // Add async.coro.suspend as a suspended block terminator.
266   builder.setInsertionPointToEnd(suspended);
267   builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
268                                 coro.cleanup);
269 
270   size_t numDependencies = execute.dependencies().size();
271   size_t numOperands = execute.operands().size();
272 
273   // Await on all dependencies before starting to execute the body region.
274   builder.setInsertionPointToStart(resume);
275   for (size_t i = 0; i < numDependencies; ++i)
276     builder.create<AwaitOp>(func.getArgument(i));
277 
278   // Await on all async value operands and unwrap the payload.
279   SmallVector<Value, 4> unwrappedOperands(numOperands);
280   for (size_t i = 0; i < numOperands; ++i) {
281     Value operand = func.getArgument(numDependencies + i);
282     unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
283   }
284 
285   // Map from function inputs defined above the execute op to the function
286   // arguments.
287   BlockAndValueMapping valueMapping;
288   valueMapping.map(functionInputs, func.getArguments());
289   valueMapping.map(execute.body().getArguments(), unwrappedOperands);
290 
291   // Clone all operations from the execute operation body into the outlined
292   // function body.
293   for (Operation &op : execute.body().getOps())
294     builder.clone(op, valueMapping);
295 
296   // Replace the original `async.execute` with a call to outlined function.
297   ImplicitLocOpBuilder callBuilder(loc, execute);
298   auto callOutlinedFunc = callBuilder.create<CallOp>(
299       func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
300   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
301   execute.erase();
302 
303   return {func, coro};
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // Convert async.create_group operation to async.runtime.create_group
308 //===----------------------------------------------------------------------===//
309 
310 namespace {
311 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
312 public:
313   using OpConversionPattern::OpConversionPattern;
314 
315   LogicalResult
316   matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands,
317                   ConversionPatternRewriter &rewriter) const override {
318     rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
319         op, GroupType::get(op->getContext()), operands);
320     return success();
321   }
322 };
323 } // namespace
324 
325 //===----------------------------------------------------------------------===//
326 // Convert async.add_to_group operation to async.runtime.add_to_group.
327 //===----------------------------------------------------------------------===//
328 
329 namespace {
330 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
331 public:
332   using OpConversionPattern::OpConversionPattern;
333 
334   LogicalResult
335   matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands,
336                   ConversionPatternRewriter &rewriter) const override {
337     rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
338         op, rewriter.getIndexType(), operands);
339     return success();
340   }
341 };
342 } // namespace
343 
344 //===----------------------------------------------------------------------===//
345 // Convert async.await and async.await_all operations to the async.runtime.await
346 // or async.runtime.await_and_resume operations.
347 //===----------------------------------------------------------------------===//
348 
349 namespace {
350 template <typename AwaitType, typename AwaitableType>
351 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
352   using AwaitAdaptor = typename AwaitType::Adaptor;
353 
354 public:
355   AwaitOpLoweringBase(MLIRContext *ctx,
356                       llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
357       : OpConversionPattern<AwaitType>(ctx),
358         outlinedFunctions(outlinedFunctions) {}
359 
360   LogicalResult
361   matchAndRewrite(AwaitType op, ArrayRef<Value> operands,
362                   ConversionPatternRewriter &rewriter) const override {
363     // We can only await on one the `AwaitableType` (for `await` it can be
364     // a `token` or a `value`, for `await_all` it must be a `group`).
365     if (!op.operand().getType().template isa<AwaitableType>())
366       return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
367 
368     // Check if await operation is inside the outlined coroutine function.
369     auto func = op->template getParentOfType<FuncOp>();
370     auto outlined = outlinedFunctions.find(func);
371     const bool isInCoroutine = outlined != outlinedFunctions.end();
372 
373     Location loc = op->getLoc();
374     Value operand = AwaitAdaptor(operands).operand();
375 
376     // Inside regular functions we use the blocking wait operation to wait for
377     // the async object (token, value or group) to become available.
378     if (!isInCoroutine)
379       rewriter.create<RuntimeAwaitOp>(loc, operand);
380 
381     // Inside the coroutine we convert await operation into coroutine suspension
382     // point, and resume execution asynchronously.
383     if (isInCoroutine) {
384       CoroMachinery &coro = outlined->getSecond();
385       Block *suspended = op->getBlock();
386 
387       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
388       MLIRContext *ctx = op->getContext();
389 
390       // Save the coroutine state and resume on a runtime managed thread when
391       // the operand becomes available.
392       auto coroSaveOp =
393           builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
394       builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
395 
396       // Split the entry block before the await operation.
397       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
398 
399       // Add async.coro.suspend as a suspended block terminator.
400       builder.setInsertionPointToEnd(suspended);
401       builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
402                                     coro.cleanup);
403 
404       // Split the resume block into error checking and continuation.
405       Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
406 
407       // Check if the awaited value is in the error state.
408       builder.setInsertionPointToStart(resume);
409       auto isError =
410           builder.create<RuntimeIsErrorOp>(loc, rewriter.getI1Type(), operand);
411       builder.create<CondBranchOp>(isError,
412                                    /*trueDest=*/setupSetErrorBlock(coro),
413                                    /*trueArgs=*/ArrayRef<Value>(),
414                                    /*falseDest=*/continuation,
415                                    /*falseArgs=*/ArrayRef<Value>());
416 
417       // Make sure that replacement value will be constructed in the
418       // continuation block.
419       rewriter.setInsertionPointToStart(continuation);
420     }
421 
422     // Erase or replace the await operation with the new value.
423     if (Value replaceWith = getReplacementValue(op, operand, rewriter))
424       rewriter.replaceOp(op, replaceWith);
425     else
426       rewriter.eraseOp(op);
427 
428     return success();
429   }
430 
431   virtual Value getReplacementValue(AwaitType op, Value operand,
432                                     ConversionPatternRewriter &rewriter) const {
433     return Value();
434   }
435 
436 private:
437   llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
438 };
439 
440 /// Lowering for `async.await` with a token operand.
441 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
442   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
443 
444 public:
445   using Base::Base;
446 };
447 
448 /// Lowering for `async.await` with a value operand.
449 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
450   using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
451 
452 public:
453   using Base::Base;
454 
455   Value
456   getReplacementValue(AwaitOp op, Value operand,
457                       ConversionPatternRewriter &rewriter) const override {
458     // Load from the async value storage.
459     auto valueType = operand.getType().cast<ValueType>().getValueType();
460     return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
461   }
462 };
463 
464 /// Lowering for `async.await_all` operation.
465 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
466   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
467 
468 public:
469   using Base::Base;
470 };
471 
472 } // namespace
473 
474 //===----------------------------------------------------------------------===//
475 // Convert async.yield operation to async.runtime operations.
476 //===----------------------------------------------------------------------===//
477 
478 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
479 public:
480   YieldOpLowering(
481       MLIRContext *ctx,
482       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
483       : OpConversionPattern<async::YieldOp>(ctx),
484         outlinedFunctions(outlinedFunctions) {}
485 
486   LogicalResult
487   matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
488                   ConversionPatternRewriter &rewriter) const override {
489     // Check if yield operation is inside the async coroutine function.
490     auto func = op->template getParentOfType<FuncOp>();
491     auto outlined = outlinedFunctions.find(func);
492     if (outlined == outlinedFunctions.end())
493       return rewriter.notifyMatchFailure(
494           op, "operation is not inside the async coroutine function");
495 
496     Location loc = op->getLoc();
497     const CoroMachinery &coro = outlined->getSecond();
498 
499     // Store yielded values into the async values storage and switch async
500     // values state to available.
501     for (auto tuple : llvm::zip(operands, coro.returnValues)) {
502       Value yieldValue = std::get<0>(tuple);
503       Value asyncValue = std::get<1>(tuple);
504       rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
505       rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
506     }
507 
508     // Switch the coroutine completion token to available state.
509     rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
510 
511     return success();
512   }
513 
514 private:
515   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
516 };
517 
518 //===----------------------------------------------------------------------===//
519 // Convert std.assert operation to cond_br into `set_error` block.
520 //===----------------------------------------------------------------------===//
521 
522 class AssertOpLowering : public OpConversionPattern<AssertOp> {
523 public:
524   AssertOpLowering(MLIRContext *ctx,
525                    llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
526       : OpConversionPattern<AssertOp>(ctx),
527         outlinedFunctions(outlinedFunctions) {}
528 
529   LogicalResult
530   matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
531                   ConversionPatternRewriter &rewriter) const override {
532     // Check if assert operation is inside the async coroutine function.
533     auto func = op->template getParentOfType<FuncOp>();
534     auto outlined = outlinedFunctions.find(func);
535     if (outlined == outlinedFunctions.end())
536       return rewriter.notifyMatchFailure(
537           op, "operation is not inside the async coroutine function");
538 
539     Location loc = op->getLoc();
540     CoroMachinery &coro = outlined->getSecond();
541 
542     Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
543     rewriter.setInsertionPointToEnd(cont->getPrevNode());
544     rewriter.create<CondBranchOp>(loc, AssertOpAdaptor(operands).arg(),
545                                   /*trueDest=*/cont,
546                                   /*trueArgs=*/ArrayRef<Value>(),
547                                   /*falseDest=*/setupSetErrorBlock(coro),
548                                   /*falseArgs=*/ArrayRef<Value>());
549     rewriter.eraseOp(op);
550 
551     return success();
552   }
553 
554 private:
555   llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
556 };
557 
558 //===----------------------------------------------------------------------===//
559 
560 void AsyncToAsyncRuntimePass::runOnOperation() {
561   ModuleOp module = getOperation();
562   SymbolTable symbolTable(module);
563 
564   // Outline all `async.execute` body regions into async functions (coroutines).
565   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
566 
567   module.walk([&](ExecuteOp execute) {
568     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
569   });
570 
571   LLVM_DEBUG({
572     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
573                  << " functions built from async.execute operations\n";
574   });
575 
576   // Returns true if operation is inside the coroutine.
577   auto isInCoroutine = [&](Operation *op) -> bool {
578     auto parentFunc = op->getParentOfType<FuncOp>();
579     return outlinedFunctions.find(parentFunc) != outlinedFunctions.end();
580   };
581 
582   // Lower async operations to async.runtime operations.
583   MLIRContext *ctx = module->getContext();
584   RewritePatternSet asyncPatterns(ctx);
585 
586   // Conversion to async runtime augments original CFG with the coroutine CFG,
587   // and we have to make sure that structured control flow operations with async
588   // operations in nested regions will be converted to branch-based control flow
589   // before we add the coroutine basic blocks.
590   populateLoopToStdConversionPatterns(asyncPatterns);
591 
592   // Async lowering does not use type converter because it must preserve all
593   // types for async.runtime operations.
594   asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
595   asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
596                     AwaitAllOpLowering, YieldOpLowering>(ctx,
597                                                          outlinedFunctions);
598 
599   // Lower assertions to conditional branches into error blocks.
600   asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions);
601 
602   // All high level async operations must be lowered to the runtime operations.
603   ConversionTarget runtimeTarget(*ctx);
604   runtimeTarget.addLegalDialect<AsyncDialect>();
605   runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
606   runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
607 
608   // Decide if structured control flow has to be lowered to branch-based CFG.
609   runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
610     auto walkResult = op->walk([&](Operation *nested) {
611       bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
612       return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
613                                               : WalkResult::advance();
614     });
615     return !walkResult.wasInterrupted();
616   });
617   runtimeTarget.addLegalOp<BranchOp, CondBranchOp>();
618 
619   // Assertions must be converted to runtime errors inside async functions.
620   runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool {
621     auto func = op->getParentOfType<FuncOp>();
622     return outlinedFunctions.find(func) == outlinedFunctions.end();
623   });
624 
625   if (failed(applyPartialConversion(module, runtimeTarget,
626                                     std::move(asyncPatterns)))) {
627     signalPassFailure();
628     return;
629   }
630 }
631 
632 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
633   return std::make_unique<AsyncToAsyncRuntimePass>();
634 }
635