1 //===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements automatic reference counting for Async runtime 10 // operations and types. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Analysis/Liveness.h" 16 #include "mlir/Dialect/Async/IR/Async.h" 17 #include "mlir/Dialect/Async/Passes.h" 18 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/IR/ImplicitLocOpBuilder.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 #include "llvm/ADT/SmallSet.h" 24 25 using namespace mlir; 26 using namespace mlir::async; 27 28 #define DEBUG_TYPE "async-runtime-ref-counting" 29 30 //===----------------------------------------------------------------------===// 31 // Utility functions shared by reference counting passes. 32 //===----------------------------------------------------------------------===// 33 34 // Drop the reference count immediately if the value has no uses. 35 static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) { 36 if (!value.getUses().empty()) 37 return failure(); 38 39 OpBuilder b(value.getContext()); 40 41 // Set insertion point after the operation producing a value, or at the 42 // beginning of the block if the value defined by the block argument. 43 if (Operation *op = value.getDefiningOp()) 44 b.setInsertionPointAfter(op); 45 else 46 b.setInsertionPointToStart(value.getParentBlock()); 47 48 b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI64IntegerAttr(1)); 49 return success(); 50 } 51 52 // Calls `addRefCounting` for every reference counted value defined by the 53 // operation `op` (block arguments and values defined in nested regions). 54 static LogicalResult walkReferenceCountedValues( 55 Operation *op, llvm::function_ref<LogicalResult(Value)> addRefCounting) { 56 // Check that we do not have high level async operations in the IR because 57 // otherwise reference counting will produce incorrect results after high 58 // level async operations will be lowered to `async.runtime` 59 WalkResult checkNoAsyncWalk = op->walk([&](Operation *op) -> WalkResult { 60 if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op)) 61 return WalkResult::advance(); 62 63 return op->emitError() 64 << "async operations must be lowered to async runtime operations"; 65 }); 66 67 if (checkNoAsyncWalk.wasInterrupted()) 68 return failure(); 69 70 // Add reference counting to block arguments. 71 WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult { 72 for (BlockArgument arg : block->getArguments()) 73 if (isRefCounted(arg.getType())) 74 if (failed(addRefCounting(arg))) 75 return WalkResult::interrupt(); 76 77 return WalkResult::advance(); 78 }); 79 80 if (blockWalk.wasInterrupted()) 81 return failure(); 82 83 // Add reference counting to operation results. 84 WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult { 85 for (unsigned i = 0; i < op->getNumResults(); ++i) 86 if (isRefCounted(op->getResultTypes()[i])) 87 if (failed(addRefCounting(op->getResult(i)))) 88 return WalkResult::interrupt(); 89 90 return WalkResult::advance(); 91 }); 92 93 if (opWalk.wasInterrupted()) 94 return failure(); 95 96 return success(); 97 } 98 99 //===----------------------------------------------------------------------===// 100 // Automatic reference counting based on the liveness analysis. 101 //===----------------------------------------------------------------------===// 102 103 namespace { 104 105 class AsyncRuntimeRefCountingPass 106 : public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> { 107 public: 108 AsyncRuntimeRefCountingPass() = default; 109 void runOnOperation() override; 110 111 private: 112 /// Adds an automatic reference counting to the `value`. 113 /// 114 /// All values (token, group or value) are semantically created with a 115 /// reference count of +1 and it is the responsibility of the async value user 116 /// to place the `add_ref` and `drop_ref` operations to ensure that the value 117 /// is destroyed after the last use. 118 /// 119 /// The function returns failure if it can't deduce the locations where 120 /// to place the reference counting operations. 121 /// 122 /// Async values "semantically created" when: 123 /// 1. Operation returns async result (e.g. `async.runtime.create`) 124 /// 2. Async value passed in as a block argument (or function argument, 125 /// because function arguments are just entry block arguments) 126 /// 127 /// Passing async value as a function argument (or block argument) does not 128 /// really mean that a new async value is created, it only means that the 129 /// caller of a function transfered ownership of `+1` reference to the callee. 130 /// It is convenient to think that from the callee perspective async value was 131 /// "created" with `+1` reference by the block argument. 132 /// 133 /// Automatic reference counting algorithm outline: 134 /// 135 /// #1 Insert `drop_ref` operations after last use of the `value`. 136 /// #2 Insert `add_ref` operations before functions calls with reference 137 /// counted `value` operand (newly created `+1` reference will be 138 /// transferred to the callee). 139 /// #3 Verify that divergent control flow does not lead to leaked reference 140 /// counted objects. 141 /// 142 /// Async runtime reference counting optimization pass will optimize away 143 /// some of the redundant `add_ref` and `drop_ref` operations inserted by this 144 /// strategy (see `async-runtime-ref-counting-opt`). 145 LogicalResult addAutomaticRefCounting(Value value); 146 147 /// (#1) Adds the `drop_ref` operation after the last use of the `value` 148 /// relying on the liveness analysis. 149 /// 150 /// If the `value` is in the block `liveIn` set and it is not in the block 151 /// `liveOut` set, it means that it "dies" in the block. We find the last 152 /// use of the value in such block and: 153 /// 154 /// 1. If the last user is a `ReturnLike` operation we do nothing, because 155 /// it forwards the ownership to the caller. 156 /// 2. Otherwise we add a `drop_ref` operation immediately after the last 157 /// use. 158 LogicalResult addDropRefAfterLastUse(Value value); 159 160 /// (#2) Adds the `add_ref` operation before the function call taking `value` 161 /// operand to ensure that the value passed to the function entry block 162 /// has a `+1` reference count. 163 LogicalResult addAddRefBeforeFunctionCall(Value value); 164 165 /// (#3) Adds the `drop_ref` operation to account for successor blocks with 166 /// divergent `liveIn` property: `value` is not in the `liveIn` set of all 167 /// successor blocks. 168 /// 169 /// Example: 170 /// 171 /// ^entry: 172 /// %token = async.runtime.create : !async.token 173 /// cf.cond_br %cond, ^bb1, ^bb2 174 /// ^bb1: 175 /// async.runtime.await %token 176 /// async.runtime.drop_ref %token 177 /// cf.br ^bb2 178 /// ^bb2: 179 /// return 180 /// 181 /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have 182 /// to branch into a special "reference counting block" from the ^entry that 183 /// will have a `drop_ref` operation, and then branch into the ^bb2. 184 /// 185 /// After transformation: 186 /// 187 /// ^entry: 188 /// %token = async.runtime.create : !async.token 189 /// cf.cond_br %cond, ^bb1, ^reference_counting 190 /// ^bb1: 191 /// async.runtime.await %token 192 /// async.runtime.drop_ref %token 193 /// cf.br ^bb2 194 /// ^reference_counting: 195 /// async.runtime.drop_ref %token 196 /// cf.br ^bb2 197 /// ^bb2: 198 /// return 199 /// 200 /// An exception to this rule are blocks with `async.coro.suspend` terminator, 201 /// because in Async to LLVM lowering it is guaranteed that the control flow 202 /// will jump into the resume block, and then follow into the cleanup and 203 /// suspend blocks. 204 /// 205 /// Example: 206 /// 207 /// ^entry(%value: !async.value<f32>): 208 /// async.runtime.await_and_resume %value, %hdl : !async.value<f32> 209 /// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup 210 /// ^resume: 211 /// %0 = async.runtime.load %value 212 /// cf.br ^cleanup 213 /// ^cleanup: 214 /// ... 215 /// ^suspend: 216 /// ... 217 /// 218 /// Although cleanup and suspend blocks do not have the `value` in the 219 /// `liveIn` set, it is guaranteed that execution will eventually continue in 220 /// the resume block (we never explicitly destroy coroutines). 221 LogicalResult addDropRefInDivergentLivenessSuccessor(Value value); 222 }; 223 224 } // namespace 225 226 LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) { 227 OpBuilder builder(value.getContext()); 228 Location loc = value.getLoc(); 229 230 // Use liveness analysis to find the placement of `drop_ref`operation. 231 auto &liveness = getAnalysis<Liveness>(); 232 233 // We analyse only the blocks of the region that defines the `value`, and do 234 // not check nested blocks attached to operations. 235 // 236 // By analyzing only the `definingRegion` CFG we potentially loose an 237 // opportunity to drop the reference count earlier and can extend the lifetime 238 // of reference counted value longer then it is really required. 239 // 240 // We also assume that all nested regions finish their execution before the 241 // completion of the owner operation. The only exception to this rule is 242 // `async.execute` operation, and we verify that they are lowered to the 243 // `async.runtime` operations before adding automatic reference counting. 244 Region *definingRegion = value.getParentRegion(); 245 246 // Last users of the `value` inside all blocks where the value dies. 247 llvm::SmallSet<Operation *, 4> lastUsers; 248 249 // Find blocks in the `definingRegion` that have users of the `value` (if 250 // there are multiple users in the block, which one will be selected is 251 // undefined). User operation might be not the actual user of the value, but 252 // the operation in the block that has a "real user" in one of the attached 253 // regions. 254 llvm::DenseMap<Block *, Operation *> usersInTheBlocks; 255 256 for (Operation *user : value.getUsers()) { 257 Block *userBlock = user->getBlock(); 258 Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock); 259 usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user); 260 assert(ancestor && "ancestor block must be not null"); 261 assert(usersInTheBlocks[ancestor] && "ancestor op must be not null"); 262 } 263 264 // Find blocks where the `value` dies: the value is in `liveIn` set and not 265 // in the `liveOut` set. We place `drop_ref` immediately after the last use 266 // of the `value` in such regions (after handling few special cases). 267 // 268 // We do not traverse all the blocks in the `definingRegion`, because the 269 // `value` can be in the live in set only if it has users in the block, or it 270 // is defined in the block. 271 // 272 // Values with zero users (only definition) handled explicitly above. 273 for (auto &blockAndUser : usersInTheBlocks) { 274 Block *block = blockAndUser.getFirst(); 275 Operation *userInTheBlock = blockAndUser.getSecond(); 276 277 const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block); 278 279 // Value must be in the live input set or defined in the block. 280 assert(blockLiveness->isLiveIn(value) || 281 blockLiveness->getBlock() == value.getParentBlock()); 282 283 // If value is in the live out set, it means it doesn't "die" in the block. 284 if (blockLiveness->isLiveOut(value)) 285 continue; 286 287 // At this point we proved that `value` dies in the `block`. Find the last 288 // use of the `value` inside the `block`, this is where it "dies". 289 Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock); 290 assert(lastUsers.count(lastUser) == 0 && "last users must be unique"); 291 lastUsers.insert(lastUser); 292 } 293 294 // Process all the last users of the `value` inside each block where the value 295 // dies. 296 for (Operation *lastUser : lastUsers) { 297 // Return like operations forward reference count. 298 if (lastUser->hasTrait<OpTrait::ReturnLike>()) 299 continue; 300 301 // We can't currently handle other types of terminators. 302 if (lastUser->hasTrait<OpTrait::IsTerminator>()) 303 return lastUser->emitError() << "async reference counting can't handle " 304 "terminators that are not ReturnLike"; 305 306 // Add a drop_ref immediately after the last user. 307 builder.setInsertionPointAfter(lastUser); 308 builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1)); 309 } 310 311 return success(); 312 } 313 314 LogicalResult 315 AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) { 316 OpBuilder builder(value.getContext()); 317 Location loc = value.getLoc(); 318 319 for (Operation *user : value.getUsers()) { 320 if (!isa<CallOp>(user)) 321 continue; 322 323 // Add a reference before the function call to pass the value at `+1` 324 // reference to the function entry block. 325 builder.setInsertionPoint(user); 326 builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1)); 327 } 328 329 return success(); 330 } 331 332 LogicalResult 333 AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor( 334 Value value) { 335 using BlockSet = llvm::SmallPtrSet<Block *, 4>; 336 337 OpBuilder builder(value.getContext()); 338 339 // If a block has successors with different `liveIn` property of the `value`, 340 // record block successors that do not thave the `value` in the `liveIn` set. 341 llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks; 342 343 // Use liveness analysis to find the placement of `drop_ref`operation. 344 auto &liveness = getAnalysis<Liveness>(); 345 346 // Because we only add `drop_ref` operations to the region that defines the 347 // `value` we can only process CFG for the same region. 348 Region *definingRegion = value.getParentRegion(); 349 350 // Collect blocks with successors with mismatching `liveIn` sets. 351 for (Block &block : definingRegion->getBlocks()) { 352 const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); 353 354 // Skip the block if value is not in the `liveOut` set. 355 if (!blockLiveness || !blockLiveness->isLiveOut(value)) 356 continue; 357 358 BlockSet liveInSuccessors; // `value` is in `liveIn` set 359 BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set 360 361 // Collect successors that do not have `value` in the `liveIn` set. 362 for (Block *successor : block.getSuccessors()) { 363 const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor); 364 if (succLiveness && succLiveness->isLiveIn(value)) 365 liveInSuccessors.insert(successor); 366 else 367 noLiveInSuccessors.insert(successor); 368 } 369 370 // Block has successors with different `liveIn` property of the `value`. 371 if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty()) 372 divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors); 373 } 374 375 // Try to insert `dropRef` operations to handle blocks with divergent liveness 376 // in successors blocks. 377 for (auto kv : divergentLivenessBlocks) { 378 Block *block = kv.getFirst(); 379 BlockSet &successors = kv.getSecond(); 380 381 // Coroutine suspension is a special case terminator for wich we do not 382 // need to create additional reference counting (see details above). 383 Operation *terminator = block->getTerminator(); 384 if (isa<CoroSuspendOp>(terminator)) 385 continue; 386 387 // We only support successor blocks with empty block argument list. 388 auto hasArgs = [](Block *block) { return !block->getArguments().empty(); }; 389 if (llvm::any_of(successors, hasArgs)) 390 return terminator->emitOpError() 391 << "successor have different `liveIn` property of the reference " 392 "counted value"; 393 394 // Make sure that `dropRef` operation is called when branched into the 395 // successor block without `value` in the `liveIn` set. 396 for (Block *successor : successors) { 397 // If successor has a unique predecessor, it is safe to create `dropRef` 398 // operations directly in the successor block. 399 // 400 // Otherwise we need to create a special block for reference counting 401 // operations, and branch from it to the original successor block. 402 Block *refCountingBlock = nullptr; 403 404 if (successor->getUniquePredecessor() == block) { 405 refCountingBlock = successor; 406 } else { 407 refCountingBlock = &successor->getParent()->emplaceBlock(); 408 refCountingBlock->moveBefore(successor); 409 OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock); 410 builder.create<cf::BranchOp>(value.getLoc(), successor); 411 } 412 413 OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock); 414 builder.create<RuntimeDropRefOp>(value.getLoc(), value, 415 builder.getI64IntegerAttr(1)); 416 417 // No need to update the terminator operation. 418 if (successor == refCountingBlock) 419 continue; 420 421 // Update terminator `successor` block to `refCountingBlock`. 422 for (const auto &pair : llvm::enumerate(terminator->getSuccessors())) 423 if (pair.value() == successor) 424 terminator->setSuccessor(refCountingBlock, pair.index()); 425 } 426 } 427 428 return success(); 429 } 430 431 LogicalResult 432 AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) { 433 // Short-circuit reference counting for values without uses. 434 if (succeeded(dropRefIfNoUses(value))) 435 return success(); 436 437 // Add `drop_ref` operations based on the liveness analysis. 438 if (failed(addDropRefAfterLastUse(value))) 439 return failure(); 440 441 // Add `add_ref` operations before function calls. 442 if (failed(addAddRefBeforeFunctionCall(value))) 443 return failure(); 444 445 // Add `drop_ref` operations to successors with divergent `value` liveness. 446 if (failed(addDropRefInDivergentLivenessSuccessor(value))) 447 return failure(); 448 449 return success(); 450 } 451 452 void AsyncRuntimeRefCountingPass::runOnOperation() { 453 auto functor = [&](Value value) { return addAutomaticRefCounting(value); }; 454 if (failed(walkReferenceCountedValues(getOperation(), functor))) 455 signalPassFailure(); 456 } 457 458 //===----------------------------------------------------------------------===// 459 // Reference counting based on the user defined policy. 460 //===----------------------------------------------------------------------===// 461 462 namespace { 463 464 class AsyncRuntimePolicyBasedRefCountingPass 465 : public AsyncRuntimePolicyBasedRefCountingBase< 466 AsyncRuntimePolicyBasedRefCountingPass> { 467 public: 468 AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); } 469 470 void runOnOperation() override; 471 472 private: 473 // Adds a reference counting operations for all uses of the `value` according 474 // to the reference counting policy. 475 LogicalResult addRefCounting(Value value); 476 477 void initializeDefaultPolicy(); 478 479 llvm::SmallVector<std::function<FailureOr<int>(OpOperand &)>> policy; 480 }; 481 482 } // namespace 483 484 LogicalResult 485 AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) { 486 // Short-circuit reference counting for values without uses. 487 if (succeeded(dropRefIfNoUses(value))) 488 return success(); 489 490 OpBuilder b(value.getContext()); 491 492 // Consult the user defined policy for every value use. 493 for (OpOperand &operand : value.getUses()) { 494 Location loc = operand.getOwner()->getLoc(); 495 496 for (auto &func : policy) { 497 FailureOr<int> refCount = func(operand); 498 if (failed(refCount)) 499 return failure(); 500 501 int cnt = refCount.getValue(); 502 503 // Create `add_ref` operation before the operand owner. 504 if (cnt > 0) { 505 b.setInsertionPoint(operand.getOwner()); 506 b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt)); 507 } 508 509 // Create `drop_ref` operation after the operand owner. 510 if (cnt < 0) { 511 b.setInsertionPointAfter(operand.getOwner()); 512 b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt)); 513 } 514 } 515 } 516 517 return success(); 518 } 519 520 void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() { 521 policy.push_back([](OpOperand &operand) -> FailureOr<int> { 522 Operation *op = operand.getOwner(); 523 Type type = operand.get().getType(); 524 525 bool isToken = type.isa<TokenType>(); 526 bool isGroup = type.isa<GroupType>(); 527 bool isValue = type.isa<ValueType>(); 528 529 // Drop reference after async token or group error check (coro await). 530 if (auto await = dyn_cast<RuntimeIsErrorOp>(op)) 531 return (isToken || isGroup) ? -1 : 0; 532 533 // Drop reference after async value load. 534 if (auto load = dyn_cast<RuntimeLoadOp>(op)) 535 return isValue ? -1 : 0; 536 537 // Drop reference after async token added to the group. 538 if (auto add = dyn_cast<RuntimeAddToGroupOp>(op)) 539 return isToken ? -1 : 0; 540 541 return 0; 542 }); 543 } 544 545 void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() { 546 auto functor = [&](Value value) { return addRefCounting(value); }; 547 if (failed(walkReferenceCountedValues(getOperation(), functor))) 548 signalPassFailure(); 549 } 550 551 //----------------------------------------------------------------------------// 552 553 std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() { 554 return std::make_unique<AsyncRuntimeRefCountingPass>(); 555 } 556 557 std::unique_ptr<Pass> mlir::createAsyncRuntimePolicyBasedRefCountingPass() { 558 return std::make_unique<AsyncRuntimePolicyBasedRefCountingPass>(); 559 } 560