119efe141SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
219efe141SMatthias Springer //
319efe141SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
419efe141SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
519efe141SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
619efe141SMatthias Springer //
719efe141SMatthias Springer //===----------------------------------------------------------------------===//
819efe141SMatthias Springer
98b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
1019efe141SMatthias Springer
1119efe141SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1219efe141SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
131e1eeae8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
158b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
1672de7588SNicolas Vasilache #include "mlir/Dialect/Tensor/IR/Tensor.h"
17a0f843fdSNicolas Vasilache #include "mlir/Dialect/Utils/StaticValueUtils.h"
1819efe141SMatthias Springer #include "mlir/IR/Dialect.h"
1919efe141SMatthias Springer #include "mlir/IR/Operation.h"
2019efe141SMatthias Springer #include "mlir/IR/PatternMatch.h"
2119efe141SMatthias Springer
2219efe141SMatthias Springer using namespace mlir;
2319efe141SMatthias Springer using namespace mlir::bufferization;
2419efe141SMatthias Springer using namespace mlir::scf;
2519efe141SMatthias Springer
2619efe141SMatthias Springer namespace mlir {
2719efe141SMatthias Springer namespace scf {
2819efe141SMatthias Springer namespace {
2919efe141SMatthias Springer
3019efe141SMatthias Springer // bufferization.to_memref is not allowed to change the rank.
ensureToMemrefOpIsValid(Value tensor,Type memrefType)3119efe141SMatthias Springer static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
3219efe141SMatthias Springer #ifndef NDEBUG
3319efe141SMatthias Springer auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
3419efe141SMatthias Springer assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
3519efe141SMatthias Springer rankedTensorType.getRank())) &&
3619efe141SMatthias Springer "to_memref would be invalid: mismatching ranks");
3719efe141SMatthias Springer #endif
3819efe141SMatthias Springer }
3919efe141SMatthias Springer
4019efe141SMatthias Springer /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
4119efe141SMatthias Springer /// fully implemented at the moment.
4219efe141SMatthias Springer struct ExecuteRegionOpInterface
4319efe141SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
4419efe141SMatthias Springer scf::ExecuteRegionOp> {
4519efe141SMatthias Springer SmallVector<OpOperand *>
getAliasingOpOperandmlir::scf::__anon76a8a75a0111::ExecuteRegionOpInterface4619efe141SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult,
479597b16aSMatthias Springer const AnalysisState &state) const {
4819efe141SMatthias Springer // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
4919efe141SMatthias Springer // any SSA value that is in scope. To allow for use-def chain traversal
5019efe141SMatthias Springer // through ExecuteRegionOps in the analysis, the corresponding yield value
5119efe141SMatthias Springer // is considered to be aliasing with the result.
5219efe141SMatthias Springer auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
5319efe141SMatthias Springer size_t resultNum = std::distance(op->getOpResults().begin(),
5419efe141SMatthias Springer llvm::find(op->getOpResults(), opResult));
5519efe141SMatthias Springer // TODO: Support multiple blocks.
5619efe141SMatthias Springer assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
5719efe141SMatthias Springer "expected exactly 1 block");
5819efe141SMatthias Springer auto yieldOp = dyn_cast<scf::YieldOp>(
5919efe141SMatthias Springer executeRegionOp.getRegion().front().getTerminator());
6019efe141SMatthias Springer assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
6119efe141SMatthias Springer return {&yieldOp->getOpOperand(resultNum)};
6219efe141SMatthias Springer }
6319efe141SMatthias Springer
6419efe141SMatthias Springer // TODO: For better bufferization results, this could return `true` only if
6519efe141SMatthias Springer // there is a memory write in the region.
isMemoryWritemlir::scf::__anon76a8a75a0111::ExecuteRegionOpInterface6619efe141SMatthias Springer bool isMemoryWrite(Operation *op, OpResult opResult,
679597b16aSMatthias Springer const AnalysisState &state) const {
6819efe141SMatthias Springer // Similar to scf.if, results of this op are always considered memory writes
6919efe141SMatthias Springer // in the analysis. This is a useful pattern for all ops that have tensor
7019efe141SMatthias Springer // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
7119efe141SMatthias Springer // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
7219efe141SMatthias Springer // ops without OpOperands.
7319efe141SMatthias Springer return true;
7419efe141SMatthias Springer }
7519efe141SMatthias Springer
bufferizemlir::scf::__anon76a8a75a0111::ExecuteRegionOpInterface7619efe141SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
77b55d55ecSMatthias Springer const BufferizationOptions &options) const {
7819efe141SMatthias Springer auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
798e691e1fSMatthias Springer assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
808e691e1fSMatthias Springer "only 1 block supported");
818e691e1fSMatthias Springer auto yieldOp =
828e691e1fSMatthias Springer cast<scf::YieldOp>(executeRegionOp.getRegion().front().getTerminator());
838e691e1fSMatthias Springer TypeRange newResultTypes(yieldOp.getResults());
8419efe141SMatthias Springer
8519efe141SMatthias Springer // Create new op and move over region.
8619efe141SMatthias Springer auto newOp =
8719efe141SMatthias Springer rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
8819efe141SMatthias Springer newOp.getRegion().takeBody(executeRegionOp.getRegion());
8919efe141SMatthias Springer
9019efe141SMatthias Springer // Update all uses of the old op.
9119efe141SMatthias Springer rewriter.setInsertionPointAfter(newOp);
9219efe141SMatthias Springer SmallVector<Value> newResults;
93bb6119ebSMehdi Amini for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
9419efe141SMatthias Springer if (it.value().isa<TensorType>()) {
9519efe141SMatthias Springer newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
9619efe141SMatthias Springer executeRegionOp.getLoc(), newOp->getResult(it.index())));
9719efe141SMatthias Springer } else {
9819efe141SMatthias Springer newResults.push_back(newOp->getResult(it.index()));
9919efe141SMatthias Springer }
10019efe141SMatthias Springer }
10119efe141SMatthias Springer
10219efe141SMatthias Springer // Replace old op.
10319efe141SMatthias Springer rewriter.replaceOp(executeRegionOp, newResults);
10419efe141SMatthias Springer
10519efe141SMatthias Springer return success();
10619efe141SMatthias Springer }
10719efe141SMatthias Springer
bufferRelationmlir::scf::__anon76a8a75a0111::ExecuteRegionOpInterface10819efe141SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult,
1099597b16aSMatthias Springer const AnalysisState &state) const {
11019efe141SMatthias Springer return BufferRelation::Equivalent;
11119efe141SMatthias Springer }
11219efe141SMatthias Springer };
11319efe141SMatthias Springer
11419efe141SMatthias Springer /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
11519efe141SMatthias Springer struct IfOpInterface
11619efe141SMatthias Springer : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
11719efe141SMatthias Springer SmallVector<OpOperand *>
getAliasingOpOperandmlir::scf::__anon76a8a75a0111::IfOpInterface11819efe141SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult,
1199597b16aSMatthias Springer const AnalysisState &state) const {
12019efe141SMatthias Springer // IfOps do not have tensor OpOperands. The yielded value can be any SSA
12119efe141SMatthias Springer // value that is in scope. To allow for use-def chain traversal through
12219efe141SMatthias Springer // IfOps in the analysis, both corresponding yield values from the then/else
12319efe141SMatthias Springer // branches are considered to be aliasing with the result.
12419efe141SMatthias Springer auto ifOp = cast<scf::IfOp>(op);
12519efe141SMatthias Springer size_t resultNum = std::distance(op->getOpResults().begin(),
12619efe141SMatthias Springer llvm::find(op->getOpResults(), opResult));
12719efe141SMatthias Springer return {&ifOp.thenYield()->getOpOperand(resultNum),
12819efe141SMatthias Springer &ifOp.elseYield()->getOpOperand(resultNum)};
12919efe141SMatthias Springer }
13019efe141SMatthias Springer
13119efe141SMatthias Springer // TODO: For better bufferization results, this could return `true` only if
13219efe141SMatthias Springer // there is a memory write in one (or both) of the branches. Since this is not
13319efe141SMatthias Springer // allowed at the moment, we should never encounter scf.ifs that yield
13419efe141SMatthias Springer // unmodified tensors. Such scf.yield ops could just fold away.
isMemoryWritemlir::scf::__anon76a8a75a0111::IfOpInterface13519efe141SMatthias Springer bool isMemoryWrite(Operation *op, OpResult opResult,
1369597b16aSMatthias Springer const AnalysisState &state) const {
13719efe141SMatthias Springer // IfOp results are always considered memory writes in the analysis. This
13819efe141SMatthias Springer // design decision simplifies the analysis considerably. E.g., consider the
13919efe141SMatthias Springer // following test case:
14019efe141SMatthias Springer //
14119efe141SMatthias Springer // %0 = "some_writing_op" : tensor<?xf32>
14219efe141SMatthias Springer // %r = scf.if %c -> (tensor<?xf32>) {
14319efe141SMatthias Springer // scf.yield %0
14419efe141SMatthias Springer // } else {
14519efe141SMatthias Springer // %1 = "another_writing_op"(%0) : tensor<?xf32>
14619efe141SMatthias Springer // }
14719efe141SMatthias Springer // "some_reading_op"(%r)
14819efe141SMatthias Springer //
14919efe141SMatthias Springer // "another_writing_op" in the above example should be able to bufferize
15019efe141SMatthias Springer // inplace in the absence of another read of %0. However, if the scf.if op
15119efe141SMatthias Springer // would not be considered a "write", the analysis would detect the
15219efe141SMatthias Springer // following conflict:
15319efe141SMatthias Springer //
15419efe141SMatthias Springer // * read = some_reading_op
15519efe141SMatthias Springer // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.)
15619efe141SMatthias Springer // * conflictingWrite = %1
15719efe141SMatthias Springer //
15819efe141SMatthias Springer // For more details, check the "scf.IfOp" section of the design document.
15919efe141SMatthias Springer return true;
16019efe141SMatthias Springer }
16119efe141SMatthias Springer
bufferizemlir::scf::__anon76a8a75a0111::IfOpInterface16219efe141SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
163b55d55ecSMatthias Springer const BufferizationOptions &options) const {
1648e691e1fSMatthias Springer OpBuilder::InsertionGuard g(rewriter);
16519efe141SMatthias Springer auto ifOp = cast<scf::IfOp>(op);
1668e691e1fSMatthias Springer auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
1678e691e1fSMatthias Springer auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
16819efe141SMatthias Springer
1698e691e1fSMatthias Springer // Reconcile type mismatches between then/else branches by inserting memref
1708e691e1fSMatthias Springer // casts.
1718e691e1fSMatthias Springer SmallVector<Value> thenResults, elseResults;
1728e691e1fSMatthias Springer bool insertedCast = false;
1738e691e1fSMatthias Springer for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) {
1748e691e1fSMatthias Springer Value thenValue = thenYieldOp.getResults()[i];
1758e691e1fSMatthias Springer Value elseValue = elseYieldOp.getResults()[i];
1768e691e1fSMatthias Springer if (thenValue.getType() == elseValue.getType()) {
1778e691e1fSMatthias Springer thenResults.push_back(thenValue);
1788e691e1fSMatthias Springer elseResults.push_back(elseValue);
1798e691e1fSMatthias Springer continue;
18019efe141SMatthias Springer }
1818e691e1fSMatthias Springer
1828e691e1fSMatthias Springer // Type mismatch between then/else yield value. Cast both to a memref type
1838e691e1fSMatthias Springer // with a fully dynamic layout map.
1848e691e1fSMatthias Springer auto thenMemrefType = thenValue.getType().cast<BaseMemRefType>();
1858e691e1fSMatthias Springer auto elseMemrefType = elseValue.getType().cast<BaseMemRefType>();
1868e691e1fSMatthias Springer if (thenMemrefType.getMemorySpaceAsInt() !=
1878e691e1fSMatthias Springer elseMemrefType.getMemorySpaceAsInt())
1888e691e1fSMatthias Springer return op->emitError("inconsistent memory space on then/else branches");
1898e691e1fSMatthias Springer rewriter.setInsertionPoint(thenYieldOp);
1908e691e1fSMatthias Springer BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout(
1918e691e1fSMatthias Springer ifOp.getResultTypes()[i].cast<TensorType>(),
1928e691e1fSMatthias Springer thenMemrefType.getMemorySpaceAsInt());
1938e691e1fSMatthias Springer thenResults.push_back(rewriter.create<memref::CastOp>(
1948e691e1fSMatthias Springer thenYieldOp.getLoc(), memrefType, thenValue));
1958e691e1fSMatthias Springer rewriter.setInsertionPoint(elseYieldOp);
1968e691e1fSMatthias Springer elseResults.push_back(rewriter.create<memref::CastOp>(
1978e691e1fSMatthias Springer elseYieldOp.getLoc(), memrefType, elseValue));
1988e691e1fSMatthias Springer insertedCast = true;
1998e691e1fSMatthias Springer }
2008e691e1fSMatthias Springer
2018e691e1fSMatthias Springer if (insertedCast) {
2028e691e1fSMatthias Springer rewriter.setInsertionPoint(thenYieldOp);
2038e691e1fSMatthias Springer rewriter.replaceOpWithNewOp<scf::YieldOp>(thenYieldOp, thenResults);
2048e691e1fSMatthias Springer rewriter.setInsertionPoint(elseYieldOp);
2058e691e1fSMatthias Springer rewriter.replaceOpWithNewOp<scf::YieldOp>(elseYieldOp, elseResults);
20619efe141SMatthias Springer }
20719efe141SMatthias Springer
20819efe141SMatthias Springer // Create new op.
2098e691e1fSMatthias Springer rewriter.setInsertionPoint(ifOp);
2108e691e1fSMatthias Springer ValueRange resultsValueRange(thenResults);
2118e691e1fSMatthias Springer TypeRange newTypes(resultsValueRange);
21219efe141SMatthias Springer auto newIfOp =
21319efe141SMatthias Springer rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
21419efe141SMatthias Springer /*withElseRegion=*/true);
21519efe141SMatthias Springer
21619efe141SMatthias Springer // Move over then/else blocks.
21719efe141SMatthias Springer rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
21819efe141SMatthias Springer rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
21919efe141SMatthias Springer
22019efe141SMatthias Springer // Replace op results.
22119efe141SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
22219efe141SMatthias Springer
22319efe141SMatthias Springer return success();
22419efe141SMatthias Springer }
22519efe141SMatthias Springer
bufferRelationmlir::scf::__anon76a8a75a0111::IfOpInterface22619efe141SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult,
2279597b16aSMatthias Springer const AnalysisState &state) const {
22819efe141SMatthias Springer // IfOp results are equivalent to their corresponding yield values if both
22919efe141SMatthias Springer // yield values are equivalent to each other.
23019efe141SMatthias Springer auto bufferizableOp = cast<BufferizableOpInterface>(op);
23119efe141SMatthias Springer SmallVector<OpOperand *> yieldValues =
23219efe141SMatthias Springer bufferizableOp.getAliasingOpOperand(opResult, state);
23319efe141SMatthias Springer assert(yieldValues.size() == 2 && "expected 2 yield values");
23419efe141SMatthias Springer bool equivalentYields = state.areEquivalentBufferizedValues(
23519efe141SMatthias Springer yieldValues[0]->get(), yieldValues[1]->get());
23619efe141SMatthias Springer return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
23719efe141SMatthias Springer }
23819efe141SMatthias Springer };
23919efe141SMatthias Springer
240417e1c7dSMatthias Springer /// Helper function for loop bufferization. Return the indices of all values
241417e1c7dSMatthias Springer /// that have a tensor type.
getTensorIndices(ValueRange values)242417e1c7dSMatthias Springer static DenseSet<int64_t> getTensorIndices(ValueRange values) {
243417e1c7dSMatthias Springer DenseSet<int64_t> result;
244417e1c7dSMatthias Springer for (const auto &it : llvm::enumerate(values))
245417e1c7dSMatthias Springer if (it.value().getType().isa<TensorType>())
246417e1c7dSMatthias Springer result.insert(it.index());
247417e1c7dSMatthias Springer return result;
248417e1c7dSMatthias Springer }
249417e1c7dSMatthias Springer
250417e1c7dSMatthias Springer /// Helper function for loop bufferization. Return the indices of all
251417e1c7dSMatthias Springer /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
getEquivalentBuffers(Block::BlockArgListType bbArgs,ValueRange yieldedValues,const AnalysisState & state)252a5d09c63SMatthias Springer DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
253417e1c7dSMatthias Springer ValueRange yieldedValues,
254417e1c7dSMatthias Springer const AnalysisState &state) {
255996834e6SMatthias Springer unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
256417e1c7dSMatthias Springer DenseSet<int64_t> result;
257996834e6SMatthias Springer for (unsigned int i = 0; i < minSize; ++i) {
258996834e6SMatthias Springer if (!bbArgs[i].getType().isa<TensorType>() ||
259996834e6SMatthias Springer !yieldedValues[i].getType().isa<TensorType>())
260417e1c7dSMatthias Springer continue;
261996834e6SMatthias Springer if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
262996834e6SMatthias Springer result.insert(i);
263417e1c7dSMatthias Springer }
264417e1c7dSMatthias Springer return result;
265417e1c7dSMatthias Springer }
266417e1c7dSMatthias Springer
267417e1c7dSMatthias Springer /// Helper function for loop bufferization. Cast the given buffer to the given
268417e1c7dSMatthias Springer /// memref type.
castBuffer(OpBuilder & b,Value buffer,Type type)269417e1c7dSMatthias Springer static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
270417e1c7dSMatthias Springer assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
271417e1c7dSMatthias Springer assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
272417e1c7dSMatthias Springer // If the buffer already has the correct type, no cast is needed.
273417e1c7dSMatthias Springer if (buffer.getType() == type)
274417e1c7dSMatthias Springer return buffer;
275417e1c7dSMatthias Springer // TODO: In case `type` has a layout map that is not the fully dynamic
276417e1c7dSMatthias Springer // one, we may not be able to cast the buffer. In that case, the loop
277417e1c7dSMatthias Springer // iter_arg's layout map must be changed (see uses of `castBuffer`).
278417e1c7dSMatthias Springer assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
279417e1c7dSMatthias Springer "scf.while op bufferization: cast incompatible");
280417e1c7dSMatthias Springer return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
281417e1c7dSMatthias Springer }
282417e1c7dSMatthias Springer
283417e1c7dSMatthias Springer /// Helper function for loop bufferization. Return the bufferized values of the
284417e1c7dSMatthias Springer /// given OpOperands. If an operand is not a tensor, return the original value.
2855d50f51cSMatthias Springer static FailureOr<SmallVector<Value>>
getBuffers(RewriterBase & rewriter,MutableArrayRef<OpOperand> operands,const BufferizationOptions & options)2865d50f51cSMatthias Springer getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
287b55d55ecSMatthias Springer const BufferizationOptions &options) {
288417e1c7dSMatthias Springer SmallVector<Value> result;
289417e1c7dSMatthias Springer for (OpOperand &opOperand : operands) {
290417e1c7dSMatthias Springer if (opOperand.get().getType().isa<TensorType>()) {
2915d50f51cSMatthias Springer FailureOr<Value> resultBuffer =
2925d50f51cSMatthias Springer getBuffer(rewriter, opOperand.get(), options);
2935d50f51cSMatthias Springer if (failed(resultBuffer))
2945d50f51cSMatthias Springer return failure();
2955d50f51cSMatthias Springer result.push_back(*resultBuffer);
296417e1c7dSMatthias Springer } else {
297417e1c7dSMatthias Springer result.push_back(opOperand.get());
298417e1c7dSMatthias Springer }
299417e1c7dSMatthias Springer }
300417e1c7dSMatthias Springer return result;
301417e1c7dSMatthias Springer }
302417e1c7dSMatthias Springer
303417e1c7dSMatthias Springer /// Helper function for loop bufferization. Compute the buffer that should be
304b3ebe3beSMatthias Springer /// yielded from a loop block (loop body or loop condition).
getYieldedBuffer(RewriterBase & rewriter,Value tensor,BaseMemRefType type,const BufferizationOptions & options)3055d50f51cSMatthias Springer static FailureOr<Value> getYieldedBuffer(RewriterBase &rewriter, Value tensor,
306b55d55ecSMatthias Springer BaseMemRefType type,
307b55d55ecSMatthias Springer const BufferizationOptions &options) {
308417e1c7dSMatthias Springer assert(tensor.getType().isa<TensorType>() && "expected tensor");
309417e1c7dSMatthias Springer ensureToMemrefOpIsValid(tensor, type);
3105d50f51cSMatthias Springer FailureOr<Value> yieldedVal = getBuffer(rewriter, tensor, options);
3115d50f51cSMatthias Springer if (failed(yieldedVal))
3125d50f51cSMatthias Springer return failure();
3135d50f51cSMatthias Springer return castBuffer(rewriter, *yieldedVal, type);
314417e1c7dSMatthias Springer }
315417e1c7dSMatthias Springer
316417e1c7dSMatthias Springer /// Helper function for loop bufferization. Given a range of values, apply
317417e1c7dSMatthias Springer /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
318417e1c7dSMatthias Springer /// value in the result vector.
3195d50f51cSMatthias Springer static FailureOr<SmallVector<Value>>
convertTensorValues(ValueRange values,const DenseSet<int64_t> & tensorIndices,llvm::function_ref<FailureOr<Value> (Value,int64_t)> func)320417e1c7dSMatthias Springer convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
3215d50f51cSMatthias Springer llvm::function_ref<FailureOr<Value>(Value, int64_t)> func) {
322417e1c7dSMatthias Springer SmallVector<Value> result;
323417e1c7dSMatthias Springer for (const auto &it : llvm::enumerate(values)) {
324417e1c7dSMatthias Springer size_t idx = it.index();
325417e1c7dSMatthias Springer Value val = it.value();
3265d50f51cSMatthias Springer if (tensorIndices.contains(idx)) {
3275d50f51cSMatthias Springer FailureOr<Value> maybeVal = func(val, idx);
3285d50f51cSMatthias Springer if (failed(maybeVal))
3295d50f51cSMatthias Springer return failure();
3305d50f51cSMatthias Springer result.push_back(*maybeVal);
3315d50f51cSMatthias Springer } else {
3325d50f51cSMatthias Springer result.push_back(val);
3335d50f51cSMatthias Springer }
334417e1c7dSMatthias Springer }
335417e1c7dSMatthias Springer return result;
336417e1c7dSMatthias Springer }
337417e1c7dSMatthias Springer
338417e1c7dSMatthias Springer /// Helper function for loop bufferization. Given a list of pre-bufferization
339417e1c7dSMatthias Springer /// yielded values, compute the list of bufferized yielded values.
3405d50f51cSMatthias Springer FailureOr<SmallVector<Value>>
getYieldedValues(RewriterBase & rewriter,ValueRange values,TypeRange bufferizedTypes,const DenseSet<int64_t> & tensorIndices,const BufferizationOptions & options)3415d50f51cSMatthias Springer getYieldedValues(RewriterBase &rewriter, ValueRange values,
342417e1c7dSMatthias Springer TypeRange bufferizedTypes,
343417e1c7dSMatthias Springer const DenseSet<int64_t> &tensorIndices,
344b55d55ecSMatthias Springer const BufferizationOptions &options) {
345417e1c7dSMatthias Springer return convertTensorValues(
346417e1c7dSMatthias Springer values, tensorIndices, [&](Value val, int64_t index) {
347417e1c7dSMatthias Springer return getYieldedBuffer(rewriter, val,
348417e1c7dSMatthias Springer bufferizedTypes[index].cast<BaseMemRefType>(),
349b55d55ecSMatthias Springer options);
350417e1c7dSMatthias Springer });
351417e1c7dSMatthias Springer }
352417e1c7dSMatthias Springer
353a5d09c63SMatthias Springer /// Helper function for loop bufferization. Given a list of bbArgs of the new
354a5d09c63SMatthias Springer /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
355a5d09c63SMatthias Springer /// ToTensorOps, so that the block body can be moved over to the new op.
356a5d09c63SMatthias Springer SmallVector<Value>
getBbArgReplacements(RewriterBase & rewriter,Block::BlockArgListType bbArgs,const DenseSet<int64_t> & tensorIndices)357a5d09c63SMatthias Springer getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
358a5d09c63SMatthias Springer const DenseSet<int64_t> &tensorIndices) {
3595d50f51cSMatthias Springer SmallVector<Value> result;
3605d50f51cSMatthias Springer for (const auto &it : llvm::enumerate(bbArgs)) {
3615d50f51cSMatthias Springer size_t idx = it.index();
3625d50f51cSMatthias Springer Value val = it.value();
3635d50f51cSMatthias Springer if (tensorIndices.contains(idx)) {
3645d50f51cSMatthias Springer result.push_back(
3655d50f51cSMatthias Springer rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
3665d50f51cSMatthias Springer .getResult());
3675d50f51cSMatthias Springer } else {
3685d50f51cSMatthias Springer result.push_back(val);
3695d50f51cSMatthias Springer }
3705d50f51cSMatthias Springer }
3715d50f51cSMatthias Springer return result;
372a5d09c63SMatthias Springer }
373a5d09c63SMatthias Springer
37419efe141SMatthias Springer /// Bufferization of scf.for. Replace with a new scf.for that operates on
37519efe141SMatthias Springer /// memrefs.
37619efe141SMatthias Springer struct ForOpInterface
37719efe141SMatthias Springer : public BufferizableOpInterface::ExternalModel<ForOpInterface,
37819efe141SMatthias Springer scf::ForOp> {
bufferizesToMemoryReadmlir::scf::__anon76a8a75a0111::ForOpInterface37919efe141SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
3809597b16aSMatthias Springer const AnalysisState &state) const {
38119efe141SMatthias Springer // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
38219efe141SMatthias Springer // its matching bbArg may.
38319efe141SMatthias Springer auto forOp = cast<scf::ForOp>(op);
38419efe141SMatthias Springer return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
38519efe141SMatthias Springer }
38619efe141SMatthias Springer
bufferizesToMemoryWritemlir::scf::__anon76a8a75a0111::ForOpInterface38719efe141SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
3889597b16aSMatthias Springer const AnalysisState &state) const {
3891e1eeae8SMatthias Springer // Tensor iter_args of scf::ForOps are always considered as a write.
39019efe141SMatthias Springer return true;
39119efe141SMatthias Springer }
39219efe141SMatthias Springer
getAliasingOpResultmlir::scf::__anon76a8a75a0111::ForOpInterface3939597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
3949597b16aSMatthias Springer const AnalysisState &state) const {
39519efe141SMatthias Springer auto forOp = cast<scf::ForOp>(op);
396585a8a32SMatthias Springer return {forOp.getResultForOpOperand(opOperand)};
39719efe141SMatthias Springer }
39819efe141SMatthias Springer
bufferRelationmlir::scf::__anon76a8a75a0111::ForOpInterface39919efe141SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult,
4009597b16aSMatthias Springer const AnalysisState &state) const {
40119efe141SMatthias Springer // ForOp results are equivalent to their corresponding init_args if the
40219efe141SMatthias Springer // corresponding iter_args and yield values are equivalent.
40319efe141SMatthias Springer auto forOp = cast<scf::ForOp>(op);
40419efe141SMatthias Springer OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
40519efe141SMatthias Springer auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
4061e1eeae8SMatthias Springer auto yieldOp =
4071e1eeae8SMatthias Springer cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
40819efe141SMatthias Springer bool equivalentYield = state.areEquivalentBufferizedValues(
40919efe141SMatthias Springer bbArg, yieldOp->getOperand(opResult.getResultNumber()));
41019efe141SMatthias Springer return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
41119efe141SMatthias Springer }
41219efe141SMatthias Springer
isWritablemlir::scf::__anon76a8a75a0111::ForOpInterface41319efe141SMatthias Springer bool isWritable(Operation *op, Value value,
4149597b16aSMatthias Springer const AnalysisState &state) const {
41519efe141SMatthias Springer // Interestingly, scf::ForOp's bbArg can **always** be viewed
41619efe141SMatthias Springer // inplace from the perspective of ops nested under:
41719efe141SMatthias Springer // 1. Either the matching iter operand is not bufferized inplace and an
41819efe141SMatthias Springer // alloc + optional copy makes the bbArg itself inplaceable.
41919efe141SMatthias Springer // 2. Or the matching iter operand is bufferized inplace and bbArg just
42019efe141SMatthias Springer // bufferizes to that too.
42119efe141SMatthias Springer return true;
42219efe141SMatthias Springer }
42319efe141SMatthias Springer
resolveConflictsmlir::scf::__anon76a8a75a0111::ForOpInterface424d361ecbdSMatthias Springer LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
425d361ecbdSMatthias Springer const AnalysisState &state) const {
426d361ecbdSMatthias Springer auto bufferizableOp = cast<BufferizableOpInterface>(op);
427d361ecbdSMatthias Springer if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
428d361ecbdSMatthias Springer return failure();
429d361ecbdSMatthias Springer
430d361ecbdSMatthias Springer if (!state.getOptions().enforceAliasingInvariants)
431d361ecbdSMatthias Springer return success();
432d361ecbdSMatthias Springer
433d361ecbdSMatthias Springer // According to the `getAliasing...` implementations, a bufferized OpResult
434d361ecbdSMatthias Springer // may alias only with the corresponding bufferized init_arg and with no
435d361ecbdSMatthias Springer // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
436d361ecbdSMatthias Springer // but not with any other OpOperand. If a corresponding OpResult/init_arg
437d361ecbdSMatthias Springer // pair bufferizes to equivalent buffers, this aliasing requirement is
438d361ecbdSMatthias Springer // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
439d361ecbdSMatthias Springer // (New buffer copies do not alias with any buffer.)
440d361ecbdSMatthias Springer auto forOp = cast<scf::ForOp>(op);
441d361ecbdSMatthias Springer auto yieldOp =
442d361ecbdSMatthias Springer cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
443d361ecbdSMatthias Springer OpBuilder::InsertionGuard g(rewriter);
444d361ecbdSMatthias Springer rewriter.setInsertionPoint(yieldOp);
445d361ecbdSMatthias Springer
446d361ecbdSMatthias Springer // Indices of all iter_args that have tensor type. These are the ones that
447d361ecbdSMatthias Springer // are bufferized.
448d361ecbdSMatthias Springer DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
449d361ecbdSMatthias Springer // For every yielded value, is the value equivalent to its corresponding
450d361ecbdSMatthias Springer // bbArg?
451d361ecbdSMatthias Springer DenseSet<int64_t> equivalentYields = getEquivalentBuffers(
452d361ecbdSMatthias Springer forOp.getRegionIterArgs(), yieldOp.getResults(), state);
453d361ecbdSMatthias Springer SmallVector<Value> yieldValues;
454d361ecbdSMatthias Springer for (int64_t idx = 0;
455d361ecbdSMatthias Springer idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
456d361ecbdSMatthias Springer Value value = yieldOp.getResults()[idx];
457d361ecbdSMatthias Springer if (!indices.contains(idx) || equivalentYields.contains(idx)) {
458d361ecbdSMatthias Springer yieldValues.push_back(value);
459d361ecbdSMatthias Springer continue;
460d361ecbdSMatthias Springer }
46145b995cdSMatthias Springer FailureOr<Value> alloc =
46245b995cdSMatthias Springer allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
46345b995cdSMatthias Springer /*escape=*/true, state.getOptions());
46445b995cdSMatthias Springer if (failed(alloc))
46545b995cdSMatthias Springer return failure();
46645b995cdSMatthias Springer yieldValues.push_back(*alloc);
467d361ecbdSMatthias Springer }
468d361ecbdSMatthias Springer
469d361ecbdSMatthias Springer rewriter.updateRootInPlace(
470d361ecbdSMatthias Springer yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
471d361ecbdSMatthias Springer return success();
472d361ecbdSMatthias Springer }
473d361ecbdSMatthias Springer
4745d50f51cSMatthias Springer FailureOr<BaseMemRefType>
getBufferTypemlir::scf::__anon76a8a75a0111::ForOpInterface4755d50f51cSMatthias Springer getBufferType(Operation *op, BlockArgument bbArg,
4763ff93f83SMatthias Springer const BufferizationOptions &options) const {
4773ff93f83SMatthias Springer auto forOp = cast<scf::ForOp>(op);
4783ff93f83SMatthias Springer return bufferization::getBufferType(
4793ff93f83SMatthias Springer forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
4803ff93f83SMatthias Springer }
4813ff93f83SMatthias Springer
bufferizemlir::scf::__anon76a8a75a0111::ForOpInterface48219efe141SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
483b55d55ecSMatthias Springer const BufferizationOptions &options) const {
48419efe141SMatthias Springer auto forOp = cast<scf::ForOp>(op);
48519efe141SMatthias Springer Block *oldLoopBody = &forOp.getLoopBody().front();
48619efe141SMatthias Springer
48719efe141SMatthias Springer // Indices of all iter_args that have tensor type. These are the ones that
48819efe141SMatthias Springer // are bufferized.
489417e1c7dSMatthias Springer DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
49019efe141SMatthias Springer
491417e1c7dSMatthias Springer // The new memref init_args of the loop.
4925d50f51cSMatthias Springer FailureOr<SmallVector<Value>> maybeInitArgs =
493b55d55ecSMatthias Springer getBuffers(rewriter, forOp.getIterOpOperands(), options);
4945d50f51cSMatthias Springer if (failed(maybeInitArgs))
4955d50f51cSMatthias Springer return failure();
4965d50f51cSMatthias Springer SmallVector<Value> initArgs = *maybeInitArgs;
49719efe141SMatthias Springer
49819efe141SMatthias Springer // Construct a new scf.for op with memref instead of tensor values.
49919efe141SMatthias Springer auto newForOp = rewriter.create<scf::ForOp>(
50019efe141SMatthias Springer forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
50119efe141SMatthias Springer forOp.getStep(), initArgs);
502413fbb04SLei Zhang newForOp->setAttrs(forOp->getAttrs());
503417e1c7dSMatthias Springer ValueRange initArgsRange(initArgs);
504417e1c7dSMatthias Springer TypeRange initArgsTypes(initArgsRange);
50519efe141SMatthias Springer Block *loopBody = &newForOp.getLoopBody().front();
50619efe141SMatthias Springer
50719efe141SMatthias Springer // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
50819efe141SMatthias Springer // iter_args of the new loop in ToTensorOps.
50919efe141SMatthias Springer rewriter.setInsertionPointToStart(loopBody);
510a5d09c63SMatthias Springer SmallVector<Value> iterArgs =
511a5d09c63SMatthias Springer getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
51219efe141SMatthias Springer iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
51319efe141SMatthias Springer
51419efe141SMatthias Springer // Move loop body to new loop.
51519efe141SMatthias Springer rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
51619efe141SMatthias Springer
51719efe141SMatthias Springer // Replace loop results.
51819efe141SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
51919efe141SMatthias Springer
52019efe141SMatthias Springer return success();
52119efe141SMatthias Springer }
5224ec00fb3SMatthias Springer
5231e1eeae8SMatthias Springer /// Assert that yielded values of an scf.for op are equivalent to their
524f178c386SMatthias Springer /// corresponding bbArgs. In that case, the buffer relations of the
525f178c386SMatthias Springer /// corresponding OpResults are "Equivalent".
526f178c386SMatthias Springer ///
527f178c386SMatthias Springer /// If this is not the case, an allocs+copies are inserted and yielded from
528f178c386SMatthias Springer /// the loop. This could be a performance problem, so it must be explicitly
529f178c386SMatthias Springer /// activated with `alloc-return-allocs`.
verifyAnalysismlir::scf::__anon76a8a75a0111::ForOpInterface5304ec00fb3SMatthias Springer LogicalResult verifyAnalysis(Operation *op,
5319597b16aSMatthias Springer const AnalysisState &state) const {
5321e1eeae8SMatthias Springer const auto &options =
5331e1eeae8SMatthias Springer static_cast<const OneShotBufferizationOptions &>(state.getOptions());
5341e1eeae8SMatthias Springer if (options.allowReturnAllocs)
5351e1eeae8SMatthias Springer return success();
5361e1eeae8SMatthias Springer
5374ec00fb3SMatthias Springer auto forOp = cast<scf::ForOp>(op);
5384ec00fb3SMatthias Springer auto yieldOp =
5394ec00fb3SMatthias Springer cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
540f178c386SMatthias Springer for (OpResult opResult : op->getOpResults()) {
541f178c386SMatthias Springer if (!opResult.getType().isa<TensorType>())
5424ec00fb3SMatthias Springer continue;
5434ec00fb3SMatthias Springer
5444ec00fb3SMatthias Springer // Note: This is overly strict. We should check for aliasing bufferized
5454ec00fb3SMatthias Springer // values. But we don't have a "must-alias" analysis yet.
546f178c386SMatthias Springer if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
5474ec00fb3SMatthias Springer return yieldOp->emitError()
548f178c386SMatthias Springer << "Yield operand #" << opResult.getResultNumber()
549e3006825SMatthias Springer << " is not equivalent to the corresponding iter bbArg";
5504ec00fb3SMatthias Springer }
551f178c386SMatthias Springer
5524ec00fb3SMatthias Springer return success();
5534ec00fb3SMatthias Springer }
55419efe141SMatthias Springer };
55519efe141SMatthias Springer
556a5d09c63SMatthias Springer /// Bufferization of scf.while. Replace with a new scf.while that operates on
557a5d09c63SMatthias Springer /// memrefs.
558a5d09c63SMatthias Springer struct WhileOpInterface
559a5d09c63SMatthias Springer : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
560a5d09c63SMatthias Springer scf::WhileOp> {
bufferizesToMemoryReadmlir::scf::__anon76a8a75a0111::WhileOpInterface561a5d09c63SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
562a5d09c63SMatthias Springer const AnalysisState &state) const {
563a5d09c63SMatthias Springer // Tensor iter_args of scf::WhileOps are always considered as a read.
564a5d09c63SMatthias Springer return true;
565a5d09c63SMatthias Springer }
566a5d09c63SMatthias Springer
bufferizesToMemoryWritemlir::scf::__anon76a8a75a0111::WhileOpInterface567a5d09c63SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
568a5d09c63SMatthias Springer const AnalysisState &state) const {
569a5d09c63SMatthias Springer // Tensor iter_args of scf::WhileOps are always considered as a write.
570a5d09c63SMatthias Springer return true;
571a5d09c63SMatthias Springer }
572a5d09c63SMatthias Springer
getAliasingOpResultmlir::scf::__anon76a8a75a0111::WhileOpInterface573a5d09c63SMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
574a5d09c63SMatthias Springer const AnalysisState &state) const {
575a5d09c63SMatthias Springer auto whileOp = cast<scf::WhileOp>(op);
576996834e6SMatthias Springer unsigned int idx = opOperand.getOperandNumber();
577996834e6SMatthias Springer
578996834e6SMatthias Springer // The OpResults and OpOperands may not match. They may not even have the
579996834e6SMatthias Springer // same type. The number of OpResults and OpOperands can also differ.
580996834e6SMatthias Springer if (idx >= op->getNumResults() ||
581996834e6SMatthias Springer opOperand.get().getType() != op->getResult(idx).getType())
582996834e6SMatthias Springer return {};
583996834e6SMatthias Springer
584996834e6SMatthias Springer // The only aliasing OpResult may be the one at the same index.
585996834e6SMatthias Springer return {whileOp->getResult(idx)};
586a5d09c63SMatthias Springer }
587a5d09c63SMatthias Springer
bufferRelationmlir::scf::__anon76a8a75a0111::WhileOpInterface588a5d09c63SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult,
589a5d09c63SMatthias Springer const AnalysisState &state) const {
590a5d09c63SMatthias Springer // WhileOp results are equivalent to their corresponding init_args if the
591a5d09c63SMatthias Springer // corresponding iter_args and yield values are equivalent (for both the
592a5d09c63SMatthias Springer // "before" and the "after" block).
593a5d09c63SMatthias Springer unsigned int resultNumber = opResult.getResultNumber();
594a5d09c63SMatthias Springer auto whileOp = cast<scf::WhileOp>(op);
595a5d09c63SMatthias Springer
596996834e6SMatthias Springer // The "before" region bbArgs and the OpResults may not match.
597996834e6SMatthias Springer if (resultNumber >= whileOp.getBeforeArguments().size())
598996834e6SMatthias Springer return BufferRelation::None;
599996834e6SMatthias Springer if (opResult.getType() !=
600996834e6SMatthias Springer whileOp.getBeforeArguments()[resultNumber].getType())
601996834e6SMatthias Springer return BufferRelation::None;
602996834e6SMatthias Springer
603a5d09c63SMatthias Springer auto conditionOp = whileOp.getConditionOp();
604a5d09c63SMatthias Springer BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
605a5d09c63SMatthias Springer Value conditionOperand = conditionOp.getArgs()[resultNumber];
606a5d09c63SMatthias Springer bool equivCondition =
607a5d09c63SMatthias Springer state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
608a5d09c63SMatthias Springer
609a5d09c63SMatthias Springer auto yieldOp = whileOp.getYieldOp();
610a5d09c63SMatthias Springer BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
611a5d09c63SMatthias Springer Value yieldOperand = yieldOp.getOperand(resultNumber);
612a5d09c63SMatthias Springer bool equivYield =
613a5d09c63SMatthias Springer state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
614a5d09c63SMatthias Springer
615a5d09c63SMatthias Springer return equivCondition && equivYield ? BufferRelation::Equivalent
616a5d09c63SMatthias Springer : BufferRelation::None;
617a5d09c63SMatthias Springer }
618a5d09c63SMatthias Springer
isWritablemlir::scf::__anon76a8a75a0111::WhileOpInterface619a5d09c63SMatthias Springer bool isWritable(Operation *op, Value value,
620a5d09c63SMatthias Springer const AnalysisState &state) const {
621a5d09c63SMatthias Springer // Interestingly, scf::WhileOp's bbArg can **always** be viewed
622a5d09c63SMatthias Springer // inplace from the perspective of ops nested under:
623a5d09c63SMatthias Springer // 1. Either the matching iter operand is not bufferized inplace and an
624a5d09c63SMatthias Springer // alloc + optional copy makes the bbArg itself inplaceable.
625a5d09c63SMatthias Springer // 2. Or the matching iter operand is bufferized inplace and bbArg just
626a5d09c63SMatthias Springer // bufferizes to that too.
627a5d09c63SMatthias Springer return true;
628a5d09c63SMatthias Springer }
629a5d09c63SMatthias Springer
resolveConflictsmlir::scf::__anon76a8a75a0111::WhileOpInterface630d361ecbdSMatthias Springer LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
631d361ecbdSMatthias Springer const AnalysisState &state) const {
632d361ecbdSMatthias Springer auto bufferizableOp = cast<BufferizableOpInterface>(op);
633d361ecbdSMatthias Springer if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
634d361ecbdSMatthias Springer return failure();
635d361ecbdSMatthias Springer
636d361ecbdSMatthias Springer if (!state.getOptions().enforceAliasingInvariants)
637d361ecbdSMatthias Springer return success();
638d361ecbdSMatthias Springer
639d361ecbdSMatthias Springer // According to the `getAliasing...` implementations, a bufferized OpResult
640d361ecbdSMatthias Springer // may alias only with the corresponding bufferized init_arg and with no
641d361ecbdSMatthias Springer // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
642d361ecbdSMatthias Springer // but not with any other OpOperand. If a corresponding OpResult/init_arg
643d361ecbdSMatthias Springer // pair bufferizes to equivalent buffers, this aliasing requirement is
644d361ecbdSMatthias Springer // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
645d361ecbdSMatthias Springer // (New buffer copies do not alias with any buffer.)
646d361ecbdSMatthias Springer OpBuilder::InsertionGuard g(rewriter);
647d361ecbdSMatthias Springer auto whileOp = cast<scf::WhileOp>(op);
648d361ecbdSMatthias Springer auto conditionOp = whileOp.getConditionOp();
649d361ecbdSMatthias Springer auto yieldOp = whileOp.getYieldOp();
650d361ecbdSMatthias Springer
651d361ecbdSMatthias Springer // Indices of all bbArgs that have tensor type. These are the ones that
652d361ecbdSMatthias Springer // are bufferized. The "before" and "after" regions may have different args.
653d361ecbdSMatthias Springer DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
654d361ecbdSMatthias Springer DenseSet<int64_t> indicesAfter =
655d361ecbdSMatthias Springer getTensorIndices(whileOp.getAfterArguments());
656d361ecbdSMatthias Springer
657d361ecbdSMatthias Springer // For every yielded value, is the value equivalent to its corresponding
658d361ecbdSMatthias Springer // bbArg?
659d361ecbdSMatthias Springer DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
660d361ecbdSMatthias Springer whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
661d361ecbdSMatthias Springer DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
662d361ecbdSMatthias Springer whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
663d361ecbdSMatthias Springer
664d361ecbdSMatthias Springer // Update "before" region.
665d361ecbdSMatthias Springer rewriter.setInsertionPoint(conditionOp);
666d361ecbdSMatthias Springer SmallVector<Value> beforeYieldValues;
667d361ecbdSMatthias Springer for (int64_t idx = 0;
668d361ecbdSMatthias Springer idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
669d361ecbdSMatthias Springer Value value = conditionOp.getArgs()[idx];
670d361ecbdSMatthias Springer if (!indicesBefore.contains(idx) ||
671d361ecbdSMatthias Springer equivalentYieldsBefore.contains(idx)) {
672d361ecbdSMatthias Springer beforeYieldValues.push_back(value);
673d361ecbdSMatthias Springer continue;
674d361ecbdSMatthias Springer }
67545b995cdSMatthias Springer FailureOr<Value> alloc =
67645b995cdSMatthias Springer allocateTensorForShapedValue(rewriter, conditionOp.getLoc(), value,
67745b995cdSMatthias Springer /*escape=*/true, state.getOptions());
67845b995cdSMatthias Springer if (failed(alloc))
67945b995cdSMatthias Springer return failure();
68045b995cdSMatthias Springer beforeYieldValues.push_back(*alloc);
681d361ecbdSMatthias Springer }
682d361ecbdSMatthias Springer rewriter.updateRootInPlace(conditionOp, [&]() {
683d361ecbdSMatthias Springer conditionOp.getArgsMutable().assign(beforeYieldValues);
684d361ecbdSMatthias Springer });
685d361ecbdSMatthias Springer
686d361ecbdSMatthias Springer // Update "after" region.
687d361ecbdSMatthias Springer rewriter.setInsertionPoint(yieldOp);
688d361ecbdSMatthias Springer SmallVector<Value> afterYieldValues;
689d361ecbdSMatthias Springer for (int64_t idx = 0;
690d361ecbdSMatthias Springer idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
691d361ecbdSMatthias Springer Value value = yieldOp.getResults()[idx];
692d361ecbdSMatthias Springer if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) {
693d361ecbdSMatthias Springer afterYieldValues.push_back(value);
694d361ecbdSMatthias Springer continue;
695d361ecbdSMatthias Springer }
69645b995cdSMatthias Springer FailureOr<Value> alloc =
69745b995cdSMatthias Springer allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
69845b995cdSMatthias Springer /*escape=*/true, state.getOptions());
69945b995cdSMatthias Springer if (failed(alloc))
70045b995cdSMatthias Springer return failure();
70145b995cdSMatthias Springer afterYieldValues.push_back(*alloc);
702d361ecbdSMatthias Springer }
703d361ecbdSMatthias Springer rewriter.updateRootInPlace(yieldOp, [&]() {
704d361ecbdSMatthias Springer yieldOp.getResultsMutable().assign(afterYieldValues);
705d361ecbdSMatthias Springer });
706d361ecbdSMatthias Springer
707d361ecbdSMatthias Springer return success();
708d361ecbdSMatthias Springer }
709d361ecbdSMatthias Springer
710c0b0b6a0SMatthias Springer // TODO: Implement getBufferType interface method and infer buffer types.
711c0b0b6a0SMatthias Springer
bufferizemlir::scf::__anon76a8a75a0111::WhileOpInterface712a5d09c63SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
713b55d55ecSMatthias Springer const BufferizationOptions &options) const {
714a5d09c63SMatthias Springer auto whileOp = cast<scf::WhileOp>(op);
715a5d09c63SMatthias Springer
716a5d09c63SMatthias Springer assert(whileOp.getBefore().getBlocks().size() == 1 &&
717a5d09c63SMatthias Springer "regions with multiple blocks not supported");
718a5d09c63SMatthias Springer Block *beforeBody = &whileOp.getBefore().front();
719a5d09c63SMatthias Springer assert(whileOp.getAfter().getBlocks().size() == 1 &&
720a5d09c63SMatthias Springer "regions with multiple blocks not supported");
721a5d09c63SMatthias Springer Block *afterBody = &whileOp.getAfter().front();
722a5d09c63SMatthias Springer
723996834e6SMatthias Springer // Indices of all bbArgs that have tensor type. These are the ones that
724996834e6SMatthias Springer // are bufferized. The "before" and "after" regions may have different args.
725996834e6SMatthias Springer DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
726996834e6SMatthias Springer DenseSet<int64_t> indicesAfter =
727996834e6SMatthias Springer getTensorIndices(whileOp.getAfterArguments());
728996834e6SMatthias Springer
729a5d09c63SMatthias Springer // The new memref init_args of the loop.
7305d50f51cSMatthias Springer FailureOr<SmallVector<Value>> maybeInitArgs =
731b55d55ecSMatthias Springer getBuffers(rewriter, whileOp->getOpOperands(), options);
7325d50f51cSMatthias Springer if (failed(maybeInitArgs))
7335d50f51cSMatthias Springer return failure();
7345d50f51cSMatthias Springer SmallVector<Value> initArgs = *maybeInitArgs;
735996834e6SMatthias Springer
736996834e6SMatthias Springer // The result types of a WhileOp are the same as the "after" bbArg types.
737996834e6SMatthias Springer SmallVector<Type> argsTypesAfter = llvm::to_vector(
738996834e6SMatthias Springer llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
7395d50f51cSMatthias Springer // TODO: error handling
7405d50f51cSMatthias Springer return bufferization::getBufferType(bbArg, options)->cast<Type>();
741996834e6SMatthias Springer }));
742a5d09c63SMatthias Springer
743a5d09c63SMatthias Springer // Construct a new scf.while op with memref instead of tensor values.
744996834e6SMatthias Springer ValueRange argsRangeBefore(initArgs);
745996834e6SMatthias Springer TypeRange argsTypesBefore(argsRangeBefore);
746996834e6SMatthias Springer auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
747996834e6SMatthias Springer argsTypesAfter, initArgs);
748996834e6SMatthias Springer
749a5d09c63SMatthias Springer // Add before/after regions to the new op.
750996834e6SMatthias Springer SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc());
751996834e6SMatthias Springer SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
752996834e6SMatthias Springer whileOp.getLoc());
753a5d09c63SMatthias Springer Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
754996834e6SMatthias Springer newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
755a5d09c63SMatthias Springer Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
756996834e6SMatthias Springer newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
757a5d09c63SMatthias Springer
758a5d09c63SMatthias Springer // Set up new iter_args and move the loop condition block to the new op.
759a5d09c63SMatthias Springer // The old block uses tensors, so wrap the (memref) bbArgs of the new block
760a5d09c63SMatthias Springer // in ToTensorOps.
761a5d09c63SMatthias Springer rewriter.setInsertionPointToStart(newBeforeBody);
762a5d09c63SMatthias Springer SmallVector<Value> newBeforeArgs = getBbArgReplacements(
763996834e6SMatthias Springer rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
764a5d09c63SMatthias Springer rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
765a5d09c63SMatthias Springer
766a5d09c63SMatthias Springer // Update scf.condition of new loop.
767a5d09c63SMatthias Springer auto newConditionOp = newWhileOp.getConditionOp();
768a5d09c63SMatthias Springer rewriter.setInsertionPoint(newConditionOp);
769996834e6SMatthias Springer // Only equivalent buffers or new buffer allocations may be yielded to the
770996834e6SMatthias Springer // "after" region.
771996834e6SMatthias Springer // TODO: This could be relaxed for better bufferization results.
7725d50f51cSMatthias Springer FailureOr<SmallVector<Value>> newConditionArgs =
773996834e6SMatthias Springer getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
774b55d55ecSMatthias Springer indicesAfter, options);
7755d50f51cSMatthias Springer if (failed(newConditionArgs))
7765d50f51cSMatthias Springer return failure();
7775d50f51cSMatthias Springer newConditionOp.getArgsMutable().assign(*newConditionArgs);
778a5d09c63SMatthias Springer
779a5d09c63SMatthias Springer // Set up new iter_args and move the loop body block to the new op.
780a5d09c63SMatthias Springer // The old block uses tensors, so wrap the (memref) bbArgs of the new block
781a5d09c63SMatthias Springer // in ToTensorOps.
782a5d09c63SMatthias Springer rewriter.setInsertionPointToStart(newAfterBody);
783996834e6SMatthias Springer SmallVector<Value> newAfterArgs = getBbArgReplacements(
784996834e6SMatthias Springer rewriter, newWhileOp.getAfterArguments(), indicesAfter);
785a5d09c63SMatthias Springer rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
786a5d09c63SMatthias Springer
787a5d09c63SMatthias Springer // Update scf.yield of the new loop.
788a5d09c63SMatthias Springer auto newYieldOp = newWhileOp.getYieldOp();
789a5d09c63SMatthias Springer rewriter.setInsertionPoint(newYieldOp);
790996834e6SMatthias Springer // Only equivalent buffers or new buffer allocations may be yielded to the
791996834e6SMatthias Springer // "before" region.
792996834e6SMatthias Springer // TODO: This could be relaxed for better bufferization results.
7935d50f51cSMatthias Springer FailureOr<SmallVector<Value>> newYieldValues =
794996834e6SMatthias Springer getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
795b55d55ecSMatthias Springer indicesBefore, options);
7965d50f51cSMatthias Springer if (failed(newYieldValues))
7975d50f51cSMatthias Springer return failure();
7985d50f51cSMatthias Springer newYieldOp.getResultsMutable().assign(*newYieldValues);
799a5d09c63SMatthias Springer
800a5d09c63SMatthias Springer // Replace loop results.
801a5d09c63SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
802a5d09c63SMatthias Springer
803a5d09c63SMatthias Springer return success();
804a5d09c63SMatthias Springer }
805a5d09c63SMatthias Springer
806a5d09c63SMatthias Springer /// Assert that yielded values of an scf.while op are equivalent to their
807a5d09c63SMatthias Springer /// corresponding bbArgs. In that case, the buffer relations of the
808a5d09c63SMatthias Springer /// corresponding OpResults are "Equivalent".
809a5d09c63SMatthias Springer ///
810a5d09c63SMatthias Springer /// If this is not the case, allocs+copies are inserted and yielded from
811a5d09c63SMatthias Springer /// the loop. This could be a performance problem, so it must be explicitly
812a5d09c63SMatthias Springer /// activated with `alloc-return-allocs`.
813a5d09c63SMatthias Springer ///
814a5d09c63SMatthias Springer /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
815a5d09c63SMatthias Springer /// equivalence condition must be checked for both.
verifyAnalysismlir::scf::__anon76a8a75a0111::WhileOpInterface816a5d09c63SMatthias Springer LogicalResult verifyAnalysis(Operation *op,
817a5d09c63SMatthias Springer const AnalysisState &state) const {
818a5d09c63SMatthias Springer auto whileOp = cast<scf::WhileOp>(op);
819a5d09c63SMatthias Springer const auto &options =
820a5d09c63SMatthias Springer static_cast<const OneShotBufferizationOptions &>(state.getOptions());
821a5d09c63SMatthias Springer if (options.allowReturnAllocs)
822a5d09c63SMatthias Springer return success();
823a5d09c63SMatthias Springer
824a5d09c63SMatthias Springer auto conditionOp = whileOp.getConditionOp();
825a5d09c63SMatthias Springer for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
826a5d09c63SMatthias Springer if (!it.value().getType().isa<TensorType>())
827a5d09c63SMatthias Springer continue;
828a5d09c63SMatthias Springer if (!state.areEquivalentBufferizedValues(
829a5d09c63SMatthias Springer it.value(), conditionOp->getBlock()->getArgument(it.index())))
830a5d09c63SMatthias Springer return conditionOp->emitError()
831a5d09c63SMatthias Springer << "Condition arg #" << it.index()
832a5d09c63SMatthias Springer << " is not equivalent to the corresponding iter bbArg";
833a5d09c63SMatthias Springer }
834a5d09c63SMatthias Springer
835a5d09c63SMatthias Springer auto yieldOp = whileOp.getYieldOp();
836a5d09c63SMatthias Springer for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
837a5d09c63SMatthias Springer if (!it.value().getType().isa<TensorType>())
838a5d09c63SMatthias Springer continue;
839a5d09c63SMatthias Springer if (!state.areEquivalentBufferizedValues(
840a5d09c63SMatthias Springer it.value(), yieldOp->getBlock()->getArgument(it.index())))
841a5d09c63SMatthias Springer return yieldOp->emitError()
842a5d09c63SMatthias Springer << "Yield operand #" << it.index()
843a5d09c63SMatthias Springer << " is not equivalent to the corresponding iter bbArg";
844a5d09c63SMatthias Springer }
845a5d09c63SMatthias Springer
846a5d09c63SMatthias Springer return success();
847a5d09c63SMatthias Springer }
848a5d09c63SMatthias Springer };
849a5d09c63SMatthias Springer
85019efe141SMatthias Springer /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
85119efe141SMatthias Springer /// this is for analysis only.
85219efe141SMatthias Springer struct YieldOpInterface
85319efe141SMatthias Springer : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
85419efe141SMatthias Springer scf::YieldOp> {
bufferizesToMemoryReadmlir::scf::__anon76a8a75a0111::YieldOpInterface85519efe141SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
8569597b16aSMatthias Springer const AnalysisState &state) const {
85719efe141SMatthias Springer return true;
85819efe141SMatthias Springer }
85919efe141SMatthias Springer
bufferizesToMemoryWritemlir::scf::__anon76a8a75a0111::YieldOpInterface86019efe141SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
8619597b16aSMatthias Springer const AnalysisState &state) const {
86219efe141SMatthias Springer return false;
86319efe141SMatthias Springer }
86419efe141SMatthias Springer
getAliasingOpResultmlir::scf::__anon76a8a75a0111::YieldOpInterface8659597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
8669597b16aSMatthias Springer const AnalysisState &state) const {
86719efe141SMatthias Springer if (isa<scf::IfOp>(op->getParentOp()))
868585a8a32SMatthias Springer return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
86919efe141SMatthias Springer if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
870585a8a32SMatthias Springer return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
871585a8a32SMatthias Springer return {};
87219efe141SMatthias Springer }
87319efe141SMatthias Springer
mustBufferizeInPlacemlir::scf::__anon76a8a75a0111::YieldOpInterface87419efe141SMatthias Springer bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
8759597b16aSMatthias Springer const AnalysisState &state) const {
87619efe141SMatthias Springer // Yield operands always bufferize inplace. Otherwise, an alloc + copy
87719efe141SMatthias Springer // may be generated inside the block. We should not return/yield allocations
87819efe141SMatthias Springer // when possible.
87919efe141SMatthias Springer return true;
88019efe141SMatthias Springer }
88119efe141SMatthias Springer
bufferizemlir::scf::__anon76a8a75a0111::YieldOpInterface88219efe141SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
883b55d55ecSMatthias Springer const BufferizationOptions &options) const {
88419efe141SMatthias Springer auto yieldOp = cast<scf::YieldOp>(op);
885a5d09c63SMatthias Springer if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
88619efe141SMatthias Springer yieldOp->getParentOp()))
88719efe141SMatthias Springer return yieldOp->emitError("unsupported scf::YieldOp parent");
8888e691e1fSMatthias Springer
8893ff93f83SMatthias Springer // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized
8903ff93f83SMatthias Springer // together with scf.while.)
8913ff93f83SMatthias Springer if (isa<scf::WhileOp>(yieldOp->getParentOp()))
8928e691e1fSMatthias Springer return success();
8938e691e1fSMatthias Springer
8948e691e1fSMatthias Springer SmallVector<Value> newResults;
8958e691e1fSMatthias Springer for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
8968e691e1fSMatthias Springer Value value = it.value();
8978e691e1fSMatthias Springer if (value.getType().isa<TensorType>()) {
8985d50f51cSMatthias Springer FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
8995d50f51cSMatthias Springer if (failed(maybeBuffer))
9005d50f51cSMatthias Springer return failure();
9015d50f51cSMatthias Springer Value buffer = *maybeBuffer;
9023ff93f83SMatthias Springer if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
9035d50f51cSMatthias Springer FailureOr<BaseMemRefType> resultType =
9043ff93f83SMatthias Springer cast<BufferizableOpInterface>(forOp.getOperation())
9053ff93f83SMatthias Springer .getBufferType(forOp.getRegionIterArgs()[it.index()],
9063ff93f83SMatthias Springer options);
9075d50f51cSMatthias Springer if (failed(resultType))
9085d50f51cSMatthias Springer return failure();
9095d50f51cSMatthias Springer buffer = castBuffer(rewriter, buffer, *resultType);
9103ff93f83SMatthias Springer }
9118e691e1fSMatthias Springer newResults.push_back(buffer);
9128e691e1fSMatthias Springer } else {
9138e691e1fSMatthias Springer newResults.push_back(value);
9148e691e1fSMatthias Springer }
9158e691e1fSMatthias Springer }
9168e691e1fSMatthias Springer
9178e691e1fSMatthias Springer replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
91819efe141SMatthias Springer return success();
91919efe141SMatthias Springer }
92019efe141SMatthias Springer };
92119efe141SMatthias Springer
92272de7588SNicolas Vasilache /// Return the destinations that an ForeachThreadOp is inserting into. One per
92372de7588SNicolas Vasilache /// ParallelInsertSliceOp.
92472de7588SNicolas Vasilache static SmallVector<OpOperand *>
getInsertionDest(ForeachThreadOp foreachThreadOp)92572de7588SNicolas Vasilache getInsertionDest(ForeachThreadOp foreachThreadOp) {
92672de7588SNicolas Vasilache PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
92772de7588SNicolas Vasilache SmallVector<OpOperand *> result;
928*7fbf55c9SNicolas Vasilache terminator.walk([&](tensor::ParallelInsertSliceOp insertOp) {
92972de7588SNicolas Vasilache result.push_back(&insertOp->getOpOperand(1) /*dest*/);
93072de7588SNicolas Vasilache });
93172de7588SNicolas Vasilache return result;
93272de7588SNicolas Vasilache }
93372de7588SNicolas Vasilache
93472de7588SNicolas Vasilache /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
93572de7588SNicolas Vasilache /// region. There are op interfaces for the terminators (PerformConcurrentlyOp
93672de7588SNicolas Vasilache /// and ParallelInsertSliceOp), but these are only used during analysis. Not
93772de7588SNicolas Vasilache /// for bufferization.
93872de7588SNicolas Vasilache struct ForeachThreadOpInterface
93972de7588SNicolas Vasilache : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface,
94072de7588SNicolas Vasilache ForeachThreadOp> {
94172de7588SNicolas Vasilache SmallVector<OpOperand *>
getAliasingOpOperandmlir::scf::__anon76a8a75a0111::ForeachThreadOpInterface94272de7588SNicolas Vasilache getAliasingOpOperand(Operation *op, OpResult opResult,
94372de7588SNicolas Vasilache const AnalysisState &state) const {
94472de7588SNicolas Vasilache // Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
94572de7588SNicolas Vasilache auto foreachThreadOp = cast<ForeachThreadOp>(op);
94672de7588SNicolas Vasilache return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]};
94772de7588SNicolas Vasilache }
94872de7588SNicolas Vasilache
isMemoryWritemlir::scf::__anon76a8a75a0111::ForeachThreadOpInterface94972de7588SNicolas Vasilache bool isMemoryWrite(Operation *op, OpResult opResult,
95072de7588SNicolas Vasilache const AnalysisState &state) const {
95172de7588SNicolas Vasilache // This op is a memory write. Stop lookup here to avoid finding false
95272de7588SNicolas Vasilache // conflicts involving this op and one of the ops in the region. This is
95372de7588SNicolas Vasilache // similar to how scf.if ops are analyzed.
95472de7588SNicolas Vasilache return true;
95572de7588SNicolas Vasilache }
95672de7588SNicolas Vasilache
bufferRelationmlir::scf::__anon76a8a75a0111::ForeachThreadOpInterface95772de7588SNicolas Vasilache BufferRelation bufferRelation(Operation *op, OpResult opResult,
95872de7588SNicolas Vasilache const AnalysisState &state) const {
95972de7588SNicolas Vasilache return BufferRelation::Equivalent;
96072de7588SNicolas Vasilache }
96172de7588SNicolas Vasilache
bufferizemlir::scf::__anon76a8a75a0111::ForeachThreadOpInterface9627ebf70d8SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
963b55d55ecSMatthias Springer const BufferizationOptions &options) const {
96472de7588SNicolas Vasilache auto foreachThreadOp = cast<ForeachThreadOp>(op);
96572de7588SNicolas Vasilache
9667ebf70d8SMatthias Springer #ifndef NDEBUG
9677ebf70d8SMatthias Springer // ParallelInsertSliceOpInterface replaces all uses.
9687ebf70d8SMatthias Springer for (OpResult opResult : foreachThreadOp->getOpResults())
9697ebf70d8SMatthias Springer assert(opResult.getUses().empty() &&
9707ebf70d8SMatthias Springer "expected that all uses were already replaced");
9717ebf70d8SMatthias Springer #endif // NDEBUG
97272de7588SNicolas Vasilache
97372de7588SNicolas Vasilache // Create new ForeachThreadOp without any results and drop the automatically
97472de7588SNicolas Vasilache // introduced terminator.
97572de7588SNicolas Vasilache TypeRange newResultTypes;
9767ebf70d8SMatthias Springer auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
9777ebf70d8SMatthias Springer foreachThreadOp.getLoc(), newResultTypes,
978a0f843fdSNicolas Vasilache foreachThreadOp.getNumThreads(),
979a0f843fdSNicolas Vasilache extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping()));
98072de7588SNicolas Vasilache newForeachThreadOp.getBody()->getTerminator()->erase();
98172de7588SNicolas Vasilache
98272de7588SNicolas Vasilache // Move over block contents of the old op.
9837ebf70d8SMatthias Springer rewriter.mergeBlocks(foreachThreadOp.getBody(),
9847ebf70d8SMatthias Springer newForeachThreadOp.getBody(),
98572de7588SNicolas Vasilache {newForeachThreadOp.getBody()->getArguments()});
98672de7588SNicolas Vasilache
9877ebf70d8SMatthias Springer // Remove the old op.
9887ebf70d8SMatthias Springer rewriter.eraseOp(op);
98972de7588SNicolas Vasilache
99072de7588SNicolas Vasilache return success();
99172de7588SNicolas Vasilache }
99272de7588SNicolas Vasilache };
99372de7588SNicolas Vasilache
99472de7588SNicolas Vasilache /// Nothing to do for PerformConcurrentlyOp.
99572de7588SNicolas Vasilache struct PerformConcurrentlyOpInterface
99672de7588SNicolas Vasilache : public BufferizableOpInterface::ExternalModel<
99772de7588SNicolas Vasilache PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
bufferizemlir::scf::__anon76a8a75a0111::PerformConcurrentlyOpInterface99872de7588SNicolas Vasilache LogicalResult bufferize(Operation *op, RewriterBase &b,
999b55d55ecSMatthias Springer const BufferizationOptions &options) const {
1000b3ebe3beSMatthias Springer llvm_unreachable("op does not have any tensor OpOperands / OpResults");
100172de7588SNicolas Vasilache return failure();
100272de7588SNicolas Vasilache }
100372de7588SNicolas Vasilache };
100472de7588SNicolas Vasilache
100519efe141SMatthias Springer } // namespace
100619efe141SMatthias Springer } // namespace scf
100719efe141SMatthias Springer } // namespace mlir
100819efe141SMatthias Springer
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)100919efe141SMatthias Springer void mlir::scf::registerBufferizableOpInterfaceExternalModels(
101019efe141SMatthias Springer DialectRegistry ®istry) {
101177eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
101277eee579SRiver Riddle ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
101377eee579SRiver Riddle ForOp::attachInterface<ForOpInterface>(*ctx);
101477eee579SRiver Riddle IfOp::attachInterface<IfOpInterface>(*ctx);
101572de7588SNicolas Vasilache ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
101672de7588SNicolas Vasilache PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
101772de7588SNicolas Vasilache *ctx);
1018a5d09c63SMatthias Springer WhileOp::attachInterface<WhileOpInterface>(*ctx);
101977eee579SRiver Riddle YieldOp::attachInterface<YieldOpInterface>(*ctx);
102077eee579SRiver Riddle });
102119efe141SMatthias Springer }
1022