1 //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Optimize Async dialect reference counting operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Async/IR/Async.h"
15 #include "mlir/Dialect/Async/Passes.h"
16 #include "llvm/ADT/SmallSet.h"
17 
18 using namespace mlir;
19 using namespace mlir::async;
20 
21 #define DEBUG_TYPE "async-ref-counting"
22 
23 namespace {
24 
25 class AsyncRuntimeRefCountingOptPass
26     : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> {
27 public:
28   AsyncRuntimeRefCountingOptPass() = default;
29   void runOnFunction() override;
30 
31 private:
32   LogicalResult optimizeReferenceCounting(
33       Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
34 };
35 
36 } // namespace
37 
38 LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
39     Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
40   Region *definingRegion = value.getParentRegion();
41 
42   // Find all users of the `value` inside each block, including operations that
43   // do not use `value` directly, but have a direct use inside nested region(s).
44   //
45   // Example:
46   //
47   //  ^bb1:
48   //    %token = ...
49   //    scf.if %cond {
50   //      ^bb2:
51   //      async.runtime.await %token : !async.token
52   //    }
53   //
54   // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
55   // (`scf.if`).
56 
57   struct BlockUsersInfo {
58     llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
59     llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
60     llvm::SmallVector<Operation *, 4> users;
61   };
62 
63   llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
64 
65   auto updateBlockUsersInfo = [&](Operation *user) {
66     BlockUsersInfo &info = blockUsers[user->getBlock()];
67     info.users.push_back(user);
68 
69     if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
70       info.addRefs.push_back(addRef);
71     if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
72       info.dropRefs.push_back(dropRef);
73   };
74 
75   for (Operation *user : value.getUsers()) {
76     while (user->getParentRegion() != definingRegion) {
77       updateBlockUsersInfo(user);
78       user = user->getParentOp();
79       assert(user != nullptr && "value user lies outside of the value region");
80     }
81 
82     updateBlockUsersInfo(user);
83   }
84 
85   // Sort all operations found in the block.
86   auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
87     auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
88       return a->isBeforeInBlock(b);
89     };
90     llvm::sort(info.addRefs, isBeforeInBlock);
91     llvm::sort(info.dropRefs, isBeforeInBlock);
92     llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
93       return isBeforeInBlock(a, b);
94     });
95 
96     return info;
97   };
98 
99   // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
100   // blocks that modify the reference count of the `value`.
101   for (auto &kv : blockUsers) {
102     BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
103 
104     for (RuntimeAddRefOp addRef : info.addRefs) {
105       for (RuntimeDropRefOp dropRef : info.dropRefs) {
106         // `drop_ref` operation after the `add_ref` with matching count.
107         if (dropRef.count() != addRef.count() ||
108             dropRef->isBeforeInBlock(addRef.getOperation()))
109           continue;
110 
111         // Try to cancel the pair of `add_ref` and `drop_ref` operations.
112         auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
113                                                 addRef.getOperation());
114 
115         if (!emplaced.second) // `drop_ref` was already marked for removal
116           continue;           // go to the next `drop_ref`
117 
118         if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
119           break;             // go to the next `add_ref`
120       }
121     }
122   }
123 
124   return success();
125 }
126 
127 void AsyncRuntimeRefCountingOptPass::runOnFunction() {
128   FuncOp func = getFunction();
129 
130   // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
131   //
132   // Find all cancellable pairs of operation and erase them in the end to keep
133   // all iterators valid while we are walking the function operations.
134   llvm::SmallDenseMap<Operation *, Operation *> cancellable;
135 
136   // Optimize reference counting for values defined by block arguments.
137   WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
138     for (BlockArgument arg : block->getArguments())
139       if (isRefCounted(arg.getType()))
140         if (failed(optimizeReferenceCounting(arg, cancellable)))
141           return WalkResult::interrupt();
142 
143     return WalkResult::advance();
144   });
145 
146   if (blockWalk.wasInterrupted())
147     signalPassFailure();
148 
149   // Optimize reference counting for values defined by operation results.
150   WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
151     for (unsigned i = 0; i < op->getNumResults(); ++i)
152       if (isRefCounted(op->getResultTypes()[i]))
153         if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
154           return WalkResult::interrupt();
155 
156     return WalkResult::advance();
157   });
158 
159   if (opWalk.wasInterrupted())
160     signalPassFailure();
161 
162   LLVM_DEBUG({
163     llvm::dbgs() << "Found " << cancellable.size()
164                  << " cancellable reference counting operations\n";
165   });
166 
167   // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
168   for (auto &kv : cancellable) {
169     kv.first->erase();
170     kv.second->erase();
171   }
172 }
173 
174 std::unique_ptr<OperationPass<FuncOp>>
175 mlir::createAsyncRuntimeRefCountingOptPass() {
176   return std::make_unique<AsyncRuntimeRefCountingOptPass>();
177 }
178