1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering from high level async operations to async.coro
10 // and async.runtime operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Async/IR/Async.h"
16 #include "mlir/Dialect/Async/Passes.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "mlir/Transforms/RegionUtils.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/Support/Debug.h"
25 
26 using namespace mlir;
27 using namespace mlir::async;
28 
29 #define DEBUG_TYPE "async-to-async-runtime"
30 // Prefix for functions outlined from `async.execute` op regions.
31 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
32 
33 namespace {
34 
35 class AsyncToAsyncRuntimePass
36     : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
37 public:
38   AsyncToAsyncRuntimePass() = default;
39   void runOnOperation() override;
40 };
41 
42 } // namespace
43 
44 //===----------------------------------------------------------------------===//
45 // async.execute op outlining to the coroutine functions.
46 //===----------------------------------------------------------------------===//
47 
48 /// Function targeted for coroutine transformation has two additional blocks at
49 /// the end: coroutine cleanup and coroutine suspension.
50 ///
51 /// async.await op lowering additionaly creates a resume block for each
52 /// operation to enable non-blocking waiting via coroutine suspension.
53 namespace {
54 struct CoroMachinery {
55   // Async execute region returns a completion token, and an async value for
56   // each yielded value.
57   //
58   //   %token, %result = async.execute -> !async.value<T> {
59   //     %0 = constant ... : T
60   //     async.yield %0 : T
61   //   }
62   Value asyncToken; // token representing completion of the async region
63   llvm::SmallVector<Value, 4> returnValues; // returned async values
64 
65   Value coroHandle; // coroutine handle (!async.coro.handle value)
66   Block *cleanup;   // coroutine cleanup block
67   Block *suspend;   // coroutine suspension block
68 };
69 } // namespace
70 
71 /// Builds an coroutine template compatible with LLVM coroutines switched-resume
72 /// lowering using `async.runtime.*` and `async.coro.*` operations.
73 ///
74 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
75 ///
76 ///  - `entry` block sets up the coroutine.
77 ///  - `cleanup` block cleans up the coroutine state.
78 ///  - `suspend block after the @llvm.coro.end() defines what value will be
79 ///    returned to the initial caller of a coroutine. Everything before the
80 ///    @llvm.coro.end() will be executed at every suspension point.
81 ///
82 /// Coroutine structure (only the important bits):
83 ///
84 ///   func @async_execute_fn(<function-arguments>)
85 ///        -> (!async.token, !async.value<T>)
86 ///   {
87 ///     ^entry(<function-arguments>):
88 ///       %token = <async token> : !async.token    // create async runtime token
89 ///       %value = <async value> : !async.value<T> // create async value
90 ///       %id = async.coro.id                      // create a coroutine id
91 ///       %hdl = async.coro.begin %id              // create a coroutine handle
92 ///       br ^cleanup
93 ///
94 ///     ^cleanup:
95 ///       async.coro.free %hdl // delete the coroutine state
96 ///       br ^suspend
97 ///
98 ///     ^suspend:
99 ///       async.coro.end %hdl // marks the end of a coroutine
100 ///       return %token, %value : !async.token, !async.value<T>
101 ///   }
102 ///
103 /// The actual code for the async.execute operation body region will be inserted
104 /// before the entry block terminator.
105 ///
106 ///
107 static CoroMachinery setupCoroMachinery(FuncOp func) {
108   assert(func.getBody().empty() && "Function must have empty body");
109 
110   MLIRContext *ctx = func.getContext();
111   Block *entryBlock = func.addEntryBlock();
112 
113   auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
114 
115   // ------------------------------------------------------------------------ //
116   // Allocate async token/values that we will return from a ramp function.
117   // ------------------------------------------------------------------------ //
118   auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
119 
120   llvm::SmallVector<Value, 4> retValues;
121   for (auto resType : func.getCallableResults().drop_front())
122     retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
123 
124   // ------------------------------------------------------------------------ //
125   // Initialize coroutine: get coroutine id and coroutine handle.
126   // ------------------------------------------------------------------------ //
127   auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
128   auto coroHdlOp =
129       builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
130 
131   Block *cleanupBlock = func.addBlock();
132   Block *suspendBlock = func.addBlock();
133 
134   // ------------------------------------------------------------------------ //
135   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
136   // ------------------------------------------------------------------------ //
137   builder.setInsertionPointToStart(cleanupBlock);
138   builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
139 
140   // Branch into the suspend block.
141   builder.create<BranchOp>(suspendBlock);
142 
143   // ------------------------------------------------------------------------ //
144   // Coroutine suspend block: mark the end of a coroutine and return allocated
145   // async token.
146   // ------------------------------------------------------------------------ //
147   builder.setInsertionPointToStart(suspendBlock);
148 
149   // Mark the end of a coroutine: async.coro.end
150   builder.create<CoroEndOp>(coroHdlOp.handle());
151 
152   // Return created `async.token` and `async.values` from the suspend block.
153   // This will be the return value of a coroutine ramp function.
154   SmallVector<Value, 4> ret{retToken};
155   ret.insert(ret.end(), retValues.begin(), retValues.end());
156   builder.create<ReturnOp>(ret);
157 
158   // Branch from the entry block to the cleanup block to create a valid CFG.
159   builder.setInsertionPointToEnd(entryBlock);
160   builder.create<BranchOp>(cleanupBlock);
161 
162   // `async.await` op lowering will create resume blocks for async
163   // continuations, and will conditionally branch to cleanup or suspend blocks.
164 
165   CoroMachinery machinery;
166   machinery.asyncToken = retToken;
167   machinery.returnValues = retValues;
168   machinery.coroHandle = coroHdlOp.handle();
169   machinery.cleanup = cleanupBlock;
170   machinery.suspend = suspendBlock;
171   return machinery;
172 }
173 
174 /// Outline the body region attached to the `async.execute` op into a standalone
175 /// function.
176 ///
177 /// Note that this is not reversible transformation.
178 static std::pair<FuncOp, CoroMachinery>
179 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
180   ModuleOp module = execute->getParentOfType<ModuleOp>();
181 
182   MLIRContext *ctx = module.getContext();
183   Location loc = execute.getLoc();
184 
185   // Collect all outlined function inputs.
186   SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
187                                         execute.dependencies().end());
188   functionInputs.insert(execute.operands().begin(), execute.operands().end());
189   getUsedValuesDefinedAbove(execute.body(), functionInputs);
190 
191   // Collect types for the outlined function inputs and outputs.
192   auto typesRange = llvm::map_range(
193       functionInputs, [](Value value) { return value.getType(); });
194   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
195   auto outputTypes = execute.getResultTypes();
196 
197   auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
198   auto funcAttrs = ArrayRef<NamedAttribute>();
199 
200   // TODO: Derive outlined function name from the parent FuncOp (support
201   // multiple nested async.execute operations).
202   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
203   symbolTable.insert(func);
204 
205   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
206 
207   // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
208   // blocks, adding async.coro operations and setting up control flow.
209   CoroMachinery coro = setupCoroMachinery(func);
210 
211   // Suspend async function at the end of an entry block, and resume it using
212   // Async resume operation (execution will be resumed in a thread managed by
213   // the async runtime).
214   Block *entryBlock = &func.getBlocks().front();
215   auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock);
216 
217   // Save the coroutine state: async.coro.save
218   auto coroSaveOp =
219       builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
220 
221   // Pass coroutine to the runtime to be resumed on a runtime managed thread.
222   builder.create<RuntimeResumeOp>(coro.coroHandle);
223 
224   // Split the entry block before the terminator (branch to suspend block).
225   auto *terminatorOp = entryBlock->getTerminator();
226   Block *suspended = terminatorOp->getBlock();
227   Block *resume = suspended->splitBlock(terminatorOp);
228 
229   // Add async.coro.suspend as a suspended block terminator.
230   builder.setInsertionPointToEnd(suspended);
231   builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
232                                 coro.cleanup);
233 
234   size_t numDependencies = execute.dependencies().size();
235   size_t numOperands = execute.operands().size();
236 
237   // Await on all dependencies before starting to execute the body region.
238   builder.setInsertionPointToStart(resume);
239   for (size_t i = 0; i < numDependencies; ++i)
240     builder.create<AwaitOp>(func.getArgument(i));
241 
242   // Await on all async value operands and unwrap the payload.
243   SmallVector<Value, 4> unwrappedOperands(numOperands);
244   for (size_t i = 0; i < numOperands; ++i) {
245     Value operand = func.getArgument(numDependencies + i);
246     unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
247   }
248 
249   // Map from function inputs defined above the execute op to the function
250   // arguments.
251   BlockAndValueMapping valueMapping;
252   valueMapping.map(functionInputs, func.getArguments());
253   valueMapping.map(execute.body().getArguments(), unwrappedOperands);
254 
255   // Clone all operations from the execute operation body into the outlined
256   // function body.
257   for (Operation &op : execute.body().getOps())
258     builder.clone(op, valueMapping);
259 
260   // Replace the original `async.execute` with a call to outlined function.
261   ImplicitLocOpBuilder callBuilder(loc, execute);
262   auto callOutlinedFunc = callBuilder.create<CallOp>(
263       func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
264   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
265   execute.erase();
266 
267   return {func, coro};
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // Convert async.create_group operation to async.runtime.create
272 //===----------------------------------------------------------------------===//
273 
274 namespace {
275 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
276 public:
277   using OpConversionPattern::OpConversionPattern;
278 
279   LogicalResult
280   matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands,
281                   ConversionPatternRewriter &rewriter) const override {
282     rewriter.replaceOpWithNewOp<RuntimeCreateOp>(
283         op, GroupType::get(op->getContext()));
284     return success();
285   }
286 };
287 } // namespace
288 
289 //===----------------------------------------------------------------------===//
290 // Convert async.add_to_group operation to async.runtime.add_to_group.
291 //===----------------------------------------------------------------------===//
292 
293 namespace {
294 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
295 public:
296   using OpConversionPattern::OpConversionPattern;
297 
298   LogicalResult
299   matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands,
300                   ConversionPatternRewriter &rewriter) const override {
301     rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
302         op, rewriter.getIndexType(), operands);
303     return success();
304   }
305 };
306 } // namespace
307 
308 //===----------------------------------------------------------------------===//
309 // Convert async.await and async.await_all operations to the async.runtime.await
310 // or async.runtime.await_and_resume operations.
311 //===----------------------------------------------------------------------===//
312 
313 namespace {
314 template <typename AwaitType, typename AwaitableType>
315 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
316   using AwaitAdaptor = typename AwaitType::Adaptor;
317 
318 public:
319   AwaitOpLoweringBase(
320       MLIRContext *ctx,
321       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
322       : OpConversionPattern<AwaitType>(ctx),
323         outlinedFunctions(outlinedFunctions) {}
324 
325   LogicalResult
326   matchAndRewrite(AwaitType op, ArrayRef<Value> operands,
327                   ConversionPatternRewriter &rewriter) const override {
328     // We can only await on one the `AwaitableType` (for `await` it can be
329     // a `token` or a `value`, for `await_all` it must be a `group`).
330     if (!op.operand().getType().template isa<AwaitableType>())
331       return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
332 
333     // Check if await operation is inside the outlined coroutine function.
334     auto func = op->template getParentOfType<FuncOp>();
335     auto outlined = outlinedFunctions.find(func);
336     const bool isInCoroutine = outlined != outlinedFunctions.end();
337 
338     Location loc = op->getLoc();
339     Value operand = AwaitAdaptor(operands).operand();
340 
341     // Inside regular functions we use the blocking wait operation to wait for
342     // the async object (token, value or group) to become available.
343     if (!isInCoroutine)
344       rewriter.create<RuntimeAwaitOp>(loc, operand);
345 
346     // Inside the coroutine we convert await operation into coroutine suspension
347     // point, and resume execution asynchronously.
348     if (isInCoroutine) {
349       const CoroMachinery &coro = outlined->getSecond();
350       Block *suspended = op->getBlock();
351 
352       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
353       MLIRContext *ctx = op->getContext();
354 
355       // Save the coroutine state and resume on a runtime managed thread when
356       // the operand becomes available.
357       auto coroSaveOp =
358           builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
359       builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
360 
361       // Split the entry block before the await operation.
362       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
363 
364       // Add async.coro.suspend as a suspended block terminator.
365       builder.setInsertionPointToEnd(suspended);
366       builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
367                                     coro.cleanup);
368 
369       // Make sure that replacement value will be constructed in resume block.
370       rewriter.setInsertionPointToStart(resume);
371     }
372 
373     // Erase or replace the await operation with the new value.
374     if (Value replaceWith = getReplacementValue(op, operand, rewriter))
375       rewriter.replaceOp(op, replaceWith);
376     else
377       rewriter.eraseOp(op);
378 
379     return success();
380   }
381 
382   virtual Value getReplacementValue(AwaitType op, Value operand,
383                                     ConversionPatternRewriter &rewriter) const {
384     return Value();
385   }
386 
387 private:
388   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
389 };
390 
391 /// Lowering for `async.await` with a token operand.
392 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
393   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
394 
395 public:
396   using Base::Base;
397 };
398 
399 /// Lowering for `async.await` with a value operand.
400 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
401   using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
402 
403 public:
404   using Base::Base;
405 
406   Value
407   getReplacementValue(AwaitOp op, Value operand,
408                       ConversionPatternRewriter &rewriter) const override {
409     // Load from the async value storage.
410     auto valueType = operand.getType().cast<ValueType>().getValueType();
411     return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
412   }
413 };
414 
415 /// Lowering for `async.await_all` operation.
416 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
417   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
418 
419 public:
420   using Base::Base;
421 };
422 
423 } // namespace
424 
425 //===----------------------------------------------------------------------===//
426 // Convert async.yield operation to async.runtime operations.
427 //===----------------------------------------------------------------------===//
428 
429 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
430 public:
431   YieldOpLowering(
432       MLIRContext *ctx,
433       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
434       : OpConversionPattern<async::YieldOp>(ctx),
435         outlinedFunctions(outlinedFunctions) {}
436 
437   LogicalResult
438   matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
439                   ConversionPatternRewriter &rewriter) const override {
440     // Check if yield operation is inside the outlined coroutine function.
441     auto func = op->template getParentOfType<FuncOp>();
442     auto outlined = outlinedFunctions.find(func);
443     if (outlined == outlinedFunctions.end())
444       return rewriter.notifyMatchFailure(
445           op, "operation is not inside the outlined async.execute function");
446 
447     Location loc = op->getLoc();
448     const CoroMachinery &coro = outlined->getSecond();
449 
450     // Store yielded values into the async values storage and switch async
451     // values state to available.
452     for (auto tuple : llvm::zip(operands, coro.returnValues)) {
453       Value yieldValue = std::get<0>(tuple);
454       Value asyncValue = std::get<1>(tuple);
455       rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
456       rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
457     }
458 
459     // Switch the coroutine completion token to available state.
460     rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
461 
462     return success();
463   }
464 
465 private:
466   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
467 };
468 
469 //===----------------------------------------------------------------------===//
470 
471 void AsyncToAsyncRuntimePass::runOnOperation() {
472   ModuleOp module = getOperation();
473   SymbolTable symbolTable(module);
474 
475   // Outline all `async.execute` body regions into async functions (coroutines).
476   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
477 
478   module.walk([&](ExecuteOp execute) {
479     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
480   });
481 
482   LLVM_DEBUG({
483     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
484                  << " functions built from async.execute operations\n";
485   });
486 
487   // Lower async operations to async.runtime operations.
488   MLIRContext *ctx = module->getContext();
489   RewritePatternSet asyncPatterns(ctx);
490 
491   // Async lowering does not use type converter because it must preserve all
492   // types for async.runtime operations.
493   asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
494   asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
495                     AwaitAllOpLowering, YieldOpLowering>(ctx,
496                                                          outlinedFunctions);
497 
498   // All high level async operations must be lowered to the runtime operations.
499   ConversionTarget runtimeTarget(*ctx);
500   runtimeTarget.addLegalDialect<AsyncDialect>();
501   runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
502   runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
503 
504   if (failed(applyPartialConversion(module, runtimeTarget,
505                                     std::move(asyncPatterns)))) {
506     signalPassFailure();
507     return;
508   }
509 }
510 
511 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
512   return std::make_unique<AsyncToAsyncRuntimePass>();
513 }
514