//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::scf;

namespace mlir {
namespace scf {
namespace {

// bufferization.to_memref is not allowed to change the rank.
static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
#ifndef NDEBUG
  auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
  assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
                                rankedTensorType.getRank())) &&
         "to_memref would be invalid: mismatching ranks");
#endif
}

/// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
/// fully implemented at the moment.
struct ExecuteRegionOpInterface
    : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
                                                    scf::ExecuteRegionOp> {
  SmallVector<OpOperand *>
  getAliasingOpOperand(Operation *op, OpResult opResult,
                       const AnalysisState &state) const {
    // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
    // any SSA value that is in scope. To allow for use-def chain traversal
    // through ExecuteRegionOps in the analysis, the corresponding yield value
    // is considered to be aliasing with the result.
    auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
    size_t resultNum = std::distance(op->getOpResults().begin(),
                                     llvm::find(op->getOpResults(), opResult));
    // TODO: Support multiple blocks.
    assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
           "expected exactly 1 block");
    auto yieldOp = dyn_cast<scf::YieldOp>(
        executeRegionOp.getRegion().front().getTerminator());
    assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
    return {&yieldOp->getOpOperand(resultNum)};
  }

  // TODO: For better bufferization results, this could return `true` only if
  // there is a memory write in the region.
  bool isMemoryWrite(Operation *op, OpResult opResult,
                     const AnalysisState &state) const {
    // Similar to scf.if, results of this op are always considered memory writes
    // in the analysis. This is a useful pattern for all ops that have tensor
    // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
    // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
    // ops without OpOperands.
    return true;
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
    assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
           "only 1 block supported");
    auto yieldOp =
        cast<scf::YieldOp>(executeRegionOp.getRegion().front().getTerminator());
    TypeRange newResultTypes(yieldOp.getResults());

    // Create new op and move over region.
    auto newOp =
        rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
    newOp.getRegion().takeBody(executeRegionOp.getRegion());

    // Update all uses of the old op.
    rewriter.setInsertionPointAfter(newOp);
    SmallVector<Value> newResults;
    for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
      if (it.value().isa<TensorType>()) {
        newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
            executeRegionOp.getLoc(), newOp->getResult(it.index())));
      } else {
        newResults.push_back(newOp->getResult(it.index()));
      }
    }

    // Replace old op.
    rewriter.replaceOp(executeRegionOp, newResults);

    return success();
  }

  BufferRelation bufferRelation(Operation *op, OpResult opResult,
                                const AnalysisState &state) const {
    return BufferRelation::Equivalent;
  }
};

/// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
struct IfOpInterface
    : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
  SmallVector<OpOperand *>
  getAliasingOpOperand(Operation *op, OpResult opResult,
                       const AnalysisState &state) const {
    // IfOps do not have tensor OpOperands. The yielded value can be any SSA
    // value that is in scope. To allow for use-def chain traversal through
    // IfOps in the analysis, both corresponding yield values from the then/else
    // branches are considered to be aliasing with the result.
    auto ifOp = cast<scf::IfOp>(op);
    size_t resultNum = std::distance(op->getOpResults().begin(),
                                     llvm::find(op->getOpResults(), opResult));
    return {&ifOp.thenYield()->getOpOperand(resultNum),
            &ifOp.elseYield()->getOpOperand(resultNum)};
  }

  // TODO: For better bufferization results, this could return `true` only if
  // there is a memory write in one (or both) of the branches. Since this is not
  // allowed at the moment, we should never encounter scf.ifs that yield
  // unmodified tensors. Such scf.yield ops could just fold away.
  bool isMemoryWrite(Operation *op, OpResult opResult,
                     const AnalysisState &state) const {
    // IfOp results are always considered memory writes in the analysis. This
    // design decision simplifies the analysis considerably. E.g., consider the
    // following test case:
    //
    // %0 = "some_writing_op" : tensor<?xf32>
    // %r = scf.if %c -> (tensor<?xf32>) {
    //   scf.yield %0
    // } else {
    //   %1 = "another_writing_op"(%0) : tensor<?xf32>
    // }
    // "some_reading_op"(%r)
    //
    // "another_writing_op" in the above example should be able to bufferize
    // inplace in the absence of another read of %0. However, if the scf.if op
    // would not be considered a "write", the analysis would detect the
    // following conflict:
    //
    // * read = some_reading_op
    // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
    // * conflictingWrite = %1
    //
    // For more details, check the "scf.IfOp" section of the design document.
    return true;
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    OpBuilder::InsertionGuard g(rewriter);
    auto ifOp = cast<scf::IfOp>(op);
    auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
    auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());

    // Reconcile type mismatches between then/else branches by inserting memref
    // casts.
    SmallVector<Value> thenResults, elseResults;
    bool insertedCast = false;
    for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) {
      Value thenValue = thenYieldOp.getResults()[i];
      Value elseValue = elseYieldOp.getResults()[i];
      if (thenValue.getType() == elseValue.getType()) {
        thenResults.push_back(thenValue);
        elseResults.push_back(elseValue);
        continue;
      }

      // Type mismatch between then/else yield value. Cast both to a memref type
      // with a fully dynamic layout map.
      auto thenMemrefType = thenValue.getType().cast<BaseMemRefType>();
      auto elseMemrefType = elseValue.getType().cast<BaseMemRefType>();
      if (thenMemrefType.getMemorySpaceAsInt() !=
          elseMemrefType.getMemorySpaceAsInt())
        return op->emitError("inconsistent memory space on then/else branches");
      rewriter.setInsertionPoint(thenYieldOp);
      BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout(
          ifOp.getResultTypes()[i].cast<TensorType>(),
          thenMemrefType.getMemorySpaceAsInt());
      thenResults.push_back(rewriter.create<memref::CastOp>(
          thenYieldOp.getLoc(), memrefType, thenValue));
      rewriter.setInsertionPoint(elseYieldOp);
      elseResults.push_back(rewriter.create<memref::CastOp>(
          elseYieldOp.getLoc(), memrefType, elseValue));
      insertedCast = true;
    }

    if (insertedCast) {
      rewriter.setInsertionPoint(thenYieldOp);
      rewriter.replaceOpWithNewOp<scf::YieldOp>(thenYieldOp, thenResults);
      rewriter.setInsertionPoint(elseYieldOp);
      rewriter.replaceOpWithNewOp<scf::YieldOp>(elseYieldOp, elseResults);
    }

    // Create new op.
    rewriter.setInsertionPoint(ifOp);
    ValueRange resultsValueRange(thenResults);
    TypeRange newTypes(resultsValueRange);
    auto newIfOp =
        rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
                                   /*withElseRegion=*/true);

    // Move over then/else blocks.
    rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
    rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());

    // Replace op results.
    replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());

    return success();
  }

  BufferRelation bufferRelation(Operation *op, OpResult opResult,
                                const AnalysisState &state) const {
    // IfOp results are equivalent to their corresponding yield values if both
    // yield values are equivalent to each other.
    auto bufferizableOp = cast<BufferizableOpInterface>(op);
    SmallVector<OpOperand *> yieldValues =
        bufferizableOp.getAliasingOpOperand(opResult, state);
    assert(yieldValues.size() == 2 && "expected 2 yield values");
    bool equivalentYields = state.areEquivalentBufferizedValues(
        yieldValues[0]->get(), yieldValues[1]->get());
    return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
  }
};

/// Helper function for loop bufferization. Return the indices of all values
/// that have a tensor type.
static DenseSet<int64_t> getTensorIndices(ValueRange values) {
  DenseSet<int64_t> result;
  for (const auto &it : llvm::enumerate(values))
    if (it.value().getType().isa<TensorType>())
      result.insert(it.index());
  return result;
}

/// Helper function for loop bufferization. Return the indices of all
/// bbArg/yielded value pairs who's buffer relation is "Equivalent".
DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
                                       ValueRange yieldedValues,
                                       const AnalysisState &state) {
  unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
  DenseSet<int64_t> result;
  for (unsigned int i = 0; i < minSize; ++i) {
    if (!bbArgs[i].getType().isa<TensorType>() ||
        !yieldedValues[i].getType().isa<TensorType>())
      continue;
    if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
      result.insert(i);
  }
  return result;
}

