149e37000SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
249e37000SMatthias Springer //
349e37000SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
449e37000SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
549e37000SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
649e37000SMatthias Springer //
749e37000SMatthias Springer //===----------------------------------------------------------------------===//
849e37000SMatthias Springer
949e37000SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1149e37000SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12b3ebe3beSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1349e37000SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
148b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
1549e37000SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
1649e37000SMatthias Springer #include "mlir/IR/Dialect.h"
1749e37000SMatthias Springer #include "mlir/IR/Operation.h"
1849e37000SMatthias Springer
1949e37000SMatthias Springer using namespace mlir;
2049e37000SMatthias Springer using namespace mlir::bufferization;
2149e37000SMatthias Springer using namespace mlir::tensor;
2249e37000SMatthias Springer
2349e37000SMatthias Springer namespace mlir {
2449e37000SMatthias Springer namespace tensor {
2549e37000SMatthias Springer namespace {
2649e37000SMatthias Springer
2749e37000SMatthias Springer struct CastOpInterface
2849e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<CastOpInterface,
2949e37000SMatthias Springer tensor::CastOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::CastOpInterface3049e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
319597b16aSMatthias Springer const AnalysisState &state) const {
3249e37000SMatthias Springer return false;
3349e37000SMatthias Springer }
3449e37000SMatthias Springer
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::CastOpInterface3549e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
369597b16aSMatthias Springer const AnalysisState &state) const {
3749e37000SMatthias Springer return false;
3849e37000SMatthias Springer }
3949e37000SMatthias Springer
getAliasingOpResultmlir::tensor::__anonb90e36390111::CastOpInterface409597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
419597b16aSMatthias Springer const AnalysisState &state) const {
42585a8a32SMatthias Springer return {op->getResult(0)};
4349e37000SMatthias Springer }
4449e37000SMatthias Springer
bufferRelationmlir::tensor::__anonb90e36390111::CastOpInterface4549e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult,
469597b16aSMatthias Springer const AnalysisState &state) const {
4749e37000SMatthias Springer return BufferRelation::Equivalent;
4849e37000SMatthias Springer }
4949e37000SMatthias Springer
bufferizemlir::tensor::__anonb90e36390111::CastOpInterface5049e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51b55d55ecSMatthias Springer const BufferizationOptions &options) const {
5249e37000SMatthias Springer auto castOp = cast<tensor::CastOp>(op);
5349e37000SMatthias Springer
5449e37000SMatthias Springer // The result buffer still has the old (pre-cast) type.
555d50f51cSMatthias Springer FailureOr<Value> resultBuffer =
565d50f51cSMatthias Springer getBuffer(rewriter, castOp.getSource(), options);
575d50f51cSMatthias Springer if (failed(resultBuffer))
585d50f51cSMatthias Springer return failure();
595d50f51cSMatthias Springer auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
6049e37000SMatthias Springer TensorType resultTensorType =
6149e37000SMatthias Springer castOp.getResult().getType().cast<TensorType>();
6249e37000SMatthias Springer MemRefLayoutAttrInterface layout;
6349e37000SMatthias Springer
6449e37000SMatthias Springer if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
6549e37000SMatthias Springer if (resultTensorType.isa<RankedTensorType>())
6649e37000SMatthias Springer layout = rankedMemRefType.getLayout();
6749e37000SMatthias Springer
6849e37000SMatthias Springer // Compute the new memref type.
69b55d55ecSMatthias Springer Type resultMemRefType =
70606f7c8fSMatthias Springer getMemRefType(castOp.getResult(), options, layout,
71b06614e2SMatthias Springer sourceMemRefType.getMemorySpaceAsInt());
7249e37000SMatthias Springer
7349e37000SMatthias Springer // Replace the op with a memref.cast.
745d50f51cSMatthias Springer assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
7549e37000SMatthias Springer resultMemRefType) &&
7649e37000SMatthias Springer "CallOp::bufferize: cast incompatible");
7749e37000SMatthias Springer replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
785d50f51cSMatthias Springer *resultBuffer);
7949e37000SMatthias Springer
8049e37000SMatthias Springer return success();
8149e37000SMatthias Springer }
8249e37000SMatthias Springer };
8349e37000SMatthias Springer
84e6f69161SMatthias Springer /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
85e6f69161SMatthias Springer struct CollapseShapeOpInterface
86e6f69161SMatthias Springer : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
87e6f69161SMatthias Springer tensor::CollapseShapeOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::CollapseShapeOpInterface88e6f69161SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
899597b16aSMatthias Springer const AnalysisState &state) const {
90e6f69161SMatthias Springer return false;
91e6f69161SMatthias Springer }
92e6f69161SMatthias Springer
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::CollapseShapeOpInterface93e6f69161SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
949597b16aSMatthias Springer const AnalysisState &state) const {
95e6f69161SMatthias Springer return false;
96e6f69161SMatthias Springer }
97e6f69161SMatthias Springer
getAliasingOpResultmlir::tensor::__anonb90e36390111::CollapseShapeOpInterface989597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
999597b16aSMatthias Springer const AnalysisState &state) const {
100e6f69161SMatthias Springer if (&opOperand == &op->getOpOperand(0) /*src*/)
101e6f69161SMatthias Springer return {op->getOpResult(0)};
102e6f69161SMatthias Springer return {};
103e6f69161SMatthias Springer }
104e6f69161SMatthias Springer
bufferRelationmlir::tensor::__anonb90e36390111::CollapseShapeOpInterface105e6f69161SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult,
1069597b16aSMatthias Springer const AnalysisState &state) const {
107e6f69161SMatthias Springer return BufferRelation::Equivalent;
108e6f69161SMatthias Springer }
109e6f69161SMatthias Springer
bufferizemlir::tensor::__anonb90e36390111::CollapseShapeOpInterface110e6f69161SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
111b55d55ecSMatthias Springer const BufferizationOptions &options) const {
112e6f69161SMatthias Springer auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
11351df6238SMatthias Springer RankedTensorType tensorResultType = collapseShapeOp.getResultType();
1145d50f51cSMatthias Springer FailureOr<Value> maybeBuffer =
1155d50f51cSMatthias Springer getBuffer(rewriter, collapseShapeOp.getSrc(), options);
1165d50f51cSMatthias Springer if (failed(maybeBuffer))
1175d50f51cSMatthias Springer return failure();
1185d50f51cSMatthias Springer Value buffer = *maybeBuffer;
119b3ebe3beSMatthias Springer auto bufferType = buffer.getType().cast<MemRefType>();
12051df6238SMatthias Springer
12151df6238SMatthias Springer if (tensorResultType.getRank() == 0) {
12251df6238SMatthias Springer // 0-d collapses must go through a different op builder.
12373c0333dSMatthias Springer MemRefType resultType;
12473c0333dSMatthias Springer
12573c0333dSMatthias Springer if (bufferType.getLayout().isIdentity()) {
12673c0333dSMatthias Springer // Standard layout: result type has no offset.
12751df6238SMatthias Springer MemRefLayoutAttrInterface layout;
12873c0333dSMatthias Springer resultType = MemRefType::get({}, tensorResultType.getElementType(),
12951df6238SMatthias Springer layout, bufferType.getMemorySpace());
13073c0333dSMatthias Springer } else {
13173c0333dSMatthias Springer // Source memref has a layout map: result type has the same offset as
13273c0333dSMatthias Springer // the source type.
13373c0333dSMatthias Springer SmallVector<int64_t> strides;
13473c0333dSMatthias Springer int64_t offset;
13573c0333dSMatthias Springer if (failed(getStridesAndOffset(bufferType, strides, offset)))
13673c0333dSMatthias Springer return failure();
13773c0333dSMatthias Springer AffineMap resultLayout =
13873c0333dSMatthias Springer makeStridedLinearLayoutMap({}, offset, op->getContext());
13973c0333dSMatthias Springer resultType =
14073c0333dSMatthias Springer MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
14173c0333dSMatthias Springer bufferType.getMemorySpaceAsInt());
14273c0333dSMatthias Springer }
14373c0333dSMatthias Springer
144e6f69161SMatthias Springer replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
1458df54a6aSJacques Pienaar rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
146e6f69161SMatthias Springer return success();
147e6f69161SMatthias Springer }
14851df6238SMatthias Springer
149d7a9bf91SMatthias Springer // If the dims are not collapsible (due to an incompatible source layout
150d7a9bf91SMatthias Springer // map), force an out-of-place bufferization, i.e., a buffer copy. This
151d7a9bf91SMatthias Springer // newly allocated buffer will have no layout map and thus be collapsible.
152a74e5a89SAdrian Kuegel bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
153d7a9bf91SMatthias Springer bufferType, collapseShapeOp.getReassociationIndices());
154b3ebe3beSMatthias Springer if (!canBeCollapsed) {
155b3ebe3beSMatthias Springer // TODO: Create alloc_tensor ops during TensorCopyInsertion.
156b55d55ecSMatthias Springer AnalysisState analysisState(options);
15745b995cdSMatthias Springer FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1588df54a6aSJacques Pienaar rewriter, op->getLoc(), collapseShapeOp.getSrc(),
15945b995cdSMatthias Springer analysisState.isTensorYielded(collapseShapeOp.getResult()), options);
16045b995cdSMatthias Springer if (failed(tensorAlloc))
16145b995cdSMatthias Springer return failure();
162b3ebe3beSMatthias Springer auto memrefType =
163b3ebe3beSMatthias Springer MemRefType::get(collapseShapeOp.getSrcType().getShape(),
164b3ebe3beSMatthias Springer collapseShapeOp.getSrcType().getElementType(),
165b3ebe3beSMatthias Springer AffineMap(), bufferType.getMemorySpaceAsInt());
166b3ebe3beSMatthias Springer buffer = rewriter.create<bufferization::ToMemrefOp>(
16745b995cdSMatthias Springer op->getLoc(), memrefType, *tensorAlloc);
168b3ebe3beSMatthias Springer }
169d7a9bf91SMatthias Springer
17051df6238SMatthias Springer // Result type is inferred by the builder.
17151df6238SMatthias Springer replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
172b3ebe3beSMatthias Springer rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
17351df6238SMatthias Springer return success();
17451df6238SMatthias Springer }
175e6f69161SMatthias Springer };
176e6f69161SMatthias Springer
17749e37000SMatthias Springer /// Bufferization of tensor.dim. Replace with memref.dim.
17849e37000SMatthias Springer struct DimOpInterface
17949e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<DimOpInterface,
18049e37000SMatthias Springer tensor::DimOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::DimOpInterface18149e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1829597b16aSMatthias Springer const AnalysisState &state) const {
18349e37000SMatthias Springer return true;
18449e37000SMatthias Springer }
18549e37000SMatthias Springer
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::DimOpInterface18649e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1879597b16aSMatthias Springer const AnalysisState &state) const {
18849e37000SMatthias Springer return false;
18949e37000SMatthias Springer }
19049e37000SMatthias Springer
getAliasingOpResultmlir::tensor::__anonb90e36390111::DimOpInterface1919597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
1929597b16aSMatthias Springer const AnalysisState &state) const {
193585a8a32SMatthias Springer return {};
19449e37000SMatthias Springer }
19549e37000SMatthias Springer
bufferizemlir::tensor::__anonb90e36390111::DimOpInterface19649e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
197b55d55ecSMatthias Springer const BufferizationOptions &options) const {
19849e37000SMatthias Springer auto dimOp = cast<tensor::DimOp>(op);
1995d50f51cSMatthias Springer FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
2005d50f51cSMatthias Springer if (failed(v))
2015d50f51cSMatthias Springer return failure();
2025d50f51cSMatthias Springer replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
203136d746eSJacques Pienaar dimOp.getIndex());
20449e37000SMatthias Springer return success();
20549e37000SMatthias Springer }
20649e37000SMatthias Springer };
20749e37000SMatthias Springer
208e6f69161SMatthias Springer /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
209e6f69161SMatthias Springer struct ExpandShapeOpInterface
210e6f69161SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
211e6f69161SMatthias Springer tensor::ExpandShapeOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::ExpandShapeOpInterface212e6f69161SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
2139597b16aSMatthias Springer const AnalysisState &state) const {
214e6f69161SMatthias Springer return false;
215e6f69161SMatthias Springer }
216e6f69161SMatthias Springer
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::ExpandShapeOpInterface217e6f69161SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
2189597b16aSMatthias Springer const AnalysisState &state) const {
219e6f69161SMatthias Springer return false;
220e6f69161SMatthias Springer }
221e6f69161SMatthias Springer
getAliasingOpResultmlir::tensor::__anonb90e36390111::ExpandShapeOpInterface2229597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
2239597b16aSMatthias Springer const AnalysisState &state) const {
224e6f69161SMatthias Springer if (&opOperand == &op->getOpOperand(0) /*src*/)
225e6f69161SMatthias Springer return {op->getOpResult(0)};
226e6f69161SMatthias Springer return {};
227e6f69161SMatthias Springer }
228e6f69161SMatthias Springer
bufferRelationmlir::tensor::__anonb90e36390111::ExpandShapeOpInterface229e6f69161SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult,
2309597b16aSMatthias Springer const AnalysisState &state) const {
231e6f69161SMatthias Springer return BufferRelation::Equivalent;
232e6f69161SMatthias Springer }
233e6f69161SMatthias Springer
bufferizemlir::tensor::__anonb90e36390111::ExpandShapeOpInterface234e6f69161SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
235b55d55ecSMatthias Springer const BufferizationOptions &options) const {
236e6f69161SMatthias Springer auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
23751df6238SMatthias Springer auto tensorResultType = expandShapeOp.getResultType();
2385d50f51cSMatthias Springer FailureOr<Value> buffer =
2395d50f51cSMatthias Springer getBuffer(rewriter, expandShapeOp.getSrc(), options);
2405d50f51cSMatthias Springer if (failed(buffer))
2415d50f51cSMatthias Springer return failure();
24251df6238SMatthias Springer
24351df6238SMatthias Springer // Memref result type is inferred by the builder based on reassociation
24451df6238SMatthias Springer // indices and result shape.
245e6f69161SMatthias Springer replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
2465d50f51cSMatthias Springer rewriter, op, tensorResultType.getShape(), *buffer,
24751df6238SMatthias Springer expandShapeOp.getReassociationIndices());
248e6f69161SMatthias Springer return success();
249e6f69161SMatthias Springer }
250e6f69161SMatthias Springer };
251e6f69161SMatthias Springer
25249e37000SMatthias Springer /// Bufferization of tensor.extract_slice. Replace with memref.subview.
25349e37000SMatthias Springer struct ExtractSliceOpInterface
25449e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
25549e37000SMatthias Springer tensor::ExtractSliceOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::ExtractSliceOpInterface25649e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
2579597b16aSMatthias Springer const AnalysisState &state) const {
25849e37000SMatthias Springer return false;
25949e37000SMatthias Springer }
26049e37000SMatthias Springer
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::ExtractSliceOpInterface26149e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
2629597b16aSMatthias Springer const AnalysisState &state) const {
26349e37000SMatthias Springer return false;
26449e37000SMatthias Springer }
26549e37000SMatthias Springer
getAliasingOpResultmlir::tensor::__anonb90e36390111::ExtractSliceOpInterface2669597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
2679597b16aSMatthias Springer const AnalysisState &state) const {
268585a8a32SMatthias Springer if (&opOperand == &op->getOpOperand(0) /*source*/)
269585a8a32SMatthias Springer return {op->getOpResult(0)};
270585a8a32SMatthias Springer return {};
27149e37000SMatthias Springer }
27249e37000SMatthias Springer
bufferRelationmlir::tensor::__anonb90e36390111::ExtractSliceOpInterface27349e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult,
2749597b16aSMatthias Springer const AnalysisState &state) const {
27549e37000SMatthias Springer return BufferRelation::None;
27649e37000SMatthias Springer }
27749e37000SMatthias Springer
bufferizemlir::tensor::__anonb90e36390111::ExtractSliceOpInterface27849e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
279b55d55ecSMatthias Springer const BufferizationOptions &options) const {
28049e37000SMatthias Springer auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
2816c3c5f80SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
2826c3c5f80SMatthias Springer SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
2836c3c5f80SMatthias Springer SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
28449e37000SMatthias Springer Location loc = extractSliceOp.getLoc();
285d7a9bf91SMatthias Springer
2866c3c5f80SMatthias Springer // Get source buffer.
2875d50f51cSMatthias Springer FailureOr<Value> srcMemref =
2885d50f51cSMatthias Springer getBuffer(rewriter, extractSliceOp.getSource(), options);
2895d50f51cSMatthias Springer if (failed(srcMemref))
2905d50f51cSMatthias Springer return failure();
2915d50f51cSMatthias Springer auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
29249e37000SMatthias Springer
2936c3c5f80SMatthias Springer // Take a subview of the source buffer.
2946c3c5f80SMatthias Springer auto subviewMemRefType =
2956c3c5f80SMatthias Springer memref::SubViewOp::inferRankReducedResultType(
2966c3c5f80SMatthias Springer extractSliceOp.getType().getShape(), srcMemrefType, mixedOffsets,
2976c3c5f80SMatthias Springer mixedSizes, mixedStrides)
29849e37000SMatthias Springer .cast<MemRefType>();
29949e37000SMatthias Springer Value subView = rewriter.create<memref::SubViewOp>(
3005d50f51cSMatthias Springer loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
30149e37000SMatthias Springer mixedStrides);
30249e37000SMatthias Springer
30349e37000SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, subView);
30449e37000SMatthias Springer return success();
30549e37000SMatthias Springer }
30649e37000SMatthias Springer };
30749e37000SMatthias Springer
30849e37000SMatthias Springer /// Bufferization of tensor.extract. Replace with memref.load.
30949e37000SMatthias Springer struct ExtractOpInterface
31049e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
31149e37000SMatthias Springer tensor::ExtractOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::ExtractOpInterface31249e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
3139597b16aSMatthias Springer const AnalysisState &state) const {
31449e37000SMatthias Springer return true;
31549e37000SMatthias Springer }
31649e37000SMatthias Springer
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::ExtractOpInterface31749e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
3189597b16aSMatthias Springer const AnalysisState &state) const {
31949e37000SMatthias Springer return false;
32049e37000SMatthias Springer }
32149e37000SMatthias Springer
getAliasingOpResultmlir::tensor::__anonb90e36390111::ExtractOpInterface3229597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
3239597b16aSMatthias Springer const AnalysisState &state) const {
324585a8a32SMatthias Springer return {};
32549e37000SMatthias Springer }
32649e37000SMatthias Springer
bufferizemlir::tensor::__anonb90e36390111::ExtractOpInterface32749e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
328b55d55ecSMatthias Springer const BufferizationOptions &options) const {
32949e37000SMatthias Springer auto extractOp = cast<tensor::ExtractOp>(op);
3305d50f51cSMatthias Springer FailureOr<Value> srcMemref =
3315d50f51cSMatthias Springer getBuffer(rewriter, extractOp.getTensor(), options);
3325d50f51cSMatthias Springer if (failed(srcMemref))
3335d50f51cSMatthias Springer return failure();
3345d50f51cSMatthias Springer replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
335136d746eSJacques Pienaar extractOp.getIndices());
33649e37000SMatthias Springer return success();
33749e37000SMatthias Springer }
33849e37000SMatthias Springer };
33949e37000SMatthias Springer
340d581c94dSMatthias Springer // Implements backtracking to traverse indices of the output buffer while
341d581c94dSMatthias Springer // iterating over op.elements().
createStores(RewriterBase & rewriter,Location loc,int dim,Value buffer,ArrayRef<int64_t> shape,ArrayRef<Value> constants,OperandRange::iterator & elementIt,SmallVectorImpl<Value> & indices)342d581c94dSMatthias Springer static void createStores(RewriterBase &rewriter, Location loc, int dim,
343d581c94dSMatthias Springer Value buffer, ArrayRef<int64_t> shape,
344d581c94dSMatthias Springer ArrayRef<Value> constants,
345d581c94dSMatthias Springer OperandRange::iterator &elementIt,
346d581c94dSMatthias Springer SmallVectorImpl<Value> &indices) {
347d581c94dSMatthias Springer if (dim == static_cast<int>(shape.size()) - 1) {
348d581c94dSMatthias Springer for (int i = 0; i < shape.back(); ++i) {
349d581c94dSMatthias Springer indices.back() = constants[i];
350d581c94dSMatthias Springer rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
351d581c94dSMatthias Springer ++elementIt;
352d581c94dSMatthias Springer }
353d581c94dSMatthias Springer return;
354d581c94dSMatthias Springer }
355d581c94dSMatthias Springer for (int i = 0; i < shape[dim]; ++i) {
356d581c94dSMatthias Springer indices[dim] = constants[i];
357d581c94dSMatthias Springer createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
358d581c94dSMatthias Springer indices);
359d581c94dSMatthias Springer }
360d581c94dSMatthias Springer }
361d581c94dSMatthias Springer
362d581c94dSMatthias Springer /// Bufferization of tensor.from_elements.
363d581c94dSMatthias Springer struct FromElementsOpInterface
364d581c94dSMatthias Springer : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
365d581c94dSMatthias Springer tensor::FromElementsOp> {
366664ffa46SMatthias Springer
bufferizesToAllocationmlir::tensor::__anonb90e36390111::FromElementsOpInterface367664ffa46SMatthias Springer bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
368664ffa46SMatthias Springer return true;
369664ffa46SMatthias Springer }
370664ffa46SMatthias Springer
bufferizemlir::tensor::__anonb90e36390111::FromElementsOpInterface371d581c94dSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
372b55d55ecSMatthias Springer const BufferizationOptions &options) const {
373d581c94dSMatthias Springer auto fromElementsOp = cast<tensor::FromElementsOp>(op);
374664ffa46SMatthias Springer // Should the buffer be deallocated?
375664ffa46SMatthias Springer bool dealloc = shouldDeallocateOpResult(
376664ffa46SMatthias Springer fromElementsOp.getResult().cast<OpResult>(), options);
377d581c94dSMatthias Springer
378c0b0b6a0SMatthias Springer // TODO: Implement memory space for this op.
379c0b0b6a0SMatthias Springer if (options.defaultMemorySpace != static_cast<unsigned>(0))
380c0b0b6a0SMatthias Springer return op->emitError("memory space not implemented yet");
381c0b0b6a0SMatthias Springer
382d581c94dSMatthias Springer // Allocate a buffer for the result.
383d581c94dSMatthias Springer Location loc = op->getLoc();
384d581c94dSMatthias Springer auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
385d581c94dSMatthias Springer auto shape = tensorType.getShape();
386b3ebe3beSMatthias Springer // TODO: Create alloc_tensor ops during TensorCopyInsertion.
387664ffa46SMatthias Springer FailureOr<Value> tensorAlloc =
388664ffa46SMatthias Springer allocateTensorForShapedValue(rewriter, loc, fromElementsOp.getResult(),
389664ffa46SMatthias Springer /*escape=*/!dealloc, options,
390b3ebe3beSMatthias Springer /*copy=*/false);
39145b995cdSMatthias Springer if (failed(tensorAlloc))
39245b995cdSMatthias Springer return failure();
393b3ebe3beSMatthias Springer auto memrefType =
394b3ebe3beSMatthias Springer MemRefType::get(tensorType.getShape(), tensorType.getElementType());
395b3ebe3beSMatthias Springer Value buffer = rewriter.create<bufferization::ToMemrefOp>(
39645b995cdSMatthias Springer op->getLoc(), memrefType, *tensorAlloc);
397d581c94dSMatthias Springer
398d581c94dSMatthias Springer // Case: tensor<0xelem_type>.
3998df54a6aSJacques Pienaar if (fromElementsOp.getElements().empty()) {
400d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer);
401d581c94dSMatthias Springer return success();
402d581c94dSMatthias Springer }
403d581c94dSMatthias Springer
404d581c94dSMatthias Springer // Case: tensor<elem_type>.
405d581c94dSMatthias Springer if (shape.empty()) {
4068df54a6aSJacques Pienaar rewriter.create<memref::StoreOp>(
4078df54a6aSJacques Pienaar loc, fromElementsOp.getElements().front(), buffer);
408d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer);
409d581c94dSMatthias Springer return success();
410d581c94dSMatthias Springer }
411d581c94dSMatthias Springer
412d581c94dSMatthias Springer // Create constants for the range of possible indices [0, max{shape_i}).
413d581c94dSMatthias Springer auto maxDim = *std::max_element(shape.begin(), shape.end());
414d581c94dSMatthias Springer SmallVector<Value, 2> constants;
415d581c94dSMatthias Springer constants.reserve(maxDim);
416d581c94dSMatthias Springer for (int i = 0; i < maxDim; ++i)
417d581c94dSMatthias Springer constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
418d581c94dSMatthias Springer
419d581c94dSMatthias Springer // Traverse all `elements` and create `memref.store` ops.
4208df54a6aSJacques Pienaar auto elementIt = fromElementsOp.getElements().begin();
421d581c94dSMatthias Springer SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
422d581c94dSMatthias Springer createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
423d581c94dSMatthias Springer indices);
424d581c94dSMatthias Springer
425d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer);
426664ffa46SMatthias Springer
427d581c94dSMatthias Springer return success();
428d581c94dSMatthias Springer }
429d581c94dSMatthias Springer };
430d581c94dSMatthias Springer
43171bbb78bSMatthias Springer /// Bufferization of tensor.generate.
43271bbb78bSMatthias Springer struct GenerateOpInterface
43371bbb78bSMatthias Springer : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
43471bbb78bSMatthias Springer tensor::GenerateOp> {
435664ffa46SMatthias Springer
bufferizesToAllocationmlir::tensor::__anonb90e36390111::GenerateOpInterface436664ffa46SMatthias Springer bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
437664ffa46SMatthias Springer return true;
438664ffa46SMatthias Springer }
439664ffa46SMatthias Springer
bufferizemlir::tensor::__anonb90e36390111::GenerateOpInterface44071bbb78bSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
441b55d55ecSMatthias Springer const BufferizationOptions &options) const {
44271bbb78bSMatthias Springer auto generateOp = cast<tensor::GenerateOp>(op);
443664ffa46SMatthias Springer // Should the buffer be deallocated?
444664ffa46SMatthias Springer bool dealloc = shouldDeallocateOpResult(
445664ffa46SMatthias Springer generateOp.getResult().cast<OpResult>(), options);
446c0b0b6a0SMatthias Springer
447c0b0b6a0SMatthias Springer // TODO: Implement memory space for this op.
448c0b0b6a0SMatthias Springer if (options.defaultMemorySpace != static_cast<unsigned>(0))
449c0b0b6a0SMatthias Springer return op->emitError("memory space not implemented yet");
450c0b0b6a0SMatthias Springer
451b3ebe3beSMatthias Springer auto tensorType = generateOp.getType().cast<RankedTensorType>();
45271bbb78bSMatthias Springer // Allocate memory.
45371bbb78bSMatthias Springer Location loc = op->getLoc();
454b3ebe3beSMatthias Springer // TODO: Create alloc_tensor ops during TensorCopyInsertion.
455664ffa46SMatthias Springer FailureOr<Value> tensorAlloc =
456664ffa46SMatthias Springer allocateTensorForShapedValue(rewriter, loc, generateOp.getResult(),
457664ffa46SMatthias Springer /*escape=*/!dealloc, options,
458b3ebe3beSMatthias Springer /*copy=*/false);
45945b995cdSMatthias Springer if (failed(tensorAlloc))
46045b995cdSMatthias Springer return failure();
461b3ebe3beSMatthias Springer auto memrefType =
462b3ebe3beSMatthias Springer MemRefType::get(tensorType.getShape(), tensorType.getElementType());
463b3ebe3beSMatthias Springer Value buffer = rewriter.create<bufferization::ToMemrefOp>(
46445b995cdSMatthias Springer op->getLoc(), memrefType, *tensorAlloc);
46571bbb78bSMatthias Springer
46671bbb78bSMatthias Springer // Collect loop bounds.
46771bbb78bSMatthias Springer int64_t rank = memrefType.getRank();
46871bbb78bSMatthias Springer Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
46971bbb78bSMatthias Springer Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
47071bbb78bSMatthias Springer SmallVector<Value, 4> lowerBounds(rank, zero);
47171bbb78bSMatthias Springer SmallVector<Value, 4> steps(rank, one);
47271bbb78bSMatthias Springer SmallVector<Value, 4> upperBounds;
47371bbb78bSMatthias Springer int nextDynamicIndex = 0;
47471bbb78bSMatthias Springer for (int i = 0; i < rank; i++) {
4758df54a6aSJacques Pienaar Value upperBound =
4768df54a6aSJacques Pienaar memrefType.isDynamicDim(i)
4778df54a6aSJacques Pienaar ? generateOp.getDynamicExtents()[nextDynamicIndex++]
47871bbb78bSMatthias Springer : rewriter.create<arith::ConstantIndexOp>(
47971bbb78bSMatthias Springer loc, memrefType.getDimSize(i));
48071bbb78bSMatthias Springer upperBounds.push_back(upperBound);
48171bbb78bSMatthias Springer }
48271bbb78bSMatthias Springer
48371bbb78bSMatthias Springer // Generate tensor elements with a parallel loop that stores into
48471bbb78bSMatthias Springer // each element of the resulting memref. We use mergeBlockBefore to "move"
48571bbb78bSMatthias Springer // this op's body into the scf.parallel's body.
48671bbb78bSMatthias Springer auto parallel =
48771bbb78bSMatthias Springer rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
48871bbb78bSMatthias Springer Block *parallelBody = parallel.getBody();
489eca86cb2SJacques Pienaar rewriter.mergeBlockBefore(&generateOp.getBody().front(),
49071bbb78bSMatthias Springer parallelBody->getTerminator(),
49171bbb78bSMatthias Springer parallelBody->getArguments());
49271bbb78bSMatthias Springer // Replace the inlined yield op with a store op. The scf.parallel's builder
49371bbb78bSMatthias Springer // already populated an scf.yield at the end, so we don't need to worry
49471bbb78bSMatthias Springer // about creating that.
49571bbb78bSMatthias Springer Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
49671bbb78bSMatthias Springer rewriter.setInsertionPointAfter(elementYield);
49771bbb78bSMatthias Springer rewriter.replaceOpWithNewOp<memref::StoreOp>(
498b3ebe3beSMatthias Springer elementYield, elementYield->getOperands()[0], buffer,
49971bbb78bSMatthias Springer parallelBody->getArguments());
50071bbb78bSMatthias Springer
501b3ebe3beSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer);
502664ffa46SMatthias Springer
50371bbb78bSMatthias Springer return success();
50471bbb78bSMatthias Springer }
50571bbb78bSMatthias Springer };
50671bbb78bSMatthias Springer
50749e37000SMatthias Springer /// Bufferization of tensor.insert. Replace with memref.store.
50849e37000SMatthias Springer struct InsertOpInterface
50949e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
51049e37000SMatthias Springer tensor::InsertOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::InsertOpInterface51149e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
5129597b16aSMatthias Springer const AnalysisState &state) const {
51349e37000SMatthias Springer return true;
51449e37000SMatthias Springer }
51549e37000SMatthias Springer
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::InsertOpInterface51649e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
5179597b16aSMatthias Springer const AnalysisState &state) const {
51849e37000SMatthias Springer return true;
51949e37000SMatthias Springer }
52049e37000SMatthias Springer
getAliasingOpResultmlir::tensor::__anonb90e36390111::InsertOpInterface5219597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
5229597b16aSMatthias Springer const AnalysisState &state) const {
52349e37000SMatthias Springer assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
52449e37000SMatthias Springer "expected dest OpOperand");
525585a8a32SMatthias Springer return {op->getOpResult(0)};
52649e37000SMatthias Springer }
52749e37000SMatthias Springer
52849e37000SMatthias Springer SmallVector<OpOperand *>
getAliasingOpOperandmlir::tensor::__anonb90e36390111::InsertOpInterface52949e37000SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult,
5309597b16aSMatthias Springer const AnalysisState &state) const {
53149e37000SMatthias Springer return {&op->getOpOperand(1) /*dest*/};
53249e37000SMatthias Springer }
53349e37000SMatthias Springer
bufferizemlir::tensor::__anonb90e36390111::InsertOpInterface53449e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
535b55d55ecSMatthias Springer const BufferizationOptions &options) const {
53649e37000SMatthias Springer auto insertOp = cast<tensor::InsertOp>(op);
5375d50f51cSMatthias Springer FailureOr<Value> destMemref =
5385d50f51cSMatthias Springer getBuffer(rewriter, insertOp.getDest(), options);
5395d50f51cSMatthias Springer if (failed(destMemref))
5405d50f51cSMatthias Springer return failure();
5418df54a6aSJacques Pienaar rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
5425d50f51cSMatthias Springer *destMemref, insertOp.getIndices());
5435d50f51cSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *destMemref);
54449e37000SMatthias Springer return success();
54549e37000SMatthias Springer }
54649e37000SMatthias Springer
bufferRelationmlir::tensor::__anonb90e36390111::InsertOpInterface54749e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult,
5489597b16aSMatthias Springer const AnalysisState &state) const {
54949e37000SMatthias Springer return BufferRelation::Equivalent;
55049e37000SMatthias Springer }
55149e37000SMatthias Springer };
55249e37000SMatthias Springer
55349e37000SMatthias Springer /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
55449e37000SMatthias Springer /// equivalent operand / result and same offset/sizes/strides specification).
555*1defec87SMatthias Springer template <typename OpTy>
areEquivalentExtractSliceOps(const AnalysisState & state,ExtractSliceOp extractSliceOp,OpTy insertSliceOp)5569597b16aSMatthias Springer static bool areEquivalentExtractSliceOps(const AnalysisState &state,
557*1defec87SMatthias Springer ExtractSliceOp extractSliceOp,
558*1defec87SMatthias Springer OpTy insertSliceOp) {
559*1defec87SMatthias Springer if (!extractSliceOp || !insertSliceOp)
56049e37000SMatthias Springer return false;
561*1defec87SMatthias Springer if (extractSliceOp != insertSliceOp &&
562*1defec87SMatthias Springer !state.areEquivalentBufferizedValues(extractSliceOp.getSource(),
563*1defec87SMatthias Springer insertSliceOp.getDest()))
56449e37000SMatthias Springer return false;
565*1defec87SMatthias Springer if (!sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
566*1defec87SMatthias Springer isEqualConstantIntOrValue))
56749e37000SMatthias Springer return false;
56849e37000SMatthias Springer return true;
56949e37000SMatthias Springer }
57049e37000SMatthias Springer
57149e37000SMatthias Springer /// Return true if `value` is originating from an ExtractSliceOp that matches
57249e37000SMatthias Springer /// the given InsertSliceOp.
573*1defec87SMatthias Springer template <typename OpTy>
hasMatchingExtractSliceOp(const AnalysisState & state,Value value,OpTy insertSliceOp)5749597b16aSMatthias Springer static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
575*1defec87SMatthias Springer OpTy insertSliceOp) {
57649e37000SMatthias Springer auto condition = [&](Value val) {
577*1defec87SMatthias Springer if (auto extractSliceOp = val.getDefiningOp<ExtractSliceOp>())
578*1defec87SMatthias Springer if (areEquivalentExtractSliceOps(state, extractSliceOp, insertSliceOp))
57949e37000SMatthias Springer return true;
58049e37000SMatthias Springer return false;
58149e37000SMatthias Springer };
58249e37000SMatthias Springer
58349e37000SMatthias Springer return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
58449e37000SMatthias Springer condition);
58549e37000SMatthias Springer }
58649e37000SMatthias Springer
587*1defec87SMatthias Springer template <typename OpTy>
isNotConflictingInsertSliceLikeOp(Operation * op,OpOperand * uRead,OpOperand * uConflictingWrite,const AnalysisState & state)588*1defec87SMatthias Springer static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
58949e37000SMatthias Springer OpOperand *uConflictingWrite,
590*1defec87SMatthias Springer const AnalysisState &state) {
59149e37000SMatthias Springer Operation *readingOp = uRead->getOwner();
59249e37000SMatthias Springer Operation *conflictingWritingOp = uConflictingWrite->getOwner();
59349e37000SMatthias Springer
59449e37000SMatthias Springer // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
59549e37000SMatthias Springer // uRead is an InsertSliceOp...
596*1defec87SMatthias Springer if (auto insertSliceOp = dyn_cast<OpTy>(readingOp)) {
59749e37000SMatthias Springer // As an example, consider the following IR.
59849e37000SMatthias Springer //
59949e37000SMatthias Springer // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
60049e37000SMatthias Springer // %1 = linalg.fill %cst, %0 {inplace= [true] }
60149e37000SMatthias Springer // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
60249e37000SMatthias Springer // {inplace= [true] }
60349e37000SMatthias Springer
60449e37000SMatthias Springer // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
60549e37000SMatthias Springer if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
60649e37000SMatthias Springer hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
60749e37000SMatthias Springer insertSliceOp))
60849e37000SMatthias Springer // Case 1: The main insight is that InsertSliceOp reads only part of
60949e37000SMatthias Springer // the destination tensor. The overwritten area is not read. If
61049e37000SMatthias Springer // uConflictingWrite writes into exactly the memory location that is
61149e37000SMatthias Springer // being read by uRead, this is not a conflict.
61249e37000SMatthias Springer //
61349e37000SMatthias Springer // In the above example:
61449e37000SMatthias Springer // uRead = OpOperand 1 (%t) of tensor.insert_slice
61549e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
61649e37000SMatthias Springer //
61749e37000SMatthias Springer // The read of %t does not conflict with the write of the FillOp
61849e37000SMatthias Springer // (same aliases!) because the area that the FillOp operates on is
61949e37000SMatthias Springer // exactly the one that is *not* read via %t.
62049e37000SMatthias Springer return true;
62149e37000SMatthias Springer
62249e37000SMatthias Springer if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
62349e37000SMatthias Springer uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
62449e37000SMatthias Springer hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
62549e37000SMatthias Springer // Case 2: The read of the source tensor and the write to the dest
62649e37000SMatthias Springer // tensor via an InsertSliceOp is not a conflict if the read is
62749e37000SMatthias Springer // reading exactly that part of an equivalent tensor that the
62849e37000SMatthias Springer // InsertSliceOp is writing.
62949e37000SMatthias Springer //
63049e37000SMatthias Springer // In the above example:
63149e37000SMatthias Springer // uRead = OpOperand 0 (%1) of tensor.insert_slice
63249e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
63349e37000SMatthias Springer return true;
63449e37000SMatthias Springer }
63549e37000SMatthias Springer
63649e37000SMatthias Springer // If uConflictingWrite is an InsertSliceOp...
637*1defec87SMatthias Springer if (auto insertSliceOp = dyn_cast<OpTy>(conflictingWritingOp))
63849e37000SMatthias Springer // As an example, consider the following IR.
63949e37000SMatthias Springer //
64049e37000SMatthias Springer // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
64149e37000SMatthias Springer // %1 = linalg.fill %cst, %0 {inplace= [true] }
64249e37000SMatthias Springer // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
64349e37000SMatthias Springer // {inplace= [true] }
64449e37000SMatthias Springer // %3 = vector.transfer_read %1, %cst
64549e37000SMatthias Springer //
64649e37000SMatthias Springer // In the above example:
64749e37000SMatthias Springer // uRead = OpOperand 0 (%1) of vector.transfer_read
64849e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
64949e37000SMatthias Springer // lastWrite = %1
65049e37000SMatthias Springer //
65149e37000SMatthias Springer // This is not a conflict because the InsertSliceOp overwrites the
65249e37000SMatthias Springer // memory segment of %1 with the exact same data. (Effectively, there
65349e37000SMatthias Springer // is no memory write here.)
65449e37000SMatthias Springer if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
65549e37000SMatthias Springer state.areEquivalentBufferizedValues(uRead->get(),
6568df54a6aSJacques Pienaar insertSliceOp.getSource()) &&
6578df54a6aSJacques Pienaar hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
65849e37000SMatthias Springer insertSliceOp))
65949e37000SMatthias Springer return true;
66049e37000SMatthias Springer
66149e37000SMatthias Springer return false;
66249e37000SMatthias Springer }
66349e37000SMatthias Springer
664*1defec87SMatthias Springer /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
665*1defec87SMatthias Springer /// certain circumstances, this op can also be a no-op.
666*1defec87SMatthias Springer struct InsertSliceOpInterface
667*1defec87SMatthias Springer : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
668*1defec87SMatthias Springer tensor::InsertSliceOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::InsertSliceOpInterface669*1defec87SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
670*1defec87SMatthias Springer const AnalysisState &state) const {
671*1defec87SMatthias Springer return true;
672*1defec87SMatthias Springer }
673*1defec87SMatthias Springer
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::InsertSliceOpInterface674*1defec87SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
675*1defec87SMatthias Springer const AnalysisState &state) const {
676*1defec87SMatthias Springer return &opOperand == &op->getOpOperand(1) /*dest*/;
677*1defec87SMatthias Springer }
678*1defec87SMatthias Springer
getAliasingOpResultmlir::tensor::__anonb90e36390111::InsertSliceOpInterface679*1defec87SMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
680*1defec87SMatthias Springer const AnalysisState &state) const {
681*1defec87SMatthias Springer if (&opOperand == &op->getOpOperand(1) /*dest*/)
682*1defec87SMatthias Springer return {op->getResult(0)};
683*1defec87SMatthias Springer return {};
684*1defec87SMatthias Springer }
685*1defec87SMatthias Springer
bufferRelationmlir::tensor::__anonb90e36390111::InsertSliceOpInterface686*1defec87SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult,
687*1defec87SMatthias Springer const AnalysisState &state) const {
688*1defec87SMatthias Springer return BufferRelation::Equivalent;
689*1defec87SMatthias Springer }
690*1defec87SMatthias Springer
isNotConflictingmlir::tensor::__anonb90e36390111::InsertSliceOpInterface691*1defec87SMatthias Springer bool isNotConflicting(Operation *op, OpOperand *uRead,
692*1defec87SMatthias Springer OpOperand *uConflictingWrite,
693*1defec87SMatthias Springer const AnalysisState &state) const {
694*1defec87SMatthias Springer return isNotConflictingInsertSliceLikeOp<tensor::InsertSliceOp>(
695*1defec87SMatthias Springer op, uRead, uConflictingWrite, state);
696*1defec87SMatthias Springer }
697*1defec87SMatthias Springer
bufferizemlir::tensor::__anonb90e36390111::InsertSliceOpInterface69849e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
699b55d55ecSMatthias Springer const BufferizationOptions &options) const {
70049e37000SMatthias Springer // insert_slice ops arise from tiling and bufferizing them out-of-place is
70149e37000SMatthias Springer // generally a deal breaker. When used with loops, this ends up cloning the
70249e37000SMatthias Springer // whole tensor on every single iteration and is a symptom of a
70349e37000SMatthias Springer // catastrophically bad scheduling decision.
70449e37000SMatthias Springer // TODO: be very loud about it or even consider failing the pass.
70549e37000SMatthias Springer auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
7066c3c5f80SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
7076c3c5f80SMatthias Springer SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
7086c3c5f80SMatthias Springer SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
70949e37000SMatthias Springer Location loc = insertSliceOp.getLoc();
7106c3c5f80SMatthias Springer
7116c3c5f80SMatthias Springer // Get destination buffer.
7125d50f51cSMatthias Springer FailureOr<Value> dstMemref =
7135d50f51cSMatthias Springer getBuffer(rewriter, insertSliceOp.getDest(), options);
7145d50f51cSMatthias Springer if (failed(dstMemref))
7155d50f51cSMatthias Springer return failure();
71649e37000SMatthias Springer
7176c3c5f80SMatthias Springer // Take a subview of the destination buffer.
7185d50f51cSMatthias Springer auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
71949e37000SMatthias Springer auto subviewMemRefType =
72049e37000SMatthias Springer memref::SubViewOp::inferRankReducedResultType(
7216c3c5f80SMatthias Springer insertSliceOp.getSourceType().getShape(), dstMemrefType,
72249e37000SMatthias Springer mixedOffsets, mixedSizes, mixedStrides)
72349e37000SMatthias Springer .cast<MemRefType>();
72449e37000SMatthias Springer Value subView = rewriter.create<memref::SubViewOp>(
7255d50f51cSMatthias Springer loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
72649e37000SMatthias Springer mixedStrides);
72749e37000SMatthias Springer
72849e37000SMatthias Springer // Copy tensor. If this tensor.insert_slice has a matching
72949e37000SMatthias Springer // tensor.extract_slice, the copy operation will eventually fold away.
7305d50f51cSMatthias Springer FailureOr<Value> srcMemref =
7315d50f51cSMatthias Springer getBuffer(rewriter, insertSliceOp.getSource(), options);
7325d50f51cSMatthias Springer if (failed(srcMemref))
7335d50f51cSMatthias Springer return failure();
7345d50f51cSMatthias Springer if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
73549e37000SMatthias Springer return failure();
73649e37000SMatthias Springer
7375d50f51cSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
73849e37000SMatthias Springer return success();
73949e37000SMatthias Springer }
74049e37000SMatthias Springer };
74149e37000SMatthias Springer
742fc08d1c2SMatthias Springer /// Bufferization of tensor.rank. Replace with memref.rank.
743fc08d1c2SMatthias Springer struct RankOpInterface
744fc08d1c2SMatthias Springer : public BufferizableOpInterface::ExternalModel<RankOpInterface,
745fc08d1c2SMatthias Springer tensor::RankOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::RankOpInterface746fc08d1c2SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
7479597b16aSMatthias Springer const AnalysisState &state) const {
748fc08d1c2SMatthias Springer return true;
749fc08d1c2SMatthias Springer }
750fc08d1c2SMatthias Springer
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::RankOpInterface751fc08d1c2SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
7529597b16aSMatthias Springer const AnalysisState &state) const {
753fc08d1c2SMatthias Springer return false;
754fc08d1c2SMatthias Springer }
755fc08d1c2SMatthias Springer
getAliasingOpResultmlir::tensor::__anonb90e36390111::RankOpInterface7569597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
7579597b16aSMatthias Springer const AnalysisState &state) const {
758585a8a32SMatthias Springer return {};
759fc08d1c2SMatthias Springer }
760fc08d1c2SMatthias Springer
bufferizemlir::tensor::__anonb90e36390111::RankOpInterface761fc08d1c2SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
762b55d55ecSMatthias Springer const BufferizationOptions &options) const {
763fc08d1c2SMatthias Springer auto rankOp = cast<tensor::RankOp>(op);
7645d50f51cSMatthias Springer FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
7655d50f51cSMatthias Springer if (failed(v))
7665d50f51cSMatthias Springer return failure();
767fc08d1c2SMatthias Springer replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
7685d50f51cSMatthias Springer *v);
769fc08d1c2SMatthias Springer return success();
770fc08d1c2SMatthias Springer }
771fc08d1c2SMatthias Springer };
772fc08d1c2SMatthias Springer
773e287d647SAshay Rane /// Bufferization of tensor.reshape. Replace with memref.reshape.
774e287d647SAshay Rane struct ReshapeOpInterface
775e287d647SAshay Rane : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
776e287d647SAshay Rane tensor::ReshapeOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::ReshapeOpInterface777e287d647SAshay Rane bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
778e287d647SAshay Rane const AnalysisState &state) const {
779e287d647SAshay Rane if (&opOperand == &op->getOpOperand(1) /* shape */)
780e287d647SAshay Rane return true;
781e287d647SAshay Rane return false;
782e287d647SAshay Rane }
783e287d647SAshay Rane
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::ReshapeOpInterface784e287d647SAshay Rane bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
785e287d647SAshay Rane const AnalysisState &state) const {
786e287d647SAshay Rane return false;
787e287d647SAshay Rane }
788e287d647SAshay Rane
getAliasingOpResultmlir::tensor::__anonb90e36390111::ReshapeOpInterface789e287d647SAshay Rane SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
790e287d647SAshay Rane const AnalysisState &state) const {
791e287d647SAshay Rane return {op->getOpResult(0)};
792e287d647SAshay Rane }
793e287d647SAshay Rane
bufferRelationmlir::tensor::__anonb90e36390111::ReshapeOpInterface794e287d647SAshay Rane BufferRelation bufferRelation(Operation *op, OpResult opResult,
795e287d647SAshay Rane const AnalysisState &state) const {
796e287d647SAshay Rane return BufferRelation::Equivalent;
797e287d647SAshay Rane }
798e287d647SAshay Rane
bufferizemlir::tensor::__anonb90e36390111::ReshapeOpInterface799e287d647SAshay Rane LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
800b55d55ecSMatthias Springer const BufferizationOptions &options) const {
801e287d647SAshay Rane auto reshapeOp = cast<tensor::ReshapeOp>(op);
8025d50f51cSMatthias Springer FailureOr<Value> srcBuffer =
8035d50f51cSMatthias Springer getBuffer(rewriter, reshapeOp.getSource(), options);
8045d50f51cSMatthias Springer FailureOr<Value> shapeBuffer =
8055d50f51cSMatthias Springer getBuffer(rewriter, reshapeOp.getShape(), options);
8065d50f51cSMatthias Springer if (failed(srcBuffer) || failed(shapeBuffer))
8075d50f51cSMatthias Springer return failure();
808c0b0b6a0SMatthias Springer auto resultMemRefType = getMemRefType(
809606f7c8fSMatthias Springer reshapeOp.getResult(), options, /*layout=*/{},
810c0b0b6a0SMatthias Springer srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt());
811e287d647SAshay Rane replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
8125d50f51cSMatthias Springer rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
813e287d647SAshay Rane return success();
814e287d647SAshay Rane }
815e287d647SAshay Rane };
816e287d647SAshay Rane
8177fbf55c9SNicolas Vasilache /// Analysis of ParallelInsertSliceOp.
8187fbf55c9SNicolas Vasilache struct ParallelInsertSliceOpInterface
8197fbf55c9SNicolas Vasilache : public BufferizableOpInterface::ExternalModel<
8207fbf55c9SNicolas Vasilache ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
getAliasingOpResultmlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface8217fbf55c9SNicolas Vasilache SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
8227fbf55c9SNicolas Vasilache const AnalysisState &state) const {
8237fbf55c9SNicolas Vasilache if (&opOperand != &op->getOpOperand(1) /*dest*/)
8247fbf55c9SNicolas Vasilache return {};
8257fbf55c9SNicolas Vasilache
8267fbf55c9SNicolas Vasilache // ParallelInsertSliceOp itself has no results, query its tied op results.
8277fbf55c9SNicolas Vasilache auto insertOp = cast<ParallelInsertSliceOp>(op);
8287fbf55c9SNicolas Vasilache return {insertOp.getTiedOpResult()};
8297fbf55c9SNicolas Vasilache }
8307fbf55c9SNicolas Vasilache
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface8317fbf55c9SNicolas Vasilache bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
8327fbf55c9SNicolas Vasilache const AnalysisState &state) const {
8337fbf55c9SNicolas Vasilache return true;
8347fbf55c9SNicolas Vasilache }
8357fbf55c9SNicolas Vasilache
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface8367fbf55c9SNicolas Vasilache bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
8377fbf55c9SNicolas Vasilache const AnalysisState &state) const {
8387fbf55c9SNicolas Vasilache return &opOperand == &op->getOpOperand(1) /*dest*/;
8397fbf55c9SNicolas Vasilache }
8407fbf55c9SNicolas Vasilache
bufferRelationmlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface8417fbf55c9SNicolas Vasilache BufferRelation bufferRelation(Operation *op, OpResult opResult,
8427fbf55c9SNicolas Vasilache const AnalysisState &state) const {
8437fbf55c9SNicolas Vasilache return BufferRelation::Equivalent;
8447fbf55c9SNicolas Vasilache }
8457fbf55c9SNicolas Vasilache
resolveConflictsmlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface8467fbf55c9SNicolas Vasilache LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
8477fbf55c9SNicolas Vasilache const AnalysisState &state) const {
8487fbf55c9SNicolas Vasilache // This interface method is overridden because we want to set a custom
8497fbf55c9SNicolas Vasilache // insertion point for tensor copies. They should be inserted right before
8507fbf55c9SNicolas Vasilache // the ForeachThreadOp. E.g.:
8517fbf55c9SNicolas Vasilache //
8527fbf55c9SNicolas Vasilache // %r0, %r1 = foreach_thead ... {
8537fbf55c9SNicolas Vasilache // ...
8547fbf55c9SNicolas Vasilache // perform_concurrently {
8557fbf55c9SNicolas Vasilache // parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
8567fbf55c9SNicolas Vasilache // parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
8577fbf55c9SNicolas Vasilache // }
8587fbf55c9SNicolas Vasilache // }
8597fbf55c9SNicolas Vasilache //
8607fbf55c9SNicolas Vasilache // After TensorCopyInsertion:
8617fbf55c9SNicolas Vasilache //
8627fbf55c9SNicolas Vasilache // %copy = bufferization.alloc_tensor() copy(%d)
8637fbf55c9SNicolas Vasilache // %r0, %r1 = foreach_thead ... {
8647fbf55c9SNicolas Vasilache // ...
8657fbf55c9SNicolas Vasilache // perform_concurrently {
8667fbf55c9SNicolas Vasilache // parallel_insert_slice %a into %b ...
8677fbf55c9SNicolas Vasilache // parallel_insert_slice %c into %copy ...
8687fbf55c9SNicolas Vasilache // }
8697fbf55c9SNicolas Vasilache // }
8707fbf55c9SNicolas Vasilache
8717fbf55c9SNicolas Vasilache OpBuilder::InsertionGuard g(rewriter);
8727fbf55c9SNicolas Vasilache auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
8737fbf55c9SNicolas Vasilache ParallelCombiningOpInterface parallelCombiningParent =
8747fbf55c9SNicolas Vasilache parallelInsertSliceOp.getParallelCombiningParent();
8757fbf55c9SNicolas Vasilache Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
8767fbf55c9SNicolas Vasilache
8777fbf55c9SNicolas Vasilache // Nothing to do if the destination tensor is inplace.
8787fbf55c9SNicolas Vasilache assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
8797fbf55c9SNicolas Vasilache "source is always in-place");
8807fbf55c9SNicolas Vasilache if (state.isInPlace(op->getOpOperand(1) /*dest*/))
8817fbf55c9SNicolas Vasilache return success();
8827fbf55c9SNicolas Vasilache
8837fbf55c9SNicolas Vasilache // Find corresponding OpResult.
8847fbf55c9SNicolas Vasilache OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
8857fbf55c9SNicolas Vasilache
8867fbf55c9SNicolas Vasilache // Insert tensor allocation right before the ForeachThreadOp.
8877fbf55c9SNicolas Vasilache rewriter.setInsertionPoint(parallelIteratingOp);
8887fbf55c9SNicolas Vasilache bool isYielded = state.isTensorYielded(opResult);
8897fbf55c9SNicolas Vasilache FailureOr<Value> alloc = allocateTensorForShapedValue(
8907fbf55c9SNicolas Vasilache rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
8917fbf55c9SNicolas Vasilache /*escape=*/isYielded, state.getOptions());
8927fbf55c9SNicolas Vasilache if (failed(alloc))
8937fbf55c9SNicolas Vasilache return failure();
8947fbf55c9SNicolas Vasilache
8957fbf55c9SNicolas Vasilache // Update destination operand.
8967fbf55c9SNicolas Vasilache rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
8977fbf55c9SNicolas Vasilache parallelInsertSliceOp.getDestMutable().assign(*alloc);
8987fbf55c9SNicolas Vasilache });
8997fbf55c9SNicolas Vasilache
9007fbf55c9SNicolas Vasilache return success();
9017fbf55c9SNicolas Vasilache }
9027fbf55c9SNicolas Vasilache
bufferizemlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface9037fbf55c9SNicolas Vasilache LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
9047fbf55c9SNicolas Vasilache const BufferizationOptions &options) const {
9057fbf55c9SNicolas Vasilache OpBuilder::InsertionGuard g(rewriter);
9067fbf55c9SNicolas Vasilache auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
9077fbf55c9SNicolas Vasilache ParallelCombiningOpInterface parallelCombiningParent =
9087fbf55c9SNicolas Vasilache parallelInsertSliceOp.getParallelCombiningParent();
9097fbf55c9SNicolas Vasilache Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
9107fbf55c9SNicolas Vasilache
9117fbf55c9SNicolas Vasilache // Get destination buffer.
9127fbf55c9SNicolas Vasilache FailureOr<Value> destBuffer =
9137fbf55c9SNicolas Vasilache getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
9147fbf55c9SNicolas Vasilache if (failed(destBuffer))
9157fbf55c9SNicolas Vasilache return failure();
9167fbf55c9SNicolas Vasilache
9177fbf55c9SNicolas Vasilache // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
9187fbf55c9SNicolas Vasilache rewriter.setInsertionPoint(parallelCombiningParent);
9197fbf55c9SNicolas Vasilache FailureOr<Value> srcBuffer =
9207fbf55c9SNicolas Vasilache getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
9217fbf55c9SNicolas Vasilache if (failed(srcBuffer))
9227fbf55c9SNicolas Vasilache return failure();
9236c3c5f80SMatthias Springer
9246c3c5f80SMatthias Springer // Take a subview of the destination buffer.
9256c3c5f80SMatthias Springer auto destBufferType = destBuffer->getType().cast<MemRefType>();
9266c3c5f80SMatthias Springer auto subviewMemRefType =
9276c3c5f80SMatthias Springer memref::SubViewOp::inferRankReducedResultType(
9286c3c5f80SMatthias Springer parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
9296c3c5f80SMatthias Springer parallelInsertSliceOp.getMixedOffsets(),
9306c3c5f80SMatthias Springer parallelInsertSliceOp.getMixedSizes(),
9316c3c5f80SMatthias Springer parallelInsertSliceOp.getMixedStrides())
9326c3c5f80SMatthias Springer .cast<MemRefType>();
9337fbf55c9SNicolas Vasilache Value subview = rewriter.create<memref::SubViewOp>(
9346c3c5f80SMatthias Springer parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
9357fbf55c9SNicolas Vasilache parallelInsertSliceOp.getMixedOffsets(),
9367fbf55c9SNicolas Vasilache parallelInsertSliceOp.getMixedSizes(),
9377fbf55c9SNicolas Vasilache parallelInsertSliceOp.getMixedStrides());
9386c3c5f80SMatthias Springer
9397fbf55c9SNicolas Vasilache // This memcpy will fold away if everything bufferizes in-place.
9407fbf55c9SNicolas Vasilache if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
9417fbf55c9SNicolas Vasilache *srcBuffer, subview)))
9427fbf55c9SNicolas Vasilache return failure();
9437fbf55c9SNicolas Vasilache
9447fbf55c9SNicolas Vasilache // Replace all uses of parallelIteratingOp (just the corresponding result).
9457fbf55c9SNicolas Vasilache rewriter.setInsertionPointAfter(parallelIteratingOp);
9467fbf55c9SNicolas Vasilache Value toTensorOp =
9477fbf55c9SNicolas Vasilache rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
9487fbf55c9SNicolas Vasilache // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
9497fbf55c9SNicolas Vasilache SmallVector<OpOperand *> resultUses = llvm::to_vector(
9507fbf55c9SNicolas Vasilache llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
9517fbf55c9SNicolas Vasilache [](OpOperand &use) { return &use; }));
9527fbf55c9SNicolas Vasilache for (OpOperand *use : resultUses) {
9537fbf55c9SNicolas Vasilache rewriter.updateRootInPlace(use->getOwner(),
9547fbf55c9SNicolas Vasilache [&]() { use->set(toTensorOp); });
9557fbf55c9SNicolas Vasilache }
9567fbf55c9SNicolas Vasilache rewriter.eraseOp(op);
9577fbf55c9SNicolas Vasilache return success();
9587fbf55c9SNicolas Vasilache }
9597fbf55c9SNicolas Vasilache
isNotConflictingmlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface9607fbf55c9SNicolas Vasilache bool isNotConflicting(Operation *op, OpOperand *uRead,
9617fbf55c9SNicolas Vasilache OpOperand *uConflictingWrite,
9627fbf55c9SNicolas Vasilache const AnalysisState &state) const {
963*1defec87SMatthias Springer return isNotConflictingInsertSliceLikeOp<tensor::ParallelInsertSliceOp>(
964*1defec87SMatthias Springer op, uRead, uConflictingWrite, state);
9657fbf55c9SNicolas Vasilache }
9667fbf55c9SNicolas Vasilache };
9677fbf55c9SNicolas Vasilache
96849e37000SMatthias Springer } // namespace
96949e37000SMatthias Springer } // namespace tensor
97049e37000SMatthias Springer } // namespace mlir
97149e37000SMatthias Springer
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)97249e37000SMatthias Springer void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
97349e37000SMatthias Springer DialectRegistry ®istry) {
97477eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
97577eee579SRiver Riddle CastOp::attachInterface<CastOpInterface>(*ctx);
97677eee579SRiver Riddle CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
97777eee579SRiver Riddle DimOp::attachInterface<DimOpInterface>(*ctx);
97877eee579SRiver Riddle ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
97977eee579SRiver Riddle ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
98077eee579SRiver Riddle ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
98177eee579SRiver Riddle FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
98277eee579SRiver Riddle GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
98377eee579SRiver Riddle InsertOp::attachInterface<InsertOpInterface>(*ctx);
98477eee579SRiver Riddle InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
9857fbf55c9SNicolas Vasilache ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
9867fbf55c9SNicolas Vasilache *ctx);
98777eee579SRiver Riddle RankOp::attachInterface<RankOpInterface>(*ctx);
988e287d647SAshay Rane ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
9895f5f71e7SMatthias Springer
9905f5f71e7SMatthias Springer // Load additional dialects of which ops may get created.
9915f5f71e7SMatthias Springer ctx->loadDialect<arith::ArithmeticDialect, scf::SCFDialect>();
99277eee579SRiver Riddle });
99349e37000SMatthias Springer }
994