1*a6628e59SEugene Zhulenev //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===// 2*a6628e59SEugene Zhulenev // 3*a6628e59SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*a6628e59SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information. 5*a6628e59SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*a6628e59SEugene Zhulenev // 7*a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===// 8*a6628e59SEugene Zhulenev // 9*a6628e59SEugene Zhulenev // Optimize Async dialect reference counting operations. 10*a6628e59SEugene Zhulenev // 11*a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===// 12*a6628e59SEugene Zhulenev 13*a6628e59SEugene Zhulenev #include "PassDetail.h" 14*a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h" 15*a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h" 16*a6628e59SEugene Zhulenev #include "llvm/ADT/SmallSet.h" 17*a6628e59SEugene Zhulenev 18*a6628e59SEugene Zhulenev using namespace mlir; 19*a6628e59SEugene Zhulenev using namespace mlir::async; 20*a6628e59SEugene Zhulenev 21*a6628e59SEugene Zhulenev #define DEBUG_TYPE "async-ref-counting" 22*a6628e59SEugene Zhulenev 23*a6628e59SEugene Zhulenev namespace { 24*a6628e59SEugene Zhulenev 25*a6628e59SEugene Zhulenev class AsyncRuntimeRefCountingOptPass 26*a6628e59SEugene Zhulenev : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> { 27*a6628e59SEugene Zhulenev public: 28*a6628e59SEugene Zhulenev AsyncRuntimeRefCountingOptPass() = default; 29*a6628e59SEugene Zhulenev void runOnFunction() override; 30*a6628e59SEugene Zhulenev 31*a6628e59SEugene Zhulenev private: 32*a6628e59SEugene Zhulenev LogicalResult optimizeReferenceCounting( 33*a6628e59SEugene Zhulenev Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable); 34*a6628e59SEugene Zhulenev }; 35*a6628e59SEugene Zhulenev 36*a6628e59SEugene Zhulenev } // namespace 37*a6628e59SEugene Zhulenev 38*a6628e59SEugene Zhulenev LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting( 39*a6628e59SEugene Zhulenev Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) { 40*a6628e59SEugene Zhulenev Region *definingRegion = value.getParentRegion(); 41*a6628e59SEugene Zhulenev 42*a6628e59SEugene Zhulenev // Find all users of the `value` inside each block, including operations that 43*a6628e59SEugene Zhulenev // do not use `value` directly, but have a direct use inside nested region(s). 44*a6628e59SEugene Zhulenev // 45*a6628e59SEugene Zhulenev // Example: 46*a6628e59SEugene Zhulenev // 47*a6628e59SEugene Zhulenev // ^bb1: 48*a6628e59SEugene Zhulenev // %token = ... 49*a6628e59SEugene Zhulenev // scf.if %cond { 50*a6628e59SEugene Zhulenev // ^bb2: 51*a6628e59SEugene Zhulenev // async.runtime.await %token : !async.token 52*a6628e59SEugene Zhulenev // } 53*a6628e59SEugene Zhulenev // 54*a6628e59SEugene Zhulenev // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1 55*a6628e59SEugene Zhulenev // (`scf.if`). 56*a6628e59SEugene Zhulenev 57*a6628e59SEugene Zhulenev struct BlockUsersInfo { 58*a6628e59SEugene Zhulenev llvm::SmallVector<RuntimeAddRefOp, 4> addRefs; 59*a6628e59SEugene Zhulenev llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs; 60*a6628e59SEugene Zhulenev llvm::SmallVector<Operation *, 4> users; 61*a6628e59SEugene Zhulenev }; 62*a6628e59SEugene Zhulenev 63*a6628e59SEugene Zhulenev llvm::DenseMap<Block *, BlockUsersInfo> blockUsers; 64*a6628e59SEugene Zhulenev 65*a6628e59SEugene Zhulenev auto updateBlockUsersInfo = [&](Operation *user) { 66*a6628e59SEugene Zhulenev BlockUsersInfo &info = blockUsers[user->getBlock()]; 67*a6628e59SEugene Zhulenev info.users.push_back(user); 68*a6628e59SEugene Zhulenev 69*a6628e59SEugene Zhulenev if (auto addRef = dyn_cast<RuntimeAddRefOp>(user)) 70*a6628e59SEugene Zhulenev info.addRefs.push_back(addRef); 71*a6628e59SEugene Zhulenev if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user)) 72*a6628e59SEugene Zhulenev info.dropRefs.push_back(dropRef); 73*a6628e59SEugene Zhulenev }; 74*a6628e59SEugene Zhulenev 75*a6628e59SEugene Zhulenev for (Operation *user : value.getUsers()) { 76*a6628e59SEugene Zhulenev while (user->getParentRegion() != definingRegion) { 77*a6628e59SEugene Zhulenev updateBlockUsersInfo(user); 78*a6628e59SEugene Zhulenev user = user->getParentOp(); 79*a6628e59SEugene Zhulenev assert(user != nullptr && "value user lies outside of the value region"); 80*a6628e59SEugene Zhulenev } 81*a6628e59SEugene Zhulenev 82*a6628e59SEugene Zhulenev updateBlockUsersInfo(user); 83*a6628e59SEugene Zhulenev } 84*a6628e59SEugene Zhulenev 85*a6628e59SEugene Zhulenev // Sort all operations found in the block. 86*a6628e59SEugene Zhulenev auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & { 87*a6628e59SEugene Zhulenev auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool { 88*a6628e59SEugene Zhulenev return a->isBeforeInBlock(b); 89*a6628e59SEugene Zhulenev }; 90*a6628e59SEugene Zhulenev llvm::sort(info.addRefs, isBeforeInBlock); 91*a6628e59SEugene Zhulenev llvm::sort(info.dropRefs, isBeforeInBlock); 92*a6628e59SEugene Zhulenev llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool { 93*a6628e59SEugene Zhulenev return isBeforeInBlock(a, b); 94*a6628e59SEugene Zhulenev }); 95*a6628e59SEugene Zhulenev 96*a6628e59SEugene Zhulenev return info; 97*a6628e59SEugene Zhulenev }; 98*a6628e59SEugene Zhulenev 99*a6628e59SEugene Zhulenev // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the 100*a6628e59SEugene Zhulenev // blocks that modify the reference count of the `value`. 101*a6628e59SEugene Zhulenev for (auto &kv : blockUsers) { 102*a6628e59SEugene Zhulenev BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second); 103*a6628e59SEugene Zhulenev 104*a6628e59SEugene Zhulenev for (RuntimeAddRefOp addRef : info.addRefs) { 105*a6628e59SEugene Zhulenev for (RuntimeDropRefOp dropRef : info.dropRefs) { 106*a6628e59SEugene Zhulenev // `drop_ref` operation after the `add_ref` with matching count. 107*a6628e59SEugene Zhulenev if (dropRef.count() != addRef.count() || 108*a6628e59SEugene Zhulenev dropRef->isBeforeInBlock(addRef.getOperation())) 109*a6628e59SEugene Zhulenev continue; 110*a6628e59SEugene Zhulenev 111*a6628e59SEugene Zhulenev // Try to cancel the pair of `add_ref` and `drop_ref` operations. 112*a6628e59SEugene Zhulenev auto emplaced = cancellable.try_emplace(dropRef.getOperation(), 113*a6628e59SEugene Zhulenev addRef.getOperation()); 114*a6628e59SEugene Zhulenev 115*a6628e59SEugene Zhulenev if (!emplaced.second) // `drop_ref` was already marked for removal 116*a6628e59SEugene Zhulenev continue; // go to the next `drop_ref` 117*a6628e59SEugene Zhulenev 118*a6628e59SEugene Zhulenev if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref` 119*a6628e59SEugene Zhulenev break; // go to the next `add_ref` 120*a6628e59SEugene Zhulenev } 121*a6628e59SEugene Zhulenev } 122*a6628e59SEugene Zhulenev } 123*a6628e59SEugene Zhulenev 124*a6628e59SEugene Zhulenev return success(); 125*a6628e59SEugene Zhulenev } 126*a6628e59SEugene Zhulenev 127*a6628e59SEugene Zhulenev void AsyncRuntimeRefCountingOptPass::runOnFunction() { 128*a6628e59SEugene Zhulenev FuncOp func = getFunction(); 129*a6628e59SEugene Zhulenev 130*a6628e59SEugene Zhulenev // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`. 131*a6628e59SEugene Zhulenev // 132*a6628e59SEugene Zhulenev // Find all cancellable pairs of operation and erase them in the end to keep 133*a6628e59SEugene Zhulenev // all iterators valid while we are walking the function operations. 134*a6628e59SEugene Zhulenev llvm::SmallDenseMap<Operation *, Operation *> cancellable; 135*a6628e59SEugene Zhulenev 136*a6628e59SEugene Zhulenev // Optimize reference counting for values defined by block arguments. 137*a6628e59SEugene Zhulenev WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult { 138*a6628e59SEugene Zhulenev for (BlockArgument arg : block->getArguments()) 139*a6628e59SEugene Zhulenev if (isRefCounted(arg.getType())) 140*a6628e59SEugene Zhulenev if (failed(optimizeReferenceCounting(arg, cancellable))) 141*a6628e59SEugene Zhulenev return WalkResult::interrupt(); 142*a6628e59SEugene Zhulenev 143*a6628e59SEugene Zhulenev return WalkResult::advance(); 144*a6628e59SEugene Zhulenev }); 145*a6628e59SEugene Zhulenev 146*a6628e59SEugene Zhulenev if (blockWalk.wasInterrupted()) 147*a6628e59SEugene Zhulenev signalPassFailure(); 148*a6628e59SEugene Zhulenev 149*a6628e59SEugene Zhulenev // Optimize reference counting for values defined by operation results. 150*a6628e59SEugene Zhulenev WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult { 151*a6628e59SEugene Zhulenev for (unsigned i = 0; i < op->getNumResults(); ++i) 152*a6628e59SEugene Zhulenev if (isRefCounted(op->getResultTypes()[i])) 153*a6628e59SEugene Zhulenev if (failed(optimizeReferenceCounting(op->getResult(i), cancellable))) 154*a6628e59SEugene Zhulenev return WalkResult::interrupt(); 155*a6628e59SEugene Zhulenev 156*a6628e59SEugene Zhulenev return WalkResult::advance(); 157*a6628e59SEugene Zhulenev }); 158*a6628e59SEugene Zhulenev 159*a6628e59SEugene Zhulenev if (opWalk.wasInterrupted()) 160*a6628e59SEugene Zhulenev signalPassFailure(); 161*a6628e59SEugene Zhulenev 162*a6628e59SEugene Zhulenev LLVM_DEBUG({ 163*a6628e59SEugene Zhulenev llvm::dbgs() << "Found " << cancellable.size() 164*a6628e59SEugene Zhulenev << " cancellable reference counting operations\n"; 165*a6628e59SEugene Zhulenev }); 166*a6628e59SEugene Zhulenev 167*a6628e59SEugene Zhulenev // Erase all cancellable `add_ref <-> drop_ref` operation pairs. 168*a6628e59SEugene Zhulenev for (auto &kv : cancellable) { 169*a6628e59SEugene Zhulenev kv.first->erase(); 170*a6628e59SEugene Zhulenev kv.second->erase(); 171*a6628e59SEugene Zhulenev } 172*a6628e59SEugene Zhulenev } 173*a6628e59SEugene Zhulenev 174*a6628e59SEugene Zhulenev std::unique_ptr<OperationPass<FuncOp>> 175*a6628e59SEugene Zhulenev mlir::createAsyncRuntimeRefCountingOptPass() { 176*a6628e59SEugene Zhulenev return std::make_unique<AsyncRuntimeRefCountingOptPass>(); 177*a6628e59SEugene Zhulenev } 178