/// Helper function for loop bufferization. Cast the given buffer to the given
/// memref type.
static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
  assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
  assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
  // If the buffer already has the correct type, no cast is needed.
  if (buffer.getType() == type)
    return buffer;
  // TODO: In case `type` has a layout map that is not the fully dynamic
  // one, we may not be able to cast the buffer. In that case, the loop
  // iter_arg's layout map must be changed (see uses of `castBuffer`).
  assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
         "scf.while op bufferization: cast incompatible");
  return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
}

/// Helper function for loop bufferization. Return the bufferized values of the
/// given OpOperands. If an operand is not a tensor, return the original value.
static FailureOr<SmallVector<Value>>
getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
           const BufferizationOptions &options) {
  SmallVector<Value> result;
  for (OpOperand &opOperand : operands) {
    if (opOperand.get().getType().isa<TensorType>()) {
      FailureOr<Value> resultBuffer =
          getBuffer(rewriter, opOperand.get(), options);
      if (failed(resultBuffer))
        return failure();
      result.push_back(*resultBuffer);
    } else {
      result.push_back(opOperand.get());
    }
  }
  return result;
}

/// Helper function for loop bufferization. Compute the buffer that should be
/// yielded from a loop block (loop body or loop condition).
static FailureOr<Value> getYieldedBuffer(RewriterBase &rewriter, Value tensor,
                                         BaseMemRefType type,
                                         const BufferizationOptions &options) {
  assert(tensor.getType().isa<TensorType>() && "expected tensor");
  ensureToMemrefOpIsValid(tensor, type);
  FailureOr<Value> yieldedVal = getBuffer(rewriter, tensor, options);
  if (failed(yieldedVal))
    return failure();
  return castBuffer(rewriter, *yieldedVal, type);
}

/// Helper function for loop bufferization. Given a range of values, apply
/// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
/// value in the result vector.
static FailureOr<SmallVector<Value>>
convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
                    llvm::function_ref<FailureOr<Value>(Value, int64_t)> func) {
  SmallVector<Value> result;
  for (const auto &it : llvm::enumerate(values)) {
    size_t idx = it.index();
    Value val = it.value();
    if (tensorIndices.contains(idx)) {
      FailureOr<Value> maybeVal = func(val, idx);
      if (failed(maybeVal))
        return failure();
      result.push_back(*maybeVal);
    } else {
      result.push_back(val);
    }
  }
  return result;
}

/// Helper function for loop bufferization. Given a list of pre-bufferization
/// yielded values, compute the list of bufferized yielded values.
FailureOr<SmallVector<Value>>
getYieldedValues(RewriterBase &rewriter, ValueRange values,
                 TypeRange bufferizedTypes,
                 const DenseSet<int64_t> &tensorIndices,
                 const BufferizationOptions &options) {
  return convertTensorValues(
      values, tensorIndices, [&](Value val, int64_t index) {
        return getYieldedBuffer(rewriter, val,
                                bufferizedTypes[index].cast<BaseMemRefType>(),
                                options);
      });
}

/// Helper function for loop bufferization. Given a list of bbArgs of the new
/// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
/// ToTensorOps, so that the block body can be moved over to the new op.
SmallVector<Value>
getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
                     const DenseSet<int64_t> &tensorIndices) {
  SmallVector<Value> result;
  for (const auto &it : llvm::enumerate(bbArgs)) {
    size_t idx = it.index();
    Value val = it.value();
    if (tensorIndices.contains(idx)) {
      result.push_back(
          rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
              .getResult());
    } else {
      result.push_back(val);
    }
  }
  return result;
}

