1 //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Optimize Async dialect reference counting operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Async/IR/Async.h"
15 #include "mlir/Dialect/Async/Passes.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "llvm/ADT/SmallSet.h"
18 #include "llvm/Support/Debug.h"
19 
20 using namespace mlir;
21 using namespace mlir::async;
22 
23 #define DEBUG_TYPE "async-ref-counting"
24 
25 namespace {
26 
27 class AsyncRuntimeRefCountingOptPass
28     : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> {
29 public:
30   AsyncRuntimeRefCountingOptPass() = default;
31   void runOnOperation() override;
32 
33 private:
34   LogicalResult optimizeReferenceCounting(
35       Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
36 };
37 
38 } // namespace
39 
optimizeReferenceCounting(Value value,llvm::SmallDenseMap<Operation *,Operation * > & cancellable)40 LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
41     Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
42   Region *definingRegion = value.getParentRegion();
43 
44   // Find all users of the `value` inside each block, including operations that
45   // do not use `value` directly, but have a direct use inside nested region(s).
46   //
47   // Example:
48   //
49   //  ^bb1:
50   //    %token = ...
51   //    scf.if %cond {
52   //      ^bb2:
53   //      async.runtime.await %token : !async.token
54   //    }
55   //
56   // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
57   // (`scf.if`).
58 
59   struct BlockUsersInfo {
60     llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
61     llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
62     llvm::SmallVector<Operation *, 4> users;
63   };
64 
65   llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
66 
67   auto updateBlockUsersInfo = [&](Operation *user) {
68     BlockUsersInfo &info = blockUsers[user->getBlock()];
69     info.users.push_back(user);
70 
71     if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
72       info.addRefs.push_back(addRef);
73     if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
74       info.dropRefs.push_back(dropRef);
75   };
76 
77   for (Operation *user : value.getUsers()) {
78     while (user->getParentRegion() != definingRegion) {
79       updateBlockUsersInfo(user);
80       user = user->getParentOp();
81       assert(user != nullptr && "value user lies outside of the value region");
82     }
83 
84     updateBlockUsersInfo(user);
85   }
86 
87   // Sort all operations found in the block.
88   auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
89     auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
90       return a->isBeforeInBlock(b);
91     };
92     llvm::sort(info.addRefs, isBeforeInBlock);
93     llvm::sort(info.dropRefs, isBeforeInBlock);
94     llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
95       return isBeforeInBlock(a, b);
96     });
97 
98     return info;
99   };
100 
101   // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
102   // blocks that modify the reference count of the `value`.
103   for (auto &kv : blockUsers) {
104     BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
105 
106     for (RuntimeAddRefOp addRef : info.addRefs) {
107       for (RuntimeDropRefOp dropRef : info.dropRefs) {
108         // `drop_ref` operation after the `add_ref` with matching count.
109         if (dropRef.count() != addRef.count() ||
110             dropRef->isBeforeInBlock(addRef.getOperation()))
111           continue;
112 
113         // When reference counted value passed to a function as an argument,
114         // function takes ownership of +1 reference and it will drop it before
115         // returning.
116         //
117         // Example:
118         //
119         //   %token = ... : !async.token
120         //
121         //   async.runtime.add_ref %token {count = 1 : i64} : !async.token
122         //   call @pass_token(%token: !async.token, ...)
123         //
124         //   async.await %token : !async.token
125         //   async.runtime.drop_ref %token {count = 1 : i64} : !async.token
126         //
127         // In this example if we'll cancel a pair of reference counting
128         // operations we might end up with a deallocated token when we'll
129         // reach `async.await` operation.
130         Operation *firstFunctionCallUser = nullptr;
131         Operation *lastNonFunctionCallUser = nullptr;
132 
133         for (Operation *user : info.users) {
134           // `user` operation lies after `addRef` ...
135           if (user == addRef || user->isBeforeInBlock(addRef))
136             continue;
137           // ... and before `dropRef`.
138           if (user == dropRef || dropRef->isBeforeInBlock(user))
139             break;
140 
141           // Find the first function call user of the reference counted value.
142           Operation *functionCall = dyn_cast<func::CallOp>(user);
143           if (functionCall &&
144               (!firstFunctionCallUser ||
145                functionCall->isBeforeInBlock(firstFunctionCallUser))) {
146             firstFunctionCallUser = functionCall;
147             continue;
148           }
149 
150           // Find the last regular user of the reference counted value.
151           if (!functionCall &&
152               (!lastNonFunctionCallUser ||
153                lastNonFunctionCallUser->isBeforeInBlock(user))) {
154             lastNonFunctionCallUser = user;
155             continue;
156           }
157         }
158 
159         // Non function call user after the function call user of the reference
160         // counted value.
161         if (firstFunctionCallUser && lastNonFunctionCallUser &&
162             firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser))
163           continue;
164 
165         // Try to cancel the pair of `add_ref` and `drop_ref` operations.
166         auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
167                                                 addRef.getOperation());
168 
169         if (!emplaced.second) // `drop_ref` was already marked for removal
170           continue;           // go to the next `drop_ref`
171 
172         if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
173           break;             // go to the next `add_ref`
174       }
175     }
176   }
177 
178   return success();
179 }
180 
runOnOperation()181 void AsyncRuntimeRefCountingOptPass::runOnOperation() {
182   Operation *op = getOperation();
183 
184   // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
185   //
186   // Find all cancellable pairs of operation and erase them in the end to keep
187   // all iterators valid while we are walking the function operations.
188   llvm::SmallDenseMap<Operation *, Operation *> cancellable;
189 
190   // Optimize reference counting for values defined by block arguments.
191   WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
192     for (BlockArgument arg : block->getArguments())
193       if (isRefCounted(arg.getType()))
194         if (failed(optimizeReferenceCounting(arg, cancellable)))
195           return WalkResult::interrupt();
196 
197     return WalkResult::advance();
198   });
199 
200   if (blockWalk.wasInterrupted())
201     signalPassFailure();
202 
203   // Optimize reference counting for values defined by operation results.
204   WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
205     for (unsigned i = 0; i < op->getNumResults(); ++i)
206       if (isRefCounted(op->getResultTypes()[i]))
207         if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
208           return WalkResult::interrupt();
209 
210     return WalkResult::advance();
211   });
212 
213   if (opWalk.wasInterrupted())
214     signalPassFailure();
215 
216   LLVM_DEBUG({
217     llvm::dbgs() << "Found " << cancellable.size()
218                  << " cancellable reference counting operations\n";
219   });
220 
221   // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
222   for (auto &kv : cancellable) {
223     kv.first->erase();
224     kv.second->erase();
225   }
226 }
227 
createAsyncRuntimeRefCountingOptPass()228 std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() {
229   return std::make_unique<AsyncRuntimeRefCountingOptPass>();
230 }
231