1 //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Optimize Async dialect reference counting operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Async/IR/Async.h" 15 #include "mlir/Dialect/Async/Passes.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "llvm/ADT/SmallSet.h" 18 #include "llvm/Support/Debug.h" 19 20 using namespace mlir; 21 using namespace mlir::async; 22 23 #define DEBUG_TYPE "async-ref-counting" 24 25 namespace { 26 27 class AsyncRuntimeRefCountingOptPass 28 : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> { 29 public: 30 AsyncRuntimeRefCountingOptPass() = default; 31 void runOnOperation() override; 32 33 private: 34 LogicalResult optimizeReferenceCounting( 35 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable); 36 }; 37 38 } // namespace 39 40 LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting( 41 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) { 42 Region *definingRegion = value.getParentRegion(); 43 44 // Find all users of the `value` inside each block, including operations that 45 // do not use `value` directly, but have a direct use inside nested region(s). 46 // 47 // Example: 48 // 49 // ^bb1: 50 // %token = ... 51 // scf.if %cond { 52 // ^bb2: 53 // async.runtime.await %token : !async.token 54 // } 55 // 56 // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1 57 // (`scf.if`). 58 59 struct BlockUsersInfo { 60 llvm::SmallVector<RuntimeAddRefOp, 4> addRefs; 61 llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs; 62 llvm::SmallVector<Operation *, 4> users; 63 }; 64 65 llvm::DenseMap<Block *, BlockUsersInfo> blockUsers; 66 67 auto updateBlockUsersInfo = [&](Operation *user) { 68 BlockUsersInfo &info = blockUsers[user->getBlock()]; 69 info.users.push_back(user); 70 71 if (auto addRef = dyn_cast<RuntimeAddRefOp>(user)) 72 info.addRefs.push_back(addRef); 73 if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user)) 74 info.dropRefs.push_back(dropRef); 75 }; 76 77 for (Operation *user : value.getUsers()) { 78 while (user->getParentRegion() != definingRegion) { 79 updateBlockUsersInfo(user); 80 user = user->getParentOp(); 81 assert(user != nullptr && "value user lies outside of the value region"); 82 } 83 84 updateBlockUsersInfo(user); 85 } 86 87 // Sort all operations found in the block. 88 auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & { 89 auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool { 90 return a->isBeforeInBlock(b); 91 }; 92 llvm::sort(info.addRefs, isBeforeInBlock); 93 llvm::sort(info.dropRefs, isBeforeInBlock); 94 llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool { 95 return isBeforeInBlock(a, b); 96 }); 97 98 return info; 99 }; 100 101 // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the 102 // blocks that modify the reference count of the `value`. 103 for (auto &kv : blockUsers) { 104 BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second); 105 106 for (RuntimeAddRefOp addRef : info.addRefs) { 107 for (RuntimeDropRefOp dropRef : info.dropRefs) { 108 // `drop_ref` operation after the `add_ref` with matching count. 109 if (dropRef.count() != addRef.count() || 110 dropRef->isBeforeInBlock(addRef.getOperation())) 111 continue; 112 113 // When reference counted value passed to a function as an argument, 114 // function takes ownership of +1 reference and it will drop it before 115 // returning. 116 // 117 // Example: 118 // 119 // %token = ... : !async.token 120 // 121 // async.runtime.add_ref %token {count = 1 : i64} : !async.token 122 // call @pass_token(%token: !async.token, ...) 123 // 124 // async.await %token : !async.token 125 // async.runtime.drop_ref %token {count = 1 : i64} : !async.token 126 // 127 // In this example if we'll cancel a pair of reference counting 128 // operations we might end up with a deallocated token when we'll 129 // reach `async.await` operation. 130 Operation *firstFunctionCallUser = nullptr; 131 Operation *lastNonFunctionCallUser = nullptr; 132 133 for (Operation *user : info.users) { 134 // `user` operation lies after `addRef` ... 135 if (user == addRef || user->isBeforeInBlock(addRef)) 136 continue; 137 // ... and before `dropRef`. 138 if (user == dropRef || dropRef->isBeforeInBlock(user)) 139 break; 140 141 // Find the first function call user of the reference counted value. 142 Operation *functionCall = dyn_cast<CallOp>(user); 143 if (functionCall && 144 (!firstFunctionCallUser || 145 functionCall->isBeforeInBlock(firstFunctionCallUser))) { 146 firstFunctionCallUser = functionCall; 147 continue; 148 } 149 150 // Find the last regular user of the reference counted value. 151 if (!functionCall && 152 (!lastNonFunctionCallUser || 153 lastNonFunctionCallUser->isBeforeInBlock(user))) { 154 lastNonFunctionCallUser = user; 155 continue; 156 } 157 } 158 159 // Non function call user after the function call user of the reference 160 // counted value. 161 if (firstFunctionCallUser && lastNonFunctionCallUser && 162 firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser)) 163 continue; 164 165 // Try to cancel the pair of `add_ref` and `drop_ref` operations. 166 auto emplaced = cancellable.try_emplace(dropRef.getOperation(), 167 addRef.getOperation()); 168 169 if (!emplaced.second) // `drop_ref` was already marked for removal 170 continue; // go to the next `drop_ref` 171 172 if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref` 173 break; // go to the next `add_ref` 174 } 175 } 176 } 177 178 return success(); 179 } 180 181 void AsyncRuntimeRefCountingOptPass::runOnOperation() { 182 Operation *op = getOperation(); 183 184 // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`. 185 // 186 // Find all cancellable pairs of operation and erase them in the end to keep 187 // all iterators valid while we are walking the function operations. 188 llvm::SmallDenseMap<Operation *, Operation *> cancellable; 189 190 // Optimize reference counting for values defined by block arguments. 191 WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult { 192 for (BlockArgument arg : block->getArguments()) 193 if (isRefCounted(arg.getType())) 194 if (failed(optimizeReferenceCounting(arg, cancellable))) 195 return WalkResult::interrupt(); 196 197 return WalkResult::advance(); 198 }); 199 200 if (blockWalk.wasInterrupted()) 201 signalPassFailure(); 202 203 // Optimize reference counting for values defined by operation results. 204 WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult { 205 for (unsigned i = 0; i < op->getNumResults(); ++i) 206 if (isRefCounted(op->getResultTypes()[i])) 207 if (failed(optimizeReferenceCounting(op->getResult(i), cancellable))) 208 return WalkResult::interrupt(); 209 210 return WalkResult::advance(); 211 }); 212 213 if (opWalk.wasInterrupted()) 214 signalPassFailure(); 215 216 LLVM_DEBUG({ 217 llvm::dbgs() << "Found " << cancellable.size() 218 << " cancellable reference counting operations\n"; 219 }); 220 221 // Erase all cancellable `add_ref <-> drop_ref` operation pairs. 222 for (auto &kv : cancellable) { 223 kv.first->erase(); 224 kv.second->erase(); 225 } 226 } 227 228 std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() { 229 return std::make_unique<AsyncRuntimeRefCountingOptPass>(); 230 } 231