1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering from high level async operations to async.coro
10 // and async.runtime operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Async/IR/Async.h"
18 #include "mlir/Dialect/Async/Passes.h"
19 #include "mlir/Dialect/SCF/SCF.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/IR/BlockAndValueMapping.h"
22 #include "mlir/IR/ImplicitLocOpBuilder.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "mlir/Transforms/RegionUtils.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/Support/Debug.h"
28 
29 using namespace mlir;
30 using namespace mlir::async;
31 
32 #define DEBUG_TYPE "async-to-async-runtime"
33 // Prefix for functions outlined from `async.execute` op regions.
34 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
35 
36 namespace {
37 
38 class AsyncToAsyncRuntimePass
39     : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
40 public:
41   AsyncToAsyncRuntimePass() = default;
42   void runOnOperation() override;
43 };
44 
45 } // namespace
46 
47 //===----------------------------------------------------------------------===//
48 // async.execute op outlining to the coroutine functions.
49 //===----------------------------------------------------------------------===//
50 
51 /// Function targeted for coroutine transformation has two additional blocks at
52 /// the end: coroutine cleanup and coroutine suspension.
53 ///
54 /// async.await op lowering additionaly creates a resume block for each
55 /// operation to enable non-blocking waiting via coroutine suspension.
56 namespace {
57 struct CoroMachinery {
58   FuncOp func;
59 
60   // Async execute region returns a completion token, and an async value for
61   // each yielded value.
62   //
63   //   %token, %result = async.execute -> !async.value<T> {
64   //     %0 = arith.constant ... : T
65   //     async.yield %0 : T
66   //   }
67   Value asyncToken; // token representing completion of the async region
68   llvm::SmallVector<Value, 4> returnValues; // returned async values
69 
70   Value coroHandle; // coroutine handle (!async.coro.handle value)
71   Block *entry;     // coroutine entry block
72   Block *setError;  // switch completion token and all values to error state
73   Block *cleanup;   // coroutine cleanup block
74   Block *suspend;   // coroutine suspension block
75 };
76 } // namespace
77 
78 /// Utility to partially update the regular function CFG to the coroutine CFG
79 /// compatible with LLVM coroutines switched-resume lowering using
80 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
81 /// that branches into preexisting entry block. Also inserts trailing blocks.
82 ///
83 /// The result types of the passed `func` must start with an `async.token`
84 /// and be continued with some number of `async.value`s.
85 ///
86 /// The func given to this function needs to have been preprocessed to have
87 /// either branch or yield ops as terminators. Branches to the cleanup block are
88 /// inserted after each yield.
89 ///
90 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
91 ///
92 ///  - `entry` block sets up the coroutine.
93 ///  - `set_error` block sets completion token and async values state to error.
94 ///  - `cleanup` block cleans up the coroutine state.
95 ///  - `suspend block after the @llvm.coro.end() defines what value will be
96 ///    returned to the initial caller of a coroutine. Everything before the
97 ///    @llvm.coro.end() will be executed at every suspension point.
98 ///
99 /// Coroutine structure (only the important bits):
100 ///
101 ///   func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
102 ///   {
103 ///     ^entry(<function-arguments>):
104 ///       %token = <async token> : !async.token    // create async runtime token
105 ///       %value = <async value> : !async.value<T> // create async value
106 ///       %id = async.coro.id                      // create a coroutine id
107 ///       %hdl = async.coro.begin %id              // create a coroutine handle
108 ///       br ^preexisting_entry_block
109 ///
110 ///     /*  preexisting blocks modified to branch to the cleanup block */
111 ///
112 ///     ^set_error: // this block created lazily only if needed (see code below)
113 ///       async.runtime.set_error %token : !async.token
114 ///       async.runtime.set_error %value : !async.value<T>
115 ///       br ^cleanup
116 ///
117 ///     ^cleanup:
118 ///       async.coro.free %hdl // delete the coroutine state
119 ///       br ^suspend
120 ///
121 ///     ^suspend:
122 ///       async.coro.end %hdl // marks the end of a coroutine
123 ///       return %token, %value : !async.token, !async.value<T>
124 ///   }
125 ///
126 static CoroMachinery setupCoroMachinery(FuncOp func) {
127   assert(!func.getBlocks().empty() && "Function must have an entry block");
128 
129   MLIRContext *ctx = func.getContext();
130   Block *entryBlock = &func.getBlocks().front();
131   Block *originalEntryBlock =
132       entryBlock->splitBlock(entryBlock->getOperations().begin());
133   auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
134 
135   // ------------------------------------------------------------------------ //
136   // Allocate async token/values that we will return from a ramp function.
137   // ------------------------------------------------------------------------ //
138   auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
139 
140   llvm::SmallVector<Value, 4> retValues;
141   for (auto resType : func.getCallableResults().drop_front())
142     retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
143 
144   // ------------------------------------------------------------------------ //
145   // Initialize coroutine: get coroutine id and coroutine handle.
146   // ------------------------------------------------------------------------ //
147   auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
148   auto coroHdlOp =
149       builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
150   builder.create<BranchOp>(originalEntryBlock);
151 
152   Block *cleanupBlock = func.addBlock();
153   Block *suspendBlock = func.addBlock();
154 
155   // ------------------------------------------------------------------------ //
156   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
157   // ------------------------------------------------------------------------ //
158   builder.setInsertionPointToStart(cleanupBlock);
159   builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
160 
161   // Branch into the suspend block.
162   builder.create<BranchOp>(suspendBlock);
163 
164   // ------------------------------------------------------------------------ //
165   // Coroutine suspend block: mark the end of a coroutine and return allocated
166   // async token.
167   // ------------------------------------------------------------------------ //
168   builder.setInsertionPointToStart(suspendBlock);
169 
170   // Mark the end of a coroutine: async.coro.end
171   builder.create<CoroEndOp>(coroHdlOp.handle());
172 
173   // Return created `async.token` and `async.values` from the suspend block.
174   // This will be the return value of a coroutine ramp function.
175   SmallVector<Value, 4> ret{retToken};
176   ret.insert(ret.end(), retValues.begin(), retValues.end());
177   builder.create<ReturnOp>(ret);
178 
179   // `async.await` op lowering will create resume blocks for async
180   // continuations, and will conditionally branch to cleanup or suspend blocks.
181 
182   for (Block &block : func.body().getBlocks()) {
183     if (&block == entryBlock || &block == cleanupBlock ||
184         &block == suspendBlock)
185       continue;
186     Operation *terminator = block.getTerminator();
187     if (auto yield = dyn_cast<YieldOp>(terminator)) {
188       builder.setInsertionPointToEnd(&block);
189       builder.create<BranchOp>(cleanupBlock);
190     }
191   }
192 
193   CoroMachinery machinery;
194   machinery.func = func;
195   machinery.asyncToken = retToken;
196   machinery.returnValues = retValues;
197   machinery.coroHandle = coroHdlOp.handle();
198   machinery.entry = entryBlock;
199   machinery.setError = nullptr; // created lazily only if needed
200   machinery.cleanup = cleanupBlock;
201   machinery.suspend = suspendBlock;
202   return machinery;
203 }
204 
205 // Lazily creates `set_error` block only if it is required for lowering to the
206 // runtime operations (see for example lowering of assert operation).
207 static Block *setupSetErrorBlock(CoroMachinery &coro) {
208   if (coro.setError)
209     return coro.setError;
210 
211   coro.setError = coro.func.addBlock();
212   coro.setError->moveBefore(coro.cleanup);
213 
214   auto builder =
215       ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError);
216 
217   // Coroutine set_error block: set error on token and all returned values.
218   builder.create<RuntimeSetErrorOp>(coro.asyncToken);
219   for (Value retValue : coro.returnValues)
220     builder.create<RuntimeSetErrorOp>(retValue);
221 
222   // Branch into the cleanup block.
223   builder.create<BranchOp>(coro.cleanup);
224 
225   return coro.setError;
226 }
227 
228 /// Outline the body region attached to the `async.execute` op into a standalone
229 /// function.
230 ///
231 /// Note that this is not reversible transformation.
232 static std::pair<FuncOp, CoroMachinery>
233 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
234   ModuleOp module = execute->getParentOfType<ModuleOp>();
235 
236   MLIRContext *ctx = module.getContext();
237   Location loc = execute.getLoc();
238 
239   // Make sure that all constants will be inside the outlined async function to
240   // reduce the number of function arguments.
241   cloneConstantsIntoTheRegion(execute.body());
242 
243   // Collect all outlined function inputs.
244   SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
245                                         execute.dependencies().end());
246   functionInputs.insert(execute.operands().begin(), execute.operands().end());
247   getUsedValuesDefinedAbove(execute.body(), functionInputs);
248 
249   // Collect types for the outlined function inputs and outputs.
250   auto typesRange = llvm::map_range(
251       functionInputs, [](Value value) { return value.getType(); });
252   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
253   auto outputTypes = execute.getResultTypes();
254 
255   auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
256   auto funcAttrs = ArrayRef<NamedAttribute>();
257 
258   // TODO: Derive outlined function name from the parent FuncOp (support
259   // multiple nested async.execute operations).
260   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
261   symbolTable.insert(func);
262 
263   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
264   auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
265 
266   // Prepare for coroutine conversion by creating the body of the function.
267   {
268     size_t numDependencies = execute.dependencies().size();
269     size_t numOperands = execute.operands().size();
270 
271     // Await on all dependencies before starting to execute the body region.
272     for (size_t i = 0; i < numDependencies; ++i)
273       builder.create<AwaitOp>(func.getArgument(i));
274 
275     // Await on all async value operands and unwrap the payload.
276     SmallVector<Value, 4> unwrappedOperands(numOperands);
277     for (size_t i = 0; i < numOperands; ++i) {
278       Value operand = func.getArgument(numDependencies + i);
279       unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
280     }
281 
282     // Map from function inputs defined above the execute op to the function
283     // arguments.
284     BlockAndValueMapping valueMapping;
285     valueMapping.map(functionInputs, func.getArguments());
286     valueMapping.map(execute.body().getArguments(), unwrappedOperands);
287 
288     // Clone all operations from the execute operation body into the outlined
289     // function body.
290     for (Operation &op : execute.body().getOps())
291       builder.clone(op, valueMapping);
292   }
293 
294   // Adding entry/cleanup/suspend blocks.
295   CoroMachinery coro = setupCoroMachinery(func);
296 
297   // Suspend async function at the end of an entry block, and resume it using
298   // Async resume operation (execution will be resumed in a thread managed by
299   // the async runtime).
300   {
301     BranchOp branch = cast<BranchOp>(coro.entry->getTerminator());
302     builder.setInsertionPointToEnd(coro.entry);
303 
304     // Save the coroutine state: async.coro.save
305     auto coroSaveOp =
306         builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
307 
308     // Pass coroutine to the runtime to be resumed on a runtime managed
309     // thread.
310     builder.create<RuntimeResumeOp>(coro.coroHandle);
311 
312     // Add async.coro.suspend as a suspended block terminator.
313     builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend,
314                                   branch.getDest(), coro.cleanup);
315 
316     branch.erase();
317   }
318 
319   // Replace the original `async.execute` with a call to outlined function.
320   {
321     ImplicitLocOpBuilder callBuilder(loc, execute);
322     auto callOutlinedFunc = callBuilder.create<CallOp>(
323         func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
324     execute.replaceAllUsesWith(callOutlinedFunc.getResults());
325     execute.erase();
326   }
327 
328   return {func, coro};
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // Convert async.create_group operation to async.runtime.create_group
333 //===----------------------------------------------------------------------===//
334 
335 namespace {
336 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
337 public:
338   using OpConversionPattern::OpConversionPattern;
339 
340   LogicalResult
341   matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
342                   ConversionPatternRewriter &rewriter) const override {
343     rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
344         op, GroupType::get(op->getContext()), adaptor.getOperands());
345     return success();
346   }
347 };
348 } // namespace
349 
350 //===----------------------------------------------------------------------===//
351 // Convert async.add_to_group operation to async.runtime.add_to_group.
352 //===----------------------------------------------------------------------===//
353 
354 namespace {
355 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
356 public:
357   using OpConversionPattern::OpConversionPattern;
358 
359   LogicalResult
360   matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
361                   ConversionPatternRewriter &rewriter) const override {
362     rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
363         op, rewriter.getIndexType(), adaptor.getOperands());
364     return success();
365   }
366 };
367 } // namespace
368 
369 //===----------------------------------------------------------------------===//
370 // Convert async.await and async.await_all operations to the async.runtime.await
371 // or async.runtime.await_and_resume operations.
372 //===----------------------------------------------------------------------===//
373 
374 namespace {
375 template <typename AwaitType, typename AwaitableType>
376 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
377   using AwaitAdaptor = typename AwaitType::Adaptor;
378 
379 public:
380   AwaitOpLoweringBase(MLIRContext *ctx,
381                       llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
382       : OpConversionPattern<AwaitType>(ctx),
383         outlinedFunctions(outlinedFunctions) {}
384 
385   LogicalResult
386   matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
387                   ConversionPatternRewriter &rewriter) const override {
388     // We can only await on one the `AwaitableType` (for `await` it can be
389     // a `token` or a `value`, for `await_all` it must be a `group`).
390     if (!op.operand().getType().template isa<AwaitableType>())
391       return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
392 
393     // Check if await operation is inside the outlined coroutine function.
394     auto func = op->template getParentOfType<FuncOp>();
395     auto outlined = outlinedFunctions.find(func);
396     const bool isInCoroutine = outlined != outlinedFunctions.end();
397 
398     Location loc = op->getLoc();
399     Value operand = adaptor.operand();
400 
401     Type i1 = rewriter.getI1Type();
402 
403     // Inside regular functions we use the blocking wait operation to wait for
404     // the async object (token, value or group) to become available.
405     if (!isInCoroutine) {
406       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
407       builder.create<RuntimeAwaitOp>(loc, operand);
408 
409       // Assert that the awaited operands is not in the error state.
410       Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
411       Value notError = builder.create<arith::XOrIOp>(
412           isError, builder.create<arith::ConstantOp>(
413                        loc, i1, builder.getIntegerAttr(i1, 1)));
414 
415       builder.create<AssertOp>(notError,
416                                "Awaited async operand is in error state");
417     }
418 
419     // Inside the coroutine we convert await operation into coroutine suspension
420     // point, and resume execution asynchronously.
421     if (isInCoroutine) {
422       CoroMachinery &coro = outlined->getSecond();
423       Block *suspended = op->getBlock();
424 
425       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
426       MLIRContext *ctx = op->getContext();
427 
428       // Save the coroutine state and resume on a runtime managed thread when
429       // the operand becomes available.
430       auto coroSaveOp =
431           builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
432       builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
433 
434       // Split the entry block before the await operation.
435       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
436 
437       // Add async.coro.suspend as a suspended block terminator.
438       builder.setInsertionPointToEnd(suspended);
439       builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
440                                     coro.cleanup);
441 
442       // Split the resume block into error checking and continuation.
443       Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
444 
445       // Check if the awaited value is in the error state.
446       builder.setInsertionPointToStart(resume);
447       auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
448       builder.create<CondBranchOp>(isError,
449                                    /*trueDest=*/setupSetErrorBlock(coro),
450                                    /*trueArgs=*/ArrayRef<Value>(),
451                                    /*falseDest=*/continuation,
452                                    /*falseArgs=*/ArrayRef<Value>());
453 
454       // Make sure that replacement value will be constructed in the
455       // continuation block.
456       rewriter.setInsertionPointToStart(continuation);
457     }
458 
459     // Erase or replace the await operation with the new value.
460     if (Value replaceWith = getReplacementValue(op, operand, rewriter))
461       rewriter.replaceOp(op, replaceWith);
462     else
463       rewriter.eraseOp(op);
464 
465     return success();
466   }
467 
468   virtual Value getReplacementValue(AwaitType op, Value operand,
469                                     ConversionPatternRewriter &rewriter) const {
470     return Value();
471   }
472 
473 private:
474   llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
475 };
476 
477 /// Lowering for `async.await` with a token operand.
478 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
479   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
480 
481 public:
482   using Base::Base;
483 };
484 
485 /// Lowering for `async.await` with a value operand.
486 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
487   using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
488 
489 public:
490   using Base::Base;
491 
492   Value
493   getReplacementValue(AwaitOp op, Value operand,
494                       ConversionPatternRewriter &rewriter) const override {
495     // Load from the async value storage.
496     auto valueType = operand.getType().cast<ValueType>().getValueType();
497     return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
498   }
499 };
500 
501 /// Lowering for `async.await_all` operation.
502 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
503   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
504 
505 public:
506   using Base::Base;
507 };
508 
509 } // namespace
510 
511 //===----------------------------------------------------------------------===//
512 // Convert async.yield operation to async.runtime operations.
513 //===----------------------------------------------------------------------===//
514 
515 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
516 public:
517   YieldOpLowering(
518       MLIRContext *ctx,
519       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
520       : OpConversionPattern<async::YieldOp>(ctx),
521         outlinedFunctions(outlinedFunctions) {}
522 
523   LogicalResult
524   matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
525                   ConversionPatternRewriter &rewriter) const override {
526     // Check if yield operation is inside the async coroutine function.
527     auto func = op->template getParentOfType<FuncOp>();
528     auto outlined = outlinedFunctions.find(func);
529     if (outlined == outlinedFunctions.end())
530       return rewriter.notifyMatchFailure(
531           op, "operation is not inside the async coroutine function");
532 
533     Location loc = op->getLoc();
534     const CoroMachinery &coro = outlined->getSecond();
535 
536     // Store yielded values into the async values storage and switch async
537     // values state to available.
538     for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
539       Value yieldValue = std::get<0>(tuple);
540       Value asyncValue = std::get<1>(tuple);
541       rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
542       rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
543     }
544 
545     // Switch the coroutine completion token to available state.
546     rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
547 
548     return success();
549   }
550 
551 private:
552   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
553 };
554 
555 //===----------------------------------------------------------------------===//
556 // Convert std.assert operation to cond_br into `set_error` block.
557 //===----------------------------------------------------------------------===//
558 
559 class AssertOpLowering : public OpConversionPattern<AssertOp> {
560 public:
561   AssertOpLowering(MLIRContext *ctx,
562                    llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
563       : OpConversionPattern<AssertOp>(ctx),
564         outlinedFunctions(outlinedFunctions) {}
565 
566   LogicalResult
567   matchAndRewrite(AssertOp op, OpAdaptor adaptor,
568                   ConversionPatternRewriter &rewriter) const override {
569     // Check if assert operation is inside the async coroutine function.
570     auto func = op->template getParentOfType<FuncOp>();
571     auto outlined = outlinedFunctions.find(func);
572     if (outlined == outlinedFunctions.end())
573       return rewriter.notifyMatchFailure(
574           op, "operation is not inside the async coroutine function");
575 
576     Location loc = op->getLoc();
577     CoroMachinery &coro = outlined->getSecond();
578 
579     Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
580     rewriter.setInsertionPointToEnd(cont->getPrevNode());
581     rewriter.create<CondBranchOp>(loc, adaptor.getArg(),
582                                   /*trueDest=*/cont,
583                                   /*trueArgs=*/ArrayRef<Value>(),
584                                   /*falseDest=*/setupSetErrorBlock(coro),
585                                   /*falseArgs=*/ArrayRef<Value>());
586     rewriter.eraseOp(op);
587 
588     return success();
589   }
590 
591 private:
592   llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
593 };
594 
595 //===----------------------------------------------------------------------===//
596 
597 /// Rewrite a func as a coroutine by:
598 /// 1) Wrapping the results into `async.value`.
599 /// 2) Prepending the results with `async.token`.
600 /// 3) Setting up coroutine blocks.
601 /// 4) Rewriting return ops as yield op and branch op into the suspend block.
602 static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) {
603   auto *ctx = func->getContext();
604   auto loc = func.getLoc();
605   SmallVector<Type> resultTypes;
606   resultTypes.reserve(func.getCallableResults().size());
607   llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes),
608                   [](Type type) { return ValueType::get(type); });
609   func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes));
610   func.insertResult(0, TokenType::get(ctx), {});
611   for (Block &block : func.getBlocks()) {
612     Operation *terminator = block.getTerminator();
613     if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) {
614       ImplicitLocOpBuilder builder(loc, returnOp);
615       builder.create<YieldOp>(returnOp.getOperands());
616       returnOp.erase();
617     }
618   }
619   return setupCoroMachinery(func);
620 }
621 
622 /// Rewrites a call into a function that has been rewritten as a coroutine.
623 ///
624 /// The invocation of this function is safe only when call ops are traversed in
625 /// reverse order of how they appear in a single block. See `funcsToCoroutines`.
626 static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) {
627   auto loc = func.getLoc();
628   ImplicitLocOpBuilder callBuilder(loc, oldCall);
629   auto newCall = callBuilder.create<CallOp>(
630       func.getName(), func.getCallableResults(), oldCall.getArgOperands());
631 
632   // Await on the async token and all the value results and unwrap the latter.
633   callBuilder.create<AwaitOp>(loc, newCall.getResults().front());
634   SmallVector<Value> unwrappedResults;
635   unwrappedResults.reserve(newCall->getResults().size() - 1);
636   for (Value result : newCall.getResults().drop_front())
637     unwrappedResults.push_back(
638         callBuilder.create<AwaitOp>(loc, result).result());
639   // Careful, when result of a call is piped into another call this could lead
640   // to a dangling pointer.
641   oldCall.replaceAllUsesWith(unwrappedResults);
642   oldCall.erase();
643 }
644 
645 static bool isAllowedToBlock(FuncOp func) {
646   return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName);
647 }
648 
649 static LogicalResult
650 funcsToCoroutines(ModuleOp module,
651                   llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) {
652   // The following code supports the general case when 2 functions mutually
653   // recurse into each other. Because of this and that we are relying on
654   // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase
655   // a FuncOp while inserting an equivalent coroutine, because that could lead
656   // to dangling pointers.
657 
658   SmallVector<FuncOp> funcWorklist;
659 
660   // Careful, it's okay to add a func to the worklist multiple times if and only
661   // if the loop processing the worklist will skip the functions that have
662   // already been converted to coroutines.
663   auto addToWorklist = [&](FuncOp func) {
664     if (isAllowedToBlock(func))
665       return;
666     // N.B. To refactor this code into a separate pass the lookup in
667     // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary
668     // func and recognizing if it has a coroutine structure is messy. Passing
669     // this dict between the passes is ugly.
670     if (isAllowedToBlock(func) ||
671         outlinedFunctions.find(func) == outlinedFunctions.end()) {
672       for (Operation &op : func.body().getOps()) {
673         if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
674           funcWorklist.push_back(func);
675           break;
676         }
677       }
678     }
679   };
680 
681   // Traverse in post-order collecting for each func op the await ops it has.
682   for (FuncOp func : module.getOps<FuncOp>())
683     addToWorklist(func);
684 
685   SymbolTableCollection symbolTable;
686   SymbolUserMap symbolUserMap(symbolTable, module);
687 
688   // Rewrite funcs, while updating call sites and adding them to the worklist.
689   while (!funcWorklist.empty()) {
690     auto func = funcWorklist.pop_back_val();
691     auto insertion = outlinedFunctions.insert({func, CoroMachinery{}});
692     if (!insertion.second)
693       // This function has already been processed because this is either
694       // the corecursive case, or a caller with multiple calls to a newly
695       // created corouting. Either way, skip updating the call sites.
696       continue;
697     insertion.first->second = rewriteFuncAsCoroutine(func);
698     SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(),
699                                    symbolUserMap.getUsers(func).end());
700     // If there are multiple calls from the same block they need to be traversed
701     // in reverse order so that symbolUserMap references are not invalidated
702     // when updating the users of the call op which is earlier in the block.
703     llvm::sort(users, [](Operation *a, Operation *b) {
704       Block *blockA = a->getBlock();
705       Block *blockB = b->getBlock();
706       // Impose arbitrary order on blocks so that there is a well-defined order.
707       return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b));
708     });
709     // Rewrite the callsites to await on results of the newly created coroutine.
710     for (Operation *op : users) {
711       if (CallOp call = dyn_cast<mlir::CallOp>(*op)) {
712         FuncOp caller = call->getParentOfType<FuncOp>();
713         rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op.
714         addToWorklist(caller);
715       } else {
716         op->emitError("Unexpected reference to func referenced by symbol");
717         return failure();
718       }
719     }
720   }
721   return success();
722 }
723 
724 //===----------------------------------------------------------------------===//
725 void AsyncToAsyncRuntimePass::runOnOperation() {
726   ModuleOp module = getOperation();
727   SymbolTable symbolTable(module);
728 
729   // Outline all `async.execute` body regions into async functions (coroutines).
730   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
731 
732   module.walk([&](ExecuteOp execute) {
733     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
734   });
735 
736   LLVM_DEBUG({
737     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
738                  << " functions built from async.execute operations\n";
739   });
740 
741   // Returns true if operation is inside the coroutine.
742   auto isInCoroutine = [&](Operation *op) -> bool {
743     auto parentFunc = op->getParentOfType<FuncOp>();
744     return outlinedFunctions.find(parentFunc) != outlinedFunctions.end();
745   };
746 
747   if (eliminateBlockingAwaitOps &&
748       failed(funcsToCoroutines(module, outlinedFunctions))) {
749     signalPassFailure();
750     return;
751   }
752 
753   // Lower async operations to async.runtime operations.
754   MLIRContext *ctx = module->getContext();
755   RewritePatternSet asyncPatterns(ctx);
756 
757   // Conversion to async runtime augments original CFG with the coroutine CFG,
758   // and we have to make sure that structured control flow operations with async
759   // operations in nested regions will be converted to branch-based control flow
760   // before we add the coroutine basic blocks.
761   populateLoopToStdConversionPatterns(asyncPatterns);
762 
763   // Async lowering does not use type converter because it must preserve all
764   // types for async.runtime operations.
765   asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
766   asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
767                     AwaitAllOpLowering, YieldOpLowering>(ctx,
768                                                          outlinedFunctions);
769 
770   // Lower assertions to conditional branches into error blocks.
771   asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions);
772 
773   // All high level async operations must be lowered to the runtime operations.
774   ConversionTarget runtimeTarget(*ctx);
775   runtimeTarget.addLegalDialect<AsyncDialect>();
776   runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
777   runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
778 
779   // Decide if structured control flow has to be lowered to branch-based CFG.
780   runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
781     auto walkResult = op->walk([&](Operation *nested) {
782       bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
783       return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
784                                               : WalkResult::advance();
785     });
786     return !walkResult.wasInterrupted();
787   });
788   runtimeTarget.addLegalOp<AssertOp, arith::XOrIOp, arith::ConstantOp,
789                            ConstantOp, BranchOp, CondBranchOp>();
790 
791   // Assertions must be converted to runtime errors inside async functions.
792   runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool {
793     auto func = op->getParentOfType<FuncOp>();
794     return outlinedFunctions.find(func) == outlinedFunctions.end();
795   });
796 
797   if (eliminateBlockingAwaitOps)
798     runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(
799         [&](RuntimeAwaitOp op) -> bool {
800           return isAllowedToBlock(op->getParentOfType<FuncOp>());
801         });
802 
803   if (failed(applyPartialConversion(module, runtimeTarget,
804                                     std::move(asyncPatterns)))) {
805     signalPassFailure();
806     return;
807   }
808 }
809 
810 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
811   return std::make_unique<AsyncToAsyncRuntimePass>();
812 }
813