119efe141SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
219efe141SMatthias Springer //
319efe141SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
419efe141SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
519efe141SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
619efe141SMatthias Springer //
719efe141SMatthias Springer //===----------------------------------------------------------------------===//
819efe141SMatthias Springer 
919efe141SMatthias Springer #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
1019efe141SMatthias Springer 
1119efe141SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1219efe141SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
131e1eeae8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
1519efe141SMatthias Springer #include "mlir/Dialect/SCF/SCF.h"
1619efe141SMatthias Springer #include "mlir/IR/Dialect.h"
1719efe141SMatthias Springer #include "mlir/IR/Operation.h"
1819efe141SMatthias Springer #include "mlir/IR/PatternMatch.h"
1919efe141SMatthias Springer 
2019efe141SMatthias Springer using namespace mlir;
2119efe141SMatthias Springer using namespace mlir::bufferization;
2219efe141SMatthias Springer using namespace mlir::scf;
2319efe141SMatthias Springer 
2419efe141SMatthias Springer namespace mlir {
2519efe141SMatthias Springer namespace scf {
2619efe141SMatthias Springer namespace {
2719efe141SMatthias Springer 
2819efe141SMatthias Springer // bufferization.to_memref is not allowed to change the rank.
2919efe141SMatthias Springer static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
3019efe141SMatthias Springer #ifndef NDEBUG
3119efe141SMatthias Springer   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
3219efe141SMatthias Springer   assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
3319efe141SMatthias Springer                                 rankedTensorType.getRank())) &&
3419efe141SMatthias Springer          "to_memref would be invalid: mismatching ranks");
3519efe141SMatthias Springer #endif
3619efe141SMatthias Springer }
3719efe141SMatthias Springer 
3819efe141SMatthias Springer /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
3919efe141SMatthias Springer /// fully implemented at the moment.
4019efe141SMatthias Springer struct ExecuteRegionOpInterface
4119efe141SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
4219efe141SMatthias Springer                                                     scf::ExecuteRegionOp> {
4319efe141SMatthias Springer   SmallVector<OpOperand *>
4419efe141SMatthias Springer   getAliasingOpOperand(Operation *op, OpResult opResult,
459597b16aSMatthias Springer                        const AnalysisState &state) const {
4619efe141SMatthias Springer     // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
4719efe141SMatthias Springer     // any SSA value that is in scope. To allow for use-def chain traversal
4819efe141SMatthias Springer     // through ExecuteRegionOps in the analysis, the corresponding yield value
4919efe141SMatthias Springer     // is considered to be aliasing with the result.
5019efe141SMatthias Springer     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
5119efe141SMatthias Springer     size_t resultNum = std::distance(op->getOpResults().begin(),
5219efe141SMatthias Springer                                      llvm::find(op->getOpResults(), opResult));
5319efe141SMatthias Springer     // TODO: Support multiple blocks.
5419efe141SMatthias Springer     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
5519efe141SMatthias Springer            "expected exactly 1 block");
5619efe141SMatthias Springer     auto yieldOp = dyn_cast<scf::YieldOp>(
5719efe141SMatthias Springer         executeRegionOp.getRegion().front().getTerminator());
5819efe141SMatthias Springer     assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
5919efe141SMatthias Springer     return {&yieldOp->getOpOperand(resultNum)};
6019efe141SMatthias Springer   }
6119efe141SMatthias Springer 
6219efe141SMatthias Springer   // TODO: For better bufferization results, this could return `true` only if
6319efe141SMatthias Springer   // there is a memory write in the region.
6419efe141SMatthias Springer   bool isMemoryWrite(Operation *op, OpResult opResult,
659597b16aSMatthias Springer                      const AnalysisState &state) const {
6619efe141SMatthias Springer     // Similar to scf.if, results of this op are always considered memory writes
6719efe141SMatthias Springer     // in the analysis. This is a useful pattern for all ops that have tensor
6819efe141SMatthias Springer     // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
6919efe141SMatthias Springer     // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
7019efe141SMatthias Springer     // ops without OpOperands.
7119efe141SMatthias Springer     return true;
7219efe141SMatthias Springer   }
7319efe141SMatthias Springer 
7419efe141SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
759597b16aSMatthias Springer                           BufferizationState &state) const {
7619efe141SMatthias Springer     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
7719efe141SMatthias Springer 
7819efe141SMatthias Springer     // Compute new result types.
7919efe141SMatthias Springer     SmallVector<Type> newResultTypes;
8019efe141SMatthias Springer     for (Type type : executeRegionOp->getResultTypes()) {
8119efe141SMatthias Springer       if (auto tensorType = type.dyn_cast<TensorType>()) {
82*12e41d92SMatthias Springer         // TODO: Infer the result type instead of computing it.
8319efe141SMatthias Springer         newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
8419efe141SMatthias Springer       } else {
8519efe141SMatthias Springer         newResultTypes.push_back(type);
8619efe141SMatthias Springer       }
8719efe141SMatthias Springer     }
8819efe141SMatthias Springer 
8919efe141SMatthias Springer     // Create new op and move over region.
9019efe141SMatthias Springer     auto newOp =
9119efe141SMatthias Springer         rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
9219efe141SMatthias Springer     newOp.getRegion().takeBody(executeRegionOp.getRegion());
9319efe141SMatthias Springer 
9419efe141SMatthias Springer     // Update terminator.
9519efe141SMatthias Springer     assert(newOp.getRegion().getBlocks().size() == 1 &&
9619efe141SMatthias Springer            "only 1 block supported");
9719efe141SMatthias Springer     Block *newBlock = &newOp.getRegion().front();
9819efe141SMatthias Springer     auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
9919efe141SMatthias Springer     rewriter.setInsertionPoint(yieldOp);
10019efe141SMatthias Springer     SmallVector<Value> newYieldValues;
101bb6119ebSMehdi Amini     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
10219efe141SMatthias Springer       Value val = it.value();
10319efe141SMatthias Springer       if (val.getType().isa<TensorType>()) {
10419efe141SMatthias Springer         newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
10519efe141SMatthias Springer             yieldOp.getLoc(), newResultTypes[it.index()], val));
10619efe141SMatthias Springer       } else {
10719efe141SMatthias Springer         newYieldValues.push_back(val);
10819efe141SMatthias Springer       }
10919efe141SMatthias Springer     }
11019efe141SMatthias Springer     rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
11119efe141SMatthias Springer 
11219efe141SMatthias Springer     // Update all uses of the old op.
11319efe141SMatthias Springer     rewriter.setInsertionPointAfter(newOp);
11419efe141SMatthias Springer     SmallVector<Value> newResults;
115bb6119ebSMehdi Amini     for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
11619efe141SMatthias Springer       if (it.value().isa<TensorType>()) {
11719efe141SMatthias Springer         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
11819efe141SMatthias Springer             executeRegionOp.getLoc(), newOp->getResult(it.index())));
11919efe141SMatthias Springer       } else {
12019efe141SMatthias Springer         newResults.push_back(newOp->getResult(it.index()));
12119efe141SMatthias Springer       }
12219efe141SMatthias Springer     }
12319efe141SMatthias Springer 
12419efe141SMatthias Springer     // Replace old op.
12519efe141SMatthias Springer     rewriter.replaceOp(executeRegionOp, newResults);
12619efe141SMatthias Springer 
12719efe141SMatthias Springer     return success();
12819efe141SMatthias Springer   }
12919efe141SMatthias Springer 
13019efe141SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
1319597b16aSMatthias Springer                                 const AnalysisState &state) const {
13219efe141SMatthias Springer     return BufferRelation::Equivalent;
13319efe141SMatthias Springer   }
13419efe141SMatthias Springer };
13519efe141SMatthias Springer 
13619efe141SMatthias Springer /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
13719efe141SMatthias Springer struct IfOpInterface
13819efe141SMatthias Springer     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
13919efe141SMatthias Springer   SmallVector<OpOperand *>
14019efe141SMatthias Springer   getAliasingOpOperand(Operation *op, OpResult opResult,
1419597b16aSMatthias Springer                        const AnalysisState &state) const {
14219efe141SMatthias Springer     // IfOps do not have tensor OpOperands. The yielded value can be any SSA
14319efe141SMatthias Springer     // value that is in scope. To allow for use-def chain traversal through
14419efe141SMatthias Springer     // IfOps in the analysis, both corresponding yield values from the then/else
14519efe141SMatthias Springer     // branches are considered to be aliasing with the result.
14619efe141SMatthias Springer     auto ifOp = cast<scf::IfOp>(op);
14719efe141SMatthias Springer     size_t resultNum = std::distance(op->getOpResults().begin(),
14819efe141SMatthias Springer                                      llvm::find(op->getOpResults(), opResult));
14919efe141SMatthias Springer     return {&ifOp.thenYield()->getOpOperand(resultNum),
15019efe141SMatthias Springer             &ifOp.elseYield()->getOpOperand(resultNum)};
15119efe141SMatthias Springer   }
15219efe141SMatthias Springer 
15319efe141SMatthias Springer   // TODO: For better bufferization results, this could return `true` only if
15419efe141SMatthias Springer   // there is a memory write in one (or both) of the branches. Since this is not
15519efe141SMatthias Springer   // allowed at the moment, we should never encounter scf.ifs that yield
15619efe141SMatthias Springer   // unmodified tensors. Such scf.yield ops could just fold away.
15719efe141SMatthias Springer   bool isMemoryWrite(Operation *op, OpResult opResult,
1589597b16aSMatthias Springer                      const AnalysisState &state) const {
15919efe141SMatthias Springer     // IfOp results are always considered memory writes in the analysis. This
16019efe141SMatthias Springer     // design decision simplifies the analysis considerably. E.g., consider the
16119efe141SMatthias Springer     // following test case:
16219efe141SMatthias Springer     //
16319efe141SMatthias Springer     // %0 = "some_writing_op" : tensor<?xf32>
16419efe141SMatthias Springer     // %r = scf.if %c -> (tensor<?xf32>) {
16519efe141SMatthias Springer     //   scf.yield %0
16619efe141SMatthias Springer     // } else {
16719efe141SMatthias Springer     //   %1 = "another_writing_op"(%0) : tensor<?xf32>
16819efe141SMatthias Springer     // }
16919efe141SMatthias Springer     // "some_reading_op"(%r)
17019efe141SMatthias Springer     //
17119efe141SMatthias Springer     // "another_writing_op" in the above example should be able to bufferize
17219efe141SMatthias Springer     // inplace in the absence of another read of %0. However, if the scf.if op
17319efe141SMatthias Springer     // would not be considered a "write", the analysis would detect the
17419efe141SMatthias Springer     // following conflict:
17519efe141SMatthias Springer     //
17619efe141SMatthias Springer     // * read = some_reading_op
17719efe141SMatthias Springer     // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
17819efe141SMatthias Springer     // * conflictingWrite = %1
17919efe141SMatthias Springer     //
18019efe141SMatthias Springer     // For more details, check the "scf.IfOp" section of the design document.
18119efe141SMatthias Springer     return true;
18219efe141SMatthias Springer   }
18319efe141SMatthias Springer 
18419efe141SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1859597b16aSMatthias Springer                           BufferizationState &state) const {
18619efe141SMatthias Springer     auto ifOp = cast<scf::IfOp>(op);
18719efe141SMatthias Springer 
18819efe141SMatthias Springer     // Compute new types of the bufferized scf.if op.
18919efe141SMatthias Springer     SmallVector<Type> newTypes;
19019efe141SMatthias Springer     for (Type returnType : ifOp->getResultTypes()) {
19119efe141SMatthias Springer       if (auto tensorType = returnType.dyn_cast<TensorType>()) {
192*12e41d92SMatthias Springer         // TODO: Infer the result type instead of computing it.
19319efe141SMatthias Springer         newTypes.push_back(getMemRefType(tensorType, state.getOptions()));
19419efe141SMatthias Springer       } else {
19519efe141SMatthias Springer         newTypes.push_back(returnType);
19619efe141SMatthias Springer       }
19719efe141SMatthias Springer     }
19819efe141SMatthias Springer 
19919efe141SMatthias Springer     // Create new op.
20019efe141SMatthias Springer     auto newIfOp =
20119efe141SMatthias Springer         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
20219efe141SMatthias Springer                                    /*withElseRegion=*/true);
20319efe141SMatthias Springer 
20419efe141SMatthias Springer     // Remove terminators.
20519efe141SMatthias Springer     if (!newIfOp.thenBlock()->empty()) {
20619efe141SMatthias Springer       rewriter.eraseOp(newIfOp.thenBlock()->getTerminator());
20719efe141SMatthias Springer       rewriter.eraseOp(newIfOp.elseBlock()->getTerminator());
20819efe141SMatthias Springer     }
20919efe141SMatthias Springer 
21019efe141SMatthias Springer     // Move over then/else blocks.
21119efe141SMatthias Springer     rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
21219efe141SMatthias Springer     rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
21319efe141SMatthias Springer 
21419efe141SMatthias Springer     // Update scf.yield of new then-block.
21519efe141SMatthias Springer     auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator());
21619efe141SMatthias Springer     rewriter.setInsertionPoint(thenYieldOp);
21719efe141SMatthias Springer     SmallVector<Value> thenYieldValues;
21819efe141SMatthias Springer     for (OpOperand &operand : thenYieldOp->getOpOperands()) {
21919efe141SMatthias Springer       if (operand.get().getType().isa<TensorType>()) {
22019efe141SMatthias Springer         ensureToMemrefOpIsValid(operand.get(),
22119efe141SMatthias Springer                                 newTypes[operand.getOperandNumber()]);
22219efe141SMatthias Springer         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
22319efe141SMatthias Springer             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
22419efe141SMatthias Springer             operand.get());
22519efe141SMatthias Springer         operand.set(toMemrefOp);
22619efe141SMatthias Springer       }
22719efe141SMatthias Springer     }
22819efe141SMatthias Springer 
22919efe141SMatthias Springer     // Update scf.yield of new else-block.
23019efe141SMatthias Springer     auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator());
23119efe141SMatthias Springer     rewriter.setInsertionPoint(elseYieldOp);
23219efe141SMatthias Springer     SmallVector<Value> elseYieldValues;
23319efe141SMatthias Springer     for (OpOperand &operand : elseYieldOp->getOpOperands()) {
23419efe141SMatthias Springer       if (operand.get().getType().isa<TensorType>()) {
23519efe141SMatthias Springer         ensureToMemrefOpIsValid(operand.get(),
23619efe141SMatthias Springer                                 newTypes[operand.getOperandNumber()]);
23719efe141SMatthias Springer         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
23819efe141SMatthias Springer             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
23919efe141SMatthias Springer             operand.get());
24019efe141SMatthias Springer         operand.set(toMemrefOp);
24119efe141SMatthias Springer       }
24219efe141SMatthias Springer     }
24319efe141SMatthias Springer 
24419efe141SMatthias Springer     // Replace op results.
24519efe141SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
24619efe141SMatthias Springer 
24719efe141SMatthias Springer     return success();
24819efe141SMatthias Springer   }
24919efe141SMatthias Springer 
25019efe141SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
2519597b16aSMatthias Springer                                 const AnalysisState &state) const {
25219efe141SMatthias Springer     // IfOp results are equivalent to their corresponding yield values if both
25319efe141SMatthias Springer     // yield values are equivalent to each other.
25419efe141SMatthias Springer     auto bufferizableOp = cast<BufferizableOpInterface>(op);
25519efe141SMatthias Springer     SmallVector<OpOperand *> yieldValues =
25619efe141SMatthias Springer         bufferizableOp.getAliasingOpOperand(opResult, state);
25719efe141SMatthias Springer     assert(yieldValues.size() == 2 && "expected 2 yield values");
25819efe141SMatthias Springer     bool equivalentYields = state.areEquivalentBufferizedValues(
25919efe141SMatthias Springer         yieldValues[0]->get(), yieldValues[1]->get());
26019efe141SMatthias Springer     return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
26119efe141SMatthias Springer   }
26219efe141SMatthias Springer };
26319efe141SMatthias Springer 
264417e1c7dSMatthias Springer /// Helper function for loop bufferization. Return the indices of all values
265417e1c7dSMatthias Springer /// that have a tensor type.
266417e1c7dSMatthias Springer static DenseSet<int64_t> getTensorIndices(ValueRange values) {
267417e1c7dSMatthias Springer   DenseSet<int64_t> result;
268417e1c7dSMatthias Springer   for (const auto &it : llvm::enumerate(values))
269417e1c7dSMatthias Springer     if (it.value().getType().isa<TensorType>())
270417e1c7dSMatthias Springer       result.insert(it.index());
271417e1c7dSMatthias Springer   return result;
272417e1c7dSMatthias Springer }
273417e1c7dSMatthias Springer 
274417e1c7dSMatthias Springer /// Helper function for loop bufferization. Return the indices of all
275417e1c7dSMatthias Springer /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
276a5d09c63SMatthias Springer DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
277417e1c7dSMatthias Springer                                        ValueRange yieldedValues,
278417e1c7dSMatthias Springer                                        const AnalysisState &state) {
279417e1c7dSMatthias Springer   DenseSet<int64_t> result;
280417e1c7dSMatthias Springer   int64_t counter = 0;
281417e1c7dSMatthias Springer   for (const auto &it : llvm::zip(bbArgs, yieldedValues)) {
282417e1c7dSMatthias Springer     if (!std::get<0>(it).getType().isa<TensorType>())
283417e1c7dSMatthias Springer       continue;
284417e1c7dSMatthias Springer     if (state.areEquivalentBufferizedValues(std::get<0>(it), std::get<1>(it)))
285417e1c7dSMatthias Springer       result.insert(counter);
286417e1c7dSMatthias Springer     counter++;
287417e1c7dSMatthias Springer   }
288417e1c7dSMatthias Springer   return result;
289417e1c7dSMatthias Springer }
290417e1c7dSMatthias Springer 
291417e1c7dSMatthias Springer /// Helper function for loop bufferization. Cast the given buffer to the given
292417e1c7dSMatthias Springer /// memref type.
293417e1c7dSMatthias Springer static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
294417e1c7dSMatthias Springer   assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
295417e1c7dSMatthias Springer   assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
296417e1c7dSMatthias Springer   // If the buffer already has the correct type, no cast is needed.
297417e1c7dSMatthias Springer   if (buffer.getType() == type)
298417e1c7dSMatthias Springer     return buffer;
299417e1c7dSMatthias Springer   // TODO: In case `type` has a layout map that is not the fully dynamic
300417e1c7dSMatthias Springer   // one, we may not be able to cast the buffer. In that case, the loop
301417e1c7dSMatthias Springer   // iter_arg's layout map must be changed (see uses of `castBuffer`).
302417e1c7dSMatthias Springer   assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
303417e1c7dSMatthias Springer          "scf.while op bufferization: cast incompatible");
304417e1c7dSMatthias Springer   return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
305417e1c7dSMatthias Springer }
306417e1c7dSMatthias Springer 
307417e1c7dSMatthias Springer /// Helper function for loop bufferization. Return the bufferized values of the
308417e1c7dSMatthias Springer /// given OpOperands. If an operand is not a tensor, return the original value.
309417e1c7dSMatthias Springer static SmallVector<Value> getBuffers(RewriterBase &rewriter,
310417e1c7dSMatthias Springer                                      MutableArrayRef<OpOperand> operands,
311417e1c7dSMatthias Springer                                      BufferizationState &state) {
312417e1c7dSMatthias Springer   SmallVector<Value> result;
313417e1c7dSMatthias Springer   for (OpOperand &opOperand : operands) {
314417e1c7dSMatthias Springer     if (opOperand.get().getType().isa<TensorType>()) {
315417e1c7dSMatthias Springer       FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
316417e1c7dSMatthias Springer       if (failed(resultBuffer))
317417e1c7dSMatthias Springer         return {};
318417e1c7dSMatthias Springer       result.push_back(*resultBuffer);
319417e1c7dSMatthias Springer     } else {
320417e1c7dSMatthias Springer       result.push_back(opOperand.get());
321417e1c7dSMatthias Springer     }
322417e1c7dSMatthias Springer   }
323417e1c7dSMatthias Springer   return result;
324417e1c7dSMatthias Springer }
325417e1c7dSMatthias Springer 
326417e1c7dSMatthias Springer /// Helper function for loop bufferization. Compute the buffer that should be
327417e1c7dSMatthias Springer /// yielded from a loop block (loop body or loop condition). If the given tensor
328417e1c7dSMatthias Springer /// is equivalent to the corresponding block argument (as indicated by
329417e1c7dSMatthias Springer /// `isEquivalent`), the buffer can be yielded directly. Otherwise, a new buffer
330417e1c7dSMatthias Springer /// copy must be yielded.
331417e1c7dSMatthias Springer ///
332417e1c7dSMatthias Springer /// According to the `BufferizableOpInterface` implementation of scf loops, a
333417e1c7dSMatthias Springer /// a bufferized OpResult may alias only with the corresponding bufferized
334417e1c7dSMatthias Springer /// init_arg and with no other buffers. I.e., the i-th OpResult may alias with
335417e1c7dSMatthias Springer /// the i-th init_arg; but not with any other OpOperand. If a corresponding
336417e1c7dSMatthias Springer /// OpResult/init_arg pair bufferized to equivalent buffers (as indicated by
337417e1c7dSMatthias Springer /// `isEquivalent`), this aliasing requirement is satisfied. Otherwise, we
338417e1c7dSMatthias Springer /// cannot be sure and must yield a new buffer copy. (New buffer copies do not
339417e1c7dSMatthias Springer /// alias with any buffer.)
340417e1c7dSMatthias Springer static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
341417e1c7dSMatthias Springer                               BaseMemRefType type, bool isEquivalent,
342417e1c7dSMatthias Springer                               BufferizationState &state) {
343417e1c7dSMatthias Springer   assert(tensor.getType().isa<TensorType>() && "expected tensor");
344417e1c7dSMatthias Springer   ensureToMemrefOpIsValid(tensor, type);
345417e1c7dSMatthias Springer   Value yieldedVal =
346417e1c7dSMatthias Springer       bufferization::lookupBuffer(rewriter, tensor, state.getOptions());
347417e1c7dSMatthias Springer 
348417e1c7dSMatthias Springer   if (isEquivalent)
349417e1c7dSMatthias Springer     // Yielded value is equivalent to the corresponding iter_arg bbArg.
350417e1c7dSMatthias Springer     // Yield the value directly. Most IR should be like that. Everything
351417e1c7dSMatthias Springer     // else must be resolved with copies and is potentially inefficient.
352417e1c7dSMatthias Springer     // By default, such problematic IR would already have been rejected
353417e1c7dSMatthias Springer     // during `verifyAnalysis`, unless `allow-return-allocs`.
354417e1c7dSMatthias Springer     return castBuffer(rewriter, yieldedVal, type);
355417e1c7dSMatthias Springer 
356417e1c7dSMatthias Springer   // It is not certain that the yielded value and the iter_arg bbArg
357417e1c7dSMatthias Springer   // have the same buffer. Allocate a new buffer and copy. The yielded
358417e1c7dSMatthias Springer   // buffer will get deallocated by `deallocateBuffers`.
359417e1c7dSMatthias Springer 
360417e1c7dSMatthias Springer   // TODO: There are cases in which it is not neccessary to return a new
361417e1c7dSMatthias Springer   // buffer allocation. E.g., when equivalent values are yielded in a
362417e1c7dSMatthias Springer   // different order. This could be resolved with copies.
363417e1c7dSMatthias Springer   Optional<Value> yieldedAlloc = state.createAlloc(
364417e1c7dSMatthias Springer       rewriter, tensor.getLoc(), yieldedVal, /*deallocMemref=*/false);
365417e1c7dSMatthias Springer   // TODO: We should rollback, but for now just assume that this always
366417e1c7dSMatthias Springer   // succeeds.
367417e1c7dSMatthias Springer   assert(yieldedAlloc.hasValue() && "could not create alloc");
368248e113eSMatthias Springer   LogicalResult copyStatus = state.getOptions().createMemCpy(
369248e113eSMatthias Springer       rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc);
370417e1c7dSMatthias Springer   (void)copyStatus;
371417e1c7dSMatthias Springer   assert(succeeded(copyStatus) && "could not create memcpy");
372417e1c7dSMatthias Springer 
373417e1c7dSMatthias Springer   // The iter_arg memref type may have a layout map. Cast the new buffer
374417e1c7dSMatthias Springer   // to the same type if needed.
375417e1c7dSMatthias Springer   return castBuffer(rewriter, *yieldedAlloc, type);
376417e1c7dSMatthias Springer }
377417e1c7dSMatthias Springer 
378417e1c7dSMatthias Springer /// Helper function for loop bufferization. Given a range of values, apply
379417e1c7dSMatthias Springer /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
380417e1c7dSMatthias Springer /// value in the result vector.
381417e1c7dSMatthias Springer static SmallVector<Value>
382417e1c7dSMatthias Springer convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
383417e1c7dSMatthias Springer                     llvm::function_ref<Value(Value, int64_t)> func) {
384417e1c7dSMatthias Springer   SmallVector<Value> result;
385417e1c7dSMatthias Springer   for (const auto &it : llvm::enumerate(values)) {
386417e1c7dSMatthias Springer     size_t idx = it.index();
387417e1c7dSMatthias Springer     Value val = it.value();
388417e1c7dSMatthias Springer     result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val);
389417e1c7dSMatthias Springer   }
390417e1c7dSMatthias Springer   return result;
391417e1c7dSMatthias Springer }
392417e1c7dSMatthias Springer 
393417e1c7dSMatthias Springer /// Helper function for loop bufferization. Given a list of pre-bufferization
394417e1c7dSMatthias Springer /// yielded values, compute the list of bufferized yielded values.
395417e1c7dSMatthias Springer SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
396417e1c7dSMatthias Springer                                     TypeRange bufferizedTypes,
397417e1c7dSMatthias Springer                                     const DenseSet<int64_t> &tensorIndices,
398417e1c7dSMatthias Springer                                     const DenseSet<int64_t> &equivalentTensors,
399417e1c7dSMatthias Springer                                     BufferizationState &state) {
400417e1c7dSMatthias Springer   return convertTensorValues(
401417e1c7dSMatthias Springer       values, tensorIndices, [&](Value val, int64_t index) {
402417e1c7dSMatthias Springer         return getYieldedBuffer(rewriter, val,
403417e1c7dSMatthias Springer                                 bufferizedTypes[index].cast<BaseMemRefType>(),
404417e1c7dSMatthias Springer                                 equivalentTensors.contains(index), state);
405417e1c7dSMatthias Springer       });
406417e1c7dSMatthias Springer }
407417e1c7dSMatthias Springer 
408a5d09c63SMatthias Springer /// Helper function for loop bufferization. Given a list of bbArgs of the new
409a5d09c63SMatthias Springer /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
410a5d09c63SMatthias Springer /// ToTensorOps, so that the block body can be moved over to the new op.
411a5d09c63SMatthias Springer SmallVector<Value>
412a5d09c63SMatthias Springer getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
413a5d09c63SMatthias Springer                      const DenseSet<int64_t> &tensorIndices) {
414a5d09c63SMatthias Springer   return convertTensorValues(
415a5d09c63SMatthias Springer       bbArgs, tensorIndices, [&](Value val, int64_t index) {
416a5d09c63SMatthias Springer         return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
417a5d09c63SMatthias Springer       });
418a5d09c63SMatthias Springer }
419a5d09c63SMatthias Springer 
42019efe141SMatthias Springer /// Bufferization of scf.for. Replace with a new scf.for that operates on
42119efe141SMatthias Springer /// memrefs.
42219efe141SMatthias Springer struct ForOpInterface
42319efe141SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
42419efe141SMatthias Springer                                                     scf::ForOp> {
42519efe141SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
4269597b16aSMatthias Springer                               const AnalysisState &state) const {
42719efe141SMatthias Springer     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
42819efe141SMatthias Springer     // its matching bbArg may.
42919efe141SMatthias Springer     auto forOp = cast<scf::ForOp>(op);
43019efe141SMatthias Springer     return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
43119efe141SMatthias Springer   }
43219efe141SMatthias Springer 
43319efe141SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
4349597b16aSMatthias Springer                                const AnalysisState &state) const {
4351e1eeae8SMatthias Springer     // Tensor iter_args of scf::ForOps are always considered as a write.
43619efe141SMatthias Springer     return true;
43719efe141SMatthias Springer   }
43819efe141SMatthias Springer 
4399597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
4409597b16aSMatthias Springer                                             const AnalysisState &state) const {
44119efe141SMatthias Springer     auto forOp = cast<scf::ForOp>(op);
442585a8a32SMatthias Springer     return {forOp.getResultForOpOperand(opOperand)};
44319efe141SMatthias Springer   }
44419efe141SMatthias Springer 
44519efe141SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
4469597b16aSMatthias Springer                                 const AnalysisState &state) const {
44719efe141SMatthias Springer     // ForOp results are equivalent to their corresponding init_args if the
44819efe141SMatthias Springer     // corresponding iter_args and yield values are equivalent.
44919efe141SMatthias Springer     auto forOp = cast<scf::ForOp>(op);
45019efe141SMatthias Springer     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
45119efe141SMatthias Springer     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
4521e1eeae8SMatthias Springer     auto yieldOp =
4531e1eeae8SMatthias Springer         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
45419efe141SMatthias Springer     bool equivalentYield = state.areEquivalentBufferizedValues(
45519efe141SMatthias Springer         bbArg, yieldOp->getOperand(opResult.getResultNumber()));
45619efe141SMatthias Springer     return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
45719efe141SMatthias Springer   }
45819efe141SMatthias Springer 
45919efe141SMatthias Springer   bool isWritable(Operation *op, Value value,
4609597b16aSMatthias Springer                   const AnalysisState &state) const {
46119efe141SMatthias Springer     // Interestingly, scf::ForOp's bbArg can **always** be viewed
46219efe141SMatthias Springer     // inplace from the perspective of ops nested under:
46319efe141SMatthias Springer     //   1. Either the matching iter operand is not bufferized inplace and an
46419efe141SMatthias Springer     //      alloc + optional copy makes the bbArg itself inplaceable.
46519efe141SMatthias Springer     //   2. Or the matching iter operand is bufferized inplace and bbArg just
46619efe141SMatthias Springer     //      bufferizes to that too.
46719efe141SMatthias Springer     return true;
46819efe141SMatthias Springer   }
46919efe141SMatthias Springer 
47019efe141SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
4719597b16aSMatthias Springer                           BufferizationState &state) const {
47219efe141SMatthias Springer     auto forOp = cast<scf::ForOp>(op);
473417e1c7dSMatthias Springer     auto oldYieldOp =
474417e1c7dSMatthias Springer         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
47519efe141SMatthias Springer     Block *oldLoopBody = &forOp.getLoopBody().front();
47619efe141SMatthias Springer 
47719efe141SMatthias Springer     // Indices of all iter_args that have tensor type. These are the ones that
47819efe141SMatthias Springer     // are bufferized.
479417e1c7dSMatthias Springer     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
4801e1eeae8SMatthias Springer     // For every yielded value, is the value equivalent to its corresponding
4811e1eeae8SMatthias Springer     // bbArg?
482417e1c7dSMatthias Springer     DenseSet<int64_t> equivalentYields =
483417e1c7dSMatthias Springer         getEquivalentBuffers(forOp.getRegionIterArgs(), oldYieldOp.getResults(),
484417e1c7dSMatthias Springer                              state.getAnalysisState());
48519efe141SMatthias Springer 
486417e1c7dSMatthias Springer     // The new memref init_args of the loop.
487417e1c7dSMatthias Springer     SmallVector<Value> initArgs =
488417e1c7dSMatthias Springer         getBuffers(rewriter, forOp.getIterOpOperands(), state);
489417e1c7dSMatthias Springer     if (initArgs.size() != indices.size())
490417e1c7dSMatthias Springer       return failure();
49119efe141SMatthias Springer 
49219efe141SMatthias Springer     // Construct a new scf.for op with memref instead of tensor values.
49319efe141SMatthias Springer     auto newForOp = rewriter.create<scf::ForOp>(
49419efe141SMatthias Springer         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
49519efe141SMatthias Springer         forOp.getStep(), initArgs);
496417e1c7dSMatthias Springer     ValueRange initArgsRange(initArgs);
497417e1c7dSMatthias Springer     TypeRange initArgsTypes(initArgsRange);
49819efe141SMatthias Springer     Block *loopBody = &newForOp.getLoopBody().front();
49919efe141SMatthias Springer 
50019efe141SMatthias Springer     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
50119efe141SMatthias Springer     // iter_args of the new loop in ToTensorOps.
50219efe141SMatthias Springer     rewriter.setInsertionPointToStart(loopBody);
503a5d09c63SMatthias Springer     SmallVector<Value> iterArgs =
504a5d09c63SMatthias Springer         getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
50519efe141SMatthias Springer     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
50619efe141SMatthias Springer 
50719efe141SMatthias Springer     // Erase terminator if present.
50819efe141SMatthias Springer     if (iterArgs.size() == 1)
50919efe141SMatthias Springer       rewriter.eraseOp(loopBody->getTerminator());
51019efe141SMatthias Springer 
51119efe141SMatthias Springer     // Move loop body to new loop.
51219efe141SMatthias Springer     rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
51319efe141SMatthias Springer 
51419efe141SMatthias Springer     // Update scf.yield of new loop.
51519efe141SMatthias Springer     auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
51619efe141SMatthias Springer     rewriter.setInsertionPoint(yieldOp);
51719efe141SMatthias Springer     SmallVector<Value> yieldValues =
518417e1c7dSMatthias Springer         getYieldedValues(rewriter, yieldOp.getResults(), initArgsTypes, indices,
519417e1c7dSMatthias Springer                          equivalentYields, state);
52019efe141SMatthias Springer     yieldOp.getResultsMutable().assign(yieldValues);
52119efe141SMatthias Springer 
52219efe141SMatthias Springer     // Replace loop results.
52319efe141SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
52419efe141SMatthias Springer 
52519efe141SMatthias Springer     return success();
52619efe141SMatthias Springer   }
5274ec00fb3SMatthias Springer 
5281e1eeae8SMatthias Springer   /// Assert that yielded values of an scf.for op are equivalent to their
529f178c386SMatthias Springer   /// corresponding bbArgs. In that case, the buffer relations of the
530f178c386SMatthias Springer   /// corresponding OpResults are "Equivalent".
531f178c386SMatthias Springer   ///
532f178c386SMatthias Springer   /// If this is not the case, an allocs+copies are inserted and yielded from
533f178c386SMatthias Springer   /// the loop. This could be a performance problem, so it must be explicitly
534f178c386SMatthias Springer   /// activated with `alloc-return-allocs`.
5354ec00fb3SMatthias Springer   LogicalResult verifyAnalysis(Operation *op,
5369597b16aSMatthias Springer                                const AnalysisState &state) const {
5371e1eeae8SMatthias Springer     const auto &options =
5381e1eeae8SMatthias Springer         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
5391e1eeae8SMatthias Springer     if (options.allowReturnAllocs)
5401e1eeae8SMatthias Springer       return success();
5411e1eeae8SMatthias Springer 
5424ec00fb3SMatthias Springer     auto forOp = cast<scf::ForOp>(op);
5434ec00fb3SMatthias Springer     auto yieldOp =
5444ec00fb3SMatthias Springer         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
545f178c386SMatthias Springer     for (OpResult opResult : op->getOpResults()) {
546f178c386SMatthias Springer       if (!opResult.getType().isa<TensorType>())
5474ec00fb3SMatthias Springer         continue;
5484ec00fb3SMatthias Springer 
5494ec00fb3SMatthias Springer       // Note: This is overly strict. We should check for aliasing bufferized
5504ec00fb3SMatthias Springer       // values. But we don't have a "must-alias" analysis yet.
551f178c386SMatthias Springer       if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
5524ec00fb3SMatthias Springer         return yieldOp->emitError()
553f178c386SMatthias Springer                << "Yield operand #" << opResult.getResultNumber()
554e3006825SMatthias Springer                << " is not equivalent to the corresponding iter bbArg";
5554ec00fb3SMatthias Springer     }
556f178c386SMatthias Springer 
5574ec00fb3SMatthias Springer     return success();
5584ec00fb3SMatthias Springer   }
55919efe141SMatthias Springer };
56019efe141SMatthias Springer 
561a5d09c63SMatthias Springer /// Bufferization of scf.while. Replace with a new scf.while that operates on
562a5d09c63SMatthias Springer /// memrefs.
563a5d09c63SMatthias Springer struct WhileOpInterface
564a5d09c63SMatthias Springer     : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
565a5d09c63SMatthias Springer                                                     scf::WhileOp> {
566a5d09c63SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
567a5d09c63SMatthias Springer                               const AnalysisState &state) const {
568a5d09c63SMatthias Springer     // Tensor iter_args of scf::WhileOps are always considered as a read.
569a5d09c63SMatthias Springer     return true;
570a5d09c63SMatthias Springer   }
571a5d09c63SMatthias Springer 
572a5d09c63SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
573a5d09c63SMatthias Springer                                const AnalysisState &state) const {
574a5d09c63SMatthias Springer     // Tensor iter_args of scf::WhileOps are always considered as a write.
575a5d09c63SMatthias Springer     return true;
576a5d09c63SMatthias Springer   }
577a5d09c63SMatthias Springer 
578a5d09c63SMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
579a5d09c63SMatthias Springer                                             const AnalysisState &state) const {
580a5d09c63SMatthias Springer     auto whileOp = cast<scf::WhileOp>(op);
581a5d09c63SMatthias Springer     return {whileOp->getResult(opOperand.getOperandNumber())};
582a5d09c63SMatthias Springer   }
583a5d09c63SMatthias Springer 
584a5d09c63SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
585a5d09c63SMatthias Springer                                 const AnalysisState &state) const {
586a5d09c63SMatthias Springer     // WhileOp results are equivalent to their corresponding init_args if the
587a5d09c63SMatthias Springer     // corresponding iter_args and yield values are equivalent (for both the
588a5d09c63SMatthias Springer     // "before" and the "after" block).
589a5d09c63SMatthias Springer     unsigned int resultNumber = opResult.getResultNumber();
590a5d09c63SMatthias Springer     auto whileOp = cast<scf::WhileOp>(op);
591a5d09c63SMatthias Springer 
592a5d09c63SMatthias Springer     auto conditionOp = whileOp.getConditionOp();
593a5d09c63SMatthias Springer     BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
594a5d09c63SMatthias Springer     Value conditionOperand = conditionOp.getArgs()[resultNumber];
595a5d09c63SMatthias Springer     bool equivCondition =
596a5d09c63SMatthias Springer         state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
597a5d09c63SMatthias Springer 
598a5d09c63SMatthias Springer     auto yieldOp = whileOp.getYieldOp();
599a5d09c63SMatthias Springer     BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
600a5d09c63SMatthias Springer     Value yieldOperand = yieldOp.getOperand(resultNumber);
601a5d09c63SMatthias Springer     bool equivYield =
602a5d09c63SMatthias Springer         state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
603a5d09c63SMatthias Springer 
604a5d09c63SMatthias Springer     return equivCondition && equivYield ? BufferRelation::Equivalent
605a5d09c63SMatthias Springer                                         : BufferRelation::None;
606a5d09c63SMatthias Springer   }
607a5d09c63SMatthias Springer 
608a5d09c63SMatthias Springer   bool isWritable(Operation *op, Value value,
609a5d09c63SMatthias Springer                   const AnalysisState &state) const {
610a5d09c63SMatthias Springer     // Interestingly, scf::WhileOp's bbArg can **always** be viewed
611a5d09c63SMatthias Springer     // inplace from the perspective of ops nested under:
612a5d09c63SMatthias Springer     //   1. Either the matching iter operand is not bufferized inplace and an
613a5d09c63SMatthias Springer     //      alloc + optional copy makes the bbArg itself inplaceable.
614a5d09c63SMatthias Springer     //   2. Or the matching iter operand is bufferized inplace and bbArg just
615a5d09c63SMatthias Springer     //      bufferizes to that too.
616a5d09c63SMatthias Springer     return true;
617a5d09c63SMatthias Springer   }
618a5d09c63SMatthias Springer 
619a5d09c63SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
620a5d09c63SMatthias Springer                           BufferizationState &state) const {
621a5d09c63SMatthias Springer     auto whileOp = cast<scf::WhileOp>(op);
622a5d09c63SMatthias Springer 
623a5d09c63SMatthias Springer     assert(whileOp.getBefore().getBlocks().size() == 1 &&
624a5d09c63SMatthias Springer            "regions with multiple blocks not supported");
625a5d09c63SMatthias Springer     Block *beforeBody = &whileOp.getBefore().front();
626a5d09c63SMatthias Springer     assert(whileOp.getAfter().getBlocks().size() == 1 &&
627a5d09c63SMatthias Springer            "regions with multiple blocks not supported");
628a5d09c63SMatthias Springer     Block *afterBody = &whileOp.getAfter().front();
629a5d09c63SMatthias Springer 
630a5d09c63SMatthias Springer     // Indices of all iter_args that have tensor type. These are the ones that
631a5d09c63SMatthias Springer     // are bufferized.
632a5d09c63SMatthias Springer     DenseSet<int64_t> indices = getTensorIndices(whileOp.getInits());
633a5d09c63SMatthias Springer     // For every yielded value, is the value equivalent to its corresponding
634a5d09c63SMatthias Springer     // bbArg?
635a5d09c63SMatthias Springer     DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
636a5d09c63SMatthias Springer         whileOp.getBeforeArguments(), whileOp.getConditionOp().getArgs(),
637a5d09c63SMatthias Springer         state.getAnalysisState());
638a5d09c63SMatthias Springer     DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
639a5d09c63SMatthias Springer         whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(),
640a5d09c63SMatthias Springer         state.getAnalysisState());
641a5d09c63SMatthias Springer 
642a5d09c63SMatthias Springer     // The new memref init_args of the loop.
643a5d09c63SMatthias Springer     SmallVector<Value> initArgs =
644a5d09c63SMatthias Springer         getBuffers(rewriter, whileOp->getOpOperands(), state);
645a5d09c63SMatthias Springer     if (initArgs.size() != indices.size())
646a5d09c63SMatthias Springer       return failure();
647a5d09c63SMatthias Springer 
648a5d09c63SMatthias Springer     // Construct a new scf.while op with memref instead of tensor values.
649a5d09c63SMatthias Springer     ValueRange argsRange(initArgs);
650a5d09c63SMatthias Springer     TypeRange argsTypes(argsRange);
651a5d09c63SMatthias Springer     auto newWhileOp =
652a5d09c63SMatthias Springer         rewriter.create<scf::WhileOp>(whileOp.getLoc(), argsTypes, initArgs);
653a5d09c63SMatthias Springer     // Add before/after regions to the new op.
654a5d09c63SMatthias Springer     SmallVector<Location> bbArgLocs(initArgs.size(), whileOp.getLoc());
655a5d09c63SMatthias Springer     Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
656a5d09c63SMatthias Springer     newWhileOp.getBefore().addArguments(argsTypes, bbArgLocs);
657a5d09c63SMatthias Springer     Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
658a5d09c63SMatthias Springer     newWhileOp.getAfter().addArguments(argsTypes, bbArgLocs);
659a5d09c63SMatthias Springer 
660a5d09c63SMatthias Springer     // Set up new iter_args and move the loop condition block to the new op.
661a5d09c63SMatthias Springer     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
662a5d09c63SMatthias Springer     // in ToTensorOps.
663a5d09c63SMatthias Springer     rewriter.setInsertionPointToStart(newBeforeBody);
664a5d09c63SMatthias Springer     SmallVector<Value> newBeforeArgs = getBbArgReplacements(
665a5d09c63SMatthias Springer         rewriter, newWhileOp.getBeforeArguments(), indices);
666a5d09c63SMatthias Springer     rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
667a5d09c63SMatthias Springer 
668a5d09c63SMatthias Springer     // Update scf.condition of new loop.
669a5d09c63SMatthias Springer     auto newConditionOp = newWhileOp.getConditionOp();
670a5d09c63SMatthias Springer     rewriter.setInsertionPoint(newConditionOp);
671a5d09c63SMatthias Springer     SmallVector<Value> newConditionArgs =
672a5d09c63SMatthias Springer         getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypes, indices,
673a5d09c63SMatthias Springer                          equivalentYieldsBefore, state);
674a5d09c63SMatthias Springer     newConditionOp.getArgsMutable().assign(newConditionArgs);
675a5d09c63SMatthias Springer 
676a5d09c63SMatthias Springer     // Set up new iter_args and move the loop body block to the new op.
677a5d09c63SMatthias Springer     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
678a5d09c63SMatthias Springer     // in ToTensorOps.
679a5d09c63SMatthias Springer     rewriter.setInsertionPointToStart(newAfterBody);
680a5d09c63SMatthias Springer     SmallVector<Value> newAfterArgs =
681a5d09c63SMatthias Springer         getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(), indices);
682a5d09c63SMatthias Springer     rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
683a5d09c63SMatthias Springer 
684a5d09c63SMatthias Springer     // Update scf.yield of the new loop.
685a5d09c63SMatthias Springer     auto newYieldOp = newWhileOp.getYieldOp();
686a5d09c63SMatthias Springer     rewriter.setInsertionPoint(newYieldOp);
687a5d09c63SMatthias Springer     SmallVector<Value> newYieldValues =
688a5d09c63SMatthias Springer         getYieldedValues(rewriter, newYieldOp.getResults(), argsTypes, indices,
689a5d09c63SMatthias Springer                          equivalentYieldsAfter, state);
690a5d09c63SMatthias Springer     newYieldOp.getResultsMutable().assign(newYieldValues);
691a5d09c63SMatthias Springer 
692a5d09c63SMatthias Springer     // Replace loop results.
693a5d09c63SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
694a5d09c63SMatthias Springer 
695a5d09c63SMatthias Springer     return success();
696a5d09c63SMatthias Springer   }
697a5d09c63SMatthias Springer 
698a5d09c63SMatthias Springer   /// Assert that yielded values of an scf.while op are equivalent to their
699a5d09c63SMatthias Springer   /// corresponding bbArgs. In that case, the buffer relations of the
700a5d09c63SMatthias Springer   /// corresponding OpResults are "Equivalent".
701a5d09c63SMatthias Springer   ///
702a5d09c63SMatthias Springer   /// If this is not the case, allocs+copies are inserted and yielded from
703a5d09c63SMatthias Springer   /// the loop. This could be a performance problem, so it must be explicitly
704a5d09c63SMatthias Springer   /// activated with `alloc-return-allocs`.
705a5d09c63SMatthias Springer   ///
706a5d09c63SMatthias Springer   /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
707a5d09c63SMatthias Springer   /// equivalence condition must be checked for both.
708a5d09c63SMatthias Springer   LogicalResult verifyAnalysis(Operation *op,
709a5d09c63SMatthias Springer                                const AnalysisState &state) const {
710a5d09c63SMatthias Springer     auto whileOp = cast<scf::WhileOp>(op);
711a5d09c63SMatthias Springer     const auto &options =
712a5d09c63SMatthias Springer         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
713a5d09c63SMatthias Springer     if (options.allowReturnAllocs)
714a5d09c63SMatthias Springer       return success();
715a5d09c63SMatthias Springer 
716a5d09c63SMatthias Springer     auto conditionOp = whileOp.getConditionOp();
717a5d09c63SMatthias Springer     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
718a5d09c63SMatthias Springer       if (!it.value().getType().isa<TensorType>())
719a5d09c63SMatthias Springer         continue;
720a5d09c63SMatthias Springer       if (!state.areEquivalentBufferizedValues(
721a5d09c63SMatthias Springer               it.value(), conditionOp->getBlock()->getArgument(it.index())))
722a5d09c63SMatthias Springer         return conditionOp->emitError()
723a5d09c63SMatthias Springer                << "Condition arg #" << it.index()
724a5d09c63SMatthias Springer                << " is not equivalent to the corresponding iter bbArg";
725a5d09c63SMatthias Springer     }
726a5d09c63SMatthias Springer 
727a5d09c63SMatthias Springer     auto yieldOp = whileOp.getYieldOp();
728a5d09c63SMatthias Springer     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
729a5d09c63SMatthias Springer       if (!it.value().getType().isa<TensorType>())
730a5d09c63SMatthias Springer         continue;
731a5d09c63SMatthias Springer       if (!state.areEquivalentBufferizedValues(
732a5d09c63SMatthias Springer               it.value(), yieldOp->getBlock()->getArgument(it.index())))
733a5d09c63SMatthias Springer         return yieldOp->emitError()
734a5d09c63SMatthias Springer                << "Yield operand #" << it.index()
735a5d09c63SMatthias Springer                << " is not equivalent to the corresponding iter bbArg";
736a5d09c63SMatthias Springer     }
737a5d09c63SMatthias Springer 
738a5d09c63SMatthias Springer     return success();
739a5d09c63SMatthias Springer   }
740a5d09c63SMatthias Springer };
741a5d09c63SMatthias Springer 
74219efe141SMatthias Springer /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
74319efe141SMatthias Springer /// this is for analysis only.
74419efe141SMatthias Springer struct YieldOpInterface
74519efe141SMatthias Springer     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
74619efe141SMatthias Springer                                                     scf::YieldOp> {
74719efe141SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
7489597b16aSMatthias Springer                               const AnalysisState &state) const {
74919efe141SMatthias Springer     return true;
75019efe141SMatthias Springer   }
75119efe141SMatthias Springer 
75219efe141SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
7539597b16aSMatthias Springer                                const AnalysisState &state) const {
75419efe141SMatthias Springer     return false;
75519efe141SMatthias Springer   }
75619efe141SMatthias Springer 
7579597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
7589597b16aSMatthias Springer                                             const AnalysisState &state) const {
75919efe141SMatthias Springer     if (isa<scf::IfOp>(op->getParentOp()))
760585a8a32SMatthias Springer       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
76119efe141SMatthias Springer     if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
762585a8a32SMatthias Springer       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
763585a8a32SMatthias Springer     return {};
76419efe141SMatthias Springer   }
76519efe141SMatthias Springer 
76619efe141SMatthias Springer   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
7679597b16aSMatthias Springer                             const AnalysisState &state) const {
76819efe141SMatthias Springer     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
76919efe141SMatthias Springer     // may be generated inside the block. We should not return/yield allocations
77019efe141SMatthias Springer     // when possible.
77119efe141SMatthias Springer     return true;
77219efe141SMatthias Springer   }
77319efe141SMatthias Springer 
77419efe141SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
7759597b16aSMatthias Springer                           BufferizationState &state) const {
77619efe141SMatthias Springer     auto yieldOp = cast<scf::YieldOp>(op);
777a5d09c63SMatthias Springer     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
77819efe141SMatthias Springer             yieldOp->getParentOp()))
77919efe141SMatthias Springer       return yieldOp->emitError("unsupported scf::YieldOp parent");
78019efe141SMatthias Springer     return success();
78119efe141SMatthias Springer   }
78219efe141SMatthias Springer };
78319efe141SMatthias Springer 
78419efe141SMatthias Springer } // namespace
78519efe141SMatthias Springer } // namespace scf
78619efe141SMatthias Springer } // namespace mlir
78719efe141SMatthias Springer 
78819efe141SMatthias Springer void mlir::scf::registerBufferizableOpInterfaceExternalModels(
78919efe141SMatthias Springer     DialectRegistry &registry) {
79077eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
79177eee579SRiver Riddle     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
79277eee579SRiver Riddle     ForOp::attachInterface<ForOpInterface>(*ctx);
79377eee579SRiver Riddle     IfOp::attachInterface<IfOpInterface>(*ctx);
794a5d09c63SMatthias Springer     WhileOp::attachInterface<WhileOpInterface>(*ctx);
79577eee579SRiver Riddle     YieldOp::attachInterface<YieldOpInterface>(*ctx);
79677eee579SRiver Riddle   });
79719efe141SMatthias Springer }
798