//===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements automatic reference counting for Async runtime
// operations and types.
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"
#include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallSet.h"

using namespace mlir;
using namespace mlir::async;

#define DEBUG_TYPE "async-runtime-ref-counting"

namespace {

class AsyncRuntimeRefCountingPass
    : public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
public:
  AsyncRuntimeRefCountingPass() = default;
  void runOnOperation() override;

private:
  /// Adds an automatic reference counting to the `value`.
  ///
  /// All values (token, group or value) are semantically created with a
  /// reference count of +1 and it is the responsibility of the async value user
  /// to place the `add_ref` and `drop_ref` operations to ensure that the value
  /// is destroyed after the last use.
  ///
  /// The function returns failure if it can't deduce the locations where
  /// to place the reference counting operations.
  ///
  /// Async values "semantically created" when:
  ///   1. Operation returns async result (e.g. `async.runtime.create`)
  ///   2. Async value passed in as a block argument (or function argument,
  ///      because function arguments are just entry block arguments)
  ///
  /// Passing async value as a function argument (or block argument) does not
  /// really mean that a new async value is created, it only means that the
  /// caller of a function transfered ownership of `+1` reference to the callee.
  /// It is convenient to think that from the callee perspective async value was
  /// "created" with `+1` reference by the block argument.
  ///
  /// Automatic reference counting algorithm outline:
  ///
  /// #1 Insert `drop_ref` operations after last use of the `value`.
  /// #2 Insert `add_ref` operations before functions calls with reference
  ///    counted `value` operand (newly created `+1` reference will be
  ///    transferred to the callee).
  /// #3 Verify that divergent control flow does not lead to leaked reference
  ///    counted objects.
  ///
  /// Async runtime reference counting optimization pass will optimize away
  /// some of the redundant `add_ref` and `drop_ref` operations inserted by this
  /// strategy (see `async-runtime-ref-counting-opt`).
  LogicalResult addAutomaticRefCounting(Value value);

  /// (#1) Adds the `drop_ref` operation after the last use of the `value`
  /// relying on the liveness analysis.
  ///
  /// If the `value` is in the block `liveIn` set and it is not in the block
  /// `liveOut` set, it means that it "dies" in the block. We find the last
  /// use of the value in such block and:
  ///
  ///   1. If the last user is a `ReturnLike` operation we do nothing, because
  ///      it forwards the ownership to the caller.
  ///   2. Otherwise we add a `drop_ref` operation immediately after the last
  ///      use.
  LogicalResult addDropRefAfterLastUse(Value value);

  /// (#2) Adds the `add_ref` operation before the function call taking `value`
  /// operand to ensure that the value passed to the function entry block
  /// has a `+1` reference count.
  LogicalResult addAddRefBeforeFunctionCall(Value value);

  /// (#3) Verifies that if a block has a value in the `liveOut` set, then the
  /// value is in `liveIn` set in all successors.
  ///
  /// Example:
  ///
  ///   ^entry:
  ///     %token = async.runtime.create : !async.token
  ///     cond_br %cond, ^bb1, ^bb2
  ///   ^bb1:
  ///     async.runtime.await %token
  ///     return
  ///   ^bb2:
  ///     return
  ///
  /// This CFG will be rejected because ^bb2 does not have `value` in the
  /// `liveIn` set, and it will leak a reference counted object.
  ///
  /// An exception to this rule are blocks with `async.coro.suspend` terminator,
  /// because in Async to LLVM lowering it is guaranteed that the control flow
  /// will jump into the resume block, and then follow into the cleanup and
  /// suspend blocks.
  ///
  /// Example:
  ///
  ///  ^entry(%value: !async.value<f32>):
  ///     async.runtime.await_and_resume %value, %hdl : !async.value<f32>
  ///     async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
  ///   ^resume:
  ///     %0 = async.runtime.load %value
  ///     br ^cleanup
  ///   ^cleanup:
  ///     ...
  ///   ^suspend:
  ///     ...
  ///
  /// Although cleanup and suspend blocks do not have the `value` in the
  /// `liveIn` set, it is guaranteed that execution will eventually continue in
  /// the resume block (we never explicitly destroy coroutines).
  LogicalResult verifySuccessors(Value value);
};

} // namespace

LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) {
  OpBuilder builder(value.getContext());
  Location loc = value.getLoc();

  // Use liveness analysis to find the placement of `drop_ref`operation.
  auto &liveness = getAnalysis<Liveness>();

  // We analyse only the blocks of the region that defines the `value`, and do
  // not check nested blocks attached to operations.
  //
  // By analyzing only the `definingRegion` CFG we potentially loose an
  // opportunity to drop the reference count earlier and can extend the lifetime
  // of reference counted value longer then it is really required.
  //
  // We also assume that all nested regions finish their execution before the
  // completion of the owner operation. The only exception to this rule is
  // `async.execute` operation, and we verify that they are lowered to the
  // `async.runtime` operations before adding automatic reference counting.
  Region *definingRegion = value.getParentRegion();

  // Last users of the `value` inside all blocks where the value dies.
  llvm::SmallSet<Operation *, 4> lastUsers;

  // Find blocks in the `definingRegion` that have users of the `value` (if
  // there are multiple users in the block, which one will be selected is
  // undefined). User operation might be not the actual user of the value, but
  // the operation in the block that has a "real user" in one of the attached
  // regions.
  llvm::DenseMap<Block *, Operation *> usersInTheBlocks;

  for (Operation *user : value.getUsers()) {
    Block *userBlock = user->getBlock();
    Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock);
    usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user);
    assert(ancestor && "ancestor block must be not null");
    assert(usersInTheBlocks[ancestor] && "ancestor op must be not null");
  }

  // Find blocks where the `value` dies: the value is in `liveIn` set and not
  // in the `liveOut` set. We place `drop_ref` immediately after the last use
  // of the `value` in such regions (after handling few special cases).
  //
  // We do not traverse all the blocks in the `definingRegion`, because the
  // `value` can be in the live in set only if it has users in the block, or it
  // is defined in the block.
  //
  // Values with zero users (only definition) handled explicitly above.
  for (auto &blockAndUser : usersInTheBlocks) {
    Block *block = blockAndUser.getFirst();
    Operation *userInTheBlock = blockAndUser.getSecond();

    const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block);

    // Value must be in the live input set or defined in the block.
    assert(blockLiveness->isLiveIn(value) ||
           blockLiveness->getBlock() == value.getParentBlock());

    // If value is in the live out set, it means it doesn't "die" in the block.
    if (blockLiveness->isLiveOut(value))
      continue;

    // At this point we proved that `value` dies in the `block`. Find the last
    // use of the `value` inside the `block`, this is where it "dies".
    Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
    assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
    lastUsers.insert(lastUser);
  }

  // Process all the last users of the `value` inside each block where the value
  // dies.
  for (Operation *lastUser : lastUsers) {
    // Return like operations forward reference count.
    if (lastUser->hasTrait<OpTrait::ReturnLike>())
      continue;

    // We can't currently handle other types of terminators.
    if (lastUser->hasTrait<OpTrait::IsTerminator>())
      return lastUser->emitError() << "async reference counting can't handle "
                                      "terminators that are not ReturnLike";

    // Add a drop_ref immediately after the last user.
    builder.setInsertionPointAfter(lastUser);
    builder.create<RuntimeDropRefOp>(loc, value, builder.getI32IntegerAttr(1));
  }

  return success();
}

LogicalResult
AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) {
  OpBuilder builder(value.getContext());
  Location loc = value.getLoc();

  for (Operation *user : value.getUsers()) {
    if (!isa<CallOp>(user))
      continue;

    // Add a reference before the function call to pass the value at `+1`
    // reference to the function entry block.
    builder.setInsertionPoint(user);
    builder.create<RuntimeAddRefOp>(loc, value, builder.getI32IntegerAttr(1));
  }

  return success();
}

