1a6628e59SEugene Zhulenev //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
2a6628e59SEugene Zhulenev //
3a6628e59SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a6628e59SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
5a6628e59SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a6628e59SEugene Zhulenev //
7a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===//
8a6628e59SEugene Zhulenev //
9a6628e59SEugene Zhulenev // Optimize Async dialect reference counting operations.
10a6628e59SEugene Zhulenev //
11a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===//
12a6628e59SEugene Zhulenev 
13a6628e59SEugene Zhulenev #include "PassDetail.h"
14a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h"
15a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h"
16*23aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
17a6628e59SEugene Zhulenev #include "llvm/ADT/SmallSet.h"
18297a5b7cSNico Weber #include "llvm/Support/Debug.h"
19a6628e59SEugene Zhulenev 
20a6628e59SEugene Zhulenev using namespace mlir;
21a6628e59SEugene Zhulenev using namespace mlir::async;
22a6628e59SEugene Zhulenev 
23a6628e59SEugene Zhulenev #define DEBUG_TYPE "async-ref-counting"
24a6628e59SEugene Zhulenev 
25a6628e59SEugene Zhulenev namespace {
26a6628e59SEugene Zhulenev 
27a6628e59SEugene Zhulenev class AsyncRuntimeRefCountingOptPass
28a6628e59SEugene Zhulenev     : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> {
29a6628e59SEugene Zhulenev public:
30a6628e59SEugene Zhulenev   AsyncRuntimeRefCountingOptPass() = default;
318a316b00SEugene Zhulenev   void runOnOperation() override;
32a6628e59SEugene Zhulenev 
33a6628e59SEugene Zhulenev private:
34a6628e59SEugene Zhulenev   LogicalResult optimizeReferenceCounting(
35a6628e59SEugene Zhulenev       Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
36a6628e59SEugene Zhulenev };
37a6628e59SEugene Zhulenev 
38a6628e59SEugene Zhulenev } // namespace
39a6628e59SEugene Zhulenev 
optimizeReferenceCounting(Value value,llvm::SmallDenseMap<Operation *,Operation * > & cancellable)40a6628e59SEugene Zhulenev LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
41a6628e59SEugene Zhulenev     Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
42a6628e59SEugene Zhulenev   Region *definingRegion = value.getParentRegion();
43a6628e59SEugene Zhulenev 
44a6628e59SEugene Zhulenev   // Find all users of the `value` inside each block, including operations that
45a6628e59SEugene Zhulenev   // do not use `value` directly, but have a direct use inside nested region(s).
46a6628e59SEugene Zhulenev   //
47a6628e59SEugene Zhulenev   // Example:
48a6628e59SEugene Zhulenev   //
49a6628e59SEugene Zhulenev   //  ^bb1:
50a6628e59SEugene Zhulenev   //    %token = ...
51a6628e59SEugene Zhulenev   //    scf.if %cond {
52a6628e59SEugene Zhulenev   //      ^bb2:
53a6628e59SEugene Zhulenev   //      async.runtime.await %token : !async.token
54a6628e59SEugene Zhulenev   //    }
55a6628e59SEugene Zhulenev   //
56a6628e59SEugene Zhulenev   // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
57a6628e59SEugene Zhulenev   // (`scf.if`).
58a6628e59SEugene Zhulenev 
59a6628e59SEugene Zhulenev   struct BlockUsersInfo {
60a6628e59SEugene Zhulenev     llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
61a6628e59SEugene Zhulenev     llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
62a6628e59SEugene Zhulenev     llvm::SmallVector<Operation *, 4> users;
63a6628e59SEugene Zhulenev   };
64a6628e59SEugene Zhulenev 
65a6628e59SEugene Zhulenev   llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
66a6628e59SEugene Zhulenev 
67a6628e59SEugene Zhulenev   auto updateBlockUsersInfo = [&](Operation *user) {
68a6628e59SEugene Zhulenev     BlockUsersInfo &info = blockUsers[user->getBlock()];
69a6628e59SEugene Zhulenev     info.users.push_back(user);
70a6628e59SEugene Zhulenev 
71a6628e59SEugene Zhulenev     if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
72a6628e59SEugene Zhulenev       info.addRefs.push_back(addRef);
73a6628e59SEugene Zhulenev     if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
74a6628e59SEugene Zhulenev       info.dropRefs.push_back(dropRef);
75a6628e59SEugene Zhulenev   };
76a6628e59SEugene Zhulenev 
77a6628e59SEugene Zhulenev   for (Operation *user : value.getUsers()) {
78a6628e59SEugene Zhulenev     while (user->getParentRegion() != definingRegion) {
79a6628e59SEugene Zhulenev       updateBlockUsersInfo(user);
80a6628e59SEugene Zhulenev       user = user->getParentOp();
81a6628e59SEugene Zhulenev       assert(user != nullptr && "value user lies outside of the value region");
82a6628e59SEugene Zhulenev     }
83a6628e59SEugene Zhulenev 
84a6628e59SEugene Zhulenev     updateBlockUsersInfo(user);
85a6628e59SEugene Zhulenev   }
86a6628e59SEugene Zhulenev 
87a6628e59SEugene Zhulenev   // Sort all operations found in the block.
88a6628e59SEugene Zhulenev   auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
89a6628e59SEugene Zhulenev     auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
90a6628e59SEugene Zhulenev       return a->isBeforeInBlock(b);
91a6628e59SEugene Zhulenev     };
92a6628e59SEugene Zhulenev     llvm::sort(info.addRefs, isBeforeInBlock);
93a6628e59SEugene Zhulenev     llvm::sort(info.dropRefs, isBeforeInBlock);
94a6628e59SEugene Zhulenev     llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
95a6628e59SEugene Zhulenev       return isBeforeInBlock(a, b);
96a6628e59SEugene Zhulenev     });
97a6628e59SEugene Zhulenev 
98a6628e59SEugene Zhulenev     return info;
99a6628e59SEugene Zhulenev   };
100a6628e59SEugene Zhulenev 
101a6628e59SEugene Zhulenev   // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
102a6628e59SEugene Zhulenev   // blocks that modify the reference count of the `value`.
103a6628e59SEugene Zhulenev   for (auto &kv : blockUsers) {
104a6628e59SEugene Zhulenev     BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
105a6628e59SEugene Zhulenev 
106a6628e59SEugene Zhulenev     for (RuntimeAddRefOp addRef : info.addRefs) {
107a6628e59SEugene Zhulenev       for (RuntimeDropRefOp dropRef : info.dropRefs) {
108a6628e59SEugene Zhulenev         // `drop_ref` operation after the `add_ref` with matching count.
109a6628e59SEugene Zhulenev         if (dropRef.count() != addRef.count() ||
110a6628e59SEugene Zhulenev             dropRef->isBeforeInBlock(addRef.getOperation()))
111a6628e59SEugene Zhulenev           continue;
112a6628e59SEugene Zhulenev 
1139ccdaac8SEugene Zhulenev         // When reference counted value passed to a function as an argument,
1149ccdaac8SEugene Zhulenev         // function takes ownership of +1 reference and it will drop it before
1159ccdaac8SEugene Zhulenev         // returning.
1169ccdaac8SEugene Zhulenev         //
1179ccdaac8SEugene Zhulenev         // Example:
1189ccdaac8SEugene Zhulenev         //
1199ccdaac8SEugene Zhulenev         //   %token = ... : !async.token
1209ccdaac8SEugene Zhulenev         //
12192db09cdSEugene Zhulenev         //   async.runtime.add_ref %token {count = 1 : i64} : !async.token
1229ccdaac8SEugene Zhulenev         //   call @pass_token(%token: !async.token, ...)
1239ccdaac8SEugene Zhulenev         //
1249ccdaac8SEugene Zhulenev         //   async.await %token : !async.token
12592db09cdSEugene Zhulenev         //   async.runtime.drop_ref %token {count = 1 : i64} : !async.token
1269ccdaac8SEugene Zhulenev         //
1279ccdaac8SEugene Zhulenev         // In this example if we'll cancel a pair of reference counting
1289ccdaac8SEugene Zhulenev         // operations we might end up with a deallocated token when we'll
1299ccdaac8SEugene Zhulenev         // reach `async.await` operation.
1309ccdaac8SEugene Zhulenev         Operation *firstFunctionCallUser = nullptr;
1319ccdaac8SEugene Zhulenev         Operation *lastNonFunctionCallUser = nullptr;
1329ccdaac8SEugene Zhulenev 
1339ccdaac8SEugene Zhulenev         for (Operation *user : info.users) {
1349ccdaac8SEugene Zhulenev           // `user` operation lies after `addRef` ...
1359ccdaac8SEugene Zhulenev           if (user == addRef || user->isBeforeInBlock(addRef))
1369ccdaac8SEugene Zhulenev             continue;
1379ccdaac8SEugene Zhulenev           // ... and before `dropRef`.
1389ccdaac8SEugene Zhulenev           if (user == dropRef || dropRef->isBeforeInBlock(user))
1399ccdaac8SEugene Zhulenev             break;
1409ccdaac8SEugene Zhulenev 
1419ccdaac8SEugene Zhulenev           // Find the first function call user of the reference counted value.
142*23aa5a74SRiver Riddle           Operation *functionCall = dyn_cast<func::CallOp>(user);
1439ccdaac8SEugene Zhulenev           if (functionCall &&
1449ccdaac8SEugene Zhulenev               (!firstFunctionCallUser ||
1459ccdaac8SEugene Zhulenev                functionCall->isBeforeInBlock(firstFunctionCallUser))) {
1469ccdaac8SEugene Zhulenev             firstFunctionCallUser = functionCall;
1479ccdaac8SEugene Zhulenev             continue;
1489ccdaac8SEugene Zhulenev           }
1499ccdaac8SEugene Zhulenev 
1509ccdaac8SEugene Zhulenev           // Find the last regular user of the reference counted value.
1519ccdaac8SEugene Zhulenev           if (!functionCall &&
1529ccdaac8SEugene Zhulenev               (!lastNonFunctionCallUser ||
1539ccdaac8SEugene Zhulenev                lastNonFunctionCallUser->isBeforeInBlock(user))) {
1549ccdaac8SEugene Zhulenev             lastNonFunctionCallUser = user;
1559ccdaac8SEugene Zhulenev             continue;
1569ccdaac8SEugene Zhulenev           }
1579ccdaac8SEugene Zhulenev         }
1589ccdaac8SEugene Zhulenev 
1599ccdaac8SEugene Zhulenev         // Non function call user after the function call user of the reference
1609ccdaac8SEugene Zhulenev         // counted value.
1619ccdaac8SEugene Zhulenev         if (firstFunctionCallUser && lastNonFunctionCallUser &&
1629ccdaac8SEugene Zhulenev             firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser))
1639ccdaac8SEugene Zhulenev           continue;
1649ccdaac8SEugene Zhulenev 
165a6628e59SEugene Zhulenev         // Try to cancel the pair of `add_ref` and `drop_ref` operations.
166a6628e59SEugene Zhulenev         auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
167a6628e59SEugene Zhulenev                                                 addRef.getOperation());
168a6628e59SEugene Zhulenev 
169a6628e59SEugene Zhulenev         if (!emplaced.second) // `drop_ref` was already marked for removal
170a6628e59SEugene Zhulenev           continue;           // go to the next `drop_ref`
171a6628e59SEugene Zhulenev 
172a6628e59SEugene Zhulenev         if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
173a6628e59SEugene Zhulenev           break;             // go to the next `add_ref`
174a6628e59SEugene Zhulenev       }
175a6628e59SEugene Zhulenev     }
176a6628e59SEugene Zhulenev   }
177a6628e59SEugene Zhulenev 
178a6628e59SEugene Zhulenev   return success();
179a6628e59SEugene Zhulenev }
180a6628e59SEugene Zhulenev 
runOnOperation()1818a316b00SEugene Zhulenev void AsyncRuntimeRefCountingOptPass::runOnOperation() {
1828a316b00SEugene Zhulenev   Operation *op = getOperation();
183a6628e59SEugene Zhulenev 
184a6628e59SEugene Zhulenev   // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
185a6628e59SEugene Zhulenev   //
186a6628e59SEugene Zhulenev   // Find all cancellable pairs of operation and erase them in the end to keep
187a6628e59SEugene Zhulenev   // all iterators valid while we are walking the function operations.
188a6628e59SEugene Zhulenev   llvm::SmallDenseMap<Operation *, Operation *> cancellable;
189a6628e59SEugene Zhulenev 
190a6628e59SEugene Zhulenev   // Optimize reference counting for values defined by block arguments.
1918a316b00SEugene Zhulenev   WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
192a6628e59SEugene Zhulenev     for (BlockArgument arg : block->getArguments())
193a6628e59SEugene Zhulenev       if (isRefCounted(arg.getType()))
194a6628e59SEugene Zhulenev         if (failed(optimizeReferenceCounting(arg, cancellable)))
195a6628e59SEugene Zhulenev           return WalkResult::interrupt();
196a6628e59SEugene Zhulenev 
197a6628e59SEugene Zhulenev     return WalkResult::advance();
198a6628e59SEugene Zhulenev   });
199a6628e59SEugene Zhulenev 
200a6628e59SEugene Zhulenev   if (blockWalk.wasInterrupted())
201a6628e59SEugene Zhulenev     signalPassFailure();
202a6628e59SEugene Zhulenev 
203a6628e59SEugene Zhulenev   // Optimize reference counting for values defined by operation results.
2048a316b00SEugene Zhulenev   WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
205a6628e59SEugene Zhulenev     for (unsigned i = 0; i < op->getNumResults(); ++i)
206a6628e59SEugene Zhulenev       if (isRefCounted(op->getResultTypes()[i]))
207a6628e59SEugene Zhulenev         if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
208a6628e59SEugene Zhulenev           return WalkResult::interrupt();
209a6628e59SEugene Zhulenev 
210a6628e59SEugene Zhulenev     return WalkResult::advance();
211a6628e59SEugene Zhulenev   });
212a6628e59SEugene Zhulenev 
213a6628e59SEugene Zhulenev   if (opWalk.wasInterrupted())
214a6628e59SEugene Zhulenev     signalPassFailure();
215a6628e59SEugene Zhulenev 
216a6628e59SEugene Zhulenev   LLVM_DEBUG({
217a6628e59SEugene Zhulenev     llvm::dbgs() << "Found " << cancellable.size()
218a6628e59SEugene Zhulenev                  << " cancellable reference counting operations\n";
219a6628e59SEugene Zhulenev   });
220a6628e59SEugene Zhulenev 
221a6628e59SEugene Zhulenev   // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
222a6628e59SEugene Zhulenev   for (auto &kv : cancellable) {
223a6628e59SEugene Zhulenev     kv.first->erase();
224a6628e59SEugene Zhulenev     kv.second->erase();
225a6628e59SEugene Zhulenev   }
226a6628e59SEugene Zhulenev }
227a6628e59SEugene Zhulenev 
createAsyncRuntimeRefCountingOptPass()2288a316b00SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() {
229a6628e59SEugene Zhulenev   return std::make_unique<AsyncRuntimeRefCountingOptPass>();
230a6628e59SEugene Zhulenev }
231