/// Bufferization of scf.for. Replace with a new scf.for that operates on
/// memrefs.
struct ForOpInterface
    : public BufferizableOpInterface::ExternalModel<ForOpInterface,
                                                    scf::ForOp> {
  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                              const AnalysisState &state) const {
    // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
    // its matching bbArg may.
    auto forOp = cast<scf::ForOp>(op);
    return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
  }

  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
    // Tensor iter_args of scf::ForOps are always considered as a write.
    return true;
  }

  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
                                            const AnalysisState &state) const {
    auto forOp = cast<scf::ForOp>(op);
    return {forOp.getResultForOpOperand(opOperand)};
  }

  BufferRelation bufferRelation(Operation *op, OpResult opResult,
                                const AnalysisState &state) const {
    // ForOp results are equivalent to their corresponding init_args if the
    // corresponding iter_args and yield values are equivalent.
    auto forOp = cast<scf::ForOp>(op);
    OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
    auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
    auto yieldOp =
        cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
    bool equivalentYield = state.areEquivalentBufferizedValues(
        bbArg, yieldOp->getOperand(opResult.getResultNumber()));
    return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
  }

  bool isWritable(Operation *op, Value value,
                  const AnalysisState &state) const {
    // Interestingly, scf::ForOp's bbArg can **always** be viewed
    // inplace from the perspective of ops nested under:
    //   1. Either the matching iter operand is not bufferized inplace and an
    //      alloc + optional copy makes the bbArg itself inplaceable.
    //   2. Or the matching iter operand is bufferized inplace and bbArg just
    //      bufferizes to that too.
    return true;
  }

  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
                                 const AnalysisState &state) const {
    auto bufferizableOp = cast<BufferizableOpInterface>(op);
    if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
      return failure();

    if (!state.getOptions().enforceAliasingInvariants)
      return success();

    // According to the `getAliasing...` implementations, a bufferized OpResult
    // may alias only with the corresponding bufferized init_arg and with no
    // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
    // but not with any other OpOperand. If a corresponding OpResult/init_arg
    // pair bufferizes to equivalent buffers, this aliasing requirement is
    // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
    // (New buffer copies do not alias with any buffer.)
    auto forOp = cast<scf::ForOp>(op);
    auto yieldOp =
        cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
    OpBuilder::InsertionGuard g(rewriter);
    rewriter.setInsertionPoint(yieldOp);

    // Indices of all iter_args that have tensor type. These are the ones that
    // are bufferized.
    DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
    // For every yielded value, is the value equivalent to its corresponding
    // bbArg?
    DenseSet<int64_t> equivalentYields = getEquivalentBuffers(
        forOp.getRegionIterArgs(), yieldOp.getResults(), state);
    SmallVector<Value> yieldValues;
    for (int64_t idx = 0;
         idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
      Value value = yieldOp.getResults()[idx];
      if (!indices.contains(idx) || equivalentYields.contains(idx)) {
        yieldValues.push_back(value);
        continue;
      }
      FailureOr<Value> alloc =
          allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
                                       /*escape=*/true, state.getOptions());
      if (failed(alloc))
        return failure();
      yieldValues.push_back(*alloc);
    }

    rewriter.updateRootInPlace(
        yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
    return success();
  }

  FailureOr<BaseMemRefType>
  getBufferType(Operation *op, BlockArgument bbArg,
                const BufferizationOptions &options) const {
    auto forOp = cast<scf::ForOp>(op);
    return bufferization::getBufferType(
        forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    auto forOp = cast<scf::ForOp>(op);
    Block *oldLoopBody = &forOp.getLoopBody().front();

    // Indices of all iter_args that have tensor type. These are the ones that
    // are bufferized.
    DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());

    // The new memref init_args of the loop.
    FailureOr<SmallVector<Value>> maybeInitArgs =
        getBuffers(rewriter, forOp.getIterOpOperands(), options);
    if (failed(maybeInitArgs))
      return failure();
    SmallVector<Value> initArgs = *maybeInitArgs;

    // Construct a new scf.for op with memref instead of tensor values.
    auto newForOp = rewriter.create<scf::ForOp>(
        forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
        forOp.getStep(), initArgs);
    newForOp->setAttrs(forOp->getAttrs());
    ValueRange initArgsRange(initArgs);
    TypeRange initArgsTypes(initArgsRange);
    Block *loopBody = &newForOp.getLoopBody().front();

    // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
    // iter_args of the new loop in ToTensorOps.
    rewriter.setInsertionPointToStart(loopBody);
    SmallVector<Value> iterArgs =
        getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
    iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());

    // Move loop body to new loop.
    rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);

    // Replace loop results.
    replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());

    return success();
  }

  /// Assert that yielded values of an scf.for op are equivalent to their
  /// corresponding bbArgs. In that case, the buffer relations of the
  /// corresponding OpResults are "Equivalent".
  ///
  /// If this is not the case, an allocs+copies are inserted and yielded from
  /// the loop. This could be a performance problem, so it must be explicitly
  /// activated with `alloc-return-allocs`.
  LogicalResult verifyAnalysis(Operation *op,
                               const AnalysisState &state) const {
    const auto &options =
        static_cast<const OneShotBufferizationOptions &>(state.getOptions());
    if (options.allowReturnAllocs)
      return success();

    auto forOp = cast<scf::ForOp>(op);
    auto yieldOp =
        cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
    for (OpResult opResult : op->getOpResults()) {
      if (!opResult.getType().isa<TensorType>())
        continue;

      // Note: This is overly strict. We should check for aliasing bufferized
      // values. But we don't have a "must-alias" analysis yet.
      if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
        return yieldOp->emitError()
               << "Yield operand #" << opResult.getResultNumber()
               << " is not equivalent to the corresponding iter bbArg";
    }

    return success();
  }
};