LogicalResult AsyncRuntimeRefCountingPass::verifySuccessors(Value value) {
  OpBuilder builder(value.getContext());

  // Blocks with successfors with different `liveIn` properties of the `value`.
  llvm::SmallSet<Block *, 4> divergentLivenessBlocks;

  // Use liveness analysis to find the placement of `drop_ref`operation.
  auto &liveness = getAnalysis<Liveness>();

  // Because we only add `drop_ref` operations to the region that defines the
  // `value` we can only process CFG for the same region.
  Region *definingRegion = value.getParentRegion();

  // Collect blocks with successors with mismatching `liveIn` sets.
  for (Block &block : definingRegion->getBlocks()) {
    const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);

    // Skip the block if value is not in the `liveOut` set.
    if (!blockLiveness->isLiveOut(value))
      continue;

    // Sucessors with value in `liveIn` set and not value in `liveIn` set.
    llvm::SmallSet<Block *, 4> liveInSuccessors;
    llvm::SmallSet<Block *, 4> noLiveInSuccessors;

    // Collect successors that do not have `value` in the `liveIn` set.
    for (Block *successor : block.getSuccessors()) {
      const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
      if (succLiveness->isLiveIn(value))
        liveInSuccessors.insert(successor);
      else
        noLiveInSuccessors.insert(successor);
    }

    // Block has successors with different `liveIn` property of the `value`.
    if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty())
      divergentLivenessBlocks.insert(&block);
  }

  // Verify that divergent `liveIn` property only present in blocks with
  // async.coro.suspend terminator.
  for (Block *block : divergentLivenessBlocks) {
    Operation *terminator = block->getTerminator();
    if (isa<CoroSuspendOp>(terminator))
      continue;

    return terminator->emitOpError("successor have different `liveIn` property "
                                   "of the reference counted value: ");
  }

  return success();
}

LogicalResult
AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
  OpBuilder builder(value.getContext());
  Location loc = value.getLoc();

  // Set inserton point after the operation producing a value, or at the
  // beginning of the block if the value defined by the block argument.
  if (Operation *op = value.getDefiningOp())
    builder.setInsertionPointAfter(op);
  else
    builder.setInsertionPointToStart(value.getParentBlock());

  // Drop the reference count immediately if the value has no uses.
  if (value.getUses().empty()) {
    builder.create<RuntimeDropRefOp>(loc, value, builder.getI32IntegerAttr(1));
    return success();
  }

  // Add `drop_ref` operations based on the liveness analysis.
  if (failed(addDropRefAfterLastUse(value)))
    return failure();

  // Add `add_ref` operations before function calls.
  if (failed(addAddRefBeforeFunctionCall(value)))
    return failure();

  // Verify that the `value` is in `liveIn` set of all successors.
  if (failed(verifySuccessors(value)))
    return failure();

  return success();
}

void AsyncRuntimeRefCountingPass::runOnOperation() {
  Operation *op = getOperation();

  // Check that we do not have high level async operations in the IR because
  // otherwise automatic reference counting will produce incorrect results after
  // execute operations will be lowered to `async.runtime`
  WalkResult executeOpWalk = op->walk([&](Operation *op) -> WalkResult {
    if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
      return WalkResult::advance();

    return op->emitError()
           << "async operations must be lowered to async runtime operations";
  });

  if (executeOpWalk.wasInterrupted()) {
    signalPassFailure();
    return;
  }

  // Add reference counting to block arguments.
  WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
    for (BlockArgument arg : block->getArguments())
      if (isRefCounted(arg.getType()))
        if (failed(addAutomaticRefCounting(arg)))
          return WalkResult::interrupt();

    return WalkResult::advance();
  });

  if (blockWalk.wasInterrupted()) {
    signalPassFailure();
    return;
  }

  // Add reference counting to operation results.
  WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
    for (unsigned i = 0; i < op->getNumResults(); ++i)
      if (isRefCounted(op->getResultTypes()[i]))
        if (failed(addAutomaticRefCounting(op->getResult(i))))
          return WalkResult::interrupt();

    return WalkResult::advance();
  });

  if (opWalk.wasInterrupted())
    signalPassFailure();
}

std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() {
  return std::make_unique<AsyncRuntimeRefCountingPass>();
}
