1 //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Optimize Async dialect reference counting operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Async/IR/Async.h"
15 #include "mlir/Dialect/Async/Passes.h"
16 #include "llvm/ADT/SmallSet.h"
17 #include "llvm/Support/Debug.h"
18 
19 using namespace mlir;
20 using namespace mlir::async;
21 
22 #define DEBUG_TYPE "async-ref-counting"
23 
24 namespace {
25 
26 class AsyncRuntimeRefCountingOptPass
27     : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> {
28 public:
29   AsyncRuntimeRefCountingOptPass() = default;
30   void runOnOperation() override;
31 
32 private:
33   LogicalResult optimizeReferenceCounting(
34       Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
35 };
36 
37 } // namespace
38 
39 LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
40     Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
41   Region *definingRegion = value.getParentRegion();
42 
43   // Find all users of the `value` inside each block, including operations that
44   // do not use `value` directly, but have a direct use inside nested region(s).
45   //
46   // Example:
47   //
48   //  ^bb1:
49   //    %token = ...
50   //    scf.if %cond {
51   //      ^bb2:
52   //      async.runtime.await %token : !async.token
53   //    }
54   //
55   // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
56   // (`scf.if`).
57 
58   struct BlockUsersInfo {
59     llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
60     llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
61     llvm::SmallVector<Operation *, 4> users;
62   };
63 
64   llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
65 
66   auto updateBlockUsersInfo = [&](Operation *user) {
67     BlockUsersInfo &info = blockUsers[user->getBlock()];
68     info.users.push_back(user);
69 
70     if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
71       info.addRefs.push_back(addRef);
72     if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
73       info.dropRefs.push_back(dropRef);
74   };
75 
76   for (Operation *user : value.getUsers()) {
77     while (user->getParentRegion() != definingRegion) {
78       updateBlockUsersInfo(user);
79       user = user->getParentOp();
80       assert(user != nullptr && "value user lies outside of the value region");
81     }
82 
83     updateBlockUsersInfo(user);
84   }
85 
86   // Sort all operations found in the block.
87   auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
88     auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
89       return a->isBeforeInBlock(b);
90     };
91     llvm::sort(info.addRefs, isBeforeInBlock);
92     llvm::sort(info.dropRefs, isBeforeInBlock);
93     llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
94       return isBeforeInBlock(a, b);
95     });
96 
97     return info;
98   };
99 
100   // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
101   // blocks that modify the reference count of the `value`.
102   for (auto &kv : blockUsers) {
103     BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
104 
105     for (RuntimeAddRefOp addRef : info.addRefs) {
106       for (RuntimeDropRefOp dropRef : info.dropRefs) {
107         // `drop_ref` operation after the `add_ref` with matching count.
108         if (dropRef.count() != addRef.count() ||
109             dropRef->isBeforeInBlock(addRef.getOperation()))
110           continue;
111 
112         // Try to cancel the pair of `add_ref` and `drop_ref` operations.
113         auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
114                                                 addRef.getOperation());
115 
116         if (!emplaced.second) // `drop_ref` was already marked for removal
117           continue;           // go to the next `drop_ref`
118 
119         if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
120           break;             // go to the next `add_ref`
121       }
122     }
123   }
124 
125   return success();
126 }
127 
128 void AsyncRuntimeRefCountingOptPass::runOnOperation() {
129   Operation *op = getOperation();
130 
131   // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
132   //
133   // Find all cancellable pairs of operation and erase them in the end to keep
134   // all iterators valid while we are walking the function operations.
135   llvm::SmallDenseMap<Operation *, Operation *> cancellable;
136 
137   // Optimize reference counting for values defined by block arguments.
138   WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
139     for (BlockArgument arg : block->getArguments())
140       if (isRefCounted(arg.getType()))
141         if (failed(optimizeReferenceCounting(arg, cancellable)))
142           return WalkResult::interrupt();
143 
144     return WalkResult::advance();
145   });
146 
147   if (blockWalk.wasInterrupted())
148     signalPassFailure();
149 
150   // Optimize reference counting for values defined by operation results.
151   WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
152     for (unsigned i = 0; i < op->getNumResults(); ++i)
153       if (isRefCounted(op->getResultTypes()[i]))
154         if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
155           return WalkResult::interrupt();
156 
157     return WalkResult::advance();
158   });
159 
160   if (opWalk.wasInterrupted())
161     signalPassFailure();
162 
163   LLVM_DEBUG({
164     llvm::dbgs() << "Found " << cancellable.size()
165                  << " cancellable reference counting operations\n";
166   });
167 
168   // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
169   for (auto &kv : cancellable) {
170     kv.first->erase();
171     kv.second->erase();
172   }
173 }
174 
175 std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() {
176   return std::make_unique<AsyncRuntimeRefCountingOptPass>();
177 }
178