/// Bufferization of scf.while. Replace with a new scf.while that operates on
/// memrefs.
struct WhileOpInterface
    : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
                                                    scf::WhileOp> {
  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                              const AnalysisState &state) const {
    // Tensor iter_args of scf::WhileOps are always considered as a read.
    return true;
  }

  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
    // Tensor iter_args of scf::WhileOps are always considered as a write.
    return true;
  }

  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
                                            const AnalysisState &state) const {
    auto whileOp = cast<scf::WhileOp>(op);
    unsigned int idx = opOperand.getOperandNumber();

    // The OpResults and OpOperands may not match. They may not even have the
    // same type. The number of OpResults and OpOperands can also differ.
    if (idx >= op->getNumResults() ||
        opOperand.get().getType() != op->getResult(idx).getType())
      return {};

    // The only aliasing OpResult may be the one at the same index.
    return {whileOp->getResult(idx)};
  }

  BufferRelation bufferRelation(Operation *op, OpResult opResult,
                                const AnalysisState &state) const {
    // WhileOp results are equivalent to their corresponding init_args if the
    // corresponding iter_args and yield values are equivalent (for both the
    // "before" and the "after" block).
    unsigned int resultNumber = opResult.getResultNumber();
    auto whileOp = cast<scf::WhileOp>(op);

    // The "before" region bbArgs and the OpResults may not match.
    if (resultNumber >= whileOp.getBeforeArguments().size())
      return BufferRelation::None;
    if (opResult.getType() !=
        whileOp.getBeforeArguments()[resultNumber].getType())
      return BufferRelation::None;

    auto conditionOp = whileOp.getConditionOp();
    BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
    Value conditionOperand = conditionOp.getArgs()[resultNumber];
    bool equivCondition =
        state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);

    auto yieldOp = whileOp.getYieldOp();
    BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
    Value yieldOperand = yieldOp.getOperand(resultNumber);
    bool equivYield =
        state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);

    return equivCondition && equivYield ? BufferRelation::Equivalent
                                        : BufferRelation::None;
  }

  bool isWritable(Operation *op, Value value,
                  const AnalysisState &state) const {
    // Interestingly, scf::WhileOp's bbArg can **always** be viewed
    // inplace from the perspective of ops nested under:
    //   1. Either the matching iter operand is not bufferized inplace and an
    //      alloc + optional copy makes the bbArg itself inplaceable.
    //   2. Or the matching iter operand is bufferized inplace and bbArg just
    //      bufferizes to that too.
    return true;
  }

  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
                                 const AnalysisState &state) const {
    auto bufferizableOp = cast<BufferizableOpInterface>(op);
    if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
      return failure();

    if (!state.getOptions().enforceAliasingInvariants)
      return success();

    // According to the `getAliasing...` implementations, a bufferized OpResult
    // may alias only with the corresponding bufferized init_arg and with no
    // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
    // but not with any other OpOperand. If a corresponding OpResult/init_arg
    // pair bufferizes to equivalent buffers, this aliasing requirement is
    // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
    // (New buffer copies do not alias with any buffer.)
    OpBuilder::InsertionGuard g(rewriter);
    auto whileOp = cast<scf::WhileOp>(op);
    auto conditionOp = whileOp.getConditionOp();
    auto yieldOp = whileOp.getYieldOp();

    // Indices of all bbArgs that have tensor type. These are the ones that
    // are bufferized. The "before" and "after" regions may have different args.
    DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
    DenseSet<int64_t> indicesAfter =
        getTensorIndices(whileOp.getAfterArguments());

    // For every yielded value, is the value equivalent to its corresponding
    // bbArg?
    DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
        whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
    DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
        whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);

    // Update "before" region.
    rewriter.setInsertionPoint(conditionOp);
    SmallVector<Value> beforeYieldValues;
    for (int64_t idx = 0;
         idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
      Value value = conditionOp.getArgs()[idx];
      if (!indicesBefore.contains(idx) ||
          equivalentYieldsBefore.contains(idx)) {
        beforeYieldValues.push_back(value);
        continue;
      }
      FailureOr<Value> alloc =
          allocateTensorForShapedValue(rewriter, conditionOp.getLoc(), value,
                                       /*escape=*/true, state.getOptions());
      if (failed(alloc))
        return failure();
      beforeYieldValues.push_back(*alloc);
    }
    rewriter.updateRootInPlace(conditionOp, [&]() {
      conditionOp.getArgsMutable().assign(beforeYieldValues);
    });

    // Update "after" region.
    rewriter.setInsertionPoint(yieldOp);
    SmallVector<Value> afterYieldValues;
    for (int64_t idx = 0;
         idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
      Value value = yieldOp.getResults()[idx];
      if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) {
        afterYieldValues.push_back(value);
        continue;
      }
      FailureOr<Value> alloc =
          allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
                                       /*escape=*/true, state.getOptions());
      if (failed(alloc))
        return failure();
      afterYieldValues.push_back(*alloc);
    }
    rewriter.updateRootInPlace(yieldOp, [&]() {
      yieldOp.getResultsMutable().assign(afterYieldValues);
    });

    return success();
  }

  // TODO: Implement getBufferType interface method and infer buffer types.

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    auto whileOp = cast<scf::WhileOp>(op);

    assert(whileOp.getBefore().getBlocks().size() == 1 &&
           "regions with multiple blocks not supported");
    Block *beforeBody = &whileOp.getBefore().front();
    assert(whileOp.getAfter().getBlocks().size() == 1 &&
           "regions with multiple blocks not supported");
    Block *afterBody = &whileOp.getAfter().front();

    // Indices of all bbArgs that have tensor type. These are the ones that
    // are bufferized. The "before" and "after" regions may have different args.
    DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
    DenseSet<int64_t> indicesAfter =
        getTensorIndices(whileOp.getAfterArguments());

    // The new memref init_args of the loop.
    FailureOr<SmallVector<Value>> maybeInitArgs =
        getBuffers(rewriter, whileOp->getOpOperands(), options);
    if (failed(maybeInitArgs))
      return failure();
    SmallVector<Value> initArgs = *maybeInitArgs;

    // The result types of a WhileOp are the same as the "after" bbArg types.
    SmallVector<Type> argsTypesAfter = llvm::to_vector(
        llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
          // TODO: error handling
          return bufferization::getBufferType(bbArg, options)->cast<Type>();
        }));

    // Construct a new scf.while op with memref instead of tensor values.
    ValueRange argsRangeBefore(initArgs);
    TypeRange argsTypesBefore(argsRangeBefore);
    auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
                                                    argsTypesAfter, initArgs);

    // Add before/after regions to the new op.
    SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc());
    SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
                                         whileOp.getLoc());
    Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
    newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
    Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
    newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);

    // Set up new iter_args and move the loop condition block to the new op.
    // The old block uses tensors, so wrap the (memref) bbArgs of the new block
    // in ToTensorOps.
    rewriter.setInsertionPointToStart(newBeforeBody);
    SmallVector<Value> newBeforeArgs = getBbArgReplacements(
        rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
    rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);

    // Update scf.condition of new loop.
    auto newConditionOp = newWhileOp.getConditionOp();
    rewriter.setInsertionPoint(newConditionOp);
    // Only equivalent buffers or new buffer allocations may be yielded to the
    // "after" region.
    // TODO: This could be relaxed for better bufferization results.
    FailureOr<SmallVector<Value>> newConditionArgs =
        getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
                         indicesAfter, options);
    if (failed(newConditionArgs))
      return failure();
    newConditionOp.getArgsMutable().assign(*newConditionArgs);

    // Set up new iter_args and move the loop body block to the new op.
    // The old block uses tensors, so wrap the (memref) bbArgs of the new block
    // in ToTensorOps.
    rewriter.setInsertionPointToStart(newAfterBody);
    SmallVector<Value> newAfterArgs = getBbArgReplacements(
        rewriter, newWhileOp.getAfterArguments(), indicesAfter);
    rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);

    // Update scf.yield of the new loop.
    auto newYieldOp = newWhileOp.getYieldOp();
    rewriter.setInsertionPoint(newYieldOp);
    // Only equivalent buffers or new buffer allocations may be yielded to the
    // "before" region.
    // TODO: This could be relaxed for better bufferization results.
    FailureOr<SmallVector<Value>> newYieldValues =
        getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
                         indicesBefore, options);
    if (failed(newYieldValues))
      return failure();
    newYieldOp.getResultsMutable().assign(*newYieldValues);

    // Replace loop results.
    replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());

    return success();
  }

  /// Assert that yielded values of an scf.while op are equivalent to their
  /// corresponding bbArgs. In that case, the buffer relations of the
  /// corresponding OpResults are "Equivalent".
  ///
  /// If this is not the case, allocs+copies are inserted and yielded from
  /// the loop. This could be a performance problem, so it must be explicitly
  /// activated with `alloc-return-allocs`.
  ///
  /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
  /// equivalence condition must be checked for both.
  LogicalResult verifyAnalysis(Operation *op,
                               const AnalysisState &state) const {
    auto whileOp = cast<scf::WhileOp>(op);
    const auto &options =
        static_cast<const OneShotBufferizationOptions &>(state.getOptions());
    if (options.allowReturnAllocs)
      return success();

    auto conditionOp = whileOp.getConditionOp();
    for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
      if (!it.value().getType().isa<TensorType>())
        continue;
      if (!state.areEquivalentBufferizedValues(
              it.value(), conditionOp->getBlock()->getArgument(it.index())))
        return conditionOp->emitError()
               << "Condition arg #" << it.index()
               << " is not equivalent to the corresponding iter bbArg";
    }

    auto yieldOp = whileOp.getYieldOp();
    for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
      if (!it.value().getType().isa<TensorType>())
        continue;
      if (!state.areEquivalentBufferizedValues(
              it.value(), yieldOp->getBlock()->getArgument(it.index())))
        return yieldOp->emitError()
               << "Yield operand #" << it.index()
               << " is not equivalent to the corresponding iter bbArg";
    }

    return success();
  }
};

/// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
/// this is for analysis only.
struct YieldOpInterface
    : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
                                                    scf::YieldOp> {
  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                              const AnalysisState &state) const {
    return true;
  }

  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
    return false;
  }

  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
                                            const AnalysisState &state) const {
    if (isa<scf::IfOp>(op->getParentOp()))
      return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
    if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
      return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
    return {};
  }

  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
                            const AnalysisState &state) const {
    // Yield operands always bufferize inplace. Otherwise, an alloc + copy
    // may be generated inside the block. We should not return/yield allocations
    // when possible.
    return true;
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    auto yieldOp = cast<scf::YieldOp>(op);
    if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
            yieldOp->getParentOp()))
      return yieldOp->emitError("unsupported scf::YieldOp parent");

    // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized
    // together with scf.while.)
    if (isa<scf::WhileOp>(yieldOp->getParentOp()))
      return success();

    SmallVector<Value> newResults;
    for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
      Value value = it.value();
      if (value.getType().isa<TensorType>()) {
        FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
        if (failed(maybeBuffer))
          return failure();
        Value buffer = *maybeBuffer;
        if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
          FailureOr<BaseMemRefType> resultType =
              cast<BufferizableOpInterface>(forOp.getOperation())
                  .getBufferType(forOp.getRegionIterArgs()[it.index()],
                                 options);
          if (failed(resultType))
            return failure();
          buffer = castBuffer(rewriter, buffer, *resultType);
        }
        newResults.push_back(buffer);
      } else {
        newResults.push_back(value);
      }
    }

    replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
    return success();
  }
};

