1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering from high level async operations to async.coro
10 // and async.runtime operations.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PassDetail.h"
15 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Async/IR/Async.h"
18 #include "mlir/Dialect/Async/Passes.h"
19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "mlir/IR/ImplicitLocOpBuilder.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/RegionUtils.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/Support/Debug.h"
29
30 using namespace mlir;
31 using namespace mlir::async;
32
33 #define DEBUG_TYPE "async-to-async-runtime"
34 // Prefix for functions outlined from `async.execute` op regions.
35 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
36
37 namespace {
38
39 class AsyncToAsyncRuntimePass
40 : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
41 public:
42 AsyncToAsyncRuntimePass() = default;
43 void runOnOperation() override;
44 };
45
46 } // namespace
47
48 //===----------------------------------------------------------------------===//
49 // async.execute op outlining to the coroutine functions.
50 //===----------------------------------------------------------------------===//
51
52 /// Function targeted for coroutine transformation has two additional blocks at
53 /// the end: coroutine cleanup and coroutine suspension.
54 ///
55 /// async.await op lowering additionaly creates a resume block for each
56 /// operation to enable non-blocking waiting via coroutine suspension.
57 namespace {
58 struct CoroMachinery {
59 func::FuncOp func;
60
61 // Async execute region returns a completion token, and an async value for
62 // each yielded value.
63 //
64 // %token, %result = async.execute -> !async.value<T> {
65 // %0 = arith.constant ... : T
66 // async.yield %0 : T
67 // }
68 Value asyncToken; // token representing completion of the async region
69 llvm::SmallVector<Value, 4> returnValues; // returned async values
70
71 Value coroHandle; // coroutine handle (!async.coro.handle value)
72 Block *entry; // coroutine entry block
73 Block *setError; // switch completion token and all values to error state
74 Block *cleanup; // coroutine cleanup block
75 Block *suspend; // coroutine suspension block
76 };
77 } // namespace
78
79 /// Utility to partially update the regular function CFG to the coroutine CFG
80 /// compatible with LLVM coroutines switched-resume lowering using
81 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
82 /// that branches into preexisting entry block. Also inserts trailing blocks.
83 ///
84 /// The result types of the passed `func` must start with an `async.token`
85 /// and be continued with some number of `async.value`s.
86 ///
87 /// The func given to this function needs to have been preprocessed to have
88 /// either branch or yield ops as terminators. Branches to the cleanup block are
89 /// inserted after each yield.
90 ///
91 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
92 ///
93 /// - `entry` block sets up the coroutine.
94 /// - `set_error` block sets completion token and async values state to error.
95 /// - `cleanup` block cleans up the coroutine state.
96 /// - `suspend block after the @llvm.coro.end() defines what value will be
97 /// returned to the initial caller of a coroutine. Everything before the
98 /// @llvm.coro.end() will be executed at every suspension point.
99 ///
100 /// Coroutine structure (only the important bits):
101 ///
102 /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
103 /// {
104 /// ^entry(<function-arguments>):
105 /// %token = <async token> : !async.token // create async runtime token
106 /// %value = <async value> : !async.value<T> // create async value
107 /// %id = async.coro.id // create a coroutine id
108 /// %hdl = async.coro.begin %id // create a coroutine handle
109 /// cf.br ^preexisting_entry_block
110 ///
111 /// /* preexisting blocks modified to branch to the cleanup block */
112 ///
113 /// ^set_error: // this block created lazily only if needed (see code below)
114 /// async.runtime.set_error %token : !async.token
115 /// async.runtime.set_error %value : !async.value<T>
116 /// cf.br ^cleanup
117 ///
118 /// ^cleanup:
119 /// async.coro.free %hdl // delete the coroutine state
120 /// cf.br ^suspend
121 ///
122 /// ^suspend:
123 /// async.coro.end %hdl // marks the end of a coroutine
124 /// return %token, %value : !async.token, !async.value<T>
125 /// }
126 ///
setupCoroMachinery(func::FuncOp func)127 static CoroMachinery setupCoroMachinery(func::FuncOp func) {
128 assert(!func.getBlocks().empty() && "Function must have an entry block");
129
130 MLIRContext *ctx = func.getContext();
131 Block *entryBlock = &func.getBlocks().front();
132 Block *originalEntryBlock =
133 entryBlock->splitBlock(entryBlock->getOperations().begin());
134 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
135
136 // ------------------------------------------------------------------------ //
137 // Allocate async token/values that we will return from a ramp function.
138 // ------------------------------------------------------------------------ //
139 auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
140
141 llvm::SmallVector<Value, 4> retValues;
142 for (auto resType : func.getCallableResults().drop_front())
143 retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
144
145 // ------------------------------------------------------------------------ //
146 // Initialize coroutine: get coroutine id and coroutine handle.
147 // ------------------------------------------------------------------------ //
148 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
149 auto coroHdlOp =
150 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
151 builder.create<cf::BranchOp>(originalEntryBlock);
152
153 Block *cleanupBlock = func.addBlock();
154 Block *suspendBlock = func.addBlock();
155
156 // ------------------------------------------------------------------------ //
157 // Coroutine cleanup block: deallocate coroutine frame, free the memory.
158 // ------------------------------------------------------------------------ //
159 builder.setInsertionPointToStart(cleanupBlock);
160 builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
161
162 // Branch into the suspend block.
163 builder.create<cf::BranchOp>(suspendBlock);
164
165 // ------------------------------------------------------------------------ //
166 // Coroutine suspend block: mark the end of a coroutine and return allocated
167 // async token.
168 // ------------------------------------------------------------------------ //
169 builder.setInsertionPointToStart(suspendBlock);
170
171 // Mark the end of a coroutine: async.coro.end
172 builder.create<CoroEndOp>(coroHdlOp.handle());
173
174 // Return created `async.token` and `async.values` from the suspend block.
175 // This will be the return value of a coroutine ramp function.
176 SmallVector<Value, 4> ret{retToken};
177 ret.insert(ret.end(), retValues.begin(), retValues.end());
178 builder.create<func::ReturnOp>(ret);
179
180 // `async.await` op lowering will create resume blocks for async
181 // continuations, and will conditionally branch to cleanup or suspend blocks.
182
183 for (Block &block : func.getBody().getBlocks()) {
184 if (&block == entryBlock || &block == cleanupBlock ||
185 &block == suspendBlock)
186 continue;
187 Operation *terminator = block.getTerminator();
188 if (auto yield = dyn_cast<YieldOp>(terminator)) {
189 builder.setInsertionPointToEnd(&block);
190 builder.create<cf::BranchOp>(cleanupBlock);
191 }
192 }
193
194 // The switch-resumed API based coroutine should be marked with
195 // coroutine.presplit attribute to mark the function as a coroutine.
196 func->setAttr("passthrough", builder.getArrayAttr(
197 StringAttr::get(ctx, "presplitcoroutine")));
198
199 CoroMachinery machinery;
200 machinery.func = func;
201 machinery.asyncToken = retToken;
202 machinery.returnValues = retValues;
203 machinery.coroHandle = coroHdlOp.handle();
204 machinery.entry = entryBlock;
205 machinery.setError = nullptr; // created lazily only if needed
206 machinery.cleanup = cleanupBlock;
207 machinery.suspend = suspendBlock;
208 return machinery;
209 }
210
211 // Lazily creates `set_error` block only if it is required for lowering to the
212 // runtime operations (see for example lowering of assert operation).
setupSetErrorBlock(CoroMachinery & coro)213 static Block *setupSetErrorBlock(CoroMachinery &coro) {
214 if (coro.setError)
215 return coro.setError;
216
217 coro.setError = coro.func.addBlock();
218 coro.setError->moveBefore(coro.cleanup);
219
220 auto builder =
221 ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError);
222
223 // Coroutine set_error block: set error on token and all returned values.
224 builder.create<RuntimeSetErrorOp>(coro.asyncToken);
225 for (Value retValue : coro.returnValues)
226 builder.create<RuntimeSetErrorOp>(retValue);
227
228 // Branch into the cleanup block.
229 builder.create<cf::BranchOp>(coro.cleanup);
230
231 return coro.setError;
232 }
233
234 /// Outline the body region attached to the `async.execute` op into a standalone
235 /// function.
236 ///
237 /// Note that this is not reversible transformation.
238 static std::pair<func::FuncOp, CoroMachinery>
outlineExecuteOp(SymbolTable & symbolTable,ExecuteOp execute)239 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
240 ModuleOp module = execute->getParentOfType<ModuleOp>();
241
242 MLIRContext *ctx = module.getContext();
243 Location loc = execute.getLoc();
244
245 // Make sure that all constants will be inside the outlined async function to
246 // reduce the number of function arguments.
247 cloneConstantsIntoTheRegion(execute.body());
248
249 // Collect all outlined function inputs.
250 SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
251 execute.dependencies().end());
252 functionInputs.insert(execute.operands().begin(), execute.operands().end());
253 getUsedValuesDefinedAbove(execute.body(), functionInputs);
254
255 // Collect types for the outlined function inputs and outputs.
256 auto typesRange = llvm::map_range(
257 functionInputs, [](Value value) { return value.getType(); });
258 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
259 auto outputTypes = execute.getResultTypes();
260
261 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
262 auto funcAttrs = ArrayRef<NamedAttribute>();
263
264 // TODO: Derive outlined function name from the parent FuncOp (support
265 // multiple nested async.execute operations).
266 func::FuncOp func =
267 func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
268 symbolTable.insert(func);
269
270 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
271 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
272
273 // Prepare for coroutine conversion by creating the body of the function.
274 {
275 size_t numDependencies = execute.dependencies().size();
276 size_t numOperands = execute.operands().size();
277
278 // Await on all dependencies before starting to execute the body region.
279 for (size_t i = 0; i < numDependencies; ++i)
280 builder.create<AwaitOp>(func.getArgument(i));
281
282 // Await on all async value operands and unwrap the payload.
283 SmallVector<Value, 4> unwrappedOperands(numOperands);
284 for (size_t i = 0; i < numOperands; ++i) {
285 Value operand = func.getArgument(numDependencies + i);
286 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
287 }
288
289 // Map from function inputs defined above the execute op to the function
290 // arguments.
291 BlockAndValueMapping valueMapping;
292 valueMapping.map(functionInputs, func.getArguments());
293 valueMapping.map(execute.body().getArguments(), unwrappedOperands);
294
295 // Clone all operations from the execute operation body into the outlined
296 // function body.
297 for (Operation &op : execute.body().getOps())
298 builder.clone(op, valueMapping);
299 }
300
301 // Adding entry/cleanup/suspend blocks.
302 CoroMachinery coro = setupCoroMachinery(func);
303
304 // Suspend async function at the end of an entry block, and resume it using
305 // Async resume operation (execution will be resumed in a thread managed by
306 // the async runtime).
307 {
308 cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
309 builder.setInsertionPointToEnd(coro.entry);
310
311 // Save the coroutine state: async.coro.save
312 auto coroSaveOp =
313 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
314
315 // Pass coroutine to the runtime to be resumed on a runtime managed
316 // thread.
317 builder.create<RuntimeResumeOp>(coro.coroHandle);
318
319 // Add async.coro.suspend as a suspended block terminator.
320 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend,
321 branch.getDest(), coro.cleanup);
322
323 branch.erase();
324 }
325
326 // Replace the original `async.execute` with a call to outlined function.
327 {
328 ImplicitLocOpBuilder callBuilder(loc, execute);
329 auto callOutlinedFunc = callBuilder.create<func::CallOp>(
330 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
331 execute.replaceAllUsesWith(callOutlinedFunc.getResults());
332 execute.erase();
333 }
334
335 return {func, coro};
336 }
337
338 //===----------------------------------------------------------------------===//
339 // Convert async.create_group operation to async.runtime.create_group
340 //===----------------------------------------------------------------------===//
341
342 namespace {
343 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
344 public:
345 using OpConversionPattern::OpConversionPattern;
346
347 LogicalResult
matchAndRewrite(CreateGroupOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const348 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter) const override {
350 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
351 op, GroupType::get(op->getContext()), adaptor.getOperands());
352 return success();
353 }
354 };
355 } // namespace
356
357 //===----------------------------------------------------------------------===//
358 // Convert async.add_to_group operation to async.runtime.add_to_group.
359 //===----------------------------------------------------------------------===//
360
361 namespace {
362 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
363 public:
364 using OpConversionPattern::OpConversionPattern;
365
366 LogicalResult
matchAndRewrite(AddToGroupOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const367 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
368 ConversionPatternRewriter &rewriter) const override {
369 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
370 op, rewriter.getIndexType(), adaptor.getOperands());
371 return success();
372 }
373 };
374 } // namespace
375
376 //===----------------------------------------------------------------------===//
377 // Convert async.await and async.await_all operations to the async.runtime.await
378 // or async.runtime.await_and_resume operations.
379 //===----------------------------------------------------------------------===//
380
381 namespace {
382 template <typename AwaitType, typename AwaitableType>
383 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
384 using AwaitAdaptor = typename AwaitType::Adaptor;
385
386 public:
AwaitOpLoweringBase(MLIRContext * ctx,llvm::DenseMap<func::FuncOp,CoroMachinery> & outlinedFunctions)387 AwaitOpLoweringBase(
388 MLIRContext *ctx,
389 llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions)
390 : OpConversionPattern<AwaitType>(ctx),
391 outlinedFunctions(outlinedFunctions) {}
392
393 LogicalResult
matchAndRewrite(AwaitType op,typename AwaitType::Adaptor adaptor,ConversionPatternRewriter & rewriter) const394 matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
395 ConversionPatternRewriter &rewriter) const override {
396 // We can only await on one the `AwaitableType` (for `await` it can be
397 // a `token` or a `value`, for `await_all` it must be a `group`).
398 if (!op.operand().getType().template isa<AwaitableType>())
399 return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
400
401 // Check if await operation is inside the outlined coroutine function.
402 auto func = op->template getParentOfType<func::FuncOp>();
403 auto outlined = outlinedFunctions.find(func);
404 const bool isInCoroutine = outlined != outlinedFunctions.end();
405
406 Location loc = op->getLoc();
407 Value operand = adaptor.operand();
408
409 Type i1 = rewriter.getI1Type();
410
411 // Inside regular functions we use the blocking wait operation to wait for
412 // the async object (token, value or group) to become available.
413 if (!isInCoroutine) {
414 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
415 builder.create<RuntimeAwaitOp>(loc, operand);
416
417 // Assert that the awaited operands is not in the error state.
418 Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
419 Value notError = builder.create<arith::XOrIOp>(
420 isError, builder.create<arith::ConstantOp>(
421 loc, i1, builder.getIntegerAttr(i1, 1)));
422
423 builder.create<cf::AssertOp>(notError,
424 "Awaited async operand is in error state");
425 }
426
427 // Inside the coroutine we convert await operation into coroutine suspension
428 // point, and resume execution asynchronously.
429 if (isInCoroutine) {
430 CoroMachinery &coro = outlined->getSecond();
431 Block *suspended = op->getBlock();
432
433 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
434 MLIRContext *ctx = op->getContext();
435
436 // Save the coroutine state and resume on a runtime managed thread when
437 // the operand becomes available.
438 auto coroSaveOp =
439 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
440 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
441
442 // Split the entry block before the await operation.
443 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
444
445 // Add async.coro.suspend as a suspended block terminator.
446 builder.setInsertionPointToEnd(suspended);
447 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
448 coro.cleanup);
449
450 // Split the resume block into error checking and continuation.
451 Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
452
453 // Check if the awaited value is in the error state.
454 builder.setInsertionPointToStart(resume);
455 auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
456 builder.create<cf::CondBranchOp>(isError,
457 /*trueDest=*/setupSetErrorBlock(coro),
458 /*trueArgs=*/ArrayRef<Value>(),
459 /*falseDest=*/continuation,
460 /*falseArgs=*/ArrayRef<Value>());
461
462 // Make sure that replacement value will be constructed in the
463 // continuation block.
464 rewriter.setInsertionPointToStart(continuation);
465 }
466
467 // Erase or replace the await operation with the new value.
468 if (Value replaceWith = getReplacementValue(op, operand, rewriter))
469 rewriter.replaceOp(op, replaceWith);
470 else
471 rewriter.eraseOp(op);
472
473 return success();
474 }
475
getReplacementValue(AwaitType op,Value operand,ConversionPatternRewriter & rewriter) const476 virtual Value getReplacementValue(AwaitType op, Value operand,
477 ConversionPatternRewriter &rewriter) const {
478 return Value();
479 }
480
481 private:
482 llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions;
483 };
484
485 /// Lowering for `async.await` with a token operand.
486 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
487 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
488
489 public:
490 using Base::Base;
491 };
492
493 /// Lowering for `async.await` with a value operand.
494 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
495 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
496
497 public:
498 using Base::Base;
499
500 Value
getReplacementValue(AwaitOp op,Value operand,ConversionPatternRewriter & rewriter) const501 getReplacementValue(AwaitOp op, Value operand,
502 ConversionPatternRewriter &rewriter) const override {
503 // Load from the async value storage.
504 auto valueType = operand.getType().cast<ValueType>().getValueType();
505 return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
506 }
507 };
508
509 /// Lowering for `async.await_all` operation.
510 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
511 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
512
513 public:
514 using Base::Base;
515 };
516
517 } // namespace
518
519 //===----------------------------------------------------------------------===//
520 // Convert async.yield operation to async.runtime operations.
521 //===----------------------------------------------------------------------===//
522
523 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
524 public:
YieldOpLowering(MLIRContext * ctx,const llvm::DenseMap<func::FuncOp,CoroMachinery> & outlinedFunctions)525 YieldOpLowering(
526 MLIRContext *ctx,
527 const llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions)
528 : OpConversionPattern<async::YieldOp>(ctx),
529 outlinedFunctions(outlinedFunctions) {}
530
531 LogicalResult
matchAndRewrite(async::YieldOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const532 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
533 ConversionPatternRewriter &rewriter) const override {
534 // Check if yield operation is inside the async coroutine function.
535 auto func = op->template getParentOfType<func::FuncOp>();
536 auto outlined = outlinedFunctions.find(func);
537 if (outlined == outlinedFunctions.end())
538 return rewriter.notifyMatchFailure(
539 op, "operation is not inside the async coroutine function");
540
541 Location loc = op->getLoc();
542 const CoroMachinery &coro = outlined->getSecond();
543
544 // Store yielded values into the async values storage and switch async
545 // values state to available.
546 for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
547 Value yieldValue = std::get<0>(tuple);
548 Value asyncValue = std::get<1>(tuple);
549 rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
550 rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
551 }
552
553 // Switch the coroutine completion token to available state.
554 rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
555
556 return success();
557 }
558
559 private:
560 const llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions;
561 };
562
563 //===----------------------------------------------------------------------===//
564 // Convert cf.assert operation to cf.cond_br into `set_error` block.
565 //===----------------------------------------------------------------------===//
566
567 class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
568 public:
AssertOpLowering(MLIRContext * ctx,llvm::DenseMap<func::FuncOp,CoroMachinery> & outlinedFunctions)569 AssertOpLowering(
570 MLIRContext *ctx,
571 llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions)
572 : OpConversionPattern<cf::AssertOp>(ctx),
573 outlinedFunctions(outlinedFunctions) {}
574
575 LogicalResult
matchAndRewrite(cf::AssertOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const576 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
577 ConversionPatternRewriter &rewriter) const override {
578 // Check if assert operation is inside the async coroutine function.
579 auto func = op->template getParentOfType<func::FuncOp>();
580 auto outlined = outlinedFunctions.find(func);
581 if (outlined == outlinedFunctions.end())
582 return rewriter.notifyMatchFailure(
583 op, "operation is not inside the async coroutine function");
584
585 Location loc = op->getLoc();
586 CoroMachinery &coro = outlined->getSecond();
587
588 Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
589 rewriter.setInsertionPointToEnd(cont->getPrevNode());
590 rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(),
591 /*trueDest=*/cont,
592 /*trueArgs=*/ArrayRef<Value>(),
593 /*falseDest=*/setupSetErrorBlock(coro),
594 /*falseArgs=*/ArrayRef<Value>());
595 rewriter.eraseOp(op);
596
597 return success();
598 }
599
600 private:
601 llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions;
602 };
603
604 //===----------------------------------------------------------------------===//
605
606 /// Rewrite a func as a coroutine by:
607 /// 1) Wrapping the results into `async.value`.
608 /// 2) Prepending the results with `async.token`.
609 /// 3) Setting up coroutine blocks.
610 /// 4) Rewriting return ops as yield op and branch op into the suspend block.
rewriteFuncAsCoroutine(func::FuncOp func)611 static CoroMachinery rewriteFuncAsCoroutine(func::FuncOp func) {
612 auto *ctx = func->getContext();
613 auto loc = func.getLoc();
614 SmallVector<Type> resultTypes;
615 resultTypes.reserve(func.getCallableResults().size());
616 llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes),
617 [](Type type) { return ValueType::get(type); });
618 func.setType(
619 FunctionType::get(ctx, func.getFunctionType().getInputs(), resultTypes));
620 func.insertResult(0, TokenType::get(ctx), {});
621 for (Block &block : func.getBlocks()) {
622 Operation *terminator = block.getTerminator();
623 if (auto returnOp = dyn_cast<func::ReturnOp>(*terminator)) {
624 ImplicitLocOpBuilder builder(loc, returnOp);
625 builder.create<YieldOp>(returnOp.getOperands());
626 returnOp.erase();
627 }
628 }
629 return setupCoroMachinery(func);
630 }
631
632 /// Rewrites a call into a function that has been rewritten as a coroutine.
633 ///
634 /// The invocation of this function is safe only when call ops are traversed in
635 /// reverse order of how they appear in a single block. See `funcsToCoroutines`.
rewriteCallsiteForCoroutine(func::CallOp oldCall,func::FuncOp func)636 static void rewriteCallsiteForCoroutine(func::CallOp oldCall,
637 func::FuncOp func) {
638 auto loc = func.getLoc();
639 ImplicitLocOpBuilder callBuilder(loc, oldCall);
640 auto newCall = callBuilder.create<func::CallOp>(
641 func.getName(), func.getCallableResults(), oldCall.getArgOperands());
642
643 // Await on the async token and all the value results and unwrap the latter.
644 callBuilder.create<AwaitOp>(loc, newCall.getResults().front());
645 SmallVector<Value> unwrappedResults;
646 unwrappedResults.reserve(newCall->getResults().size() - 1);
647 for (Value result : newCall.getResults().drop_front())
648 unwrappedResults.push_back(
649 callBuilder.create<AwaitOp>(loc, result).result());
650 // Careful, when result of a call is piped into another call this could lead
651 // to a dangling pointer.
652 oldCall.replaceAllUsesWith(unwrappedResults);
653 oldCall.erase();
654 }
655
isAllowedToBlock(func::FuncOp func)656 static bool isAllowedToBlock(func::FuncOp func) {
657 return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName);
658 }
659
funcsToCoroutines(ModuleOp module,llvm::DenseMap<func::FuncOp,CoroMachinery> & outlinedFunctions)660 static LogicalResult funcsToCoroutines(
661 ModuleOp module,
662 llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions) {
663 // The following code supports the general case when 2 functions mutually
664 // recurse into each other. Because of this and that we are relying on
665 // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase
666 // a FuncOp while inserting an equivalent coroutine, because that could lead
667 // to dangling pointers.
668
669 SmallVector<func::FuncOp> funcWorklist;
670
671 // Careful, it's okay to add a func to the worklist multiple times if and only
672 // if the loop processing the worklist will skip the functions that have
673 // already been converted to coroutines.
674 auto addToWorklist = [&](func::FuncOp func) {
675 if (isAllowedToBlock(func))
676 return;
677 // N.B. To refactor this code into a separate pass the lookup in
678 // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary
679 // func and recognizing if it has a coroutine structure is messy. Passing
680 // this dict between the passes is ugly.
681 if (isAllowedToBlock(func) ||
682 outlinedFunctions.find(func) == outlinedFunctions.end()) {
683 for (Operation &op : func.getBody().getOps()) {
684 if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
685 funcWorklist.push_back(func);
686 break;
687 }
688 }
689 }
690 };
691
692 // Traverse in post-order collecting for each func op the await ops it has.
693 for (func::FuncOp func : module.getOps<func::FuncOp>())
694 addToWorklist(func);
695
696 SymbolTableCollection symbolTable;
697 SymbolUserMap symbolUserMap(symbolTable, module);
698
699 // Rewrite funcs, while updating call sites and adding them to the worklist.
700 while (!funcWorklist.empty()) {
701 auto func = funcWorklist.pop_back_val();
702 auto insertion = outlinedFunctions.insert({func, CoroMachinery{}});
703 if (!insertion.second)
704 // This function has already been processed because this is either
705 // the corecursive case, or a caller with multiple calls to a newly
706 // created corouting. Either way, skip updating the call sites.
707 continue;
708 insertion.first->second = rewriteFuncAsCoroutine(func);
709 SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(),
710 symbolUserMap.getUsers(func).end());
711 // If there are multiple calls from the same block they need to be traversed
712 // in reverse order so that symbolUserMap references are not invalidated
713 // when updating the users of the call op which is earlier in the block.
714 llvm::sort(users, [](Operation *a, Operation *b) {
715 Block *blockA = a->getBlock();
716 Block *blockB = b->getBlock();
717 // Impose arbitrary order on blocks so that there is a well-defined order.
718 return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b));
719 });
720 // Rewrite the callsites to await on results of the newly created coroutine.
721 for (Operation *op : users) {
722 if (func::CallOp call = dyn_cast<func::CallOp>(*op)) {
723 func::FuncOp caller = call->getParentOfType<func::FuncOp>();
724 rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op.
725 addToWorklist(caller);
726 } else {
727 op->emitError("Unexpected reference to func referenced by symbol");
728 return failure();
729 }
730 }
731 }
732 return success();
733 }
734
735 //===----------------------------------------------------------------------===//
runOnOperation()736 void AsyncToAsyncRuntimePass::runOnOperation() {
737 ModuleOp module = getOperation();
738 SymbolTable symbolTable(module);
739
740 // Outline all `async.execute` body regions into async functions (coroutines).
741 llvm::DenseMap<func::FuncOp, CoroMachinery> outlinedFunctions;
742
743 module.walk([&](ExecuteOp execute) {
744 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
745 });
746
747 LLVM_DEBUG({
748 llvm::dbgs() << "Outlined " << outlinedFunctions.size()
749 << " functions built from async.execute operations\n";
750 });
751
752 // Returns true if operation is inside the coroutine.
753 auto isInCoroutine = [&](Operation *op) -> bool {
754 auto parentFunc = op->getParentOfType<func::FuncOp>();
755 return outlinedFunctions.find(parentFunc) != outlinedFunctions.end();
756 };
757
758 if (eliminateBlockingAwaitOps &&
759 failed(funcsToCoroutines(module, outlinedFunctions))) {
760 signalPassFailure();
761 return;
762 }
763
764 // Lower async operations to async.runtime operations.
765 MLIRContext *ctx = module->getContext();
766 RewritePatternSet asyncPatterns(ctx);
767
768 // Conversion to async runtime augments original CFG with the coroutine CFG,
769 // and we have to make sure that structured control flow operations with async
770 // operations in nested regions will be converted to branch-based control flow
771 // before we add the coroutine basic blocks.
772 populateSCFToControlFlowConversionPatterns(asyncPatterns);
773
774 // Async lowering does not use type converter because it must preserve all
775 // types for async.runtime operations.
776 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
777 asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
778 AwaitAllOpLowering, YieldOpLowering>(ctx,
779 outlinedFunctions);
780
781 // Lower assertions to conditional branches into error blocks.
782 asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions);
783
784 // All high level async operations must be lowered to the runtime operations.
785 ConversionTarget runtimeTarget(*ctx);
786 runtimeTarget.addLegalDialect<AsyncDialect>();
787 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
788 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
789
790 // Decide if structured control flow has to be lowered to branch-based CFG.
791 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
792 auto walkResult = op->walk([&](Operation *nested) {
793 bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
794 return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
795 : WalkResult::advance();
796 });
797 return !walkResult.wasInterrupted();
798 });
799 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
800 func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
801
802 // Assertions must be converted to runtime errors inside async functions.
803 runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
804 [&](cf::AssertOp op) -> bool {
805 auto func = op->getParentOfType<func::FuncOp>();
806 return outlinedFunctions.find(func) == outlinedFunctions.end();
807 });
808
809 if (eliminateBlockingAwaitOps)
810 runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(
811 [&](RuntimeAwaitOp op) -> bool {
812 return isAllowedToBlock(op->getParentOfType<func::FuncOp>());
813 });
814
815 if (failed(applyPartialConversion(module, runtimeTarget,
816 std::move(asyncPatterns)))) {
817 signalPassFailure();
818 return;
819 }
820 }
821
createAsyncToAsyncRuntimePass()822 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
823 return std::make_unique<AsyncToAsyncRuntimePass>();
824 }
825