1a6628e59SEugene Zhulenev //===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===// 2a6628e59SEugene Zhulenev // 3a6628e59SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a6628e59SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information. 5a6628e59SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a6628e59SEugene Zhulenev // 7a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===// 8a6628e59SEugene Zhulenev // 9a6628e59SEugene Zhulenev // This file implements automatic reference counting for Async runtime 10a6628e59SEugene Zhulenev // operations and types. 11a6628e59SEugene Zhulenev // 12a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===// 13a6628e59SEugene Zhulenev 14a6628e59SEugene Zhulenev #include "PassDetail.h" 15a6628e59SEugene Zhulenev #include "mlir/Analysis/Liveness.h" 16a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h" 17a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h" 18a6628e59SEugene Zhulenev #include "mlir/Dialect/StandardOps/IR/Ops.h" 19a6628e59SEugene Zhulenev #include "mlir/IR/ImplicitLocOpBuilder.h" 20a6628e59SEugene Zhulenev #include "mlir/IR/PatternMatch.h" 21a6628e59SEugene Zhulenev #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22a6628e59SEugene Zhulenev #include "llvm/ADT/SmallSet.h" 23a6628e59SEugene Zhulenev 24a6628e59SEugene Zhulenev using namespace mlir; 25a6628e59SEugene Zhulenev using namespace mlir::async; 26a6628e59SEugene Zhulenev 27a6628e59SEugene Zhulenev #define DEBUG_TYPE "async-runtime-ref-counting" 28a6628e59SEugene Zhulenev 29f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===// 30f57b2420SEugene Zhulenev // Utility functions shared by reference counting passes. 31f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===// 32f57b2420SEugene Zhulenev 33f57b2420SEugene Zhulenev // Drop the reference count immediately if the value has no uses. 34f57b2420SEugene Zhulenev static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) { 35f57b2420SEugene Zhulenev if (!value.getUses().empty()) 36f57b2420SEugene Zhulenev return failure(); 37f57b2420SEugene Zhulenev 38f57b2420SEugene Zhulenev OpBuilder b(value.getContext()); 39f57b2420SEugene Zhulenev 40f57b2420SEugene Zhulenev // Set insertion point after the operation producing a value, or at the 41f57b2420SEugene Zhulenev // beginning of the block if the value defined by the block argument. 42f57b2420SEugene Zhulenev if (Operation *op = value.getDefiningOp()) 43f57b2420SEugene Zhulenev b.setInsertionPointAfter(op); 44f57b2420SEugene Zhulenev else 45f57b2420SEugene Zhulenev b.setInsertionPointToStart(value.getParentBlock()); 46f57b2420SEugene Zhulenev 47*92db09cdSEugene Zhulenev b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI64IntegerAttr(1)); 48f57b2420SEugene Zhulenev return success(); 49f57b2420SEugene Zhulenev } 50f57b2420SEugene Zhulenev 51f57b2420SEugene Zhulenev // Calls `addRefCounting` for every reference counted value defined by the 52f57b2420SEugene Zhulenev // operation `op` (block arguments and values defined in nested regions). 53f57b2420SEugene Zhulenev static LogicalResult walkReferenceCountedValues( 54f57b2420SEugene Zhulenev Operation *op, llvm::function_ref<LogicalResult(Value)> addRefCounting) { 55f57b2420SEugene Zhulenev // Check that we do not have high level async operations in the IR because 56f57b2420SEugene Zhulenev // otherwise reference counting will produce incorrect results after high 57f57b2420SEugene Zhulenev // level async operations will be lowered to `async.runtime` 58f57b2420SEugene Zhulenev WalkResult checkNoAsyncWalk = op->walk([&](Operation *op) -> WalkResult { 59f57b2420SEugene Zhulenev if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op)) 60f57b2420SEugene Zhulenev return WalkResult::advance(); 61f57b2420SEugene Zhulenev 62f57b2420SEugene Zhulenev return op->emitError() 63f57b2420SEugene Zhulenev << "async operations must be lowered to async runtime operations"; 64f57b2420SEugene Zhulenev }); 65f57b2420SEugene Zhulenev 66f57b2420SEugene Zhulenev if (checkNoAsyncWalk.wasInterrupted()) 67f57b2420SEugene Zhulenev return failure(); 68f57b2420SEugene Zhulenev 69f57b2420SEugene Zhulenev // Add reference counting to block arguments. 70f57b2420SEugene Zhulenev WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult { 71f57b2420SEugene Zhulenev for (BlockArgument arg : block->getArguments()) 72f57b2420SEugene Zhulenev if (isRefCounted(arg.getType())) 73f57b2420SEugene Zhulenev if (failed(addRefCounting(arg))) 74f57b2420SEugene Zhulenev return WalkResult::interrupt(); 75f57b2420SEugene Zhulenev 76f57b2420SEugene Zhulenev return WalkResult::advance(); 77f57b2420SEugene Zhulenev }); 78f57b2420SEugene Zhulenev 79f57b2420SEugene Zhulenev if (blockWalk.wasInterrupted()) 80f57b2420SEugene Zhulenev return failure(); 81f57b2420SEugene Zhulenev 82f57b2420SEugene Zhulenev // Add reference counting to operation results. 83f57b2420SEugene Zhulenev WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult { 84f57b2420SEugene Zhulenev for (unsigned i = 0; i < op->getNumResults(); ++i) 85f57b2420SEugene Zhulenev if (isRefCounted(op->getResultTypes()[i])) 86f57b2420SEugene Zhulenev if (failed(addRefCounting(op->getResult(i)))) 87f57b2420SEugene Zhulenev return WalkResult::interrupt(); 88f57b2420SEugene Zhulenev 89f57b2420SEugene Zhulenev return WalkResult::advance(); 90f57b2420SEugene Zhulenev }); 91f57b2420SEugene Zhulenev 92f57b2420SEugene Zhulenev if (opWalk.wasInterrupted()) 93f57b2420SEugene Zhulenev return failure(); 94f57b2420SEugene Zhulenev 95f57b2420SEugene Zhulenev return success(); 96f57b2420SEugene Zhulenev } 97f57b2420SEugene Zhulenev 98f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===// 99f57b2420SEugene Zhulenev // Automatic reference counting based on the liveness analysis. 100f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===// 101f57b2420SEugene Zhulenev 102a6628e59SEugene Zhulenev namespace { 103a6628e59SEugene Zhulenev 104a6628e59SEugene Zhulenev class AsyncRuntimeRefCountingPass 105a6628e59SEugene Zhulenev : public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> { 106a6628e59SEugene Zhulenev public: 107a6628e59SEugene Zhulenev AsyncRuntimeRefCountingPass() = default; 1088a316b00SEugene Zhulenev void runOnOperation() override; 109a6628e59SEugene Zhulenev 110a6628e59SEugene Zhulenev private: 111a6628e59SEugene Zhulenev /// Adds an automatic reference counting to the `value`. 112a6628e59SEugene Zhulenev /// 113a6628e59SEugene Zhulenev /// All values (token, group or value) are semantically created with a 114a6628e59SEugene Zhulenev /// reference count of +1 and it is the responsibility of the async value user 115a6628e59SEugene Zhulenev /// to place the `add_ref` and `drop_ref` operations to ensure that the value 116a6628e59SEugene Zhulenev /// is destroyed after the last use. 117a6628e59SEugene Zhulenev /// 118a6628e59SEugene Zhulenev /// The function returns failure if it can't deduce the locations where 119a6628e59SEugene Zhulenev /// to place the reference counting operations. 120a6628e59SEugene Zhulenev /// 121a6628e59SEugene Zhulenev /// Async values "semantically created" when: 122a6628e59SEugene Zhulenev /// 1. Operation returns async result (e.g. `async.runtime.create`) 123a6628e59SEugene Zhulenev /// 2. Async value passed in as a block argument (or function argument, 124a6628e59SEugene Zhulenev /// because function arguments are just entry block arguments) 125a6628e59SEugene Zhulenev /// 126a6628e59SEugene Zhulenev /// Passing async value as a function argument (or block argument) does not 127a6628e59SEugene Zhulenev /// really mean that a new async value is created, it only means that the 128a6628e59SEugene Zhulenev /// caller of a function transfered ownership of `+1` reference to the callee. 129a6628e59SEugene Zhulenev /// It is convenient to think that from the callee perspective async value was 130a6628e59SEugene Zhulenev /// "created" with `+1` reference by the block argument. 131a6628e59SEugene Zhulenev /// 132a6628e59SEugene Zhulenev /// Automatic reference counting algorithm outline: 133a6628e59SEugene Zhulenev /// 134a6628e59SEugene Zhulenev /// #1 Insert `drop_ref` operations after last use of the `value`. 135a6628e59SEugene Zhulenev /// #2 Insert `add_ref` operations before functions calls with reference 136a6628e59SEugene Zhulenev /// counted `value` operand (newly created `+1` reference will be 137a6628e59SEugene Zhulenev /// transferred to the callee). 138a6628e59SEugene Zhulenev /// #3 Verify that divergent control flow does not lead to leaked reference 139a6628e59SEugene Zhulenev /// counted objects. 140a6628e59SEugene Zhulenev /// 141a6628e59SEugene Zhulenev /// Async runtime reference counting optimization pass will optimize away 142a6628e59SEugene Zhulenev /// some of the redundant `add_ref` and `drop_ref` operations inserted by this 143a6628e59SEugene Zhulenev /// strategy (see `async-runtime-ref-counting-opt`). 144a6628e59SEugene Zhulenev LogicalResult addAutomaticRefCounting(Value value); 145a6628e59SEugene Zhulenev 146a6628e59SEugene Zhulenev /// (#1) Adds the `drop_ref` operation after the last use of the `value` 147a6628e59SEugene Zhulenev /// relying on the liveness analysis. 148a6628e59SEugene Zhulenev /// 149a6628e59SEugene Zhulenev /// If the `value` is in the block `liveIn` set and it is not in the block 150a6628e59SEugene Zhulenev /// `liveOut` set, it means that it "dies" in the block. We find the last 151a6628e59SEugene Zhulenev /// use of the value in such block and: 152a6628e59SEugene Zhulenev /// 153a6628e59SEugene Zhulenev /// 1. If the last user is a `ReturnLike` operation we do nothing, because 154a6628e59SEugene Zhulenev /// it forwards the ownership to the caller. 155a6628e59SEugene Zhulenev /// 2. Otherwise we add a `drop_ref` operation immediately after the last 156a6628e59SEugene Zhulenev /// use. 157a6628e59SEugene Zhulenev LogicalResult addDropRefAfterLastUse(Value value); 158a6628e59SEugene Zhulenev 159a6628e59SEugene Zhulenev /// (#2) Adds the `add_ref` operation before the function call taking `value` 160a6628e59SEugene Zhulenev /// operand to ensure that the value passed to the function entry block 161a6628e59SEugene Zhulenev /// has a `+1` reference count. 162a6628e59SEugene Zhulenev LogicalResult addAddRefBeforeFunctionCall(Value value); 163a6628e59SEugene Zhulenev 164c412979cSEugene Zhulenev /// (#3) Adds the `drop_ref` operation to account for successor blocks with 165c412979cSEugene Zhulenev /// divergent `liveIn` property: `value` is not in the `liveIn` set of all 166c412979cSEugene Zhulenev /// successor blocks. 167a6628e59SEugene Zhulenev /// 168a6628e59SEugene Zhulenev /// Example: 169a6628e59SEugene Zhulenev /// 170a6628e59SEugene Zhulenev /// ^entry: 171a6628e59SEugene Zhulenev /// %token = async.runtime.create : !async.token 172a6628e59SEugene Zhulenev /// cond_br %cond, ^bb1, ^bb2 173a6628e59SEugene Zhulenev /// ^bb1: 174a6628e59SEugene Zhulenev /// async.runtime.await %token 175c412979cSEugene Zhulenev /// async.runtime.drop_ref %token 176c412979cSEugene Zhulenev /// br ^bb2 177a6628e59SEugene Zhulenev /// ^bb2: 178a6628e59SEugene Zhulenev /// return 179a6628e59SEugene Zhulenev /// 180c412979cSEugene Zhulenev /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have 181c412979cSEugene Zhulenev /// to branch into a special "reference counting block" from the ^entry that 182c412979cSEugene Zhulenev /// will have a `drop_ref` operation, and then branch into the ^bb2. 183c412979cSEugene Zhulenev /// 184c412979cSEugene Zhulenev /// After transformation: 185c412979cSEugene Zhulenev /// 186c412979cSEugene Zhulenev /// ^entry: 187c412979cSEugene Zhulenev /// %token = async.runtime.create : !async.token 188c412979cSEugene Zhulenev /// cond_br %cond, ^bb1, ^reference_counting 189c412979cSEugene Zhulenev /// ^bb1: 190c412979cSEugene Zhulenev /// async.runtime.await %token 191c412979cSEugene Zhulenev /// async.runtime.drop_ref %token 192c412979cSEugene Zhulenev /// br ^bb2 193c412979cSEugene Zhulenev /// ^reference_counting: 194c412979cSEugene Zhulenev /// async.runtime.drop_ref %token 195c412979cSEugene Zhulenev /// br ^bb2 196c412979cSEugene Zhulenev /// ^bb2: 197c412979cSEugene Zhulenev /// return 198a6628e59SEugene Zhulenev /// 199a6628e59SEugene Zhulenev /// An exception to this rule are blocks with `async.coro.suspend` terminator, 200a6628e59SEugene Zhulenev /// because in Async to LLVM lowering it is guaranteed that the control flow 201a6628e59SEugene Zhulenev /// will jump into the resume block, and then follow into the cleanup and 202a6628e59SEugene Zhulenev /// suspend blocks. 203a6628e59SEugene Zhulenev /// 204a6628e59SEugene Zhulenev /// Example: 205a6628e59SEugene Zhulenev /// 206a6628e59SEugene Zhulenev /// ^entry(%value: !async.value<f32>): 207a6628e59SEugene Zhulenev /// async.runtime.await_and_resume %value, %hdl : !async.value<f32> 208a6628e59SEugene Zhulenev /// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup 209a6628e59SEugene Zhulenev /// ^resume: 210a6628e59SEugene Zhulenev /// %0 = async.runtime.load %value 211a6628e59SEugene Zhulenev /// br ^cleanup 212a6628e59SEugene Zhulenev /// ^cleanup: 213a6628e59SEugene Zhulenev /// ... 214a6628e59SEugene Zhulenev /// ^suspend: 215a6628e59SEugene Zhulenev /// ... 216a6628e59SEugene Zhulenev /// 217a6628e59SEugene Zhulenev /// Although cleanup and suspend blocks do not have the `value` in the 218a6628e59SEugene Zhulenev /// `liveIn` set, it is guaranteed that execution will eventually continue in 219a6628e59SEugene Zhulenev /// the resume block (we never explicitly destroy coroutines). 220c412979cSEugene Zhulenev LogicalResult addDropRefInDivergentLivenessSuccessor(Value value); 221a6628e59SEugene Zhulenev }; 222a6628e59SEugene Zhulenev 223a6628e59SEugene Zhulenev } // namespace 224a6628e59SEugene Zhulenev 225a6628e59SEugene Zhulenev LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) { 226a6628e59SEugene Zhulenev OpBuilder builder(value.getContext()); 227a6628e59SEugene Zhulenev Location loc = value.getLoc(); 228a6628e59SEugene Zhulenev 229a6628e59SEugene Zhulenev // Use liveness analysis to find the placement of `drop_ref`operation. 230a6628e59SEugene Zhulenev auto &liveness = getAnalysis<Liveness>(); 231a6628e59SEugene Zhulenev 232a6628e59SEugene Zhulenev // We analyse only the blocks of the region that defines the `value`, and do 233a6628e59SEugene Zhulenev // not check nested blocks attached to operations. 234a6628e59SEugene Zhulenev // 235a6628e59SEugene Zhulenev // By analyzing only the `definingRegion` CFG we potentially loose an 236a6628e59SEugene Zhulenev // opportunity to drop the reference count earlier and can extend the lifetime 237a6628e59SEugene Zhulenev // of reference counted value longer then it is really required. 238a6628e59SEugene Zhulenev // 239a6628e59SEugene Zhulenev // We also assume that all nested regions finish their execution before the 240a6628e59SEugene Zhulenev // completion of the owner operation. The only exception to this rule is 241a6628e59SEugene Zhulenev // `async.execute` operation, and we verify that they are lowered to the 242a6628e59SEugene Zhulenev // `async.runtime` operations before adding automatic reference counting. 243a6628e59SEugene Zhulenev Region *definingRegion = value.getParentRegion(); 244a6628e59SEugene Zhulenev 245a6628e59SEugene Zhulenev // Last users of the `value` inside all blocks where the value dies. 246a6628e59SEugene Zhulenev llvm::SmallSet<Operation *, 4> lastUsers; 247a6628e59SEugene Zhulenev 248a6628e59SEugene Zhulenev // Find blocks in the `definingRegion` that have users of the `value` (if 249a6628e59SEugene Zhulenev // there are multiple users in the block, which one will be selected is 250a6628e59SEugene Zhulenev // undefined). User operation might be not the actual user of the value, but 251a6628e59SEugene Zhulenev // the operation in the block that has a "real user" in one of the attached 252a6628e59SEugene Zhulenev // regions. 253a6628e59SEugene Zhulenev llvm::DenseMap<Block *, Operation *> usersInTheBlocks; 254a6628e59SEugene Zhulenev 255a6628e59SEugene Zhulenev for (Operation *user : value.getUsers()) { 256a6628e59SEugene Zhulenev Block *userBlock = user->getBlock(); 257a6628e59SEugene Zhulenev Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock); 258a6628e59SEugene Zhulenev usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user); 259a6628e59SEugene Zhulenev assert(ancestor && "ancestor block must be not null"); 260a6628e59SEugene Zhulenev assert(usersInTheBlocks[ancestor] && "ancestor op must be not null"); 261a6628e59SEugene Zhulenev } 262a6628e59SEugene Zhulenev 263a6628e59SEugene Zhulenev // Find blocks where the `value` dies: the value is in `liveIn` set and not 264a6628e59SEugene Zhulenev // in the `liveOut` set. We place `drop_ref` immediately after the last use 265a6628e59SEugene Zhulenev // of the `value` in such regions (after handling few special cases). 266a6628e59SEugene Zhulenev // 267a6628e59SEugene Zhulenev // We do not traverse all the blocks in the `definingRegion`, because the 268a6628e59SEugene Zhulenev // `value` can be in the live in set only if it has users in the block, or it 269a6628e59SEugene Zhulenev // is defined in the block. 270a6628e59SEugene Zhulenev // 271a6628e59SEugene Zhulenev // Values with zero users (only definition) handled explicitly above. 272a6628e59SEugene Zhulenev for (auto &blockAndUser : usersInTheBlocks) { 273a6628e59SEugene Zhulenev Block *block = blockAndUser.getFirst(); 274a6628e59SEugene Zhulenev Operation *userInTheBlock = blockAndUser.getSecond(); 275a6628e59SEugene Zhulenev 276a6628e59SEugene Zhulenev const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block); 277a6628e59SEugene Zhulenev 278a6628e59SEugene Zhulenev // Value must be in the live input set or defined in the block. 279a6628e59SEugene Zhulenev assert(blockLiveness->isLiveIn(value) || 280a6628e59SEugene Zhulenev blockLiveness->getBlock() == value.getParentBlock()); 281a6628e59SEugene Zhulenev 282a6628e59SEugene Zhulenev // If value is in the live out set, it means it doesn't "die" in the block. 283a6628e59SEugene Zhulenev if (blockLiveness->isLiveOut(value)) 284a6628e59SEugene Zhulenev continue; 285a6628e59SEugene Zhulenev 286a6628e59SEugene Zhulenev // At this point we proved that `value` dies in the `block`. Find the last 287a6628e59SEugene Zhulenev // use of the `value` inside the `block`, this is where it "dies". 288a6628e59SEugene Zhulenev Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock); 289a6628e59SEugene Zhulenev assert(lastUsers.count(lastUser) == 0 && "last users must be unique"); 290a6628e59SEugene Zhulenev lastUsers.insert(lastUser); 291a6628e59SEugene Zhulenev } 292a6628e59SEugene Zhulenev 293a6628e59SEugene Zhulenev // Process all the last users of the `value` inside each block where the value 294a6628e59SEugene Zhulenev // dies. 295a6628e59SEugene Zhulenev for (Operation *lastUser : lastUsers) { 296a6628e59SEugene Zhulenev // Return like operations forward reference count. 297a6628e59SEugene Zhulenev if (lastUser->hasTrait<OpTrait::ReturnLike>()) 298a6628e59SEugene Zhulenev continue; 299a6628e59SEugene Zhulenev 300a6628e59SEugene Zhulenev // We can't currently handle other types of terminators. 301a6628e59SEugene Zhulenev if (lastUser->hasTrait<OpTrait::IsTerminator>()) 302a6628e59SEugene Zhulenev return lastUser->emitError() << "async reference counting can't handle " 303a6628e59SEugene Zhulenev "terminators that are not ReturnLike"; 304a6628e59SEugene Zhulenev 305a6628e59SEugene Zhulenev // Add a drop_ref immediately after the last user. 306a6628e59SEugene Zhulenev builder.setInsertionPointAfter(lastUser); 307*92db09cdSEugene Zhulenev builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1)); 308a6628e59SEugene Zhulenev } 309a6628e59SEugene Zhulenev 310a6628e59SEugene Zhulenev return success(); 311a6628e59SEugene Zhulenev } 312a6628e59SEugene Zhulenev 313a6628e59SEugene Zhulenev LogicalResult 314a6628e59SEugene Zhulenev AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) { 315a6628e59SEugene Zhulenev OpBuilder builder(value.getContext()); 316a6628e59SEugene Zhulenev Location loc = value.getLoc(); 317a6628e59SEugene Zhulenev 318a6628e59SEugene Zhulenev for (Operation *user : value.getUsers()) { 319a6628e59SEugene Zhulenev if (!isa<CallOp>(user)) 320a6628e59SEugene Zhulenev continue; 321a6628e59SEugene Zhulenev 322a6628e59SEugene Zhulenev // Add a reference before the function call to pass the value at `+1` 323a6628e59SEugene Zhulenev // reference to the function entry block. 324a6628e59SEugene Zhulenev builder.setInsertionPoint(user); 325*92db09cdSEugene Zhulenev builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1)); 326a6628e59SEugene Zhulenev } 327a6628e59SEugene Zhulenev 328a6628e59SEugene Zhulenev return success(); 329a6628e59SEugene Zhulenev } 330a6628e59SEugene Zhulenev 331c412979cSEugene Zhulenev LogicalResult 332c412979cSEugene Zhulenev AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor( 333c412979cSEugene Zhulenev Value value) { 334c412979cSEugene Zhulenev using BlockSet = llvm::SmallPtrSet<Block *, 4>; 335c412979cSEugene Zhulenev 336a6628e59SEugene Zhulenev OpBuilder builder(value.getContext()); 337a6628e59SEugene Zhulenev 338c412979cSEugene Zhulenev // If a block has successors with different `liveIn` property of the `value`, 339c412979cSEugene Zhulenev // record block successors that do not thave the `value` in the `liveIn` set. 340c412979cSEugene Zhulenev llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks; 341a6628e59SEugene Zhulenev 342a6628e59SEugene Zhulenev // Use liveness analysis to find the placement of `drop_ref`operation. 343a6628e59SEugene Zhulenev auto &liveness = getAnalysis<Liveness>(); 344a6628e59SEugene Zhulenev 345a6628e59SEugene Zhulenev // Because we only add `drop_ref` operations to the region that defines the 346a6628e59SEugene Zhulenev // `value` we can only process CFG for the same region. 347a6628e59SEugene Zhulenev Region *definingRegion = value.getParentRegion(); 348a6628e59SEugene Zhulenev 349a6628e59SEugene Zhulenev // Collect blocks with successors with mismatching `liveIn` sets. 350a6628e59SEugene Zhulenev for (Block &block : definingRegion->getBlocks()) { 351a6628e59SEugene Zhulenev const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); 352a6628e59SEugene Zhulenev 353a6628e59SEugene Zhulenev // Skip the block if value is not in the `liveOut` set. 3549136b7d0SEugene Zhulenev if (!blockLiveness || !blockLiveness->isLiveOut(value)) 355a6628e59SEugene Zhulenev continue; 356a6628e59SEugene Zhulenev 357c412979cSEugene Zhulenev BlockSet liveInSuccessors; // `value` is in `liveIn` set 358c412979cSEugene Zhulenev BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set 359a6628e59SEugene Zhulenev 360a6628e59SEugene Zhulenev // Collect successors that do not have `value` in the `liveIn` set. 361a6628e59SEugene Zhulenev for (Block *successor : block.getSuccessors()) { 362a6628e59SEugene Zhulenev const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor); 3639136b7d0SEugene Zhulenev if (succLiveness && succLiveness->isLiveIn(value)) 364a6628e59SEugene Zhulenev liveInSuccessors.insert(successor); 365a6628e59SEugene Zhulenev else 366a6628e59SEugene Zhulenev noLiveInSuccessors.insert(successor); 367a6628e59SEugene Zhulenev } 368a6628e59SEugene Zhulenev 369a6628e59SEugene Zhulenev // Block has successors with different `liveIn` property of the `value`. 370a6628e59SEugene Zhulenev if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty()) 371c412979cSEugene Zhulenev divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors); 372a6628e59SEugene Zhulenev } 373a6628e59SEugene Zhulenev 374c412979cSEugene Zhulenev // Try to insert `dropRef` operations to handle blocks with divergent liveness 375c412979cSEugene Zhulenev // in successors blocks. 376c412979cSEugene Zhulenev for (auto kv : divergentLivenessBlocks) { 377c412979cSEugene Zhulenev Block *block = kv.getFirst(); 378c412979cSEugene Zhulenev BlockSet &successors = kv.getSecond(); 379c412979cSEugene Zhulenev 380c412979cSEugene Zhulenev // Coroutine suspension is a special case terminator for wich we do not 381c412979cSEugene Zhulenev // need to create additional reference counting (see details above). 382a6628e59SEugene Zhulenev Operation *terminator = block->getTerminator(); 383a6628e59SEugene Zhulenev if (isa<CoroSuspendOp>(terminator)) 384a6628e59SEugene Zhulenev continue; 385a6628e59SEugene Zhulenev 386c412979cSEugene Zhulenev // We only support successor blocks with empty block argument list. 387c412979cSEugene Zhulenev auto hasArgs = [](Block *block) { return !block->getArguments().empty(); }; 388c412979cSEugene Zhulenev if (llvm::any_of(successors, hasArgs)) 389c412979cSEugene Zhulenev return terminator->emitOpError() 390c412979cSEugene Zhulenev << "successor have different `liveIn` property of the reference " 391c412979cSEugene Zhulenev "counted value"; 392c412979cSEugene Zhulenev 393c412979cSEugene Zhulenev // Make sure that `dropRef` operation is called when branched into the 394c412979cSEugene Zhulenev // successor block without `value` in the `liveIn` set. 395c412979cSEugene Zhulenev for (Block *successor : successors) { 396c412979cSEugene Zhulenev // If successor has a unique predecessor, it is safe to create `dropRef` 397c412979cSEugene Zhulenev // operations directly in the successor block. 398c412979cSEugene Zhulenev // 399c412979cSEugene Zhulenev // Otherwise we need to create a special block for reference counting 400c412979cSEugene Zhulenev // operations, and branch from it to the original successor block. 401c412979cSEugene Zhulenev Block *refCountingBlock = nullptr; 402c412979cSEugene Zhulenev 403c412979cSEugene Zhulenev if (successor->getUniquePredecessor() == block) { 404c412979cSEugene Zhulenev refCountingBlock = successor; 405c412979cSEugene Zhulenev } else { 406c412979cSEugene Zhulenev refCountingBlock = &successor->getParent()->emplaceBlock(); 407c412979cSEugene Zhulenev refCountingBlock->moveBefore(successor); 408c412979cSEugene Zhulenev OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock); 409c412979cSEugene Zhulenev builder.create<BranchOp>(value.getLoc(), successor); 410c412979cSEugene Zhulenev } 411c412979cSEugene Zhulenev 412c412979cSEugene Zhulenev OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock); 413c412979cSEugene Zhulenev builder.create<RuntimeDropRefOp>(value.getLoc(), value, 414*92db09cdSEugene Zhulenev builder.getI64IntegerAttr(1)); 415c412979cSEugene Zhulenev 416c412979cSEugene Zhulenev // No need to update the terminator operation. 417c412979cSEugene Zhulenev if (successor == refCountingBlock) 418c412979cSEugene Zhulenev continue; 419c412979cSEugene Zhulenev 420c412979cSEugene Zhulenev // Update terminator `successor` block to `refCountingBlock`. 421c412979cSEugene Zhulenev for (auto pair : llvm::enumerate(terminator->getSuccessors())) 422c412979cSEugene Zhulenev if (pair.value() == successor) 423c412979cSEugene Zhulenev terminator->setSuccessor(refCountingBlock, pair.index()); 424c412979cSEugene Zhulenev } 425a6628e59SEugene Zhulenev } 426a6628e59SEugene Zhulenev 427a6628e59SEugene Zhulenev return success(); 428a6628e59SEugene Zhulenev } 429a6628e59SEugene Zhulenev 430a6628e59SEugene Zhulenev LogicalResult 431a6628e59SEugene Zhulenev AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) { 432f57b2420SEugene Zhulenev // Short-circuit reference counting for values without uses. 433f57b2420SEugene Zhulenev if (succeeded(dropRefIfNoUses(value))) 434a6628e59SEugene Zhulenev return success(); 435a6628e59SEugene Zhulenev 436a6628e59SEugene Zhulenev // Add `drop_ref` operations based on the liveness analysis. 437a6628e59SEugene Zhulenev if (failed(addDropRefAfterLastUse(value))) 438a6628e59SEugene Zhulenev return failure(); 439a6628e59SEugene Zhulenev 440a6628e59SEugene Zhulenev // Add `add_ref` operations before function calls. 441a6628e59SEugene Zhulenev if (failed(addAddRefBeforeFunctionCall(value))) 442a6628e59SEugene Zhulenev return failure(); 443a6628e59SEugene Zhulenev 444c412979cSEugene Zhulenev // Add `drop_ref` operations to successors with divergent `value` liveness. 445c412979cSEugene Zhulenev if (failed(addDropRefInDivergentLivenessSuccessor(value))) 446a6628e59SEugene Zhulenev return failure(); 447a6628e59SEugene Zhulenev 448a6628e59SEugene Zhulenev return success(); 449a6628e59SEugene Zhulenev } 450a6628e59SEugene Zhulenev 4518a316b00SEugene Zhulenev void AsyncRuntimeRefCountingPass::runOnOperation() { 452f57b2420SEugene Zhulenev auto functor = [&](Value value) { return addAutomaticRefCounting(value); }; 453f57b2420SEugene Zhulenev if (failed(walkReferenceCountedValues(getOperation(), functor))) 454a6628e59SEugene Zhulenev signalPassFailure(); 455a6628e59SEugene Zhulenev } 456a6628e59SEugene Zhulenev 457f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===// 458f57b2420SEugene Zhulenev // Reference counting based on the user defined policy. 459f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===// 460f57b2420SEugene Zhulenev 461f57b2420SEugene Zhulenev namespace { 462f57b2420SEugene Zhulenev 463f57b2420SEugene Zhulenev class AsyncRuntimePolicyBasedRefCountingPass 464f57b2420SEugene Zhulenev : public AsyncRuntimePolicyBasedRefCountingBase< 465f57b2420SEugene Zhulenev AsyncRuntimePolicyBasedRefCountingPass> { 466f57b2420SEugene Zhulenev public: 467f57b2420SEugene Zhulenev AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); } 468f57b2420SEugene Zhulenev 469f57b2420SEugene Zhulenev void runOnOperation() override; 470f57b2420SEugene Zhulenev 471f57b2420SEugene Zhulenev private: 472f57b2420SEugene Zhulenev // Adds a reference counting operations for all uses of the `value` according 473f57b2420SEugene Zhulenev // to the reference counting policy. 474f57b2420SEugene Zhulenev LogicalResult addRefCounting(Value value); 475f57b2420SEugene Zhulenev 476f57b2420SEugene Zhulenev void initializeDefaultPolicy(); 477f57b2420SEugene Zhulenev 478f57b2420SEugene Zhulenev llvm::SmallVector<std::function<FailureOr<int>(OpOperand &)>> policy; 479f57b2420SEugene Zhulenev }; 480f57b2420SEugene Zhulenev 481f57b2420SEugene Zhulenev } // namespace 482f57b2420SEugene Zhulenev 483f57b2420SEugene Zhulenev LogicalResult 484f57b2420SEugene Zhulenev AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) { 485f57b2420SEugene Zhulenev // Short-circuit reference counting for values without uses. 486f57b2420SEugene Zhulenev if (succeeded(dropRefIfNoUses(value))) 487f57b2420SEugene Zhulenev return success(); 488f57b2420SEugene Zhulenev 489f57b2420SEugene Zhulenev OpBuilder b(value.getContext()); 490f57b2420SEugene Zhulenev 491f57b2420SEugene Zhulenev // Consult the user defined policy for every value use. 492f57b2420SEugene Zhulenev for (OpOperand &operand : value.getUses()) { 493f57b2420SEugene Zhulenev Location loc = operand.getOwner()->getLoc(); 494f57b2420SEugene Zhulenev 495f57b2420SEugene Zhulenev for (auto &func : policy) { 496f57b2420SEugene Zhulenev FailureOr<int> refCount = func(operand); 497f57b2420SEugene Zhulenev if (failed(refCount)) 498f57b2420SEugene Zhulenev return failure(); 499f57b2420SEugene Zhulenev 500f57b2420SEugene Zhulenev int cnt = refCount.getValue(); 501f57b2420SEugene Zhulenev 502f57b2420SEugene Zhulenev // Create `add_ref` operation before the operand owner. 503f57b2420SEugene Zhulenev if (cnt > 0) { 504f57b2420SEugene Zhulenev b.setInsertionPoint(operand.getOwner()); 505*92db09cdSEugene Zhulenev b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt)); 506f57b2420SEugene Zhulenev } 507f57b2420SEugene Zhulenev 508f57b2420SEugene Zhulenev // Create `drop_ref` operation after the operand owner. 509f57b2420SEugene Zhulenev if (cnt < 0) { 510f57b2420SEugene Zhulenev b.setInsertionPointAfter(operand.getOwner()); 511*92db09cdSEugene Zhulenev b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt)); 512f57b2420SEugene Zhulenev } 513f57b2420SEugene Zhulenev } 514f57b2420SEugene Zhulenev } 515f57b2420SEugene Zhulenev 516f57b2420SEugene Zhulenev return success(); 517f57b2420SEugene Zhulenev } 518f57b2420SEugene Zhulenev 519f57b2420SEugene Zhulenev void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() { 520f57b2420SEugene Zhulenev policy.push_back([](OpOperand &operand) -> FailureOr<int> { 521f57b2420SEugene Zhulenev Operation *op = operand.getOwner(); 522f57b2420SEugene Zhulenev Type type = operand.get().getType(); 523f57b2420SEugene Zhulenev 524f57b2420SEugene Zhulenev bool isToken = type.isa<TokenType>(); 525f57b2420SEugene Zhulenev bool isGroup = type.isa<GroupType>(); 526f57b2420SEugene Zhulenev bool isValue = type.isa<ValueType>(); 527f57b2420SEugene Zhulenev 528f57b2420SEugene Zhulenev // Drop reference after async token or group error check (coro await). 529f57b2420SEugene Zhulenev if (auto await = dyn_cast<RuntimeIsErrorOp>(op)) 530f57b2420SEugene Zhulenev return (isToken || isGroup) ? -1 : 0; 531f57b2420SEugene Zhulenev 532f57b2420SEugene Zhulenev // Drop reference after async value load. 533f57b2420SEugene Zhulenev if (auto load = dyn_cast<RuntimeLoadOp>(op)) 534f57b2420SEugene Zhulenev return isValue ? -1 : 0; 535f57b2420SEugene Zhulenev 536f57b2420SEugene Zhulenev // Drop reference after async token added to the group. 537f57b2420SEugene Zhulenev if (auto add = dyn_cast<RuntimeAddToGroupOp>(op)) 538f57b2420SEugene Zhulenev return isToken ? -1 : 0; 539f57b2420SEugene Zhulenev 540f57b2420SEugene Zhulenev return 0; 541f57b2420SEugene Zhulenev }); 542f57b2420SEugene Zhulenev } 543f57b2420SEugene Zhulenev 544f57b2420SEugene Zhulenev void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() { 545f57b2420SEugene Zhulenev auto functor = [&](Value value) { return addRefCounting(value); }; 546f57b2420SEugene Zhulenev if (failed(walkReferenceCountedValues(getOperation(), functor))) 547f57b2420SEugene Zhulenev signalPassFailure(); 548f57b2420SEugene Zhulenev } 549f57b2420SEugene Zhulenev 550f57b2420SEugene Zhulenev //----------------------------------------------------------------------------// 551f57b2420SEugene Zhulenev 5528a316b00SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() { 553a6628e59SEugene Zhulenev return std::make_unique<AsyncRuntimeRefCountingPass>(); 554a6628e59SEugene Zhulenev } 555f57b2420SEugene Zhulenev 556f57b2420SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncRuntimePolicyBasedRefCountingPass() { 557f57b2420SEugene Zhulenev return std::make_unique<AsyncRuntimePolicyBasedRefCountingPass>(); 558f57b2420SEugene Zhulenev } 559