1 //===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements automatic reference counting for Async runtime 10 // operations and types. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Analysis/Liveness.h" 16 #include "mlir/Dialect/Async/IR/Async.h" 17 #include "mlir/Dialect/Async/Passes.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/IR/ImplicitLocOpBuilder.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 #include "llvm/ADT/SmallSet.h" 23 24 using namespace mlir; 25 using namespace mlir::async; 26 27 #define DEBUG_TYPE "async-runtime-ref-counting" 28 29 //===----------------------------------------------------------------------===// 30 // Utility functions shared by reference counting passes. 31 //===----------------------------------------------------------------------===// 32 33 // Drop the reference count immediately if the value has no uses. 34 static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) { 35 if (!value.getUses().empty()) 36 return failure(); 37 38 OpBuilder b(value.getContext()); 39 40 // Set insertion point after the operation producing a value, or at the 41 // beginning of the block if the value defined by the block argument. 42 if (Operation *op = value.getDefiningOp()) 43 b.setInsertionPointAfter(op); 44 else 45 b.setInsertionPointToStart(value.getParentBlock()); 46 47 b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI64IntegerAttr(1)); 48 return success(); 49 } 50 51 // Calls `addRefCounting` for every reference counted value defined by the 52 // operation `op` (block arguments and values defined in nested regions). 53 static LogicalResult walkReferenceCountedValues( 54 Operation *op, llvm::function_ref<LogicalResult(Value)> addRefCounting) { 55 // Check that we do not have high level async operations in the IR because 56 // otherwise reference counting will produce incorrect results after high 57 // level async operations will be lowered to `async.runtime` 58 WalkResult checkNoAsyncWalk = op->walk([&](Operation *op) -> WalkResult { 59 if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op)) 60 return WalkResult::advance(); 61 62 return op->emitError() 63 << "async operations must be lowered to async runtime operations"; 64 }); 65 66 if (checkNoAsyncWalk.wasInterrupted()) 67 return failure(); 68 69 // Add reference counting to block arguments. 70 WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult { 71 for (BlockArgument arg : block->getArguments()) 72 if (isRefCounted(arg.getType())) 73 if (failed(addRefCounting(arg))) 74 return WalkResult::interrupt(); 75 76 return WalkResult::advance(); 77 }); 78 79 if (blockWalk.wasInterrupted()) 80 return failure(); 81 82 // Add reference counting to operation results. 83 WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult { 84 for (unsigned i = 0; i < op->getNumResults(); ++i) 85 if (isRefCounted(op->getResultTypes()[i])) 86 if (failed(addRefCounting(op->getResult(i)))) 87 return WalkResult::interrupt(); 88 89 return WalkResult::advance(); 90 }); 91 92 if (opWalk.wasInterrupted()) 93 return failure(); 94 95 return success(); 96 } 97 98 //===----------------------------------------------------------------------===// 99 // Automatic reference counting based on the liveness analysis. 100 //===----------------------------------------------------------------------===// 101 102 namespace { 103 104 class AsyncRuntimeRefCountingPass 105 : public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> { 106 public: 107 AsyncRuntimeRefCountingPass() = default; 108 void runOnOperation() override; 109 110 private: 111 /// Adds an automatic reference counting to the `value`. 112 /// 113 /// All values (token, group or value) are semantically created with a 114 /// reference count of +1 and it is the responsibility of the async value user 115 /// to place the `add_ref` and `drop_ref` operations to ensure that the value 116 /// is destroyed after the last use. 117 /// 118 /// The function returns failure if it can't deduce the locations where 119 /// to place the reference counting operations. 120 /// 121 /// Async values "semantically created" when: 122 /// 1. Operation returns async result (e.g. `async.runtime.create`) 123 /// 2. Async value passed in as a block argument (or function argument, 124 /// because function arguments are just entry block arguments) 125 /// 126 /// Passing async value as a function argument (or block argument) does not 127 /// really mean that a new async value is created, it only means that the 128 /// caller of a function transfered ownership of `+1` reference to the callee. 129 /// It is convenient to think that from the callee perspective async value was 130 /// "created" with `+1` reference by the block argument. 131 /// 132 /// Automatic reference counting algorithm outline: 133 /// 134 /// #1 Insert `drop_ref` operations after last use of the `value`. 135 /// #2 Insert `add_ref` operations before functions calls with reference 136 /// counted `value` operand (newly created `+1` reference will be 137 /// transferred to the callee). 138 /// #3 Verify that divergent control flow does not lead to leaked reference 139 /// counted objects. 140 /// 141 /// Async runtime reference counting optimization pass will optimize away 142 /// some of the redundant `add_ref` and `drop_ref` operations inserted by this 143 /// strategy (see `async-runtime-ref-counting-opt`). 144 LogicalResult addAutomaticRefCounting(Value value); 145 146 /// (#1) Adds the `drop_ref` operation after the last use of the `value` 147 /// relying on the liveness analysis. 148 /// 149 /// If the `value` is in the block `liveIn` set and it is not in the block 150 /// `liveOut` set, it means that it "dies" in the block. We find the last 151 /// use of the value in such block and: 152 /// 153 /// 1. If the last user is a `ReturnLike` operation we do nothing, because 154 /// it forwards the ownership to the caller. 155 /// 2. Otherwise we add a `drop_ref` operation immediately after the last 156 /// use. 157 LogicalResult addDropRefAfterLastUse(Value value); 158 159 /// (#2) Adds the `add_ref` operation before the function call taking `value` 160 /// operand to ensure that the value passed to the function entry block 161 /// has a `+1` reference count. 162 LogicalResult addAddRefBeforeFunctionCall(Value value); 163 164 /// (#3) Adds the `drop_ref` operation to account for successor blocks with 165 /// divergent `liveIn` property: `value` is not in the `liveIn` set of all 166 /// successor blocks. 167 /// 168 /// Example: 169 /// 170 /// ^entry: 171 /// %token = async.runtime.create : !async.token 172 /// cond_br %cond, ^bb1, ^bb2 173 /// ^bb1: 174 /// async.runtime.await %token 175 /// async.runtime.drop_ref %token 176 /// br ^bb2 177 /// ^bb2: 178 /// return 179 /// 180 /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have 181 /// to branch into a special "reference counting block" from the ^entry that 182 /// will have a `drop_ref` operation, and then branch into the ^bb2. 183 /// 184 /// After transformation: 185 /// 186 /// ^entry: 187 /// %token = async.runtime.create : !async.token 188 /// cond_br %cond, ^bb1, ^reference_counting 189 /// ^bb1: 190 /// async.runtime.await %token 191 /// async.runtime.drop_ref %token 192 /// br ^bb2 193 /// ^reference_counting: 194 /// async.runtime.drop_ref %token 195 /// br ^bb2 196 /// ^bb2: 197 /// return 198 /// 199 /// An exception to this rule are blocks with `async.coro.suspend` terminator, 200 /// because in Async to LLVM lowering it is guaranteed that the control flow 201 /// will jump into the resume block, and then follow into the cleanup and 202 /// suspend blocks. 203 /// 204 /// Example: 205 /// 206 /// ^entry(%value: !async.value<f32>): 207 /// async.runtime.await_and_resume %value, %hdl : !async.value<f32> 208 /// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup 209 /// ^resume: 210 /// %0 = async.runtime.load %value 211 /// br ^cleanup 212 /// ^cleanup: 213 /// ... 214 /// ^suspend: 215 /// ... 216 /// 217 /// Although cleanup and suspend blocks do not have the `value` in the 218 /// `liveIn` set, it is guaranteed that execution will eventually continue in 219 /// the resume block (we never explicitly destroy coroutines). 220 LogicalResult addDropRefInDivergentLivenessSuccessor(Value value); 221 }; 222 223 } // namespace 224 225 LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) { 226 OpBuilder builder(value.getContext()); 227 Location loc = value.getLoc(); 228 229 // Use liveness analysis to find the placement of `drop_ref`operation. 230 auto &liveness = getAnalysis<Liveness>(); 231 232 // We analyse only the blocks of the region that defines the `value`, and do 233 // not check nested blocks attached to operations. 234 // 235 // By analyzing only the `definingRegion` CFG we potentially loose an 236 // opportunity to drop the reference count earlier and can extend the lifetime 237 // of reference counted value longer then it is really required. 238 // 239 // We also assume that all nested regions finish their execution before the 240 // completion of the owner operation. The only exception to this rule is 241 // `async.execute` operation, and we verify that they are lowered to the 242 // `async.runtime` operations before adding automatic reference counting. 243 Region *definingRegion = value.getParentRegion(); 244 245 // Last users of the `value` inside all blocks where the value dies. 246 llvm::SmallSet<Operation *, 4> lastUsers; 247 248 // Find blocks in the `definingRegion` that have users of the `value` (if 249 // there are multiple users in the block, which one will be selected is 250 // undefined). User operation might be not the actual user of the value, but 251 // the operation in the block that has a "real user" in one of the attached 252 // regions. 253 llvm::DenseMap<Block *, Operation *> usersInTheBlocks; 254 255 for (Operation *user : value.getUsers()) { 256 Block *userBlock = user->getBlock(); 257 Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock); 258 usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user); 259 assert(ancestor && "ancestor block must be not null"); 260 assert(usersInTheBlocks[ancestor] && "ancestor op must be not null"); 261 } 262 263 // Find blocks where the `value` dies: the value is in `liveIn` set and not 264 // in the `liveOut` set. We place `drop_ref` immediately after the last use 265 // of the `value` in such regions (after handling few special cases). 266 // 267 // We do not traverse all the blocks in the `definingRegion`, because the 268 // `value` can be in the live in set only if it has users in the block, or it 269 // is defined in the block. 270 // 271 // Values with zero users (only definition) handled explicitly above. 272 for (auto &blockAndUser : usersInTheBlocks) { 273 Block *block = blockAndUser.getFirst(); 274 Operation *userInTheBlock = blockAndUser.getSecond(); 275 276 const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block); 277 278 // Value must be in the live input set or defined in the block. 279 assert(blockLiveness->isLiveIn(value) || 280 blockLiveness->getBlock() == value.getParentBlock()); 281 282 // If value is in the live out set, it means it doesn't "die" in the block. 283 if (blockLiveness->isLiveOut(value)) 284 continue; 285 286 // At this point we proved that `value` dies in the `block`. Find the last 287 // use of the `value` inside the `block`, this is where it "dies". 288 Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock); 289 assert(lastUsers.count(lastUser) == 0 && "last users must be unique"); 290 lastUsers.insert(lastUser); 291 } 292 293 // Process all the last users of the `value` inside each block where the value 294 // dies. 295 for (Operation *lastUser : lastUsers) { 296 // Return like operations forward reference count. 297 if (lastUser->hasTrait<OpTrait::ReturnLike>()) 298 continue; 299 300 // We can't currently handle other types of terminators. 301 if (lastUser->hasTrait<OpTrait::IsTerminator>()) 302 return lastUser->emitError() << "async reference counting can't handle " 303 "terminators that are not ReturnLike"; 304 305 // Add a drop_ref immediately after the last user. 306 builder.setInsertionPointAfter(lastUser); 307 builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1)); 308 } 309 310 return success(); 311 } 312 313 LogicalResult 314 AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) { 315 OpBuilder builder(value.getContext()); 316 Location loc = value.getLoc(); 317 318 for (Operation *user : value.getUsers()) { 319 if (!isa<CallOp>(user)) 320 continue; 321 322 // Add a reference before the function call to pass the value at `+1` 323 // reference to the function entry block. 324 builder.setInsertionPoint(user); 325 builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1)); 326 } 327 328 return success(); 329 } 330 331 LogicalResult 332 AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor( 333 Value value) { 334 using BlockSet = llvm::SmallPtrSet<Block *, 4>; 335 336 OpBuilder builder(value.getContext()); 337 338 // If a block has successors with different `liveIn` property of the `value`, 339 // record block successors that do not thave the `value` in the `liveIn` set. 340 llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks; 341 342 // Use liveness analysis to find the placement of `drop_ref`operation. 343 auto &liveness = getAnalysis<Liveness>(); 344 345 // Because we only add `drop_ref` operations to the region that defines the 346 // `value` we can only process CFG for the same region. 347 Region *definingRegion = value.getParentRegion(); 348 349 // Collect blocks with successors with mismatching `liveIn` sets. 350 for (Block &block : definingRegion->getBlocks()) { 351 const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); 352 353 // Skip the block if value is not in the `liveOut` set. 354 if (!blockLiveness || !blockLiveness->isLiveOut(value)) 355 continue; 356 357 BlockSet liveInSuccessors; // `value` is in `liveIn` set 358 BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set 359 360 // Collect successors that do not have `value` in the `liveIn` set. 361 for (Block *successor : block.getSuccessors()) { 362 const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor); 363 if (succLiveness && succLiveness->isLiveIn(value)) 364 liveInSuccessors.insert(successor); 365 else 366 noLiveInSuccessors.insert(successor); 367 } 368 369 // Block has successors with different `liveIn` property of the `value`. 370 if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty()) 371 divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors); 372 } 373 374 // Try to insert `dropRef` operations to handle blocks with divergent liveness 375 // in successors blocks. 376 for (auto kv : divergentLivenessBlocks) { 377 Block *block = kv.getFirst(); 378 BlockSet &successors = kv.getSecond(); 379 380 // Coroutine suspension is a special case terminator for wich we do not 381 // need to create additional reference counting (see details above). 382 Operation *terminator = block->getTerminator(); 383 if (isa<CoroSuspendOp>(terminator)) 384 continue; 385 386 // We only support successor blocks with empty block argument list. 387 auto hasArgs = [](Block *block) { return !block->getArguments().empty(); }; 388 if (llvm::any_of(successors, hasArgs)) 389 return terminator->emitOpError() 390 << "successor have different `liveIn` property of the reference " 391 "counted value"; 392 393 // Make sure that `dropRef` operation is called when branched into the 394 // successor block without `value` in the `liveIn` set. 395 for (Block *successor : successors) { 396 // If successor has a unique predecessor, it is safe to create `dropRef` 397 // operations directly in the successor block. 398 // 399 // Otherwise we need to create a special block for reference counting 400 // operations, and branch from it to the original successor block. 401 Block *refCountingBlock = nullptr; 402 403 if (successor->getUniquePredecessor() == block) { 404 refCountingBlock = successor; 405 } else { 406 refCountingBlock = &successor->getParent()->emplaceBlock(); 407 refCountingBlock->moveBefore(successor); 408 OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock); 409 builder.create<BranchOp>(value.getLoc(), successor); 410 } 411 412 OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock); 413 builder.create<RuntimeDropRefOp>(value.getLoc(), value, 414 builder.getI64IntegerAttr(1)); 415 416 // No need to update the terminator operation. 417 if (successor == refCountingBlock) 418 continue; 419 420 // Update terminator `successor` block to `refCountingBlock`. 421 for (auto pair : llvm::enumerate(terminator->getSuccessors())) 422 if (pair.value() == successor) 423 terminator->setSuccessor(refCountingBlock, pair.index()); 424 } 425 } 426 427 return success(); 428 } 429 430 LogicalResult 431 AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) { 432 // Short-circuit reference counting for values without uses. 433 if (succeeded(dropRefIfNoUses(value))) 434 return success(); 435 436 // Add `drop_ref` operations based on the liveness analysis. 437 if (failed(addDropRefAfterLastUse(value))) 438 return failure(); 439 440 // Add `add_ref` operations before function calls. 441 if (failed(addAddRefBeforeFunctionCall(value))) 442 return failure(); 443 444 // Add `drop_ref` operations to successors with divergent `value` liveness. 445 if (failed(addDropRefInDivergentLivenessSuccessor(value))) 446 return failure(); 447 448 return success(); 449 } 450 451 void AsyncRuntimeRefCountingPass::runOnOperation() { 452 auto functor = [&](Value value) { return addAutomaticRefCounting(value); }; 453 if (failed(walkReferenceCountedValues(getOperation(), functor))) 454 signalPassFailure(); 455 } 456 457 //===----------------------------------------------------------------------===// 458 // Reference counting based on the user defined policy. 459 //===----------------------------------------------------------------------===// 460 461 namespace { 462 463 class AsyncRuntimePolicyBasedRefCountingPass 464 : public AsyncRuntimePolicyBasedRefCountingBase< 465 AsyncRuntimePolicyBasedRefCountingPass> { 466 public: 467 AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); } 468 469 void runOnOperation() override; 470 471 private: 472 // Adds a reference counting operations for all uses of the `value` according 473 // to the reference counting policy. 474 LogicalResult addRefCounting(Value value); 475 476 void initializeDefaultPolicy(); 477 478 llvm::SmallVector<std::function<FailureOr<int>(OpOperand &)>> policy; 479 }; 480 481 } // namespace 482 483 LogicalResult 484 AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) { 485 // Short-circuit reference counting for values without uses. 486 if (succeeded(dropRefIfNoUses(value))) 487 return success(); 488 489 OpBuilder b(value.getContext()); 490 491 // Consult the user defined policy for every value use. 492 for (OpOperand &operand : value.getUses()) { 493 Location loc = operand.getOwner()->getLoc(); 494 495 for (auto &func : policy) { 496 FailureOr<int> refCount = func(operand); 497 if (failed(refCount)) 498 return failure(); 499 500 int cnt = refCount.getValue(); 501 502 // Create `add_ref` operation before the operand owner. 503 if (cnt > 0) { 504 b.setInsertionPoint(operand.getOwner()); 505 b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt)); 506 } 507 508 // Create `drop_ref` operation after the operand owner. 509 if (cnt < 0) { 510 b.setInsertionPointAfter(operand.getOwner()); 511 b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt)); 512 } 513 } 514 } 515 516 return success(); 517 } 518 519 void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() { 520 policy.push_back([](OpOperand &operand) -> FailureOr<int> { 521 Operation *op = operand.getOwner(); 522 Type type = operand.get().getType(); 523 524 bool isToken = type.isa<TokenType>(); 525 bool isGroup = type.isa<GroupType>(); 526 bool isValue = type.isa<ValueType>(); 527 528 // Drop reference after async token or group error check (coro await). 529 if (auto await = dyn_cast<RuntimeIsErrorOp>(op)) 530 return (isToken || isGroup) ? -1 : 0; 531 532 // Drop reference after async value load. 533 if (auto load = dyn_cast<RuntimeLoadOp>(op)) 534 return isValue ? -1 : 0; 535 536 // Drop reference after async token added to the group. 537 if (auto add = dyn_cast<RuntimeAddToGroupOp>(op)) 538 return isToken ? -1 : 0; 539 540 return 0; 541 }); 542 } 543 544 void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() { 545 auto functor = [&](Value value) { return addRefCounting(value); }; 546 if (failed(walkReferenceCountedValues(getOperation(), functor))) 547 signalPassFailure(); 548 } 549 550 //----------------------------------------------------------------------------// 551 552 std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() { 553 return std::make_unique<AsyncRuntimeRefCountingPass>(); 554 } 555 556 std::unique_ptr<Pass> mlir::createAsyncRuntimePolicyBasedRefCountingPass() { 557 return std::make_unique<AsyncRuntimePolicyBasedRefCountingPass>(); 558 } 559