1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering from high level async operations to async.coro 10 // and async.runtime operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Async/IR/Async.h" 16 #include "mlir/Dialect/Async/Passes.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/IR/BlockAndValueMapping.h" 19 #include "mlir/IR/ImplicitLocOpBuilder.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "mlir/Transforms/RegionUtils.h" 23 #include "llvm/ADT/SetVector.h" 24 25 using namespace mlir; 26 using namespace mlir::async; 27 28 #define DEBUG_TYPE "async-to-async-runtime" 29 // Prefix for functions outlined from `async.execute` op regions. 30 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 31 32 namespace { 33 34 class AsyncToAsyncRuntimePass 35 : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> { 36 public: 37 AsyncToAsyncRuntimePass() = default; 38 void runOnOperation() override; 39 }; 40 41 } // namespace 42 43 //===----------------------------------------------------------------------===// 44 // async.execute op outlining to the coroutine functions. 45 //===----------------------------------------------------------------------===// 46 47 /// Function targeted for coroutine transformation has two additional blocks at 48 /// the end: coroutine cleanup and coroutine suspension. 49 /// 50 /// async.await op lowering additionaly creates a resume block for each 51 /// operation to enable non-blocking waiting via coroutine suspension. 52 namespace { 53 struct CoroMachinery { 54 // Async execute region returns a completion token, and an async value for 55 // each yielded value. 56 // 57 // %token, %result = async.execute -> !async.value<T> { 58 // %0 = constant ... : T 59 // async.yield %0 : T 60 // } 61 Value asyncToken; // token representing completion of the async region 62 llvm::SmallVector<Value, 4> returnValues; // returned async values 63 64 Value coroHandle; // coroutine handle (!async.coro.handle value) 65 Block *cleanup; // coroutine cleanup block 66 Block *suspend; // coroutine suspension block 67 }; 68 } // namespace 69 70 /// Builds an coroutine template compatible with LLVM coroutines switched-resume 71 /// lowering using `async.runtime.*` and `async.coro.*` operations. 72 /// 73 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html 74 /// 75 /// - `entry` block sets up the coroutine. 76 /// - `cleanup` block cleans up the coroutine state. 77 /// - `suspend block after the @llvm.coro.end() defines what value will be 78 /// returned to the initial caller of a coroutine. Everything before the 79 /// @llvm.coro.end() will be executed at every suspension point. 80 /// 81 /// Coroutine structure (only the important bits): 82 /// 83 /// func @async_execute_fn(<function-arguments>) 84 /// -> (!async.token, !async.value<T>) 85 /// { 86 /// ^entry(<function-arguments>): 87 /// %token = <async token> : !async.token // create async runtime token 88 /// %value = <async value> : !async.value<T> // create async value 89 /// %id = async.coro.id // create a coroutine id 90 /// %hdl = async.coro.begin %id // create a coroutine handle 91 /// br ^cleanup 92 /// 93 /// ^cleanup: 94 /// async.coro.free %hdl // delete the coroutine state 95 /// br ^suspend 96 /// 97 /// ^suspend: 98 /// async.coro.end %hdl // marks the end of a coroutine 99 /// return %token, %value : !async.token, !async.value<T> 100 /// } 101 /// 102 /// The actual code for the async.execute operation body region will be inserted 103 /// before the entry block terminator. 104 /// 105 /// 106 static CoroMachinery setupCoroMachinery(FuncOp func) { 107 assert(func.getBody().empty() && "Function must have empty body"); 108 109 MLIRContext *ctx = func.getContext(); 110 Block *entryBlock = func.addEntryBlock(); 111 112 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); 113 114 // ------------------------------------------------------------------------ // 115 // Allocate async token/values that we will return from a ramp function. 116 // ------------------------------------------------------------------------ // 117 auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result(); 118 119 llvm::SmallVector<Value, 4> retValues; 120 for (auto resType : func.getCallableResults().drop_front()) 121 retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result()); 122 123 // ------------------------------------------------------------------------ // 124 // Initialize coroutine: get coroutine id and coroutine handle. 125 // ------------------------------------------------------------------------ // 126 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); 127 auto coroHdlOp = 128 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); 129 130 Block *cleanupBlock = func.addBlock(); 131 Block *suspendBlock = func.addBlock(); 132 133 // ------------------------------------------------------------------------ // 134 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 135 // ------------------------------------------------------------------------ // 136 builder.setInsertionPointToStart(cleanupBlock); 137 builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); 138 139 // Branch into the suspend block. 140 builder.create<BranchOp>(suspendBlock); 141 142 // ------------------------------------------------------------------------ // 143 // Coroutine suspend block: mark the end of a coroutine and return allocated 144 // async token. 145 // ------------------------------------------------------------------------ // 146 builder.setInsertionPointToStart(suspendBlock); 147 148 // Mark the end of a coroutine: async.coro.end 149 builder.create<CoroEndOp>(coroHdlOp.handle()); 150 151 // Return created `async.token` and `async.values` from the suspend block. 152 // This will be the return value of a coroutine ramp function. 153 SmallVector<Value, 4> ret{retToken}; 154 ret.insert(ret.end(), retValues.begin(), retValues.end()); 155 builder.create<ReturnOp>(ret); 156 157 // Branch from the entry block to the cleanup block to create a valid CFG. 158 builder.setInsertionPointToEnd(entryBlock); 159 builder.create<BranchOp>(cleanupBlock); 160 161 // `async.await` op lowering will create resume blocks for async 162 // continuations, and will conditionally branch to cleanup or suspend blocks. 163 164 CoroMachinery machinery; 165 machinery.asyncToken = retToken; 166 machinery.returnValues = retValues; 167 machinery.coroHandle = coroHdlOp.handle(); 168 machinery.cleanup = cleanupBlock; 169 machinery.suspend = suspendBlock; 170 return machinery; 171 } 172 173 /// Outline the body region attached to the `async.execute` op into a standalone 174 /// function. 175 /// 176 /// Note that this is not reversible transformation. 177 static std::pair<FuncOp, CoroMachinery> 178 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 179 ModuleOp module = execute->getParentOfType<ModuleOp>(); 180 181 MLIRContext *ctx = module.getContext(); 182 Location loc = execute.getLoc(); 183 184 // Collect all outlined function inputs. 185 llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 186 execute.dependencies().end()); 187 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 188 getUsedValuesDefinedAbove(execute.body(), functionInputs); 189 190 // Collect types for the outlined function inputs and outputs. 191 auto typesRange = llvm::map_range( 192 functionInputs, [](Value value) { return value.getType(); }); 193 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 194 auto outputTypes = execute.getResultTypes(); 195 196 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 197 auto funcAttrs = ArrayRef<NamedAttribute>(); 198 199 // TODO: Derive outlined function name from the parent FuncOp (support 200 // multiple nested async.execute operations). 201 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 202 symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator())); 203 204 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 205 206 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 207 // blocks, adding async.coro operations and setting up control flow. 208 CoroMachinery coro = setupCoroMachinery(func); 209 210 // Suspend async function at the end of an entry block, and resume it using 211 // Async resume operation (execution will be resumed in a thread managed by 212 // the async runtime). 213 Block *entryBlock = &func.getBlocks().front(); 214 auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock); 215 216 // Save the coroutine state: async.coro.save 217 auto coroSaveOp = 218 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 219 220 // Pass coroutine to the runtime to be resumed on a runtime managed thread. 221 builder.create<RuntimeResumeOp>(coro.coroHandle); 222 223 // Split the entry block before the terminator (branch to suspend block). 224 auto *terminatorOp = entryBlock->getTerminator(); 225 Block *suspended = terminatorOp->getBlock(); 226 Block *resume = suspended->splitBlock(terminatorOp); 227 228 // Add async.coro.suspend as a suspended block terminator. 229 builder.setInsertionPointToEnd(suspended); 230 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 231 coro.cleanup); 232 233 size_t numDependencies = execute.dependencies().size(); 234 size_t numOperands = execute.operands().size(); 235 236 // Await on all dependencies before starting to execute the body region. 237 builder.setInsertionPointToStart(resume); 238 for (size_t i = 0; i < numDependencies; ++i) 239 builder.create<AwaitOp>(func.getArgument(i)); 240 241 // Await on all async value operands and unwrap the payload. 242 SmallVector<Value, 4> unwrappedOperands(numOperands); 243 for (size_t i = 0; i < numOperands; ++i) { 244 Value operand = func.getArgument(numDependencies + i); 245 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 246 } 247 248 // Map from function inputs defined above the execute op to the function 249 // arguments. 250 BlockAndValueMapping valueMapping; 251 valueMapping.map(functionInputs, func.getArguments()); 252 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 253 254 // Clone all operations from the execute operation body into the outlined 255 // function body. 256 for (Operation &op : execute.body().getOps()) 257 builder.clone(op, valueMapping); 258 259 // Replace the original `async.execute` with a call to outlined function. 260 ImplicitLocOpBuilder callBuilder(loc, execute); 261 auto callOutlinedFunc = callBuilder.create<CallOp>( 262 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 263 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 264 execute.erase(); 265 266 return {func, coro}; 267 } 268 269 //===----------------------------------------------------------------------===// 270 // Convert async.create_group operation to async.runtime.create 271 //===----------------------------------------------------------------------===// 272 273 namespace { 274 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { 275 public: 276 using OpConversionPattern::OpConversionPattern; 277 278 LogicalResult 279 matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands, 280 ConversionPatternRewriter &rewriter) const override { 281 rewriter.replaceOpWithNewOp<RuntimeCreateOp>( 282 op, GroupType::get(op->getContext())); 283 return success(); 284 } 285 }; 286 } // namespace 287 288 //===----------------------------------------------------------------------===// 289 // Convert async.add_to_group operation to async.runtime.add_to_group. 290 //===----------------------------------------------------------------------===// 291 292 namespace { 293 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { 294 public: 295 using OpConversionPattern::OpConversionPattern; 296 297 LogicalResult 298 matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands, 299 ConversionPatternRewriter &rewriter) const override { 300 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>( 301 op, rewriter.getIndexType(), operands); 302 return success(); 303 } 304 }; 305 } // namespace 306 307 //===----------------------------------------------------------------------===// 308 // Convert async.await and async.await_all operations to the async.runtime.await 309 // or async.runtime.await_and_resume operations. 310 //===----------------------------------------------------------------------===// 311 312 namespace { 313 template <typename AwaitType, typename AwaitableType> 314 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { 315 using AwaitAdaptor = typename AwaitType::Adaptor; 316 317 public: 318 AwaitOpLoweringBase( 319 MLIRContext *ctx, 320 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 321 : OpConversionPattern<AwaitType>(ctx), 322 outlinedFunctions(outlinedFunctions) {} 323 324 LogicalResult 325 matchAndRewrite(AwaitType op, ArrayRef<Value> operands, 326 ConversionPatternRewriter &rewriter) const override { 327 // We can only await on one the `AwaitableType` (for `await` it can be 328 // a `token` or a `value`, for `await_all` it must be a `group`). 329 if (!op.operand().getType().template isa<AwaitableType>()) 330 return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); 331 332 // Check if await operation is inside the outlined coroutine function. 333 auto func = op->template getParentOfType<FuncOp>(); 334 auto outlined = outlinedFunctions.find(func); 335 const bool isInCoroutine = outlined != outlinedFunctions.end(); 336 337 Location loc = op->getLoc(); 338 Value operand = AwaitAdaptor(operands).operand(); 339 340 // Inside regular functions we use the blocking wait operation to wait for 341 // the async object (token, value or group) to become available. 342 if (!isInCoroutine) 343 rewriter.create<RuntimeAwaitOp>(loc, operand); 344 345 // Inside the coroutine we convert await operation into coroutine suspension 346 // point, and resume execution asynchronously. 347 if (isInCoroutine) { 348 const CoroMachinery &coro = outlined->getSecond(); 349 Block *suspended = op->getBlock(); 350 351 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 352 MLIRContext *ctx = op->getContext(); 353 354 // Save the coroutine state and resume on a runtime managed thread when 355 // the operand becomes available. 356 auto coroSaveOp = 357 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 358 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); 359 360 // Split the entry block before the await operation. 361 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 362 363 // Add async.coro.suspend as a suspended block terminator. 364 builder.setInsertionPointToEnd(suspended); 365 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 366 coro.cleanup); 367 368 // Make sure that replacement value will be constructed in resume block. 369 rewriter.setInsertionPointToStart(resume); 370 } 371 372 // Erase or replace the await operation with the new value. 373 if (Value replaceWith = getReplacementValue(op, operand, rewriter)) 374 rewriter.replaceOp(op, replaceWith); 375 else 376 rewriter.eraseOp(op); 377 378 return success(); 379 } 380 381 virtual Value getReplacementValue(AwaitType op, Value operand, 382 ConversionPatternRewriter &rewriter) const { 383 return Value(); 384 } 385 386 private: 387 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 388 }; 389 390 /// Lowering for `async.await` with a token operand. 391 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 392 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 393 394 public: 395 using Base::Base; 396 }; 397 398 /// Lowering for `async.await` with a value operand. 399 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 400 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 401 402 public: 403 using Base::Base; 404 405 Value 406 getReplacementValue(AwaitOp op, Value operand, 407 ConversionPatternRewriter &rewriter) const override { 408 // Load from the async value storage. 409 auto valueType = operand.getType().cast<ValueType>().getValueType(); 410 return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); 411 } 412 }; 413 414 /// Lowering for `async.await_all` operation. 415 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 416 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 417 418 public: 419 using Base::Base; 420 }; 421 422 } // namespace 423 424 //===----------------------------------------------------------------------===// 425 // Convert async.yield operation to async.runtime operations. 426 //===----------------------------------------------------------------------===// 427 428 class YieldOpLowering : public OpConversionPattern<async::YieldOp> { 429 public: 430 YieldOpLowering( 431 MLIRContext *ctx, 432 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 433 : OpConversionPattern<async::YieldOp>(ctx), 434 outlinedFunctions(outlinedFunctions) {} 435 436 LogicalResult 437 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 438 ConversionPatternRewriter &rewriter) const override { 439 // Check if yield operation is inside the outlined coroutine function. 440 auto func = op->template getParentOfType<FuncOp>(); 441 auto outlined = outlinedFunctions.find(func); 442 if (outlined == outlinedFunctions.end()) 443 return rewriter.notifyMatchFailure( 444 op, "operation is not inside the outlined async.execute function"); 445 446 Location loc = op->getLoc(); 447 const CoroMachinery &coro = outlined->getSecond(); 448 449 // Store yielded values into the async values storage and switch async 450 // values state to available. 451 for (auto tuple : llvm::zip(operands, coro.returnValues)) { 452 Value yieldValue = std::get<0>(tuple); 453 Value asyncValue = std::get<1>(tuple); 454 rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); 455 rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); 456 } 457 458 // Switch the coroutine completion token to available state. 459 rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken); 460 461 return success(); 462 } 463 464 private: 465 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 466 }; 467 468 //===----------------------------------------------------------------------===// 469 470 void AsyncToAsyncRuntimePass::runOnOperation() { 471 ModuleOp module = getOperation(); 472 SymbolTable symbolTable(module); 473 474 // Outline all `async.execute` body regions into async functions (coroutines). 475 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 476 477 module.walk([&](ExecuteOp execute) { 478 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 479 }); 480 481 LLVM_DEBUG({ 482 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 483 << " functions built from async.execute operations\n"; 484 }); 485 486 // Lower async operations to async.runtime operations. 487 MLIRContext *ctx = module->getContext(); 488 OwningRewritePatternList asyncPatterns; 489 490 // Async lowering does not use type converter because it must preserve all 491 // types for async.runtime operations. 492 asyncPatterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 493 asyncPatterns.insert<AwaitTokenOpLowering, AwaitValueOpLowering, 494 AwaitAllOpLowering, YieldOpLowering>(ctx, 495 outlinedFunctions); 496 497 // All high level async operations must be lowered to the runtime operations. 498 ConversionTarget runtimeTarget(*ctx); 499 runtimeTarget.addLegalDialect<AsyncDialect>(); 500 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>(); 501 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>(); 502 503 if (failed(applyPartialConversion(module, runtimeTarget, 504 std::move(asyncPatterns)))) { 505 signalPassFailure(); 506 return; 507 } 508 } 509 510 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() { 511 return std::make_unique<AsyncToAsyncRuntimePass>(); 512 } 513