1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering from high level async operations to async.coro
10 // and async.runtime operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
16 #include "mlir/Dialect/Async/IR/Async.h"
17 #include "mlir/Dialect/Async/Passes.h"
18 #include "mlir/Dialect/SCF/SCF.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "mlir/Transforms/RegionUtils.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/Support/Debug.h"
27 
28 using namespace mlir;
29 using namespace mlir::async;
30 
31 #define DEBUG_TYPE "async-to-async-runtime"
32 // Prefix for functions outlined from `async.execute` op regions.
33 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
34 
35 namespace {
36 
37 class AsyncToAsyncRuntimePass
38     : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
39 public:
40   AsyncToAsyncRuntimePass() = default;
41   void runOnOperation() override;
42 };
43 
44 } // namespace
45 
46 //===----------------------------------------------------------------------===//
47 // async.execute op outlining to the coroutine functions.
48 //===----------------------------------------------------------------------===//
49 
50 /// Function targeted for coroutine transformation has two additional blocks at
51 /// the end: coroutine cleanup and coroutine suspension.
52 ///
53 /// async.await op lowering additionaly creates a resume block for each
54 /// operation to enable non-blocking waiting via coroutine suspension.
55 namespace {
56 struct CoroMachinery {
57   FuncOp func;
58 
59   // Async execute region returns a completion token, and an async value for
60   // each yielded value.
61   //
62   //   %token, %result = async.execute -> !async.value<T> {
63   //     %0 = constant ... : T
64   //     async.yield %0 : T
65   //   }
66   Value asyncToken; // token representing completion of the async region
67   llvm::SmallVector<Value, 4> returnValues; // returned async values
68 
69   Value coroHandle; // coroutine handle (!async.coro.handle value)
70   Block *setError;  // switch completion token and all values to error state
71   Block *cleanup;   // coroutine cleanup block
72   Block *suspend;   // coroutine suspension block
73 };
74 } // namespace
75 
76 /// Utility to partially update the regular function CFG to the coroutine CFG
77 /// compatible with LLVM coroutines switched-resume lowering using
78 /// `async.runtime.*` and `async.coro.*` operations. Modifies the entry block
79 /// by prepending its ops with coroutine setup. Also inserts trailing blocks.
80 ///
81 /// The result types of the passed `func` must start with an `async.token`
82 /// and be continued with some number of `async.value`s.
83 ///
84 /// It's up to the caller of this function to fix up the terminators of the
85 /// preexisting blocks of the passed func op. If the passed `func` is legal,
86 /// this typically means rewriting every return op as a yield op and a branch op
87 /// to the suspend block.
88 ///
89 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
90 ///
91 ///  - `entry` block sets up the coroutine.
92 ///  - `set_error` block sets completion token and async values state to error.
93 ///  - `cleanup` block cleans up the coroutine state.
94 ///  - `suspend block after the @llvm.coro.end() defines what value will be
95 ///    returned to the initial caller of a coroutine. Everything before the
96 ///    @llvm.coro.end() will be executed at every suspension point.
97 ///
98 /// Coroutine structure (only the important bits):
99 ///
100 ///   func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
101 ///   {
102 ///     ^entry(<function-arguments>):
103 ///       %token = <async token> : !async.token    // create async runtime token
104 ///       %value = <async value> : !async.value<T> // create async value
105 ///       %id = async.coro.id                      // create a coroutine id
106 ///       %hdl = async.coro.begin %id              // create a coroutine handle
107 ///       /* other ops of the preexisting entry block */
108 ///
109 ///     /* other preexisting blocks */
110 ///
111 ///     ^set_error: // this block created lazily only if needed (see code below)
112 ///       async.runtime.set_error %token : !async.token
113 ///       async.runtime.set_error %value : !async.value<T>
114 ///       br ^cleanup
115 ///
116 ///     ^cleanup:
117 ///       async.coro.free %hdl // delete the coroutine state
118 ///       br ^suspend
119 ///
120 ///     ^suspend:
121 ///       async.coro.end %hdl // marks the end of a coroutine
122 ///       return %token, %value : !async.token, !async.value<T>
123 ///   }
124 ///
125 static CoroMachinery setupCoroMachinery(FuncOp func) {
126   assert(!func.getBlocks().empty() && "Function must have an entry block");
127 
128   MLIRContext *ctx = func.getContext();
129   Block *entryBlock = &func.getBlocks().front();
130   auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
131 
132   // ------------------------------------------------------------------------ //
133   // Allocate async token/values that we will return from a ramp function.
134   // ------------------------------------------------------------------------ //
135   auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
136 
137   llvm::SmallVector<Value, 4> retValues;
138   for (auto resType : func.getCallableResults().drop_front())
139     retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
140 
141   // ------------------------------------------------------------------------ //
142   // Initialize coroutine: get coroutine id and coroutine handle.
143   // ------------------------------------------------------------------------ //
144   auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
145   auto coroHdlOp =
146       builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
147 
148   Block *cleanupBlock = func.addBlock();
149   Block *suspendBlock = func.addBlock();
150 
151   // ------------------------------------------------------------------------ //
152   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
153   // ------------------------------------------------------------------------ //
154   builder.setInsertionPointToStart(cleanupBlock);
155   builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
156 
157   // Branch into the suspend block.
158   builder.create<BranchOp>(suspendBlock);
159 
160   // ------------------------------------------------------------------------ //
161   // Coroutine suspend block: mark the end of a coroutine and return allocated
162   // async token.
163   // ------------------------------------------------------------------------ //
164   builder.setInsertionPointToStart(suspendBlock);
165 
166   // Mark the end of a coroutine: async.coro.end
167   builder.create<CoroEndOp>(coroHdlOp.handle());
168 
169   // Return created `async.token` and `async.values` from the suspend block.
170   // This will be the return value of a coroutine ramp function.
171   SmallVector<Value, 4> ret{retToken};
172   ret.insert(ret.end(), retValues.begin(), retValues.end());
173   builder.create<ReturnOp>(ret);
174 
175   // `async.await` op lowering will create resume blocks for async
176   // continuations, and will conditionally branch to cleanup or suspend blocks.
177 
178   CoroMachinery machinery;
179   machinery.func = func;
180   machinery.asyncToken = retToken;
181   machinery.returnValues = retValues;
182   machinery.coroHandle = coroHdlOp.handle();
183   machinery.setError = nullptr; // created lazily only if needed
184   machinery.cleanup = cleanupBlock;
185   machinery.suspend = suspendBlock;
186   return machinery;
187 }
188 
189 // Lazily creates `set_error` block only if it is required for lowering to the
190 // runtime operations (see for example lowering of assert operation).
191 static Block *setupSetErrorBlock(CoroMachinery &coro) {
192   if (coro.setError)
193     return coro.setError;
194 
195   coro.setError = coro.func.addBlock();
196   coro.setError->moveBefore(coro.cleanup);
197 
198   auto builder =
199       ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError);
200 
201   // Coroutine set_error block: set error on token and all returned values.
202   builder.create<RuntimeSetErrorOp>(coro.asyncToken);
203   for (Value retValue : coro.returnValues)
204     builder.create<RuntimeSetErrorOp>(retValue);
205 
206   // Branch into the cleanup block.
207   builder.create<BranchOp>(coro.cleanup);
208 
209   return coro.setError;
210 }
211 
212 /// Outline the body region attached to the `async.execute` op into a standalone
213 /// function.
214 ///
215 /// Note that this is not reversible transformation.
216 static std::pair<FuncOp, CoroMachinery>
217 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
218   ModuleOp module = execute->getParentOfType<ModuleOp>();
219 
220   MLIRContext *ctx = module.getContext();
221   Location loc = execute.getLoc();
222 
223   // Collect all outlined function inputs.
224   SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
225                                         execute.dependencies().end());
226   functionInputs.insert(execute.operands().begin(), execute.operands().end());
227   getUsedValuesDefinedAbove(execute.body(), functionInputs);
228 
229   // Collect types for the outlined function inputs and outputs.
230   auto typesRange = llvm::map_range(
231       functionInputs, [](Value value) { return value.getType(); });
232   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
233   auto outputTypes = execute.getResultTypes();
234 
235   auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
236   auto funcAttrs = ArrayRef<NamedAttribute>();
237 
238   // TODO: Derive outlined function name from the parent FuncOp (support
239   // multiple nested async.execute operations).
240   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
241   symbolTable.insert(func);
242 
243   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
244 
245   // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
246   // blocks, adding async.coro operations and setting up control flow.
247   func.addEntryBlock();
248   CoroMachinery coro = setupCoroMachinery(func);
249 
250   // Suspend async function at the end of an entry block, and resume it using
251   // Async resume operation (execution will be resumed in a thread managed by
252   // the async runtime).
253   Block *entryBlock = &func.getBlocks().front();
254   auto builder = ImplicitLocOpBuilder::atBlockEnd(loc, entryBlock);
255 
256   // Save the coroutine state: async.coro.save
257   auto coroSaveOp =
258       builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
259 
260   // Pass coroutine to the runtime to be resumed on a runtime managed thread.
261   builder.create<RuntimeResumeOp>(coro.coroHandle);
262   builder.create<BranchOp>(coro.cleanup);
263 
264   // Split the entry block before the terminator (branch to suspend block).
265   auto *terminatorOp = entryBlock->getTerminator();
266   Block *suspended = terminatorOp->getBlock();
267   Block *resume = suspended->splitBlock(terminatorOp);
268 
269   // Add async.coro.suspend as a suspended block terminator.
270   builder.setInsertionPointToEnd(suspended);
271   builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
272                                 coro.cleanup);
273 
274   size_t numDependencies = execute.dependencies().size();
275   size_t numOperands = execute.operands().size();
276 
277   // Await on all dependencies before starting to execute the body region.
278   builder.setInsertionPointToStart(resume);
279   for (size_t i = 0; i < numDependencies; ++i)
280     builder.create<AwaitOp>(func.getArgument(i));
281 
282   // Await on all async value operands and unwrap the payload.
283   SmallVector<Value, 4> unwrappedOperands(numOperands);
284   for (size_t i = 0; i < numOperands; ++i) {
285     Value operand = func.getArgument(numDependencies + i);
286     unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
287   }
288 
289   // Map from function inputs defined above the execute op to the function
290   // arguments.
291   BlockAndValueMapping valueMapping;
292   valueMapping.map(functionInputs, func.getArguments());
293   valueMapping.map(execute.body().getArguments(), unwrappedOperands);
294 
295   // Clone all operations from the execute operation body into the outlined
296   // function body.
297   for (Operation &op : execute.body().getOps())
298     builder.clone(op, valueMapping);
299 
300   // Replace the original `async.execute` with a call to outlined function.
301   ImplicitLocOpBuilder callBuilder(loc, execute);
302   auto callOutlinedFunc = callBuilder.create<CallOp>(
303       func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
304   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
305   execute.erase();
306 
307   return {func, coro};
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // Convert async.create_group operation to async.runtime.create_group
312 //===----------------------------------------------------------------------===//
313 
314 namespace {
315 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
316 public:
317   using OpConversionPattern::OpConversionPattern;
318 
319   LogicalResult
320   matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands,
321                   ConversionPatternRewriter &rewriter) const override {
322     rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
323         op, GroupType::get(op->getContext()), operands);
324     return success();
325   }
326 };
327 } // namespace
328 
329 //===----------------------------------------------------------------------===//
330 // Convert async.add_to_group operation to async.runtime.add_to_group.
331 //===----------------------------------------------------------------------===//
332 
333 namespace {
334 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
335 public:
336   using OpConversionPattern::OpConversionPattern;
337 
338   LogicalResult
339   matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands,
340                   ConversionPatternRewriter &rewriter) const override {
341     rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
342         op, rewriter.getIndexType(), operands);
343     return success();
344   }
345 };
346 } // namespace
347 
348 //===----------------------------------------------------------------------===//
349 // Convert async.await and async.await_all operations to the async.runtime.await
350 // or async.runtime.await_and_resume operations.
351 //===----------------------------------------------------------------------===//
352 
353 namespace {
354 template <typename AwaitType, typename AwaitableType>
355 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
356   using AwaitAdaptor = typename AwaitType::Adaptor;
357 
358 public:
359   AwaitOpLoweringBase(MLIRContext *ctx,
360                       llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
361       : OpConversionPattern<AwaitType>(ctx),
362         outlinedFunctions(outlinedFunctions) {}
363 
364   LogicalResult
365   matchAndRewrite(AwaitType op, ArrayRef<Value> operands,
366                   ConversionPatternRewriter &rewriter) const override {
367     // We can only await on one the `AwaitableType` (for `await` it can be
368     // a `token` or a `value`, for `await_all` it must be a `group`).
369     if (!op.operand().getType().template isa<AwaitableType>())
370       return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
371 
372     // Check if await operation is inside the outlined coroutine function.
373     auto func = op->template getParentOfType<FuncOp>();
374     auto outlined = outlinedFunctions.find(func);
375     const bool isInCoroutine = outlined != outlinedFunctions.end();
376 
377     Location loc = op->getLoc();
378     Value operand = AwaitAdaptor(operands).operand();
379 
380     // Inside regular functions we use the blocking wait operation to wait for
381     // the async object (token, value or group) to become available.
382     if (!isInCoroutine)
383       rewriter.create<RuntimeAwaitOp>(loc, operand);
384 
385     // Inside the coroutine we convert await operation into coroutine suspension
386     // point, and resume execution asynchronously.
387     if (isInCoroutine) {
388       CoroMachinery &coro = outlined->getSecond();
389       Block *suspended = op->getBlock();
390 
391       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
392       MLIRContext *ctx = op->getContext();
393 
394       // Save the coroutine state and resume on a runtime managed thread when
395       // the operand becomes available.
396       auto coroSaveOp =
397           builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
398       builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
399 
400       // Split the entry block before the await operation.
401       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
402 
403       // Add async.coro.suspend as a suspended block terminator.
404       builder.setInsertionPointToEnd(suspended);
405       builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
406                                     coro.cleanup);
407 
408       // Split the resume block into error checking and continuation.
409       Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
410 
411       // Check if the awaited value is in the error state.
412       builder.setInsertionPointToStart(resume);
413       auto isError =
414           builder.create<RuntimeIsErrorOp>(loc, rewriter.getI1Type(), operand);
415       builder.create<CondBranchOp>(isError,
416                                    /*trueDest=*/setupSetErrorBlock(coro),
417                                    /*trueArgs=*/ArrayRef<Value>(),
418                                    /*falseDest=*/continuation,
419                                    /*falseArgs=*/ArrayRef<Value>());
420 
421       // Make sure that replacement value will be constructed in the
422       // continuation block.
423       rewriter.setInsertionPointToStart(continuation);
424     }
425 
426     // Erase or replace the await operation with the new value.
427     if (Value replaceWith = getReplacementValue(op, operand, rewriter))
428       rewriter.replaceOp(op, replaceWith);
429     else
430       rewriter.eraseOp(op);
431 
432     return success();
433   }
434 
435   virtual Value getReplacementValue(AwaitType op, Value operand,
436                                     ConversionPatternRewriter &rewriter) const {
437     return Value();
438   }
439 
440 private:
441   llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
442 };
443 
444 /// Lowering for `async.await` with a token operand.
445 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
446   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
447 
448 public:
449   using Base::Base;
450 };
451 
452 /// Lowering for `async.await` with a value operand.
453 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
454   using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
455 
456 public:
457   using Base::Base;
458 
459   Value
460   getReplacementValue(AwaitOp op, Value operand,
461                       ConversionPatternRewriter &rewriter) const override {
462     // Load from the async value storage.
463     auto valueType = operand.getType().cast<ValueType>().getValueType();
464     return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
465   }
466 };
467 
468 /// Lowering for `async.await_all` operation.
469 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
470   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
471 
472 public:
473   using Base::Base;
474 };
475 
476 } // namespace
477 
478 //===----------------------------------------------------------------------===//
479 // Convert async.yield operation to async.runtime operations.
480 //===----------------------------------------------------------------------===//
481 
482 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
483 public:
484   YieldOpLowering(
485       MLIRContext *ctx,
486       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
487       : OpConversionPattern<async::YieldOp>(ctx),
488         outlinedFunctions(outlinedFunctions) {}
489 
490   LogicalResult
491   matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
492                   ConversionPatternRewriter &rewriter) const override {
493     // Check if yield operation is inside the async coroutine function.
494     auto func = op->template getParentOfType<FuncOp>();
495     auto outlined = outlinedFunctions.find(func);
496     if (outlined == outlinedFunctions.end())
497       return rewriter.notifyMatchFailure(
498           op, "operation is not inside the async coroutine function");
499 
500     Location loc = op->getLoc();
501     const CoroMachinery &coro = outlined->getSecond();
502 
503     // Store yielded values into the async values storage and switch async
504     // values state to available.
505     for (auto tuple : llvm::zip(operands, coro.returnValues)) {
506       Value yieldValue = std::get<0>(tuple);
507       Value asyncValue = std::get<1>(tuple);
508       rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
509       rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
510     }
511 
512     // Switch the coroutine completion token to available state.
513     rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
514 
515     return success();
516   }
517 
518 private:
519   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
520 };
521 
522 //===----------------------------------------------------------------------===//
523 // Convert std.assert operation to cond_br into `set_error` block.
524 //===----------------------------------------------------------------------===//
525 
526 class AssertOpLowering : public OpConversionPattern<AssertOp> {
527 public:
528   AssertOpLowering(MLIRContext *ctx,
529                    llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
530       : OpConversionPattern<AssertOp>(ctx),
531         outlinedFunctions(outlinedFunctions) {}
532 
533   LogicalResult
534   matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
535                   ConversionPatternRewriter &rewriter) const override {
536     // Check if assert operation is inside the async coroutine function.
537     auto func = op->template getParentOfType<FuncOp>();
538     auto outlined = outlinedFunctions.find(func);
539     if (outlined == outlinedFunctions.end())
540       return rewriter.notifyMatchFailure(
541           op, "operation is not inside the async coroutine function");
542 
543     Location loc = op->getLoc();
544     CoroMachinery &coro = outlined->getSecond();
545 
546     Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
547     rewriter.setInsertionPointToEnd(cont->getPrevNode());
548     rewriter.create<CondBranchOp>(loc, AssertOpAdaptor(operands).arg(),
549                                   /*trueDest=*/cont,
550                                   /*trueArgs=*/ArrayRef<Value>(),
551                                   /*falseDest=*/setupSetErrorBlock(coro),
552                                   /*falseArgs=*/ArrayRef<Value>());
553     rewriter.eraseOp(op);
554 
555     return success();
556   }
557 
558 private:
559   llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
560 };
561 
562 //===----------------------------------------------------------------------===//
563 
564 /// Rewrite a func as a coroutine by:
565 /// 1) Wrapping the results into `async.value`.
566 /// 2) Prepending the results with `async.token`.
567 /// 3) Setting up coroutine blocks.
568 /// 4) Rewriting return ops as yield op and branch op into the suspend block.
569 static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) {
570   auto *ctx = func->getContext();
571   auto loc = func.getLoc();
572   SmallVector<Type> resultTypes;
573   resultTypes.reserve(func.getCallableResults().size());
574   llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes),
575                   [](Type type) { return ValueType::get(type); });
576   func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes));
577   func.insertResult(0, TokenType::get(ctx), {});
578   CoroMachinery coro = setupCoroMachinery(func);
579   for (Block &block : func.getBlocks()) {
580     if (&block == coro.suspend)
581       continue;
582 
583     Operation *terminator = block.getTerminator();
584     if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) {
585       ImplicitLocOpBuilder builder(loc, returnOp);
586       builder.create<YieldOp>(returnOp.getOperands());
587       builder.create<BranchOp>(coro.cleanup);
588       returnOp.erase();
589     }
590   }
591   return coro;
592 }
593 
594 /// Rewrites a call into a function that has been rewritten as a coroutine.
595 ///
596 /// The invocation of this function is safe only when call ops are traversed in
597 /// reverse order of how they appear in a single block. See `funcsToCoroutines`.
598 static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) {
599   auto loc = func.getLoc();
600   ImplicitLocOpBuilder callBuilder(loc, oldCall);
601   auto newCall = callBuilder.create<CallOp>(
602       func.getName(), func.getCallableResults(), oldCall.getArgOperands());
603 
604   // Await on the async token and all the value results and unwrap the latter.
605   callBuilder.create<AwaitOp>(loc, newCall.getResults().front());
606   SmallVector<Value> unwrappedResults;
607   unwrappedResults.reserve(newCall->getResults().size() - 1);
608   for (Value result : newCall.getResults().drop_front())
609     unwrappedResults.push_back(
610         callBuilder.create<AwaitOp>(loc, result).result());
611   // Careful, when result of a call is piped into another call this could lead
612   // to a dangling pointer.
613   oldCall.replaceAllUsesWith(unwrappedResults);
614   oldCall.erase();
615 }
616 
617 static LogicalResult
618 funcsToCoroutines(ModuleOp module,
619                   llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) {
620   // The following code supports the general case when 2 functions mutually
621   // recurse into each other. Because of this and that we are relying on
622   // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase
623   // a FuncOp while inserting an equivalent coroutine, because that could lead
624   // to dangling pointers.
625 
626   SmallVector<FuncOp> funcWorklist;
627 
628   // Careful, it's okay to add a func to the worklist multiple times if and only
629   // if the loop processing the worklist will skip the functions that have
630   // already been converted to coroutines.
631   auto addToWorklist = [&outlinedFunctions, &funcWorklist](FuncOp func) {
632     // N.B. To refactor this code into a separate pass the lookup in
633     // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary
634     // func and recognizing if it has a coroutine structure is messy. Passing
635     // this dict between the passes is ugly.
636     if (outlinedFunctions.find(func) == outlinedFunctions.end()) {
637       for (Operation &op : func.body().getOps()) {
638         if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
639           funcWorklist.push_back(func);
640           break;
641         }
642       }
643     }
644   };
645 
646   // Traverse in post-order collecting for each func op the await ops it has.
647   for (FuncOp func : module.getOps<FuncOp>())
648     addToWorklist(func);
649 
650   SymbolTableCollection symbolTable;
651   SymbolUserMap symbolUserMap(symbolTable, module);
652 
653   // Rewrite funcs, while updating call sites and adding them to the worklist.
654   while (!funcWorklist.empty()) {
655     auto func = funcWorklist.pop_back_val();
656     auto insertion = outlinedFunctions.insert({func, CoroMachinery{}});
657     if (!insertion.second)
658       // This function has already been processed because this is either
659       // the corecursive case, or a caller with multiple calls to a newly
660       // created corouting. Either way, skip updating the call sites.
661       continue;
662     insertion.first->second = rewriteFuncAsCoroutine(func);
663     SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(),
664                                    symbolUserMap.getUsers(func).end());
665     // If there are multiple calls from the same block they need to be traversed
666     // in reverse order so that symbolUserMap references are not invalidated
667     // when updating the users of the call op which is earlier in the block.
668     llvm::sort(users, [](Operation *a, Operation *b) {
669       Block *blockA = a->getBlock();
670       Block *blockB = b->getBlock();
671       // Impose arbitrary order on blocks so that there is a well-defined order.
672       return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b));
673     });
674     // Rewrite the callsites to await on results of the newly created coroutine.
675     for (Operation *op : users) {
676       if (CallOp call = dyn_cast<mlir::CallOp>(*op)) {
677         FuncOp caller = call->getParentOfType<FuncOp>();
678         rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op.
679         addToWorklist(caller);
680       } else {
681         op->emitError("Unexpected reference to func referenced by symbol");
682         return failure();
683       }
684     }
685   }
686   return success();
687 }
688 
689 //===----------------------------------------------------------------------===//
690 void AsyncToAsyncRuntimePass::runOnOperation() {
691   ModuleOp module = getOperation();
692   SymbolTable symbolTable(module);
693 
694   // Outline all `async.execute` body regions into async functions (coroutines).
695   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
696 
697   module.walk([&](ExecuteOp execute) {
698     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
699   });
700 
701   LLVM_DEBUG({
702     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
703                  << " functions built from async.execute operations\n";
704   });
705 
706   // Returns true if operation is inside the coroutine.
707   auto isInCoroutine = [&](Operation *op) -> bool {
708     auto parentFunc = op->getParentOfType<FuncOp>();
709     return outlinedFunctions.find(parentFunc) != outlinedFunctions.end();
710   };
711 
712   if (eliminateBlockingAwaitOps &&
713       failed(funcsToCoroutines(module, outlinedFunctions))) {
714     signalPassFailure();
715     return;
716   }
717 
718   // Lower async operations to async.runtime operations.
719   MLIRContext *ctx = module->getContext();
720   RewritePatternSet asyncPatterns(ctx);
721 
722   // Conversion to async runtime augments original CFG with the coroutine CFG,
723   // and we have to make sure that structured control flow operations with async
724   // operations in nested regions will be converted to branch-based control flow
725   // before we add the coroutine basic blocks.
726   populateLoopToStdConversionPatterns(asyncPatterns);
727 
728   // Async lowering does not use type converter because it must preserve all
729   // types for async.runtime operations.
730   asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
731   asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
732                     AwaitAllOpLowering, YieldOpLowering>(ctx,
733                                                          outlinedFunctions);
734 
735   // Lower assertions to conditional branches into error blocks.
736   asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions);
737 
738   // All high level async operations must be lowered to the runtime operations.
739   ConversionTarget runtimeTarget(*ctx);
740   runtimeTarget.addLegalDialect<AsyncDialect>();
741   runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
742   runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
743 
744   // Decide if structured control flow has to be lowered to branch-based CFG.
745   runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
746     auto walkResult = op->walk([&](Operation *nested) {
747       bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
748       return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
749                                               : WalkResult::advance();
750     });
751     return !walkResult.wasInterrupted();
752   });
753   runtimeTarget.addLegalOp<BranchOp, CondBranchOp>();
754 
755   // Assertions must be converted to runtime errors inside async functions.
756   runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool {
757     auto func = op->getParentOfType<FuncOp>();
758     return outlinedFunctions.find(func) == outlinedFunctions.end();
759   });
760 
761   if (eliminateBlockingAwaitOps)
762     runtimeTarget.addIllegalOp<RuntimeAwaitOp>();
763 
764   if (failed(applyPartialConversion(module, runtimeTarget,
765                                     std::move(asyncPatterns)))) {
766     signalPassFailure();
767     return;
768   }
769 }
770 
771 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
772   return std::make_unique<AsyncToAsyncRuntimePass>();
773 }
774