1a6628e59SEugene Zhulenev //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
2a6628e59SEugene Zhulenev //
3a6628e59SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a6628e59SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
5a6628e59SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a6628e59SEugene Zhulenev //
7a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===//
8a6628e59SEugene Zhulenev //
9a6628e59SEugene Zhulenev // Optimize Async dialect reference counting operations.
10a6628e59SEugene Zhulenev //
11a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===//
12a6628e59SEugene Zhulenev
13a6628e59SEugene Zhulenev #include "PassDetail.h"
14a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h"
15a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h"
16*23aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
17a6628e59SEugene Zhulenev #include "llvm/ADT/SmallSet.h"
18297a5b7cSNico Weber #include "llvm/Support/Debug.h"
19a6628e59SEugene Zhulenev
20a6628e59SEugene Zhulenev using namespace mlir;
21a6628e59SEugene Zhulenev using namespace mlir::async;
22a6628e59SEugene Zhulenev
23a6628e59SEugene Zhulenev #define DEBUG_TYPE "async-ref-counting"
24a6628e59SEugene Zhulenev
25a6628e59SEugene Zhulenev namespace {
26a6628e59SEugene Zhulenev
27a6628e59SEugene Zhulenev class AsyncRuntimeRefCountingOptPass
28a6628e59SEugene Zhulenev : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> {
29a6628e59SEugene Zhulenev public:
30a6628e59SEugene Zhulenev AsyncRuntimeRefCountingOptPass() = default;
318a316b00SEugene Zhulenev void runOnOperation() override;
32a6628e59SEugene Zhulenev
33a6628e59SEugene Zhulenev private:
34a6628e59SEugene Zhulenev LogicalResult optimizeReferenceCounting(
35a6628e59SEugene Zhulenev Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
36a6628e59SEugene Zhulenev };
37a6628e59SEugene Zhulenev
38a6628e59SEugene Zhulenev } // namespace
39a6628e59SEugene Zhulenev
optimizeReferenceCounting(Value value,llvm::SmallDenseMap<Operation *,Operation * > & cancellable)40a6628e59SEugene Zhulenev LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
41a6628e59SEugene Zhulenev Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
42a6628e59SEugene Zhulenev Region *definingRegion = value.getParentRegion();
43a6628e59SEugene Zhulenev
44a6628e59SEugene Zhulenev // Find all users of the `value` inside each block, including operations that
45a6628e59SEugene Zhulenev // do not use `value` directly, but have a direct use inside nested region(s).
46a6628e59SEugene Zhulenev //
47a6628e59SEugene Zhulenev // Example:
48a6628e59SEugene Zhulenev //
49a6628e59SEugene Zhulenev // ^bb1:
50a6628e59SEugene Zhulenev // %token = ...
51a6628e59SEugene Zhulenev // scf.if %cond {
52a6628e59SEugene Zhulenev // ^bb2:
53a6628e59SEugene Zhulenev // async.runtime.await %token : !async.token
54a6628e59SEugene Zhulenev // }
55a6628e59SEugene Zhulenev //
56a6628e59SEugene Zhulenev // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
57a6628e59SEugene Zhulenev // (`scf.if`).
58a6628e59SEugene Zhulenev
59a6628e59SEugene Zhulenev struct BlockUsersInfo {
60a6628e59SEugene Zhulenev llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
61a6628e59SEugene Zhulenev llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
62a6628e59SEugene Zhulenev llvm::SmallVector<Operation *, 4> users;
63a6628e59SEugene Zhulenev };
64a6628e59SEugene Zhulenev
65a6628e59SEugene Zhulenev llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
66a6628e59SEugene Zhulenev
67a6628e59SEugene Zhulenev auto updateBlockUsersInfo = [&](Operation *user) {
68a6628e59SEugene Zhulenev BlockUsersInfo &info = blockUsers[user->getBlock()];
69a6628e59SEugene Zhulenev info.users.push_back(user);
70a6628e59SEugene Zhulenev
71a6628e59SEugene Zhulenev if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
72a6628e59SEugene Zhulenev info.addRefs.push_back(addRef);
73a6628e59SEugene Zhulenev if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
74a6628e59SEugene Zhulenev info.dropRefs.push_back(dropRef);
75a6628e59SEugene Zhulenev };
76a6628e59SEugene Zhulenev
77a6628e59SEugene Zhulenev for (Operation *user : value.getUsers()) {
78a6628e59SEugene Zhulenev while (user->getParentRegion() != definingRegion) {
79a6628e59SEugene Zhulenev updateBlockUsersInfo(user);
80a6628e59SEugene Zhulenev user = user->getParentOp();
81a6628e59SEugene Zhulenev assert(user != nullptr && "value user lies outside of the value region");
82a6628e59SEugene Zhulenev }
83a6628e59SEugene Zhulenev
84a6628e59SEugene Zhulenev updateBlockUsersInfo(user);
85a6628e59SEugene Zhulenev }
86a6628e59SEugene Zhulenev
87a6628e59SEugene Zhulenev // Sort all operations found in the block.
88a6628e59SEugene Zhulenev auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
89a6628e59SEugene Zhulenev auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
90a6628e59SEugene Zhulenev return a->isBeforeInBlock(b);
91a6628e59SEugene Zhulenev };
92a6628e59SEugene Zhulenev llvm::sort(info.addRefs, isBeforeInBlock);
93a6628e59SEugene Zhulenev llvm::sort(info.dropRefs, isBeforeInBlock);
94a6628e59SEugene Zhulenev llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
95a6628e59SEugene Zhulenev return isBeforeInBlock(a, b);
96a6628e59SEugene Zhulenev });
97a6628e59SEugene Zhulenev
98a6628e59SEugene Zhulenev return info;
99a6628e59SEugene Zhulenev };
100a6628e59SEugene Zhulenev
101a6628e59SEugene Zhulenev // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
102a6628e59SEugene Zhulenev // blocks that modify the reference count of the `value`.
103a6628e59SEugene Zhulenev for (auto &kv : blockUsers) {
104a6628e59SEugene Zhulenev BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
105a6628e59SEugene Zhulenev
106a6628e59SEugene Zhulenev for (RuntimeAddRefOp addRef : info.addRefs) {
107a6628e59SEugene Zhulenev for (RuntimeDropRefOp dropRef : info.dropRefs) {
108a6628e59SEugene Zhulenev // `drop_ref` operation after the `add_ref` with matching count.
109a6628e59SEugene Zhulenev if (dropRef.count() != addRef.count() ||
110a6628e59SEugene Zhulenev dropRef->isBeforeInBlock(addRef.getOperation()))
111a6628e59SEugene Zhulenev continue;
112a6628e59SEugene Zhulenev
1139ccdaac8SEugene Zhulenev // When reference counted value passed to a function as an argument,
1149ccdaac8SEugene Zhulenev // function takes ownership of +1 reference and it will drop it before
1159ccdaac8SEugene Zhulenev // returning.
1169ccdaac8SEugene Zhulenev //
1179ccdaac8SEugene Zhulenev // Example:
1189ccdaac8SEugene Zhulenev //
1199ccdaac8SEugene Zhulenev // %token = ... : !async.token
1209ccdaac8SEugene Zhulenev //
12192db09cdSEugene Zhulenev // async.runtime.add_ref %token {count = 1 : i64} : !async.token
1229ccdaac8SEugene Zhulenev // call @pass_token(%token: !async.token, ...)
1239ccdaac8SEugene Zhulenev //
1249ccdaac8SEugene Zhulenev // async.await %token : !async.token
12592db09cdSEugene Zhulenev // async.runtime.drop_ref %token {count = 1 : i64} : !async.token
1269ccdaac8SEugene Zhulenev //
1279ccdaac8SEugene Zhulenev // In this example if we'll cancel a pair of reference counting
1289ccdaac8SEugene Zhulenev // operations we might end up with a deallocated token when we'll
1299ccdaac8SEugene Zhulenev // reach `async.await` operation.
1309ccdaac8SEugene Zhulenev Operation *firstFunctionCallUser = nullptr;
1319ccdaac8SEugene Zhulenev Operation *lastNonFunctionCallUser = nullptr;
1329ccdaac8SEugene Zhulenev
1339ccdaac8SEugene Zhulenev for (Operation *user : info.users) {
1349ccdaac8SEugene Zhulenev // `user` operation lies after `addRef` ...
1359ccdaac8SEugene Zhulenev if (user == addRef || user->isBeforeInBlock(addRef))
1369ccdaac8SEugene Zhulenev continue;
1379ccdaac8SEugene Zhulenev // ... and before `dropRef`.
1389ccdaac8SEugene Zhulenev if (user == dropRef || dropRef->isBeforeInBlock(user))
1399ccdaac8SEugene Zhulenev break;
1409ccdaac8SEugene Zhulenev
1419ccdaac8SEugene Zhulenev // Find the first function call user of the reference counted value.
142*23aa5a74SRiver Riddle Operation *functionCall = dyn_cast<func::CallOp>(user);
1439ccdaac8SEugene Zhulenev if (functionCall &&
1449ccdaac8SEugene Zhulenev (!firstFunctionCallUser ||
1459ccdaac8SEugene Zhulenev functionCall->isBeforeInBlock(firstFunctionCallUser))) {
1469ccdaac8SEugene Zhulenev firstFunctionCallUser = functionCall;
1479ccdaac8SEugene Zhulenev continue;
1489ccdaac8SEugene Zhulenev }
1499ccdaac8SEugene Zhulenev
1509ccdaac8SEugene Zhulenev // Find the last regular user of the reference counted value.
1519ccdaac8SEugene Zhulenev if (!functionCall &&
1529ccdaac8SEugene Zhulenev (!lastNonFunctionCallUser ||
1539ccdaac8SEugene Zhulenev lastNonFunctionCallUser->isBeforeInBlock(user))) {
1549ccdaac8SEugene Zhulenev lastNonFunctionCallUser = user;
1559ccdaac8SEugene Zhulenev continue;
1569ccdaac8SEugene Zhulenev }
1579ccdaac8SEugene Zhulenev }
1589ccdaac8SEugene Zhulenev
1599ccdaac8SEugene Zhulenev // Non function call user after the function call user of the reference
1609ccdaac8SEugene Zhulenev // counted value.
1619ccdaac8SEugene Zhulenev if (firstFunctionCallUser && lastNonFunctionCallUser &&
1629ccdaac8SEugene Zhulenev firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser))
1639ccdaac8SEugene Zhulenev continue;
1649ccdaac8SEugene Zhulenev
165a6628e59SEugene Zhulenev // Try to cancel the pair of `add_ref` and `drop_ref` operations.
166a6628e59SEugene Zhulenev auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
167a6628e59SEugene Zhulenev addRef.getOperation());
168a6628e59SEugene Zhulenev
169a6628e59SEugene Zhulenev if (!emplaced.second) // `drop_ref` was already marked for removal
170a6628e59SEugene Zhulenev continue; // go to the next `drop_ref`
171a6628e59SEugene Zhulenev
172a6628e59SEugene Zhulenev if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
173a6628e59SEugene Zhulenev break; // go to the next `add_ref`
174a6628e59SEugene Zhulenev }
175a6628e59SEugene Zhulenev }
176a6628e59SEugene Zhulenev }
177a6628e59SEugene Zhulenev
178a6628e59SEugene Zhulenev return success();
179a6628e59SEugene Zhulenev }
180a6628e59SEugene Zhulenev
runOnOperation()1818a316b00SEugene Zhulenev void AsyncRuntimeRefCountingOptPass::runOnOperation() {
1828a316b00SEugene Zhulenev Operation *op = getOperation();
183a6628e59SEugene Zhulenev
184a6628e59SEugene Zhulenev // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
185a6628e59SEugene Zhulenev //
186a6628e59SEugene Zhulenev // Find all cancellable pairs of operation and erase them in the end to keep
187a6628e59SEugene Zhulenev // all iterators valid while we are walking the function operations.
188a6628e59SEugene Zhulenev llvm::SmallDenseMap<Operation *, Operation *> cancellable;
189a6628e59SEugene Zhulenev
190a6628e59SEugene Zhulenev // Optimize reference counting for values defined by block arguments.
1918a316b00SEugene Zhulenev WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
192a6628e59SEugene Zhulenev for (BlockArgument arg : block->getArguments())
193a6628e59SEugene Zhulenev if (isRefCounted(arg.getType()))
194a6628e59SEugene Zhulenev if (failed(optimizeReferenceCounting(arg, cancellable)))
195a6628e59SEugene Zhulenev return WalkResult::interrupt();
196a6628e59SEugene Zhulenev
197a6628e59SEugene Zhulenev return WalkResult::advance();
198a6628e59SEugene Zhulenev });
199a6628e59SEugene Zhulenev
200a6628e59SEugene Zhulenev if (blockWalk.wasInterrupted())
201a6628e59SEugene Zhulenev signalPassFailure();
202a6628e59SEugene Zhulenev
203a6628e59SEugene Zhulenev // Optimize reference counting for values defined by operation results.
2048a316b00SEugene Zhulenev WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
205a6628e59SEugene Zhulenev for (unsigned i = 0; i < op->getNumResults(); ++i)
206a6628e59SEugene Zhulenev if (isRefCounted(op->getResultTypes()[i]))
207a6628e59SEugene Zhulenev if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
208a6628e59SEugene Zhulenev return WalkResult::interrupt();
209a6628e59SEugene Zhulenev
210a6628e59SEugene Zhulenev return WalkResult::advance();
211a6628e59SEugene Zhulenev });
212a6628e59SEugene Zhulenev
213a6628e59SEugene Zhulenev if (opWalk.wasInterrupted())
214a6628e59SEugene Zhulenev signalPassFailure();
215a6628e59SEugene Zhulenev
216a6628e59SEugene Zhulenev LLVM_DEBUG({
217a6628e59SEugene Zhulenev llvm::dbgs() << "Found " << cancellable.size()
218a6628e59SEugene Zhulenev << " cancellable reference counting operations\n";
219a6628e59SEugene Zhulenev });
220a6628e59SEugene Zhulenev
221a6628e59SEugene Zhulenev // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
222a6628e59SEugene Zhulenev for (auto &kv : cancellable) {
223a6628e59SEugene Zhulenev kv.first->erase();
224a6628e59SEugene Zhulenev kv.second->erase();
225a6628e59SEugene Zhulenev }
226a6628e59SEugene Zhulenev }
227a6628e59SEugene Zhulenev
createAsyncRuntimeRefCountingOptPass()2288a316b00SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() {
229a6628e59SEugene Zhulenev return std::make_unique<AsyncRuntimeRefCountingOptPass>();
230a6628e59SEugene Zhulenev }
231