1 //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Optimize Async dialect reference counting operations.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Async/IR/Async.h"
15 #include "mlir/Dialect/Async/Passes.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "llvm/ADT/SmallSet.h"
18 #include "llvm/Support/Debug.h"
19
20 using namespace mlir;
21 using namespace mlir::async;
22
23 #define DEBUG_TYPE "async-ref-counting"
24
25 namespace {
26
27 class AsyncRuntimeRefCountingOptPass
28 : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> {
29 public:
30 AsyncRuntimeRefCountingOptPass() = default;
31 void runOnOperation() override;
32
33 private:
34 LogicalResult optimizeReferenceCounting(
35 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
36 };
37
38 } // namespace
39
optimizeReferenceCounting(Value value,llvm::SmallDenseMap<Operation *,Operation * > & cancellable)40 LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
41 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
42 Region *definingRegion = value.getParentRegion();
43
44 // Find all users of the `value` inside each block, including operations that
45 // do not use `value` directly, but have a direct use inside nested region(s).
46 //
47 // Example:
48 //
49 // ^bb1:
50 // %token = ...
51 // scf.if %cond {
52 // ^bb2:
53 // async.runtime.await %token : !async.token
54 // }
55 //
56 // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
57 // (`scf.if`).
58
59 struct BlockUsersInfo {
60 llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
61 llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
62 llvm::SmallVector<Operation *, 4> users;
63 };
64
65 llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
66
67 auto updateBlockUsersInfo = [&](Operation *user) {
68 BlockUsersInfo &info = blockUsers[user->getBlock()];
69 info.users.push_back(user);
70
71 if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
72 info.addRefs.push_back(addRef);
73 if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
74 info.dropRefs.push_back(dropRef);
75 };
76
77 for (Operation *user : value.getUsers()) {
78 while (user->getParentRegion() != definingRegion) {
79 updateBlockUsersInfo(user);
80 user = user->getParentOp();
81 assert(user != nullptr && "value user lies outside of the value region");
82 }
83
84 updateBlockUsersInfo(user);
85 }
86
87 // Sort all operations found in the block.
88 auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
89 auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
90 return a->isBeforeInBlock(b);
91 };
92 llvm::sort(info.addRefs, isBeforeInBlock);
93 llvm::sort(info.dropRefs, isBeforeInBlock);
94 llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
95 return isBeforeInBlock(a, b);
96 });
97
98 return info;
99 };
100
101 // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
102 // blocks that modify the reference count of the `value`.
103 for (auto &kv : blockUsers) {
104 BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
105
106 for (RuntimeAddRefOp addRef : info.addRefs) {
107 for (RuntimeDropRefOp dropRef : info.dropRefs) {
108 // `drop_ref` operation after the `add_ref` with matching count.
109 if (dropRef.count() != addRef.count() ||
110 dropRef->isBeforeInBlock(addRef.getOperation()))
111 continue;
112
113 // When reference counted value passed to a function as an argument,
114 // function takes ownership of +1 reference and it will drop it before
115 // returning.
116 //
117 // Example:
118 //
119 // %token = ... : !async.token
120 //
121 // async.runtime.add_ref %token {count = 1 : i64} : !async.token
122 // call @pass_token(%token: !async.token, ...)
123 //
124 // async.await %token : !async.token
125 // async.runtime.drop_ref %token {count = 1 : i64} : !async.token
126 //
127 // In this example if we'll cancel a pair of reference counting
128 // operations we might end up with a deallocated token when we'll
129 // reach `async.await` operation.
130 Operation *firstFunctionCallUser = nullptr;
131 Operation *lastNonFunctionCallUser = nullptr;
132
133 for (Operation *user : info.users) {
134 // `user` operation lies after `addRef` ...
135 if (user == addRef || user->isBeforeInBlock(addRef))
136 continue;
137 // ... and before `dropRef`.
138 if (user == dropRef || dropRef->isBeforeInBlock(user))
139 break;
140
141 // Find the first function call user of the reference counted value.
142 Operation *functionCall = dyn_cast<func::CallOp>(user);
143 if (functionCall &&
144 (!firstFunctionCallUser ||
145 functionCall->isBeforeInBlock(firstFunctionCallUser))) {
146 firstFunctionCallUser = functionCall;
147 continue;
148 }
149
150 // Find the last regular user of the reference counted value.
151 if (!functionCall &&
152 (!lastNonFunctionCallUser ||
153 lastNonFunctionCallUser->isBeforeInBlock(user))) {
154 lastNonFunctionCallUser = user;
155 continue;
156 }
157 }
158
159 // Non function call user after the function call user of the reference
160 // counted value.
161 if (firstFunctionCallUser && lastNonFunctionCallUser &&
162 firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser))
163 continue;
164
165 // Try to cancel the pair of `add_ref` and `drop_ref` operations.
166 auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
167 addRef.getOperation());
168
169 if (!emplaced.second) // `drop_ref` was already marked for removal
170 continue; // go to the next `drop_ref`
171
172 if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
173 break; // go to the next `add_ref`
174 }
175 }
176 }
177
178 return success();
179 }
180
runOnOperation()181 void AsyncRuntimeRefCountingOptPass::runOnOperation() {
182 Operation *op = getOperation();
183
184 // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
185 //
186 // Find all cancellable pairs of operation and erase them in the end to keep
187 // all iterators valid while we are walking the function operations.
188 llvm::SmallDenseMap<Operation *, Operation *> cancellable;
189
190 // Optimize reference counting for values defined by block arguments.
191 WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
192 for (BlockArgument arg : block->getArguments())
193 if (isRefCounted(arg.getType()))
194 if (failed(optimizeReferenceCounting(arg, cancellable)))
195 return WalkResult::interrupt();
196
197 return WalkResult::advance();
198 });
199
200 if (blockWalk.wasInterrupted())
201 signalPassFailure();
202
203 // Optimize reference counting for values defined by operation results.
204 WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
205 for (unsigned i = 0; i < op->getNumResults(); ++i)
206 if (isRefCounted(op->getResultTypes()[i]))
207 if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
208 return WalkResult::interrupt();
209
210 return WalkResult::advance();
211 });
212
213 if (opWalk.wasInterrupted())
214 signalPassFailure();
215
216 LLVM_DEBUG({
217 llvm::dbgs() << "Found " << cancellable.size()
218 << " cancellable reference counting operations\n";
219 });
220
221 // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
222 for (auto &kv : cancellable) {
223 kv.first->erase();
224 kv.second->erase();
225 }
226 }
227
createAsyncRuntimeRefCountingOptPass()228 std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() {
229 return std::make_unique<AsyncRuntimeRefCountingOptPass>();
230 }
231