1 //===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements automatic reference counting for Async runtime
10 // operations and types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Analysis/Liveness.h"
16 #include "mlir/Dialect/Async/IR/Async.h"
17 #include "mlir/Dialect/Async/Passes.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "llvm/ADT/SmallSet.h"
23 
24 using namespace mlir;
25 using namespace mlir::async;
26 
27 #define DEBUG_TYPE "async-runtime-ref-counting"
28 
29 namespace {
30 
31 class AsyncRuntimeRefCountingPass
32     : public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
33 public:
34   AsyncRuntimeRefCountingPass() = default;
35   void runOnOperation() override;
36 
37 private:
38   /// Adds an automatic reference counting to the `value`.
39   ///
40   /// All values (token, group or value) are semantically created with a
41   /// reference count of +1 and it is the responsibility of the async value user
42   /// to place the `add_ref` and `drop_ref` operations to ensure that the value
43   /// is destroyed after the last use.
44   ///
45   /// The function returns failure if it can't deduce the locations where
46   /// to place the reference counting operations.
47   ///
48   /// Async values "semantically created" when:
49   ///   1. Operation returns async result (e.g. `async.runtime.create`)
50   ///   2. Async value passed in as a block argument (or function argument,
51   ///      because function arguments are just entry block arguments)
52   ///
53   /// Passing async value as a function argument (or block argument) does not
54   /// really mean that a new async value is created, it only means that the
55   /// caller of a function transfered ownership of `+1` reference to the callee.
56   /// It is convenient to think that from the callee perspective async value was
57   /// "created" with `+1` reference by the block argument.
58   ///
59   /// Automatic reference counting algorithm outline:
60   ///
61   /// #1 Insert `drop_ref` operations after last use of the `value`.
62   /// #2 Insert `add_ref` operations before functions calls with reference
63   ///    counted `value` operand (newly created `+1` reference will be
64   ///    transferred to the callee).
65   /// #3 Verify that divergent control flow does not lead to leaked reference
66   ///    counted objects.
67   ///
68   /// Async runtime reference counting optimization pass will optimize away
69   /// some of the redundant `add_ref` and `drop_ref` operations inserted by this
70   /// strategy (see `async-runtime-ref-counting-opt`).
71   LogicalResult addAutomaticRefCounting(Value value);
72 
73   /// (#1) Adds the `drop_ref` operation after the last use of the `value`
74   /// relying on the liveness analysis.
75   ///
76   /// If the `value` is in the block `liveIn` set and it is not in the block
77   /// `liveOut` set, it means that it "dies" in the block. We find the last
78   /// use of the value in such block and:
79   ///
80   ///   1. If the last user is a `ReturnLike` operation we do nothing, because
81   ///      it forwards the ownership to the caller.
82   ///   2. Otherwise we add a `drop_ref` operation immediately after the last
83   ///      use.
84   LogicalResult addDropRefAfterLastUse(Value value);
85 
86   /// (#2) Adds the `add_ref` operation before the function call taking `value`
87   /// operand to ensure that the value passed to the function entry block
88   /// has a `+1` reference count.
89   LogicalResult addAddRefBeforeFunctionCall(Value value);
90 
91   /// (#3) Adds the `drop_ref` operation to account for successor blocks with
92   /// divergent `liveIn` property: `value` is not in the `liveIn` set of all
93   /// successor blocks.
94   ///
95   /// Example:
96   ///
97   ///   ^entry:
98   ///     %token = async.runtime.create : !async.token
99   ///     cond_br %cond, ^bb1, ^bb2
100   ///   ^bb1:
101   ///     async.runtime.await %token
102   ///     async.runtime.drop_ref %token
103   ///     br ^bb2
104   ///   ^bb2:
105   ///     return
106   ///
107   /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have
108   /// to branch into a special "reference counting block" from the ^entry that
109   /// will have a `drop_ref` operation, and then branch into the ^bb2.
110   ///
111   /// After transformation:
112   ///
113   ///   ^entry:
114   ///     %token = async.runtime.create : !async.token
115   ///     cond_br %cond, ^bb1, ^reference_counting
116   ///   ^bb1:
117   ///     async.runtime.await %token
118   ///     async.runtime.drop_ref %token
119   ///     br ^bb2
120   ///   ^reference_counting:
121   ///     async.runtime.drop_ref %token
122   ///     br ^bb2
123   ///   ^bb2:
124   ///     return
125   ///
126   /// An exception to this rule are blocks with `async.coro.suspend` terminator,
127   /// because in Async to LLVM lowering it is guaranteed that the control flow
128   /// will jump into the resume block, and then follow into the cleanup and
129   /// suspend blocks.
130   ///
131   /// Example:
132   ///
133   ///  ^entry(%value: !async.value<f32>):
134   ///     async.runtime.await_and_resume %value, %hdl : !async.value<f32>
135   ///     async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
136   ///   ^resume:
137   ///     %0 = async.runtime.load %value
138   ///     br ^cleanup
139   ///   ^cleanup:
140   ///     ...
141   ///   ^suspend:
142   ///     ...
143   ///
144   /// Although cleanup and suspend blocks do not have the `value` in the
145   /// `liveIn` set, it is guaranteed that execution will eventually continue in
146   /// the resume block (we never explicitly destroy coroutines).
147   LogicalResult addDropRefInDivergentLivenessSuccessor(Value value);
148 };
149 
150 } // namespace
151 
152 LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) {
153   OpBuilder builder(value.getContext());
154   Location loc = value.getLoc();
155 
156   // Use liveness analysis to find the placement of `drop_ref`operation.
157   auto &liveness = getAnalysis<Liveness>();
158 
159   // We analyse only the blocks of the region that defines the `value`, and do
160   // not check nested blocks attached to operations.
161   //
162   // By analyzing only the `definingRegion` CFG we potentially loose an
163   // opportunity to drop the reference count earlier and can extend the lifetime
164   // of reference counted value longer then it is really required.
165   //
166   // We also assume that all nested regions finish their execution before the
167   // completion of the owner operation. The only exception to this rule is
168   // `async.execute` operation, and we verify that they are lowered to the
169   // `async.runtime` operations before adding automatic reference counting.
170   Region *definingRegion = value.getParentRegion();
171 
172   // Last users of the `value` inside all blocks where the value dies.
173   llvm::SmallSet<Operation *, 4> lastUsers;
174 
175   // Find blocks in the `definingRegion` that have users of the `value` (if
176   // there are multiple users in the block, which one will be selected is
177   // undefined). User operation might be not the actual user of the value, but
178   // the operation in the block that has a "real user" in one of the attached
179   // regions.
180   llvm::DenseMap<Block *, Operation *> usersInTheBlocks;
181 
182   for (Operation *user : value.getUsers()) {
183     Block *userBlock = user->getBlock();
184     Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock);
185     usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user);
186     assert(ancestor && "ancestor block must be not null");
187     assert(usersInTheBlocks[ancestor] && "ancestor op must be not null");
188   }
189 
190   // Find blocks where the `value` dies: the value is in `liveIn` set and not
191   // in the `liveOut` set. We place `drop_ref` immediately after the last use
192   // of the `value` in such regions (after handling few special cases).
193   //
194   // We do not traverse all the blocks in the `definingRegion`, because the
195   // `value` can be in the live in set only if it has users in the block, or it
196   // is defined in the block.
197   //
198   // Values with zero users (only definition) handled explicitly above.
199   for (auto &blockAndUser : usersInTheBlocks) {
200     Block *block = blockAndUser.getFirst();
201     Operation *userInTheBlock = blockAndUser.getSecond();
202 
203     const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block);
204 
205     // Value must be in the live input set or defined in the block.
206     assert(blockLiveness->isLiveIn(value) ||
207            blockLiveness->getBlock() == value.getParentBlock());
208 
209     // If value is in the live out set, it means it doesn't "die" in the block.
210     if (blockLiveness->isLiveOut(value))
211       continue;
212 
213     // At this point we proved that `value` dies in the `block`. Find the last
214     // use of the `value` inside the `block`, this is where it "dies".
215     Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
216     assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
217     lastUsers.insert(lastUser);
218   }
219 
220   // Process all the last users of the `value` inside each block where the value
221   // dies.
222   for (Operation *lastUser : lastUsers) {
223     // Return like operations forward reference count.
224     if (lastUser->hasTrait<OpTrait::ReturnLike>())
225       continue;
226 
227     // We can't currently handle other types of terminators.
228     if (lastUser->hasTrait<OpTrait::IsTerminator>())
229       return lastUser->emitError() << "async reference counting can't handle "
230                                       "terminators that are not ReturnLike";
231 
232     // Add a drop_ref immediately after the last user.
233     builder.setInsertionPointAfter(lastUser);
234     builder.create<RuntimeDropRefOp>(loc, value, builder.getI32IntegerAttr(1));
235   }
236 
237   return success();
238 }
239 
240 LogicalResult
241 AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) {
242   OpBuilder builder(value.getContext());
243   Location loc = value.getLoc();
244 
245   for (Operation *user : value.getUsers()) {
246     if (!isa<CallOp>(user))
247       continue;
248 
249     // Add a reference before the function call to pass the value at `+1`
250     // reference to the function entry block.
251     builder.setInsertionPoint(user);
252     builder.create<RuntimeAddRefOp>(loc, value, builder.getI32IntegerAttr(1));
253   }
254 
255   return success();
256 }
257 
258 LogicalResult
259 AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
260     Value value) {
261   using BlockSet = llvm::SmallPtrSet<Block *, 4>;
262 
263   OpBuilder builder(value.getContext());
264 
265   // If a block has successors with different `liveIn` property of the `value`,
266   // record block successors that do not thave the `value` in the `liveIn` set.
267   llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks;
268 
269   // Use liveness analysis to find the placement of `drop_ref`operation.
270   auto &liveness = getAnalysis<Liveness>();
271 
272   // Because we only add `drop_ref` operations to the region that defines the
273   // `value` we can only process CFG for the same region.
274   Region *definingRegion = value.getParentRegion();
275 
276   // Collect blocks with successors with mismatching `liveIn` sets.
277   for (Block &block : definingRegion->getBlocks()) {
278     const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
279 
280     // Skip the block if value is not in the `liveOut` set.
281     if (!blockLiveness || !blockLiveness->isLiveOut(value))
282       continue;
283 
284     BlockSet liveInSuccessors;   // `value` is in `liveIn` set
285     BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set
286 
287     // Collect successors that do not have `value` in the `liveIn` set.
288     for (Block *successor : block.getSuccessors()) {
289       const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
290       if (succLiveness && succLiveness->isLiveIn(value))
291         liveInSuccessors.insert(successor);
292       else
293         noLiveInSuccessors.insert(successor);
294     }
295 
296     // Block has successors with different `liveIn` property of the `value`.
297     if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty())
298       divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors);
299   }
300 
301   // Try to insert `dropRef` operations to handle blocks with divergent liveness
302   // in successors blocks.
303   for (auto kv : divergentLivenessBlocks) {
304     Block *block = kv.getFirst();
305     BlockSet &successors = kv.getSecond();
306 
307     // Coroutine suspension is a special case terminator for wich we do not
308     // need to create additional reference counting (see details above).
309     Operation *terminator = block->getTerminator();
310     if (isa<CoroSuspendOp>(terminator))
311       continue;
312 
313     // We only support successor blocks with empty block argument list.
314     auto hasArgs = [](Block *block) { return !block->getArguments().empty(); };
315     if (llvm::any_of(successors, hasArgs))
316       return terminator->emitOpError()
317              << "successor have different `liveIn` property of the reference "
318                 "counted value";
319 
320     // Make sure that `dropRef` operation is called when branched into the
321     // successor block without `value` in the `liveIn` set.
322     for (Block *successor : successors) {
323       // If successor has a unique predecessor, it is safe to create `dropRef`
324       // operations directly in the successor block.
325       //
326       // Otherwise we need to create a special block for reference counting
327       // operations, and branch from it to the original successor block.
328       Block *refCountingBlock = nullptr;
329 
330       if (successor->getUniquePredecessor() == block) {
331         refCountingBlock = successor;
332       } else {
333         refCountingBlock = &successor->getParent()->emplaceBlock();
334         refCountingBlock->moveBefore(successor);
335         OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock);
336         builder.create<BranchOp>(value.getLoc(), successor);
337       }
338 
339       OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock);
340       builder.create<RuntimeDropRefOp>(value.getLoc(), value,
341                                        builder.getI32IntegerAttr(1));
342 
343       // No need to update the terminator operation.
344       if (successor == refCountingBlock)
345         continue;
346 
347       // Update terminator `successor` block to `refCountingBlock`.
348       for (auto pair : llvm::enumerate(terminator->getSuccessors()))
349         if (pair.value() == successor)
350           terminator->setSuccessor(refCountingBlock, pair.index());
351     }
352   }
353 
354   return success();
355 }
356 
357 LogicalResult
358 AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
359   OpBuilder builder(value.getContext());
360   Location loc = value.getLoc();
361 
362   // Set inserton point after the operation producing a value, or at the
363   // beginning of the block if the value defined by the block argument.
364   if (Operation *op = value.getDefiningOp())
365     builder.setInsertionPointAfter(op);
366   else
367     builder.setInsertionPointToStart(value.getParentBlock());
368 
369   // Drop the reference count immediately if the value has no uses.
370   if (value.getUses().empty()) {
371     builder.create<RuntimeDropRefOp>(loc, value, builder.getI32IntegerAttr(1));
372     return success();
373   }
374 
375   // Add `drop_ref` operations based on the liveness analysis.
376   if (failed(addDropRefAfterLastUse(value)))
377     return failure();
378 
379   // Add `add_ref` operations before function calls.
380   if (failed(addAddRefBeforeFunctionCall(value)))
381     return failure();
382 
383   // Add `drop_ref` operations to successors with divergent `value` liveness.
384   if (failed(addDropRefInDivergentLivenessSuccessor(value)))
385     return failure();
386 
387   return success();
388 }
389 
390 void AsyncRuntimeRefCountingPass::runOnOperation() {
391   Operation *op = getOperation();
392 
393   // Check that we do not have high level async operations in the IR because
394   // otherwise automatic reference counting will produce incorrect results after
395   // execute operations will be lowered to `async.runtime`
396   WalkResult executeOpWalk = op->walk([&](Operation *op) -> WalkResult {
397     if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
398       return WalkResult::advance();
399 
400     return op->emitError()
401            << "async operations must be lowered to async runtime operations";
402   });
403 
404   if (executeOpWalk.wasInterrupted()) {
405     signalPassFailure();
406     return;
407   }
408 
409   // Add reference counting to block arguments.
410   WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
411     for (BlockArgument arg : block->getArguments())
412       if (isRefCounted(arg.getType()))
413         if (failed(addAutomaticRefCounting(arg)))
414           return WalkResult::interrupt();
415 
416     return WalkResult::advance();
417   });
418 
419   if (blockWalk.wasInterrupted()) {
420     signalPassFailure();
421     return;
422   }
423 
424   // Add reference counting to operation results.
425   WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
426     for (unsigned i = 0; i < op->getNumResults(); ++i)
427       if (isRefCounted(op->getResultTypes()[i]))
428         if (failed(addAutomaticRefCounting(op->getResult(i))))
429           return WalkResult::interrupt();
430 
431     return WalkResult::advance();
432   });
433 
434   if (opWalk.wasInterrupted())
435     signalPassFailure();
436 }
437 
438 std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() {
439   return std::make_unique<AsyncRuntimeRefCountingPass>();
440 }
441