1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering from high level async operations to async.coro 10 // and async.runtime operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Async/IR/Async.h" 16 #include "mlir/Dialect/Async/Passes.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/IR/BlockAndValueMapping.h" 19 #include "mlir/IR/ImplicitLocOpBuilder.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "mlir/Transforms/RegionUtils.h" 23 #include "llvm/ADT/SetVector.h" 24 #include "llvm/Support/Debug.h" 25 26 using namespace mlir; 27 using namespace mlir::async; 28 29 #define DEBUG_TYPE "async-to-async-runtime" 30 // Prefix for functions outlined from `async.execute` op regions. 31 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 32 33 namespace { 34 35 class AsyncToAsyncRuntimePass 36 : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> { 37 public: 38 AsyncToAsyncRuntimePass() = default; 39 void runOnOperation() override; 40 }; 41 42 } // namespace 43 44 //===----------------------------------------------------------------------===// 45 // async.execute op outlining to the coroutine functions. 46 //===----------------------------------------------------------------------===// 47 48 /// Function targeted for coroutine transformation has two additional blocks at 49 /// the end: coroutine cleanup and coroutine suspension. 50 /// 51 /// async.await op lowering additionaly creates a resume block for each 52 /// operation to enable non-blocking waiting via coroutine suspension. 53 namespace { 54 struct CoroMachinery { 55 // Async execute region returns a completion token, and an async value for 56 // each yielded value. 57 // 58 // %token, %result = async.execute -> !async.value<T> { 59 // %0 = constant ... : T 60 // async.yield %0 : T 61 // } 62 Value asyncToken; // token representing completion of the async region 63 llvm::SmallVector<Value, 4> returnValues; // returned async values 64 65 Value coroHandle; // coroutine handle (!async.coro.handle value) 66 Block *cleanup; // coroutine cleanup block 67 Block *suspend; // coroutine suspension block 68 }; 69 } // namespace 70 71 /// Builds an coroutine template compatible with LLVM coroutines switched-resume 72 /// lowering using `async.runtime.*` and `async.coro.*` operations. 73 /// 74 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html 75 /// 76 /// - `entry` block sets up the coroutine. 77 /// - `cleanup` block cleans up the coroutine state. 78 /// - `suspend block after the @llvm.coro.end() defines what value will be 79 /// returned to the initial caller of a coroutine. Everything before the 80 /// @llvm.coro.end() will be executed at every suspension point. 81 /// 82 /// Coroutine structure (only the important bits): 83 /// 84 /// func @async_execute_fn(<function-arguments>) 85 /// -> (!async.token, !async.value<T>) 86 /// { 87 /// ^entry(<function-arguments>): 88 /// %token = <async token> : !async.token // create async runtime token 89 /// %value = <async value> : !async.value<T> // create async value 90 /// %id = async.coro.id // create a coroutine id 91 /// %hdl = async.coro.begin %id // create a coroutine handle 92 /// br ^cleanup 93 /// 94 /// ^cleanup: 95 /// async.coro.free %hdl // delete the coroutine state 96 /// br ^suspend 97 /// 98 /// ^suspend: 99 /// async.coro.end %hdl // marks the end of a coroutine 100 /// return %token, %value : !async.token, !async.value<T> 101 /// } 102 /// 103 /// The actual code for the async.execute operation body region will be inserted 104 /// before the entry block terminator. 105 /// 106 /// 107 static CoroMachinery setupCoroMachinery(FuncOp func) { 108 assert(func.getBody().empty() && "Function must have empty body"); 109 110 MLIRContext *ctx = func.getContext(); 111 Block *entryBlock = func.addEntryBlock(); 112 113 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); 114 115 // ------------------------------------------------------------------------ // 116 // Allocate async token/values that we will return from a ramp function. 117 // ------------------------------------------------------------------------ // 118 auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result(); 119 120 llvm::SmallVector<Value, 4> retValues; 121 for (auto resType : func.getCallableResults().drop_front()) 122 retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result()); 123 124 // ------------------------------------------------------------------------ // 125 // Initialize coroutine: get coroutine id and coroutine handle. 126 // ------------------------------------------------------------------------ // 127 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); 128 auto coroHdlOp = 129 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); 130 131 Block *cleanupBlock = func.addBlock(); 132 Block *suspendBlock = func.addBlock(); 133 134 // ------------------------------------------------------------------------ // 135 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 136 // ------------------------------------------------------------------------ // 137 builder.setInsertionPointToStart(cleanupBlock); 138 builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); 139 140 // Branch into the suspend block. 141 builder.create<BranchOp>(suspendBlock); 142 143 // ------------------------------------------------------------------------ // 144 // Coroutine suspend block: mark the end of a coroutine and return allocated 145 // async token. 146 // ------------------------------------------------------------------------ // 147 builder.setInsertionPointToStart(suspendBlock); 148 149 // Mark the end of a coroutine: async.coro.end 150 builder.create<CoroEndOp>(coroHdlOp.handle()); 151 152 // Return created `async.token` and `async.values` from the suspend block. 153 // This will be the return value of a coroutine ramp function. 154 SmallVector<Value, 4> ret{retToken}; 155 ret.insert(ret.end(), retValues.begin(), retValues.end()); 156 builder.create<ReturnOp>(ret); 157 158 // Branch from the entry block to the cleanup block to create a valid CFG. 159 builder.setInsertionPointToEnd(entryBlock); 160 builder.create<BranchOp>(cleanupBlock); 161 162 // `async.await` op lowering will create resume blocks for async 163 // continuations, and will conditionally branch to cleanup or suspend blocks. 164 165 CoroMachinery machinery; 166 machinery.asyncToken = retToken; 167 machinery.returnValues = retValues; 168 machinery.coroHandle = coroHdlOp.handle(); 169 machinery.cleanup = cleanupBlock; 170 machinery.suspend = suspendBlock; 171 return machinery; 172 } 173 174 /// Outline the body region attached to the `async.execute` op into a standalone 175 /// function. 176 /// 177 /// Note that this is not reversible transformation. 178 static std::pair<FuncOp, CoroMachinery> 179 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 180 ModuleOp module = execute->getParentOfType<ModuleOp>(); 181 182 MLIRContext *ctx = module.getContext(); 183 Location loc = execute.getLoc(); 184 185 // Collect all outlined function inputs. 186 SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 187 execute.dependencies().end()); 188 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 189 getUsedValuesDefinedAbove(execute.body(), functionInputs); 190 191 // Collect types for the outlined function inputs and outputs. 192 auto typesRange = llvm::map_range( 193 functionInputs, [](Value value) { return value.getType(); }); 194 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 195 auto outputTypes = execute.getResultTypes(); 196 197 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 198 auto funcAttrs = ArrayRef<NamedAttribute>(); 199 200 // TODO: Derive outlined function name from the parent FuncOp (support 201 // multiple nested async.execute operations). 202 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 203 symbolTable.insert(func); 204 205 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 206 207 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 208 // blocks, adding async.coro operations and setting up control flow. 209 CoroMachinery coro = setupCoroMachinery(func); 210 211 // Suspend async function at the end of an entry block, and resume it using 212 // Async resume operation (execution will be resumed in a thread managed by 213 // the async runtime). 214 Block *entryBlock = &func.getBlocks().front(); 215 auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock); 216 217 // Save the coroutine state: async.coro.save 218 auto coroSaveOp = 219 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 220 221 // Pass coroutine to the runtime to be resumed on a runtime managed thread. 222 builder.create<RuntimeResumeOp>(coro.coroHandle); 223 224 // Split the entry block before the terminator (branch to suspend block). 225 auto *terminatorOp = entryBlock->getTerminator(); 226 Block *suspended = terminatorOp->getBlock(); 227 Block *resume = suspended->splitBlock(terminatorOp); 228 229 // Add async.coro.suspend as a suspended block terminator. 230 builder.setInsertionPointToEnd(suspended); 231 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 232 coro.cleanup); 233 234 size_t numDependencies = execute.dependencies().size(); 235 size_t numOperands = execute.operands().size(); 236 237 // Await on all dependencies before starting to execute the body region. 238 builder.setInsertionPointToStart(resume); 239 for (size_t i = 0; i < numDependencies; ++i) 240 builder.create<AwaitOp>(func.getArgument(i)); 241 242 // Await on all async value operands and unwrap the payload. 243 SmallVector<Value, 4> unwrappedOperands(numOperands); 244 for (size_t i = 0; i < numOperands; ++i) { 245 Value operand = func.getArgument(numDependencies + i); 246 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 247 } 248 249 // Map from function inputs defined above the execute op to the function 250 // arguments. 251 BlockAndValueMapping valueMapping; 252 valueMapping.map(functionInputs, func.getArguments()); 253 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 254 255 // Clone all operations from the execute operation body into the outlined 256 // function body. 257 for (Operation &op : execute.body().getOps()) 258 builder.clone(op, valueMapping); 259 260 // Replace the original `async.execute` with a call to outlined function. 261 ImplicitLocOpBuilder callBuilder(loc, execute); 262 auto callOutlinedFunc = callBuilder.create<CallOp>( 263 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 264 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 265 execute.erase(); 266 267 return {func, coro}; 268 } 269 270 //===----------------------------------------------------------------------===// 271 // Convert async.create_group operation to async.runtime.create 272 //===----------------------------------------------------------------------===// 273 274 namespace { 275 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { 276 public: 277 using OpConversionPattern::OpConversionPattern; 278 279 LogicalResult 280 matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands, 281 ConversionPatternRewriter &rewriter) const override { 282 rewriter.replaceOpWithNewOp<RuntimeCreateOp>( 283 op, GroupType::get(op->getContext())); 284 return success(); 285 } 286 }; 287 } // namespace 288 289 //===----------------------------------------------------------------------===// 290 // Convert async.add_to_group operation to async.runtime.add_to_group. 291 //===----------------------------------------------------------------------===// 292 293 namespace { 294 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { 295 public: 296 using OpConversionPattern::OpConversionPattern; 297 298 LogicalResult 299 matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands, 300 ConversionPatternRewriter &rewriter) const override { 301 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>( 302 op, rewriter.getIndexType(), operands); 303 return success(); 304 } 305 }; 306 } // namespace 307 308 //===----------------------------------------------------------------------===// 309 // Convert async.await and async.await_all operations to the async.runtime.await 310 // or async.runtime.await_and_resume operations. 311 //===----------------------------------------------------------------------===// 312 313 namespace { 314 template <typename AwaitType, typename AwaitableType> 315 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { 316 using AwaitAdaptor = typename AwaitType::Adaptor; 317 318 public: 319 AwaitOpLoweringBase( 320 MLIRContext *ctx, 321 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 322 : OpConversionPattern<AwaitType>(ctx), 323 outlinedFunctions(outlinedFunctions) {} 324 325 LogicalResult 326 matchAndRewrite(AwaitType op, ArrayRef<Value> operands, 327 ConversionPatternRewriter &rewriter) const override { 328 // We can only await on one the `AwaitableType` (for `await` it can be 329 // a `token` or a `value`, for `await_all` it must be a `group`). 330 if (!op.operand().getType().template isa<AwaitableType>()) 331 return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); 332 333 // Check if await operation is inside the outlined coroutine function. 334 auto func = op->template getParentOfType<FuncOp>(); 335 auto outlined = outlinedFunctions.find(func); 336 const bool isInCoroutine = outlined != outlinedFunctions.end(); 337 338 Location loc = op->getLoc(); 339 Value operand = AwaitAdaptor(operands).operand(); 340 341 // Inside regular functions we use the blocking wait operation to wait for 342 // the async object (token, value or group) to become available. 343 if (!isInCoroutine) 344 rewriter.create<RuntimeAwaitOp>(loc, operand); 345 346 // Inside the coroutine we convert await operation into coroutine suspension 347 // point, and resume execution asynchronously. 348 if (isInCoroutine) { 349 const CoroMachinery &coro = outlined->getSecond(); 350 Block *suspended = op->getBlock(); 351 352 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 353 MLIRContext *ctx = op->getContext(); 354 355 // Save the coroutine state and resume on a runtime managed thread when 356 // the operand becomes available. 357 auto coroSaveOp = 358 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 359 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); 360 361 // Split the entry block before the await operation. 362 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 363 364 // Add async.coro.suspend as a suspended block terminator. 365 builder.setInsertionPointToEnd(suspended); 366 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 367 coro.cleanup); 368 369 // Make sure that replacement value will be constructed in resume block. 370 rewriter.setInsertionPointToStart(resume); 371 } 372 373 // Erase or replace the await operation with the new value. 374 if (Value replaceWith = getReplacementValue(op, operand, rewriter)) 375 rewriter.replaceOp(op, replaceWith); 376 else 377 rewriter.eraseOp(op); 378 379 return success(); 380 } 381 382 virtual Value getReplacementValue(AwaitType op, Value operand, 383 ConversionPatternRewriter &rewriter) const { 384 return Value(); 385 } 386 387 private: 388 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 389 }; 390 391 /// Lowering for `async.await` with a token operand. 392 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 393 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 394 395 public: 396 using Base::Base; 397 }; 398 399 /// Lowering for `async.await` with a value operand. 400 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 401 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 402 403 public: 404 using Base::Base; 405 406 Value 407 getReplacementValue(AwaitOp op, Value operand, 408 ConversionPatternRewriter &rewriter) const override { 409 // Load from the async value storage. 410 auto valueType = operand.getType().cast<ValueType>().getValueType(); 411 return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); 412 } 413 }; 414 415 /// Lowering for `async.await_all` operation. 416 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 417 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 418 419 public: 420 using Base::Base; 421 }; 422 423 } // namespace 424 425 //===----------------------------------------------------------------------===// 426 // Convert async.yield operation to async.runtime operations. 427 //===----------------------------------------------------------------------===// 428 429 class YieldOpLowering : public OpConversionPattern<async::YieldOp> { 430 public: 431 YieldOpLowering( 432 MLIRContext *ctx, 433 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 434 : OpConversionPattern<async::YieldOp>(ctx), 435 outlinedFunctions(outlinedFunctions) {} 436 437 LogicalResult 438 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 439 ConversionPatternRewriter &rewriter) const override { 440 // Check if yield operation is inside the outlined coroutine function. 441 auto func = op->template getParentOfType<FuncOp>(); 442 auto outlined = outlinedFunctions.find(func); 443 if (outlined == outlinedFunctions.end()) 444 return rewriter.notifyMatchFailure( 445 op, "operation is not inside the outlined async.execute function"); 446 447 Location loc = op->getLoc(); 448 const CoroMachinery &coro = outlined->getSecond(); 449 450 // Store yielded values into the async values storage and switch async 451 // values state to available. 452 for (auto tuple : llvm::zip(operands, coro.returnValues)) { 453 Value yieldValue = std::get<0>(tuple); 454 Value asyncValue = std::get<1>(tuple); 455 rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); 456 rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); 457 } 458 459 // Switch the coroutine completion token to available state. 460 rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken); 461 462 return success(); 463 } 464 465 private: 466 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 467 }; 468 469 //===----------------------------------------------------------------------===// 470 471 void AsyncToAsyncRuntimePass::runOnOperation() { 472 ModuleOp module = getOperation(); 473 SymbolTable symbolTable(module); 474 475 // Outline all `async.execute` body regions into async functions (coroutines). 476 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 477 478 module.walk([&](ExecuteOp execute) { 479 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 480 }); 481 482 LLVM_DEBUG({ 483 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 484 << " functions built from async.execute operations\n"; 485 }); 486 487 // Lower async operations to async.runtime operations. 488 MLIRContext *ctx = module->getContext(); 489 RewritePatternSet asyncPatterns(ctx); 490 491 // Async lowering does not use type converter because it must preserve all 492 // types for async.runtime operations. 493 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 494 asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, 495 AwaitAllOpLowering, YieldOpLowering>(ctx, 496 outlinedFunctions); 497 498 // All high level async operations must be lowered to the runtime operations. 499 ConversionTarget runtimeTarget(*ctx); 500 runtimeTarget.addLegalDialect<AsyncDialect>(); 501 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>(); 502 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>(); 503 504 if (failed(applyPartialConversion(module, runtimeTarget, 505 std::move(asyncPatterns)))) { 506 signalPassFailure(); 507 return; 508 } 509 } 510 511 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() { 512 return std::make_unique<AsyncToAsyncRuntimePass>(); 513 } 514