1 //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Optimize Async dialect reference counting operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Async/IR/Async.h" 15 #include "mlir/Dialect/Async/Passes.h" 16 #include "llvm/ADT/SmallSet.h" 17 18 using namespace mlir; 19 using namespace mlir::async; 20 21 #define DEBUG_TYPE "async-ref-counting" 22 23 namespace { 24 25 class AsyncRuntimeRefCountingOptPass 26 : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> { 27 public: 28 AsyncRuntimeRefCountingOptPass() = default; 29 void runOnFunction() override; 30 31 private: 32 LogicalResult optimizeReferenceCounting( 33 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable); 34 }; 35 36 } // namespace 37 38 LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting( 39 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) { 40 Region *definingRegion = value.getParentRegion(); 41 42 // Find all users of the `value` inside each block, including operations that 43 // do not use `value` directly, but have a direct use inside nested region(s). 44 // 45 // Example: 46 // 47 // ^bb1: 48 // %token = ... 49 // scf.if %cond { 50 // ^bb2: 51 // async.runtime.await %token : !async.token 52 // } 53 // 54 // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1 55 // (`scf.if`). 56 57 struct BlockUsersInfo { 58 llvm::SmallVector<RuntimeAddRefOp, 4> addRefs; 59 llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs; 60 llvm::SmallVector<Operation *, 4> users; 61 }; 62 63 llvm::DenseMap<Block *, BlockUsersInfo> blockUsers; 64 65 auto updateBlockUsersInfo = [&](Operation *user) { 66 BlockUsersInfo &info = blockUsers[user->getBlock()]; 67 info.users.push_back(user); 68 69 if (auto addRef = dyn_cast<RuntimeAddRefOp>(user)) 70 info.addRefs.push_back(addRef); 71 if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user)) 72 info.dropRefs.push_back(dropRef); 73 }; 74 75 for (Operation *user : value.getUsers()) { 76 while (user->getParentRegion() != definingRegion) { 77 updateBlockUsersInfo(user); 78 user = user->getParentOp(); 79 assert(user != nullptr && "value user lies outside of the value region"); 80 } 81 82 updateBlockUsersInfo(user); 83 } 84 85 // Sort all operations found in the block. 86 auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & { 87 auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool { 88 return a->isBeforeInBlock(b); 89 }; 90 llvm::sort(info.addRefs, isBeforeInBlock); 91 llvm::sort(info.dropRefs, isBeforeInBlock); 92 llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool { 93 return isBeforeInBlock(a, b); 94 }); 95 96 return info; 97 }; 98 99 // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the 100 // blocks that modify the reference count of the `value`. 101 for (auto &kv : blockUsers) { 102 BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second); 103 104 for (RuntimeAddRefOp addRef : info.addRefs) { 105 for (RuntimeDropRefOp dropRef : info.dropRefs) { 106 // `drop_ref` operation after the `add_ref` with matching count. 107 if (dropRef.count() != addRef.count() || 108 dropRef->isBeforeInBlock(addRef.getOperation())) 109 continue; 110 111 // Try to cancel the pair of `add_ref` and `drop_ref` operations. 112 auto emplaced = cancellable.try_emplace(dropRef.getOperation(), 113 addRef.getOperation()); 114 115 if (!emplaced.second) // `drop_ref` was already marked for removal 116 continue; // go to the next `drop_ref` 117 118 if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref` 119 break; // go to the next `add_ref` 120 } 121 } 122 } 123 124 return success(); 125 } 126 127 void AsyncRuntimeRefCountingOptPass::runOnFunction() { 128 FuncOp func = getFunction(); 129 130 // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`. 131 // 132 // Find all cancellable pairs of operation and erase them in the end to keep 133 // all iterators valid while we are walking the function operations. 134 llvm::SmallDenseMap<Operation *, Operation *> cancellable; 135 136 // Optimize reference counting for values defined by block arguments. 137 WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult { 138 for (BlockArgument arg : block->getArguments()) 139 if (isRefCounted(arg.getType())) 140 if (failed(optimizeReferenceCounting(arg, cancellable))) 141 return WalkResult::interrupt(); 142 143 return WalkResult::advance(); 144 }); 145 146 if (blockWalk.wasInterrupted()) 147 signalPassFailure(); 148 149 // Optimize reference counting for values defined by operation results. 150 WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult { 151 for (unsigned i = 0; i < op->getNumResults(); ++i) 152 if (isRefCounted(op->getResultTypes()[i])) 153 if (failed(optimizeReferenceCounting(op->getResult(i), cancellable))) 154 return WalkResult::interrupt(); 155 156 return WalkResult::advance(); 157 }); 158 159 if (opWalk.wasInterrupted()) 160 signalPassFailure(); 161 162 LLVM_DEBUG({ 163 llvm::dbgs() << "Found " << cancellable.size() 164 << " cancellable reference counting operations\n"; 165 }); 166 167 // Erase all cancellable `add_ref <-> drop_ref` operation pairs. 168 for (auto &kv : cancellable) { 169 kv.first->erase(); 170 kv.second->erase(); 171 } 172 } 173 174 std::unique_ptr<OperationPass<FuncOp>> 175 mlir::createAsyncRuntimeRefCountingOptPass() { 176 return std::make_unique<AsyncRuntimeRefCountingOptPass>(); 177 } 178