1 //===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements automatic reference counting for Async runtime
10 // operations and types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Analysis/Liveness.h"
16 #include "mlir/Dialect/Async/IR/Async.h"
17 #include "mlir/Dialect/Async/Passes.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "llvm/ADT/SmallSet.h"
23 
24 using namespace mlir;
25 using namespace mlir::async;
26 
27 #define DEBUG_TYPE "async-runtime-ref-counting"
28 
29 //===----------------------------------------------------------------------===//
30 // Utility functions shared by reference counting passes.
31 //===----------------------------------------------------------------------===//
32 
33 // Drop the reference count immediately if the value has no uses.
34 static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) {
35   if (!value.getUses().empty())
36     return failure();
37 
38   OpBuilder b(value.getContext());
39 
40   // Set insertion point after the operation producing a value, or at the
41   // beginning of the block if the value defined by the block argument.
42   if (Operation *op = value.getDefiningOp())
43     b.setInsertionPointAfter(op);
44   else
45     b.setInsertionPointToStart(value.getParentBlock());
46 
47   b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI64IntegerAttr(1));
48   return success();
49 }
50 
51 // Calls `addRefCounting` for every reference counted value defined by the
52 // operation `op` (block arguments and values defined in nested regions).
53 static LogicalResult walkReferenceCountedValues(
54     Operation *op, llvm::function_ref<LogicalResult(Value)> addRefCounting) {
55   // Check that we do not have high level async operations in the IR because
56   // otherwise reference counting will produce incorrect results after high
57   // level async operations will be lowered to `async.runtime`
58   WalkResult checkNoAsyncWalk = op->walk([&](Operation *op) -> WalkResult {
59     if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
60       return WalkResult::advance();
61 
62     return op->emitError()
63            << "async operations must be lowered to async runtime operations";
64   });
65 
66   if (checkNoAsyncWalk.wasInterrupted())
67     return failure();
68 
69   // Add reference counting to block arguments.
70   WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
71     for (BlockArgument arg : block->getArguments())
72       if (isRefCounted(arg.getType()))
73         if (failed(addRefCounting(arg)))
74           return WalkResult::interrupt();
75 
76     return WalkResult::advance();
77   });
78 
79   if (blockWalk.wasInterrupted())
80     return failure();
81 
82   // Add reference counting to operation results.
83   WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
84     for (unsigned i = 0; i < op->getNumResults(); ++i)
85       if (isRefCounted(op->getResultTypes()[i]))
86         if (failed(addRefCounting(op->getResult(i))))
87           return WalkResult::interrupt();
88 
89     return WalkResult::advance();
90   });
91 
92   if (opWalk.wasInterrupted())
93     return failure();
94 
95   return success();
96 }
97 
98 //===----------------------------------------------------------------------===//
99 // Automatic reference counting based on the liveness analysis.
100 //===----------------------------------------------------------------------===//
101 
102 namespace {
103 
104 class AsyncRuntimeRefCountingPass
105     : public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
106 public:
107   AsyncRuntimeRefCountingPass() = default;
108   void runOnOperation() override;
109 
110 private:
111   /// Adds an automatic reference counting to the `value`.
112   ///
113   /// All values (token, group or value) are semantically created with a
114   /// reference count of +1 and it is the responsibility of the async value user
115   /// to place the `add_ref` and `drop_ref` operations to ensure that the value
116   /// is destroyed after the last use.
117   ///
118   /// The function returns failure if it can't deduce the locations where
119   /// to place the reference counting operations.
120   ///
121   /// Async values "semantically created" when:
122   ///   1. Operation returns async result (e.g. `async.runtime.create`)
123   ///   2. Async value passed in as a block argument (or function argument,
124   ///      because function arguments are just entry block arguments)
125   ///
126   /// Passing async value as a function argument (or block argument) does not
127   /// really mean that a new async value is created, it only means that the
128   /// caller of a function transfered ownership of `+1` reference to the callee.
129   /// It is convenient to think that from the callee perspective async value was
130   /// "created" with `+1` reference by the block argument.
131   ///
132   /// Automatic reference counting algorithm outline:
133   ///
134   /// #1 Insert `drop_ref` operations after last use of the `value`.
135   /// #2 Insert `add_ref` operations before functions calls with reference
136   ///    counted `value` operand (newly created `+1` reference will be
137   ///    transferred to the callee).
138   /// #3 Verify that divergent control flow does not lead to leaked reference
139   ///    counted objects.
140   ///
141   /// Async runtime reference counting optimization pass will optimize away
142   /// some of the redundant `add_ref` and `drop_ref` operations inserted by this
143   /// strategy (see `async-runtime-ref-counting-opt`).
144   LogicalResult addAutomaticRefCounting(Value value);
145 
146   /// (#1) Adds the `drop_ref` operation after the last use of the `value`
147   /// relying on the liveness analysis.
148   ///
149   /// If the `value` is in the block `liveIn` set and it is not in the block
150   /// `liveOut` set, it means that it "dies" in the block. We find the last
151   /// use of the value in such block and:
152   ///
153   ///   1. If the last user is a `ReturnLike` operation we do nothing, because
154   ///      it forwards the ownership to the caller.
155   ///   2. Otherwise we add a `drop_ref` operation immediately after the last
156   ///      use.
157   LogicalResult addDropRefAfterLastUse(Value value);
158 
159   /// (#2) Adds the `add_ref` operation before the function call taking `value`
160   /// operand to ensure that the value passed to the function entry block
161   /// has a `+1` reference count.
162   LogicalResult addAddRefBeforeFunctionCall(Value value);
163 
164   /// (#3) Adds the `drop_ref` operation to account for successor blocks with
165   /// divergent `liveIn` property: `value` is not in the `liveIn` set of all
166   /// successor blocks.
167   ///
168   /// Example:
169   ///
170   ///   ^entry:
171   ///     %token = async.runtime.create : !async.token
172   ///     cond_br %cond, ^bb1, ^bb2
173   ///   ^bb1:
174   ///     async.runtime.await %token
175   ///     async.runtime.drop_ref %token
176   ///     br ^bb2
177   ///   ^bb2:
178   ///     return
179   ///
180   /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have
181   /// to branch into a special "reference counting block" from the ^entry that
182   /// will have a `drop_ref` operation, and then branch into the ^bb2.
183   ///
184   /// After transformation:
185   ///
186   ///   ^entry:
187   ///     %token = async.runtime.create : !async.token
188   ///     cond_br %cond, ^bb1, ^reference_counting
189   ///   ^bb1:
190   ///     async.runtime.await %token
191   ///     async.runtime.drop_ref %token
192   ///     br ^bb2
193   ///   ^reference_counting:
194   ///     async.runtime.drop_ref %token
195   ///     br ^bb2
196   ///   ^bb2:
197   ///     return
198   ///
199   /// An exception to this rule are blocks with `async.coro.suspend` terminator,
200   /// because in Async to LLVM lowering it is guaranteed that the control flow
201   /// will jump into the resume block, and then follow into the cleanup and
202   /// suspend blocks.
203   ///
204   /// Example:
205   ///
206   ///  ^entry(%value: !async.value<f32>):
207   ///     async.runtime.await_and_resume %value, %hdl : !async.value<f32>
208   ///     async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
209   ///   ^resume:
210   ///     %0 = async.runtime.load %value
211   ///     br ^cleanup
212   ///   ^cleanup:
213   ///     ...
214   ///   ^suspend:
215   ///     ...
216   ///
217   /// Although cleanup and suspend blocks do not have the `value` in the
218   /// `liveIn` set, it is guaranteed that execution will eventually continue in
219   /// the resume block (we never explicitly destroy coroutines).
220   LogicalResult addDropRefInDivergentLivenessSuccessor(Value value);
221 };
222 
223 } // namespace
224 
225 LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) {
226   OpBuilder builder(value.getContext());
227   Location loc = value.getLoc();
228 
229   // Use liveness analysis to find the placement of `drop_ref`operation.
230   auto &liveness = getAnalysis<Liveness>();
231 
232   // We analyse only the blocks of the region that defines the `value`, and do
233   // not check nested blocks attached to operations.
234   //
235   // By analyzing only the `definingRegion` CFG we potentially loose an
236   // opportunity to drop the reference count earlier and can extend the lifetime
237   // of reference counted value longer then it is really required.
238   //
239   // We also assume that all nested regions finish their execution before the
240   // completion of the owner operation. The only exception to this rule is
241   // `async.execute` operation, and we verify that they are lowered to the
242   // `async.runtime` operations before adding automatic reference counting.
243   Region *definingRegion = value.getParentRegion();
244 
245   // Last users of the `value` inside all blocks where the value dies.
246   llvm::SmallSet<Operation *, 4> lastUsers;
247 
248   // Find blocks in the `definingRegion` that have users of the `value` (if
249   // there are multiple users in the block, which one will be selected is
250   // undefined). User operation might be not the actual user of the value, but
251   // the operation in the block that has a "real user" in one of the attached
252   // regions.
253   llvm::DenseMap<Block *, Operation *> usersInTheBlocks;
254 
255   for (Operation *user : value.getUsers()) {
256     Block *userBlock = user->getBlock();
257     Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock);
258     usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user);
259     assert(ancestor && "ancestor block must be not null");
260     assert(usersInTheBlocks[ancestor] && "ancestor op must be not null");
261   }
262 
263   // Find blocks where the `value` dies: the value is in `liveIn` set and not
264   // in the `liveOut` set. We place `drop_ref` immediately after the last use
265   // of the `value` in such regions (after handling few special cases).
266   //
267   // We do not traverse all the blocks in the `definingRegion`, because the
268   // `value` can be in the live in set only if it has users in the block, or it
269   // is defined in the block.
270   //
271   // Values with zero users (only definition) handled explicitly above.
272   for (auto &blockAndUser : usersInTheBlocks) {
273     Block *block = blockAndUser.getFirst();
274     Operation *userInTheBlock = blockAndUser.getSecond();
275 
276     const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block);
277 
278     // Value must be in the live input set or defined in the block.
279     assert(blockLiveness->isLiveIn(value) ||
280            blockLiveness->getBlock() == value.getParentBlock());
281 
282     // If value is in the live out set, it means it doesn't "die" in the block.
283     if (blockLiveness->isLiveOut(value))
284       continue;
285 
286     // At this point we proved that `value` dies in the `block`. Find the last
287     // use of the `value` inside the `block`, this is where it "dies".
288     Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
289     assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
290     lastUsers.insert(lastUser);
291   }
292 
293   // Process all the last users of the `value` inside each block where the value
294   // dies.
295   for (Operation *lastUser : lastUsers) {
296     // Return like operations forward reference count.
297     if (lastUser->hasTrait<OpTrait::ReturnLike>())
298       continue;
299 
300     // We can't currently handle other types of terminators.
301     if (lastUser->hasTrait<OpTrait::IsTerminator>())
302       return lastUser->emitError() << "async reference counting can't handle "
303                                       "terminators that are not ReturnLike";
304 
305     // Add a drop_ref immediately after the last user.
306     builder.setInsertionPointAfter(lastUser);
307     builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1));
308   }
309 
310   return success();
311 }
312 
313 LogicalResult
314 AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) {
315   OpBuilder builder(value.getContext());
316   Location loc = value.getLoc();
317 
318   for (Operation *user : value.getUsers()) {
319     if (!isa<CallOp>(user))
320       continue;
321 
322     // Add a reference before the function call to pass the value at `+1`
323     // reference to the function entry block.
324     builder.setInsertionPoint(user);
325     builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1));
326   }
327 
328   return success();
329 }
330 
331 LogicalResult
332 AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
333     Value value) {
334   using BlockSet = llvm::SmallPtrSet<Block *, 4>;
335 
336   OpBuilder builder(value.getContext());
337 
338   // If a block has successors with different `liveIn` property of the `value`,
339   // record block successors that do not thave the `value` in the `liveIn` set.
340   llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks;
341 
342   // Use liveness analysis to find the placement of `drop_ref`operation.
343   auto &liveness = getAnalysis<Liveness>();
344 
345   // Because we only add `drop_ref` operations to the region that defines the
346   // `value` we can only process CFG for the same region.
347   Region *definingRegion = value.getParentRegion();
348 
349   // Collect blocks with successors with mismatching `liveIn` sets.
350   for (Block &block : definingRegion->getBlocks()) {
351     const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
352 
353     // Skip the block if value is not in the `liveOut` set.
354     if (!blockLiveness || !blockLiveness->isLiveOut(value))
355       continue;
356 
357     BlockSet liveInSuccessors;   // `value` is in `liveIn` set
358     BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set
359 
360     // Collect successors that do not have `value` in the `liveIn` set.
361     for (Block *successor : block.getSuccessors()) {
362       const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
363       if (succLiveness && succLiveness->isLiveIn(value))
364         liveInSuccessors.insert(successor);
365       else
366         noLiveInSuccessors.insert(successor);
367     }
368 
369     // Block has successors with different `liveIn` property of the `value`.
370     if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty())
371       divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors);
372   }
373 
374   // Try to insert `dropRef` operations to handle blocks with divergent liveness
375   // in successors blocks.
376   for (auto kv : divergentLivenessBlocks) {
377     Block *block = kv.getFirst();
378     BlockSet &successors = kv.getSecond();
379 
380     // Coroutine suspension is a special case terminator for wich we do not
381     // need to create additional reference counting (see details above).
382     Operation *terminator = block->getTerminator();
383     if (isa<CoroSuspendOp>(terminator))
384       continue;
385 
386     // We only support successor blocks with empty block argument list.
387     auto hasArgs = [](Block *block) { return !block->getArguments().empty(); };
388     if (llvm::any_of(successors, hasArgs))
389       return terminator->emitOpError()
390              << "successor have different `liveIn` property of the reference "
391                 "counted value";
392 
393     // Make sure that `dropRef` operation is called when branched into the
394     // successor block without `value` in the `liveIn` set.
395     for (Block *successor : successors) {
396       // If successor has a unique predecessor, it is safe to create `dropRef`
397       // operations directly in the successor block.
398       //
399       // Otherwise we need to create a special block for reference counting
400       // operations, and branch from it to the original successor block.
401       Block *refCountingBlock = nullptr;
402 
403       if (successor->getUniquePredecessor() == block) {
404         refCountingBlock = successor;
405       } else {
406         refCountingBlock = &successor->getParent()->emplaceBlock();
407         refCountingBlock->moveBefore(successor);
408         OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock);
409         builder.create<BranchOp>(value.getLoc(), successor);
410       }
411 
412       OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock);
413       builder.create<RuntimeDropRefOp>(value.getLoc(), value,
414                                        builder.getI64IntegerAttr(1));
415 
416       // No need to update the terminator operation.
417       if (successor == refCountingBlock)
418         continue;
419 
420       // Update terminator `successor` block to `refCountingBlock`.
421       for (auto pair : llvm::enumerate(terminator->getSuccessors()))
422         if (pair.value() == successor)
423           terminator->setSuccessor(refCountingBlock, pair.index());
424     }
425   }
426 
427   return success();
428 }
429 
430 LogicalResult
431 AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
432   // Short-circuit reference counting for values without uses.
433   if (succeeded(dropRefIfNoUses(value)))
434     return success();
435 
436   // Add `drop_ref` operations based on the liveness analysis.
437   if (failed(addDropRefAfterLastUse(value)))
438     return failure();
439 
440   // Add `add_ref` operations before function calls.
441   if (failed(addAddRefBeforeFunctionCall(value)))
442     return failure();
443 
444   // Add `drop_ref` operations to successors with divergent `value` liveness.
445   if (failed(addDropRefInDivergentLivenessSuccessor(value)))
446     return failure();
447 
448   return success();
449 }
450 
451 void AsyncRuntimeRefCountingPass::runOnOperation() {
452   auto functor = [&](Value value) { return addAutomaticRefCounting(value); };
453   if (failed(walkReferenceCountedValues(getOperation(), functor)))
454     signalPassFailure();
455 }
456 
457 //===----------------------------------------------------------------------===//
458 // Reference counting based on the user defined policy.
459 //===----------------------------------------------------------------------===//
460 
461 namespace {
462 
463 class AsyncRuntimePolicyBasedRefCountingPass
464     : public AsyncRuntimePolicyBasedRefCountingBase<
465           AsyncRuntimePolicyBasedRefCountingPass> {
466 public:
467   AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); }
468 
469   void runOnOperation() override;
470 
471 private:
472   // Adds a reference counting operations for all uses of the `value` according
473   // to the reference counting policy.
474   LogicalResult addRefCounting(Value value);
475 
476   void initializeDefaultPolicy();
477 
478   llvm::SmallVector<std::function<FailureOr<int>(OpOperand &)>> policy;
479 };
480 
481 } // namespace
482 
483 LogicalResult
484 AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) {
485   // Short-circuit reference counting for values without uses.
486   if (succeeded(dropRefIfNoUses(value)))
487     return success();
488 
489   OpBuilder b(value.getContext());
490 
491   // Consult the user defined policy for every value use.
492   for (OpOperand &operand : value.getUses()) {
493     Location loc = operand.getOwner()->getLoc();
494 
495     for (auto &func : policy) {
496       FailureOr<int> refCount = func(operand);
497       if (failed(refCount))
498         return failure();
499 
500       int cnt = refCount.getValue();
501 
502       // Create `add_ref` operation before the operand owner.
503       if (cnt > 0) {
504         b.setInsertionPoint(operand.getOwner());
505         b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt));
506       }
507 
508       // Create `drop_ref` operation after the operand owner.
509       if (cnt < 0) {
510         b.setInsertionPointAfter(operand.getOwner());
511         b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt));
512       }
513     }
514   }
515 
516   return success();
517 }
518 
519 void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
520   policy.push_back([](OpOperand &operand) -> FailureOr<int> {
521     Operation *op = operand.getOwner();
522     Type type = operand.get().getType();
523 
524     bool isToken = type.isa<TokenType>();
525     bool isGroup = type.isa<GroupType>();
526     bool isValue = type.isa<ValueType>();
527 
528     // Drop reference after async token or group error check (coro await).
529     if (auto await = dyn_cast<RuntimeIsErrorOp>(op))
530       return (isToken || isGroup) ? -1 : 0;
531 
532     // Drop reference after async value load.
533     if (auto load = dyn_cast<RuntimeLoadOp>(op))
534       return isValue ? -1 : 0;
535 
536     // Drop reference after async token added to the group.
537     if (auto add = dyn_cast<RuntimeAddToGroupOp>(op))
538       return isToken ? -1 : 0;
539 
540     return 0;
541   });
542 }
543 
544 void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() {
545   auto functor = [&](Value value) { return addRefCounting(value); };
546   if (failed(walkReferenceCountedValues(getOperation(), functor)))
547     signalPassFailure();
548 }
549 
550 //----------------------------------------------------------------------------//
551 
552 std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() {
553   return std::make_unique<AsyncRuntimeRefCountingPass>();
554 }
555 
556 std::unique_ptr<Pass> mlir::createAsyncRuntimePolicyBasedRefCountingPass() {
557   return std::make_unique<AsyncRuntimePolicyBasedRefCountingPass>();
558 }
559