119efe141SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
219efe141SMatthias Springer //
319efe141SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
419efe141SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
519efe141SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
619efe141SMatthias Springer //
719efe141SMatthias Springer //===----------------------------------------------------------------------===//
819efe141SMatthias Springer 
98b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
1019efe141SMatthias Springer 
1119efe141SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1219efe141SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
131e1eeae8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
158b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
1672de7588SNicolas Vasilache #include "mlir/Dialect/Tensor/IR/Tensor.h"
17a0f843fdSNicolas Vasilache #include "mlir/Dialect/Utils/StaticValueUtils.h"
1819efe141SMatthias Springer #include "mlir/IR/Dialect.h"
1919efe141SMatthias Springer #include "mlir/IR/Operation.h"
2019efe141SMatthias Springer #include "mlir/IR/PatternMatch.h"
2119efe141SMatthias Springer 
2219efe141SMatthias Springer using namespace mlir;
2319efe141SMatthias Springer using namespace mlir::bufferization;
2419efe141SMatthias Springer using namespace mlir::scf;
2519efe141SMatthias Springer 
2619efe141SMatthias Springer namespace mlir {
2719efe141SMatthias Springer namespace scf {
2819efe141SMatthias Springer namespace {
2919efe141SMatthias Springer 
3019efe141SMatthias Springer // bufferization.to_memref is not allowed to change the rank.
ensureToMemrefOpIsValid(Value tensor,Type memrefType)3119efe141SMatthias Springer static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
3219efe141SMatthias Springer #ifndef NDEBUG
3319efe141SMatthias Springer   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
3419efe141SMatthias Springer   assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
3519efe141SMatthias Springer                                 rankedTensorType.getRank())) &&
3619efe141SMatthias Springer          "to_memref would be invalid: mismatching ranks");
3719efe141SMatthias Springer #endif
3819efe141SMatthias Springer }
3919efe141SMatthias Springer 
4019efe141SMatthias Springer /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
4119efe141SMatthias Springer /// fully implemented at the moment.
4219efe141SMatthias Springer struct ExecuteRegionOpInterface
4319efe141SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
4419efe141SMatthias Springer                                                     scf::ExecuteRegionOp> {
4519efe141SMatthias Springer   SmallVector<OpOperand *>
getAliasingOpOperandmlir::scf::__anon76a8a75a0111::ExecuteRegionOpInterface4619efe141SMatthias Springer   getAliasingOpOperand(Operation *op, OpResult opResult,
479597b16aSMatthias Springer                        const AnalysisState &state) const {
4819efe141SMatthias Springer     // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
4919efe141SMatthias Springer     // any SSA value that is in scope. To allow for use-def chain traversal
5019efe141SMatthias Springer     // through ExecuteRegionOps in the analysis, the corresponding yield value
5119efe141SMatthias Springer     // is considered to be aliasing with the result.
5219efe141SMatthias Springer     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
5319efe141SMatthias Springer     size_t resultNum = std::distance(op->getOpResults().begin(),
5419efe141SMatthias Springer                                      llvm::find(op->getOpResults(), opResult));
5519efe141SMatthias Springer     // TODO: Support multiple blocks.
5619efe141SMatthias Springer     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
5719efe141SMatthias Springer            "expected exactly 1 block");
5819efe141SMatthias Springer     auto yieldOp = dyn_cast<scf::YieldOp>(
5919efe141SMatthias Springer         executeRegionOp.getRegion().front().getTerminator());
6019efe141SMatthias Springer     assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
6119efe141SMatthias Springer     return {&yieldOp->getOpOperand(resultNum)};
6219efe141SMatthias Springer   }
6319efe141SMatthias Springer 
6419efe141SMatthias Springer   // TODO: For better bufferization results, this could return `true` only if
6519efe141SMatthias Springer   // there is a memory write in the region.
isMemoryWritemlir::scf::__anon76a8a75a0111::ExecuteRegionOpInterface6619efe141SMatthias Springer   bool isMemoryWrite(Operation *op, OpResult opResult,
679597b16aSMatthias Springer                      const AnalysisState &state) const {
6819efe141SMatthias Springer     // Similar to scf.if, results of this op are always considered memory writes
6919efe141SMatthias Springer     // in the analysis. This is a useful pattern for all ops that have tensor
7019efe141SMatthias Springer     // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
7119efe141SMatthias Springer     // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
7219efe141SMatthias Springer     // ops without OpOperands.
7319efe141SMatthias Springer     return true;
7419efe141SMatthias Springer   }
7519efe141SMatthias Springer 
bufferizemlir::scf::__anon76a8a75a0111::ExecuteRegionOpInterface7619efe141SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
77b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
7819efe141SMatthias Springer     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
798e691e1fSMatthias Springer     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
808e691e1fSMatthias Springer            "only 1 block supported");
818e691e1fSMatthias Springer     auto yieldOp =
828e691e1fSMatthias Springer         cast<scf::YieldOp>(executeRegionOp.getRegion().front().getTerminator());
838e691e1fSMatthias Springer     TypeRange newResultTypes(yieldOp.getResults());
8419efe141SMatthias Springer 
8519efe141SMatthias Springer     // Create new op and move over region.
8619efe141SMatthias Springer     auto newOp =
8719efe141SMatthias Springer         rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
8819efe141SMatthias Springer     newOp.getRegion().takeBody(executeRegionOp.getRegion());
8919efe141SMatthias Springer 
9019efe141SMatthias Springer     // Update all uses of the old op.
9119efe141SMatthias Springer     rewriter.setInsertionPointAfter(newOp);
9219efe141SMatthias Springer     SmallVector<Value> newResults;
93bb6119ebSMehdi Amini     for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
9419efe141SMatthias Springer       if (it.value().isa<TensorType>()) {
9519efe141SMatthias Springer         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
9619efe141SMatthias Springer             executeRegionOp.getLoc(), newOp->getResult(it.index())));
9719efe141SMatthias Springer       } else {
9819efe141SMatthias Springer         newResults.push_back(newOp->getResult(it.index()));
9919efe141SMatthias Springer       }
10019efe141SMatthias Springer     }
10119efe141SMatthias Springer 
10219efe141SMatthias Springer     // Replace old op.
10319efe141SMatthias Springer     rewriter.replaceOp(executeRegionOp, newResults);
10419efe141SMatthias Springer 
10519efe141SMatthias Springer     return success();
10619efe141SMatthias Springer   }
10719efe141SMatthias Springer 
bufferRelationmlir::scf::__anon76a8a75a0111::ExecuteRegionOpInterface10819efe141SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
1099597b16aSMatthias Springer                                 const AnalysisState &state) const {
11019efe141SMatthias Springer     return BufferRelation::Equivalent;
11119efe141SMatthias Springer   }
11219efe141SMatthias Springer };
11319efe141SMatthias Springer 
11419efe141SMatthias Springer /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
11519efe141SMatthias Springer struct IfOpInterface
11619efe141SMatthias Springer     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
11719efe141SMatthias Springer   SmallVector<OpOperand *>
getAliasingOpOperandmlir::scf::__anon76a8a75a0111::IfOpInterface11819efe141SMatthias Springer   getAliasingOpOperand(Operation *op, OpResult opResult,
1199597b16aSMatthias Springer                        const AnalysisState &state) const {
12019efe141SMatthias Springer     // IfOps do not have tensor OpOperands. The yielded value can be any SSA
12119efe141SMatthias Springer     // value that is in scope. To allow for use-def chain traversal through
12219efe141SMatthias Springer     // IfOps in the analysis, both corresponding yield values from the then/else
12319efe141SMatthias Springer     // branches are considered to be aliasing with the result.
12419efe141SMatthias Springer     auto ifOp = cast<scf::IfOp>(op);
12519efe141SMatthias Springer     size_t resultNum = std::distance(op->getOpResults().begin(),
12619efe141SMatthias Springer                                      llvm::find(op->getOpResults(), opResult));
12719efe141SMatthias Springer     return {&ifOp.thenYield()->getOpOperand(resultNum),
12819efe141SMatthias Springer             &ifOp.elseYield()->getOpOperand(resultNum)};
12919efe141SMatthias Springer   }
13019efe141SMatthias Springer 
13119efe141SMatthias Springer   // TODO: For better bufferization results, this could return `true` only if
13219efe141SMatthias Springer   // there is a memory write in one (or both) of the branches. Since this is not
13319efe141SMatthias Springer   // allowed at the moment, we should never encounter scf.ifs that yield
13419efe141SMatthias Springer   // unmodified tensors. Such scf.yield ops could just fold away.
isMemoryWritemlir::scf::__anon76a8a75a0111::IfOpInterface13519efe141SMatthias Springer   bool isMemoryWrite(Operation *op, OpResult opResult,
1369597b16aSMatthias Springer                      const AnalysisState &state) const {
13719efe141SMatthias Springer     // IfOp results are always considered memory writes in the analysis. This
13819efe141SMatthias Springer     // design decision simplifies the analysis considerably. E.g., consider the
13919efe141SMatthias Springer     // following test case:
14019efe141SMatthias Springer     //
14119efe141SMatthias Springer     // %0 = "some_writing_op" : tensor<?xf32>
14219efe141SMatthias Springer     // %r = scf.if %c -> (tensor<?xf32>) {
14319efe141SMatthias Springer     //   scf.yield %0
14419efe141SMatthias Springer     // } else {
14519efe141SMatthias Springer     //   %1 = "another_writing_op"(%0) : tensor<?xf32>
14619efe141SMatthias Springer     // }
14719efe141SMatthias Springer     // "some_reading_op"(%r)
14819efe141SMatthias Springer     //
14919efe141SMatthias Springer     // "another_writing_op" in the above example should be able to bufferize
15019efe141SMatthias Springer     // inplace in the absence of another read of %0. However, if the scf.if op
15119efe141SMatthias Springer     // would not be considered a "write", the analysis would detect the
15219efe141SMatthias Springer     // following conflict:
15319efe141SMatthias Springer     //
15419efe141SMatthias Springer     // * read = some_reading_op
15519efe141SMatthias Springer     // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
15619efe141SMatthias Springer     // * conflictingWrite = %1
15719efe141SMatthias Springer     //
15819efe141SMatthias Springer     // For more details, check the "scf.IfOp" section of the design document.
15919efe141SMatthias Springer     return true;
16019efe141SMatthias Springer   }
16119efe141SMatthias Springer 
bufferizemlir::scf::__anon76a8a75a0111::IfOpInterface16219efe141SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
163b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
1648e691e1fSMatthias Springer     OpBuilder::InsertionGuard g(rewriter);
16519efe141SMatthias Springer     auto ifOp = cast<scf::IfOp>(op);
1668e691e1fSMatthias Springer     auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
1678e691e1fSMatthias Springer     auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
16819efe141SMatthias Springer 
1698e691e1fSMatthias Springer     // Reconcile type mismatches between then/else branches by inserting memref
1708e691e1fSMatthias Springer     // casts.
1718e691e1fSMatthias Springer     SmallVector<Value> thenResults, elseResults;
1728e691e1fSMatthias Springer     bool insertedCast = false;
1738e691e1fSMatthias Springer     for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) {
1748e691e1fSMatthias Springer       Value thenValue = thenYieldOp.getResults()[i];
1758e691e1fSMatthias Springer       Value elseValue = elseYieldOp.getResults()[i];
1768e691e1fSMatthias Springer       if (thenValue.getType() == elseValue.getType()) {
1778e691e1fSMatthias Springer         thenResults.push_back(thenValue);
1788e691e1fSMatthias Springer         elseResults.push_back(elseValue);
1798e691e1fSMatthias Springer         continue;
18019efe141SMatthias Springer       }
1818e691e1fSMatthias Springer 
1828e691e1fSMatthias Springer       // Type mismatch between then/else yield value. Cast both to a memref type
1838e691e1fSMatthias Springer       // with a fully dynamic layout map.
1848e691e1fSMatthias Springer       auto thenMemrefType = thenValue.getType().cast<BaseMemRefType>();
1858e691e1fSMatthias Springer       auto elseMemrefType = elseValue.getType().cast<BaseMemRefType>();
1868e691e1fSMatthias Springer       if (thenMemrefType.getMemorySpaceAsInt() !=
1878e691e1fSMatthias Springer           elseMemrefType.getMemorySpaceAsInt())
1888e691e1fSMatthias Springer         return op->emitError("inconsistent memory space on then/else branches");
1898e691e1fSMatthias Springer       rewriter.setInsertionPoint(thenYieldOp);
1908e691e1fSMatthias Springer       BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout(
1918e691e1fSMatthias Springer           ifOp.getResultTypes()[i].cast<TensorType>(),
1928e691e1fSMatthias Springer           thenMemrefType.getMemorySpaceAsInt());
1938e691e1fSMatthias Springer       thenResults.push_back(rewriter.create<memref::CastOp>(
1948e691e1fSMatthias Springer           thenYieldOp.getLoc(), memrefType, thenValue));
1958e691e1fSMatthias Springer       rewriter.setInsertionPoint(elseYieldOp);
1968e691e1fSMatthias Springer       elseResults.push_back(rewriter.create<memref::CastOp>(
1978e691e1fSMatthias Springer           elseYieldOp.getLoc(), memrefType, elseValue));
1988e691e1fSMatthias Springer       insertedCast = true;
1998e691e1fSMatthias Springer     }
2008e691e1fSMatthias Springer 
2018e691e1fSMatthias Springer     if (insertedCast) {
2028e691e1fSMatthias Springer       rewriter.setInsertionPoint(thenYieldOp);
2038e691e1fSMatthias Springer       rewriter.replaceOpWithNewOp<scf::YieldOp>(thenYieldOp, thenResults);
2048e691e1fSMatthias Springer       rewriter.setInsertionPoint(elseYieldOp);
2058e691e1fSMatthias Springer       rewriter.replaceOpWithNewOp<scf::YieldOp>(elseYieldOp, elseResults);
20619efe141SMatthias Springer     }
20719efe141SMatthias Springer 
20819efe141SMatthias Springer     // Create new op.
2098e691e1fSMatthias Springer     rewriter.setInsertionPoint(ifOp);
2108e691e1fSMatthias Springer     ValueRange resultsValueRange(thenResults);
2118e691e1fSMatthias Springer     TypeRange newTypes(resultsValueRange);
21219efe141SMatthias Springer     auto newIfOp =
21319efe141SMatthias Springer         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
21419efe141SMatthias Springer                                    /*withElseRegion=*/true);
21519efe141SMatthias Springer 
21619efe141SMatthias Springer     // Move over then/else blocks.
21719efe141SMatthias Springer     rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
21819efe141SMatthias Springer     rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
21919efe141SMatthias Springer 
22019efe141SMatthias Springer     // Replace op results.
22119efe141SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
22219efe141SMatthias Springer 
22319efe141SMatthias Springer     return success();
22419efe141SMatthias Springer   }
22519efe141SMatthias Springer 
bufferRelationmlir::scf::__anon76a8a75a0111::IfOpInterface22619efe141SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
2279597b16aSMatthias Springer                                 const AnalysisState &state) const {
22819efe141SMatthias Springer     // IfOp results are equivalent to their corresponding yield values if both
22919efe141SMatthias Springer     // yield values are equivalent to each other.
23019efe141SMatthias Springer     auto bufferizableOp = cast<BufferizableOpInterface>(op);
23119efe141SMatthias Springer     SmallVector<OpOperand *> yieldValues =
23219efe141SMatthias Springer         bufferizableOp.getAliasingOpOperand(opResult, state);
23319efe141SMatthias Springer     assert(yieldValues.size() == 2 && "expected 2 yield values");
23419efe141SMatthias Springer     bool equivalentYields = state.areEquivalentBufferizedValues(
23519efe141SMatthias Springer         yieldValues[0]->get(), yieldValues[1]->get());
23619efe141SMatthias Springer     return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
23719efe141SMatthias Springer   }
23819efe141SMatthias Springer };
23919efe141SMatthias Springer 
240417e1c7dSMatthias Springer /// Helper function for loop bufferization. Return the indices of all values
241417e1c7dSMatthias Springer /// that have a tensor type.
getTensorIndices(ValueRange values)242417e1c7dSMatthias Springer static DenseSet<int64_t> getTensorIndices(ValueRange values) {
243417e1c7dSMatthias Springer   DenseSet<int64_t> result;
244417e1c7dSMatthias Springer   for (const auto &it : llvm::enumerate(values))
245417e1c7dSMatthias Springer     if (it.value().getType().isa<TensorType>())
246417e1c7dSMatthias Springer       result.insert(it.index());
247417e1c7dSMatthias Springer   return result;
248417e1c7dSMatthias Springer }
249417e1c7dSMatthias Springer 
250417e1c7dSMatthias Springer /// Helper function for loop bufferization. Return the indices of all
251417e1c7dSMatthias Springer /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
getEquivalentBuffers(Block::BlockArgListType bbArgs,ValueRange yieldedValues,const AnalysisState & state)252a5d09c63SMatthias Springer DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
253417e1c7dSMatthias Springer                                        ValueRange yieldedValues,
254417e1c7dSMatthias Springer                                        const AnalysisState &state) {
255996834e6SMatthias Springer   unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
256417e1c7dSMatthias Springer   DenseSet<int64_t> result;
257996834e6SMatthias Springer   for (unsigned int i = 0; i < minSize; ++i) {
258996834e6SMatthias Springer     if (!bbArgs[i].getType().isa<TensorType>() ||
259996834e6SMatthias Springer         !yieldedValues[i].getType().isa<TensorType>())
260417e1c7dSMatthias Springer       continue;
261996834e6SMatthias Springer     if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
262996834e6SMatthias Springer       result.insert(i);
263417e1c7dSMatthias Springer   }
264417e1c7dSMatthias Springer   return result;
265417e1c7dSMatthias Springer }
266417e1c7dSMatthias Springer 
267417e1c7dSMatthias Springer /// Helper function for loop bufferization. Cast the given buffer to the given
268417e1c7dSMatthias Springer /// memref type.
castBuffer(OpBuilder & b,Value buffer,Type type)269417e1c7dSMatthias Springer static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
270417e1c7dSMatthias Springer   assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
271417e1c7dSMatthias Springer   assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
272417e1c7dSMatthias Springer   // If the buffer already has the correct type, no cast is needed.
273417e1c7dSMatthias Springer   if (buffer.getType() == type)
274417e1c7dSMatthias Springer     return buffer;
275417e1c7dSMatthias Springer   // TODO: In case `type` has a layout map that is not the fully dynamic
276417e1c7dSMatthias Springer   // one, we may not be able to cast the buffer. In that case, the loop
277417e1c7dSMatthias Springer   // iter_arg's layout map must be changed (see uses of `castBuffer`).
278417e1c7dSMatthias Springer   assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
279417e1c7dSMatthias Springer          "scf.while op bufferization: cast incompatible");
280417e1c7dSMatthias Springer   return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
281417e1c7dSMatthias Springer }
282417e1c7dSMatthias Springer 
283417e1c7dSMatthias Springer /// Helper function for loop bufferization. Return the bufferized values of the
284417e1c7dSMatthias Springer /// given OpOperands. If an operand is not a tensor, return the original value.
2855d50f51cSMatthias Springer static FailureOr<SmallVector<Value>>
getBuffers(RewriterBase & rewriter,MutableArrayRef<OpOperand> operands,const BufferizationOptions & options)2865d50f51cSMatthias Springer getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
287b55d55ecSMatthias Springer            const BufferizationOptions &options) {
288417e1c7dSMatthias Springer   SmallVector<Value> result;
289417e1c7dSMatthias Springer   for (OpOperand &opOperand : operands) {
290417e1c7dSMatthias Springer     if (opOperand.get().getType().isa<TensorType>()) {
2915d50f51cSMatthias Springer       FailureOr<Value> resultBuffer =
2925d50f51cSMatthias Springer           getBuffer(rewriter, opOperand.get(), options);
2935d50f51cSMatthias Springer       if (failed(resultBuffer))
2945d50f51cSMatthias Springer         return failure();
2955d50f51cSMatthias Springer       result.push_back(*resultBuffer);
296417e1c7dSMatthias Springer     } else {
297417e1c7dSMatthias Springer       result.push_back(opOperand.get());
298417e1c7dSMatthias Springer     }
299417e1c7dSMatthias Springer   }
300417e1c7dSMatthias Springer   return result;
301417e1c7dSMatthias Springer }
302417e1c7dSMatthias Springer 
303417e1c7dSMatthias Springer /// Helper function for loop bufferization. Compute the buffer that should be
304b3ebe3beSMatthias Springer /// yielded from a loop block (loop body or loop condition).
getYieldedBuffer(RewriterBase & rewriter,Value tensor,BaseMemRefType type,const BufferizationOptions & options)3055d50f51cSMatthias Springer static FailureOr<Value> getYieldedBuffer(RewriterBase &rewriter, Value tensor,
306b55d55ecSMatthias Springer                                          BaseMemRefType type,
307b55d55ecSMatthias Springer                                          const BufferizationOptions &options) {
308417e1c7dSMatthias Springer   assert(tensor.getType().isa<TensorType>() && "expected tensor");
309417e1c7dSMatthias Springer   ensureToMemrefOpIsValid(tensor, type);
3105d50f51cSMatthias Springer   FailureOr<Value> yieldedVal = getBuffer(rewriter, tensor, options);
3115d50f51cSMatthias Springer   if (failed(yieldedVal))
3125d50f51cSMatthias Springer     return failure();
3135d50f51cSMatthias Springer   return castBuffer(rewriter, *yieldedVal, type);
314417e1c7dSMatthias Springer }
315417e1c7dSMatthias Springer 
316417e1c7dSMatthias Springer /// Helper function for loop bufferization. Given a range of values, apply
317417e1c7dSMatthias Springer /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
318417e1c7dSMatthias Springer /// value in the result vector.
3195d50f51cSMatthias Springer static FailureOr<SmallVector<Value>>
convertTensorValues(ValueRange values,const DenseSet<int64_t> & tensorIndices,llvm::function_ref<FailureOr<Value> (Value,int64_t)> func)320417e1c7dSMatthias Springer convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
3215d50f51cSMatthias Springer                     llvm::function_ref<FailureOr<Value>(Value, int64_t)> func) {
322417e1c7dSMatthias Springer   SmallVector<Value> result;
323417e1c7dSMatthias Springer   for (const auto &it : llvm::enumerate(values)) {
324417e1c7dSMatthias Springer     size_t idx = it.index();
325417e1c7dSMatthias Springer     Value val = it.value();
3265d50f51cSMatthias Springer     if (tensorIndices.contains(idx)) {
3275d50f51cSMatthias Springer       FailureOr<Value> maybeVal = func(val, idx);
3285d50f51cSMatthias Springer       if (failed(maybeVal))
3295d50f51cSMatthias Springer         return failure();
3305d50f51cSMatthias Springer       result.push_back(*maybeVal);
3315d50f51cSMatthias Springer     } else {
3325d50f51cSMatthias Springer       result.push_back(val);
3335d50f51cSMatthias Springer     }
334417e1c7dSMatthias Springer   }
335417e1c7dSMatthias Springer   return result;
336417e1c7dSMatthias Springer }
337417e1c7dSMatthias Springer 
338417e1c7dSMatthias Springer /// Helper function for loop bufferization. Given a list of pre-bufferization
339417e1c7dSMatthias Springer /// yielded values, compute the list of bufferized yielded values.
3405d50f51cSMatthias Springer FailureOr<SmallVector<Value>>
getYieldedValues(RewriterBase & rewriter,ValueRange values,TypeRange bufferizedTypes,const DenseSet<int64_t> & tensorIndices,const BufferizationOptions & options)3415d50f51cSMatthias Springer getYieldedValues(RewriterBase &rewriter, ValueRange values,
342417e1c7dSMatthias Springer                  TypeRange bufferizedTypes,
343417e1c7dSMatthias Springer                  const DenseSet<int64_t> &tensorIndices,
344b55d55ecSMatthias Springer                  const BufferizationOptions &options) {
345417e1c7dSMatthias Springer   return convertTensorValues(
346417e1c7dSMatthias Springer       values, tensorIndices, [&](Value val, int64_t index) {
347417e1c7dSMatthias Springer         return getYieldedBuffer(rewriter, val,
348417e1c7dSMatthias Springer                                 bufferizedTypes[index].cast<BaseMemRefType>(),
349b55d55ecSMatthias Springer                                 options);
350417e1c7dSMatthias Springer       });
351417e1c7dSMatthias Springer }
352417e1c7dSMatthias Springer 
353a5d09c63SMatthias Springer /// Helper function for loop bufferization. Given a list of bbArgs of the new
354a5d09c63SMatthias Springer /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
355a5d09c63SMatthias Springer /// ToTensorOps, so that the block body can be moved over to the new op.
356a5d09c63SMatthias Springer SmallVector<Value>
getBbArgReplacements(RewriterBase & rewriter,Block::BlockArgListType bbArgs,const DenseSet<int64_t> & tensorIndices)357a5d09c63SMatthias Springer getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
358a5d09c63SMatthias Springer                      const DenseSet<int64_t> &tensorIndices) {
3595d50f51cSMatthias Springer   SmallVector<Value> result;
3605d50f51cSMatthias Springer   for (const auto &it : llvm::enumerate(bbArgs)) {
3615d50f51cSMatthias Springer     size_t idx = it.index();
3625d50f51cSMatthias Springer     Value val = it.value();
3635d50f51cSMatthias Springer     if (tensorIndices.contains(idx)) {
3645d50f51cSMatthias Springer       result.push_back(
3655d50f51cSMatthias Springer           rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
3665d50f51cSMatthias Springer               .getResult());
3675d50f51cSMatthias Springer     } else {
3685d50f51cSMatthias Springer       result.push_back(val);
3695d50f51cSMatthias Springer     }
3705d50f51cSMatthias Springer   }
3715d50f51cSMatthias Springer   return result;
372a5d09c63SMatthias Springer }
373a5d09c63SMatthias Springer 
37419efe141SMatthias Springer /// Bufferization of scf.for. Replace with a new scf.for that operates on
37519efe141SMatthias Springer /// memrefs.
37619efe141SMatthias Springer struct ForOpInterface
37719efe141SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
37819efe141SMatthias Springer                                                     scf::ForOp> {
bufferizesToMemoryReadmlir::scf::__anon76a8a75a0111::ForOpInterface37919efe141SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
3809597b16aSMatthias Springer                               const AnalysisState &state) const {
38119efe141SMatthias Springer     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
38219efe141SMatthias Springer     // its matching bbArg may.
38319efe141SMatthias Springer     auto forOp = cast<scf::ForOp>(op);
38419efe141SMatthias Springer     return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
38519efe141SMatthias Springer   }
38619efe141SMatthias Springer 
bufferizesToMemoryWritemlir::scf::__anon76a8a75a0111::ForOpInterface38719efe141SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
3889597b16aSMatthias Springer                                const AnalysisState &state) const {
3891e1eeae8SMatthias Springer     // Tensor iter_args of scf::ForOps are always considered as a write.
39019efe141SMatthias Springer     return true;
39119efe141SMatthias Springer   }
39219efe141SMatthias Springer 
getAliasingOpResultmlir::scf::__anon76a8a75a0111::ForOpInterface3939597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
3949597b16aSMatthias Springer                                             const AnalysisState &state) const {
39519efe141SMatthias Springer     auto forOp = cast<scf::ForOp>(op);
396585a8a32SMatthias Springer     return {forOp.getResultForOpOperand(opOperand)};
39719efe141SMatthias Springer   }
39819efe141SMatthias Springer 
bufferRelationmlir::scf::__anon76a8a75a0111::ForOpInterface39919efe141SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
4009597b16aSMatthias Springer                                 const AnalysisState &state) const {
40119efe141SMatthias Springer     // ForOp results are equivalent to their corresponding init_args if the
40219efe141SMatthias Springer     // corresponding iter_args and yield values are equivalent.
40319efe141SMatthias Springer     auto forOp = cast<scf::ForOp>(op);
40419efe141SMatthias Springer     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
40519efe141SMatthias Springer     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
4061e1eeae8SMatthias Springer     auto yieldOp =
4071e1eeae8SMatthias Springer         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
40819efe141SMatthias Springer     bool equivalentYield = state.areEquivalentBufferizedValues(
40919efe141SMatthias Springer         bbArg, yieldOp->getOperand(opResult.getResultNumber()));
41019efe141SMatthias Springer     return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
41119efe141SMatthias Springer   }
41219efe141SMatthias Springer 
isWritablemlir::scf::__anon76a8a75a0111::ForOpInterface41319efe141SMatthias Springer   bool isWritable(Operation *op, Value value,
4149597b16aSMatthias Springer                   const AnalysisState &state) const {
41519efe141SMatthias Springer     // Interestingly, scf::ForOp's bbArg can **always** be viewed
41619efe141SMatthias Springer     // inplace from the perspective of ops nested under:
41719efe141SMatthias Springer     //   1. Either the matching iter operand is not bufferized inplace and an
41819efe141SMatthias Springer     //      alloc + optional copy makes the bbArg itself inplaceable.
41919efe141SMatthias Springer     //   2. Or the matching iter operand is bufferized inplace and bbArg just
42019efe141SMatthias Springer     //      bufferizes to that too.
42119efe141SMatthias Springer     return true;
42219efe141SMatthias Springer   }
42319efe141SMatthias Springer 
resolveConflictsmlir::scf::__anon76a8a75a0111::ForOpInterface424d361ecbdSMatthias Springer   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
425d361ecbdSMatthias Springer                                  const AnalysisState &state) const {
426d361ecbdSMatthias Springer     auto bufferizableOp = cast<BufferizableOpInterface>(op);
427d361ecbdSMatthias Springer     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
428d361ecbdSMatthias Springer       return failure();
429d361ecbdSMatthias Springer 
430d361ecbdSMatthias Springer     if (!state.getOptions().enforceAliasingInvariants)
431d361ecbdSMatthias Springer       return success();
432d361ecbdSMatthias Springer 
433d361ecbdSMatthias Springer     // According to the `getAliasing...` implementations, a bufferized OpResult
434d361ecbdSMatthias Springer     // may alias only with the corresponding bufferized init_arg and with no
435d361ecbdSMatthias Springer     // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
436d361ecbdSMatthias Springer     // but not with any other OpOperand. If a corresponding OpResult/init_arg
437d361ecbdSMatthias Springer     // pair bufferizes to equivalent buffers, this aliasing requirement is
438d361ecbdSMatthias Springer     // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
439d361ecbdSMatthias Springer     // (New buffer copies do not alias with any buffer.)
440d361ecbdSMatthias Springer     auto forOp = cast<scf::ForOp>(op);
441d361ecbdSMatthias Springer     auto yieldOp =
442d361ecbdSMatthias Springer         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
443d361ecbdSMatthias Springer     OpBuilder::InsertionGuard g(rewriter);
444d361ecbdSMatthias Springer     rewriter.setInsertionPoint(yieldOp);
445d361ecbdSMatthias Springer 
446d361ecbdSMatthias Springer     // Indices of all iter_args that have tensor type. These are the ones that
447d361ecbdSMatthias Springer     // are bufferized.
448d361ecbdSMatthias Springer     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
449d361ecbdSMatthias Springer     // For every yielded value, is the value equivalent to its corresponding
450d361ecbdSMatthias Springer     // bbArg?
451d361ecbdSMatthias Springer     DenseSet<int64_t> equivalentYields = getEquivalentBuffers(
452d361ecbdSMatthias Springer         forOp.getRegionIterArgs(), yieldOp.getResults(), state);
453d361ecbdSMatthias Springer     SmallVector<Value> yieldValues;
454d361ecbdSMatthias Springer     for (int64_t idx = 0;
455d361ecbdSMatthias Springer          idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
456d361ecbdSMatthias Springer       Value value = yieldOp.getResults()[idx];
457d361ecbdSMatthias Springer       if (!indices.contains(idx) || equivalentYields.contains(idx)) {
458d361ecbdSMatthias Springer         yieldValues.push_back(value);
459d361ecbdSMatthias Springer         continue;
460d361ecbdSMatthias Springer       }
46145b995cdSMatthias Springer       FailureOr<Value> alloc =
46245b995cdSMatthias Springer           allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
46345b995cdSMatthias Springer                                        /*escape=*/true, state.getOptions());
46445b995cdSMatthias Springer       if (failed(alloc))
46545b995cdSMatthias Springer         return failure();
46645b995cdSMatthias Springer       yieldValues.push_back(*alloc);
467d361ecbdSMatthias Springer     }
468d361ecbdSMatthias Springer 
469d361ecbdSMatthias Springer     rewriter.updateRootInPlace(
470d361ecbdSMatthias Springer         yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
471d361ecbdSMatthias Springer     return success();
472d361ecbdSMatthias Springer   }
473d361ecbdSMatthias Springer 
4745d50f51cSMatthias Springer   FailureOr<BaseMemRefType>
getBufferTypemlir::scf::__anon76a8a75a0111::ForOpInterface4755d50f51cSMatthias Springer   getBufferType(Operation *op, BlockArgument bbArg,
4763ff93f83SMatthias Springer                 const BufferizationOptions &options) const {
4773ff93f83SMatthias Springer     auto forOp = cast<scf::ForOp>(op);
4783ff93f83SMatthias Springer     return bufferization::getBufferType(
4793ff93f83SMatthias Springer         forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
4803ff93f83SMatthias Springer   }
4813ff93f83SMatthias Springer 
bufferizemlir::scf::__anon76a8a75a0111::ForOpInterface48219efe141SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
483b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
48419efe141SMatthias Springer     auto forOp = cast<scf::ForOp>(op);
48519efe141SMatthias Springer     Block *oldLoopBody = &forOp.getLoopBody().front();
48619efe141SMatthias Springer 
48719efe141SMatthias Springer     // Indices of all iter_args that have tensor type. These are the ones that
48819efe141SMatthias Springer     // are bufferized.
489417e1c7dSMatthias Springer     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
49019efe141SMatthias Springer 
491417e1c7dSMatthias Springer     // The new memref init_args of the loop.
4925d50f51cSMatthias Springer     FailureOr<SmallVector<Value>> maybeInitArgs =
493b55d55ecSMatthias Springer         getBuffers(rewriter, forOp.getIterOpOperands(), options);
4945d50f51cSMatthias Springer     if (failed(maybeInitArgs))
4955d50f51cSMatthias Springer       return failure();
4965d50f51cSMatthias Springer     SmallVector<Value> initArgs = *maybeInitArgs;
49719efe141SMatthias Springer 
49819efe141SMatthias Springer     // Construct a new scf.for op with memref instead of tensor values.
49919efe141SMatthias Springer     auto newForOp = rewriter.create<scf::ForOp>(
50019efe141SMatthias Springer         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
50119efe141SMatthias Springer         forOp.getStep(), initArgs);
502413fbb04SLei Zhang     newForOp->setAttrs(forOp->getAttrs());
503417e1c7dSMatthias Springer     ValueRange initArgsRange(initArgs);
504417e1c7dSMatthias Springer     TypeRange initArgsTypes(initArgsRange);
50519efe141SMatthias Springer     Block *loopBody = &newForOp.getLoopBody().front();
50619efe141SMatthias Springer 
50719efe141SMatthias Springer     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
50819efe141SMatthias Springer     // iter_args of the new loop in ToTensorOps.
50919efe141SMatthias Springer     rewriter.setInsertionPointToStart(loopBody);
510a5d09c63SMatthias Springer     SmallVector<Value> iterArgs =
511a5d09c63SMatthias Springer         getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
51219efe141SMatthias Springer     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
51319efe141SMatthias Springer 
51419efe141SMatthias Springer     // Move loop body to new loop.
51519efe141SMatthias Springer     rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
51619efe141SMatthias Springer 
51719efe141SMatthias Springer     // Replace loop results.
51819efe141SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
51919efe141SMatthias Springer 
52019efe141SMatthias Springer     return success();
52119efe141SMatthias Springer   }
5224ec00fb3SMatthias Springer 
5231e1eeae8SMatthias Springer   /// Assert that yielded values of an scf.for op are equivalent to their
524f178c386SMatthias Springer   /// corresponding bbArgs. In that case, the buffer relations of the
525f178c386SMatthias Springer   /// corresponding OpResults are "Equivalent".
526f178c386SMatthias Springer   ///
527f178c386SMatthias Springer   /// If this is not the case, an allocs+copies are inserted and yielded from
528f178c386SMatthias Springer   /// the loop. This could be a performance problem, so it must be explicitly
529f178c386SMatthias Springer   /// activated with `alloc-return-allocs`.
verifyAnalysismlir::scf::__anon76a8a75a0111::ForOpInterface5304ec00fb3SMatthias Springer   LogicalResult verifyAnalysis(Operation *op,
5319597b16aSMatthias Springer                                const AnalysisState &state) const {
5321e1eeae8SMatthias Springer     const auto &options =
5331e1eeae8SMatthias Springer         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
5341e1eeae8SMatthias Springer     if (options.allowReturnAllocs)
5351e1eeae8SMatthias Springer       return success();
5361e1eeae8SMatthias Springer 
5374ec00fb3SMatthias Springer     auto forOp = cast<scf::ForOp>(op);
5384ec00fb3SMatthias Springer     auto yieldOp =
5394ec00fb3SMatthias Springer         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
540f178c386SMatthias Springer     for (OpResult opResult : op->getOpResults()) {
541f178c386SMatthias Springer       if (!opResult.getType().isa<TensorType>())
5424ec00fb3SMatthias Springer         continue;
5434ec00fb3SMatthias Springer 
5444ec00fb3SMatthias Springer       // Note: This is overly strict. We should check for aliasing bufferized
5454ec00fb3SMatthias Springer       // values. But we don't have a "must-alias" analysis yet.
546f178c386SMatthias Springer       if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
5474ec00fb3SMatthias Springer         return yieldOp->emitError()
548f178c386SMatthias Springer                << "Yield operand #" << opResult.getResultNumber()
549e3006825SMatthias Springer                << " is not equivalent to the corresponding iter bbArg";
5504ec00fb3SMatthias Springer     }
551f178c386SMatthias Springer 
5524ec00fb3SMatthias Springer     return success();
5534ec00fb3SMatthias Springer   }
55419efe141SMatthias Springer };
55519efe141SMatthias Springer 
556a5d09c63SMatthias Springer /// Bufferization of scf.while. Replace with a new scf.while that operates on
557a5d09c63SMatthias Springer /// memrefs.
558a5d09c63SMatthias Springer struct WhileOpInterface
559a5d09c63SMatthias Springer     : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
560a5d09c63SMatthias Springer                                                     scf::WhileOp> {
bufferizesToMemoryReadmlir::scf::__anon76a8a75a0111::WhileOpInterface561a5d09c63SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
562a5d09c63SMatthias Springer                               const AnalysisState &state) const {
563a5d09c63SMatthias Springer     // Tensor iter_args of scf::WhileOps are always considered as a read.
564a5d09c63SMatthias Springer     return true;
565a5d09c63SMatthias Springer   }
566a5d09c63SMatthias Springer 
bufferizesToMemoryWritemlir::scf::__anon76a8a75a0111::WhileOpInterface567a5d09c63SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
568a5d09c63SMatthias Springer                                const AnalysisState &state) const {
569a5d09c63SMatthias Springer     // Tensor iter_args of scf::WhileOps are always considered as a write.
570a5d09c63SMatthias Springer     return true;
571a5d09c63SMatthias Springer   }
572a5d09c63SMatthias Springer 
getAliasingOpResultmlir::scf::__anon76a8a75a0111::WhileOpInterface573a5d09c63SMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
574a5d09c63SMatthias Springer                                             const AnalysisState &state) const {
575a5d09c63SMatthias Springer     auto whileOp = cast<scf::WhileOp>(op);
576996834e6SMatthias Springer     unsigned int idx = opOperand.getOperandNumber();
577996834e6SMatthias Springer 
578996834e6SMatthias Springer     // The OpResults and OpOperands may not match. They may not even have the
579996834e6SMatthias Springer     // same type. The number of OpResults and OpOperands can also differ.
580996834e6SMatthias Springer     if (idx >= op->getNumResults() ||
581996834e6SMatthias Springer         opOperand.get().getType() != op->getResult(idx).getType())
582996834e6SMatthias Springer       return {};
583996834e6SMatthias Springer 
584996834e6SMatthias Springer     // The only aliasing OpResult may be the one at the same index.
585996834e6SMatthias Springer     return {whileOp->getResult(idx)};
586a5d09c63SMatthias Springer   }
587a5d09c63SMatthias Springer 
bufferRelationmlir::scf::__anon76a8a75a0111::WhileOpInterface588a5d09c63SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
589a5d09c63SMatthias Springer                                 const AnalysisState &state) const {
590a5d09c63SMatthias Springer     // WhileOp results are equivalent to their corresponding init_args if the
591a5d09c63SMatthias Springer     // corresponding iter_args and yield values are equivalent (for both the
592a5d09c63SMatthias Springer     // "before" and the "after" block).
593a5d09c63SMatthias Springer     unsigned int resultNumber = opResult.getResultNumber();
594a5d09c63SMatthias Springer     auto whileOp = cast<scf::WhileOp>(op);
595a5d09c63SMatthias Springer 
596996834e6SMatthias Springer     // The "before" region bbArgs and the OpResults may not match.
597996834e6SMatthias Springer     if (resultNumber >= whileOp.getBeforeArguments().size())
598996834e6SMatthias Springer       return BufferRelation::None;
599996834e6SMatthias Springer     if (opResult.getType() !=
600996834e6SMatthias Springer         whileOp.getBeforeArguments()[resultNumber].getType())
601996834e6SMatthias Springer       return BufferRelation::None;
602996834e6SMatthias Springer 
603a5d09c63SMatthias Springer     auto conditionOp = whileOp.getConditionOp();
604a5d09c63SMatthias Springer     BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
605a5d09c63SMatthias Springer     Value conditionOperand = conditionOp.getArgs()[resultNumber];
606a5d09c63SMatthias Springer     bool equivCondition =
607a5d09c63SMatthias Springer         state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
608a5d09c63SMatthias Springer 
609a5d09c63SMatthias Springer     auto yieldOp = whileOp.getYieldOp();
610a5d09c63SMatthias Springer     BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
611a5d09c63SMatthias Springer     Value yieldOperand = yieldOp.getOperand(resultNumber);
612a5d09c63SMatthias Springer     bool equivYield =
613a5d09c63SMatthias Springer         state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
614a5d09c63SMatthias Springer 
615a5d09c63SMatthias Springer     return equivCondition && equivYield ? BufferRelation::Equivalent
616a5d09c63SMatthias Springer                                         : BufferRelation::None;
617a5d09c63SMatthias Springer   }
618a5d09c63SMatthias Springer 
isWritablemlir::scf::__anon76a8a75a0111::WhileOpInterface619a5d09c63SMatthias Springer   bool isWritable(Operation *op, Value value,
620a5d09c63SMatthias Springer                   const AnalysisState &state) const {
621a5d09c63SMatthias Springer     // Interestingly, scf::WhileOp's bbArg can **always** be viewed
622a5d09c63SMatthias Springer     // inplace from the perspective of ops nested under:
623a5d09c63SMatthias Springer     //   1. Either the matching iter operand is not bufferized inplace and an
624a5d09c63SMatthias Springer     //      alloc + optional copy makes the bbArg itself inplaceable.
625a5d09c63SMatthias Springer     //   2. Or the matching iter operand is bufferized inplace and bbArg just
626a5d09c63SMatthias Springer     //      bufferizes to that too.
627a5d09c63SMatthias Springer     return true;
628a5d09c63SMatthias Springer   }
629a5d09c63SMatthias Springer 
resolveConflictsmlir::scf::__anon76a8a75a0111::WhileOpInterface630d361ecbdSMatthias Springer   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
631d361ecbdSMatthias Springer                                  const AnalysisState &state) const {
632d361ecbdSMatthias Springer     auto bufferizableOp = cast<BufferizableOpInterface>(op);
633d361ecbdSMatthias Springer     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
634d361ecbdSMatthias Springer       return failure();
635d361ecbdSMatthias Springer 
636d361ecbdSMatthias Springer     if (!state.getOptions().enforceAliasingInvariants)
637d361ecbdSMatthias Springer       return success();
638d361ecbdSMatthias Springer 
639d361ecbdSMatthias Springer     // According to the `getAliasing...` implementations, a bufferized OpResult
640d361ecbdSMatthias Springer     // may alias only with the corresponding bufferized init_arg and with no
641d361ecbdSMatthias Springer     // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
642d361ecbdSMatthias Springer     // but not with any other OpOperand. If a corresponding OpResult/init_arg
643d361ecbdSMatthias Springer     // pair bufferizes to equivalent buffers, this aliasing requirement is
644d361ecbdSMatthias Springer     // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
645d361ecbdSMatthias Springer     // (New buffer copies do not alias with any buffer.)
646d361ecbdSMatthias Springer     OpBuilder::InsertionGuard g(rewriter);
647d361ecbdSMatthias Springer     auto whileOp = cast<scf::WhileOp>(op);
648d361ecbdSMatthias Springer     auto conditionOp = whileOp.getConditionOp();
649d361ecbdSMatthias Springer     auto yieldOp = whileOp.getYieldOp();
650d361ecbdSMatthias Springer 
651d361ecbdSMatthias Springer     // Indices of all bbArgs that have tensor type. These are the ones that
652d361ecbdSMatthias Springer     // are bufferized. The "before" and "after" regions may have different args.
653d361ecbdSMatthias Springer     DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
654d361ecbdSMatthias Springer     DenseSet<int64_t> indicesAfter =
655d361ecbdSMatthias Springer         getTensorIndices(whileOp.getAfterArguments());
656d361ecbdSMatthias Springer 
657d361ecbdSMatthias Springer     // For every yielded value, is the value equivalent to its corresponding
658d361ecbdSMatthias Springer     // bbArg?
659d361ecbdSMatthias Springer     DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
660d361ecbdSMatthias Springer         whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
661d361ecbdSMatthias Springer     DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
662d361ecbdSMatthias Springer         whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
663d361ecbdSMatthias Springer 
664d361ecbdSMatthias Springer     // Update "before" region.
665d361ecbdSMatthias Springer     rewriter.setInsertionPoint(conditionOp);
666d361ecbdSMatthias Springer     SmallVector<Value> beforeYieldValues;
667d361ecbdSMatthias Springer     for (int64_t idx = 0;
668d361ecbdSMatthias Springer          idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
669d361ecbdSMatthias Springer       Value value = conditionOp.getArgs()[idx];
670d361ecbdSMatthias Springer       if (!indicesBefore.contains(idx) ||
671d361ecbdSMatthias Springer           equivalentYieldsBefore.contains(idx)) {
672d361ecbdSMatthias Springer         beforeYieldValues.push_back(value);
673d361ecbdSMatthias Springer         continue;
674d361ecbdSMatthias Springer       }
67545b995cdSMatthias Springer       FailureOr<Value> alloc =
67645b995cdSMatthias Springer           allocateTensorForShapedValue(rewriter, conditionOp.getLoc(), value,
67745b995cdSMatthias Springer                                        /*escape=*/true, state.getOptions());
67845b995cdSMatthias Springer       if (failed(alloc))
67945b995cdSMatthias Springer         return failure();
68045b995cdSMatthias Springer       beforeYieldValues.push_back(*alloc);
681d361ecbdSMatthias Springer     }
682d361ecbdSMatthias Springer     rewriter.updateRootInPlace(conditionOp, [&]() {
683d361ecbdSMatthias Springer       conditionOp.getArgsMutable().assign(beforeYieldValues);
684d361ecbdSMatthias Springer     });
685d361ecbdSMatthias Springer 
686d361ecbdSMatthias Springer     // Update "after" region.
687d361ecbdSMatthias Springer     rewriter.setInsertionPoint(yieldOp);
688d361ecbdSMatthias Springer     SmallVector<Value> afterYieldValues;
689d361ecbdSMatthias Springer     for (int64_t idx = 0;
690d361ecbdSMatthias Springer          idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
691d361ecbdSMatthias Springer       Value value = yieldOp.getResults()[idx];
692d361ecbdSMatthias Springer       if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) {
693d361ecbdSMatthias Springer         afterYieldValues.push_back(value);
694d361ecbdSMatthias Springer         continue;
695d361ecbdSMatthias Springer       }
69645b995cdSMatthias Springer       FailureOr<Value> alloc =
69745b995cdSMatthias Springer           allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
69845b995cdSMatthias Springer                                        /*escape=*/true, state.getOptions());
69945b995cdSMatthias Springer       if (failed(alloc))
70045b995cdSMatthias Springer         return failure();
70145b995cdSMatthias Springer       afterYieldValues.push_back(*alloc);
702d361ecbdSMatthias Springer     }
703d361ecbdSMatthias Springer     rewriter.updateRootInPlace(yieldOp, [&]() {
704d361ecbdSMatthias Springer       yieldOp.getResultsMutable().assign(afterYieldValues);
705d361ecbdSMatthias Springer     });
706d361ecbdSMatthias Springer 
707d361ecbdSMatthias Springer     return success();
708d361ecbdSMatthias Springer   }
709d361ecbdSMatthias Springer 
710c0b0b6a0SMatthias Springer   // TODO: Implement getBufferType interface method and infer buffer types.
711c0b0b6a0SMatthias Springer 
bufferizemlir::scf::__anon76a8a75a0111::WhileOpInterface712a5d09c63SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
713b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
714a5d09c63SMatthias Springer     auto whileOp = cast<scf::WhileOp>(op);
715a5d09c63SMatthias Springer 
716a5d09c63SMatthias Springer     assert(whileOp.getBefore().getBlocks().size() == 1 &&
717a5d09c63SMatthias Springer            "regions with multiple blocks not supported");
718a5d09c63SMatthias Springer     Block *beforeBody = &whileOp.getBefore().front();
719a5d09c63SMatthias Springer     assert(whileOp.getAfter().getBlocks().size() == 1 &&
720a5d09c63SMatthias Springer            "regions with multiple blocks not supported");
721a5d09c63SMatthias Springer     Block *afterBody = &whileOp.getAfter().front();
722a5d09c63SMatthias Springer 
723996834e6SMatthias Springer     // Indices of all bbArgs that have tensor type. These are the ones that
724996834e6SMatthias Springer     // are bufferized. The "before" and "after" regions may have different args.
725996834e6SMatthias Springer     DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
726996834e6SMatthias Springer     DenseSet<int64_t> indicesAfter =
727996834e6SMatthias Springer         getTensorIndices(whileOp.getAfterArguments());
728996834e6SMatthias Springer 
729a5d09c63SMatthias Springer     // The new memref init_args of the loop.
7305d50f51cSMatthias Springer     FailureOr<SmallVector<Value>> maybeInitArgs =
731b55d55ecSMatthias Springer         getBuffers(rewriter, whileOp->getOpOperands(), options);
7325d50f51cSMatthias Springer     if (failed(maybeInitArgs))
7335d50f51cSMatthias Springer       return failure();
7345d50f51cSMatthias Springer     SmallVector<Value> initArgs = *maybeInitArgs;
735996834e6SMatthias Springer 
736996834e6SMatthias Springer     // The result types of a WhileOp are the same as the "after" bbArg types.
737996834e6SMatthias Springer     SmallVector<Type> argsTypesAfter = llvm::to_vector(
738996834e6SMatthias Springer         llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
7395d50f51cSMatthias Springer           // TODO: error handling
7405d50f51cSMatthias Springer           return bufferization::getBufferType(bbArg, options)->cast<Type>();
741996834e6SMatthias Springer         }));
742a5d09c63SMatthias Springer 
743a5d09c63SMatthias Springer     // Construct a new scf.while op with memref instead of tensor values.
744996834e6SMatthias Springer     ValueRange argsRangeBefore(initArgs);
745996834e6SMatthias Springer     TypeRange argsTypesBefore(argsRangeBefore);
746996834e6SMatthias Springer     auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
747996834e6SMatthias Springer                                                     argsTypesAfter, initArgs);
748996834e6SMatthias Springer 
749a5d09c63SMatthias Springer     // Add before/after regions to the new op.
750996834e6SMatthias Springer     SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc());
751996834e6SMatthias Springer     SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
752996834e6SMatthias Springer                                          whileOp.getLoc());
753a5d09c63SMatthias Springer     Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
754996834e6SMatthias Springer     newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
755a5d09c63SMatthias Springer     Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
756996834e6SMatthias Springer     newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
757a5d09c63SMatthias Springer 
758a5d09c63SMatthias Springer     // Set up new iter_args and move the loop condition block to the new op.
759a5d09c63SMatthias Springer     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
760a5d09c63SMatthias Springer     // in ToTensorOps.
761a5d09c63SMatthias Springer     rewriter.setInsertionPointToStart(newBeforeBody);
762a5d09c63SMatthias Springer     SmallVector<Value> newBeforeArgs = getBbArgReplacements(
763996834e6SMatthias Springer         rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
764a5d09c63SMatthias Springer     rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
765a5d09c63SMatthias Springer 
766a5d09c63SMatthias Springer     // Update scf.condition of new loop.
767a5d09c63SMatthias Springer     auto newConditionOp = newWhileOp.getConditionOp();
768a5d09c63SMatthias Springer     rewriter.setInsertionPoint(newConditionOp);
769996834e6SMatthias Springer     // Only equivalent buffers or new buffer allocations may be yielded to the
770996834e6SMatthias Springer     // "after" region.
771996834e6SMatthias Springer     // TODO: This could be relaxed for better bufferization results.
7725d50f51cSMatthias Springer     FailureOr<SmallVector<Value>> newConditionArgs =
773996834e6SMatthias Springer         getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
774b55d55ecSMatthias Springer                          indicesAfter, options);
7755d50f51cSMatthias Springer     if (failed(newConditionArgs))
7765d50f51cSMatthias Springer       return failure();
7775d50f51cSMatthias Springer     newConditionOp.getArgsMutable().assign(*newConditionArgs);
778a5d09c63SMatthias Springer 
779a5d09c63SMatthias Springer     // Set up new iter_args and move the loop body block to the new op.
780a5d09c63SMatthias Springer     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
781a5d09c63SMatthias Springer     // in ToTensorOps.
782a5d09c63SMatthias Springer     rewriter.setInsertionPointToStart(newAfterBody);
783996834e6SMatthias Springer     SmallVector<Value> newAfterArgs = getBbArgReplacements(
784996834e6SMatthias Springer         rewriter, newWhileOp.getAfterArguments(), indicesAfter);
785a5d09c63SMatthias Springer     rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
786a5d09c63SMatthias Springer 
787a5d09c63SMatthias Springer     // Update scf.yield of the new loop.
788a5d09c63SMatthias Springer     auto newYieldOp = newWhileOp.getYieldOp();
789a5d09c63SMatthias Springer     rewriter.setInsertionPoint(newYieldOp);
790996834e6SMatthias Springer     // Only equivalent buffers or new buffer allocations may be yielded to the
791996834e6SMatthias Springer     // "before" region.
792996834e6SMatthias Springer     // TODO: This could be relaxed for better bufferization results.
7935d50f51cSMatthias Springer     FailureOr<SmallVector<Value>> newYieldValues =
794996834e6SMatthias Springer         getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
795b55d55ecSMatthias Springer                          indicesBefore, options);
7965d50f51cSMatthias Springer     if (failed(newYieldValues))
7975d50f51cSMatthias Springer       return failure();
7985d50f51cSMatthias Springer     newYieldOp.getResultsMutable().assign(*newYieldValues);
799a5d09c63SMatthias Springer 
800a5d09c63SMatthias Springer     // Replace loop results.
801a5d09c63SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
802a5d09c63SMatthias Springer 
803a5d09c63SMatthias Springer     return success();
804a5d09c63SMatthias Springer   }
805a5d09c63SMatthias Springer 
806a5d09c63SMatthias Springer   /// Assert that yielded values of an scf.while op are equivalent to their
807a5d09c63SMatthias Springer   /// corresponding bbArgs. In that case, the buffer relations of the
808a5d09c63SMatthias Springer   /// corresponding OpResults are "Equivalent".
809a5d09c63SMatthias Springer   ///
810a5d09c63SMatthias Springer   /// If this is not the case, allocs+copies are inserted and yielded from
811a5d09c63SMatthias Springer   /// the loop. This could be a performance problem, so it must be explicitly
812a5d09c63SMatthias Springer   /// activated with `alloc-return-allocs`.
813a5d09c63SMatthias Springer   ///
814a5d09c63SMatthias Springer   /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
815a5d09c63SMatthias Springer   /// equivalence condition must be checked for both.
verifyAnalysismlir::scf::__anon76a8a75a0111::WhileOpInterface816a5d09c63SMatthias Springer   LogicalResult verifyAnalysis(Operation *op,
817a5d09c63SMatthias Springer                                const AnalysisState &state) const {
818a5d09c63SMatthias Springer     auto whileOp = cast<scf::WhileOp>(op);
819a5d09c63SMatthias Springer     const auto &options =
820a5d09c63SMatthias Springer         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
821a5d09c63SMatthias Springer     if (options.allowReturnAllocs)
822a5d09c63SMatthias Springer       return success();
823a5d09c63SMatthias Springer 
824a5d09c63SMatthias Springer     auto conditionOp = whileOp.getConditionOp();
825a5d09c63SMatthias Springer     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
826a5d09c63SMatthias Springer       if (!it.value().getType().isa<TensorType>())
827a5d09c63SMatthias Springer         continue;
828a5d09c63SMatthias Springer       if (!state.areEquivalentBufferizedValues(
829a5d09c63SMatthias Springer               it.value(), conditionOp->getBlock()->getArgument(it.index())))
830a5d09c63SMatthias Springer         return conditionOp->emitError()
831a5d09c63SMatthias Springer                << "Condition arg #" << it.index()
832a5d09c63SMatthias Springer                << " is not equivalent to the corresponding iter bbArg";
833a5d09c63SMatthias Springer     }
834a5d09c63SMatthias Springer 
835a5d09c63SMatthias Springer     auto yieldOp = whileOp.getYieldOp();
836a5d09c63SMatthias Springer     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
837a5d09c63SMatthias Springer       if (!it.value().getType().isa<TensorType>())
838a5d09c63SMatthias Springer         continue;
839a5d09c63SMatthias Springer       if (!state.areEquivalentBufferizedValues(
840a5d09c63SMatthias Springer               it.value(), yieldOp->getBlock()->getArgument(it.index())))
841a5d09c63SMatthias Springer         return yieldOp->emitError()
842a5d09c63SMatthias Springer                << "Yield operand #" << it.index()
843a5d09c63SMatthias Springer                << " is not equivalent to the corresponding iter bbArg";
844a5d09c63SMatthias Springer     }
845a5d09c63SMatthias Springer 
846a5d09c63SMatthias Springer     return success();
847a5d09c63SMatthias Springer   }
848a5d09c63SMatthias Springer };
849a5d09c63SMatthias Springer 
85019efe141SMatthias Springer /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
85119efe141SMatthias Springer /// this is for analysis only.
85219efe141SMatthias Springer struct YieldOpInterface
85319efe141SMatthias Springer     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
85419efe141SMatthias Springer                                                     scf::YieldOp> {
bufferizesToMemoryReadmlir::scf::__anon76a8a75a0111::YieldOpInterface85519efe141SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
8569597b16aSMatthias Springer                               const AnalysisState &state) const {
85719efe141SMatthias Springer     return true;
85819efe141SMatthias Springer   }
85919efe141SMatthias Springer 
bufferizesToMemoryWritemlir::scf::__anon76a8a75a0111::YieldOpInterface86019efe141SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
8619597b16aSMatthias Springer                                const AnalysisState &state) const {
86219efe141SMatthias Springer     return false;
86319efe141SMatthias Springer   }
86419efe141SMatthias Springer 
getAliasingOpResultmlir::scf::__anon76a8a75a0111::YieldOpInterface8659597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
8669597b16aSMatthias Springer                                             const AnalysisState &state) const {
86719efe141SMatthias Springer     if (isa<scf::IfOp>(op->getParentOp()))
868585a8a32SMatthias Springer       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
86919efe141SMatthias Springer     if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
870585a8a32SMatthias Springer       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
871585a8a32SMatthias Springer     return {};
87219efe141SMatthias Springer   }
87319efe141SMatthias Springer 
mustBufferizeInPlacemlir::scf::__anon76a8a75a0111::YieldOpInterface87419efe141SMatthias Springer   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
8759597b16aSMatthias Springer                             const AnalysisState &state) const {
87619efe141SMatthias Springer     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
87719efe141SMatthias Springer     // may be generated inside the block. We should not return/yield allocations
87819efe141SMatthias Springer     // when possible.
87919efe141SMatthias Springer     return true;
88019efe141SMatthias Springer   }
88119efe141SMatthias Springer 
bufferizemlir::scf::__anon76a8a75a0111::YieldOpInterface88219efe141SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
883b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
88419efe141SMatthias Springer     auto yieldOp = cast<scf::YieldOp>(op);
885a5d09c63SMatthias Springer     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
88619efe141SMatthias Springer             yieldOp->getParentOp()))
88719efe141SMatthias Springer       return yieldOp->emitError("unsupported scf::YieldOp parent");
8888e691e1fSMatthias Springer 
8893ff93f83SMatthias Springer     // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized
8903ff93f83SMatthias Springer     // together with scf.while.)
8913ff93f83SMatthias Springer     if (isa<scf::WhileOp>(yieldOp->getParentOp()))
8928e691e1fSMatthias Springer       return success();
8938e691e1fSMatthias Springer 
8948e691e1fSMatthias Springer     SmallVector<Value> newResults;
8958e691e1fSMatthias Springer     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
8968e691e1fSMatthias Springer       Value value = it.value();
8978e691e1fSMatthias Springer       if (value.getType().isa<TensorType>()) {
8985d50f51cSMatthias Springer         FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
8995d50f51cSMatthias Springer         if (failed(maybeBuffer))
9005d50f51cSMatthias Springer           return failure();
9015d50f51cSMatthias Springer         Value buffer = *maybeBuffer;
9023ff93f83SMatthias Springer         if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
9035d50f51cSMatthias Springer           FailureOr<BaseMemRefType> resultType =
9043ff93f83SMatthias Springer               cast<BufferizableOpInterface>(forOp.getOperation())
9053ff93f83SMatthias Springer                   .getBufferType(forOp.getRegionIterArgs()[it.index()],
9063ff93f83SMatthias Springer                                  options);
9075d50f51cSMatthias Springer           if (failed(resultType))
9085d50f51cSMatthias Springer             return failure();
9095d50f51cSMatthias Springer           buffer = castBuffer(rewriter, buffer, *resultType);
9103ff93f83SMatthias Springer         }
9118e691e1fSMatthias Springer         newResults.push_back(buffer);
9128e691e1fSMatthias Springer       } else {
9138e691e1fSMatthias Springer         newResults.push_back(value);
9148e691e1fSMatthias Springer       }
9158e691e1fSMatthias Springer     }
9168e691e1fSMatthias Springer 
9178e691e1fSMatthias Springer     replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
91819efe141SMatthias Springer     return success();
91919efe141SMatthias Springer   }
92019efe141SMatthias Springer };
92119efe141SMatthias Springer 
92272de7588SNicolas Vasilache /// Return the destinations that an ForeachThreadOp is inserting into. One per
92372de7588SNicolas Vasilache /// ParallelInsertSliceOp.
92472de7588SNicolas Vasilache static SmallVector<OpOperand *>
getInsertionDest(ForeachThreadOp foreachThreadOp)92572de7588SNicolas Vasilache getInsertionDest(ForeachThreadOp foreachThreadOp) {
92672de7588SNicolas Vasilache   PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
92772de7588SNicolas Vasilache   SmallVector<OpOperand *> result;
928*7fbf55c9SNicolas Vasilache   terminator.walk([&](tensor::ParallelInsertSliceOp insertOp) {
92972de7588SNicolas Vasilache     result.push_back(&insertOp->getOpOperand(1) /*dest*/);
93072de7588SNicolas Vasilache   });
93172de7588SNicolas Vasilache   return result;
93272de7588SNicolas Vasilache }
93372de7588SNicolas Vasilache 
93472de7588SNicolas Vasilache /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
93572de7588SNicolas Vasilache /// region. There are op interfaces for the terminators (PerformConcurrentlyOp
93672de7588SNicolas Vasilache /// and ParallelInsertSliceOp), but these are only used during analysis. Not
93772de7588SNicolas Vasilache /// for bufferization.
93872de7588SNicolas Vasilache struct ForeachThreadOpInterface
93972de7588SNicolas Vasilache     : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface,
94072de7588SNicolas Vasilache                                                     ForeachThreadOp> {
94172de7588SNicolas Vasilache   SmallVector<OpOperand *>
getAliasingOpOperandmlir::scf::__anon76a8a75a0111::ForeachThreadOpInterface94272de7588SNicolas Vasilache   getAliasingOpOperand(Operation *op, OpResult opResult,
94372de7588SNicolas Vasilache                        const AnalysisState &state) const {
94472de7588SNicolas Vasilache     // Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
94572de7588SNicolas Vasilache     auto foreachThreadOp = cast<ForeachThreadOp>(op);
94672de7588SNicolas Vasilache     return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]};
94772de7588SNicolas Vasilache   }
94872de7588SNicolas Vasilache 
isMemoryWritemlir::scf::__anon76a8a75a0111::ForeachThreadOpInterface94972de7588SNicolas Vasilache   bool isMemoryWrite(Operation *op, OpResult opResult,
95072de7588SNicolas Vasilache                      const AnalysisState &state) const {
95172de7588SNicolas Vasilache     // This op is a memory write. Stop lookup here to avoid finding false
95272de7588SNicolas Vasilache     // conflicts involving this op and one of the ops in the region. This is
95372de7588SNicolas Vasilache     // similar to how scf.if ops are analyzed.
95472de7588SNicolas Vasilache     return true;
95572de7588SNicolas Vasilache   }
95672de7588SNicolas Vasilache 
bufferRelationmlir::scf::__anon76a8a75a0111::ForeachThreadOpInterface95772de7588SNicolas Vasilache   BufferRelation bufferRelation(Operation *op, OpResult opResult,
95872de7588SNicolas Vasilache                                 const AnalysisState &state) const {
95972de7588SNicolas Vasilache     return BufferRelation::Equivalent;
96072de7588SNicolas Vasilache   }
96172de7588SNicolas Vasilache 
bufferizemlir::scf::__anon76a8a75a0111::ForeachThreadOpInterface9627ebf70d8SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
963b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
96472de7588SNicolas Vasilache     auto foreachThreadOp = cast<ForeachThreadOp>(op);
96572de7588SNicolas Vasilache 
9667ebf70d8SMatthias Springer #ifndef NDEBUG
9677ebf70d8SMatthias Springer     // ParallelInsertSliceOpInterface replaces all uses.
9687ebf70d8SMatthias Springer     for (OpResult opResult : foreachThreadOp->getOpResults())
9697ebf70d8SMatthias Springer       assert(opResult.getUses().empty() &&
9707ebf70d8SMatthias Springer              "expected that all uses were already replaced");
9717ebf70d8SMatthias Springer #endif // NDEBUG
97272de7588SNicolas Vasilache 
97372de7588SNicolas Vasilache     // Create new ForeachThreadOp without any results and drop the automatically
97472de7588SNicolas Vasilache     // introduced terminator.
97572de7588SNicolas Vasilache     TypeRange newResultTypes;
9767ebf70d8SMatthias Springer     auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
9777ebf70d8SMatthias Springer         foreachThreadOp.getLoc(), newResultTypes,
978a0f843fdSNicolas Vasilache         foreachThreadOp.getNumThreads(),
979a0f843fdSNicolas Vasilache         extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping()));
98072de7588SNicolas Vasilache     newForeachThreadOp.getBody()->getTerminator()->erase();
98172de7588SNicolas Vasilache 
98272de7588SNicolas Vasilache     // Move over block contents of the old op.
9837ebf70d8SMatthias Springer     rewriter.mergeBlocks(foreachThreadOp.getBody(),
9847ebf70d8SMatthias Springer                          newForeachThreadOp.getBody(),
98572de7588SNicolas Vasilache                          {newForeachThreadOp.getBody()->getArguments()});
98672de7588SNicolas Vasilache 
9877ebf70d8SMatthias Springer     // Remove the old op.
9887ebf70d8SMatthias Springer     rewriter.eraseOp(op);
98972de7588SNicolas Vasilache 
99072de7588SNicolas Vasilache     return success();
99172de7588SNicolas Vasilache   }
99272de7588SNicolas Vasilache };
99372de7588SNicolas Vasilache 
99472de7588SNicolas Vasilache /// Nothing to do for PerformConcurrentlyOp.
99572de7588SNicolas Vasilache struct PerformConcurrentlyOpInterface
99672de7588SNicolas Vasilache     : public BufferizableOpInterface::ExternalModel<
99772de7588SNicolas Vasilache           PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
bufferizemlir::scf::__anon76a8a75a0111::PerformConcurrentlyOpInterface99872de7588SNicolas Vasilache   LogicalResult bufferize(Operation *op, RewriterBase &b,
999b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
1000b3ebe3beSMatthias Springer     llvm_unreachable("op does not have any tensor OpOperands / OpResults");
100172de7588SNicolas Vasilache     return failure();
100272de7588SNicolas Vasilache   }
100372de7588SNicolas Vasilache };
100472de7588SNicolas Vasilache 
100519efe141SMatthias Springer } // namespace
100619efe141SMatthias Springer } // namespace scf
100719efe141SMatthias Springer } // namespace mlir
100819efe141SMatthias Springer 
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)100919efe141SMatthias Springer void mlir::scf::registerBufferizableOpInterfaceExternalModels(
101019efe141SMatthias Springer     DialectRegistry &registry) {
101177eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
101277eee579SRiver Riddle     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
101377eee579SRiver Riddle     ForOp::attachInterface<ForOpInterface>(*ctx);
101477eee579SRiver Riddle     IfOp::attachInterface<IfOpInterface>(*ctx);
101572de7588SNicolas Vasilache     ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
101672de7588SNicolas Vasilache     PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
101772de7588SNicolas Vasilache         *ctx);
1018a5d09c63SMatthias Springer     WhileOp::attachInterface<WhileOpInterface>(*ctx);
101977eee579SRiver Riddle     YieldOp::attachInterface<YieldOpInterface>(*ctx);
102077eee579SRiver Riddle   });
102119efe141SMatthias Springer }
1022