1 //===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements automatic reference counting for Async runtime 10 // operations and types. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Analysis/Liveness.h" 16 #include "mlir/Dialect/Async/IR/Async.h" 17 #include "mlir/Dialect/Async/Passes.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/IR/ImplicitLocOpBuilder.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 #include "llvm/ADT/SmallSet.h" 23 24 using namespace mlir; 25 using namespace mlir::async; 26 27 #define DEBUG_TYPE "async-runtime-ref-counting" 28 29 namespace { 30 31 class AsyncRuntimeRefCountingPass 32 : public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> { 33 public: 34 AsyncRuntimeRefCountingPass() = default; 35 void runOnOperation() override; 36 37 private: 38 /// Adds an automatic reference counting to the `value`. 39 /// 40 /// All values (token, group or value) are semantically created with a 41 /// reference count of +1 and it is the responsibility of the async value user 42 /// to place the `add_ref` and `drop_ref` operations to ensure that the value 43 /// is destroyed after the last use. 44 /// 45 /// The function returns failure if it can't deduce the locations where 46 /// to place the reference counting operations. 47 /// 48 /// Async values "semantically created" when: 49 /// 1. Operation returns async result (e.g. `async.runtime.create`) 50 /// 2. Async value passed in as a block argument (or function argument, 51 /// because function arguments are just entry block arguments) 52 /// 53 /// Passing async value as a function argument (or block argument) does not 54 /// really mean that a new async value is created, it only means that the 55 /// caller of a function transfered ownership of `+1` reference to the callee. 56 /// It is convenient to think that from the callee perspective async value was 57 /// "created" with `+1` reference by the block argument. 58 /// 59 /// Automatic reference counting algorithm outline: 60 /// 61 /// #1 Insert `drop_ref` operations after last use of the `value`. 62 /// #2 Insert `add_ref` operations before functions calls with reference 63 /// counted `value` operand (newly created `+1` reference will be 64 /// transferred to the callee). 65 /// #3 Verify that divergent control flow does not lead to leaked reference 66 /// counted objects. 67 /// 68 /// Async runtime reference counting optimization pass will optimize away 69 /// some of the redundant `add_ref` and `drop_ref` operations inserted by this 70 /// strategy (see `async-runtime-ref-counting-opt`). 71 LogicalResult addAutomaticRefCounting(Value value); 72 73 /// (#1) Adds the `drop_ref` operation after the last use of the `value` 74 /// relying on the liveness analysis. 75 /// 76 /// If the `value` is in the block `liveIn` set and it is not in the block 77 /// `liveOut` set, it means that it "dies" in the block. We find the last 78 /// use of the value in such block and: 79 /// 80 /// 1. If the last user is a `ReturnLike` operation we do nothing, because 81 /// it forwards the ownership to the caller. 82 /// 2. Otherwise we add a `drop_ref` operation immediately after the last 83 /// use. 84 LogicalResult addDropRefAfterLastUse(Value value); 85 86 /// (#2) Adds the `add_ref` operation before the function call taking `value` 87 /// operand to ensure that the value passed to the function entry block 88 /// has a `+1` reference count. 89 LogicalResult addAddRefBeforeFunctionCall(Value value); 90 91 /// (#3) Adds the `drop_ref` operation to account for successor blocks with 92 /// divergent `liveIn` property: `value` is not in the `liveIn` set of all 93 /// successor blocks. 94 /// 95 /// Example: 96 /// 97 /// ^entry: 98 /// %token = async.runtime.create : !async.token 99 /// cond_br %cond, ^bb1, ^bb2 100 /// ^bb1: 101 /// async.runtime.await %token 102 /// async.runtime.drop_ref %token 103 /// br ^bb2 104 /// ^bb2: 105 /// return 106 /// 107 /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have 108 /// to branch into a special "reference counting block" from the ^entry that 109 /// will have a `drop_ref` operation, and then branch into the ^bb2. 110 /// 111 /// After transformation: 112 /// 113 /// ^entry: 114 /// %token = async.runtime.create : !async.token 115 /// cond_br %cond, ^bb1, ^reference_counting 116 /// ^bb1: 117 /// async.runtime.await %token 118 /// async.runtime.drop_ref %token 119 /// br ^bb2 120 /// ^reference_counting: 121 /// async.runtime.drop_ref %token 122 /// br ^bb2 123 /// ^bb2: 124 /// return 125 /// 126 /// An exception to this rule are blocks with `async.coro.suspend` terminator, 127 /// because in Async to LLVM lowering it is guaranteed that the control flow 128 /// will jump into the resume block, and then follow into the cleanup and 129 /// suspend blocks. 130 /// 131 /// Example: 132 /// 133 /// ^entry(%value: !async.value<f32>): 134 /// async.runtime.await_and_resume %value, %hdl : !async.value<f32> 135 /// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup 136 /// ^resume: 137 /// %0 = async.runtime.load %value 138 /// br ^cleanup 139 /// ^cleanup: 140 /// ... 141 /// ^suspend: 142 /// ... 143 /// 144 /// Although cleanup and suspend blocks do not have the `value` in the 145 /// `liveIn` set, it is guaranteed that execution will eventually continue in 146 /// the resume block (we never explicitly destroy coroutines). 147 LogicalResult addDropRefInDivergentLivenessSuccessor(Value value); 148 }; 149 150 } // namespace 151 152 LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) { 153 OpBuilder builder(value.getContext()); 154 Location loc = value.getLoc(); 155 156 // Use liveness analysis to find the placement of `drop_ref`operation. 157 auto &liveness = getAnalysis<Liveness>(); 158 159 // We analyse only the blocks of the region that defines the `value`, and do 160 // not check nested blocks attached to operations. 161 // 162 // By analyzing only the `definingRegion` CFG we potentially loose an 163 // opportunity to drop the reference count earlier and can extend the lifetime 164 // of reference counted value longer then it is really required. 165 // 166 // We also assume that all nested regions finish their execution before the 167 // completion of the owner operation. The only exception to this rule is 168 // `async.execute` operation, and we verify that they are lowered to the 169 // `async.runtime` operations before adding automatic reference counting. 170 Region *definingRegion = value.getParentRegion(); 171 172 // Last users of the `value` inside all blocks where the value dies. 173 llvm::SmallSet<Operation *, 4> lastUsers; 174 175 // Find blocks in the `definingRegion` that have users of the `value` (if 176 // there are multiple users in the block, which one will be selected is 177 // undefined). User operation might be not the actual user of the value, but 178 // the operation in the block that has a "real user" in one of the attached 179 // regions. 180 llvm::DenseMap<Block *, Operation *> usersInTheBlocks; 181 182 for (Operation *user : value.getUsers()) { 183 Block *userBlock = user->getBlock(); 184 Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock); 185 usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user); 186 assert(ancestor && "ancestor block must be not null"); 187 assert(usersInTheBlocks[ancestor] && "ancestor op must be not null"); 188 } 189 190 // Find blocks where the `value` dies: the value is in `liveIn` set and not 191 // in the `liveOut` set. We place `drop_ref` immediately after the last use 192 // of the `value` in such regions (after handling few special cases). 193 // 194 // We do not traverse all the blocks in the `definingRegion`, because the 195 // `value` can be in the live in set only if it has users in the block, or it 196 // is defined in the block. 197 // 198 // Values with zero users (only definition) handled explicitly above. 199 for (auto &blockAndUser : usersInTheBlocks) { 200 Block *block = blockAndUser.getFirst(); 201 Operation *userInTheBlock = blockAndUser.getSecond(); 202 203 const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block); 204 205 // Value must be in the live input set or defined in the block. 206 assert(blockLiveness->isLiveIn(value) || 207 blockLiveness->getBlock() == value.getParentBlock()); 208 209 // If value is in the live out set, it means it doesn't "die" in the block. 210 if (blockLiveness->isLiveOut(value)) 211 continue; 212 213 // At this point we proved that `value` dies in the `block`. Find the last 214 // use of the `value` inside the `block`, this is where it "dies". 215 Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock); 216 assert(lastUsers.count(lastUser) == 0 && "last users must be unique"); 217 lastUsers.insert(lastUser); 218 } 219 220 // Process all the last users of the `value` inside each block where the value 221 // dies. 222 for (Operation *lastUser : lastUsers) { 223 // Return like operations forward reference count. 224 if (lastUser->hasTrait<OpTrait::ReturnLike>()) 225 continue; 226 227 // We can't currently handle other types of terminators. 228 if (lastUser->hasTrait<OpTrait::IsTerminator>()) 229 return lastUser->emitError() << "async reference counting can't handle " 230 "terminators that are not ReturnLike"; 231 232 // Add a drop_ref immediately after the last user. 233 builder.setInsertionPointAfter(lastUser); 234 builder.create<RuntimeDropRefOp>(loc, value, builder.getI32IntegerAttr(1)); 235 } 236 237 return success(); 238 } 239 240 LogicalResult 241 AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) { 242 OpBuilder builder(value.getContext()); 243 Location loc = value.getLoc(); 244 245 for (Operation *user : value.getUsers()) { 246 if (!isa<CallOp>(user)) 247 continue; 248 249 // Add a reference before the function call to pass the value at `+1` 250 // reference to the function entry block. 251 builder.setInsertionPoint(user); 252 builder.create<RuntimeAddRefOp>(loc, value, builder.getI32IntegerAttr(1)); 253 } 254 255 return success(); 256 } 257 258 LogicalResult 259 AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor( 260 Value value) { 261 using BlockSet = llvm::SmallPtrSet<Block *, 4>; 262 263 OpBuilder builder(value.getContext()); 264 265 // If a block has successors with different `liveIn` property of the `value`, 266 // record block successors that do not thave the `value` in the `liveIn` set. 267 llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks; 268 269 // Use liveness analysis to find the placement of `drop_ref`operation. 270 auto &liveness = getAnalysis<Liveness>(); 271 272 // Because we only add `drop_ref` operations to the region that defines the 273 // `value` we can only process CFG for the same region. 274 Region *definingRegion = value.getParentRegion(); 275 276 // Collect blocks with successors with mismatching `liveIn` sets. 277 for (Block &block : definingRegion->getBlocks()) { 278 const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); 279 280 // Skip the block if value is not in the `liveOut` set. 281 if (!blockLiveness || !blockLiveness->isLiveOut(value)) 282 continue; 283 284 BlockSet liveInSuccessors; // `value` is in `liveIn` set 285 BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set 286 287 // Collect successors that do not have `value` in the `liveIn` set. 288 for (Block *successor : block.getSuccessors()) { 289 const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor); 290 if (succLiveness && succLiveness->isLiveIn(value)) 291 liveInSuccessors.insert(successor); 292 else 293 noLiveInSuccessors.insert(successor); 294 } 295 296 // Block has successors with different `liveIn` property of the `value`. 297 if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty()) 298 divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors); 299 } 300 301 // Try to insert `dropRef` operations to handle blocks with divergent liveness 302 // in successors blocks. 303 for (auto kv : divergentLivenessBlocks) { 304 Block *block = kv.getFirst(); 305 BlockSet &successors = kv.getSecond(); 306 307 // Coroutine suspension is a special case terminator for wich we do not 308 // need to create additional reference counting (see details above). 309 Operation *terminator = block->getTerminator(); 310 if (isa<CoroSuspendOp>(terminator)) 311 continue; 312 313 // We only support successor blocks with empty block argument list. 314 auto hasArgs = [](Block *block) { return !block->getArguments().empty(); }; 315 if (llvm::any_of(successors, hasArgs)) 316 return terminator->emitOpError() 317 << "successor have different `liveIn` property of the reference " 318 "counted value"; 319 320 // Make sure that `dropRef` operation is called when branched into the 321 // successor block without `value` in the `liveIn` set. 322 for (Block *successor : successors) { 323 // If successor has a unique predecessor, it is safe to create `dropRef` 324 // operations directly in the successor block. 325 // 326 // Otherwise we need to create a special block for reference counting 327 // operations, and branch from it to the original successor block. 328 Block *refCountingBlock = nullptr; 329 330 if (successor->getUniquePredecessor() == block) { 331 refCountingBlock = successor; 332 } else { 333 refCountingBlock = &successor->getParent()->emplaceBlock(); 334 refCountingBlock->moveBefore(successor); 335 OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock); 336 builder.create<BranchOp>(value.getLoc(), successor); 337 } 338 339 OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock); 340 builder.create<RuntimeDropRefOp>(value.getLoc(), value, 341 builder.getI32IntegerAttr(1)); 342 343 // No need to update the terminator operation. 344 if (successor == refCountingBlock) 345 continue; 346 347 // Update terminator `successor` block to `refCountingBlock`. 348 for (auto pair : llvm::enumerate(terminator->getSuccessors())) 349 if (pair.value() == successor) 350 terminator->setSuccessor(refCountingBlock, pair.index()); 351 } 352 } 353 354 return success(); 355 } 356 357 LogicalResult 358 AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) { 359 OpBuilder builder(value.getContext()); 360 Location loc = value.getLoc(); 361 362 // Set inserton point after the operation producing a value, or at the 363 // beginning of the block if the value defined by the block argument. 364 if (Operation *op = value.getDefiningOp()) 365 builder.setInsertionPointAfter(op); 366 else 367 builder.setInsertionPointToStart(value.getParentBlock()); 368 369 // Drop the reference count immediately if the value has no uses. 370 if (value.getUses().empty()) { 371 builder.create<RuntimeDropRefOp>(loc, value, builder.getI32IntegerAttr(1)); 372 return success(); 373 } 374 375 // Add `drop_ref` operations based on the liveness analysis. 376 if (failed(addDropRefAfterLastUse(value))) 377 return failure(); 378 379 // Add `add_ref` operations before function calls. 380 if (failed(addAddRefBeforeFunctionCall(value))) 381 return failure(); 382 383 // Add `drop_ref` operations to successors with divergent `value` liveness. 384 if (failed(addDropRefInDivergentLivenessSuccessor(value))) 385 return failure(); 386 387 return success(); 388 } 389 390 void AsyncRuntimeRefCountingPass::runOnOperation() { 391 Operation *op = getOperation(); 392 393 // Check that we do not have high level async operations in the IR because 394 // otherwise automatic reference counting will produce incorrect results after 395 // execute operations will be lowered to `async.runtime` 396 WalkResult executeOpWalk = op->walk([&](Operation *op) -> WalkResult { 397 if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op)) 398 return WalkResult::advance(); 399 400 return op->emitError() 401 << "async operations must be lowered to async runtime operations"; 402 }); 403 404 if (executeOpWalk.wasInterrupted()) { 405 signalPassFailure(); 406 return; 407 } 408 409 // Add reference counting to block arguments. 410 WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult { 411 for (BlockArgument arg : block->getArguments()) 412 if (isRefCounted(arg.getType())) 413 if (failed(addAutomaticRefCounting(arg))) 414 return WalkResult::interrupt(); 415 416 return WalkResult::advance(); 417 }); 418 419 if (blockWalk.wasInterrupted()) { 420 signalPassFailure(); 421 return; 422 } 423 424 // Add reference counting to operation results. 425 WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult { 426 for (unsigned i = 0; i < op->getNumResults(); ++i) 427 if (isRefCounted(op->getResultTypes()[i])) 428 if (failed(addAutomaticRefCounting(op->getResult(i)))) 429 return WalkResult::interrupt(); 430 431 return WalkResult::advance(); 432 }); 433 434 if (opWalk.wasInterrupted()) 435 signalPassFailure(); 436 } 437 438 std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() { 439 return std::make_unique<AsyncRuntimeRefCountingPass>(); 440 } 441