//===- Detensorize.cpp - Linalg transformations as patterns ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <iterator>
#include <memory>
#include <utility>

using namespace mlir;
using namespace mlir::linalg;

static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
                                           ValueRange inputs, Location loc) {
  assert(inputs.size() == 1);
  auto inputType = inputs[0].getType();
  if (inputType.isa<TensorType>())
    return nullptr;

  // A detensored value is converted back by creating a new tensor from its
  // element(s).
  return builder.create<tensor::FromElementsOp>(
      loc, RankedTensorType::get({}, inputType), inputs[0]);
}

namespace {
/// Defines the criteria a TensorType must follow in order to be considered
/// "detensorable".
///
/// NOTE: For now, only 0-D tensors are supported.
///
/// Returns true if tensorType can be detensored.
bool canBeDetensored(TensorType tensorType) {
  return tensorType.hasRank() && tensorType.getRank() == 0;
}

bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
  GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
  return genericOp &&
         llvm::all_of(
             genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
               return !typeConverter.isLegal(opOperand->get().getType());
             });
}

/// A conversion patttern for detensoring `linalg.generic` ops.
class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
public:
  using OpConversionPattern::OpConversionPattern;
  LogicalResult
  matchAndRewrite(GenericOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Block *originalBlock = op->getBlock();

    // Gather some information about the op before inling its region.
    Block *opEntryBlock = &*op.region().begin();
    YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());

    // Split the op's region before the op. This way, we have a clear insertion
    // point in which the op can be inlined.
    Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));
    rewriter.inlineRegionBefore(op.region(), newBlock);
    // Now that op's region is inlined, the operands of its YieldOp are mapped
    // to the materialized target values. Therefore, we can replace the op's
    // uses with those of its YielOp's operands.
    rewriter.replaceOp(op, yieldOp->getOperands());

    // No need for these intermediate blocks, merge them into 1.
    rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
    rewriter.mergeBlocks(newBlock, originalBlock, {});

    rewriter.eraseOp(&*Block::iterator(yieldOp));

    return success();
  }
};

/// A conversion pattern for detensoring internal (non-entry) blocks within a
/// function.
struct FunctionNonEntryBlockConversion
    : public OpInterfaceConversionPattern<FunctionOpInterface> {
  FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
                                  DenseSet<BlockArgument> blockArgsToDetensor)
      : OpInterfaceConversionPattern(converter, ctx),
        blockArgsToDetensor(std::move(blockArgsToDetensor)) {}

  LogicalResult
  matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    rewriter.startRootUpdate(op);
    Region &region = op.getBody();
    SmallVector<TypeConverter::SignatureConversion, 2> conversions;

    for (Block &block : llvm::drop_begin(region, 1)) {
      conversions.emplace_back(block.getNumArguments());
      TypeConverter::SignatureConversion &back = conversions.back();

      for (BlockArgument blockArgument : block.getArguments()) {
        int idx = blockArgument.getArgNumber();

        if (blockArgsToDetensor.count(blockArgument))
          back.addInputs(idx, {getTypeConverter()->convertType(
                                  block.getArgumentTypes()[idx])});
        else
          back.addInputs(idx, {block.getArgumentTypes()[idx]});
      }
    }

    if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
                                                   conversions))) {
      rewriter.cancelRootUpdate(op);
      return failure();
    }

    rewriter.finalizeRootUpdate(op);
    return success();
  }

private:
  const DenseSet<BlockArgument> blockArgsToDetensor;
};

class DetensorizeTypeConverter : public TypeConverter {
public:
  DetensorizeTypeConverter() {
    addConversion([](Type type) { return type; });

    // A TensorType that can be detensored, is converted to the underlying
    // element type.
    addConversion([](TensorType tensorType) -> Type {
      if (canBeDetensored(tensorType))
        return tensorType.getElementType();

      return tensorType;
    });

    // A tensor value is detensoried by extracting its element(s).
    addTargetMaterialization([](OpBuilder &builder, Type type,
                                ValueRange inputs, Location loc) -> Value {
      return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
    });

    addSourceMaterialization(sourceMaterializationCallback);
    addArgumentMaterialization(sourceMaterializationCallback);
  }
};

