149e37000SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
249e37000SMatthias Springer //
349e37000SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
449e37000SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
549e37000SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
649e37000SMatthias Springer //
749e37000SMatthias Springer //===----------------------------------------------------------------------===//
849e37000SMatthias Springer 
949e37000SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1149e37000SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12b3ebe3beSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1349e37000SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
148b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
1549e37000SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
1649e37000SMatthias Springer #include "mlir/IR/Dialect.h"
1749e37000SMatthias Springer #include "mlir/IR/Operation.h"
1849e37000SMatthias Springer 
1949e37000SMatthias Springer using namespace mlir;
2049e37000SMatthias Springer using namespace mlir::bufferization;
2149e37000SMatthias Springer using namespace mlir::tensor;
2249e37000SMatthias Springer 
2349e37000SMatthias Springer namespace mlir {
2449e37000SMatthias Springer namespace tensor {
2549e37000SMatthias Springer namespace {
2649e37000SMatthias Springer 
2749e37000SMatthias Springer struct CastOpInterface
2849e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
2949e37000SMatthias Springer                                                     tensor::CastOp> {
3049e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
319597b16aSMatthias Springer                               const AnalysisState &state) const {
3249e37000SMatthias Springer     return false;
3349e37000SMatthias Springer   }
3449e37000SMatthias Springer 
3549e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
369597b16aSMatthias Springer                                const AnalysisState &state) const {
3749e37000SMatthias Springer     return false;
3849e37000SMatthias Springer   }
3949e37000SMatthias Springer 
409597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
419597b16aSMatthias Springer                                             const AnalysisState &state) const {
42585a8a32SMatthias Springer     return {op->getResult(0)};
4349e37000SMatthias Springer   }
4449e37000SMatthias Springer 
4549e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
469597b16aSMatthias Springer                                 const AnalysisState &state) const {
4749e37000SMatthias Springer     return BufferRelation::Equivalent;
4849e37000SMatthias Springer   }
4949e37000SMatthias Springer 
5049e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
5249e37000SMatthias Springer     auto castOp = cast<tensor::CastOp>(op);
5349e37000SMatthias Springer 
5449e37000SMatthias Springer     // The result buffer still has the old (pre-cast) type.
555d50f51cSMatthias Springer     FailureOr<Value> resultBuffer =
565d50f51cSMatthias Springer         getBuffer(rewriter, castOp.getSource(), options);
575d50f51cSMatthias Springer     if (failed(resultBuffer))
585d50f51cSMatthias Springer       return failure();
595d50f51cSMatthias Springer     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
6049e37000SMatthias Springer     TensorType resultTensorType =
6149e37000SMatthias Springer         castOp.getResult().getType().cast<TensorType>();
6249e37000SMatthias Springer     MemRefLayoutAttrInterface layout;
6349e37000SMatthias Springer 
6449e37000SMatthias Springer     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
6549e37000SMatthias Springer       if (resultTensorType.isa<RankedTensorType>())
6649e37000SMatthias Springer         layout = rankedMemRefType.getLayout();
6749e37000SMatthias Springer 
6849e37000SMatthias Springer     // Compute the new memref type.
69b55d55ecSMatthias Springer     Type resultMemRefType =
70b06614e2SMatthias Springer         getMemRefType(resultTensorType, options, layout,
71b06614e2SMatthias Springer                       sourceMemRefType.getMemorySpaceAsInt());
7249e37000SMatthias Springer 
7349e37000SMatthias Springer     // Replace the op with a memref.cast.
745d50f51cSMatthias Springer     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
7549e37000SMatthias Springer                                              resultMemRefType) &&
7649e37000SMatthias Springer            "CallOp::bufferize: cast incompatible");
7749e37000SMatthias Springer     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
785d50f51cSMatthias Springer                                                  *resultBuffer);
7949e37000SMatthias Springer 
8049e37000SMatthias Springer     return success();
8149e37000SMatthias Springer   }
8249e37000SMatthias Springer };
8349e37000SMatthias Springer 
84e6f69161SMatthias Springer /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
85e6f69161SMatthias Springer struct CollapseShapeOpInterface
86e6f69161SMatthias Springer     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
87e6f69161SMatthias Springer                                                     tensor::CollapseShapeOp> {
88e6f69161SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
899597b16aSMatthias Springer                               const AnalysisState &state) const {
90e6f69161SMatthias Springer     return false;
91e6f69161SMatthias Springer   }
92e6f69161SMatthias Springer 
93e6f69161SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
949597b16aSMatthias Springer                                const AnalysisState &state) const {
95e6f69161SMatthias Springer     return false;
96e6f69161SMatthias Springer   }
97e6f69161SMatthias Springer 
989597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
999597b16aSMatthias Springer                                             const AnalysisState &state) const {
100e6f69161SMatthias Springer     if (&opOperand == &op->getOpOperand(0) /*src*/)
101e6f69161SMatthias Springer       return {op->getOpResult(0)};
102e6f69161SMatthias Springer     return {};
103e6f69161SMatthias Springer   }
104e6f69161SMatthias Springer 
105e6f69161SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
1069597b16aSMatthias Springer                                 const AnalysisState &state) const {
107e6f69161SMatthias Springer     return BufferRelation::Equivalent;
108e6f69161SMatthias Springer   }
109e6f69161SMatthias Springer 
110e6f69161SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
111b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
112e6f69161SMatthias Springer     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
11351df6238SMatthias Springer     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
1145d50f51cSMatthias Springer     FailureOr<Value> maybeBuffer =
1155d50f51cSMatthias Springer         getBuffer(rewriter, collapseShapeOp.getSrc(), options);
1165d50f51cSMatthias Springer     if (failed(maybeBuffer))
1175d50f51cSMatthias Springer       return failure();
1185d50f51cSMatthias Springer     Value buffer = *maybeBuffer;
119b3ebe3beSMatthias Springer     auto bufferType = buffer.getType().cast<MemRefType>();
12051df6238SMatthias Springer 
12151df6238SMatthias Springer     if (tensorResultType.getRank() == 0) {
12251df6238SMatthias Springer       // 0-d collapses must go through a different op builder.
12373c0333dSMatthias Springer       MemRefType resultType;
12473c0333dSMatthias Springer 
12573c0333dSMatthias Springer       if (bufferType.getLayout().isIdentity()) {
12673c0333dSMatthias Springer         // Standard layout: result type has no offset.
12751df6238SMatthias Springer         MemRefLayoutAttrInterface layout;
12873c0333dSMatthias Springer         resultType = MemRefType::get({}, tensorResultType.getElementType(),
12951df6238SMatthias Springer                                      layout, bufferType.getMemorySpace());
13073c0333dSMatthias Springer       } else {
13173c0333dSMatthias Springer         // Source memref has a layout map: result type has the same offset as
13273c0333dSMatthias Springer         // the source type.
13373c0333dSMatthias Springer         SmallVector<int64_t> strides;
13473c0333dSMatthias Springer         int64_t offset;
13573c0333dSMatthias Springer         if (failed(getStridesAndOffset(bufferType, strides, offset)))
13673c0333dSMatthias Springer           return failure();
13773c0333dSMatthias Springer         AffineMap resultLayout =
13873c0333dSMatthias Springer             makeStridedLinearLayoutMap({}, offset, op->getContext());
13973c0333dSMatthias Springer         resultType =
14073c0333dSMatthias Springer             MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
14173c0333dSMatthias Springer                             bufferType.getMemorySpaceAsInt());
14273c0333dSMatthias Springer       }
14373c0333dSMatthias Springer 
144e6f69161SMatthias Springer       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
1458df54a6aSJacques Pienaar           rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
146e6f69161SMatthias Springer       return success();
147e6f69161SMatthias Springer     }
14851df6238SMatthias Springer 
149d7a9bf91SMatthias Springer     // If the dims are not collapsible (due to an incompatible source layout
150d7a9bf91SMatthias Springer     // map), force an out-of-place bufferization, i.e., a buffer copy. This
151d7a9bf91SMatthias Springer     // newly allocated buffer will have no layout map and thus be collapsible.
152a74e5a89SAdrian Kuegel     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
153d7a9bf91SMatthias Springer         bufferType, collapseShapeOp.getReassociationIndices());
154b3ebe3beSMatthias Springer     if (!canBeCollapsed) {
155b3ebe3beSMatthias Springer       // TODO: Create alloc_tensor ops during TensorCopyInsertion.
156b55d55ecSMatthias Springer       AnalysisState analysisState(options);
15745b995cdSMatthias Springer       FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1588df54a6aSJacques Pienaar           rewriter, op->getLoc(), collapseShapeOp.getSrc(),
15945b995cdSMatthias Springer           analysisState.isTensorYielded(collapseShapeOp.getResult()), options);
16045b995cdSMatthias Springer       if (failed(tensorAlloc))
16145b995cdSMatthias Springer         return failure();
162b3ebe3beSMatthias Springer       auto memrefType =
163b3ebe3beSMatthias Springer           MemRefType::get(collapseShapeOp.getSrcType().getShape(),
164b3ebe3beSMatthias Springer                           collapseShapeOp.getSrcType().getElementType(),
165b3ebe3beSMatthias Springer                           AffineMap(), bufferType.getMemorySpaceAsInt());
166b3ebe3beSMatthias Springer       buffer = rewriter.create<bufferization::ToMemrefOp>(
16745b995cdSMatthias Springer           op->getLoc(), memrefType, *tensorAlloc);
168b3ebe3beSMatthias Springer     }
169d7a9bf91SMatthias Springer 
17051df6238SMatthias Springer     // Result type is inferred by the builder.
17151df6238SMatthias Springer     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
172b3ebe3beSMatthias Springer         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
17351df6238SMatthias Springer     return success();
17451df6238SMatthias Springer   }
175e6f69161SMatthias Springer };
176e6f69161SMatthias Springer 
17749e37000SMatthias Springer /// Bufferization of tensor.dim. Replace with memref.dim.
17849e37000SMatthias Springer struct DimOpInterface
17949e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
18049e37000SMatthias Springer                                                     tensor::DimOp> {
18149e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1829597b16aSMatthias Springer                               const AnalysisState &state) const {
18349e37000SMatthias Springer     return true;
18449e37000SMatthias Springer   }
18549e37000SMatthias Springer 
18649e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1879597b16aSMatthias Springer                                const AnalysisState &state) const {
18849e37000SMatthias Springer     return false;
18949e37000SMatthias Springer   }
19049e37000SMatthias Springer 
1919597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
1929597b16aSMatthias Springer                                             const AnalysisState &state) const {
193585a8a32SMatthias Springer     return {};
19449e37000SMatthias Springer   }
19549e37000SMatthias Springer 
19649e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
197b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
19849e37000SMatthias Springer     auto dimOp = cast<tensor::DimOp>(op);
1995d50f51cSMatthias Springer     FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
2005d50f51cSMatthias Springer     if (failed(v))
2015d50f51cSMatthias Springer       return failure();
2025d50f51cSMatthias Springer     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
2035d50f51cSMatthias Springer                                                 dimOp.index());
20449e37000SMatthias Springer     return success();
20549e37000SMatthias Springer   }
20649e37000SMatthias Springer };
20749e37000SMatthias Springer 
208e6f69161SMatthias Springer /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
209e6f69161SMatthias Springer struct ExpandShapeOpInterface
210e6f69161SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
211e6f69161SMatthias Springer                                                     tensor::ExpandShapeOp> {
212e6f69161SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
2139597b16aSMatthias Springer                               const AnalysisState &state) const {
214e6f69161SMatthias Springer     return false;
215e6f69161SMatthias Springer   }
216e6f69161SMatthias Springer 
217e6f69161SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
2189597b16aSMatthias Springer                                const AnalysisState &state) const {
219e6f69161SMatthias Springer     return false;
220e6f69161SMatthias Springer   }
221e6f69161SMatthias Springer 
2229597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
2239597b16aSMatthias Springer                                             const AnalysisState &state) const {
224e6f69161SMatthias Springer     if (&opOperand == &op->getOpOperand(0) /*src*/)
225e6f69161SMatthias Springer       return {op->getOpResult(0)};
226e6f69161SMatthias Springer     return {};
227e6f69161SMatthias Springer   }
228e6f69161SMatthias Springer 
229e6f69161SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
2309597b16aSMatthias Springer                                 const AnalysisState &state) const {
231e6f69161SMatthias Springer     return BufferRelation::Equivalent;
232e6f69161SMatthias Springer   }
233e6f69161SMatthias Springer 
234e6f69161SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
235b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
236e6f69161SMatthias Springer     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
23751df6238SMatthias Springer     auto tensorResultType = expandShapeOp.getResultType();
2385d50f51cSMatthias Springer     FailureOr<Value> buffer =
2395d50f51cSMatthias Springer         getBuffer(rewriter, expandShapeOp.getSrc(), options);
2405d50f51cSMatthias Springer     if (failed(buffer))
2415d50f51cSMatthias Springer       return failure();
24251df6238SMatthias Springer 
24351df6238SMatthias Springer     // Memref result type is inferred by the builder based on reassociation
24451df6238SMatthias Springer     // indices and result shape.
245e6f69161SMatthias Springer     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
2465d50f51cSMatthias Springer         rewriter, op, tensorResultType.getShape(), *buffer,
24751df6238SMatthias Springer         expandShapeOp.getReassociationIndices());
248e6f69161SMatthias Springer     return success();
249e6f69161SMatthias Springer   }
250e6f69161SMatthias Springer };
251e6f69161SMatthias Springer 
25249e37000SMatthias Springer /// Bufferization of tensor.extract_slice. Replace with memref.subview.
25349e37000SMatthias Springer struct ExtractSliceOpInterface
25449e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
25549e37000SMatthias Springer                                                     tensor::ExtractSliceOp> {
25649e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
2579597b16aSMatthias Springer                               const AnalysisState &state) const {
25849e37000SMatthias Springer     return false;
25949e37000SMatthias Springer   }
26049e37000SMatthias Springer 
26149e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
2629597b16aSMatthias Springer                                const AnalysisState &state) const {
26349e37000SMatthias Springer     return false;
26449e37000SMatthias Springer   }
26549e37000SMatthias Springer 
2669597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
2679597b16aSMatthias Springer                                             const AnalysisState &state) const {
268585a8a32SMatthias Springer     if (&opOperand == &op->getOpOperand(0) /*source*/)
269585a8a32SMatthias Springer       return {op->getOpResult(0)};
270585a8a32SMatthias Springer     return {};
27149e37000SMatthias Springer   }
27249e37000SMatthias Springer 
27349e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
2749597b16aSMatthias Springer                                 const AnalysisState &state) const {
27549e37000SMatthias Springer     return BufferRelation::None;
27649e37000SMatthias Springer   }
27749e37000SMatthias Springer 
27849e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
279b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
28049e37000SMatthias Springer     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
28149e37000SMatthias Springer     Location loc = extractSliceOp.getLoc();
282d7a9bf91SMatthias Springer 
283d7a9bf91SMatthias Springer     // Even if this op was decided to bufferize out-of-place, do not insert the
284d7a9bf91SMatthias Springer     // buffer copy yet. This is done later in this function.
2855d50f51cSMatthias Springer     FailureOr<Value> srcMemref =
2865d50f51cSMatthias Springer         getBuffer(rewriter, extractSliceOp.getSource(), options);
2875d50f51cSMatthias Springer     if (failed(srcMemref))
2885d50f51cSMatthias Springer       return failure();
2895d50f51cSMatthias Springer     auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
29049e37000SMatthias Springer     auto dstTensorType =
2918df54a6aSJacques Pienaar         extractSliceOp.getResult().getType().cast<RankedTensorType>();
29249e37000SMatthias Springer 
29349e37000SMatthias Springer     // Expand offsets, sizes and strides to the full rank to handle the
29449e37000SMatthias Springer     // rank-reducing case.
29549e37000SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
29649e37000SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
29749e37000SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
29849e37000SMatthias Springer     OffsetSizeAndStrideOpInterface::expandToRank(
2995d50f51cSMatthias Springer         *srcMemref, mixedOffsets, mixedSizes, mixedStrides,
30049e37000SMatthias Springer         [&](Value target, int64_t dim) -> OpFoldResult {
30149e37000SMatthias Springer           auto shapedType = target.getType().cast<ShapedType>();
30249e37000SMatthias Springer           if (shapedType.isDynamicDim(dim))
30349e37000SMatthias Springer             return rewriter.create<memref::DimOp>(loc, target, dim).result();
30449e37000SMatthias Springer           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
30549e37000SMatthias Springer         });
30649e37000SMatthias Springer     // Bufferize to subview.
30749e37000SMatthias Springer     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
30849e37000SMatthias Springer                                  dstTensorType.getRank(), srcMemrefType,
30949e37000SMatthias Springer                                  mixedOffsets, mixedSizes, mixedStrides)
31049e37000SMatthias Springer                                  .cast<MemRefType>();
31149e37000SMatthias Springer     Value subView = rewriter.create<memref::SubViewOp>(
3125d50f51cSMatthias Springer         loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
31349e37000SMatthias Springer         mixedStrides);
31449e37000SMatthias Springer 
31549e37000SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, subView);
31649e37000SMatthias Springer     return success();
31749e37000SMatthias Springer   }
31849e37000SMatthias Springer };
31949e37000SMatthias Springer 
32049e37000SMatthias Springer /// Bufferization of tensor.extract. Replace with memref.load.
32149e37000SMatthias Springer struct ExtractOpInterface
32249e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
32349e37000SMatthias Springer                                                     tensor::ExtractOp> {
32449e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
3259597b16aSMatthias Springer                               const AnalysisState &state) const {
32649e37000SMatthias Springer     return true;
32749e37000SMatthias Springer   }
32849e37000SMatthias Springer 
32949e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
3309597b16aSMatthias Springer                                const AnalysisState &state) const {
33149e37000SMatthias Springer     return false;
33249e37000SMatthias Springer   }
33349e37000SMatthias Springer 
3349597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
3359597b16aSMatthias Springer                                             const AnalysisState &state) const {
336585a8a32SMatthias Springer     return {};
33749e37000SMatthias Springer   }
33849e37000SMatthias Springer 
33949e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
340b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
34149e37000SMatthias Springer     auto extractOp = cast<tensor::ExtractOp>(op);
3425d50f51cSMatthias Springer     FailureOr<Value> srcMemref =
3435d50f51cSMatthias Springer         getBuffer(rewriter, extractOp.getTensor(), options);
3445d50f51cSMatthias Springer     if (failed(srcMemref))
3455d50f51cSMatthias Springer       return failure();
3465d50f51cSMatthias Springer     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
3475d50f51cSMatthias Springer                                                  extractOp.indices());
34849e37000SMatthias Springer     return success();
34949e37000SMatthias Springer   }
35049e37000SMatthias Springer };
35149e37000SMatthias Springer 
352d581c94dSMatthias Springer // Implements backtracking to traverse indices of the output buffer while
353d581c94dSMatthias Springer // iterating over op.elements().
354d581c94dSMatthias Springer static void createStores(RewriterBase &rewriter, Location loc, int dim,
355d581c94dSMatthias Springer                          Value buffer, ArrayRef<int64_t> shape,
356d581c94dSMatthias Springer                          ArrayRef<Value> constants,
357d581c94dSMatthias Springer                          OperandRange::iterator &elementIt,
358d581c94dSMatthias Springer                          SmallVectorImpl<Value> &indices) {
359d581c94dSMatthias Springer   if (dim == static_cast<int>(shape.size()) - 1) {
360d581c94dSMatthias Springer     for (int i = 0; i < shape.back(); ++i) {
361d581c94dSMatthias Springer       indices.back() = constants[i];
362d581c94dSMatthias Springer       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
363d581c94dSMatthias Springer       ++elementIt;
364d581c94dSMatthias Springer     }
365d581c94dSMatthias Springer     return;
366d581c94dSMatthias Springer   }
367d581c94dSMatthias Springer   for (int i = 0; i < shape[dim]; ++i) {
368d581c94dSMatthias Springer     indices[dim] = constants[i];
369d581c94dSMatthias Springer     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
370d581c94dSMatthias Springer                  indices);
371d581c94dSMatthias Springer   }
372d581c94dSMatthias Springer }
373d581c94dSMatthias Springer 
374d581c94dSMatthias Springer /// Bufferization of tensor.from_elements.
375d581c94dSMatthias Springer struct FromElementsOpInterface
376d581c94dSMatthias Springer     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
377d581c94dSMatthias Springer                                                     tensor::FromElementsOp> {
378d581c94dSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
379b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
380d581c94dSMatthias Springer     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
381d581c94dSMatthias Springer 
382c0b0b6a0SMatthias Springer     // TODO: Implement memory space for this op.
383c0b0b6a0SMatthias Springer     if (options.defaultMemorySpace != static_cast<unsigned>(0))
384c0b0b6a0SMatthias Springer       return op->emitError("memory space not implemented yet");
385c0b0b6a0SMatthias Springer 
386d581c94dSMatthias Springer     // Allocate a buffer for the result.
387d581c94dSMatthias Springer     Location loc = op->getLoc();
388d581c94dSMatthias Springer     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
389d581c94dSMatthias Springer     auto shape = tensorType.getShape();
390b3ebe3beSMatthias Springer     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
391b55d55ecSMatthias Springer     AnalysisState analysisState(options);
39245b995cdSMatthias Springer     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
3938df54a6aSJacques Pienaar         rewriter, loc, fromElementsOp.getResult(),
39445b995cdSMatthias Springer         analysisState.isTensorYielded(fromElementsOp.getResult()), options,
395b3ebe3beSMatthias Springer         /*copy=*/false);
39645b995cdSMatthias Springer     if (failed(tensorAlloc))
39745b995cdSMatthias Springer       return failure();
398b3ebe3beSMatthias Springer     auto memrefType =
399b3ebe3beSMatthias Springer         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
400b3ebe3beSMatthias Springer     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
40145b995cdSMatthias Springer         op->getLoc(), memrefType, *tensorAlloc);
402d581c94dSMatthias Springer 
403d581c94dSMatthias Springer     // Case: tensor<0xelem_type>.
4048df54a6aSJacques Pienaar     if (fromElementsOp.getElements().empty()) {
405d581c94dSMatthias Springer       replaceOpWithBufferizedValues(rewriter, op, buffer);
406d581c94dSMatthias Springer       return success();
407d581c94dSMatthias Springer     }
408d581c94dSMatthias Springer 
409d581c94dSMatthias Springer     // Case: tensor<elem_type>.
410d581c94dSMatthias Springer     if (shape.empty()) {
4118df54a6aSJacques Pienaar       rewriter.create<memref::StoreOp>(
4128df54a6aSJacques Pienaar           loc, fromElementsOp.getElements().front(), buffer);
413d581c94dSMatthias Springer       replaceOpWithBufferizedValues(rewriter, op, buffer);
414d581c94dSMatthias Springer       return success();
415d581c94dSMatthias Springer     }
416d581c94dSMatthias Springer 
417d581c94dSMatthias Springer     // Create constants for the range of possible indices [0, max{shape_i}).
418d581c94dSMatthias Springer     auto maxDim = *std::max_element(shape.begin(), shape.end());
419d581c94dSMatthias Springer     SmallVector<Value, 2> constants;
420d581c94dSMatthias Springer     constants.reserve(maxDim);
421d581c94dSMatthias Springer     for (int i = 0; i < maxDim; ++i)
422d581c94dSMatthias Springer       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
423d581c94dSMatthias Springer 
424d581c94dSMatthias Springer     // Traverse all `elements` and create `memref.store` ops.
4258df54a6aSJacques Pienaar     auto elementIt = fromElementsOp.getElements().begin();
426d581c94dSMatthias Springer     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
427d581c94dSMatthias Springer     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
428d581c94dSMatthias Springer                  indices);
429d581c94dSMatthias Springer 
430d581c94dSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, buffer);
431d581c94dSMatthias Springer     return success();
432d581c94dSMatthias Springer   }
433d581c94dSMatthias Springer };
434d581c94dSMatthias Springer 
43571bbb78bSMatthias Springer /// Bufferization of tensor.generate.
43671bbb78bSMatthias Springer struct GenerateOpInterface
43771bbb78bSMatthias Springer     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
43871bbb78bSMatthias Springer                                                     tensor::GenerateOp> {
43971bbb78bSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
440b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
44171bbb78bSMatthias Springer     auto generateOp = cast<tensor::GenerateOp>(op);
442c0b0b6a0SMatthias Springer 
443c0b0b6a0SMatthias Springer     // TODO: Implement memory space for this op.
444c0b0b6a0SMatthias Springer     if (options.defaultMemorySpace != static_cast<unsigned>(0))
445c0b0b6a0SMatthias Springer       return op->emitError("memory space not implemented yet");
446c0b0b6a0SMatthias Springer 
447b3ebe3beSMatthias Springer     auto tensorType = generateOp.getType().cast<RankedTensorType>();
44871bbb78bSMatthias Springer     // Allocate memory.
44971bbb78bSMatthias Springer     Location loc = op->getLoc();
450b3ebe3beSMatthias Springer     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
451b55d55ecSMatthias Springer     AnalysisState analysisState(options);
45245b995cdSMatthias Springer     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
4538df54a6aSJacques Pienaar         rewriter, loc, generateOp.getResult(),
45445b995cdSMatthias Springer         analysisState.isTensorYielded(generateOp.getResult()), options,
455b3ebe3beSMatthias Springer         /*copy=*/false);
45645b995cdSMatthias Springer     if (failed(tensorAlloc))
45745b995cdSMatthias Springer       return failure();
458b3ebe3beSMatthias Springer     auto memrefType =
459b3ebe3beSMatthias Springer         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
460b3ebe3beSMatthias Springer     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
46145b995cdSMatthias Springer         op->getLoc(), memrefType, *tensorAlloc);
46271bbb78bSMatthias Springer 
46371bbb78bSMatthias Springer     // Collect loop bounds.
46471bbb78bSMatthias Springer     int64_t rank = memrefType.getRank();
46571bbb78bSMatthias Springer     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
46671bbb78bSMatthias Springer     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
46771bbb78bSMatthias Springer     SmallVector<Value, 4> lowerBounds(rank, zero);
46871bbb78bSMatthias Springer     SmallVector<Value, 4> steps(rank, one);
46971bbb78bSMatthias Springer     SmallVector<Value, 4> upperBounds;
47071bbb78bSMatthias Springer     int nextDynamicIndex = 0;
47171bbb78bSMatthias Springer     for (int i = 0; i < rank; i++) {
4728df54a6aSJacques Pienaar       Value upperBound =
4738df54a6aSJacques Pienaar           memrefType.isDynamicDim(i)
4748df54a6aSJacques Pienaar               ? generateOp.getDynamicExtents()[nextDynamicIndex++]
47571bbb78bSMatthias Springer               : rewriter.create<arith::ConstantIndexOp>(
47671bbb78bSMatthias Springer                     loc, memrefType.getDimSize(i));
47771bbb78bSMatthias Springer       upperBounds.push_back(upperBound);
47871bbb78bSMatthias Springer     }
47971bbb78bSMatthias Springer 
48071bbb78bSMatthias Springer     // Generate tensor elements with a parallel loop that stores into
48171bbb78bSMatthias Springer     // each element of the resulting memref. We use mergeBlockBefore to "move"
48271bbb78bSMatthias Springer     // this op's body into the scf.parallel's body.
48371bbb78bSMatthias Springer     auto parallel =
48471bbb78bSMatthias Springer         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
48571bbb78bSMatthias Springer     Block *parallelBody = parallel.getBody();
486eca86cb2SJacques Pienaar     rewriter.mergeBlockBefore(&generateOp.getBody().front(),
48771bbb78bSMatthias Springer                               parallelBody->getTerminator(),
48871bbb78bSMatthias Springer                               parallelBody->getArguments());
48971bbb78bSMatthias Springer     // Replace the inlined yield op with a store op. The scf.parallel's builder
49071bbb78bSMatthias Springer     // already populated an scf.yield at the end, so we don't need to worry
49171bbb78bSMatthias Springer     // about creating that.
49271bbb78bSMatthias Springer     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
49371bbb78bSMatthias Springer     rewriter.setInsertionPointAfter(elementYield);
49471bbb78bSMatthias Springer     rewriter.replaceOpWithNewOp<memref::StoreOp>(
495b3ebe3beSMatthias Springer         elementYield, elementYield->getOperands()[0], buffer,
49671bbb78bSMatthias Springer         parallelBody->getArguments());
49771bbb78bSMatthias Springer 
498b3ebe3beSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, buffer);
49971bbb78bSMatthias Springer     return success();
50071bbb78bSMatthias Springer   }
50171bbb78bSMatthias Springer };
50271bbb78bSMatthias Springer 
50349e37000SMatthias Springer /// Bufferization of tensor.insert. Replace with memref.store.
50449e37000SMatthias Springer struct InsertOpInterface
50549e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
50649e37000SMatthias Springer                                                     tensor::InsertOp> {
50749e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
5089597b16aSMatthias Springer                               const AnalysisState &state) const {
50949e37000SMatthias Springer     return true;
51049e37000SMatthias Springer   }
51149e37000SMatthias Springer 
51249e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
5139597b16aSMatthias Springer                                const AnalysisState &state) const {
51449e37000SMatthias Springer     return true;
51549e37000SMatthias Springer   }
51649e37000SMatthias Springer 
5179597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
5189597b16aSMatthias Springer                                             const AnalysisState &state) const {
51949e37000SMatthias Springer     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
52049e37000SMatthias Springer            "expected dest OpOperand");
521585a8a32SMatthias Springer     return {op->getOpResult(0)};
52249e37000SMatthias Springer   }
52349e37000SMatthias Springer 
52449e37000SMatthias Springer   SmallVector<OpOperand *>
52549e37000SMatthias Springer   getAliasingOpOperand(Operation *op, OpResult opResult,
5269597b16aSMatthias Springer                        const AnalysisState &state) const {
52749e37000SMatthias Springer     return {&op->getOpOperand(1) /*dest*/};
52849e37000SMatthias Springer   }
52949e37000SMatthias Springer 
53049e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
531b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
53249e37000SMatthias Springer     auto insertOp = cast<tensor::InsertOp>(op);
5335d50f51cSMatthias Springer     FailureOr<Value> destMemref =
5345d50f51cSMatthias Springer         getBuffer(rewriter, insertOp.getDest(), options);
5355d50f51cSMatthias Springer     if (failed(destMemref))
5365d50f51cSMatthias Springer       return failure();
5378df54a6aSJacques Pienaar     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
5385d50f51cSMatthias Springer                                      *destMemref, insertOp.getIndices());
5395d50f51cSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
54049e37000SMatthias Springer     return success();
54149e37000SMatthias Springer   }
54249e37000SMatthias Springer 
54349e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
5449597b16aSMatthias Springer                                 const AnalysisState &state) const {
54549e37000SMatthias Springer     return BufferRelation::Equivalent;
54649e37000SMatthias Springer   }
54749e37000SMatthias Springer };
54849e37000SMatthias Springer 
54949e37000SMatthias Springer /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
55049e37000SMatthias Springer /// equivalent operand / result and same offset/sizes/strides specification).
55149e37000SMatthias Springer ///
55249e37000SMatthias Springer /// This is one particular type of relationship between ops on tensors that
55349e37000SMatthias Springer /// reduce to an equivalence on buffers. This should be generalized and
55449e37000SMatthias Springer /// exposed as interfaces on the proper types.
5559597b16aSMatthias Springer static bool areEquivalentExtractSliceOps(const AnalysisState &state,
55649e37000SMatthias Springer                                          ExtractSliceOp st, InsertSliceOp sti) {
55749e37000SMatthias Springer   if (!st || !sti)
55849e37000SMatthias Springer     return false;
55949e37000SMatthias Springer   if (sti != sti &&
5608df54a6aSJacques Pienaar       !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
56149e37000SMatthias Springer     return false;
56249e37000SMatthias Springer   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
56349e37000SMatthias Springer     return false;
56449e37000SMatthias Springer   return true;
56549e37000SMatthias Springer }
56649e37000SMatthias Springer 
56749e37000SMatthias Springer /// Return true if `value` is originating from an ExtractSliceOp that matches
56849e37000SMatthias Springer /// the given InsertSliceOp.
5699597b16aSMatthias Springer static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
5709597b16aSMatthias Springer                                       InsertSliceOp insertOp) {
57149e37000SMatthias Springer   auto condition = [&](Value val) {
57249e37000SMatthias Springer     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
57349e37000SMatthias Springer       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
57449e37000SMatthias Springer         return true;
57549e37000SMatthias Springer     return false;
57649e37000SMatthias Springer   };
57749e37000SMatthias Springer 
57849e37000SMatthias Springer   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
57949e37000SMatthias Springer                       condition);
58049e37000SMatthias Springer }
58149e37000SMatthias Springer 
58249e37000SMatthias Springer /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
58349e37000SMatthias Springer /// certain circumstances, this op can also be a no-op.
58449e37000SMatthias Springer struct InsertSliceOpInterface
58549e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
58649e37000SMatthias Springer                                                     tensor::InsertSliceOp> {
58749e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
5889597b16aSMatthias Springer                               const AnalysisState &state) const {
58949e37000SMatthias Springer     return true;
59049e37000SMatthias Springer   }
59149e37000SMatthias Springer 
59249e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
5939597b16aSMatthias Springer                                const AnalysisState &state) const {
59449e37000SMatthias Springer     return &opOperand == &op->getOpOperand(1) /*dest*/;
59549e37000SMatthias Springer   }
59649e37000SMatthias Springer 
5979597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
5989597b16aSMatthias Springer                                             const AnalysisState &state) const {
599585a8a32SMatthias Springer     if (&opOperand == &op->getOpOperand(1) /*dest*/)
600585a8a32SMatthias Springer       return {op->getResult(0)};
601585a8a32SMatthias Springer     return {};
60249e37000SMatthias Springer   }
60349e37000SMatthias Springer 
60449e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
6059597b16aSMatthias Springer                                 const AnalysisState &state) const {
60649e37000SMatthias Springer     return BufferRelation::Equivalent;
60749e37000SMatthias Springer   }
60849e37000SMatthias Springer 
60949e37000SMatthias Springer   bool isNotConflicting(Operation *op, OpOperand *uRead,
61049e37000SMatthias Springer                         OpOperand *uConflictingWrite,
6119597b16aSMatthias Springer                         const AnalysisState &state) const {
61249e37000SMatthias Springer     Operation *readingOp = uRead->getOwner();
61349e37000SMatthias Springer     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
61449e37000SMatthias Springer 
61549e37000SMatthias Springer     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
61649e37000SMatthias Springer     // uRead is an InsertSliceOp...
61749e37000SMatthias Springer     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
61849e37000SMatthias Springer       // As an example, consider the following IR.
61949e37000SMatthias Springer       //
62049e37000SMatthias Springer       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
62149e37000SMatthias Springer       // %1 = linalg.fill %cst, %0 {inplace= [true] }
62249e37000SMatthias Springer       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
62349e37000SMatthias Springer       //     {inplace= [true] }
62449e37000SMatthias Springer 
62549e37000SMatthias Springer       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
62649e37000SMatthias Springer       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
62749e37000SMatthias Springer           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
62849e37000SMatthias Springer                                     insertSliceOp))
62949e37000SMatthias Springer         // Case 1: The main insight is that InsertSliceOp reads only part of
63049e37000SMatthias Springer         // the destination tensor. The overwritten area is not read. If
63149e37000SMatthias Springer         // uConflictingWrite writes into exactly the memory location that is
63249e37000SMatthias Springer         // being read by uRead, this is not a conflict.
63349e37000SMatthias Springer         //
63449e37000SMatthias Springer         // In the above example:
63549e37000SMatthias Springer         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
63649e37000SMatthias Springer         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
63749e37000SMatthias Springer         //
63849e37000SMatthias Springer         // The read of %t does not conflict with the write of the FillOp
63949e37000SMatthias Springer         // (same aliases!) because the area that the FillOp operates on is
64049e37000SMatthias Springer         // exactly the one that is *not* read via %t.
64149e37000SMatthias Springer         return true;
64249e37000SMatthias Springer 
64349e37000SMatthias Springer       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
64449e37000SMatthias Springer           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
64549e37000SMatthias Springer           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
64649e37000SMatthias Springer         // Case 2: The read of the source tensor and the write to the dest
64749e37000SMatthias Springer         // tensor via an InsertSliceOp is not a conflict if the read is
64849e37000SMatthias Springer         // reading exactly that part of an equivalent tensor that the
64949e37000SMatthias Springer         // InsertSliceOp is writing.
65049e37000SMatthias Springer         //
65149e37000SMatthias Springer         // In the above example:
65249e37000SMatthias Springer         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
65349e37000SMatthias Springer         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
65449e37000SMatthias Springer         return true;
65549e37000SMatthias Springer     }
65649e37000SMatthias Springer 
65749e37000SMatthias Springer     // If uConflictingWrite is an InsertSliceOp...
65849e37000SMatthias Springer     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
65949e37000SMatthias Springer       // As an example, consider the following IR.
66049e37000SMatthias Springer       //
66149e37000SMatthias Springer       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
66249e37000SMatthias Springer       // %1 = linalg.fill %cst, %0 {inplace= [true] }
66349e37000SMatthias Springer       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
66449e37000SMatthias Springer       //     {inplace= [true] }
66549e37000SMatthias Springer       // %3 = vector.transfer_read %1, %cst
66649e37000SMatthias Springer       //
66749e37000SMatthias Springer       // In the above example:
66849e37000SMatthias Springer       // uRead             = OpOperand 0 (%1) of vector.transfer_read
66949e37000SMatthias Springer       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
67049e37000SMatthias Springer       // lastWrite         = %1
67149e37000SMatthias Springer       //
67249e37000SMatthias Springer       // This is not a conflict because the InsertSliceOp overwrites the
67349e37000SMatthias Springer       // memory segment of %1 with the exact same data. (Effectively, there
67449e37000SMatthias Springer       // is no memory write here.)
67549e37000SMatthias Springer       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
67649e37000SMatthias Springer           state.areEquivalentBufferizedValues(uRead->get(),
6778df54a6aSJacques Pienaar                                               insertSliceOp.getSource()) &&
6788df54a6aSJacques Pienaar           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
67949e37000SMatthias Springer                                     insertSliceOp))
68049e37000SMatthias Springer         return true;
68149e37000SMatthias Springer 
68249e37000SMatthias Springer     return false;
68349e37000SMatthias Springer   }
68449e37000SMatthias Springer 
68549e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
686b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
68749e37000SMatthias Springer     // insert_slice ops arise from tiling and bufferizing them out-of-place is
68849e37000SMatthias Springer     // generally a deal breaker. When used with loops, this ends up cloning the
68949e37000SMatthias Springer     // whole tensor on every single iteration and is a symptom of a
69049e37000SMatthias Springer     // catastrophically bad scheduling decision.
69149e37000SMatthias Springer     // TODO: be very loud about it or even consider failing the pass.
69249e37000SMatthias Springer     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
69349e37000SMatthias Springer     Location loc = insertSliceOp.getLoc();
6945d50f51cSMatthias Springer     FailureOr<Value> dstMemref =
6955d50f51cSMatthias Springer         getBuffer(rewriter, insertSliceOp.getDest(), options);
6965d50f51cSMatthias Springer     if (failed(dstMemref))
6975d50f51cSMatthias Springer       return failure();
69849e37000SMatthias Springer 
69949e37000SMatthias Springer     // Expand offsets, sizes and strides to the full rank to handle the
70049e37000SMatthias Springer     // rank-reducing case.
70149e37000SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
70249e37000SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
70349e37000SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
70449e37000SMatthias Springer     OffsetSizeAndStrideOpInterface::expandToRank(
7055d50f51cSMatthias Springer         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
70649e37000SMatthias Springer         [&](Value target, int64_t dim) -> OpFoldResult {
70749e37000SMatthias Springer           auto shapedType = target.getType().cast<ShapedType>();
70849e37000SMatthias Springer           if (shapedType.isDynamicDim(dim))
70949e37000SMatthias Springer             return rewriter.create<memref::DimOp>(loc, target, dim).result();
71049e37000SMatthias Springer           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
71149e37000SMatthias Springer         });
71249e37000SMatthias Springer     // Take a subview of the dst.
7135d50f51cSMatthias Springer     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
71449e37000SMatthias Springer     auto subviewMemRefType =
71549e37000SMatthias Springer         memref::SubViewOp::inferRankReducedResultType(
71649e37000SMatthias Springer             insertSliceOp.getSourceType().getRank(), dstMemrefType,
71749e37000SMatthias Springer             mixedOffsets, mixedSizes, mixedStrides)
71849e37000SMatthias Springer             .cast<MemRefType>();
71949e37000SMatthias Springer     Value subView = rewriter.create<memref::SubViewOp>(
7205d50f51cSMatthias Springer         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
72149e37000SMatthias Springer         mixedStrides);
72249e37000SMatthias Springer 
72349e37000SMatthias Springer     // Copy tensor. If this tensor.insert_slice has a matching
72449e37000SMatthias Springer     // tensor.extract_slice, the copy operation will eventually fold away.
7255d50f51cSMatthias Springer     FailureOr<Value> srcMemref =
7265d50f51cSMatthias Springer         getBuffer(rewriter, insertSliceOp.getSource(), options);
7275d50f51cSMatthias Springer     if (failed(srcMemref))
7285d50f51cSMatthias Springer       return failure();
7295d50f51cSMatthias Springer     if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
73049e37000SMatthias Springer       return failure();
73149e37000SMatthias Springer 
7325d50f51cSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
73349e37000SMatthias Springer     return success();
73449e37000SMatthias Springer   }
73549e37000SMatthias Springer };
73649e37000SMatthias Springer 
737fc08d1c2SMatthias Springer /// Bufferization of tensor.rank. Replace with memref.rank.
738fc08d1c2SMatthias Springer struct RankOpInterface
739fc08d1c2SMatthias Springer     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
740fc08d1c2SMatthias Springer                                                     tensor::RankOp> {
741fc08d1c2SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
7429597b16aSMatthias Springer                               const AnalysisState &state) const {
743fc08d1c2SMatthias Springer     return true;
744fc08d1c2SMatthias Springer   }
745fc08d1c2SMatthias Springer 
746fc08d1c2SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
7479597b16aSMatthias Springer                                const AnalysisState &state) const {
748fc08d1c2SMatthias Springer     return false;
749fc08d1c2SMatthias Springer   }
750fc08d1c2SMatthias Springer 
7519597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
7529597b16aSMatthias Springer                                             const AnalysisState &state) const {
753585a8a32SMatthias Springer     return {};
754fc08d1c2SMatthias Springer   }
755fc08d1c2SMatthias Springer 
756fc08d1c2SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
757b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
758fc08d1c2SMatthias Springer     auto rankOp = cast<tensor::RankOp>(op);
7595d50f51cSMatthias Springer     FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
7605d50f51cSMatthias Springer     if (failed(v))
7615d50f51cSMatthias Springer       return failure();
762fc08d1c2SMatthias Springer     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
7635d50f51cSMatthias Springer                                                  *v);
764fc08d1c2SMatthias Springer     return success();
765fc08d1c2SMatthias Springer   }
766fc08d1c2SMatthias Springer };
767fc08d1c2SMatthias Springer 
768e287d647SAshay Rane /// Bufferization of tensor.reshape. Replace with memref.reshape.
769e287d647SAshay Rane struct ReshapeOpInterface
770e287d647SAshay Rane     : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
771e287d647SAshay Rane                                                     tensor::ReshapeOp> {
772e287d647SAshay Rane   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
773e287d647SAshay Rane                               const AnalysisState &state) const {
774e287d647SAshay Rane     if (&opOperand == &op->getOpOperand(1) /* shape */)
775e287d647SAshay Rane       return true;
776e287d647SAshay Rane     return false;
777e287d647SAshay Rane   }
778e287d647SAshay Rane 
779e287d647SAshay Rane   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
780e287d647SAshay Rane                                const AnalysisState &state) const {
781e287d647SAshay Rane     return false;
782e287d647SAshay Rane   }
783e287d647SAshay Rane 
784e287d647SAshay Rane   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
785e287d647SAshay Rane                                             const AnalysisState &state) const {
786e287d647SAshay Rane     return {op->getOpResult(0)};
787e287d647SAshay Rane   }
788e287d647SAshay Rane 
789e287d647SAshay Rane   BufferRelation bufferRelation(Operation *op, OpResult opResult,
790e287d647SAshay Rane                                 const AnalysisState &state) const {
791e287d647SAshay Rane     return BufferRelation::Equivalent;
792e287d647SAshay Rane   }
793e287d647SAshay Rane 
794e287d647SAshay Rane   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
795b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
796e287d647SAshay Rane     auto reshapeOp = cast<tensor::ReshapeOp>(op);
7975d50f51cSMatthias Springer     FailureOr<Value> srcBuffer =
7985d50f51cSMatthias Springer         getBuffer(rewriter, reshapeOp.getSource(), options);
7995d50f51cSMatthias Springer     FailureOr<Value> shapeBuffer =
8005d50f51cSMatthias Springer         getBuffer(rewriter, reshapeOp.getShape(), options);
8015d50f51cSMatthias Springer     if (failed(srcBuffer) || failed(shapeBuffer))
8025d50f51cSMatthias Springer       return failure();
803e287d647SAshay Rane     auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
804c0b0b6a0SMatthias Springer     auto resultMemRefType = getMemRefType(
805c0b0b6a0SMatthias Springer         resultTensorType, options, /*layout=*/{},
806c0b0b6a0SMatthias Springer         srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt());
807e287d647SAshay Rane     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
8085d50f51cSMatthias Springer         rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
809e287d647SAshay Rane     return success();
810e287d647SAshay Rane   }
811e287d647SAshay Rane };
812e287d647SAshay Rane 
813*7fbf55c9SNicolas Vasilache /// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
814*7fbf55c9SNicolas Vasilache /// equivalent operand / result and same offset/sizes/strides specification).
815*7fbf55c9SNicolas Vasilache static bool areEquivalentExtractSliceOps(const AnalysisState &state,
816*7fbf55c9SNicolas Vasilache                                          ExtractSliceOp st,
817*7fbf55c9SNicolas Vasilache                                          ParallelInsertSliceOp sti) {
818*7fbf55c9SNicolas Vasilache   if (!st || !sti)
819*7fbf55c9SNicolas Vasilache     return false;
820*7fbf55c9SNicolas Vasilache   if (st != sti &&
821*7fbf55c9SNicolas Vasilache       !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
822*7fbf55c9SNicolas Vasilache     return false;
823*7fbf55c9SNicolas Vasilache   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
824*7fbf55c9SNicolas Vasilache     return false;
825*7fbf55c9SNicolas Vasilache   return true;
826*7fbf55c9SNicolas Vasilache }
827*7fbf55c9SNicolas Vasilache 
828*7fbf55c9SNicolas Vasilache /// Return true if `value` is originating from an ExtractSliceOp that matches
829*7fbf55c9SNicolas Vasilache /// the given InsertSliceOp.
830*7fbf55c9SNicolas Vasilache static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
831*7fbf55c9SNicolas Vasilache                                       ParallelInsertSliceOp insertOp) {
832*7fbf55c9SNicolas Vasilache   auto condition = [&](Value val) {
833*7fbf55c9SNicolas Vasilache     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
834*7fbf55c9SNicolas Vasilache       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
835*7fbf55c9SNicolas Vasilache         return true;
836*7fbf55c9SNicolas Vasilache     return false;
837*7fbf55c9SNicolas Vasilache   };
838*7fbf55c9SNicolas Vasilache 
839*7fbf55c9SNicolas Vasilache   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
840*7fbf55c9SNicolas Vasilache                       condition);
841*7fbf55c9SNicolas Vasilache }
842*7fbf55c9SNicolas Vasilache 
843*7fbf55c9SNicolas Vasilache /// Analysis of ParallelInsertSliceOp.
844*7fbf55c9SNicolas Vasilache struct ParallelInsertSliceOpInterface
845*7fbf55c9SNicolas Vasilache     : public BufferizableOpInterface::ExternalModel<
846*7fbf55c9SNicolas Vasilache           ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
847*7fbf55c9SNicolas Vasilache   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
848*7fbf55c9SNicolas Vasilache                                             const AnalysisState &state) const {
849*7fbf55c9SNicolas Vasilache     if (&opOperand != &op->getOpOperand(1) /*dest*/)
850*7fbf55c9SNicolas Vasilache       return {};
851*7fbf55c9SNicolas Vasilache 
852*7fbf55c9SNicolas Vasilache     // ParallelInsertSliceOp itself has no results, query its tied op results.
853*7fbf55c9SNicolas Vasilache     auto insertOp = cast<ParallelInsertSliceOp>(op);
854*7fbf55c9SNicolas Vasilache     return {insertOp.getTiedOpResult()};
855*7fbf55c9SNicolas Vasilache   }
856*7fbf55c9SNicolas Vasilache 
857*7fbf55c9SNicolas Vasilache   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
858*7fbf55c9SNicolas Vasilache                               const AnalysisState &state) const {
859*7fbf55c9SNicolas Vasilache     return true;
860*7fbf55c9SNicolas Vasilache   }
861*7fbf55c9SNicolas Vasilache 
862*7fbf55c9SNicolas Vasilache   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
863*7fbf55c9SNicolas Vasilache                                const AnalysisState &state) const {
864*7fbf55c9SNicolas Vasilache     return &opOperand == &op->getOpOperand(1) /*dest*/;
865*7fbf55c9SNicolas Vasilache   }
866*7fbf55c9SNicolas Vasilache 
867*7fbf55c9SNicolas Vasilache   BufferRelation bufferRelation(Operation *op, OpResult opResult,
868*7fbf55c9SNicolas Vasilache                                 const AnalysisState &state) const {
869*7fbf55c9SNicolas Vasilache     return BufferRelation::Equivalent;
870*7fbf55c9SNicolas Vasilache   }
871*7fbf55c9SNicolas Vasilache 
872*7fbf55c9SNicolas Vasilache   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
873*7fbf55c9SNicolas Vasilache                                  const AnalysisState &state) const {
874*7fbf55c9SNicolas Vasilache     // This interface method is overridden because we want to set a custom
875*7fbf55c9SNicolas Vasilache     // insertion point for tensor copies. They should be inserted right before
876*7fbf55c9SNicolas Vasilache     // the ForeachThreadOp. E.g.:
877*7fbf55c9SNicolas Vasilache     //
878*7fbf55c9SNicolas Vasilache     // %r0, %r1 = foreach_thead ... {
879*7fbf55c9SNicolas Vasilache     //   ...
880*7fbf55c9SNicolas Vasilache     //   perform_concurrently {
881*7fbf55c9SNicolas Vasilache     //     parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
882*7fbf55c9SNicolas Vasilache     //     parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
883*7fbf55c9SNicolas Vasilache     //   }
884*7fbf55c9SNicolas Vasilache     // }
885*7fbf55c9SNicolas Vasilache     //
886*7fbf55c9SNicolas Vasilache     // After TensorCopyInsertion:
887*7fbf55c9SNicolas Vasilache     //
888*7fbf55c9SNicolas Vasilache     // %copy = bufferization.alloc_tensor() copy(%d)
889*7fbf55c9SNicolas Vasilache     // %r0, %r1 = foreach_thead ... {
890*7fbf55c9SNicolas Vasilache     //   ...
891*7fbf55c9SNicolas Vasilache     //   perform_concurrently {
892*7fbf55c9SNicolas Vasilache     //     parallel_insert_slice %a into %b ...
893*7fbf55c9SNicolas Vasilache     //     parallel_insert_slice %c into %copy ...
894*7fbf55c9SNicolas Vasilache     //   }
895*7fbf55c9SNicolas Vasilache     // }
896*7fbf55c9SNicolas Vasilache 
897*7fbf55c9SNicolas Vasilache     OpBuilder::InsertionGuard g(rewriter);
898*7fbf55c9SNicolas Vasilache     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
899*7fbf55c9SNicolas Vasilache     ParallelCombiningOpInterface parallelCombiningParent =
900*7fbf55c9SNicolas Vasilache         parallelInsertSliceOp.getParallelCombiningParent();
901*7fbf55c9SNicolas Vasilache     Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
902*7fbf55c9SNicolas Vasilache 
903*7fbf55c9SNicolas Vasilache     // Nothing to do if the destination tensor is inplace.
904*7fbf55c9SNicolas Vasilache     assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
905*7fbf55c9SNicolas Vasilache            "source is always in-place");
906*7fbf55c9SNicolas Vasilache     if (state.isInPlace(op->getOpOperand(1) /*dest*/))
907*7fbf55c9SNicolas Vasilache       return success();
908*7fbf55c9SNicolas Vasilache 
909*7fbf55c9SNicolas Vasilache     // Find corresponding OpResult.
910*7fbf55c9SNicolas Vasilache     OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
911*7fbf55c9SNicolas Vasilache 
912*7fbf55c9SNicolas Vasilache     // Insert tensor allocation right before the ForeachThreadOp.
913*7fbf55c9SNicolas Vasilache     rewriter.setInsertionPoint(parallelIteratingOp);
914*7fbf55c9SNicolas Vasilache     bool isYielded = state.isTensorYielded(opResult);
915*7fbf55c9SNicolas Vasilache     FailureOr<Value> alloc = allocateTensorForShapedValue(
916*7fbf55c9SNicolas Vasilache         rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
917*7fbf55c9SNicolas Vasilache         /*escape=*/isYielded, state.getOptions());
918*7fbf55c9SNicolas Vasilache     if (failed(alloc))
919*7fbf55c9SNicolas Vasilache       return failure();
920*7fbf55c9SNicolas Vasilache 
921*7fbf55c9SNicolas Vasilache     // Update destination operand.
922*7fbf55c9SNicolas Vasilache     rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
923*7fbf55c9SNicolas Vasilache       parallelInsertSliceOp.getDestMutable().assign(*alloc);
924*7fbf55c9SNicolas Vasilache     });
925*7fbf55c9SNicolas Vasilache 
926*7fbf55c9SNicolas Vasilache     return success();
927*7fbf55c9SNicolas Vasilache   }
928*7fbf55c9SNicolas Vasilache 
929*7fbf55c9SNicolas Vasilache   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
930*7fbf55c9SNicolas Vasilache                           const BufferizationOptions &options) const {
931*7fbf55c9SNicolas Vasilache     OpBuilder::InsertionGuard g(rewriter);
932*7fbf55c9SNicolas Vasilache     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
933*7fbf55c9SNicolas Vasilache     ParallelCombiningOpInterface parallelCombiningParent =
934*7fbf55c9SNicolas Vasilache         parallelInsertSliceOp.getParallelCombiningParent();
935*7fbf55c9SNicolas Vasilache     Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
936*7fbf55c9SNicolas Vasilache 
937*7fbf55c9SNicolas Vasilache     // Get destination buffer.
938*7fbf55c9SNicolas Vasilache     FailureOr<Value> destBuffer =
939*7fbf55c9SNicolas Vasilache         getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
940*7fbf55c9SNicolas Vasilache     if (failed(destBuffer))
941*7fbf55c9SNicolas Vasilache       return failure();
942*7fbf55c9SNicolas Vasilache 
943*7fbf55c9SNicolas Vasilache     // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
944*7fbf55c9SNicolas Vasilache     rewriter.setInsertionPoint(parallelCombiningParent);
945*7fbf55c9SNicolas Vasilache     FailureOr<Value> srcBuffer =
946*7fbf55c9SNicolas Vasilache         getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
947*7fbf55c9SNicolas Vasilache     if (failed(srcBuffer))
948*7fbf55c9SNicolas Vasilache       return failure();
949*7fbf55c9SNicolas Vasilache     Value subview = rewriter.create<memref::SubViewOp>(
950*7fbf55c9SNicolas Vasilache         parallelInsertSliceOp.getLoc(), *destBuffer,
951*7fbf55c9SNicolas Vasilache         parallelInsertSliceOp.getMixedOffsets(),
952*7fbf55c9SNicolas Vasilache         parallelInsertSliceOp.getMixedSizes(),
953*7fbf55c9SNicolas Vasilache         parallelInsertSliceOp.getMixedStrides());
954*7fbf55c9SNicolas Vasilache     // This memcpy will fold away if everything bufferizes in-place.
955*7fbf55c9SNicolas Vasilache     if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
956*7fbf55c9SNicolas Vasilache                                     *srcBuffer, subview)))
957*7fbf55c9SNicolas Vasilache       return failure();
958*7fbf55c9SNicolas Vasilache 
959*7fbf55c9SNicolas Vasilache     // Replace all uses of parallelIteratingOp (just the corresponding result).
960*7fbf55c9SNicolas Vasilache     rewriter.setInsertionPointAfter(parallelIteratingOp);
961*7fbf55c9SNicolas Vasilache     Value toTensorOp =
962*7fbf55c9SNicolas Vasilache         rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
963*7fbf55c9SNicolas Vasilache     // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
964*7fbf55c9SNicolas Vasilache     SmallVector<OpOperand *> resultUses = llvm::to_vector(
965*7fbf55c9SNicolas Vasilache         llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
966*7fbf55c9SNicolas Vasilache                         [](OpOperand &use) { return &use; }));
967*7fbf55c9SNicolas Vasilache     for (OpOperand *use : resultUses) {
968*7fbf55c9SNicolas Vasilache       rewriter.updateRootInPlace(use->getOwner(),
969*7fbf55c9SNicolas Vasilache                                  [&]() { use->set(toTensorOp); });
970*7fbf55c9SNicolas Vasilache     }
971*7fbf55c9SNicolas Vasilache     rewriter.eraseOp(op);
972*7fbf55c9SNicolas Vasilache     return success();
973*7fbf55c9SNicolas Vasilache   }
974*7fbf55c9SNicolas Vasilache 
975*7fbf55c9SNicolas Vasilache   // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
976*7fbf55c9SNicolas Vasilache   // the code.
977*7fbf55c9SNicolas Vasilache   bool isNotConflicting(Operation *op, OpOperand *uRead,
978*7fbf55c9SNicolas Vasilache                         OpOperand *uConflictingWrite,
979*7fbf55c9SNicolas Vasilache                         const AnalysisState &state) const {
980*7fbf55c9SNicolas Vasilache     Operation *readingOp = uRead->getOwner();
981*7fbf55c9SNicolas Vasilache     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
982*7fbf55c9SNicolas Vasilache 
983*7fbf55c9SNicolas Vasilache     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
984*7fbf55c9SNicolas Vasilache     // uRead is an InsertSliceOp...
985*7fbf55c9SNicolas Vasilache     if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
986*7fbf55c9SNicolas Vasilache       // As an example, consider the following IR.
987*7fbf55c9SNicolas Vasilache       //
988*7fbf55c9SNicolas Vasilache       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
989*7fbf55c9SNicolas Vasilache       // %1 = linalg.fill %cst, %0 {inplace= [true] }
990*7fbf55c9SNicolas Vasilache       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
991*7fbf55c9SNicolas Vasilache       //     {inplace= [true] }
992*7fbf55c9SNicolas Vasilache 
993*7fbf55c9SNicolas Vasilache       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
994*7fbf55c9SNicolas Vasilache       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
995*7fbf55c9SNicolas Vasilache           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
996*7fbf55c9SNicolas Vasilache                                     insertSliceOp))
997*7fbf55c9SNicolas Vasilache         // Case 1: The main insight is that InsertSliceOp reads only part of
998*7fbf55c9SNicolas Vasilache         // the destination tensor. The overwritten area is not read. If
999*7fbf55c9SNicolas Vasilache         // uConflictingWrite writes into exactly the memory location that is
1000*7fbf55c9SNicolas Vasilache         // being read by uRead, this is not a conflict.
1001*7fbf55c9SNicolas Vasilache         //
1002*7fbf55c9SNicolas Vasilache         // In the above example:
1003*7fbf55c9SNicolas Vasilache         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
1004*7fbf55c9SNicolas Vasilache         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
1005*7fbf55c9SNicolas Vasilache         //
1006*7fbf55c9SNicolas Vasilache         // The read of %t does not conflict with the write of the FillOp
1007*7fbf55c9SNicolas Vasilache         // (same aliases!) because the area that the FillOp operates on is
1008*7fbf55c9SNicolas Vasilache         // exactly the one that is *not* read via %t.
1009*7fbf55c9SNicolas Vasilache         return true;
1010*7fbf55c9SNicolas Vasilache 
1011*7fbf55c9SNicolas Vasilache       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
1012*7fbf55c9SNicolas Vasilache           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1013*7fbf55c9SNicolas Vasilache           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
1014*7fbf55c9SNicolas Vasilache         // Case 2: The read of the source tensor and the write to the dest
1015*7fbf55c9SNicolas Vasilache         // tensor via an InsertSliceOp is not a conflict if the read is
1016*7fbf55c9SNicolas Vasilache         // reading exactly that part of an equivalent tensor that the
1017*7fbf55c9SNicolas Vasilache         // InsertSliceOp is writing.
1018*7fbf55c9SNicolas Vasilache         //
1019*7fbf55c9SNicolas Vasilache         // In the above example:
1020*7fbf55c9SNicolas Vasilache         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
1021*7fbf55c9SNicolas Vasilache         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1022*7fbf55c9SNicolas Vasilache         return true;
1023*7fbf55c9SNicolas Vasilache     }
1024*7fbf55c9SNicolas Vasilache 
1025*7fbf55c9SNicolas Vasilache     // If uConflictingWrite is an InsertSliceOp...
1026*7fbf55c9SNicolas Vasilache     if (auto insertSliceOp =
1027*7fbf55c9SNicolas Vasilache             dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
1028*7fbf55c9SNicolas Vasilache       // As an example, consider the following IR.
1029*7fbf55c9SNicolas Vasilache       //
1030*7fbf55c9SNicolas Vasilache       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1031*7fbf55c9SNicolas Vasilache       // %1 = linalg.fill %cst, %0 {inplace= [true] }
1032*7fbf55c9SNicolas Vasilache       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1033*7fbf55c9SNicolas Vasilache       //     {inplace= [true] }
1034*7fbf55c9SNicolas Vasilache       // %3 = vector.transfer_read %1, %cst
1035*7fbf55c9SNicolas Vasilache       //
1036*7fbf55c9SNicolas Vasilache       // In the above example:
1037*7fbf55c9SNicolas Vasilache       // uRead             = OpOperand 0 (%1) of vector.transfer_read
1038*7fbf55c9SNicolas Vasilache       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1039*7fbf55c9SNicolas Vasilache       // lastWrite         = %1
1040*7fbf55c9SNicolas Vasilache       //
1041*7fbf55c9SNicolas Vasilache       // This is not a conflict because the InsertSliceOp overwrites the
1042*7fbf55c9SNicolas Vasilache       // memory segment of %1 with the exact same data. (Effectively, there
1043*7fbf55c9SNicolas Vasilache       // is no memory write here.)
1044*7fbf55c9SNicolas Vasilache       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1045*7fbf55c9SNicolas Vasilache           state.areEquivalentBufferizedValues(uRead->get(),
1046*7fbf55c9SNicolas Vasilache                                               insertSliceOp.getSource()) &&
1047*7fbf55c9SNicolas Vasilache           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
1048*7fbf55c9SNicolas Vasilache                                     insertSliceOp))
1049*7fbf55c9SNicolas Vasilache         return true;
1050*7fbf55c9SNicolas Vasilache 
1051*7fbf55c9SNicolas Vasilache     return false;
1052*7fbf55c9SNicolas Vasilache   }
1053*7fbf55c9SNicolas Vasilache };
1054*7fbf55c9SNicolas Vasilache 
105549e37000SMatthias Springer } // namespace
105649e37000SMatthias Springer } // namespace tensor
105749e37000SMatthias Springer } // namespace mlir
105849e37000SMatthias Springer 
105949e37000SMatthias Springer void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
106049e37000SMatthias Springer     DialectRegistry &registry) {
106177eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
106277eee579SRiver Riddle     CastOp::attachInterface<CastOpInterface>(*ctx);
106377eee579SRiver Riddle     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
106477eee579SRiver Riddle     DimOp::attachInterface<DimOpInterface>(*ctx);
106577eee579SRiver Riddle     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
106677eee579SRiver Riddle     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
106777eee579SRiver Riddle     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
106877eee579SRiver Riddle     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
106977eee579SRiver Riddle     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
107077eee579SRiver Riddle     InsertOp::attachInterface<InsertOpInterface>(*ctx);
107177eee579SRiver Riddle     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
1072*7fbf55c9SNicolas Vasilache     ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1073*7fbf55c9SNicolas Vasilache         *ctx);
107477eee579SRiver Riddle     RankOp::attachInterface<RankOpInterface>(*ctx);
1075e287d647SAshay Rane     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
107677eee579SRiver Riddle   });
107749e37000SMatthias Springer }
1078