1 //===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements automatic reference counting for Async runtime
10 // operations and types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Analysis/Liveness.h"
16 #include "mlir/Dialect/Async/IR/Async.h"
17 #include "mlir/Dialect/Async/Passes.h"
18 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/IR/ImplicitLocOpBuilder.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "llvm/ADT/SmallSet.h"
24 
25 using namespace mlir;
26 using namespace mlir::async;
27 
28 #define DEBUG_TYPE "async-runtime-ref-counting"
29 
30 //===----------------------------------------------------------------------===//
31 // Utility functions shared by reference counting passes.
32 //===----------------------------------------------------------------------===//
33 
34 // Drop the reference count immediately if the value has no uses.
dropRefIfNoUses(Value value,unsigned count=1)35 static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) {
36   if (!value.getUses().empty())
37     return failure();
38 
39   OpBuilder b(value.getContext());
40 
41   // Set insertion point after the operation producing a value, or at the
42   // beginning of the block if the value defined by the block argument.
43   if (Operation *op = value.getDefiningOp())
44     b.setInsertionPointAfter(op);
45   else
46     b.setInsertionPointToStart(value.getParentBlock());
47 
48   b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI64IntegerAttr(1));
49   return success();
50 }
51 
52 // Calls `addRefCounting` for every reference counted value defined by the
53 // operation `op` (block arguments and values defined in nested regions).
walkReferenceCountedValues(Operation * op,llvm::function_ref<LogicalResult (Value)> addRefCounting)54 static LogicalResult walkReferenceCountedValues(
55     Operation *op, llvm::function_ref<LogicalResult(Value)> addRefCounting) {
56   // Check that we do not have high level async operations in the IR because
57   // otherwise reference counting will produce incorrect results after high
58   // level async operations will be lowered to `async.runtime`
59   WalkResult checkNoAsyncWalk = op->walk([&](Operation *op) -> WalkResult {
60     if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
61       return WalkResult::advance();
62 
63     return op->emitError()
64            << "async operations must be lowered to async runtime operations";
65   });
66 
67   if (checkNoAsyncWalk.wasInterrupted())
68     return failure();
69 
70   // Add reference counting to block arguments.
71   WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
72     for (BlockArgument arg : block->getArguments())
73       if (isRefCounted(arg.getType()))
74         if (failed(addRefCounting(arg)))
75           return WalkResult::interrupt();
76 
77     return WalkResult::advance();
78   });
79 
80   if (blockWalk.wasInterrupted())
81     return failure();
82 
83   // Add reference counting to operation results.
84   WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
85     for (unsigned i = 0; i < op->getNumResults(); ++i)
86       if (isRefCounted(op->getResultTypes()[i]))
87         if (failed(addRefCounting(op->getResult(i))))
88           return WalkResult::interrupt();
89 
90     return WalkResult::advance();
91   });
92 
93   if (opWalk.wasInterrupted())
94     return failure();
95 
96   return success();
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // Automatic reference counting based on the liveness analysis.
101 //===----------------------------------------------------------------------===//
102 
103 namespace {
104 
105 class AsyncRuntimeRefCountingPass
106     : public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
107 public:
108   AsyncRuntimeRefCountingPass() = default;
109   void runOnOperation() override;
110 
111 private:
112   /// Adds an automatic reference counting to the `value`.
113   ///
114   /// All values (token, group or value) are semantically created with a
115   /// reference count of +1 and it is the responsibility of the async value user
116   /// to place the `add_ref` and `drop_ref` operations to ensure that the value
117   /// is destroyed after the last use.
118   ///
119   /// The function returns failure if it can't deduce the locations where
120   /// to place the reference counting operations.
121   ///
122   /// Async values "semantically created" when:
123   ///   1. Operation returns async result (e.g. `async.runtime.create`)
124   ///   2. Async value passed in as a block argument (or function argument,
125   ///      because function arguments are just entry block arguments)
126   ///
127   /// Passing async value as a function argument (or block argument) does not
128   /// really mean that a new async value is created, it only means that the
129   /// caller of a function transfered ownership of `+1` reference to the callee.
130   /// It is convenient to think that from the callee perspective async value was
131   /// "created" with `+1` reference by the block argument.
132   ///
133   /// Automatic reference counting algorithm outline:
134   ///
135   /// #1 Insert `drop_ref` operations after last use of the `value`.
136   /// #2 Insert `add_ref` operations before functions calls with reference
137   ///    counted `value` operand (newly created `+1` reference will be
138   ///    transferred to the callee).
139   /// #3 Verify that divergent control flow does not lead to leaked reference
140   ///    counted objects.
141   ///
142   /// Async runtime reference counting optimization pass will optimize away
143   /// some of the redundant `add_ref` and `drop_ref` operations inserted by this
144   /// strategy (see `async-runtime-ref-counting-opt`).
145   LogicalResult addAutomaticRefCounting(Value value);
146 
147   /// (#1) Adds the `drop_ref` operation after the last use of the `value`
148   /// relying on the liveness analysis.
149   ///
150   /// If the `value` is in the block `liveIn` set and it is not in the block
151   /// `liveOut` set, it means that it "dies" in the block. We find the last
152   /// use of the value in such block and:
153   ///
154   ///   1. If the last user is a `ReturnLike` operation we do nothing, because
155   ///      it forwards the ownership to the caller.
156   ///   2. Otherwise we add a `drop_ref` operation immediately after the last
157   ///      use.
158   LogicalResult addDropRefAfterLastUse(Value value);
159 
160   /// (#2) Adds the `add_ref` operation before the function call taking `value`
161   /// operand to ensure that the value passed to the function entry block
162   /// has a `+1` reference count.
163   LogicalResult addAddRefBeforeFunctionCall(Value value);
164 
165   /// (#3) Adds the `drop_ref` operation to account for successor blocks with
166   /// divergent `liveIn` property: `value` is not in the `liveIn` set of all
167   /// successor blocks.
168   ///
169   /// Example:
170   ///
171   ///   ^entry:
172   ///     %token = async.runtime.create : !async.token
173   ///     cf.cond_br %cond, ^bb1, ^bb2
174   ///   ^bb1:
175   ///     async.runtime.await %token
176   ///     async.runtime.drop_ref %token
177   ///     cf.br ^bb2
178   ///   ^bb2:
179   ///     return
180   ///
181   /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have
182   /// to branch into a special "reference counting block" from the ^entry that
183   /// will have a `drop_ref` operation, and then branch into the ^bb2.
184   ///
185   /// After transformation:
186   ///
187   ///   ^entry:
188   ///     %token = async.runtime.create : !async.token
189   ///     cf.cond_br %cond, ^bb1, ^reference_counting
190   ///   ^bb1:
191   ///     async.runtime.await %token
192   ///     async.runtime.drop_ref %token
193   ///     cf.br ^bb2
194   ///   ^reference_counting:
195   ///     async.runtime.drop_ref %token
196   ///     cf.br ^bb2
197   ///   ^bb2:
198   ///     return
199   ///
200   /// An exception to this rule are blocks with `async.coro.suspend` terminator,
201   /// because in Async to LLVM lowering it is guaranteed that the control flow
202   /// will jump into the resume block, and then follow into the cleanup and
203   /// suspend blocks.
204   ///
205   /// Example:
206   ///
207   ///  ^entry(%value: !async.value<f32>):
208   ///     async.runtime.await_and_resume %value, %hdl : !async.value<f32>
209   ///     async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
210   ///   ^resume:
211   ///     %0 = async.runtime.load %value
212   ///     cf.br ^cleanup
213   ///   ^cleanup:
214   ///     ...
215   ///   ^suspend:
216   ///     ...
217   ///
218   /// Although cleanup and suspend blocks do not have the `value` in the
219   /// `liveIn` set, it is guaranteed that execution will eventually continue in
220   /// the resume block (we never explicitly destroy coroutines).
221   LogicalResult addDropRefInDivergentLivenessSuccessor(Value value);
222 };
223 
224 } // namespace
225 
addDropRefAfterLastUse(Value value)226 LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) {
227   OpBuilder builder(value.getContext());
228   Location loc = value.getLoc();
229 
230   // Use liveness analysis to find the placement of `drop_ref`operation.
231   auto &liveness = getAnalysis<Liveness>();
232 
233   // We analyse only the blocks of the region that defines the `value`, and do
234   // not check nested blocks attached to operations.
235   //
236   // By analyzing only the `definingRegion` CFG we potentially loose an
237   // opportunity to drop the reference count earlier and can extend the lifetime
238   // of reference counted value longer then it is really required.
239   //
240   // We also assume that all nested regions finish their execution before the
241   // completion of the owner operation. The only exception to this rule is
242   // `async.execute` operation, and we verify that they are lowered to the
243   // `async.runtime` operations before adding automatic reference counting.
244   Region *definingRegion = value.getParentRegion();
245 
246   // Last users of the `value` inside all blocks where the value dies.
247   llvm::SmallSet<Operation *, 4> lastUsers;
248 
249   // Find blocks in the `definingRegion` that have users of the `value` (if
250   // there are multiple users in the block, which one will be selected is
251   // undefined). User operation might be not the actual user of the value, but
252   // the operation in the block that has a "real user" in one of the attached
253   // regions.
254   llvm::DenseMap<Block *, Operation *> usersInTheBlocks;
255 
256   for (Operation *user : value.getUsers()) {
257     Block *userBlock = user->getBlock();
258     Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock);
259     usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user);
260     assert(ancestor && "ancestor block must be not null");
261     assert(usersInTheBlocks[ancestor] && "ancestor op must be not null");
262   }
263 
264   // Find blocks where the `value` dies: the value is in `liveIn` set and not
265   // in the `liveOut` set. We place `drop_ref` immediately after the last use
266   // of the `value` in such regions (after handling few special cases).
267   //
268   // We do not traverse all the blocks in the `definingRegion`, because the
269   // `value` can be in the live in set only if it has users in the block, or it
270   // is defined in the block.
271   //
272   // Values with zero users (only definition) handled explicitly above.
273   for (auto &blockAndUser : usersInTheBlocks) {
274     Block *block = blockAndUser.getFirst();
275     Operation *userInTheBlock = blockAndUser.getSecond();
276 
277     const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block);
278 
279     // Value must be in the live input set or defined in the block.
280     assert(blockLiveness->isLiveIn(value) ||
281            blockLiveness->getBlock() == value.getParentBlock());
282 
283     // If value is in the live out set, it means it doesn't "die" in the block.
284     if (blockLiveness->isLiveOut(value))
285       continue;
286 
287     // At this point we proved that `value` dies in the `block`. Find the last
288     // use of the `value` inside the `block`, this is where it "dies".
289     Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
290     assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
291     lastUsers.insert(lastUser);
292   }
293 
294   // Process all the last users of the `value` inside each block where the value
295   // dies.
296   for (Operation *lastUser : lastUsers) {
297     // Return like operations forward reference count.
298     if (lastUser->hasTrait<OpTrait::ReturnLike>())
299       continue;
300 
301     // We can't currently handle other types of terminators.
302     if (lastUser->hasTrait<OpTrait::IsTerminator>())
303       return lastUser->emitError() << "async reference counting can't handle "
304                                       "terminators that are not ReturnLike";
305 
306     // Add a drop_ref immediately after the last user.
307     builder.setInsertionPointAfter(lastUser);
308     builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1));
309   }
310 
311   return success();
312 }
313 
314 LogicalResult
addAddRefBeforeFunctionCall(Value value)315 AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) {
316   OpBuilder builder(value.getContext());
317   Location loc = value.getLoc();
318 
319   for (Operation *user : value.getUsers()) {
320     if (!isa<func::CallOp>(user))
321       continue;
322 
323     // Add a reference before the function call to pass the value at `+1`
324     // reference to the function entry block.
325     builder.setInsertionPoint(user);
326     builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1));
327   }
328 
329   return success();
330 }
331 
332 LogicalResult
addDropRefInDivergentLivenessSuccessor(Value value)333 AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
334     Value value) {
335   using BlockSet = llvm::SmallPtrSet<Block *, 4>;
336 
337   OpBuilder builder(value.getContext());
338 
339   // If a block has successors with different `liveIn` property of the `value`,
340   // record block successors that do not thave the `value` in the `liveIn` set.
341   llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks;
342 
343   // Use liveness analysis to find the placement of `drop_ref`operation.
344   auto &liveness = getAnalysis<Liveness>();
345 
346   // Because we only add `drop_ref` operations to the region that defines the
347   // `value` we can only process CFG for the same region.
348   Region *definingRegion = value.getParentRegion();
349 
350   // Collect blocks with successors with mismatching `liveIn` sets.
351   for (Block &block : definingRegion->getBlocks()) {
352     const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
353 
354     // Skip the block if value is not in the `liveOut` set.
355     if (!blockLiveness || !blockLiveness->isLiveOut(value))
356       continue;
357 
358     BlockSet liveInSuccessors;   // `value` is in `liveIn` set
359     BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set
360 
361     // Collect successors that do not have `value` in the `liveIn` set.
362     for (Block *successor : block.getSuccessors()) {
363       const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
364       if (succLiveness && succLiveness->isLiveIn(value))
365         liveInSuccessors.insert(successor);
366       else
367         noLiveInSuccessors.insert(successor);
368     }
369 
370     // Block has successors with different `liveIn` property of the `value`.
371     if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty())
372       divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors);
373   }
374 
375   // Try to insert `dropRef` operations to handle blocks with divergent liveness
376   // in successors blocks.
377   for (auto kv : divergentLivenessBlocks) {
378     Block *block = kv.getFirst();
379     BlockSet &successors = kv.getSecond();
380 
381     // Coroutine suspension is a special case terminator for wich we do not
382     // need to create additional reference counting (see details above).
383     Operation *terminator = block->getTerminator();
384     if (isa<CoroSuspendOp>(terminator))
385       continue;
386 
387     // We only support successor blocks with empty block argument list.
388     auto hasArgs = [](Block *block) { return !block->getArguments().empty(); };
389     if (llvm::any_of(successors, hasArgs))
390       return terminator->emitOpError()
391              << "successor have different `liveIn` property of the reference "
392                 "counted value";
393 
394     // Make sure that `dropRef` operation is called when branched into the
395     // successor block without `value` in the `liveIn` set.
396     for (Block *successor : successors) {
397       // If successor has a unique predecessor, it is safe to create `dropRef`
398       // operations directly in the successor block.
399       //
400       // Otherwise we need to create a special block for reference counting
401       // operations, and branch from it to the original successor block.
402       Block *refCountingBlock = nullptr;
403 
404       if (successor->getUniquePredecessor() == block) {
405         refCountingBlock = successor;
406       } else {
407         refCountingBlock = &successor->getParent()->emplaceBlock();
408         refCountingBlock->moveBefore(successor);
409         OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock);
410         builder.create<cf::BranchOp>(value.getLoc(), successor);
411       }
412 
413       OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock);
414       builder.create<RuntimeDropRefOp>(value.getLoc(), value,
415                                        builder.getI64IntegerAttr(1));
416 
417       // No need to update the terminator operation.
418       if (successor == refCountingBlock)
419         continue;
420 
421       // Update terminator `successor` block to `refCountingBlock`.
422       for (const auto &pair : llvm::enumerate(terminator->getSuccessors()))
423         if (pair.value() == successor)
424           terminator->setSuccessor(refCountingBlock, pair.index());
425     }
426   }
427 
428   return success();
429 }
430 
431 LogicalResult
addAutomaticRefCounting(Value value)432 AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
433   // Short-circuit reference counting for values without uses.
434   if (succeeded(dropRefIfNoUses(value)))
435     return success();
436 
437   // Add `drop_ref` operations based on the liveness analysis.
438   if (failed(addDropRefAfterLastUse(value)))
439     return failure();
440 
441   // Add `add_ref` operations before function calls.
442   if (failed(addAddRefBeforeFunctionCall(value)))
443     return failure();
444 
445   // Add `drop_ref` operations to successors with divergent `value` liveness.
446   if (failed(addDropRefInDivergentLivenessSuccessor(value)))
447     return failure();
448 
449   return success();
450 }
451 
runOnOperation()452 void AsyncRuntimeRefCountingPass::runOnOperation() {
453   auto functor = [&](Value value) { return addAutomaticRefCounting(value); };
454   if (failed(walkReferenceCountedValues(getOperation(), functor)))
455     signalPassFailure();
456 }
457 
458 //===----------------------------------------------------------------------===//
459 // Reference counting based on the user defined policy.
460 //===----------------------------------------------------------------------===//
461 
462 namespace {
463 
464 class AsyncRuntimePolicyBasedRefCountingPass
465     : public AsyncRuntimePolicyBasedRefCountingBase<
466           AsyncRuntimePolicyBasedRefCountingPass> {
467 public:
AsyncRuntimePolicyBasedRefCountingPass()468   AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); }
469 
470   void runOnOperation() override;
471 
472 private:
473   // Adds a reference counting operations for all uses of the `value` according
474   // to the reference counting policy.
475   LogicalResult addRefCounting(Value value);
476 
477   void initializeDefaultPolicy();
478 
479   llvm::SmallVector<std::function<FailureOr<int>(OpOperand &)>> policy;
480 };
481 
482 } // namespace
483 
484 LogicalResult
addRefCounting(Value value)485 AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) {
486   // Short-circuit reference counting for values without uses.
487   if (succeeded(dropRefIfNoUses(value)))
488     return success();
489 
490   OpBuilder b(value.getContext());
491 
492   // Consult the user defined policy for every value use.
493   for (OpOperand &operand : value.getUses()) {
494     Location loc = operand.getOwner()->getLoc();
495 
496     for (auto &func : policy) {
497       FailureOr<int> refCount = func(operand);
498       if (failed(refCount))
499         return failure();
500 
501       int cnt = *refCount;
502 
503       // Create `add_ref` operation before the operand owner.
504       if (cnt > 0) {
505         b.setInsertionPoint(operand.getOwner());
506         b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt));
507       }
508 
509       // Create `drop_ref` operation after the operand owner.
510       if (cnt < 0) {
511         b.setInsertionPointAfter(operand.getOwner());
512         b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt));
513       }
514     }
515   }
516 
517   return success();
518 }
519 
initializeDefaultPolicy()520 void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
521   policy.push_back([](OpOperand &operand) -> FailureOr<int> {
522     Operation *op = operand.getOwner();
523     Type type = operand.get().getType();
524 
525     bool isToken = type.isa<TokenType>();
526     bool isGroup = type.isa<GroupType>();
527     bool isValue = type.isa<ValueType>();
528 
529     // Drop reference after async token or group error check (coro await).
530     if (auto await = dyn_cast<RuntimeIsErrorOp>(op))
531       return (isToken || isGroup) ? -1 : 0;
532 
533     // Drop reference after async value load.
534     if (auto load = dyn_cast<RuntimeLoadOp>(op))
535       return isValue ? -1 : 0;
536 
537     // Drop reference after async token added to the group.
538     if (auto add = dyn_cast<RuntimeAddToGroupOp>(op))
539       return isToken ? -1 : 0;
540 
541     return 0;
542   });
543 }
544 
runOnOperation()545 void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() {
546   auto functor = [&](Value value) { return addRefCounting(value); };
547   if (failed(walkReferenceCountedValues(getOperation(), functor)))
548     signalPassFailure();
549 }
550 
551 //----------------------------------------------------------------------------//
552 
createAsyncRuntimeRefCountingPass()553 std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() {
554   return std::make_unique<AsyncRuntimeRefCountingPass>();
555 }
556 
createAsyncRuntimePolicyBasedRefCountingPass()557 std::unique_ptr<Pass> mlir::createAsyncRuntimePolicyBasedRefCountingPass() {
558   return std::make_unique<AsyncRuntimePolicyBasedRefCountingPass>();
559 }
560