/// @see LinalgDetensorize in Linalg/Passes.td for more details.
struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
  LinalgDetensorize() = default;

  class CostModel {
  public:
    virtual ~CostModel() = default;

    /// A cost model algorithm computes the following outputs:
    ///
    /// - opsToDetensor: the list of linalg ops that should be
    /// detensored.
    ///
    /// - blockArgsToDetensor: since the operands and results of detensored
    /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
    /// from a BB argument and a linalg op's output can be passed to successor
    /// BBs), we need to maintain the sub-set of arguments that should be
    /// detensored (i.e. converted by typeConverter) for each affected BB.
    ///
    /// Example:
    ///
    /// For the following snippet:
    /// ...
    /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
    ///   %7 = linalg.init_tensor [] : tensor<i32>
    ///   %8 = linalg.generic #attrs
    ///     ins(%6, %6 : tensor<i32>, tensor<i32>)
    ///     outs(%7 : tensor<i32>) {
    ///     ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
    ///       %9 = arith.addi %arg0, %arg1 : i32
    ///       linalg.yield %9 : i32
    ///   } -> tensor<i32>
    ///   %10 = "some.op"(%9)
    ///   br ^bb2(%8 : tensor<i32>)
    /// ...
    ///
    /// if the cost model decides that the linalg.generic op should be
    /// detensored, then:
    /// - opsToDetensor should be = {linalg.generic{add}}.
    /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
    virtual void compute(FunctionOpInterface func,
                         DetensorizeTypeConverter typeConverter,
                         DenseSet<Operation *> &opsToDetensor,
                         DenseSet<BlockArgument> &blockArgsToDetensor) = 0;

    /// From the blockArgsToDetensor set computed by a CostModel
    /// implementation, this method computes the corresponding branch op
    /// detensoring. The result is a map from a branch op to a subset of indices
    /// of its operands. The indices specify which of the branch op's operands
    /// should be detensored.
    ///
    /// For the previous example, this method would compute: {bb2 -> {0}}.
    static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
        const DenseSet<BlockArgument> &blockArgsToDetensor) {
      DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;

      for (auto blockArgumentElem : blockArgsToDetensor) {
        Block *block = blockArgumentElem.getOwner();

        for (PredecessorIterator pred = block->pred_begin();
             pred != block->pred_end(); ++pred) {
          BranchOpInterface terminator =
              dyn_cast<BranchOpInterface>((*pred)->getTerminator());
          auto blockOperands =
              terminator.getSuccessorOperands(pred.getSuccessorIndex());

          if (!blockOperands || blockOperands->empty())
            continue;

          detensorableBranchOps[terminator].insert(
              blockOperands->getBeginOperandIndex() +
              blockArgumentElem.getArgNumber());
        }
      }

      return detensorableBranchOps;
    }
  };

  /// Detensorize linalg ops involved in control-flow within a function.
  ///
  /// This model starts from BranchOps and CondBranchOps within a function. For
  /// each such branch, the model then walks the use-def chain for the branch's
  /// condition backwards in order to understand where the condition's value
  /// comes from. If the condition value is (indirectly) computed by a linalg op
  /// that can be detensored, the model then continues walking the use-def chain
  /// in order to understand where the linalg op's operands come from. This
  /// leads to discovering a "detensoring component". A detensoring component is
  /// the set of operations + block arguments that are involved in control-flow
  /// AND can be detensored.
  class ControlFlowDetectionModel : public CostModel {
  public:
    void compute(FunctionOpInterface func,
                 DetensorizeTypeConverter typeConverter,
                 DenseSet<Operation *> &opsToDetensor,
                 DenseSet<BlockArgument> &blockArgsToDetensor) override {
      SmallVector<Value> workList;

      func->walk([&](cf::CondBranchOp condBr) {
        for (auto operand : condBr.getOperands()) {
          workList.push_back(operand);
        }
      });

      func->walk([&](cf::BranchOp br) {
        for (auto operand : br.getOperands()) {
          workList.push_back(operand);
        }
      });

      DenseSet<Value> visitedValues;
      DenseSet<Operation *> visitedOps;

      // For a (to-be-detesored) value, check if it "escapes" the block by being
      // passed to terminator. If it does, then workList is updated with the
      // corresponding argument to the successor block.
      auto updateWorkListWithSuccessorArguments =
          [&](Value value, BranchOpInterface terminator) {
            if (!terminator)
              return;

            for (auto operandIdx :
                 llvm::seq<unsigned>(0, terminator->getOperands().size())) {
              Value operand = terminator->getOperand(operandIdx);

              if (operand == value) {
                auto succBlockArg =
                    terminator.getSuccessorBlockArgument(operandIdx);

                if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
                  workList.push_back(*succBlockArg);
              }
            }
          };

      while (!workList.empty()) {
        Value currentItem = workList.pop_back_val();

        if (!visitedValues.insert(currentItem).second)
          continue;

        // 1   - Look forward:
        // 1.1 - If currentItem escapes to one or more successors, add
        // the corresponding successor arguments to workList.
        updateWorkListWithSuccessorArguments(
            currentItem, dyn_cast<BranchOpInterface>(
                             currentItem.getParentBlock()->getTerminator()));

        // 1.2 - For each user of currentItem, add the defined values to
        // workList. This way, the user ops can be inspected later if they are
        // detensorable and if so, their operands will be added to workList to
        // potentially discover other parts of the detensorable component.
        for (auto *user : currentItem.getUsers())
          for (Value result : user->getResults())
            workList.push_back(result);

        // 2   - Look backward:
        // 2.1 - The current item is defined by a block argument. If the owner
        // block is a non-entry one, then:
        //       * Add the argument to blockArgsToDetensor.
        //       * Walk the use-def chain backwards to add each predecessor's
        //       terminator-operands corresponding to currentItem to workList.
        if (currentItem.dyn_cast<BlockArgument>()) {
          BlockArgument currentItemBlockArgument =
              currentItem.cast<BlockArgument>();
          Block *ownerBlock = currentItemBlockArgument.getOwner();

          // Function arguments are not detensored/converted.
          if (&*ownerBlock->getParent()->begin() == ownerBlock)
            continue;

          // This inner-block argument is involved in control-flow, it should be
          // detensored.
          blockArgsToDetensor.insert(currentItemBlockArgument);

          for (PredecessorIterator pred = ownerBlock->pred_begin();
               pred != ownerBlock->pred_end(); ++pred) {
            BranchOpInterface predTerminator =
                dyn_cast<BranchOpInterface>((*pred)->getTerminator());

            // TODO: For now, we give up if any of the control-flow components
            // in a function is not detensorable. Fix that.
            if (!predTerminator) {
              opsToDetensor.clear();
              blockArgsToDetensor.clear();
              return;
            }

            auto ownerBlockOperands =
                predTerminator.getSuccessorOperands(pred.getSuccessorIndex());

            if (!ownerBlockOperands || ownerBlockOperands->empty())
              continue;

            // For each predecessor, add the value it passes to that argument to
            // workList to find out how it's computed.
            workList.push_back(
                ownerBlockOperands
                    .getValue()[currentItemBlockArgument.getArgNumber()]);
          }

          continue;
        }

        Operation *currentItemDefiningOp = currentItem.getDefiningOp();

        if (!visitedOps.insert(currentItemDefiningOp).second)
          continue;

        // 2.2 - The current item is computed by a GenericOp. If the op should
        // be detensored, then:
        //       * Add it to opsToDetensor.
        //       * Add its operands to workList to discover other parts of the
        //       potentially detensorable component.
        if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
          // The op was encountered already, no need to inspect it again.
          if (opsToDetensor.count(genericOp))
            continue;

          // The op should not be detensored, give up on it but continue with
          // discovering the rest of the control-flow component.
          if (!shouldBeDetensored(genericOp, typeConverter)) {
            continue;
          }

          opsToDetensor.insert(genericOp);

          for (Value genericOpOperand : genericOp.inputs())
            workList.push_back(genericOpOperand);

          continue;
        }

        // 2.3 - The current item is the result of a FromElementsOp, it will be
        // trivially detensored later as part of canonicalization patterns
        // applied at the end of detensoring.
        //
        // Note: No need to check whether the result type of this op is
        // detensorable since if it wasn't we wouldn't reach that point in the
        // work list.
        if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
          continue;

        // 2.4 - The current item is the result of a scalar op, add all its
        // operands to the work list.
        if (llvm::all_of(
                currentItemDefiningOp->getResultTypes(),
                [&](Type resultType) { return resultType.isIntOrFloat(); }))
          for (Value scalarOpOperand : currentItemDefiningOp->getOperands())
            workList.push_back(scalarOpOperand);
      }

      // Since the cost model gives up on some ops (see the details of step 2.2
      // above), block arguments that correspond to the values produced by those
      // ops should not be detensored as well.

      DenseSet<BlockArgument> blockArgsToRemove;

      for (auto &blockArg : blockArgsToDetensor) {
        Block *block = blockArg.getParentBlock();

        // For the potentially detensorable block argument, find the
        // correpsonding operands in predecessor blocks.
        for (PredecessorIterator pred = block->pred_begin();
             pred != block->pred_end(); ++pred) {
          BranchOpInterface terminator =
              dyn_cast<BranchOpInterface>((*pred)->getTerminator());
          auto blockOperands =
              terminator.getSuccessorOperands(pred.getSuccessorIndex());

          if (!blockOperands || blockOperands->empty())
            continue;

          Operation *definingOp =
              terminator
                  ->getOperand(blockOperands->getBeginOperandIndex() +
                               blockArg.getArgNumber())
                  .getDefiningOp();

          // If the operand is defined by a GenericOp that will not be
          // detensored, then do not detensor the corresponding block argument.
          if (dyn_cast_or_null<GenericOp>(definingOp) &&
              opsToDetensor.count(definingOp) == 0) {
            blockArgsToRemove.insert(blockArg);
            break;
          }
        }
      }

      for (auto &blockArg : blockArgsToRemove) {
        blockArgsToDetensor.erase(blockArg);
      }
    }
  };

  /// Detensorize everything that can detensored.
  class AggressiveDetensoringModel : public CostModel {
  public:
    void compute(FunctionOpInterface func,
                 DetensorizeTypeConverter typeConverter,
                 DenseSet<Operation *> &opsToDetensor,
                 DenseSet<BlockArgument> &blockArgsToDetensor) override {
      func->walk([&](GenericOp genericOp) {
        if (shouldBeDetensored(genericOp, typeConverter))
          opsToDetensor.insert(genericOp);
      });

      for (Block &block : llvm::drop_begin(func.getBody(), 1))
        for (BlockArgument blockArgument : block.getArguments())
          blockArgsToDetensor.insert(blockArgument);
    }
  };

  void runOnOperation() override {
    MLIRContext *context = &getContext();
    DetensorizeTypeConverter typeConverter;
    RewritePatternSet patterns(context);
    ConversionTarget target(*context);
    DenseSet<Operation *> opsToDetensor;
    DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
    DenseSet<BlockArgument> blockArgsToDetensor;
    FunctionOpInterface funcOp = cast<FunctionOpInterface>(getOperation());

    if (aggressiveMode.getValue()) {
      AggressiveDetensoringModel costModel;
      costModel.compute(funcOp, typeConverter, opsToDetensor,
                        blockArgsToDetensor);
    } else {
      ControlFlowDetectionModel costModel;
      costModel.compute(funcOp, typeConverter, opsToDetensor,
                        blockArgsToDetensor);
    }

    detensorableBranchOps =
        CostModel::computeBranchOpDetensoring(blockArgsToDetensor);

    target.addDynamicallyLegalOp<GenericOp>(
        [&](GenericOp op) { return !opsToDetensor.count(op); });

    target.markUnknownOpDynamicallyLegal([&](Operation *op) {
      // A function is legal if all of its non-entry blocks are legal. We
      // don't legalize the entry block (i.e. the function's signature)
      // since detensoring can't happen along external calling convention
      // boundaries, which we conservatively approximate as all function
      // signatures.
      if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
        Region &body = funcOp.getBody();
        return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
          return !llvm::any_of(
              blockArgsToDetensor, [&](BlockArgument blockArgument) {
                return blockArgument.getOwner() == &block &&
                       !typeConverter.isLegal(blockArgument.getType());
              });
        });
      }

      if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
          isLegalForReturnOpTypeConversionPattern(op, typeConverter,
                                                  /*returnOpAlwaysLegal*/ true))
        return true;

      if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
        if (!detensorableBranchOps.count(branchOp))
          return true;

        for (auto operandIdx : detensorableBranchOps[branchOp])
          if (!typeConverter.isLegal(
                  branchOp->getOperand(operandIdx).getType()))
            return false;

        return true;
      }

      return false;
    });

    patterns.insert<DetensorizeGenericOp>(typeConverter, context);
    patterns.insert<FunctionNonEntryBlockConversion>(context, typeConverter,
                                                     blockArgsToDetensor);
    // Since non-entry block arguments get detensorized, we also need to
    // update the control flow inside the function to reflect the correct
    // types.
    auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
                                          int operandIdx) -> bool {
      return detensorableBranchOps.count(branchOp) &&
             detensorableBranchOps[branchOp].count(operandIdx);
    };

    populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
                                                   shouldConvertBranchOperand);

    if (failed(
            applyFullConversion(getOperation(), target, std::move(patterns))))
      signalPassFailure();

    RewritePatternSet canonPatterns(context);
    tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
    if (failed(applyPatternsAndFoldGreedily(getOperation(),
                                            std::move(canonPatterns))))
      signalPassFailure();
  }
};
} // namespace

std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
  return std::make_unique<LinalgDetensorize>();
}