using tensor::ExtractSliceOp;

/// Return the destinations that an ForeachThreadOp is inserting into. One per
/// ParallelInsertSliceOp.
static SmallVector<OpOperand *>
getInsertionDest(ForeachThreadOp foreachThreadOp) {
  PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
  SmallVector<OpOperand *> result;
  terminator.walk([&](tensor::ParallelInsertSliceOp insertOp) {
    result.push_back(&insertOp->getOpOperand(1) /*dest*/);
  });
  return result;
}

/// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
/// region. There are op interfaces for the terminators (PerformConcurrentlyOp
/// and ParallelInsertSliceOp), but these are only used during analysis. Not
/// for bufferization.
struct ForeachThreadOpInterface
    : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface,
                                                    ForeachThreadOp> {
  SmallVector<OpOperand *>
  getAliasingOpOperand(Operation *op, OpResult opResult,
                       const AnalysisState &state) const {
    // Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
    auto foreachThreadOp = cast<ForeachThreadOp>(op);
    return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]};
  }

  bool isMemoryWrite(Operation *op, OpResult opResult,
                     const AnalysisState &state) const {
    // This op is a memory write. Stop lookup here to avoid finding false
    // conflicts involving this op and one of the ops in the region. This is
    // similar to how scf.if ops are analyzed.
    return true;
  }

  BufferRelation bufferRelation(Operation *op, OpResult opResult,
                                const AnalysisState &state) const {
    return BufferRelation::Equivalent;
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationOptions &options) const {
    auto foreachThreadOp = cast<ForeachThreadOp>(op);

#ifndef NDEBUG
    // ParallelInsertSliceOpInterface replaces all uses.
    for (OpResult opResult : foreachThreadOp->getOpResults())
      assert(opResult.getUses().empty() &&
             "expected that all uses were already replaced");
#endif // NDEBUG

    // Create new ForeachThreadOp without any results and drop the automatically
    // introduced terminator.
    TypeRange newResultTypes;
    auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
        foreachThreadOp.getLoc(), newResultTypes,
        foreachThreadOp.getNumThreads(),
        extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping()));
    newForeachThreadOp.getBody()->getTerminator()->erase();

    // Move over block contents of the old op.
    rewriter.mergeBlocks(foreachThreadOp.getBody(),
                         newForeachThreadOp.getBody(),
                         {newForeachThreadOp.getBody()->getArguments()});

    // Remove the old op.
    rewriter.eraseOp(op);

    return success();
  }
};

/// Nothing to do for PerformConcurrentlyOp.
struct PerformConcurrentlyOpInterface
    : public BufferizableOpInterface::ExternalModel<
          PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
  LogicalResult bufferize(Operation *op, RewriterBase &b,
                          const BufferizationOptions &options) const {
    llvm_unreachable("op does not have any tensor OpOperands / OpResults");
    return failure();
  }
};

} // namespace
} // namespace scf
} // namespace mlir

void mlir::scf::registerBufferizableOpInterfaceExternalModels(
    DialectRegistry &registry) {
  registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
    ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
    ForOp::attachInterface<ForOpInterface>(*ctx);
    IfOp::attachInterface<IfOpInterface>(*ctx);
    ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
    PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
        *ctx);
    WhileOp::attachInterface<WhileOpInterface>(*ctx);
    YieldOp::attachInterface<YieldOpInterface>(*ctx);
  });
}
