1*a6628e59SEugene Zhulenev //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
2*a6628e59SEugene Zhulenev //
3*a6628e59SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*a6628e59SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
5*a6628e59SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*a6628e59SEugene Zhulenev //
7*a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===//
8*a6628e59SEugene Zhulenev //
9*a6628e59SEugene Zhulenev // Optimize Async dialect reference counting operations.
10*a6628e59SEugene Zhulenev //
11*a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===//
12*a6628e59SEugene Zhulenev 
13*a6628e59SEugene Zhulenev #include "PassDetail.h"
14*a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h"
15*a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h"
16*a6628e59SEugene Zhulenev #include "llvm/ADT/SmallSet.h"
17*a6628e59SEugene Zhulenev 
18*a6628e59SEugene Zhulenev using namespace mlir;
19*a6628e59SEugene Zhulenev using namespace mlir::async;
20*a6628e59SEugene Zhulenev 
21*a6628e59SEugene Zhulenev #define DEBUG_TYPE "async-ref-counting"
22*a6628e59SEugene Zhulenev 
23*a6628e59SEugene Zhulenev namespace {
24*a6628e59SEugene Zhulenev 
25*a6628e59SEugene Zhulenev class AsyncRuntimeRefCountingOptPass
26*a6628e59SEugene Zhulenev     : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> {
27*a6628e59SEugene Zhulenev public:
28*a6628e59SEugene Zhulenev   AsyncRuntimeRefCountingOptPass() = default;
29*a6628e59SEugene Zhulenev   void runOnFunction() override;
30*a6628e59SEugene Zhulenev 
31*a6628e59SEugene Zhulenev private:
32*a6628e59SEugene Zhulenev   LogicalResult optimizeReferenceCounting(
33*a6628e59SEugene Zhulenev       Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
34*a6628e59SEugene Zhulenev };
35*a6628e59SEugene Zhulenev 
36*a6628e59SEugene Zhulenev } // namespace
37*a6628e59SEugene Zhulenev 
38*a6628e59SEugene Zhulenev LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
39*a6628e59SEugene Zhulenev     Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
40*a6628e59SEugene Zhulenev   Region *definingRegion = value.getParentRegion();
41*a6628e59SEugene Zhulenev 
42*a6628e59SEugene Zhulenev   // Find all users of the `value` inside each block, including operations that
43*a6628e59SEugene Zhulenev   // do not use `value` directly, but have a direct use inside nested region(s).
44*a6628e59SEugene Zhulenev   //
45*a6628e59SEugene Zhulenev   // Example:
46*a6628e59SEugene Zhulenev   //
47*a6628e59SEugene Zhulenev   //  ^bb1:
48*a6628e59SEugene Zhulenev   //    %token = ...
49*a6628e59SEugene Zhulenev   //    scf.if %cond {
50*a6628e59SEugene Zhulenev   //      ^bb2:
51*a6628e59SEugene Zhulenev   //      async.runtime.await %token : !async.token
52*a6628e59SEugene Zhulenev   //    }
53*a6628e59SEugene Zhulenev   //
54*a6628e59SEugene Zhulenev   // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
55*a6628e59SEugene Zhulenev   // (`scf.if`).
56*a6628e59SEugene Zhulenev 
57*a6628e59SEugene Zhulenev   struct BlockUsersInfo {
58*a6628e59SEugene Zhulenev     llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
59*a6628e59SEugene Zhulenev     llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
60*a6628e59SEugene Zhulenev     llvm::SmallVector<Operation *, 4> users;
61*a6628e59SEugene Zhulenev   };
62*a6628e59SEugene Zhulenev 
63*a6628e59SEugene Zhulenev   llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
64*a6628e59SEugene Zhulenev 
65*a6628e59SEugene Zhulenev   auto updateBlockUsersInfo = [&](Operation *user) {
66*a6628e59SEugene Zhulenev     BlockUsersInfo &info = blockUsers[user->getBlock()];
67*a6628e59SEugene Zhulenev     info.users.push_back(user);
68*a6628e59SEugene Zhulenev 
69*a6628e59SEugene Zhulenev     if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
70*a6628e59SEugene Zhulenev       info.addRefs.push_back(addRef);
71*a6628e59SEugene Zhulenev     if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
72*a6628e59SEugene Zhulenev       info.dropRefs.push_back(dropRef);
73*a6628e59SEugene Zhulenev   };
74*a6628e59SEugene Zhulenev 
75*a6628e59SEugene Zhulenev   for (Operation *user : value.getUsers()) {
76*a6628e59SEugene Zhulenev     while (user->getParentRegion() != definingRegion) {
77*a6628e59SEugene Zhulenev       updateBlockUsersInfo(user);
78*a6628e59SEugene Zhulenev       user = user->getParentOp();
79*a6628e59SEugene Zhulenev       assert(user != nullptr && "value user lies outside of the value region");
80*a6628e59SEugene Zhulenev     }
81*a6628e59SEugene Zhulenev 
82*a6628e59SEugene Zhulenev     updateBlockUsersInfo(user);
83*a6628e59SEugene Zhulenev   }
84*a6628e59SEugene Zhulenev 
85*a6628e59SEugene Zhulenev   // Sort all operations found in the block.
86*a6628e59SEugene Zhulenev   auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
87*a6628e59SEugene Zhulenev     auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
88*a6628e59SEugene Zhulenev       return a->isBeforeInBlock(b);
89*a6628e59SEugene Zhulenev     };
90*a6628e59SEugene Zhulenev     llvm::sort(info.addRefs, isBeforeInBlock);
91*a6628e59SEugene Zhulenev     llvm::sort(info.dropRefs, isBeforeInBlock);
92*a6628e59SEugene Zhulenev     llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
93*a6628e59SEugene Zhulenev       return isBeforeInBlock(a, b);
94*a6628e59SEugene Zhulenev     });
95*a6628e59SEugene Zhulenev 
96*a6628e59SEugene Zhulenev     return info;
97*a6628e59SEugene Zhulenev   };
98*a6628e59SEugene Zhulenev 
99*a6628e59SEugene Zhulenev   // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
100*a6628e59SEugene Zhulenev   // blocks that modify the reference count of the `value`.
101*a6628e59SEugene Zhulenev   for (auto &kv : blockUsers) {
102*a6628e59SEugene Zhulenev     BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
103*a6628e59SEugene Zhulenev 
104*a6628e59SEugene Zhulenev     for (RuntimeAddRefOp addRef : info.addRefs) {
105*a6628e59SEugene Zhulenev       for (RuntimeDropRefOp dropRef : info.dropRefs) {
106*a6628e59SEugene Zhulenev         // `drop_ref` operation after the `add_ref` with matching count.
107*a6628e59SEugene Zhulenev         if (dropRef.count() != addRef.count() ||
108*a6628e59SEugene Zhulenev             dropRef->isBeforeInBlock(addRef.getOperation()))
109*a6628e59SEugene Zhulenev           continue;
110*a6628e59SEugene Zhulenev 
111*a6628e59SEugene Zhulenev         // Try to cancel the pair of `add_ref` and `drop_ref` operations.
112*a6628e59SEugene Zhulenev         auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
113*a6628e59SEugene Zhulenev                                                 addRef.getOperation());
114*a6628e59SEugene Zhulenev 
115*a6628e59SEugene Zhulenev         if (!emplaced.second) // `drop_ref` was already marked for removal
116*a6628e59SEugene Zhulenev           continue;           // go to the next `drop_ref`
117*a6628e59SEugene Zhulenev 
118*a6628e59SEugene Zhulenev         if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
119*a6628e59SEugene Zhulenev           break;             // go to the next `add_ref`
120*a6628e59SEugene Zhulenev       }
121*a6628e59SEugene Zhulenev     }
122*a6628e59SEugene Zhulenev   }
123*a6628e59SEugene Zhulenev 
124*a6628e59SEugene Zhulenev   return success();
125*a6628e59SEugene Zhulenev }
126*a6628e59SEugene Zhulenev 
127*a6628e59SEugene Zhulenev void AsyncRuntimeRefCountingOptPass::runOnFunction() {
128*a6628e59SEugene Zhulenev   FuncOp func = getFunction();
129*a6628e59SEugene Zhulenev 
130*a6628e59SEugene Zhulenev   // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
131*a6628e59SEugene Zhulenev   //
132*a6628e59SEugene Zhulenev   // Find all cancellable pairs of operation and erase them in the end to keep
133*a6628e59SEugene Zhulenev   // all iterators valid while we are walking the function operations.
134*a6628e59SEugene Zhulenev   llvm::SmallDenseMap<Operation *, Operation *> cancellable;
135*a6628e59SEugene Zhulenev 
136*a6628e59SEugene Zhulenev   // Optimize reference counting for values defined by block arguments.
137*a6628e59SEugene Zhulenev   WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
138*a6628e59SEugene Zhulenev     for (BlockArgument arg : block->getArguments())
139*a6628e59SEugene Zhulenev       if (isRefCounted(arg.getType()))
140*a6628e59SEugene Zhulenev         if (failed(optimizeReferenceCounting(arg, cancellable)))
141*a6628e59SEugene Zhulenev           return WalkResult::interrupt();
142*a6628e59SEugene Zhulenev 
143*a6628e59SEugene Zhulenev     return WalkResult::advance();
144*a6628e59SEugene Zhulenev   });
145*a6628e59SEugene Zhulenev 
146*a6628e59SEugene Zhulenev   if (blockWalk.wasInterrupted())
147*a6628e59SEugene Zhulenev     signalPassFailure();
148*a6628e59SEugene Zhulenev 
149*a6628e59SEugene Zhulenev   // Optimize reference counting for values defined by operation results.
150*a6628e59SEugene Zhulenev   WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
151*a6628e59SEugene Zhulenev     for (unsigned i = 0; i < op->getNumResults(); ++i)
152*a6628e59SEugene Zhulenev       if (isRefCounted(op->getResultTypes()[i]))
153*a6628e59SEugene Zhulenev         if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
154*a6628e59SEugene Zhulenev           return WalkResult::interrupt();
155*a6628e59SEugene Zhulenev 
156*a6628e59SEugene Zhulenev     return WalkResult::advance();
157*a6628e59SEugene Zhulenev   });
158*a6628e59SEugene Zhulenev 
159*a6628e59SEugene Zhulenev   if (opWalk.wasInterrupted())
160*a6628e59SEugene Zhulenev     signalPassFailure();
161*a6628e59SEugene Zhulenev 
162*a6628e59SEugene Zhulenev   LLVM_DEBUG({
163*a6628e59SEugene Zhulenev     llvm::dbgs() << "Found " << cancellable.size()
164*a6628e59SEugene Zhulenev                  << " cancellable reference counting operations\n";
165*a6628e59SEugene Zhulenev   });
166*a6628e59SEugene Zhulenev 
167*a6628e59SEugene Zhulenev   // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
168*a6628e59SEugene Zhulenev   for (auto &kv : cancellable) {
169*a6628e59SEugene Zhulenev     kv.first->erase();
170*a6628e59SEugene Zhulenev     kv.second->erase();
171*a6628e59SEugene Zhulenev   }
172*a6628e59SEugene Zhulenev }
173*a6628e59SEugene Zhulenev 
174*a6628e59SEugene Zhulenev std::unique_ptr<OperationPass<FuncOp>>
175*a6628e59SEugene Zhulenev mlir::createAsyncRuntimeRefCountingOptPass() {
176*a6628e59SEugene Zhulenev   return std::make_unique<AsyncRuntimeRefCountingOptPass>();
177*a6628e59SEugene Zhulenev }
178