1 //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Optimize Async dialect reference counting operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Async/IR/Async.h" 15 #include "mlir/Dialect/Async/Passes.h" 16 #include "llvm/ADT/SmallSet.h" 17 #include "llvm/Support/Debug.h" 18 19 using namespace mlir; 20 using namespace mlir::async; 21 22 #define DEBUG_TYPE "async-ref-counting" 23 24 namespace { 25 26 class AsyncRuntimeRefCountingOptPass 27 : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> { 28 public: 29 AsyncRuntimeRefCountingOptPass() = default; 30 void runOnOperation() override; 31 32 private: 33 LogicalResult optimizeReferenceCounting( 34 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable); 35 }; 36 37 } // namespace 38 39 LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting( 40 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) { 41 Region *definingRegion = value.getParentRegion(); 42 43 // Find all users of the `value` inside each block, including operations that 44 // do not use `value` directly, but have a direct use inside nested region(s). 45 // 46 // Example: 47 // 48 // ^bb1: 49 // %token = ... 50 // scf.if %cond { 51 // ^bb2: 52 // async.runtime.await %token : !async.token 53 // } 54 // 55 // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1 56 // (`scf.if`). 57 58 struct BlockUsersInfo { 59 llvm::SmallVector<RuntimeAddRefOp, 4> addRefs; 60 llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs; 61 llvm::SmallVector<Operation *, 4> users; 62 }; 63 64 llvm::DenseMap<Block *, BlockUsersInfo> blockUsers; 65 66 auto updateBlockUsersInfo = [&](Operation *user) { 67 BlockUsersInfo &info = blockUsers[user->getBlock()]; 68 info.users.push_back(user); 69 70 if (auto addRef = dyn_cast<RuntimeAddRefOp>(user)) 71 info.addRefs.push_back(addRef); 72 if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user)) 73 info.dropRefs.push_back(dropRef); 74 }; 75 76 for (Operation *user : value.getUsers()) { 77 while (user->getParentRegion() != definingRegion) { 78 updateBlockUsersInfo(user); 79 user = user->getParentOp(); 80 assert(user != nullptr && "value user lies outside of the value region"); 81 } 82 83 updateBlockUsersInfo(user); 84 } 85 86 // Sort all operations found in the block. 87 auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & { 88 auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool { 89 return a->isBeforeInBlock(b); 90 }; 91 llvm::sort(info.addRefs, isBeforeInBlock); 92 llvm::sort(info.dropRefs, isBeforeInBlock); 93 llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool { 94 return isBeforeInBlock(a, b); 95 }); 96 97 return info; 98 }; 99 100 // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the 101 // blocks that modify the reference count of the `value`. 102 for (auto &kv : blockUsers) { 103 BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second); 104 105 for (RuntimeAddRefOp addRef : info.addRefs) { 106 for (RuntimeDropRefOp dropRef : info.dropRefs) { 107 // `drop_ref` operation after the `add_ref` with matching count. 108 if (dropRef.count() != addRef.count() || 109 dropRef->isBeforeInBlock(addRef.getOperation())) 110 continue; 111 112 // Try to cancel the pair of `add_ref` and `drop_ref` operations. 113 auto emplaced = cancellable.try_emplace(dropRef.getOperation(), 114 addRef.getOperation()); 115 116 if (!emplaced.second) // `drop_ref` was already marked for removal 117 continue; // go to the next `drop_ref` 118 119 if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref` 120 break; // go to the next `add_ref` 121 } 122 } 123 } 124 125 return success(); 126 } 127 128 void AsyncRuntimeRefCountingOptPass::runOnOperation() { 129 Operation *op = getOperation(); 130 131 // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`. 132 // 133 // Find all cancellable pairs of operation and erase them in the end to keep 134 // all iterators valid while we are walking the function operations. 135 llvm::SmallDenseMap<Operation *, Operation *> cancellable; 136 137 // Optimize reference counting for values defined by block arguments. 138 WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult { 139 for (BlockArgument arg : block->getArguments()) 140 if (isRefCounted(arg.getType())) 141 if (failed(optimizeReferenceCounting(arg, cancellable))) 142 return WalkResult::interrupt(); 143 144 return WalkResult::advance(); 145 }); 146 147 if (blockWalk.wasInterrupted()) 148 signalPassFailure(); 149 150 // Optimize reference counting for values defined by operation results. 151 WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult { 152 for (unsigned i = 0; i < op->getNumResults(); ++i) 153 if (isRefCounted(op->getResultTypes()[i])) 154 if (failed(optimizeReferenceCounting(op->getResult(i), cancellable))) 155 return WalkResult::interrupt(); 156 157 return WalkResult::advance(); 158 }); 159 160 if (opWalk.wasInterrupted()) 161 signalPassFailure(); 162 163 LLVM_DEBUG({ 164 llvm::dbgs() << "Found " << cancellable.size() 165 << " cancellable reference counting operations\n"; 166 }); 167 168 // Erase all cancellable `add_ref <-> drop_ref` operation pairs. 169 for (auto &kv : cancellable) { 170 kv.first->erase(); 171 kv.second->erase(); 172 } 173 } 174 175 std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() { 176 return std::make_unique<AsyncRuntimeRefCountingOptPass>(); 177 } 178