1a6628e59SEugene Zhulenev //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
2a6628e59SEugene Zhulenev //
3a6628e59SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a6628e59SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
5a6628e59SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a6628e59SEugene Zhulenev //
7a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===//
8a6628e59SEugene Zhulenev //
9a6628e59SEugene Zhulenev // Optimize Async dialect reference counting operations.
10a6628e59SEugene Zhulenev //
11a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===//
12a6628e59SEugene Zhulenev 
13a6628e59SEugene Zhulenev #include "PassDetail.h"
14a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h"
15a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h"
16a6628e59SEugene Zhulenev #include "llvm/ADT/SmallSet.h"
17a6628e59SEugene Zhulenev 
18a6628e59SEugene Zhulenev using namespace mlir;
19a6628e59SEugene Zhulenev using namespace mlir::async;
20a6628e59SEugene Zhulenev 
21a6628e59SEugene Zhulenev #define DEBUG_TYPE "async-ref-counting"
22a6628e59SEugene Zhulenev 
23a6628e59SEugene Zhulenev namespace {
24a6628e59SEugene Zhulenev 
25a6628e59SEugene Zhulenev class AsyncRuntimeRefCountingOptPass
26a6628e59SEugene Zhulenev     : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> {
27a6628e59SEugene Zhulenev public:
28a6628e59SEugene Zhulenev   AsyncRuntimeRefCountingOptPass() = default;
29*8a316b00SEugene Zhulenev   void runOnOperation() override;
30a6628e59SEugene Zhulenev 
31a6628e59SEugene Zhulenev private:
32a6628e59SEugene Zhulenev   LogicalResult optimizeReferenceCounting(
33a6628e59SEugene Zhulenev       Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
34a6628e59SEugene Zhulenev };
35a6628e59SEugene Zhulenev 
36a6628e59SEugene Zhulenev } // namespace
37a6628e59SEugene Zhulenev 
38a6628e59SEugene Zhulenev LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
39a6628e59SEugene Zhulenev     Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
40a6628e59SEugene Zhulenev   Region *definingRegion = value.getParentRegion();
41a6628e59SEugene Zhulenev 
42a6628e59SEugene Zhulenev   // Find all users of the `value` inside each block, including operations that
43a6628e59SEugene Zhulenev   // do not use `value` directly, but have a direct use inside nested region(s).
44a6628e59SEugene Zhulenev   //
45a6628e59SEugene Zhulenev   // Example:
46a6628e59SEugene Zhulenev   //
47a6628e59SEugene Zhulenev   //  ^bb1:
48a6628e59SEugene Zhulenev   //    %token = ...
49a6628e59SEugene Zhulenev   //    scf.if %cond {
50a6628e59SEugene Zhulenev   //      ^bb2:
51a6628e59SEugene Zhulenev   //      async.runtime.await %token : !async.token
52a6628e59SEugene Zhulenev   //    }
53a6628e59SEugene Zhulenev   //
54a6628e59SEugene Zhulenev   // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
55a6628e59SEugene Zhulenev   // (`scf.if`).
56a6628e59SEugene Zhulenev 
57a6628e59SEugene Zhulenev   struct BlockUsersInfo {
58a6628e59SEugene Zhulenev     llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
59a6628e59SEugene Zhulenev     llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
60a6628e59SEugene Zhulenev     llvm::SmallVector<Operation *, 4> users;
61a6628e59SEugene Zhulenev   };
62a6628e59SEugene Zhulenev 
63a6628e59SEugene Zhulenev   llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
64a6628e59SEugene Zhulenev 
65a6628e59SEugene Zhulenev   auto updateBlockUsersInfo = [&](Operation *user) {
66a6628e59SEugene Zhulenev     BlockUsersInfo &info = blockUsers[user->getBlock()];
67a6628e59SEugene Zhulenev     info.users.push_back(user);
68a6628e59SEugene Zhulenev 
69a6628e59SEugene Zhulenev     if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
70a6628e59SEugene Zhulenev       info.addRefs.push_back(addRef);
71a6628e59SEugene Zhulenev     if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
72a6628e59SEugene Zhulenev       info.dropRefs.push_back(dropRef);
73a6628e59SEugene Zhulenev   };
74a6628e59SEugene Zhulenev 
75a6628e59SEugene Zhulenev   for (Operation *user : value.getUsers()) {
76a6628e59SEugene Zhulenev     while (user->getParentRegion() != definingRegion) {
77a6628e59SEugene Zhulenev       updateBlockUsersInfo(user);
78a6628e59SEugene Zhulenev       user = user->getParentOp();
79a6628e59SEugene Zhulenev       assert(user != nullptr && "value user lies outside of the value region");
80a6628e59SEugene Zhulenev     }
81a6628e59SEugene Zhulenev 
82a6628e59SEugene Zhulenev     updateBlockUsersInfo(user);
83a6628e59SEugene Zhulenev   }
84a6628e59SEugene Zhulenev 
85a6628e59SEugene Zhulenev   // Sort all operations found in the block.
86a6628e59SEugene Zhulenev   auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
87a6628e59SEugene Zhulenev     auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
88a6628e59SEugene Zhulenev       return a->isBeforeInBlock(b);
89a6628e59SEugene Zhulenev     };
90a6628e59SEugene Zhulenev     llvm::sort(info.addRefs, isBeforeInBlock);
91a6628e59SEugene Zhulenev     llvm::sort(info.dropRefs, isBeforeInBlock);
92a6628e59SEugene Zhulenev     llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
93a6628e59SEugene Zhulenev       return isBeforeInBlock(a, b);
94a6628e59SEugene Zhulenev     });
95a6628e59SEugene Zhulenev 
96a6628e59SEugene Zhulenev     return info;
97a6628e59SEugene Zhulenev   };
98a6628e59SEugene Zhulenev 
99a6628e59SEugene Zhulenev   // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
100a6628e59SEugene Zhulenev   // blocks that modify the reference count of the `value`.
101a6628e59SEugene Zhulenev   for (auto &kv : blockUsers) {
102a6628e59SEugene Zhulenev     BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
103a6628e59SEugene Zhulenev 
104a6628e59SEugene Zhulenev     for (RuntimeAddRefOp addRef : info.addRefs) {
105a6628e59SEugene Zhulenev       for (RuntimeDropRefOp dropRef : info.dropRefs) {
106a6628e59SEugene Zhulenev         // `drop_ref` operation after the `add_ref` with matching count.
107a6628e59SEugene Zhulenev         if (dropRef.count() != addRef.count() ||
108a6628e59SEugene Zhulenev             dropRef->isBeforeInBlock(addRef.getOperation()))
109a6628e59SEugene Zhulenev           continue;
110a6628e59SEugene Zhulenev 
111a6628e59SEugene Zhulenev         // Try to cancel the pair of `add_ref` and `drop_ref` operations.
112a6628e59SEugene Zhulenev         auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
113a6628e59SEugene Zhulenev                                                 addRef.getOperation());
114a6628e59SEugene Zhulenev 
115a6628e59SEugene Zhulenev         if (!emplaced.second) // `drop_ref` was already marked for removal
116a6628e59SEugene Zhulenev           continue;           // go to the next `drop_ref`
117a6628e59SEugene Zhulenev 
118a6628e59SEugene Zhulenev         if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
119a6628e59SEugene Zhulenev           break;             // go to the next `add_ref`
120a6628e59SEugene Zhulenev       }
121a6628e59SEugene Zhulenev     }
122a6628e59SEugene Zhulenev   }
123a6628e59SEugene Zhulenev 
124a6628e59SEugene Zhulenev   return success();
125a6628e59SEugene Zhulenev }
126a6628e59SEugene Zhulenev 
127*8a316b00SEugene Zhulenev void AsyncRuntimeRefCountingOptPass::runOnOperation() {
128*8a316b00SEugene Zhulenev   Operation *op = getOperation();
129a6628e59SEugene Zhulenev 
130a6628e59SEugene Zhulenev   // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
131a6628e59SEugene Zhulenev   //
132a6628e59SEugene Zhulenev   // Find all cancellable pairs of operation and erase them in the end to keep
133a6628e59SEugene Zhulenev   // all iterators valid while we are walking the function operations.
134a6628e59SEugene Zhulenev   llvm::SmallDenseMap<Operation *, Operation *> cancellable;
135a6628e59SEugene Zhulenev 
136a6628e59SEugene Zhulenev   // Optimize reference counting for values defined by block arguments.
137*8a316b00SEugene Zhulenev   WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
138a6628e59SEugene Zhulenev     for (BlockArgument arg : block->getArguments())
139a6628e59SEugene Zhulenev       if (isRefCounted(arg.getType()))
140a6628e59SEugene Zhulenev         if (failed(optimizeReferenceCounting(arg, cancellable)))
141a6628e59SEugene Zhulenev           return WalkResult::interrupt();
142a6628e59SEugene Zhulenev 
143a6628e59SEugene Zhulenev     return WalkResult::advance();
144a6628e59SEugene Zhulenev   });
145a6628e59SEugene Zhulenev 
146a6628e59SEugene Zhulenev   if (blockWalk.wasInterrupted())
147a6628e59SEugene Zhulenev     signalPassFailure();
148a6628e59SEugene Zhulenev 
149a6628e59SEugene Zhulenev   // Optimize reference counting for values defined by operation results.
150*8a316b00SEugene Zhulenev   WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
151a6628e59SEugene Zhulenev     for (unsigned i = 0; i < op->getNumResults(); ++i)
152a6628e59SEugene Zhulenev       if (isRefCounted(op->getResultTypes()[i]))
153a6628e59SEugene Zhulenev         if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
154a6628e59SEugene Zhulenev           return WalkResult::interrupt();
155a6628e59SEugene Zhulenev 
156a6628e59SEugene Zhulenev     return WalkResult::advance();
157a6628e59SEugene Zhulenev   });
158a6628e59SEugene Zhulenev 
159a6628e59SEugene Zhulenev   if (opWalk.wasInterrupted())
160a6628e59SEugene Zhulenev     signalPassFailure();
161a6628e59SEugene Zhulenev 
162a6628e59SEugene Zhulenev   LLVM_DEBUG({
163a6628e59SEugene Zhulenev     llvm::dbgs() << "Found " << cancellable.size()
164a6628e59SEugene Zhulenev                  << " cancellable reference counting operations\n";
165a6628e59SEugene Zhulenev   });
166a6628e59SEugene Zhulenev 
167a6628e59SEugene Zhulenev   // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
168a6628e59SEugene Zhulenev   for (auto &kv : cancellable) {
169a6628e59SEugene Zhulenev     kv.first->erase();
170a6628e59SEugene Zhulenev     kv.second->erase();
171a6628e59SEugene Zhulenev   }
172a6628e59SEugene Zhulenev }
173a6628e59SEugene Zhulenev 
174*8a316b00SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() {
175a6628e59SEugene Zhulenev   return std::make_unique<AsyncRuntimeRefCountingOptPass>();
176a6628e59SEugene Zhulenev }
177