1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering from high level async operations to async.coro
10 // and async.runtime operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
16 #include "mlir/Dialect/Async/IR/Async.h"
17 #include "mlir/Dialect/Async/Passes.h"
18 #include "mlir/Dialect/SCF/SCF.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "mlir/Transforms/RegionUtils.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/Support/Debug.h"
27 
28 using namespace mlir;
29 using namespace mlir::async;
30 
31 #define DEBUG_TYPE "async-to-async-runtime"
32 // Prefix for functions outlined from `async.execute` op regions.
33 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
34 
35 namespace {
36 
37 class AsyncToAsyncRuntimePass
38     : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
39 public:
40   AsyncToAsyncRuntimePass() = default;
41   void runOnOperation() override;
42 };
43 
44 } // namespace
45 
46 //===----------------------------------------------------------------------===//
47 // async.execute op outlining to the coroutine functions.
48 //===----------------------------------------------------------------------===//
49 
50 /// Function targeted for coroutine transformation has two additional blocks at
51 /// the end: coroutine cleanup and coroutine suspension.
52 ///
53 /// async.await op lowering additionaly creates a resume block for each
54 /// operation to enable non-blocking waiting via coroutine suspension.
55 namespace {
56 struct CoroMachinery {
57   FuncOp func;
58 
59   // Async execute region returns a completion token, and an async value for
60   // each yielded value.
61   //
62   //   %token, %result = async.execute -> !async.value<T> {
63   //     %0 = constant ... : T
64   //     async.yield %0 : T
65   //   }
66   Value asyncToken; // token representing completion of the async region
67   llvm::SmallVector<Value, 4> returnValues; // returned async values
68 
69   Value coroHandle; // coroutine handle (!async.coro.handle value)
70   Block *entry;     // coroutine entry block
71   Block *setError;  // switch completion token and all values to error state
72   Block *cleanup;   // coroutine cleanup block
73   Block *suspend;   // coroutine suspension block
74 };
75 } // namespace
76 
77 /// Utility to partially update the regular function CFG to the coroutine CFG
78 /// compatible with LLVM coroutines switched-resume lowering using
79 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
80 /// that branches into preexisting entry block. Also inserts trailing blocks.
81 ///
82 /// The result types of the passed `func` must start with an `async.token`
83 /// and be continued with some number of `async.value`s.
84 ///
85 /// The func given to this function needs to have been preprocessed to have
86 /// either branch or yield ops as terminators. Branches to the cleanup block are
87 /// inserted after each yield.
88 ///
89 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
90 ///
91 ///  - `entry` block sets up the coroutine.
92 ///  - `set_error` block sets completion token and async values state to error.
93 ///  - `cleanup` block cleans up the coroutine state.
94 ///  - `suspend block after the @llvm.coro.end() defines what value will be
95 ///    returned to the initial caller of a coroutine. Everything before the
96 ///    @llvm.coro.end() will be executed at every suspension point.
97 ///
98 /// Coroutine structure (only the important bits):
99 ///
100 ///   func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
101 ///   {
102 ///     ^entry(<function-arguments>):
103 ///       %token = <async token> : !async.token    // create async runtime token
104 ///       %value = <async value> : !async.value<T> // create async value
105 ///       %id = async.coro.id                      // create a coroutine id
106 ///       %hdl = async.coro.begin %id              // create a coroutine handle
107 ///       br ^preexisting_entry_block
108 ///
109 ///     /*  preexisting blocks modified to branch to the cleanup block */
110 ///
111 ///     ^set_error: // this block created lazily only if needed (see code below)
112 ///       async.runtime.set_error %token : !async.token
113 ///       async.runtime.set_error %value : !async.value<T>
114 ///       br ^cleanup
115 ///
116 ///     ^cleanup:
117 ///       async.coro.free %hdl // delete the coroutine state
118 ///       br ^suspend
119 ///
120 ///     ^suspend:
121 ///       async.coro.end %hdl // marks the end of a coroutine
122 ///       return %token, %value : !async.token, !async.value<T>
123 ///   }
124 ///
125 static CoroMachinery setupCoroMachinery(FuncOp func) {
126   assert(!func.getBlocks().empty() && "Function must have an entry block");
127 
128   MLIRContext *ctx = func.getContext();
129   Block *entryBlock = &func.getBlocks().front();
130   Block *originalEntryBlock =
131       entryBlock->splitBlock(entryBlock->getOperations().begin());
132   auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
133 
134   // ------------------------------------------------------------------------ //
135   // Allocate async token/values that we will return from a ramp function.
136   // ------------------------------------------------------------------------ //
137   auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
138 
139   llvm::SmallVector<Value, 4> retValues;
140   for (auto resType : func.getCallableResults().drop_front())
141     retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
142 
143   // ------------------------------------------------------------------------ //
144   // Initialize coroutine: get coroutine id and coroutine handle.
145   // ------------------------------------------------------------------------ //
146   auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
147   auto coroHdlOp =
148       builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
149   builder.create<BranchOp>(originalEntryBlock);
150 
151   Block *cleanupBlock = func.addBlock();
152   Block *suspendBlock = func.addBlock();
153 
154   // ------------------------------------------------------------------------ //
155   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
156   // ------------------------------------------------------------------------ //
157   builder.setInsertionPointToStart(cleanupBlock);
158   builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
159 
160   // Branch into the suspend block.
161   builder.create<BranchOp>(suspendBlock);
162 
163   // ------------------------------------------------------------------------ //
164   // Coroutine suspend block: mark the end of a coroutine and return allocated
165   // async token.
166   // ------------------------------------------------------------------------ //
167   builder.setInsertionPointToStart(suspendBlock);
168 
169   // Mark the end of a coroutine: async.coro.end
170   builder.create<CoroEndOp>(coroHdlOp.handle());
171 
172   // Return created `async.token` and `async.values` from the suspend block.
173   // This will be the return value of a coroutine ramp function.
174   SmallVector<Value, 4> ret{retToken};
175   ret.insert(ret.end(), retValues.begin(), retValues.end());
176   builder.create<ReturnOp>(ret);
177 
178   // `async.await` op lowering will create resume blocks for async
179   // continuations, and will conditionally branch to cleanup or suspend blocks.
180 
181   for (Block &block : func.body().getBlocks()) {
182     if (&block == entryBlock || &block == cleanupBlock ||
183         &block == suspendBlock)
184       continue;
185     Operation *terminator = block.getTerminator();
186     if (auto yield = dyn_cast<YieldOp>(terminator)) {
187       builder.setInsertionPointToEnd(&block);
188       builder.create<BranchOp>(cleanupBlock);
189     }
190   }
191 
192   CoroMachinery machinery;
193   machinery.func = func;
194   machinery.asyncToken = retToken;
195   machinery.returnValues = retValues;
196   machinery.coroHandle = coroHdlOp.handle();
197   machinery.entry = entryBlock;
198   machinery.setError = nullptr; // created lazily only if needed
199   machinery.cleanup = cleanupBlock;
200   machinery.suspend = suspendBlock;
201   return machinery;
202 }
203 
204 // Lazily creates `set_error` block only if it is required for lowering to the
205 // runtime operations (see for example lowering of assert operation).
206 static Block *setupSetErrorBlock(CoroMachinery &coro) {
207   if (coro.setError)
208     return coro.setError;
209 
210   coro.setError = coro.func.addBlock();
211   coro.setError->moveBefore(coro.cleanup);
212 
213   auto builder =
214       ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError);
215 
216   // Coroutine set_error block: set error on token and all returned values.
217   builder.create<RuntimeSetErrorOp>(coro.asyncToken);
218   for (Value retValue : coro.returnValues)
219     builder.create<RuntimeSetErrorOp>(retValue);
220 
221   // Branch into the cleanup block.
222   builder.create<BranchOp>(coro.cleanup);
223 
224   return coro.setError;
225 }
226 
227 /// Outline the body region attached to the `async.execute` op into a standalone
228 /// function.
229 ///
230 /// Note that this is not reversible transformation.
231 static std::pair<FuncOp, CoroMachinery>
232 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
233   ModuleOp module = execute->getParentOfType<ModuleOp>();
234 
235   MLIRContext *ctx = module.getContext();
236   Location loc = execute.getLoc();
237 
238   // Make sure that all constants will be inside the outlined async function to
239   // reduce the number of function arguments.
240   cloneConstantsIntoTheRegion(execute.body());
241 
242   // Collect all outlined function inputs.
243   SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
244                                         execute.dependencies().end());
245   functionInputs.insert(execute.operands().begin(), execute.operands().end());
246   getUsedValuesDefinedAbove(execute.body(), functionInputs);
247 
248   // Collect types for the outlined function inputs and outputs.
249   auto typesRange = llvm::map_range(
250       functionInputs, [](Value value) { return value.getType(); });
251   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
252   auto outputTypes = execute.getResultTypes();
253 
254   auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
255   auto funcAttrs = ArrayRef<NamedAttribute>();
256 
257   // TODO: Derive outlined function name from the parent FuncOp (support
258   // multiple nested async.execute operations).
259   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
260   symbolTable.insert(func);
261 
262   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
263   auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
264 
265   // Prepare for coroutine conversion by creating the body of the function.
266   {
267     size_t numDependencies = execute.dependencies().size();
268     size_t numOperands = execute.operands().size();
269 
270     // Await on all dependencies before starting to execute the body region.
271     for (size_t i = 0; i < numDependencies; ++i)
272       builder.create<AwaitOp>(func.getArgument(i));
273 
274     // Await on all async value operands and unwrap the payload.
275     SmallVector<Value, 4> unwrappedOperands(numOperands);
276     for (size_t i = 0; i < numOperands; ++i) {
277       Value operand = func.getArgument(numDependencies + i);
278       unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
279     }
280 
281     // Map from function inputs defined above the execute op to the function
282     // arguments.
283     BlockAndValueMapping valueMapping;
284     valueMapping.map(functionInputs, func.getArguments());
285     valueMapping.map(execute.body().getArguments(), unwrappedOperands);
286 
287     // Clone all operations from the execute operation body into the outlined
288     // function body.
289     for (Operation &op : execute.body().getOps())
290       builder.clone(op, valueMapping);
291   }
292 
293   // Adding entry/cleanup/suspend blocks.
294   CoroMachinery coro = setupCoroMachinery(func);
295 
296   // Suspend async function at the end of an entry block, and resume it using
297   // Async resume operation (execution will be resumed in a thread managed by
298   // the async runtime).
299   {
300     BranchOp branch = cast<BranchOp>(coro.entry->getTerminator());
301     builder.setInsertionPointToEnd(coro.entry);
302 
303     // Save the coroutine state: async.coro.save
304     auto coroSaveOp =
305         builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
306 
307     // Pass coroutine to the runtime to be resumed on a runtime managed
308     // thread.
309     builder.create<RuntimeResumeOp>(coro.coroHandle);
310 
311     // Add async.coro.suspend as a suspended block terminator.
312     builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend,
313                                   branch.getDest(), coro.cleanup);
314 
315     branch.erase();
316   }
317 
318   // Replace the original `async.execute` with a call to outlined function.
319   {
320     ImplicitLocOpBuilder callBuilder(loc, execute);
321     auto callOutlinedFunc = callBuilder.create<CallOp>(
322         func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
323     execute.replaceAllUsesWith(callOutlinedFunc.getResults());
324     execute.erase();
325   }
326 
327   return {func, coro};
328 }
329 
330 //===----------------------------------------------------------------------===//
331 // Convert async.create_group operation to async.runtime.create_group
332 //===----------------------------------------------------------------------===//
333 
334 namespace {
335 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
336 public:
337   using OpConversionPattern::OpConversionPattern;
338 
339   LogicalResult
340   matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands,
341                   ConversionPatternRewriter &rewriter) const override {
342     rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
343         op, GroupType::get(op->getContext()), operands);
344     return success();
345   }
346 };
347 } // namespace
348 
349 //===----------------------------------------------------------------------===//
350 // Convert async.add_to_group operation to async.runtime.add_to_group.
351 //===----------------------------------------------------------------------===//
352 
353 namespace {
354 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
355 public:
356   using OpConversionPattern::OpConversionPattern;
357 
358   LogicalResult
359   matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands,
360                   ConversionPatternRewriter &rewriter) const override {
361     rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
362         op, rewriter.getIndexType(), operands);
363     return success();
364   }
365 };
366 } // namespace
367 
368 //===----------------------------------------------------------------------===//
369 // Convert async.await and async.await_all operations to the async.runtime.await
370 // or async.runtime.await_and_resume operations.
371 //===----------------------------------------------------------------------===//
372 
373 namespace {
374 template <typename AwaitType, typename AwaitableType>
375 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
376   using AwaitAdaptor = typename AwaitType::Adaptor;
377 
378 public:
379   AwaitOpLoweringBase(MLIRContext *ctx,
380                       llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
381       : OpConversionPattern<AwaitType>(ctx),
382         outlinedFunctions(outlinedFunctions) {}
383 
384   LogicalResult
385   matchAndRewrite(AwaitType op, ArrayRef<Value> operands,
386                   ConversionPatternRewriter &rewriter) const override {
387     // We can only await on one the `AwaitableType` (for `await` it can be
388     // a `token` or a `value`, for `await_all` it must be a `group`).
389     if (!op.operand().getType().template isa<AwaitableType>())
390       return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
391 
392     // Check if await operation is inside the outlined coroutine function.
393     auto func = op->template getParentOfType<FuncOp>();
394     auto outlined = outlinedFunctions.find(func);
395     const bool isInCoroutine = outlined != outlinedFunctions.end();
396 
397     Location loc = op->getLoc();
398     Value operand = AwaitAdaptor(operands).operand();
399 
400     Type i1 = rewriter.getI1Type();
401 
402     // Inside regular functions we use the blocking wait operation to wait for
403     // the async object (token, value or group) to become available.
404     if (!isInCoroutine) {
405       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
406       builder.create<RuntimeAwaitOp>(loc, operand);
407 
408       // Assert that the awaited operands is not in the error state.
409       Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
410       Value notError = builder.create<XOrOp>(
411           isError,
412           builder.create<ConstantOp>(loc, i1, builder.getIntegerAttr(i1, 1)));
413 
414       builder.create<AssertOp>(notError,
415                                "Awaited async operand is in error state");
416     }
417 
418     // Inside the coroutine we convert await operation into coroutine suspension
419     // point, and resume execution asynchronously.
420     if (isInCoroutine) {
421       CoroMachinery &coro = outlined->getSecond();
422       Block *suspended = op->getBlock();
423 
424       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
425       MLIRContext *ctx = op->getContext();
426 
427       // Save the coroutine state and resume on a runtime managed thread when
428       // the operand becomes available.
429       auto coroSaveOp =
430           builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
431       builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
432 
433       // Split the entry block before the await operation.
434       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
435 
436       // Add async.coro.suspend as a suspended block terminator.
437       builder.setInsertionPointToEnd(suspended);
438       builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
439                                     coro.cleanup);
440 
441       // Split the resume block into error checking and continuation.
442       Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
443 
444       // Check if the awaited value is in the error state.
445       builder.setInsertionPointToStart(resume);
446       auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
447       builder.create<CondBranchOp>(isError,
448                                    /*trueDest=*/setupSetErrorBlock(coro),
449                                    /*trueArgs=*/ArrayRef<Value>(),
450                                    /*falseDest=*/continuation,
451                                    /*falseArgs=*/ArrayRef<Value>());
452 
453       // Make sure that replacement value will be constructed in the
454       // continuation block.
455       rewriter.setInsertionPointToStart(continuation);
456     }
457 
458     // Erase or replace the await operation with the new value.
459     if (Value replaceWith = getReplacementValue(op, operand, rewriter))
460       rewriter.replaceOp(op, replaceWith);
461     else
462       rewriter.eraseOp(op);
463 
464     return success();
465   }
466 
467   virtual Value getReplacementValue(AwaitType op, Value operand,
468                                     ConversionPatternRewriter &rewriter) const {
469     return Value();
470   }
471 
472 private:
473   llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
474 };
475 
476 /// Lowering for `async.await` with a token operand.
477 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
478   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
479 
480 public:
481   using Base::Base;
482 };
483 
484 /// Lowering for `async.await` with a value operand.
485 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
486   using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
487 
488 public:
489   using Base::Base;
490 
491   Value
492   getReplacementValue(AwaitOp op, Value operand,
493                       ConversionPatternRewriter &rewriter) const override {
494     // Load from the async value storage.
495     auto valueType = operand.getType().cast<ValueType>().getValueType();
496     return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
497   }
498 };
499 
500 /// Lowering for `async.await_all` operation.
501 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
502   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
503 
504 public:
505   using Base::Base;
506 };
507 
508 } // namespace
509 
510 //===----------------------------------------------------------------------===//
511 // Convert async.yield operation to async.runtime operations.
512 //===----------------------------------------------------------------------===//
513 
514 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
515 public:
516   YieldOpLowering(
517       MLIRContext *ctx,
518       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
519       : OpConversionPattern<async::YieldOp>(ctx),
520         outlinedFunctions(outlinedFunctions) {}
521 
522   LogicalResult
523   matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
524                   ConversionPatternRewriter &rewriter) const override {
525     // Check if yield operation is inside the async coroutine function.
526     auto func = op->template getParentOfType<FuncOp>();
527     auto outlined = outlinedFunctions.find(func);
528     if (outlined == outlinedFunctions.end())
529       return rewriter.notifyMatchFailure(
530           op, "operation is not inside the async coroutine function");
531 
532     Location loc = op->getLoc();
533     const CoroMachinery &coro = outlined->getSecond();
534 
535     // Store yielded values into the async values storage and switch async
536     // values state to available.
537     for (auto tuple : llvm::zip(operands, coro.returnValues)) {
538       Value yieldValue = std::get<0>(tuple);
539       Value asyncValue = std::get<1>(tuple);
540       rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
541       rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
542     }
543 
544     // Switch the coroutine completion token to available state.
545     rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
546 
547     return success();
548   }
549 
550 private:
551   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
552 };
553 
554 //===----------------------------------------------------------------------===//
555 // Convert std.assert operation to cond_br into `set_error` block.
556 //===----------------------------------------------------------------------===//
557 
558 class AssertOpLowering : public OpConversionPattern<AssertOp> {
559 public:
560   AssertOpLowering(MLIRContext *ctx,
561                    llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
562       : OpConversionPattern<AssertOp>(ctx),
563         outlinedFunctions(outlinedFunctions) {}
564 
565   LogicalResult
566   matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
567                   ConversionPatternRewriter &rewriter) const override {
568     // Check if assert operation is inside the async coroutine function.
569     auto func = op->template getParentOfType<FuncOp>();
570     auto outlined = outlinedFunctions.find(func);
571     if (outlined == outlinedFunctions.end())
572       return rewriter.notifyMatchFailure(
573           op, "operation is not inside the async coroutine function");
574 
575     Location loc = op->getLoc();
576     CoroMachinery &coro = outlined->getSecond();
577 
578     Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
579     rewriter.setInsertionPointToEnd(cont->getPrevNode());
580     rewriter.create<CondBranchOp>(loc, AssertOpAdaptor(operands).arg(),
581                                   /*trueDest=*/cont,
582                                   /*trueArgs=*/ArrayRef<Value>(),
583                                   /*falseDest=*/setupSetErrorBlock(coro),
584                                   /*falseArgs=*/ArrayRef<Value>());
585     rewriter.eraseOp(op);
586 
587     return success();
588   }
589 
590 private:
591   llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
592 };
593 
594 //===----------------------------------------------------------------------===//
595 
596 /// Rewrite a func as a coroutine by:
597 /// 1) Wrapping the results into `async.value`.
598 /// 2) Prepending the results with `async.token`.
599 /// 3) Setting up coroutine blocks.
600 /// 4) Rewriting return ops as yield op and branch op into the suspend block.
601 static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) {
602   auto *ctx = func->getContext();
603   auto loc = func.getLoc();
604   SmallVector<Type> resultTypes;
605   resultTypes.reserve(func.getCallableResults().size());
606   llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes),
607                   [](Type type) { return ValueType::get(type); });
608   func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes));
609   func.insertResult(0, TokenType::get(ctx), {});
610   for (Block &block : func.getBlocks()) {
611     Operation *terminator = block.getTerminator();
612     if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) {
613       ImplicitLocOpBuilder builder(loc, returnOp);
614       builder.create<YieldOp>(returnOp.getOperands());
615       returnOp.erase();
616     }
617   }
618   return setupCoroMachinery(func);
619 }
620 
621 /// Rewrites a call into a function that has been rewritten as a coroutine.
622 ///
623 /// The invocation of this function is safe only when call ops are traversed in
624 /// reverse order of how they appear in a single block. See `funcsToCoroutines`.
625 static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) {
626   auto loc = func.getLoc();
627   ImplicitLocOpBuilder callBuilder(loc, oldCall);
628   auto newCall = callBuilder.create<CallOp>(
629       func.getName(), func.getCallableResults(), oldCall.getArgOperands());
630 
631   // Await on the async token and all the value results and unwrap the latter.
632   callBuilder.create<AwaitOp>(loc, newCall.getResults().front());
633   SmallVector<Value> unwrappedResults;
634   unwrappedResults.reserve(newCall->getResults().size() - 1);
635   for (Value result : newCall.getResults().drop_front())
636     unwrappedResults.push_back(
637         callBuilder.create<AwaitOp>(loc, result).result());
638   // Careful, when result of a call is piped into another call this could lead
639   // to a dangling pointer.
640   oldCall.replaceAllUsesWith(unwrappedResults);
641   oldCall.erase();
642 }
643 
644 static bool isAllowedToBlock(FuncOp func) {
645   return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName);
646 }
647 
648 static LogicalResult
649 funcsToCoroutines(ModuleOp module,
650                   llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) {
651   // The following code supports the general case when 2 functions mutually
652   // recurse into each other. Because of this and that we are relying on
653   // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase
654   // a FuncOp while inserting an equivalent coroutine, because that could lead
655   // to dangling pointers.
656 
657   SmallVector<FuncOp> funcWorklist;
658 
659   // Careful, it's okay to add a func to the worklist multiple times if and only
660   // if the loop processing the worklist will skip the functions that have
661   // already been converted to coroutines.
662   auto addToWorklist = [&](FuncOp func) {
663     if (isAllowedToBlock(func))
664       return;
665     // N.B. To refactor this code into a separate pass the lookup in
666     // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary
667     // func and recognizing if it has a coroutine structure is messy. Passing
668     // this dict between the passes is ugly.
669     if (isAllowedToBlock(func) ||
670         outlinedFunctions.find(func) == outlinedFunctions.end()) {
671       for (Operation &op : func.body().getOps()) {
672         if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
673           funcWorklist.push_back(func);
674           break;
675         }
676       }
677     }
678   };
679 
680   // Traverse in post-order collecting for each func op the await ops it has.
681   for (FuncOp func : module.getOps<FuncOp>())
682     addToWorklist(func);
683 
684   SymbolTableCollection symbolTable;
685   SymbolUserMap symbolUserMap(symbolTable, module);
686 
687   // Rewrite funcs, while updating call sites and adding them to the worklist.
688   while (!funcWorklist.empty()) {
689     auto func = funcWorklist.pop_back_val();
690     auto insertion = outlinedFunctions.insert({func, CoroMachinery{}});
691     if (!insertion.second)
692       // This function has already been processed because this is either
693       // the corecursive case, or a caller with multiple calls to a newly
694       // created corouting. Either way, skip updating the call sites.
695       continue;
696     insertion.first->second = rewriteFuncAsCoroutine(func);
697     SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(),
698                                    symbolUserMap.getUsers(func).end());
699     // If there are multiple calls from the same block they need to be traversed
700     // in reverse order so that symbolUserMap references are not invalidated
701     // when updating the users of the call op which is earlier in the block.
702     llvm::sort(users, [](Operation *a, Operation *b) {
703       Block *blockA = a->getBlock();
704       Block *blockB = b->getBlock();
705       // Impose arbitrary order on blocks so that there is a well-defined order.
706       return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b));
707     });
708     // Rewrite the callsites to await on results of the newly created coroutine.
709     for (Operation *op : users) {
710       if (CallOp call = dyn_cast<mlir::CallOp>(*op)) {
711         FuncOp caller = call->getParentOfType<FuncOp>();
712         rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op.
713         addToWorklist(caller);
714       } else {
715         op->emitError("Unexpected reference to func referenced by symbol");
716         return failure();
717       }
718     }
719   }
720   return success();
721 }
722 
723 //===----------------------------------------------------------------------===//
724 void AsyncToAsyncRuntimePass::runOnOperation() {
725   ModuleOp module = getOperation();
726   SymbolTable symbolTable(module);
727 
728   // Outline all `async.execute` body regions into async functions (coroutines).
729   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
730 
731   module.walk([&](ExecuteOp execute) {
732     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
733   });
734 
735   LLVM_DEBUG({
736     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
737                  << " functions built from async.execute operations\n";
738   });
739 
740   // Returns true if operation is inside the coroutine.
741   auto isInCoroutine = [&](Operation *op) -> bool {
742     auto parentFunc = op->getParentOfType<FuncOp>();
743     return outlinedFunctions.find(parentFunc) != outlinedFunctions.end();
744   };
745 
746   if (eliminateBlockingAwaitOps &&
747       failed(funcsToCoroutines(module, outlinedFunctions))) {
748     signalPassFailure();
749     return;
750   }
751 
752   // Lower async operations to async.runtime operations.
753   MLIRContext *ctx = module->getContext();
754   RewritePatternSet asyncPatterns(ctx);
755 
756   // Conversion to async runtime augments original CFG with the coroutine CFG,
757   // and we have to make sure that structured control flow operations with async
758   // operations in nested regions will be converted to branch-based control flow
759   // before we add the coroutine basic blocks.
760   populateLoopToStdConversionPatterns(asyncPatterns);
761 
762   // Async lowering does not use type converter because it must preserve all
763   // types for async.runtime operations.
764   asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
765   asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
766                     AwaitAllOpLowering, YieldOpLowering>(ctx,
767                                                          outlinedFunctions);
768 
769   // Lower assertions to conditional branches into error blocks.
770   asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions);
771 
772   // All high level async operations must be lowered to the runtime operations.
773   ConversionTarget runtimeTarget(*ctx);
774   runtimeTarget.addLegalDialect<AsyncDialect>();
775   runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
776   runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
777 
778   // Decide if structured control flow has to be lowered to branch-based CFG.
779   runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
780     auto walkResult = op->walk([&](Operation *nested) {
781       bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
782       return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
783                                               : WalkResult::advance();
784     });
785     return !walkResult.wasInterrupted();
786   });
787   runtimeTarget
788       .addLegalOp<AssertOp, XOrOp, ConstantOp, BranchOp, CondBranchOp>();
789 
790   // Assertions must be converted to runtime errors inside async functions.
791   runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool {
792     auto func = op->getParentOfType<FuncOp>();
793     return outlinedFunctions.find(func) == outlinedFunctions.end();
794   });
795 
796   if (eliminateBlockingAwaitOps)
797     runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(
798         [&](RuntimeAwaitOp op) -> bool {
799           return isAllowedToBlock(op->getParentOfType<FuncOp>());
800         });
801 
802   if (failed(applyPartialConversion(module, runtimeTarget,
803                                     std::move(asyncPatterns)))) {
804     signalPassFailure();
805     return;
806   }
807 }
808 
809 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
810   return std::make_unique<AsyncToAsyncRuntimePass>();
811 }
812