1a6628e59SEugene Zhulenev //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===// 2a6628e59SEugene Zhulenev // 3a6628e59SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a6628e59SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information. 5a6628e59SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a6628e59SEugene Zhulenev // 7a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===// 8a6628e59SEugene Zhulenev // 9a6628e59SEugene Zhulenev // Optimize Async dialect reference counting operations. 10a6628e59SEugene Zhulenev // 11a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===// 12a6628e59SEugene Zhulenev 13a6628e59SEugene Zhulenev #include "PassDetail.h" 14a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h" 15a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h" 16a6628e59SEugene Zhulenev #include "llvm/ADT/SmallSet.h" 17a6628e59SEugene Zhulenev 18a6628e59SEugene Zhulenev using namespace mlir; 19a6628e59SEugene Zhulenev using namespace mlir::async; 20a6628e59SEugene Zhulenev 21a6628e59SEugene Zhulenev #define DEBUG_TYPE "async-ref-counting" 22a6628e59SEugene Zhulenev 23a6628e59SEugene Zhulenev namespace { 24a6628e59SEugene Zhulenev 25a6628e59SEugene Zhulenev class AsyncRuntimeRefCountingOptPass 26a6628e59SEugene Zhulenev : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> { 27a6628e59SEugene Zhulenev public: 28a6628e59SEugene Zhulenev AsyncRuntimeRefCountingOptPass() = default; 29*8a316b00SEugene Zhulenev void runOnOperation() override; 30a6628e59SEugene Zhulenev 31a6628e59SEugene Zhulenev private: 32a6628e59SEugene Zhulenev LogicalResult optimizeReferenceCounting( 33a6628e59SEugene Zhulenev Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable); 34a6628e59SEugene Zhulenev }; 35a6628e59SEugene Zhulenev 36a6628e59SEugene Zhulenev } // namespace 37a6628e59SEugene Zhulenev 38a6628e59SEugene Zhulenev LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting( 39a6628e59SEugene Zhulenev Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) { 40a6628e59SEugene Zhulenev Region *definingRegion = value.getParentRegion(); 41a6628e59SEugene Zhulenev 42a6628e59SEugene Zhulenev // Find all users of the `value` inside each block, including operations that 43a6628e59SEugene Zhulenev // do not use `value` directly, but have a direct use inside nested region(s). 44a6628e59SEugene Zhulenev // 45a6628e59SEugene Zhulenev // Example: 46a6628e59SEugene Zhulenev // 47a6628e59SEugene Zhulenev // ^bb1: 48a6628e59SEugene Zhulenev // %token = ... 49a6628e59SEugene Zhulenev // scf.if %cond { 50a6628e59SEugene Zhulenev // ^bb2: 51a6628e59SEugene Zhulenev // async.runtime.await %token : !async.token 52a6628e59SEugene Zhulenev // } 53a6628e59SEugene Zhulenev // 54a6628e59SEugene Zhulenev // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1 55a6628e59SEugene Zhulenev // (`scf.if`). 56a6628e59SEugene Zhulenev 57a6628e59SEugene Zhulenev struct BlockUsersInfo { 58a6628e59SEugene Zhulenev llvm::SmallVector<RuntimeAddRefOp, 4> addRefs; 59a6628e59SEugene Zhulenev llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs; 60a6628e59SEugene Zhulenev llvm::SmallVector<Operation *, 4> users; 61a6628e59SEugene Zhulenev }; 62a6628e59SEugene Zhulenev 63a6628e59SEugene Zhulenev llvm::DenseMap<Block *, BlockUsersInfo> blockUsers; 64a6628e59SEugene Zhulenev 65a6628e59SEugene Zhulenev auto updateBlockUsersInfo = [&](Operation *user) { 66a6628e59SEugene Zhulenev BlockUsersInfo &info = blockUsers[user->getBlock()]; 67a6628e59SEugene Zhulenev info.users.push_back(user); 68a6628e59SEugene Zhulenev 69a6628e59SEugene Zhulenev if (auto addRef = dyn_cast<RuntimeAddRefOp>(user)) 70a6628e59SEugene Zhulenev info.addRefs.push_back(addRef); 71a6628e59SEugene Zhulenev if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user)) 72a6628e59SEugene Zhulenev info.dropRefs.push_back(dropRef); 73a6628e59SEugene Zhulenev }; 74a6628e59SEugene Zhulenev 75a6628e59SEugene Zhulenev for (Operation *user : value.getUsers()) { 76a6628e59SEugene Zhulenev while (user->getParentRegion() != definingRegion) { 77a6628e59SEugene Zhulenev updateBlockUsersInfo(user); 78a6628e59SEugene Zhulenev user = user->getParentOp(); 79a6628e59SEugene Zhulenev assert(user != nullptr && "value user lies outside of the value region"); 80a6628e59SEugene Zhulenev } 81a6628e59SEugene Zhulenev 82a6628e59SEugene Zhulenev updateBlockUsersInfo(user); 83a6628e59SEugene Zhulenev } 84a6628e59SEugene Zhulenev 85a6628e59SEugene Zhulenev // Sort all operations found in the block. 86a6628e59SEugene Zhulenev auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & { 87a6628e59SEugene Zhulenev auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool { 88a6628e59SEugene Zhulenev return a->isBeforeInBlock(b); 89a6628e59SEugene Zhulenev }; 90a6628e59SEugene Zhulenev llvm::sort(info.addRefs, isBeforeInBlock); 91a6628e59SEugene Zhulenev llvm::sort(info.dropRefs, isBeforeInBlock); 92a6628e59SEugene Zhulenev llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool { 93a6628e59SEugene Zhulenev return isBeforeInBlock(a, b); 94a6628e59SEugene Zhulenev }); 95a6628e59SEugene Zhulenev 96a6628e59SEugene Zhulenev return info; 97a6628e59SEugene Zhulenev }; 98a6628e59SEugene Zhulenev 99a6628e59SEugene Zhulenev // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the 100a6628e59SEugene Zhulenev // blocks that modify the reference count of the `value`. 101a6628e59SEugene Zhulenev for (auto &kv : blockUsers) { 102a6628e59SEugene Zhulenev BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second); 103a6628e59SEugene Zhulenev 104a6628e59SEugene Zhulenev for (RuntimeAddRefOp addRef : info.addRefs) { 105a6628e59SEugene Zhulenev for (RuntimeDropRefOp dropRef : info.dropRefs) { 106a6628e59SEugene Zhulenev // `drop_ref` operation after the `add_ref` with matching count. 107a6628e59SEugene Zhulenev if (dropRef.count() != addRef.count() || 108a6628e59SEugene Zhulenev dropRef->isBeforeInBlock(addRef.getOperation())) 109a6628e59SEugene Zhulenev continue; 110a6628e59SEugene Zhulenev 111a6628e59SEugene Zhulenev // Try to cancel the pair of `add_ref` and `drop_ref` operations. 112a6628e59SEugene Zhulenev auto emplaced = cancellable.try_emplace(dropRef.getOperation(), 113a6628e59SEugene Zhulenev addRef.getOperation()); 114a6628e59SEugene Zhulenev 115a6628e59SEugene Zhulenev if (!emplaced.second) // `drop_ref` was already marked for removal 116a6628e59SEugene Zhulenev continue; // go to the next `drop_ref` 117a6628e59SEugene Zhulenev 118a6628e59SEugene Zhulenev if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref` 119a6628e59SEugene Zhulenev break; // go to the next `add_ref` 120a6628e59SEugene Zhulenev } 121a6628e59SEugene Zhulenev } 122a6628e59SEugene Zhulenev } 123a6628e59SEugene Zhulenev 124a6628e59SEugene Zhulenev return success(); 125a6628e59SEugene Zhulenev } 126a6628e59SEugene Zhulenev 127*8a316b00SEugene Zhulenev void AsyncRuntimeRefCountingOptPass::runOnOperation() { 128*8a316b00SEugene Zhulenev Operation *op = getOperation(); 129a6628e59SEugene Zhulenev 130a6628e59SEugene Zhulenev // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`. 131a6628e59SEugene Zhulenev // 132a6628e59SEugene Zhulenev // Find all cancellable pairs of operation and erase them in the end to keep 133a6628e59SEugene Zhulenev // all iterators valid while we are walking the function operations. 134a6628e59SEugene Zhulenev llvm::SmallDenseMap<Operation *, Operation *> cancellable; 135a6628e59SEugene Zhulenev 136a6628e59SEugene Zhulenev // Optimize reference counting for values defined by block arguments. 137*8a316b00SEugene Zhulenev WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult { 138a6628e59SEugene Zhulenev for (BlockArgument arg : block->getArguments()) 139a6628e59SEugene Zhulenev if (isRefCounted(arg.getType())) 140a6628e59SEugene Zhulenev if (failed(optimizeReferenceCounting(arg, cancellable))) 141a6628e59SEugene Zhulenev return WalkResult::interrupt(); 142a6628e59SEugene Zhulenev 143a6628e59SEugene Zhulenev return WalkResult::advance(); 144a6628e59SEugene Zhulenev }); 145a6628e59SEugene Zhulenev 146a6628e59SEugene Zhulenev if (blockWalk.wasInterrupted()) 147a6628e59SEugene Zhulenev signalPassFailure(); 148a6628e59SEugene Zhulenev 149a6628e59SEugene Zhulenev // Optimize reference counting for values defined by operation results. 150*8a316b00SEugene Zhulenev WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult { 151a6628e59SEugene Zhulenev for (unsigned i = 0; i < op->getNumResults(); ++i) 152a6628e59SEugene Zhulenev if (isRefCounted(op->getResultTypes()[i])) 153a6628e59SEugene Zhulenev if (failed(optimizeReferenceCounting(op->getResult(i), cancellable))) 154a6628e59SEugene Zhulenev return WalkResult::interrupt(); 155a6628e59SEugene Zhulenev 156a6628e59SEugene Zhulenev return WalkResult::advance(); 157a6628e59SEugene Zhulenev }); 158a6628e59SEugene Zhulenev 159a6628e59SEugene Zhulenev if (opWalk.wasInterrupted()) 160a6628e59SEugene Zhulenev signalPassFailure(); 161a6628e59SEugene Zhulenev 162a6628e59SEugene Zhulenev LLVM_DEBUG({ 163a6628e59SEugene Zhulenev llvm::dbgs() << "Found " << cancellable.size() 164a6628e59SEugene Zhulenev << " cancellable reference counting operations\n"; 165a6628e59SEugene Zhulenev }); 166a6628e59SEugene Zhulenev 167a6628e59SEugene Zhulenev // Erase all cancellable `add_ref <-> drop_ref` operation pairs. 168a6628e59SEugene Zhulenev for (auto &kv : cancellable) { 169a6628e59SEugene Zhulenev kv.first->erase(); 170a6628e59SEugene Zhulenev kv.second->erase(); 171a6628e59SEugene Zhulenev } 172a6628e59SEugene Zhulenev } 173a6628e59SEugene Zhulenev 174*8a316b00SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() { 175a6628e59SEugene Zhulenev return std::make_unique<AsyncRuntimeRefCountingOptPass>(); 176a6628e59SEugene Zhulenev } 177