1a6628e59SEugene Zhulenev //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===// 2a6628e59SEugene Zhulenev // 3a6628e59SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a6628e59SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information. 5a6628e59SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a6628e59SEugene Zhulenev // 7a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===// 8a6628e59SEugene Zhulenev // 9a6628e59SEugene Zhulenev // Optimize Async dialect reference counting operations. 10a6628e59SEugene Zhulenev // 11a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===// 12a6628e59SEugene Zhulenev 13a6628e59SEugene Zhulenev #include "PassDetail.h" 14a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h" 15a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h" 169ccdaac8SEugene Zhulenev #include "mlir/Dialect/StandardOps/IR/Ops.h" 17a6628e59SEugene Zhulenev #include "llvm/ADT/SmallSet.h" 18297a5b7cSNico Weber #include "llvm/Support/Debug.h" 19a6628e59SEugene Zhulenev 20a6628e59SEugene Zhulenev using namespace mlir; 21a6628e59SEugene Zhulenev using namespace mlir::async; 22a6628e59SEugene Zhulenev 23a6628e59SEugene Zhulenev #define DEBUG_TYPE "async-ref-counting" 24a6628e59SEugene Zhulenev 25a6628e59SEugene Zhulenev namespace { 26a6628e59SEugene Zhulenev 27a6628e59SEugene Zhulenev class AsyncRuntimeRefCountingOptPass 28a6628e59SEugene Zhulenev : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> { 29a6628e59SEugene Zhulenev public: 30a6628e59SEugene Zhulenev AsyncRuntimeRefCountingOptPass() = default; 318a316b00SEugene Zhulenev void runOnOperation() override; 32a6628e59SEugene Zhulenev 33a6628e59SEugene Zhulenev private: 34a6628e59SEugene Zhulenev LogicalResult optimizeReferenceCounting( 35a6628e59SEugene Zhulenev Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable); 36a6628e59SEugene Zhulenev }; 37a6628e59SEugene Zhulenev 38a6628e59SEugene Zhulenev } // namespace 39a6628e59SEugene Zhulenev 40a6628e59SEugene Zhulenev LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting( 41a6628e59SEugene Zhulenev Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) { 42a6628e59SEugene Zhulenev Region *definingRegion = value.getParentRegion(); 43a6628e59SEugene Zhulenev 44a6628e59SEugene Zhulenev // Find all users of the `value` inside each block, including operations that 45a6628e59SEugene Zhulenev // do not use `value` directly, but have a direct use inside nested region(s). 46a6628e59SEugene Zhulenev // 47a6628e59SEugene Zhulenev // Example: 48a6628e59SEugene Zhulenev // 49a6628e59SEugene Zhulenev // ^bb1: 50a6628e59SEugene Zhulenev // %token = ... 51a6628e59SEugene Zhulenev // scf.if %cond { 52a6628e59SEugene Zhulenev // ^bb2: 53a6628e59SEugene Zhulenev // async.runtime.await %token : !async.token 54a6628e59SEugene Zhulenev // } 55a6628e59SEugene Zhulenev // 56a6628e59SEugene Zhulenev // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1 57a6628e59SEugene Zhulenev // (`scf.if`). 58a6628e59SEugene Zhulenev 59a6628e59SEugene Zhulenev struct BlockUsersInfo { 60a6628e59SEugene Zhulenev llvm::SmallVector<RuntimeAddRefOp, 4> addRefs; 61a6628e59SEugene Zhulenev llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs; 62a6628e59SEugene Zhulenev llvm::SmallVector<Operation *, 4> users; 63a6628e59SEugene Zhulenev }; 64a6628e59SEugene Zhulenev 65a6628e59SEugene Zhulenev llvm::DenseMap<Block *, BlockUsersInfo> blockUsers; 66a6628e59SEugene Zhulenev 67a6628e59SEugene Zhulenev auto updateBlockUsersInfo = [&](Operation *user) { 68a6628e59SEugene Zhulenev BlockUsersInfo &info = blockUsers[user->getBlock()]; 69a6628e59SEugene Zhulenev info.users.push_back(user); 70a6628e59SEugene Zhulenev 71a6628e59SEugene Zhulenev if (auto addRef = dyn_cast<RuntimeAddRefOp>(user)) 72a6628e59SEugene Zhulenev info.addRefs.push_back(addRef); 73a6628e59SEugene Zhulenev if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user)) 74a6628e59SEugene Zhulenev info.dropRefs.push_back(dropRef); 75a6628e59SEugene Zhulenev }; 76a6628e59SEugene Zhulenev 77a6628e59SEugene Zhulenev for (Operation *user : value.getUsers()) { 78a6628e59SEugene Zhulenev while (user->getParentRegion() != definingRegion) { 79a6628e59SEugene Zhulenev updateBlockUsersInfo(user); 80a6628e59SEugene Zhulenev user = user->getParentOp(); 81a6628e59SEugene Zhulenev assert(user != nullptr && "value user lies outside of the value region"); 82a6628e59SEugene Zhulenev } 83a6628e59SEugene Zhulenev 84a6628e59SEugene Zhulenev updateBlockUsersInfo(user); 85a6628e59SEugene Zhulenev } 86a6628e59SEugene Zhulenev 87a6628e59SEugene Zhulenev // Sort all operations found in the block. 88a6628e59SEugene Zhulenev auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & { 89a6628e59SEugene Zhulenev auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool { 90a6628e59SEugene Zhulenev return a->isBeforeInBlock(b); 91a6628e59SEugene Zhulenev }; 92a6628e59SEugene Zhulenev llvm::sort(info.addRefs, isBeforeInBlock); 93a6628e59SEugene Zhulenev llvm::sort(info.dropRefs, isBeforeInBlock); 94a6628e59SEugene Zhulenev llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool { 95a6628e59SEugene Zhulenev return isBeforeInBlock(a, b); 96a6628e59SEugene Zhulenev }); 97a6628e59SEugene Zhulenev 98a6628e59SEugene Zhulenev return info; 99a6628e59SEugene Zhulenev }; 100a6628e59SEugene Zhulenev 101a6628e59SEugene Zhulenev // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the 102a6628e59SEugene Zhulenev // blocks that modify the reference count of the `value`. 103a6628e59SEugene Zhulenev for (auto &kv : blockUsers) { 104a6628e59SEugene Zhulenev BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second); 105a6628e59SEugene Zhulenev 106a6628e59SEugene Zhulenev for (RuntimeAddRefOp addRef : info.addRefs) { 107a6628e59SEugene Zhulenev for (RuntimeDropRefOp dropRef : info.dropRefs) { 108a6628e59SEugene Zhulenev // `drop_ref` operation after the `add_ref` with matching count. 109a6628e59SEugene Zhulenev if (dropRef.count() != addRef.count() || 110a6628e59SEugene Zhulenev dropRef->isBeforeInBlock(addRef.getOperation())) 111a6628e59SEugene Zhulenev continue; 112a6628e59SEugene Zhulenev 1139ccdaac8SEugene Zhulenev // When reference counted value passed to a function as an argument, 1149ccdaac8SEugene Zhulenev // function takes ownership of +1 reference and it will drop it before 1159ccdaac8SEugene Zhulenev // returning. 1169ccdaac8SEugene Zhulenev // 1179ccdaac8SEugene Zhulenev // Example: 1189ccdaac8SEugene Zhulenev // 1199ccdaac8SEugene Zhulenev // %token = ... : !async.token 1209ccdaac8SEugene Zhulenev // 121*92db09cdSEugene Zhulenev // async.runtime.add_ref %token {count = 1 : i64} : !async.token 1229ccdaac8SEugene Zhulenev // call @pass_token(%token: !async.token, ...) 1239ccdaac8SEugene Zhulenev // 1249ccdaac8SEugene Zhulenev // async.await %token : !async.token 125*92db09cdSEugene Zhulenev // async.runtime.drop_ref %token {count = 1 : i64} : !async.token 1269ccdaac8SEugene Zhulenev // 1279ccdaac8SEugene Zhulenev // In this example if we'll cancel a pair of reference counting 1289ccdaac8SEugene Zhulenev // operations we might end up with a deallocated token when we'll 1299ccdaac8SEugene Zhulenev // reach `async.await` operation. 1309ccdaac8SEugene Zhulenev Operation *firstFunctionCallUser = nullptr; 1319ccdaac8SEugene Zhulenev Operation *lastNonFunctionCallUser = nullptr; 1329ccdaac8SEugene Zhulenev 1339ccdaac8SEugene Zhulenev for (Operation *user : info.users) { 1349ccdaac8SEugene Zhulenev // `user` operation lies after `addRef` ... 1359ccdaac8SEugene Zhulenev if (user == addRef || user->isBeforeInBlock(addRef)) 1369ccdaac8SEugene Zhulenev continue; 1379ccdaac8SEugene Zhulenev // ... and before `dropRef`. 1389ccdaac8SEugene Zhulenev if (user == dropRef || dropRef->isBeforeInBlock(user)) 1399ccdaac8SEugene Zhulenev break; 1409ccdaac8SEugene Zhulenev 1419ccdaac8SEugene Zhulenev // Find the first function call user of the reference counted value. 1429ccdaac8SEugene Zhulenev Operation *functionCall = dyn_cast<CallOp>(user); 1439ccdaac8SEugene Zhulenev if (functionCall && 1449ccdaac8SEugene Zhulenev (!firstFunctionCallUser || 1459ccdaac8SEugene Zhulenev functionCall->isBeforeInBlock(firstFunctionCallUser))) { 1469ccdaac8SEugene Zhulenev firstFunctionCallUser = functionCall; 1479ccdaac8SEugene Zhulenev continue; 1489ccdaac8SEugene Zhulenev } 1499ccdaac8SEugene Zhulenev 1509ccdaac8SEugene Zhulenev // Find the last regular user of the reference counted value. 1519ccdaac8SEugene Zhulenev if (!functionCall && 1529ccdaac8SEugene Zhulenev (!lastNonFunctionCallUser || 1539ccdaac8SEugene Zhulenev lastNonFunctionCallUser->isBeforeInBlock(user))) { 1549ccdaac8SEugene Zhulenev lastNonFunctionCallUser = user; 1559ccdaac8SEugene Zhulenev continue; 1569ccdaac8SEugene Zhulenev } 1579ccdaac8SEugene Zhulenev } 1589ccdaac8SEugene Zhulenev 1599ccdaac8SEugene Zhulenev // Non function call user after the function call user of the reference 1609ccdaac8SEugene Zhulenev // counted value. 1619ccdaac8SEugene Zhulenev if (firstFunctionCallUser && lastNonFunctionCallUser && 1629ccdaac8SEugene Zhulenev firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser)) 1639ccdaac8SEugene Zhulenev continue; 1649ccdaac8SEugene Zhulenev 165a6628e59SEugene Zhulenev // Try to cancel the pair of `add_ref` and `drop_ref` operations. 166a6628e59SEugene Zhulenev auto emplaced = cancellable.try_emplace(dropRef.getOperation(), 167a6628e59SEugene Zhulenev addRef.getOperation()); 168a6628e59SEugene Zhulenev 169a6628e59SEugene Zhulenev if (!emplaced.second) // `drop_ref` was already marked for removal 170a6628e59SEugene Zhulenev continue; // go to the next `drop_ref` 171a6628e59SEugene Zhulenev 172a6628e59SEugene Zhulenev if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref` 173a6628e59SEugene Zhulenev break; // go to the next `add_ref` 174a6628e59SEugene Zhulenev } 175a6628e59SEugene Zhulenev } 176a6628e59SEugene Zhulenev } 177a6628e59SEugene Zhulenev 178a6628e59SEugene Zhulenev return success(); 179a6628e59SEugene Zhulenev } 180a6628e59SEugene Zhulenev 1818a316b00SEugene Zhulenev void AsyncRuntimeRefCountingOptPass::runOnOperation() { 1828a316b00SEugene Zhulenev Operation *op = getOperation(); 183a6628e59SEugene Zhulenev 184a6628e59SEugene Zhulenev // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`. 185a6628e59SEugene Zhulenev // 186a6628e59SEugene Zhulenev // Find all cancellable pairs of operation and erase them in the end to keep 187a6628e59SEugene Zhulenev // all iterators valid while we are walking the function operations. 188a6628e59SEugene Zhulenev llvm::SmallDenseMap<Operation *, Operation *> cancellable; 189a6628e59SEugene Zhulenev 190a6628e59SEugene Zhulenev // Optimize reference counting for values defined by block arguments. 1918a316b00SEugene Zhulenev WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult { 192a6628e59SEugene Zhulenev for (BlockArgument arg : block->getArguments()) 193a6628e59SEugene Zhulenev if (isRefCounted(arg.getType())) 194a6628e59SEugene Zhulenev if (failed(optimizeReferenceCounting(arg, cancellable))) 195a6628e59SEugene Zhulenev return WalkResult::interrupt(); 196a6628e59SEugene Zhulenev 197a6628e59SEugene Zhulenev return WalkResult::advance(); 198a6628e59SEugene Zhulenev }); 199a6628e59SEugene Zhulenev 200a6628e59SEugene Zhulenev if (blockWalk.wasInterrupted()) 201a6628e59SEugene Zhulenev signalPassFailure(); 202a6628e59SEugene Zhulenev 203a6628e59SEugene Zhulenev // Optimize reference counting for values defined by operation results. 2048a316b00SEugene Zhulenev WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult { 205a6628e59SEugene Zhulenev for (unsigned i = 0; i < op->getNumResults(); ++i) 206a6628e59SEugene Zhulenev if (isRefCounted(op->getResultTypes()[i])) 207a6628e59SEugene Zhulenev if (failed(optimizeReferenceCounting(op->getResult(i), cancellable))) 208a6628e59SEugene Zhulenev return WalkResult::interrupt(); 209a6628e59SEugene Zhulenev 210a6628e59SEugene Zhulenev return WalkResult::advance(); 211a6628e59SEugene Zhulenev }); 212a6628e59SEugene Zhulenev 213a6628e59SEugene Zhulenev if (opWalk.wasInterrupted()) 214a6628e59SEugene Zhulenev signalPassFailure(); 215a6628e59SEugene Zhulenev 216a6628e59SEugene Zhulenev LLVM_DEBUG({ 217a6628e59SEugene Zhulenev llvm::dbgs() << "Found " << cancellable.size() 218a6628e59SEugene Zhulenev << " cancellable reference counting operations\n"; 219a6628e59SEugene Zhulenev }); 220a6628e59SEugene Zhulenev 221a6628e59SEugene Zhulenev // Erase all cancellable `add_ref <-> drop_ref` operation pairs. 222a6628e59SEugene Zhulenev for (auto &kv : cancellable) { 223a6628e59SEugene Zhulenev kv.first->erase(); 224a6628e59SEugene Zhulenev kv.second->erase(); 225a6628e59SEugene Zhulenev } 226a6628e59SEugene Zhulenev } 227a6628e59SEugene Zhulenev 2288a316b00SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() { 229a6628e59SEugene Zhulenev return std::make_unique<AsyncRuntimeRefCountingOptPass>(); 230a6628e59SEugene Zhulenev } 231