119efe141SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
219efe141SMatthias Springer //
319efe141SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
419efe141SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
519efe141SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
619efe141SMatthias Springer //
719efe141SMatthias Springer //===----------------------------------------------------------------------===//
819efe141SMatthias Springer 
919efe141SMatthias Springer #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
1019efe141SMatthias Springer 
1119efe141SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1219efe141SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1319efe141SMatthias Springer #include "mlir/Dialect/SCF/SCF.h"
1419efe141SMatthias Springer #include "mlir/IR/Dialect.h"
1519efe141SMatthias Springer #include "mlir/IR/Operation.h"
1619efe141SMatthias Springer #include "mlir/IR/PatternMatch.h"
1719efe141SMatthias Springer 
1819efe141SMatthias Springer using namespace mlir;
1919efe141SMatthias Springer using namespace mlir::bufferization;
2019efe141SMatthias Springer using namespace mlir::scf;
2119efe141SMatthias Springer 
2219efe141SMatthias Springer namespace mlir {
2319efe141SMatthias Springer namespace scf {
2419efe141SMatthias Springer namespace {
2519efe141SMatthias Springer 
2619efe141SMatthias Springer // bufferization.to_memref is not allowed to change the rank.
2719efe141SMatthias Springer static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
2819efe141SMatthias Springer #ifndef NDEBUG
2919efe141SMatthias Springer   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
3019efe141SMatthias Springer   assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
3119efe141SMatthias Springer                                 rankedTensorType.getRank())) &&
3219efe141SMatthias Springer          "to_memref would be invalid: mismatching ranks");
3319efe141SMatthias Springer #endif
3419efe141SMatthias Springer }
3519efe141SMatthias Springer 
3619efe141SMatthias Springer /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
3719efe141SMatthias Springer /// fully implemented at the moment.
3819efe141SMatthias Springer struct ExecuteRegionOpInterface
3919efe141SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
4019efe141SMatthias Springer                                                     scf::ExecuteRegionOp> {
4119efe141SMatthias Springer   SmallVector<OpOperand *>
4219efe141SMatthias Springer   getAliasingOpOperand(Operation *op, OpResult opResult,
4319efe141SMatthias Springer                        const BufferizationState &state) const {
4419efe141SMatthias Springer     // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
4519efe141SMatthias Springer     // any SSA value that is in scope. To allow for use-def chain traversal
4619efe141SMatthias Springer     // through ExecuteRegionOps in the analysis, the corresponding yield value
4719efe141SMatthias Springer     // is considered to be aliasing with the result.
4819efe141SMatthias Springer     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
4919efe141SMatthias Springer     size_t resultNum = std::distance(op->getOpResults().begin(),
5019efe141SMatthias Springer                                      llvm::find(op->getOpResults(), opResult));
5119efe141SMatthias Springer     // TODO: Support multiple blocks.
5219efe141SMatthias Springer     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
5319efe141SMatthias Springer            "expected exactly 1 block");
5419efe141SMatthias Springer     auto yieldOp = dyn_cast<scf::YieldOp>(
5519efe141SMatthias Springer         executeRegionOp.getRegion().front().getTerminator());
5619efe141SMatthias Springer     assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
5719efe141SMatthias Springer     return {&yieldOp->getOpOperand(resultNum)};
5819efe141SMatthias Springer   }
5919efe141SMatthias Springer 
6019efe141SMatthias Springer   // TODO: For better bufferization results, this could return `true` only if
6119efe141SMatthias Springer   // there is a memory write in the region.
6219efe141SMatthias Springer   bool isMemoryWrite(Operation *op, OpResult opResult,
6319efe141SMatthias Springer                      const BufferizationState &state) const {
6419efe141SMatthias Springer     // Similar to scf.if, results of this op are always considered memory writes
6519efe141SMatthias Springer     // in the analysis. This is a useful pattern for all ops that have tensor
6619efe141SMatthias Springer     // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
6719efe141SMatthias Springer     // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
6819efe141SMatthias Springer     // ops without OpOperands.
6919efe141SMatthias Springer     return true;
7019efe141SMatthias Springer   }
7119efe141SMatthias Springer 
7219efe141SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
7319efe141SMatthias Springer                           const BufferizationState &state) const {
7419efe141SMatthias Springer     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
7519efe141SMatthias Springer 
7619efe141SMatthias Springer     // Compute new result types.
7719efe141SMatthias Springer     SmallVector<Type> newResultTypes;
7819efe141SMatthias Springer     for (Type type : executeRegionOp->getResultTypes()) {
7919efe141SMatthias Springer       if (auto tensorType = type.dyn_cast<TensorType>()) {
8019efe141SMatthias Springer         newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
8119efe141SMatthias Springer       } else {
8219efe141SMatthias Springer         newResultTypes.push_back(type);
8319efe141SMatthias Springer       }
8419efe141SMatthias Springer     }
8519efe141SMatthias Springer 
8619efe141SMatthias Springer     // Create new op and move over region.
8719efe141SMatthias Springer     auto newOp =
8819efe141SMatthias Springer         rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
8919efe141SMatthias Springer     newOp.getRegion().takeBody(executeRegionOp.getRegion());
9019efe141SMatthias Springer 
9119efe141SMatthias Springer     // Update terminator.
9219efe141SMatthias Springer     assert(newOp.getRegion().getBlocks().size() == 1 &&
9319efe141SMatthias Springer            "only 1 block supported");
9419efe141SMatthias Springer     Block *newBlock = &newOp.getRegion().front();
9519efe141SMatthias Springer     auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
9619efe141SMatthias Springer     rewriter.setInsertionPoint(yieldOp);
9719efe141SMatthias Springer     SmallVector<Value> newYieldValues;
98bb6119ebSMehdi Amini     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
9919efe141SMatthias Springer       Value val = it.value();
10019efe141SMatthias Springer       if (val.getType().isa<TensorType>()) {
10119efe141SMatthias Springer         newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
10219efe141SMatthias Springer             yieldOp.getLoc(), newResultTypes[it.index()], val));
10319efe141SMatthias Springer       } else {
10419efe141SMatthias Springer         newYieldValues.push_back(val);
10519efe141SMatthias Springer       }
10619efe141SMatthias Springer     }
10719efe141SMatthias Springer     rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
10819efe141SMatthias Springer 
10919efe141SMatthias Springer     // Update all uses of the old op.
11019efe141SMatthias Springer     rewriter.setInsertionPointAfter(newOp);
11119efe141SMatthias Springer     SmallVector<Value> newResults;
112bb6119ebSMehdi Amini     for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
11319efe141SMatthias Springer       if (it.value().isa<TensorType>()) {
11419efe141SMatthias Springer         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
11519efe141SMatthias Springer             executeRegionOp.getLoc(), newOp->getResult(it.index())));
11619efe141SMatthias Springer       } else {
11719efe141SMatthias Springer         newResults.push_back(newOp->getResult(it.index()));
11819efe141SMatthias Springer       }
11919efe141SMatthias Springer     }
12019efe141SMatthias Springer 
12119efe141SMatthias Springer     // Replace old op.
12219efe141SMatthias Springer     rewriter.replaceOp(executeRegionOp, newResults);
12319efe141SMatthias Springer 
12419efe141SMatthias Springer     return success();
12519efe141SMatthias Springer   }
12619efe141SMatthias Springer 
12719efe141SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
12819efe141SMatthias Springer                                 const BufferizationState &state) const {
12919efe141SMatthias Springer     return BufferRelation::Equivalent;
13019efe141SMatthias Springer   }
13119efe141SMatthias Springer };
13219efe141SMatthias Springer 
13319efe141SMatthias Springer /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
13419efe141SMatthias Springer struct IfOpInterface
13519efe141SMatthias Springer     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
13619efe141SMatthias Springer   SmallVector<OpOperand *>
13719efe141SMatthias Springer   getAliasingOpOperand(Operation *op, OpResult opResult,
13819efe141SMatthias Springer                        const BufferizationState &state) const {
13919efe141SMatthias Springer     // IfOps do not have tensor OpOperands. The yielded value can be any SSA
14019efe141SMatthias Springer     // value that is in scope. To allow for use-def chain traversal through
14119efe141SMatthias Springer     // IfOps in the analysis, both corresponding yield values from the then/else
14219efe141SMatthias Springer     // branches are considered to be aliasing with the result.
14319efe141SMatthias Springer     auto ifOp = cast<scf::IfOp>(op);
14419efe141SMatthias Springer     size_t resultNum = std::distance(op->getOpResults().begin(),
14519efe141SMatthias Springer                                      llvm::find(op->getOpResults(), opResult));
14619efe141SMatthias Springer     return {&ifOp.thenYield()->getOpOperand(resultNum),
14719efe141SMatthias Springer             &ifOp.elseYield()->getOpOperand(resultNum)};
14819efe141SMatthias Springer   }
14919efe141SMatthias Springer 
15019efe141SMatthias Springer   // TODO: For better bufferization results, this could return `true` only if
15119efe141SMatthias Springer   // there is a memory write in one (or both) of the branches. Since this is not
15219efe141SMatthias Springer   // allowed at the moment, we should never encounter scf.ifs that yield
15319efe141SMatthias Springer   // unmodified tensors. Such scf.yield ops could just fold away.
15419efe141SMatthias Springer   bool isMemoryWrite(Operation *op, OpResult opResult,
15519efe141SMatthias Springer                      const BufferizationState &state) const {
15619efe141SMatthias Springer     // IfOp results are always considered memory writes in the analysis. This
15719efe141SMatthias Springer     // design decision simplifies the analysis considerably. E.g., consider the
15819efe141SMatthias Springer     // following test case:
15919efe141SMatthias Springer     //
16019efe141SMatthias Springer     // %0 = "some_writing_op" : tensor<?xf32>
16119efe141SMatthias Springer     // %r = scf.if %c -> (tensor<?xf32>) {
16219efe141SMatthias Springer     //   scf.yield %0
16319efe141SMatthias Springer     // } else {
16419efe141SMatthias Springer     //   %1 = "another_writing_op"(%0) : tensor<?xf32>
16519efe141SMatthias Springer     // }
16619efe141SMatthias Springer     // "some_reading_op"(%r)
16719efe141SMatthias Springer     //
16819efe141SMatthias Springer     // "another_writing_op" in the above example should be able to bufferize
16919efe141SMatthias Springer     // inplace in the absence of another read of %0. However, if the scf.if op
17019efe141SMatthias Springer     // would not be considered a "write", the analysis would detect the
17119efe141SMatthias Springer     // following conflict:
17219efe141SMatthias Springer     //
17319efe141SMatthias Springer     // * read = some_reading_op
17419efe141SMatthias Springer     // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
17519efe141SMatthias Springer     // * conflictingWrite = %1
17619efe141SMatthias Springer     //
17719efe141SMatthias Springer     // For more details, check the "scf.IfOp" section of the design document.
17819efe141SMatthias Springer     return true;
17919efe141SMatthias Springer   }
18019efe141SMatthias Springer 
18119efe141SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
18219efe141SMatthias Springer                           const BufferizationState &state) const {
18319efe141SMatthias Springer     auto ifOp = cast<scf::IfOp>(op);
18419efe141SMatthias Springer 
18519efe141SMatthias Springer     // Compute new types of the bufferized scf.if op.
18619efe141SMatthias Springer     SmallVector<Type> newTypes;
18719efe141SMatthias Springer     for (Type returnType : ifOp->getResultTypes()) {
18819efe141SMatthias Springer       if (auto tensorType = returnType.dyn_cast<TensorType>()) {
18919efe141SMatthias Springer         newTypes.push_back(getMemRefType(tensorType, state.getOptions()));
19019efe141SMatthias Springer       } else {
19119efe141SMatthias Springer         newTypes.push_back(returnType);
19219efe141SMatthias Springer       }
19319efe141SMatthias Springer     }
19419efe141SMatthias Springer 
19519efe141SMatthias Springer     // Create new op.
19619efe141SMatthias Springer     auto newIfOp =
19719efe141SMatthias Springer         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
19819efe141SMatthias Springer                                    /*withElseRegion=*/true);
19919efe141SMatthias Springer 
20019efe141SMatthias Springer     // Remove terminators.
20119efe141SMatthias Springer     if (!newIfOp.thenBlock()->empty()) {
20219efe141SMatthias Springer       rewriter.eraseOp(newIfOp.thenBlock()->getTerminator());
20319efe141SMatthias Springer       rewriter.eraseOp(newIfOp.elseBlock()->getTerminator());
20419efe141SMatthias Springer     }
20519efe141SMatthias Springer 
20619efe141SMatthias Springer     // Move over then/else blocks.
20719efe141SMatthias Springer     rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
20819efe141SMatthias Springer     rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
20919efe141SMatthias Springer 
21019efe141SMatthias Springer     // Update scf.yield of new then-block.
21119efe141SMatthias Springer     auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator());
21219efe141SMatthias Springer     rewriter.setInsertionPoint(thenYieldOp);
21319efe141SMatthias Springer     SmallVector<Value> thenYieldValues;
21419efe141SMatthias Springer     for (OpOperand &operand : thenYieldOp->getOpOperands()) {
21519efe141SMatthias Springer       if (operand.get().getType().isa<TensorType>()) {
21619efe141SMatthias Springer         ensureToMemrefOpIsValid(operand.get(),
21719efe141SMatthias Springer                                 newTypes[operand.getOperandNumber()]);
21819efe141SMatthias Springer         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
21919efe141SMatthias Springer             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
22019efe141SMatthias Springer             operand.get());
22119efe141SMatthias Springer         operand.set(toMemrefOp);
22219efe141SMatthias Springer       }
22319efe141SMatthias Springer     }
22419efe141SMatthias Springer 
22519efe141SMatthias Springer     // Update scf.yield of new else-block.
22619efe141SMatthias Springer     auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator());
22719efe141SMatthias Springer     rewriter.setInsertionPoint(elseYieldOp);
22819efe141SMatthias Springer     SmallVector<Value> elseYieldValues;
22919efe141SMatthias Springer     for (OpOperand &operand : elseYieldOp->getOpOperands()) {
23019efe141SMatthias Springer       if (operand.get().getType().isa<TensorType>()) {
23119efe141SMatthias Springer         ensureToMemrefOpIsValid(operand.get(),
23219efe141SMatthias Springer                                 newTypes[operand.getOperandNumber()]);
23319efe141SMatthias Springer         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
23419efe141SMatthias Springer             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
23519efe141SMatthias Springer             operand.get());
23619efe141SMatthias Springer         operand.set(toMemrefOp);
23719efe141SMatthias Springer       }
23819efe141SMatthias Springer     }
23919efe141SMatthias Springer 
24019efe141SMatthias Springer     // Replace op results.
24119efe141SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
24219efe141SMatthias Springer 
24319efe141SMatthias Springer     return success();
24419efe141SMatthias Springer   }
24519efe141SMatthias Springer 
24619efe141SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
24719efe141SMatthias Springer                                 const BufferizationState &state) const {
24819efe141SMatthias Springer     // IfOp results are equivalent to their corresponding yield values if both
24919efe141SMatthias Springer     // yield values are equivalent to each other.
25019efe141SMatthias Springer     auto bufferizableOp = cast<BufferizableOpInterface>(op);
25119efe141SMatthias Springer     SmallVector<OpOperand *> yieldValues =
25219efe141SMatthias Springer         bufferizableOp.getAliasingOpOperand(opResult, state);
25319efe141SMatthias Springer     assert(yieldValues.size() == 2 && "expected 2 yield values");
25419efe141SMatthias Springer     bool equivalentYields = state.areEquivalentBufferizedValues(
25519efe141SMatthias Springer         yieldValues[0]->get(), yieldValues[1]->get());
25619efe141SMatthias Springer     return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
25719efe141SMatthias Springer   }
25819efe141SMatthias Springer };
25919efe141SMatthias Springer 
26019efe141SMatthias Springer /// Bufferization of scf.for. Replace with a new scf.for that operates on
26119efe141SMatthias Springer /// memrefs.
26219efe141SMatthias Springer struct ForOpInterface
26319efe141SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
26419efe141SMatthias Springer                                                     scf::ForOp> {
26519efe141SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
26619efe141SMatthias Springer                               const BufferizationState &state) const {
26719efe141SMatthias Springer     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
26819efe141SMatthias Springer     // its matching bbArg may.
26919efe141SMatthias Springer     auto forOp = cast<scf::ForOp>(op);
27019efe141SMatthias Springer     return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
27119efe141SMatthias Springer   }
27219efe141SMatthias Springer 
27319efe141SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
27419efe141SMatthias Springer                                const BufferizationState &state) const {
27519efe141SMatthias Springer     // Tensor iter_args of scf::ForOps are always considered as a write. This is
27619efe141SMatthias Springer     // to simplify the analysis.
27719efe141SMatthias Springer     // TODO: Consider doing sth. like isValueWritten.
27819efe141SMatthias Springer     return true;
27919efe141SMatthias Springer   }
28019efe141SMatthias Springer 
281585a8a32SMatthias Springer   SmallVector<OpResult>
282585a8a32SMatthias Springer   getAliasingOpResult(Operation *op, OpOperand &opOperand,
28319efe141SMatthias Springer                       const BufferizationState &state) const {
28419efe141SMatthias Springer     auto forOp = cast<scf::ForOp>(op);
28519efe141SMatthias Springer     if (!opOperand.get().getType().isa<RankedTensorType>())
286585a8a32SMatthias Springer       return {};
287585a8a32SMatthias Springer     return {forOp.getResultForOpOperand(opOperand)};
28819efe141SMatthias Springer   }
28919efe141SMatthias Springer 
29019efe141SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
29119efe141SMatthias Springer                                 const BufferizationState &state) const {
29219efe141SMatthias Springer     // ForOp results are equivalent to their corresponding init_args if the
29319efe141SMatthias Springer     // corresponding iter_args and yield values are equivalent.
29419efe141SMatthias Springer     auto forOp = cast<scf::ForOp>(op);
29519efe141SMatthias Springer     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
29619efe141SMatthias Springer     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
29719efe141SMatthias Springer     auto yieldOp = cast<scf::YieldOp>(&forOp.getLoopBody().front().back());
29819efe141SMatthias Springer     bool equivalentYield = state.areEquivalentBufferizedValues(
29919efe141SMatthias Springer         bbArg, yieldOp->getOperand(opResult.getResultNumber()));
30019efe141SMatthias Springer     return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
30119efe141SMatthias Springer   }
30219efe141SMatthias Springer 
30319efe141SMatthias Springer   bool isWritable(Operation *op, Value value,
30419efe141SMatthias Springer                   const BufferizationState &state) const {
30519efe141SMatthias Springer     // Interestingly, scf::ForOp's bbArg can **always** be viewed
30619efe141SMatthias Springer     // inplace from the perspective of ops nested under:
30719efe141SMatthias Springer     //   1. Either the matching iter operand is not bufferized inplace and an
30819efe141SMatthias Springer     //      alloc + optional copy makes the bbArg itself inplaceable.
30919efe141SMatthias Springer     //   2. Or the matching iter operand is bufferized inplace and bbArg just
31019efe141SMatthias Springer     //      bufferizes to that too.
31119efe141SMatthias Springer     return true;
31219efe141SMatthias Springer   }
31319efe141SMatthias Springer 
31419efe141SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
31519efe141SMatthias Springer                           const BufferizationState &state) const {
31619efe141SMatthias Springer     auto forOp = cast<scf::ForOp>(op);
31719efe141SMatthias Springer     Block *oldLoopBody = &forOp.getLoopBody().front();
31819efe141SMatthias Springer 
31919efe141SMatthias Springer     // Indices of all iter_args that have tensor type. These are the ones that
32019efe141SMatthias Springer     // are bufferized.
32119efe141SMatthias Springer     DenseSet<int64_t> indices;
32219efe141SMatthias Springer     for (const auto &it : llvm::enumerate(forOp.getInitArgs()))
32319efe141SMatthias Springer       if (it.value().getType().isa<TensorType>())
32419efe141SMatthias Springer         indices.insert(it.index());
32519efe141SMatthias Springer 
32619efe141SMatthias Springer     // Given a range of values, apply `func` to those marked in `indices`.
32719efe141SMatthias Springer     // Otherwise, store the unmodified value in the result vector.
32819efe141SMatthias Springer     auto convert = [&](ValueRange values,
32919efe141SMatthias Springer                        llvm::function_ref<Value(Value, int64_t)> func) {
33019efe141SMatthias Springer       SmallVector<Value> result;
33119efe141SMatthias Springer       for (const auto &it : llvm::enumerate(values)) {
33219efe141SMatthias Springer         size_t idx = it.index();
33319efe141SMatthias Springer         Value val = it.value();
33419efe141SMatthias Springer         result.push_back(indices.contains(idx) ? func(val, idx) : val);
33519efe141SMatthias Springer       }
33619efe141SMatthias Springer       return result;
33719efe141SMatthias Springer     };
33819efe141SMatthias Springer 
33919efe141SMatthias Springer     // Construct a new scf.for op with memref instead of tensor values.
34019efe141SMatthias Springer     SmallVector<Value> initArgs;
34119efe141SMatthias Springer     for (OpOperand &opOperand : forOp.getIterOpOperands()) {
34219efe141SMatthias Springer       if (opOperand.get().getType().isa<TensorType>()) {
34319efe141SMatthias Springer         FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
34419efe141SMatthias Springer         if (failed(resultBuffer))
34519efe141SMatthias Springer           return failure();
34619efe141SMatthias Springer         initArgs.push_back(*resultBuffer);
34719efe141SMatthias Springer       } else {
34819efe141SMatthias Springer         initArgs.push_back(opOperand.get());
34919efe141SMatthias Springer       }
35019efe141SMatthias Springer     }
35119efe141SMatthias Springer     auto newForOp = rewriter.create<scf::ForOp>(
35219efe141SMatthias Springer         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
35319efe141SMatthias Springer         forOp.getStep(), initArgs);
35419efe141SMatthias Springer     Block *loopBody = &newForOp.getLoopBody().front();
35519efe141SMatthias Springer 
35619efe141SMatthias Springer     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
35719efe141SMatthias Springer     // iter_args of the new loop in ToTensorOps.
35819efe141SMatthias Springer     rewriter.setInsertionPointToStart(loopBody);
35919efe141SMatthias Springer     SmallVector<Value> iterArgs =
36019efe141SMatthias Springer         convert(newForOp.getRegionIterArgs(), [&](Value val, int64_t index) {
36119efe141SMatthias Springer           return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
36219efe141SMatthias Springer         });
36319efe141SMatthias Springer     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
36419efe141SMatthias Springer 
36519efe141SMatthias Springer     // Erase terminator if present.
36619efe141SMatthias Springer     if (iterArgs.size() == 1)
36719efe141SMatthias Springer       rewriter.eraseOp(loopBody->getTerminator());
36819efe141SMatthias Springer 
36919efe141SMatthias Springer     // Move loop body to new loop.
37019efe141SMatthias Springer     rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
37119efe141SMatthias Springer 
37219efe141SMatthias Springer     // Update scf.yield of new loop.
37319efe141SMatthias Springer     auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
37419efe141SMatthias Springer     rewriter.setInsertionPoint(yieldOp);
37519efe141SMatthias Springer     SmallVector<Value> yieldValues =
37619efe141SMatthias Springer         convert(yieldOp.getResults(), [&](Value val, int64_t index) {
37719efe141SMatthias Springer           ensureToMemrefOpIsValid(val, initArgs[index].getType());
37819efe141SMatthias Springer           return rewriter.create<bufferization::ToMemrefOp>(
37919efe141SMatthias Springer               val.getLoc(), initArgs[index].getType(), val);
38019efe141SMatthias Springer         });
38119efe141SMatthias Springer     yieldOp.getResultsMutable().assign(yieldValues);
38219efe141SMatthias Springer 
38319efe141SMatthias Springer     // Replace loop results.
38419efe141SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
38519efe141SMatthias Springer 
38619efe141SMatthias Springer     return success();
38719efe141SMatthias Springer   }
388*4ec00fb3SMatthias Springer 
389*4ec00fb3SMatthias Springer   /// Assert that yielded values of an scf.for op are aliasing with their
390*4ec00fb3SMatthias Springer   /// corresponding bbArgs. This is required because the i-th OpResult of an
391*4ec00fb3SMatthias Springer   /// scf.for op is currently assumed to alias with the i-th iter_arg (in the
392*4ec00fb3SMatthias Springer   /// absence of conflicts).
393*4ec00fb3SMatthias Springer   LogicalResult verifyAnalysis(Operation *op,
394*4ec00fb3SMatthias Springer                                const BufferizationState &state) const {
395*4ec00fb3SMatthias Springer     auto forOp = cast<scf::ForOp>(op);
396*4ec00fb3SMatthias Springer     auto yieldOp =
397*4ec00fb3SMatthias Springer         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
398*4ec00fb3SMatthias Springer     for (OpOperand &operand : yieldOp->getOpOperands()) {
399*4ec00fb3SMatthias Springer       auto tensorType = operand.get().getType().dyn_cast<TensorType>();
400*4ec00fb3SMatthias Springer       if (!tensorType)
401*4ec00fb3SMatthias Springer         continue;
402*4ec00fb3SMatthias Springer 
403*4ec00fb3SMatthias Springer       OpOperand &forOperand = forOp.getOpOperandForResult(
404*4ec00fb3SMatthias Springer           forOp->getResult(operand.getOperandNumber()));
405*4ec00fb3SMatthias Springer       auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
406*4ec00fb3SMatthias Springer       // Note: This is overly strict. We should check for aliasing bufferized
407*4ec00fb3SMatthias Springer       // values. But we don't have a "must-alias" analysis yet.
408*4ec00fb3SMatthias Springer       if (!state.areEquivalentBufferizedValues(operand.get(), bbArg))
409*4ec00fb3SMatthias Springer         // TODO: this could get resolved with copies but it can also turn into
410*4ec00fb3SMatthias Springer         // swaps so we need to be careful about order of copies.
411*4ec00fb3SMatthias Springer         return yieldOp->emitError()
412*4ec00fb3SMatthias Springer                << "Yield operand #" << operand.getOperandNumber()
413*4ec00fb3SMatthias Springer                << " does not bufferize to a buffer that is aliasing the "
414*4ec00fb3SMatthias Springer                   "matching"
415*4ec00fb3SMatthias Springer                << " enclosing scf::for operand";
416*4ec00fb3SMatthias Springer     }
417*4ec00fb3SMatthias Springer     return success();
418*4ec00fb3SMatthias Springer   }
41919efe141SMatthias Springer };
42019efe141SMatthias Springer 
42119efe141SMatthias Springer /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
42219efe141SMatthias Springer /// this is for analysis only.
42319efe141SMatthias Springer struct YieldOpInterface
42419efe141SMatthias Springer     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
42519efe141SMatthias Springer                                                     scf::YieldOp> {
42619efe141SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
42719efe141SMatthias Springer                               const BufferizationState &state) const {
42819efe141SMatthias Springer     return true;
42919efe141SMatthias Springer   }
43019efe141SMatthias Springer 
43119efe141SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
43219efe141SMatthias Springer                                const BufferizationState &state) const {
43319efe141SMatthias Springer     return false;
43419efe141SMatthias Springer   }
43519efe141SMatthias Springer 
436585a8a32SMatthias Springer   SmallVector<OpResult>
437585a8a32SMatthias Springer   getAliasingOpResult(Operation *op, OpOperand &opOperand,
43819efe141SMatthias Springer                       const BufferizationState &state) const {
43919efe141SMatthias Springer     if (isa<scf::IfOp>(op->getParentOp()))
440585a8a32SMatthias Springer       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
44119efe141SMatthias Springer     if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
442585a8a32SMatthias Springer       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
443585a8a32SMatthias Springer     return {};
44419efe141SMatthias Springer   }
44519efe141SMatthias Springer 
44619efe141SMatthias Springer   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
44719efe141SMatthias Springer                             const BufferizationState &state) const {
44819efe141SMatthias Springer     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
44919efe141SMatthias Springer     // may be generated inside the block. We should not return/yield allocations
45019efe141SMatthias Springer     // when possible.
45119efe141SMatthias Springer     return true;
45219efe141SMatthias Springer   }
45319efe141SMatthias Springer 
45419efe141SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
45519efe141SMatthias Springer                           const BufferizationState &state) const {
45619efe141SMatthias Springer     auto yieldOp = cast<scf::YieldOp>(op);
45719efe141SMatthias Springer     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(
45819efe141SMatthias Springer             yieldOp->getParentOp()))
45919efe141SMatthias Springer       return yieldOp->emitError("unsupported scf::YieldOp parent");
46019efe141SMatthias Springer     return success();
46119efe141SMatthias Springer   }
46219efe141SMatthias Springer };
46319efe141SMatthias Springer 
46419efe141SMatthias Springer } // namespace
46519efe141SMatthias Springer } // namespace scf
46619efe141SMatthias Springer } // namespace mlir
46719efe141SMatthias Springer 
46819efe141SMatthias Springer void mlir::scf::registerBufferizableOpInterfaceExternalModels(
46919efe141SMatthias Springer     DialectRegistry &registry) {
47019efe141SMatthias Springer   registry.addOpInterface<ExecuteRegionOp, ExecuteRegionOpInterface>();
47119efe141SMatthias Springer   registry.addOpInterface<ForOp, ForOpInterface>();
47219efe141SMatthias Springer   registry.addOpInterface<IfOp, IfOpInterface>();
47319efe141SMatthias Springer   registry.addOpInterface<YieldOp, YieldOpInterface>();
47419efe141SMatthias Springer   registry
47519efe141SMatthias Springer       .addOpInterface<ParallelOp, AllocationHoistingBarrierOnly<ParallelOp>>();
47619efe141SMatthias Springer }
477