1a6628e59SEugene Zhulenev //===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===//
2a6628e59SEugene Zhulenev //
3a6628e59SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a6628e59SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
5a6628e59SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a6628e59SEugene Zhulenev //
7a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===//
8a6628e59SEugene Zhulenev //
9a6628e59SEugene Zhulenev // This file implements automatic reference counting for Async runtime
10a6628e59SEugene Zhulenev // operations and types.
11a6628e59SEugene Zhulenev //
12a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===//
13a6628e59SEugene Zhulenev 
14a6628e59SEugene Zhulenev #include "PassDetail.h"
15a6628e59SEugene Zhulenev #include "mlir/Analysis/Liveness.h"
16a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h"
17a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h"
18a6628e59SEugene Zhulenev #include "mlir/Dialect/StandardOps/IR/Ops.h"
19a6628e59SEugene Zhulenev #include "mlir/IR/ImplicitLocOpBuilder.h"
20a6628e59SEugene Zhulenev #include "mlir/IR/PatternMatch.h"
21a6628e59SEugene Zhulenev #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22a6628e59SEugene Zhulenev #include "llvm/ADT/SmallSet.h"
23a6628e59SEugene Zhulenev 
24a6628e59SEugene Zhulenev using namespace mlir;
25a6628e59SEugene Zhulenev using namespace mlir::async;
26a6628e59SEugene Zhulenev 
27a6628e59SEugene Zhulenev #define DEBUG_TYPE "async-runtime-ref-counting"
28a6628e59SEugene Zhulenev 
29f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===//
30f57b2420SEugene Zhulenev // Utility functions shared by reference counting passes.
31f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===//
32f57b2420SEugene Zhulenev 
33f57b2420SEugene Zhulenev // Drop the reference count immediately if the value has no uses.
34f57b2420SEugene Zhulenev static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) {
35f57b2420SEugene Zhulenev   if (!value.getUses().empty())
36f57b2420SEugene Zhulenev     return failure();
37f57b2420SEugene Zhulenev 
38f57b2420SEugene Zhulenev   OpBuilder b(value.getContext());
39f57b2420SEugene Zhulenev 
40f57b2420SEugene Zhulenev   // Set insertion point after the operation producing a value, or at the
41f57b2420SEugene Zhulenev   // beginning of the block if the value defined by the block argument.
42f57b2420SEugene Zhulenev   if (Operation *op = value.getDefiningOp())
43f57b2420SEugene Zhulenev     b.setInsertionPointAfter(op);
44f57b2420SEugene Zhulenev   else
45f57b2420SEugene Zhulenev     b.setInsertionPointToStart(value.getParentBlock());
46f57b2420SEugene Zhulenev 
47*92db09cdSEugene Zhulenev   b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI64IntegerAttr(1));
48f57b2420SEugene Zhulenev   return success();
49f57b2420SEugene Zhulenev }
50f57b2420SEugene Zhulenev 
51f57b2420SEugene Zhulenev // Calls `addRefCounting` for every reference counted value defined by the
52f57b2420SEugene Zhulenev // operation `op` (block arguments and values defined in nested regions).
53f57b2420SEugene Zhulenev static LogicalResult walkReferenceCountedValues(
54f57b2420SEugene Zhulenev     Operation *op, llvm::function_ref<LogicalResult(Value)> addRefCounting) {
55f57b2420SEugene Zhulenev   // Check that we do not have high level async operations in the IR because
56f57b2420SEugene Zhulenev   // otherwise reference counting will produce incorrect results after high
57f57b2420SEugene Zhulenev   // level async operations will be lowered to `async.runtime`
58f57b2420SEugene Zhulenev   WalkResult checkNoAsyncWalk = op->walk([&](Operation *op) -> WalkResult {
59f57b2420SEugene Zhulenev     if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
60f57b2420SEugene Zhulenev       return WalkResult::advance();
61f57b2420SEugene Zhulenev 
62f57b2420SEugene Zhulenev     return op->emitError()
63f57b2420SEugene Zhulenev            << "async operations must be lowered to async runtime operations";
64f57b2420SEugene Zhulenev   });
65f57b2420SEugene Zhulenev 
66f57b2420SEugene Zhulenev   if (checkNoAsyncWalk.wasInterrupted())
67f57b2420SEugene Zhulenev     return failure();
68f57b2420SEugene Zhulenev 
69f57b2420SEugene Zhulenev   // Add reference counting to block arguments.
70f57b2420SEugene Zhulenev   WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
71f57b2420SEugene Zhulenev     for (BlockArgument arg : block->getArguments())
72f57b2420SEugene Zhulenev       if (isRefCounted(arg.getType()))
73f57b2420SEugene Zhulenev         if (failed(addRefCounting(arg)))
74f57b2420SEugene Zhulenev           return WalkResult::interrupt();
75f57b2420SEugene Zhulenev 
76f57b2420SEugene Zhulenev     return WalkResult::advance();
77f57b2420SEugene Zhulenev   });
78f57b2420SEugene Zhulenev 
79f57b2420SEugene Zhulenev   if (blockWalk.wasInterrupted())
80f57b2420SEugene Zhulenev     return failure();
81f57b2420SEugene Zhulenev 
82f57b2420SEugene Zhulenev   // Add reference counting to operation results.
83f57b2420SEugene Zhulenev   WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
84f57b2420SEugene Zhulenev     for (unsigned i = 0; i < op->getNumResults(); ++i)
85f57b2420SEugene Zhulenev       if (isRefCounted(op->getResultTypes()[i]))
86f57b2420SEugene Zhulenev         if (failed(addRefCounting(op->getResult(i))))
87f57b2420SEugene Zhulenev           return WalkResult::interrupt();
88f57b2420SEugene Zhulenev 
89f57b2420SEugene Zhulenev     return WalkResult::advance();
90f57b2420SEugene Zhulenev   });
91f57b2420SEugene Zhulenev 
92f57b2420SEugene Zhulenev   if (opWalk.wasInterrupted())
93f57b2420SEugene Zhulenev     return failure();
94f57b2420SEugene Zhulenev 
95f57b2420SEugene Zhulenev   return success();
96f57b2420SEugene Zhulenev }
97f57b2420SEugene Zhulenev 
98f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===//
99f57b2420SEugene Zhulenev // Automatic reference counting based on the liveness analysis.
100f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===//
101f57b2420SEugene Zhulenev 
102a6628e59SEugene Zhulenev namespace {
103a6628e59SEugene Zhulenev 
104a6628e59SEugene Zhulenev class AsyncRuntimeRefCountingPass
105a6628e59SEugene Zhulenev     : public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
106a6628e59SEugene Zhulenev public:
107a6628e59SEugene Zhulenev   AsyncRuntimeRefCountingPass() = default;
1088a316b00SEugene Zhulenev   void runOnOperation() override;
109a6628e59SEugene Zhulenev 
110a6628e59SEugene Zhulenev private:
111a6628e59SEugene Zhulenev   /// Adds an automatic reference counting to the `value`.
112a6628e59SEugene Zhulenev   ///
113a6628e59SEugene Zhulenev   /// All values (token, group or value) are semantically created with a
114a6628e59SEugene Zhulenev   /// reference count of +1 and it is the responsibility of the async value user
115a6628e59SEugene Zhulenev   /// to place the `add_ref` and `drop_ref` operations to ensure that the value
116a6628e59SEugene Zhulenev   /// is destroyed after the last use.
117a6628e59SEugene Zhulenev   ///
118a6628e59SEugene Zhulenev   /// The function returns failure if it can't deduce the locations where
119a6628e59SEugene Zhulenev   /// to place the reference counting operations.
120a6628e59SEugene Zhulenev   ///
121a6628e59SEugene Zhulenev   /// Async values "semantically created" when:
122a6628e59SEugene Zhulenev   ///   1. Operation returns async result (e.g. `async.runtime.create`)
123a6628e59SEugene Zhulenev   ///   2. Async value passed in as a block argument (or function argument,
124a6628e59SEugene Zhulenev   ///      because function arguments are just entry block arguments)
125a6628e59SEugene Zhulenev   ///
126a6628e59SEugene Zhulenev   /// Passing async value as a function argument (or block argument) does not
127a6628e59SEugene Zhulenev   /// really mean that a new async value is created, it only means that the
128a6628e59SEugene Zhulenev   /// caller of a function transfered ownership of `+1` reference to the callee.
129a6628e59SEugene Zhulenev   /// It is convenient to think that from the callee perspective async value was
130a6628e59SEugene Zhulenev   /// "created" with `+1` reference by the block argument.
131a6628e59SEugene Zhulenev   ///
132a6628e59SEugene Zhulenev   /// Automatic reference counting algorithm outline:
133a6628e59SEugene Zhulenev   ///
134a6628e59SEugene Zhulenev   /// #1 Insert `drop_ref` operations after last use of the `value`.
135a6628e59SEugene Zhulenev   /// #2 Insert `add_ref` operations before functions calls with reference
136a6628e59SEugene Zhulenev   ///    counted `value` operand (newly created `+1` reference will be
137a6628e59SEugene Zhulenev   ///    transferred to the callee).
138a6628e59SEugene Zhulenev   /// #3 Verify that divergent control flow does not lead to leaked reference
139a6628e59SEugene Zhulenev   ///    counted objects.
140a6628e59SEugene Zhulenev   ///
141a6628e59SEugene Zhulenev   /// Async runtime reference counting optimization pass will optimize away
142a6628e59SEugene Zhulenev   /// some of the redundant `add_ref` and `drop_ref` operations inserted by this
143a6628e59SEugene Zhulenev   /// strategy (see `async-runtime-ref-counting-opt`).
144a6628e59SEugene Zhulenev   LogicalResult addAutomaticRefCounting(Value value);
145a6628e59SEugene Zhulenev 
146a6628e59SEugene Zhulenev   /// (#1) Adds the `drop_ref` operation after the last use of the `value`
147a6628e59SEugene Zhulenev   /// relying on the liveness analysis.
148a6628e59SEugene Zhulenev   ///
149a6628e59SEugene Zhulenev   /// If the `value` is in the block `liveIn` set and it is not in the block
150a6628e59SEugene Zhulenev   /// `liveOut` set, it means that it "dies" in the block. We find the last
151a6628e59SEugene Zhulenev   /// use of the value in such block and:
152a6628e59SEugene Zhulenev   ///
153a6628e59SEugene Zhulenev   ///   1. If the last user is a `ReturnLike` operation we do nothing, because
154a6628e59SEugene Zhulenev   ///      it forwards the ownership to the caller.
155a6628e59SEugene Zhulenev   ///   2. Otherwise we add a `drop_ref` operation immediately after the last
156a6628e59SEugene Zhulenev   ///      use.
157a6628e59SEugene Zhulenev   LogicalResult addDropRefAfterLastUse(Value value);
158a6628e59SEugene Zhulenev 
159a6628e59SEugene Zhulenev   /// (#2) Adds the `add_ref` operation before the function call taking `value`
160a6628e59SEugene Zhulenev   /// operand to ensure that the value passed to the function entry block
161a6628e59SEugene Zhulenev   /// has a `+1` reference count.
162a6628e59SEugene Zhulenev   LogicalResult addAddRefBeforeFunctionCall(Value value);
163a6628e59SEugene Zhulenev 
164c412979cSEugene Zhulenev   /// (#3) Adds the `drop_ref` operation to account for successor blocks with
165c412979cSEugene Zhulenev   /// divergent `liveIn` property: `value` is not in the `liveIn` set of all
166c412979cSEugene Zhulenev   /// successor blocks.
167a6628e59SEugene Zhulenev   ///
168a6628e59SEugene Zhulenev   /// Example:
169a6628e59SEugene Zhulenev   ///
170a6628e59SEugene Zhulenev   ///   ^entry:
171a6628e59SEugene Zhulenev   ///     %token = async.runtime.create : !async.token
172a6628e59SEugene Zhulenev   ///     cond_br %cond, ^bb1, ^bb2
173a6628e59SEugene Zhulenev   ///   ^bb1:
174a6628e59SEugene Zhulenev   ///     async.runtime.await %token
175c412979cSEugene Zhulenev   ///     async.runtime.drop_ref %token
176c412979cSEugene Zhulenev   ///     br ^bb2
177a6628e59SEugene Zhulenev   ///   ^bb2:
178a6628e59SEugene Zhulenev   ///     return
179a6628e59SEugene Zhulenev   ///
180c412979cSEugene Zhulenev   /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have
181c412979cSEugene Zhulenev   /// to branch into a special "reference counting block" from the ^entry that
182c412979cSEugene Zhulenev   /// will have a `drop_ref` operation, and then branch into the ^bb2.
183c412979cSEugene Zhulenev   ///
184c412979cSEugene Zhulenev   /// After transformation:
185c412979cSEugene Zhulenev   ///
186c412979cSEugene Zhulenev   ///   ^entry:
187c412979cSEugene Zhulenev   ///     %token = async.runtime.create : !async.token
188c412979cSEugene Zhulenev   ///     cond_br %cond, ^bb1, ^reference_counting
189c412979cSEugene Zhulenev   ///   ^bb1:
190c412979cSEugene Zhulenev   ///     async.runtime.await %token
191c412979cSEugene Zhulenev   ///     async.runtime.drop_ref %token
192c412979cSEugene Zhulenev   ///     br ^bb2
193c412979cSEugene Zhulenev   ///   ^reference_counting:
194c412979cSEugene Zhulenev   ///     async.runtime.drop_ref %token
195c412979cSEugene Zhulenev   ///     br ^bb2
196c412979cSEugene Zhulenev   ///   ^bb2:
197c412979cSEugene Zhulenev   ///     return
198a6628e59SEugene Zhulenev   ///
199a6628e59SEugene Zhulenev   /// An exception to this rule are blocks with `async.coro.suspend` terminator,
200a6628e59SEugene Zhulenev   /// because in Async to LLVM lowering it is guaranteed that the control flow
201a6628e59SEugene Zhulenev   /// will jump into the resume block, and then follow into the cleanup and
202a6628e59SEugene Zhulenev   /// suspend blocks.
203a6628e59SEugene Zhulenev   ///
204a6628e59SEugene Zhulenev   /// Example:
205a6628e59SEugene Zhulenev   ///
206a6628e59SEugene Zhulenev   ///  ^entry(%value: !async.value<f32>):
207a6628e59SEugene Zhulenev   ///     async.runtime.await_and_resume %value, %hdl : !async.value<f32>
208a6628e59SEugene Zhulenev   ///     async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
209a6628e59SEugene Zhulenev   ///   ^resume:
210a6628e59SEugene Zhulenev   ///     %0 = async.runtime.load %value
211a6628e59SEugene Zhulenev   ///     br ^cleanup
212a6628e59SEugene Zhulenev   ///   ^cleanup:
213a6628e59SEugene Zhulenev   ///     ...
214a6628e59SEugene Zhulenev   ///   ^suspend:
215a6628e59SEugene Zhulenev   ///     ...
216a6628e59SEugene Zhulenev   ///
217a6628e59SEugene Zhulenev   /// Although cleanup and suspend blocks do not have the `value` in the
218a6628e59SEugene Zhulenev   /// `liveIn` set, it is guaranteed that execution will eventually continue in
219a6628e59SEugene Zhulenev   /// the resume block (we never explicitly destroy coroutines).
220c412979cSEugene Zhulenev   LogicalResult addDropRefInDivergentLivenessSuccessor(Value value);
221a6628e59SEugene Zhulenev };
222a6628e59SEugene Zhulenev 
223a6628e59SEugene Zhulenev } // namespace
224a6628e59SEugene Zhulenev 
225a6628e59SEugene Zhulenev LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) {
226a6628e59SEugene Zhulenev   OpBuilder builder(value.getContext());
227a6628e59SEugene Zhulenev   Location loc = value.getLoc();
228a6628e59SEugene Zhulenev 
229a6628e59SEugene Zhulenev   // Use liveness analysis to find the placement of `drop_ref`operation.
230a6628e59SEugene Zhulenev   auto &liveness = getAnalysis<Liveness>();
231a6628e59SEugene Zhulenev 
232a6628e59SEugene Zhulenev   // We analyse only the blocks of the region that defines the `value`, and do
233a6628e59SEugene Zhulenev   // not check nested blocks attached to operations.
234a6628e59SEugene Zhulenev   //
235a6628e59SEugene Zhulenev   // By analyzing only the `definingRegion` CFG we potentially loose an
236a6628e59SEugene Zhulenev   // opportunity to drop the reference count earlier and can extend the lifetime
237a6628e59SEugene Zhulenev   // of reference counted value longer then it is really required.
238a6628e59SEugene Zhulenev   //
239a6628e59SEugene Zhulenev   // We also assume that all nested regions finish their execution before the
240a6628e59SEugene Zhulenev   // completion of the owner operation. The only exception to this rule is
241a6628e59SEugene Zhulenev   // `async.execute` operation, and we verify that they are lowered to the
242a6628e59SEugene Zhulenev   // `async.runtime` operations before adding automatic reference counting.
243a6628e59SEugene Zhulenev   Region *definingRegion = value.getParentRegion();
244a6628e59SEugene Zhulenev 
245a6628e59SEugene Zhulenev   // Last users of the `value` inside all blocks where the value dies.
246a6628e59SEugene Zhulenev   llvm::SmallSet<Operation *, 4> lastUsers;
247a6628e59SEugene Zhulenev 
248a6628e59SEugene Zhulenev   // Find blocks in the `definingRegion` that have users of the `value` (if
249a6628e59SEugene Zhulenev   // there are multiple users in the block, which one will be selected is
250a6628e59SEugene Zhulenev   // undefined). User operation might be not the actual user of the value, but
251a6628e59SEugene Zhulenev   // the operation in the block that has a "real user" in one of the attached
252a6628e59SEugene Zhulenev   // regions.
253a6628e59SEugene Zhulenev   llvm::DenseMap<Block *, Operation *> usersInTheBlocks;
254a6628e59SEugene Zhulenev 
255a6628e59SEugene Zhulenev   for (Operation *user : value.getUsers()) {
256a6628e59SEugene Zhulenev     Block *userBlock = user->getBlock();
257a6628e59SEugene Zhulenev     Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock);
258a6628e59SEugene Zhulenev     usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user);
259a6628e59SEugene Zhulenev     assert(ancestor && "ancestor block must be not null");
260a6628e59SEugene Zhulenev     assert(usersInTheBlocks[ancestor] && "ancestor op must be not null");
261a6628e59SEugene Zhulenev   }
262a6628e59SEugene Zhulenev 
263a6628e59SEugene Zhulenev   // Find blocks where the `value` dies: the value is in `liveIn` set and not
264a6628e59SEugene Zhulenev   // in the `liveOut` set. We place `drop_ref` immediately after the last use
265a6628e59SEugene Zhulenev   // of the `value` in such regions (after handling few special cases).
266a6628e59SEugene Zhulenev   //
267a6628e59SEugene Zhulenev   // We do not traverse all the blocks in the `definingRegion`, because the
268a6628e59SEugene Zhulenev   // `value` can be in the live in set only if it has users in the block, or it
269a6628e59SEugene Zhulenev   // is defined in the block.
270a6628e59SEugene Zhulenev   //
271a6628e59SEugene Zhulenev   // Values with zero users (only definition) handled explicitly above.
272a6628e59SEugene Zhulenev   for (auto &blockAndUser : usersInTheBlocks) {
273a6628e59SEugene Zhulenev     Block *block = blockAndUser.getFirst();
274a6628e59SEugene Zhulenev     Operation *userInTheBlock = blockAndUser.getSecond();
275a6628e59SEugene Zhulenev 
276a6628e59SEugene Zhulenev     const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block);
277a6628e59SEugene Zhulenev 
278a6628e59SEugene Zhulenev     // Value must be in the live input set or defined in the block.
279a6628e59SEugene Zhulenev     assert(blockLiveness->isLiveIn(value) ||
280a6628e59SEugene Zhulenev            blockLiveness->getBlock() == value.getParentBlock());
281a6628e59SEugene Zhulenev 
282a6628e59SEugene Zhulenev     // If value is in the live out set, it means it doesn't "die" in the block.
283a6628e59SEugene Zhulenev     if (blockLiveness->isLiveOut(value))
284a6628e59SEugene Zhulenev       continue;
285a6628e59SEugene Zhulenev 
286a6628e59SEugene Zhulenev     // At this point we proved that `value` dies in the `block`. Find the last
287a6628e59SEugene Zhulenev     // use of the `value` inside the `block`, this is where it "dies".
288a6628e59SEugene Zhulenev     Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
289a6628e59SEugene Zhulenev     assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
290a6628e59SEugene Zhulenev     lastUsers.insert(lastUser);
291a6628e59SEugene Zhulenev   }
292a6628e59SEugene Zhulenev 
293a6628e59SEugene Zhulenev   // Process all the last users of the `value` inside each block where the value
294a6628e59SEugene Zhulenev   // dies.
295a6628e59SEugene Zhulenev   for (Operation *lastUser : lastUsers) {
296a6628e59SEugene Zhulenev     // Return like operations forward reference count.
297a6628e59SEugene Zhulenev     if (lastUser->hasTrait<OpTrait::ReturnLike>())
298a6628e59SEugene Zhulenev       continue;
299a6628e59SEugene Zhulenev 
300a6628e59SEugene Zhulenev     // We can't currently handle other types of terminators.
301a6628e59SEugene Zhulenev     if (lastUser->hasTrait<OpTrait::IsTerminator>())
302a6628e59SEugene Zhulenev       return lastUser->emitError() << "async reference counting can't handle "
303a6628e59SEugene Zhulenev                                       "terminators that are not ReturnLike";
304a6628e59SEugene Zhulenev 
305a6628e59SEugene Zhulenev     // Add a drop_ref immediately after the last user.
306a6628e59SEugene Zhulenev     builder.setInsertionPointAfter(lastUser);
307*92db09cdSEugene Zhulenev     builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1));
308a6628e59SEugene Zhulenev   }
309a6628e59SEugene Zhulenev 
310a6628e59SEugene Zhulenev   return success();
311a6628e59SEugene Zhulenev }
312a6628e59SEugene Zhulenev 
313a6628e59SEugene Zhulenev LogicalResult
314a6628e59SEugene Zhulenev AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) {
315a6628e59SEugene Zhulenev   OpBuilder builder(value.getContext());
316a6628e59SEugene Zhulenev   Location loc = value.getLoc();
317a6628e59SEugene Zhulenev 
318a6628e59SEugene Zhulenev   for (Operation *user : value.getUsers()) {
319a6628e59SEugene Zhulenev     if (!isa<CallOp>(user))
320a6628e59SEugene Zhulenev       continue;
321a6628e59SEugene Zhulenev 
322a6628e59SEugene Zhulenev     // Add a reference before the function call to pass the value at `+1`
323a6628e59SEugene Zhulenev     // reference to the function entry block.
324a6628e59SEugene Zhulenev     builder.setInsertionPoint(user);
325*92db09cdSEugene Zhulenev     builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1));
326a6628e59SEugene Zhulenev   }
327a6628e59SEugene Zhulenev 
328a6628e59SEugene Zhulenev   return success();
329a6628e59SEugene Zhulenev }
330a6628e59SEugene Zhulenev 
331c412979cSEugene Zhulenev LogicalResult
332c412979cSEugene Zhulenev AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
333c412979cSEugene Zhulenev     Value value) {
334c412979cSEugene Zhulenev   using BlockSet = llvm::SmallPtrSet<Block *, 4>;
335c412979cSEugene Zhulenev 
336a6628e59SEugene Zhulenev   OpBuilder builder(value.getContext());
337a6628e59SEugene Zhulenev 
338c412979cSEugene Zhulenev   // If a block has successors with different `liveIn` property of the `value`,
339c412979cSEugene Zhulenev   // record block successors that do not thave the `value` in the `liveIn` set.
340c412979cSEugene Zhulenev   llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks;
341a6628e59SEugene Zhulenev 
342a6628e59SEugene Zhulenev   // Use liveness analysis to find the placement of `drop_ref`operation.
343a6628e59SEugene Zhulenev   auto &liveness = getAnalysis<Liveness>();
344a6628e59SEugene Zhulenev 
345a6628e59SEugene Zhulenev   // Because we only add `drop_ref` operations to the region that defines the
346a6628e59SEugene Zhulenev   // `value` we can only process CFG for the same region.
347a6628e59SEugene Zhulenev   Region *definingRegion = value.getParentRegion();
348a6628e59SEugene Zhulenev 
349a6628e59SEugene Zhulenev   // Collect blocks with successors with mismatching `liveIn` sets.
350a6628e59SEugene Zhulenev   for (Block &block : definingRegion->getBlocks()) {
351a6628e59SEugene Zhulenev     const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
352a6628e59SEugene Zhulenev 
353a6628e59SEugene Zhulenev     // Skip the block if value is not in the `liveOut` set.
3549136b7d0SEugene Zhulenev     if (!blockLiveness || !blockLiveness->isLiveOut(value))
355a6628e59SEugene Zhulenev       continue;
356a6628e59SEugene Zhulenev 
357c412979cSEugene Zhulenev     BlockSet liveInSuccessors;   // `value` is in `liveIn` set
358c412979cSEugene Zhulenev     BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set
359a6628e59SEugene Zhulenev 
360a6628e59SEugene Zhulenev     // Collect successors that do not have `value` in the `liveIn` set.
361a6628e59SEugene Zhulenev     for (Block *successor : block.getSuccessors()) {
362a6628e59SEugene Zhulenev       const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
3639136b7d0SEugene Zhulenev       if (succLiveness && succLiveness->isLiveIn(value))
364a6628e59SEugene Zhulenev         liveInSuccessors.insert(successor);
365a6628e59SEugene Zhulenev       else
366a6628e59SEugene Zhulenev         noLiveInSuccessors.insert(successor);
367a6628e59SEugene Zhulenev     }
368a6628e59SEugene Zhulenev 
369a6628e59SEugene Zhulenev     // Block has successors with different `liveIn` property of the `value`.
370a6628e59SEugene Zhulenev     if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty())
371c412979cSEugene Zhulenev       divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors);
372a6628e59SEugene Zhulenev   }
373a6628e59SEugene Zhulenev 
374c412979cSEugene Zhulenev   // Try to insert `dropRef` operations to handle blocks with divergent liveness
375c412979cSEugene Zhulenev   // in successors blocks.
376c412979cSEugene Zhulenev   for (auto kv : divergentLivenessBlocks) {
377c412979cSEugene Zhulenev     Block *block = kv.getFirst();
378c412979cSEugene Zhulenev     BlockSet &successors = kv.getSecond();
379c412979cSEugene Zhulenev 
380c412979cSEugene Zhulenev     // Coroutine suspension is a special case terminator for wich we do not
381c412979cSEugene Zhulenev     // need to create additional reference counting (see details above).
382a6628e59SEugene Zhulenev     Operation *terminator = block->getTerminator();
383a6628e59SEugene Zhulenev     if (isa<CoroSuspendOp>(terminator))
384a6628e59SEugene Zhulenev       continue;
385a6628e59SEugene Zhulenev 
386c412979cSEugene Zhulenev     // We only support successor blocks with empty block argument list.
387c412979cSEugene Zhulenev     auto hasArgs = [](Block *block) { return !block->getArguments().empty(); };
388c412979cSEugene Zhulenev     if (llvm::any_of(successors, hasArgs))
389c412979cSEugene Zhulenev       return terminator->emitOpError()
390c412979cSEugene Zhulenev              << "successor have different `liveIn` property of the reference "
391c412979cSEugene Zhulenev                 "counted value";
392c412979cSEugene Zhulenev 
393c412979cSEugene Zhulenev     // Make sure that `dropRef` operation is called when branched into the
394c412979cSEugene Zhulenev     // successor block without `value` in the `liveIn` set.
395c412979cSEugene Zhulenev     for (Block *successor : successors) {
396c412979cSEugene Zhulenev       // If successor has a unique predecessor, it is safe to create `dropRef`
397c412979cSEugene Zhulenev       // operations directly in the successor block.
398c412979cSEugene Zhulenev       //
399c412979cSEugene Zhulenev       // Otherwise we need to create a special block for reference counting
400c412979cSEugene Zhulenev       // operations, and branch from it to the original successor block.
401c412979cSEugene Zhulenev       Block *refCountingBlock = nullptr;
402c412979cSEugene Zhulenev 
403c412979cSEugene Zhulenev       if (successor->getUniquePredecessor() == block) {
404c412979cSEugene Zhulenev         refCountingBlock = successor;
405c412979cSEugene Zhulenev       } else {
406c412979cSEugene Zhulenev         refCountingBlock = &successor->getParent()->emplaceBlock();
407c412979cSEugene Zhulenev         refCountingBlock->moveBefore(successor);
408c412979cSEugene Zhulenev         OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock);
409c412979cSEugene Zhulenev         builder.create<BranchOp>(value.getLoc(), successor);
410c412979cSEugene Zhulenev       }
411c412979cSEugene Zhulenev 
412c412979cSEugene Zhulenev       OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock);
413c412979cSEugene Zhulenev       builder.create<RuntimeDropRefOp>(value.getLoc(), value,
414*92db09cdSEugene Zhulenev                                        builder.getI64IntegerAttr(1));
415c412979cSEugene Zhulenev 
416c412979cSEugene Zhulenev       // No need to update the terminator operation.
417c412979cSEugene Zhulenev       if (successor == refCountingBlock)
418c412979cSEugene Zhulenev         continue;
419c412979cSEugene Zhulenev 
420c412979cSEugene Zhulenev       // Update terminator `successor` block to `refCountingBlock`.
421c412979cSEugene Zhulenev       for (auto pair : llvm::enumerate(terminator->getSuccessors()))
422c412979cSEugene Zhulenev         if (pair.value() == successor)
423c412979cSEugene Zhulenev           terminator->setSuccessor(refCountingBlock, pair.index());
424c412979cSEugene Zhulenev     }
425a6628e59SEugene Zhulenev   }
426a6628e59SEugene Zhulenev 
427a6628e59SEugene Zhulenev   return success();
428a6628e59SEugene Zhulenev }
429a6628e59SEugene Zhulenev 
430a6628e59SEugene Zhulenev LogicalResult
431a6628e59SEugene Zhulenev AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
432f57b2420SEugene Zhulenev   // Short-circuit reference counting for values without uses.
433f57b2420SEugene Zhulenev   if (succeeded(dropRefIfNoUses(value)))
434a6628e59SEugene Zhulenev     return success();
435a6628e59SEugene Zhulenev 
436a6628e59SEugene Zhulenev   // Add `drop_ref` operations based on the liveness analysis.
437a6628e59SEugene Zhulenev   if (failed(addDropRefAfterLastUse(value)))
438a6628e59SEugene Zhulenev     return failure();
439a6628e59SEugene Zhulenev 
440a6628e59SEugene Zhulenev   // Add `add_ref` operations before function calls.
441a6628e59SEugene Zhulenev   if (failed(addAddRefBeforeFunctionCall(value)))
442a6628e59SEugene Zhulenev     return failure();
443a6628e59SEugene Zhulenev 
444c412979cSEugene Zhulenev   // Add `drop_ref` operations to successors with divergent `value` liveness.
445c412979cSEugene Zhulenev   if (failed(addDropRefInDivergentLivenessSuccessor(value)))
446a6628e59SEugene Zhulenev     return failure();
447a6628e59SEugene Zhulenev 
448a6628e59SEugene Zhulenev   return success();
449a6628e59SEugene Zhulenev }
450a6628e59SEugene Zhulenev 
4518a316b00SEugene Zhulenev void AsyncRuntimeRefCountingPass::runOnOperation() {
452f57b2420SEugene Zhulenev   auto functor = [&](Value value) { return addAutomaticRefCounting(value); };
453f57b2420SEugene Zhulenev   if (failed(walkReferenceCountedValues(getOperation(), functor)))
454a6628e59SEugene Zhulenev     signalPassFailure();
455a6628e59SEugene Zhulenev }
456a6628e59SEugene Zhulenev 
457f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===//
458f57b2420SEugene Zhulenev // Reference counting based on the user defined policy.
459f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===//
460f57b2420SEugene Zhulenev 
461f57b2420SEugene Zhulenev namespace {
462f57b2420SEugene Zhulenev 
463f57b2420SEugene Zhulenev class AsyncRuntimePolicyBasedRefCountingPass
464f57b2420SEugene Zhulenev     : public AsyncRuntimePolicyBasedRefCountingBase<
465f57b2420SEugene Zhulenev           AsyncRuntimePolicyBasedRefCountingPass> {
466f57b2420SEugene Zhulenev public:
467f57b2420SEugene Zhulenev   AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); }
468f57b2420SEugene Zhulenev 
469f57b2420SEugene Zhulenev   void runOnOperation() override;
470f57b2420SEugene Zhulenev 
471f57b2420SEugene Zhulenev private:
472f57b2420SEugene Zhulenev   // Adds a reference counting operations for all uses of the `value` according
473f57b2420SEugene Zhulenev   // to the reference counting policy.
474f57b2420SEugene Zhulenev   LogicalResult addRefCounting(Value value);
475f57b2420SEugene Zhulenev 
476f57b2420SEugene Zhulenev   void initializeDefaultPolicy();
477f57b2420SEugene Zhulenev 
478f57b2420SEugene Zhulenev   llvm::SmallVector<std::function<FailureOr<int>(OpOperand &)>> policy;
479f57b2420SEugene Zhulenev };
480f57b2420SEugene Zhulenev 
481f57b2420SEugene Zhulenev } // namespace
482f57b2420SEugene Zhulenev 
483f57b2420SEugene Zhulenev LogicalResult
484f57b2420SEugene Zhulenev AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) {
485f57b2420SEugene Zhulenev   // Short-circuit reference counting for values without uses.
486f57b2420SEugene Zhulenev   if (succeeded(dropRefIfNoUses(value)))
487f57b2420SEugene Zhulenev     return success();
488f57b2420SEugene Zhulenev 
489f57b2420SEugene Zhulenev   OpBuilder b(value.getContext());
490f57b2420SEugene Zhulenev 
491f57b2420SEugene Zhulenev   // Consult the user defined policy for every value use.
492f57b2420SEugene Zhulenev   for (OpOperand &operand : value.getUses()) {
493f57b2420SEugene Zhulenev     Location loc = operand.getOwner()->getLoc();
494f57b2420SEugene Zhulenev 
495f57b2420SEugene Zhulenev     for (auto &func : policy) {
496f57b2420SEugene Zhulenev       FailureOr<int> refCount = func(operand);
497f57b2420SEugene Zhulenev       if (failed(refCount))
498f57b2420SEugene Zhulenev         return failure();
499f57b2420SEugene Zhulenev 
500f57b2420SEugene Zhulenev       int cnt = refCount.getValue();
501f57b2420SEugene Zhulenev 
502f57b2420SEugene Zhulenev       // Create `add_ref` operation before the operand owner.
503f57b2420SEugene Zhulenev       if (cnt > 0) {
504f57b2420SEugene Zhulenev         b.setInsertionPoint(operand.getOwner());
505*92db09cdSEugene Zhulenev         b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt));
506f57b2420SEugene Zhulenev       }
507f57b2420SEugene Zhulenev 
508f57b2420SEugene Zhulenev       // Create `drop_ref` operation after the operand owner.
509f57b2420SEugene Zhulenev       if (cnt < 0) {
510f57b2420SEugene Zhulenev         b.setInsertionPointAfter(operand.getOwner());
511*92db09cdSEugene Zhulenev         b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt));
512f57b2420SEugene Zhulenev       }
513f57b2420SEugene Zhulenev     }
514f57b2420SEugene Zhulenev   }
515f57b2420SEugene Zhulenev 
516f57b2420SEugene Zhulenev   return success();
517f57b2420SEugene Zhulenev }
518f57b2420SEugene Zhulenev 
519f57b2420SEugene Zhulenev void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
520f57b2420SEugene Zhulenev   policy.push_back([](OpOperand &operand) -> FailureOr<int> {
521f57b2420SEugene Zhulenev     Operation *op = operand.getOwner();
522f57b2420SEugene Zhulenev     Type type = operand.get().getType();
523f57b2420SEugene Zhulenev 
524f57b2420SEugene Zhulenev     bool isToken = type.isa<TokenType>();
525f57b2420SEugene Zhulenev     bool isGroup = type.isa<GroupType>();
526f57b2420SEugene Zhulenev     bool isValue = type.isa<ValueType>();
527f57b2420SEugene Zhulenev 
528f57b2420SEugene Zhulenev     // Drop reference after async token or group error check (coro await).
529f57b2420SEugene Zhulenev     if (auto await = dyn_cast<RuntimeIsErrorOp>(op))
530f57b2420SEugene Zhulenev       return (isToken || isGroup) ? -1 : 0;
531f57b2420SEugene Zhulenev 
532f57b2420SEugene Zhulenev     // Drop reference after async value load.
533f57b2420SEugene Zhulenev     if (auto load = dyn_cast<RuntimeLoadOp>(op))
534f57b2420SEugene Zhulenev       return isValue ? -1 : 0;
535f57b2420SEugene Zhulenev 
536f57b2420SEugene Zhulenev     // Drop reference after async token added to the group.
537f57b2420SEugene Zhulenev     if (auto add = dyn_cast<RuntimeAddToGroupOp>(op))
538f57b2420SEugene Zhulenev       return isToken ? -1 : 0;
539f57b2420SEugene Zhulenev 
540f57b2420SEugene Zhulenev     return 0;
541f57b2420SEugene Zhulenev   });
542f57b2420SEugene Zhulenev }
543f57b2420SEugene Zhulenev 
544f57b2420SEugene Zhulenev void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() {
545f57b2420SEugene Zhulenev   auto functor = [&](Value value) { return addRefCounting(value); };
546f57b2420SEugene Zhulenev   if (failed(walkReferenceCountedValues(getOperation(), functor)))
547f57b2420SEugene Zhulenev     signalPassFailure();
548f57b2420SEugene Zhulenev }
549f57b2420SEugene Zhulenev 
550f57b2420SEugene Zhulenev //----------------------------------------------------------------------------//
551f57b2420SEugene Zhulenev 
5528a316b00SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() {
553a6628e59SEugene Zhulenev   return std::make_unique<AsyncRuntimeRefCountingPass>();
554a6628e59SEugene Zhulenev }
555f57b2420SEugene Zhulenev 
556f57b2420SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncRuntimePolicyBasedRefCountingPass() {
557f57b2420SEugene Zhulenev   return std::make_unique<AsyncRuntimePolicyBasedRefCountingPass>();
558f57b2420SEugene Zhulenev }
559