149e37000SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
249e37000SMatthias Springer //
349e37000SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
449e37000SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
549e37000SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
649e37000SMatthias Springer //
749e37000SMatthias Springer //===----------------------------------------------------------------------===//
849e37000SMatthias Springer 
949e37000SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1149e37000SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12b3ebe3beSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1349e37000SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
14*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
1549e37000SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
1649e37000SMatthias Springer #include "mlir/IR/Dialect.h"
1749e37000SMatthias Springer #include "mlir/IR/Operation.h"
1849e37000SMatthias Springer 
1949e37000SMatthias Springer using namespace mlir;
2049e37000SMatthias Springer using namespace mlir::bufferization;
2149e37000SMatthias Springer using namespace mlir::tensor;
2249e37000SMatthias Springer 
2349e37000SMatthias Springer namespace mlir {
2449e37000SMatthias Springer namespace tensor {
2549e37000SMatthias Springer namespace {
2649e37000SMatthias Springer 
2749e37000SMatthias Springer struct CastOpInterface
2849e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
2949e37000SMatthias Springer                                                     tensor::CastOp> {
3049e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
319597b16aSMatthias Springer                               const AnalysisState &state) const {
3249e37000SMatthias Springer     return false;
3349e37000SMatthias Springer   }
3449e37000SMatthias Springer 
3549e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
369597b16aSMatthias Springer                                const AnalysisState &state) const {
3749e37000SMatthias Springer     return false;
3849e37000SMatthias Springer   }
3949e37000SMatthias Springer 
409597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
419597b16aSMatthias Springer                                             const AnalysisState &state) const {
42585a8a32SMatthias Springer     return {op->getResult(0)};
4349e37000SMatthias Springer   }
4449e37000SMatthias Springer 
4549e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
469597b16aSMatthias Springer                                 const AnalysisState &state) const {
4749e37000SMatthias Springer     return BufferRelation::Equivalent;
4849e37000SMatthias Springer   }
4949e37000SMatthias Springer 
5049e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
5249e37000SMatthias Springer     auto castOp = cast<tensor::CastOp>(op);
5349e37000SMatthias Springer 
5449e37000SMatthias Springer     // The result buffer still has the old (pre-cast) type.
558df54a6aSJacques Pienaar     Value resultBuffer = getBuffer(rewriter, castOp.getSource(), options);
56b3ebe3beSMatthias Springer     auto sourceMemRefType = resultBuffer.getType().cast<BaseMemRefType>();
5749e37000SMatthias Springer     Attribute memorySpace = sourceMemRefType.getMemorySpace();
5849e37000SMatthias Springer     TensorType resultTensorType =
5949e37000SMatthias Springer         castOp.getResult().getType().cast<TensorType>();
6049e37000SMatthias Springer     MemRefLayoutAttrInterface layout;
6149e37000SMatthias Springer 
6249e37000SMatthias Springer     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
6349e37000SMatthias Springer       if (resultTensorType.isa<RankedTensorType>())
6449e37000SMatthias Springer         layout = rankedMemRefType.getLayout();
6549e37000SMatthias Springer 
6649e37000SMatthias Springer     // Compute the new memref type.
67b55d55ecSMatthias Springer     Type resultMemRefType =
68b55d55ecSMatthias Springer         getMemRefType(resultTensorType, options, layout, memorySpace);
6949e37000SMatthias Springer 
7049e37000SMatthias Springer     // Replace the op with a memref.cast.
71b3ebe3beSMatthias Springer     assert(memref::CastOp::areCastCompatible(resultBuffer.getType(),
7249e37000SMatthias Springer                                              resultMemRefType) &&
7349e37000SMatthias Springer            "CallOp::bufferize: cast incompatible");
7449e37000SMatthias Springer     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
75b3ebe3beSMatthias Springer                                                  resultBuffer);
7649e37000SMatthias Springer 
7749e37000SMatthias Springer     return success();
7849e37000SMatthias Springer   }
7949e37000SMatthias Springer };
8049e37000SMatthias Springer 
81e6f69161SMatthias Springer /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
82e6f69161SMatthias Springer struct CollapseShapeOpInterface
83e6f69161SMatthias Springer     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
84e6f69161SMatthias Springer                                                     tensor::CollapseShapeOp> {
85e6f69161SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
869597b16aSMatthias Springer                               const AnalysisState &state) const {
87e6f69161SMatthias Springer     return false;
88e6f69161SMatthias Springer   }
89e6f69161SMatthias Springer 
90e6f69161SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
919597b16aSMatthias Springer                                const AnalysisState &state) const {
92e6f69161SMatthias Springer     return false;
93e6f69161SMatthias Springer   }
94e6f69161SMatthias Springer 
959597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
969597b16aSMatthias Springer                                             const AnalysisState &state) const {
97e6f69161SMatthias Springer     if (&opOperand == &op->getOpOperand(0) /*src*/)
98e6f69161SMatthias Springer       return {op->getOpResult(0)};
99e6f69161SMatthias Springer     return {};
100e6f69161SMatthias Springer   }
101e6f69161SMatthias Springer 
102e6f69161SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
1039597b16aSMatthias Springer                                 const AnalysisState &state) const {
104e6f69161SMatthias Springer     return BufferRelation::Equivalent;
105e6f69161SMatthias Springer   }
106e6f69161SMatthias Springer 
107e6f69161SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
108b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
109e6f69161SMatthias Springer     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
11051df6238SMatthias Springer     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
1118df54a6aSJacques Pienaar     Value buffer = getBuffer(rewriter, collapseShapeOp.getSrc(), options);
112b3ebe3beSMatthias Springer     auto bufferType = buffer.getType().cast<MemRefType>();
11351df6238SMatthias Springer 
11451df6238SMatthias Springer     if (tensorResultType.getRank() == 0) {
11551df6238SMatthias Springer       // 0-d collapses must go through a different op builder.
11673c0333dSMatthias Springer       MemRefType resultType;
11773c0333dSMatthias Springer 
11873c0333dSMatthias Springer       if (bufferType.getLayout().isIdentity()) {
11973c0333dSMatthias Springer         // Standard layout: result type has no offset.
12051df6238SMatthias Springer         MemRefLayoutAttrInterface layout;
12173c0333dSMatthias Springer         resultType = MemRefType::get({}, tensorResultType.getElementType(),
12251df6238SMatthias Springer                                      layout, bufferType.getMemorySpace());
12373c0333dSMatthias Springer       } else {
12473c0333dSMatthias Springer         // Source memref has a layout map: result type has the same offset as
12573c0333dSMatthias Springer         // the source type.
12673c0333dSMatthias Springer         SmallVector<int64_t> strides;
12773c0333dSMatthias Springer         int64_t offset;
12873c0333dSMatthias Springer         if (failed(getStridesAndOffset(bufferType, strides, offset)))
12973c0333dSMatthias Springer           return failure();
13073c0333dSMatthias Springer         AffineMap resultLayout =
13173c0333dSMatthias Springer             makeStridedLinearLayoutMap({}, offset, op->getContext());
13273c0333dSMatthias Springer         resultType =
13373c0333dSMatthias Springer             MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
13473c0333dSMatthias Springer                             bufferType.getMemorySpaceAsInt());
13573c0333dSMatthias Springer       }
13673c0333dSMatthias Springer 
137e6f69161SMatthias Springer       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
1388df54a6aSJacques Pienaar           rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
139e6f69161SMatthias Springer       return success();
140e6f69161SMatthias Springer     }
14151df6238SMatthias Springer 
142d7a9bf91SMatthias Springer     // If the dims are not collapsible (due to an incompatible source layout
143d7a9bf91SMatthias Springer     // map), force an out-of-place bufferization, i.e., a buffer copy. This
144d7a9bf91SMatthias Springer     // newly allocated buffer will have no layout map and thus be collapsible.
145a74e5a89SAdrian Kuegel     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
146d7a9bf91SMatthias Springer         bufferType, collapseShapeOp.getReassociationIndices());
147b3ebe3beSMatthias Springer     if (!canBeCollapsed) {
148b3ebe3beSMatthias Springer       // TODO: Create alloc_tensor ops during TensorCopyInsertion.
149b55d55ecSMatthias Springer       AnalysisState analysisState(options);
150b3ebe3beSMatthias Springer       Value tensorAlloc = allocateTensorForShapedValue(
1518df54a6aSJacques Pienaar           rewriter, op->getLoc(), collapseShapeOp.getSrc(),
1528df54a6aSJacques Pienaar           analysisState.isTensorYielded(collapseShapeOp.getResult()));
153b3ebe3beSMatthias Springer       auto memrefType =
154b3ebe3beSMatthias Springer           MemRefType::get(collapseShapeOp.getSrcType().getShape(),
155b3ebe3beSMatthias Springer                           collapseShapeOp.getSrcType().getElementType(),
156b3ebe3beSMatthias Springer                           AffineMap(), bufferType.getMemorySpaceAsInt());
157b3ebe3beSMatthias Springer       buffer = rewriter.create<bufferization::ToMemrefOp>(
158b3ebe3beSMatthias Springer           op->getLoc(), memrefType, tensorAlloc);
159b3ebe3beSMatthias Springer     }
160d7a9bf91SMatthias Springer 
16151df6238SMatthias Springer     // Result type is inferred by the builder.
16251df6238SMatthias Springer     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
163b3ebe3beSMatthias Springer         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
16451df6238SMatthias Springer     return success();
16551df6238SMatthias Springer   }
166e6f69161SMatthias Springer };
167e6f69161SMatthias Springer 
16849e37000SMatthias Springer /// Bufferization of tensor.dim. Replace with memref.dim.
16949e37000SMatthias Springer struct DimOpInterface
17049e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
17149e37000SMatthias Springer                                                     tensor::DimOp> {
17249e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1739597b16aSMatthias Springer                               const AnalysisState &state) const {
17449e37000SMatthias Springer     return true;
17549e37000SMatthias Springer   }
17649e37000SMatthias Springer 
17749e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1789597b16aSMatthias Springer                                const AnalysisState &state) const {
17949e37000SMatthias Springer     return false;
18049e37000SMatthias Springer   }
18149e37000SMatthias Springer 
1829597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
1839597b16aSMatthias Springer                                             const AnalysisState &state) const {
184585a8a32SMatthias Springer     return {};
18549e37000SMatthias Springer   }
18649e37000SMatthias Springer 
18749e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
188b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
18949e37000SMatthias Springer     auto dimOp = cast<tensor::DimOp>(op);
1908df54a6aSJacques Pienaar     auto v = getBuffer(rewriter, dimOp.getSource(), options);
1918df54a6aSJacques Pienaar     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v,
1928df54a6aSJacques Pienaar                                                 dimOp.getIndex());
19349e37000SMatthias Springer     return success();
19449e37000SMatthias Springer   }
19549e37000SMatthias Springer };
19649e37000SMatthias Springer 
197e6f69161SMatthias Springer /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
198e6f69161SMatthias Springer struct ExpandShapeOpInterface
199e6f69161SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
200e6f69161SMatthias Springer                                                     tensor::ExpandShapeOp> {
201e6f69161SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
2029597b16aSMatthias Springer                               const AnalysisState &state) const {
203e6f69161SMatthias Springer     return false;
204e6f69161SMatthias Springer   }
205e6f69161SMatthias Springer 
206e6f69161SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
2079597b16aSMatthias Springer                                const AnalysisState &state) const {
208e6f69161SMatthias Springer     return false;
209e6f69161SMatthias Springer   }
210e6f69161SMatthias Springer 
2119597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
2129597b16aSMatthias Springer                                             const AnalysisState &state) const {
213e6f69161SMatthias Springer     if (&opOperand == &op->getOpOperand(0) /*src*/)
214e6f69161SMatthias Springer       return {op->getOpResult(0)};
215e6f69161SMatthias Springer     return {};
216e6f69161SMatthias Springer   }
217e6f69161SMatthias Springer 
218e6f69161SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
2199597b16aSMatthias Springer                                 const AnalysisState &state) const {
220e6f69161SMatthias Springer     return BufferRelation::Equivalent;
221e6f69161SMatthias Springer   }
222e6f69161SMatthias Springer 
223e6f69161SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
224b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
225e6f69161SMatthias Springer     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
22651df6238SMatthias Springer     auto tensorResultType = expandShapeOp.getResultType();
2278df54a6aSJacques Pienaar     auto buffer = getBuffer(rewriter, expandShapeOp.getSrc(), options);
22851df6238SMatthias Springer 
22951df6238SMatthias Springer     // Memref result type is inferred by the builder based on reassociation
23051df6238SMatthias Springer     // indices and result shape.
231e6f69161SMatthias Springer     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
232b3ebe3beSMatthias Springer         rewriter, op, tensorResultType.getShape(), buffer,
23351df6238SMatthias Springer         expandShapeOp.getReassociationIndices());
234e6f69161SMatthias Springer     return success();
235e6f69161SMatthias Springer   }
236e6f69161SMatthias Springer };
237e6f69161SMatthias Springer 
23849e37000SMatthias Springer /// Bufferization of tensor.extract_slice. Replace with memref.subview.
23949e37000SMatthias Springer struct ExtractSliceOpInterface
24049e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
24149e37000SMatthias Springer                                                     tensor::ExtractSliceOp> {
24249e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
2439597b16aSMatthias Springer                               const AnalysisState &state) const {
24449e37000SMatthias Springer     return false;
24549e37000SMatthias Springer   }
24649e37000SMatthias Springer 
24749e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
2489597b16aSMatthias Springer                                const AnalysisState &state) const {
24949e37000SMatthias Springer     return false;
25049e37000SMatthias Springer   }
25149e37000SMatthias Springer 
2529597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
2539597b16aSMatthias Springer                                             const AnalysisState &state) const {
254585a8a32SMatthias Springer     if (&opOperand == &op->getOpOperand(0) /*source*/)
255585a8a32SMatthias Springer       return {op->getOpResult(0)};
256585a8a32SMatthias Springer     return {};
25749e37000SMatthias Springer   }
25849e37000SMatthias Springer 
25949e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
2609597b16aSMatthias Springer                                 const AnalysisState &state) const {
26149e37000SMatthias Springer     return BufferRelation::None;
26249e37000SMatthias Springer   }
26349e37000SMatthias Springer 
26449e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
265b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
26649e37000SMatthias Springer     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
26749e37000SMatthias Springer     Location loc = extractSliceOp.getLoc();
268d7a9bf91SMatthias Springer 
269d7a9bf91SMatthias Springer     // Even if this op was decided to bufferize out-of-place, do not insert the
270d7a9bf91SMatthias Springer     // buffer copy yet. This is done later in this function.
2718df54a6aSJacques Pienaar     auto srcMemref = getBuffer(rewriter, extractSliceOp.getSource(), options);
272b3ebe3beSMatthias Springer     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
27349e37000SMatthias Springer     auto dstTensorType =
2748df54a6aSJacques Pienaar         extractSliceOp.getResult().getType().cast<RankedTensorType>();
27549e37000SMatthias Springer 
27649e37000SMatthias Springer     // Expand offsets, sizes and strides to the full rank to handle the
27749e37000SMatthias Springer     // rank-reducing case.
27849e37000SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
27949e37000SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
28049e37000SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
28149e37000SMatthias Springer     OffsetSizeAndStrideOpInterface::expandToRank(
282b3ebe3beSMatthias Springer         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
28349e37000SMatthias Springer         [&](Value target, int64_t dim) -> OpFoldResult {
28449e37000SMatthias Springer           auto shapedType = target.getType().cast<ShapedType>();
28549e37000SMatthias Springer           if (shapedType.isDynamicDim(dim))
28649e37000SMatthias Springer             return rewriter.create<memref::DimOp>(loc, target, dim).result();
28749e37000SMatthias Springer           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
28849e37000SMatthias Springer         });
28949e37000SMatthias Springer     // Bufferize to subview.
29049e37000SMatthias Springer     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
29149e37000SMatthias Springer                                  dstTensorType.getRank(), srcMemrefType,
29249e37000SMatthias Springer                                  mixedOffsets, mixedSizes, mixedStrides)
29349e37000SMatthias Springer                                  .cast<MemRefType>();
29449e37000SMatthias Springer     Value subView = rewriter.create<memref::SubViewOp>(
295b3ebe3beSMatthias Springer         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
29649e37000SMatthias Springer         mixedStrides);
29749e37000SMatthias Springer 
29849e37000SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, subView);
29949e37000SMatthias Springer     return success();
30049e37000SMatthias Springer   }
30149e37000SMatthias Springer };
30249e37000SMatthias Springer 
30349e37000SMatthias Springer /// Bufferization of tensor.extract. Replace with memref.load.
30449e37000SMatthias Springer struct ExtractOpInterface
30549e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
30649e37000SMatthias Springer                                                     tensor::ExtractOp> {
30749e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
3089597b16aSMatthias Springer                               const AnalysisState &state) const {
30949e37000SMatthias Springer     return true;
31049e37000SMatthias Springer   }
31149e37000SMatthias Springer 
31249e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
3139597b16aSMatthias Springer                                const AnalysisState &state) const {
31449e37000SMatthias Springer     return false;
31549e37000SMatthias Springer   }
31649e37000SMatthias Springer 
3179597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
3189597b16aSMatthias Springer                                             const AnalysisState &state) const {
319585a8a32SMatthias Springer     return {};
32049e37000SMatthias Springer   }
32149e37000SMatthias Springer 
32249e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
323b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
32449e37000SMatthias Springer     auto extractOp = cast<tensor::ExtractOp>(op);
3258df54a6aSJacques Pienaar     Value srcMemref = getBuffer(rewriter, extractOp.getTensor(), options);
326b3ebe3beSMatthias Springer     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
3278df54a6aSJacques Pienaar                                                  extractOp.getIndices());
32849e37000SMatthias Springer     return success();
32949e37000SMatthias Springer   }
33049e37000SMatthias Springer };
33149e37000SMatthias Springer 
332d581c94dSMatthias Springer // Implements backtracking to traverse indices of the output buffer while
333d581c94dSMatthias Springer // iterating over op.elements().
334d581c94dSMatthias Springer static void createStores(RewriterBase &rewriter, Location loc, int dim,
335d581c94dSMatthias Springer                          Value buffer, ArrayRef<int64_t> shape,
336d581c94dSMatthias Springer                          ArrayRef<Value> constants,
337d581c94dSMatthias Springer                          OperandRange::iterator &elementIt,
338d581c94dSMatthias Springer                          SmallVectorImpl<Value> &indices) {
339d581c94dSMatthias Springer   if (dim == static_cast<int>(shape.size()) - 1) {
340d581c94dSMatthias Springer     for (int i = 0; i < shape.back(); ++i) {
341d581c94dSMatthias Springer       indices.back() = constants[i];
342d581c94dSMatthias Springer       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
343d581c94dSMatthias Springer       ++elementIt;
344d581c94dSMatthias Springer     }
345d581c94dSMatthias Springer     return;
346d581c94dSMatthias Springer   }
347d581c94dSMatthias Springer   for (int i = 0; i < shape[dim]; ++i) {
348d581c94dSMatthias Springer     indices[dim] = constants[i];
349d581c94dSMatthias Springer     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
350d581c94dSMatthias Springer                  indices);
351d581c94dSMatthias Springer   }
352d581c94dSMatthias Springer }
353d581c94dSMatthias Springer 
354d581c94dSMatthias Springer /// Bufferization of tensor.from_elements.
355d581c94dSMatthias Springer struct FromElementsOpInterface
356d581c94dSMatthias Springer     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
357d581c94dSMatthias Springer                                                     tensor::FromElementsOp> {
358d581c94dSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
359b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
360d581c94dSMatthias Springer     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
361d581c94dSMatthias Springer 
362d581c94dSMatthias Springer     // Allocate a buffer for the result.
363d581c94dSMatthias Springer     Location loc = op->getLoc();
364d581c94dSMatthias Springer     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
365d581c94dSMatthias Springer     auto shape = tensorType.getShape();
366b3ebe3beSMatthias Springer     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
367b55d55ecSMatthias Springer     AnalysisState analysisState(options);
368b3ebe3beSMatthias Springer     Value tensorAlloc = allocateTensorForShapedValue(
3698df54a6aSJacques Pienaar         rewriter, loc, fromElementsOp.getResult(),
3708df54a6aSJacques Pienaar         analysisState.isTensorYielded(fromElementsOp.getResult()),
371b3ebe3beSMatthias Springer         /*copy=*/false);
372b3ebe3beSMatthias Springer     auto memrefType =
373b3ebe3beSMatthias Springer         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
374b3ebe3beSMatthias Springer     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
375b3ebe3beSMatthias Springer         op->getLoc(), memrefType, tensorAlloc);
376d581c94dSMatthias Springer 
377d581c94dSMatthias Springer     // Case: tensor<0xelem_type>.
3788df54a6aSJacques Pienaar     if (fromElementsOp.getElements().empty()) {
379d581c94dSMatthias Springer       replaceOpWithBufferizedValues(rewriter, op, buffer);
380d581c94dSMatthias Springer       return success();
381d581c94dSMatthias Springer     }
382d581c94dSMatthias Springer 
383d581c94dSMatthias Springer     // Case: tensor<elem_type>.
384d581c94dSMatthias Springer     if (shape.empty()) {
3858df54a6aSJacques Pienaar       rewriter.create<memref::StoreOp>(
3868df54a6aSJacques Pienaar           loc, fromElementsOp.getElements().front(), buffer);
387d581c94dSMatthias Springer       replaceOpWithBufferizedValues(rewriter, op, buffer);
388d581c94dSMatthias Springer       return success();
389d581c94dSMatthias Springer     }
390d581c94dSMatthias Springer 
391d581c94dSMatthias Springer     // Create constants for the range of possible indices [0, max{shape_i}).
392d581c94dSMatthias Springer     auto maxDim = *std::max_element(shape.begin(), shape.end());
393d581c94dSMatthias Springer     SmallVector<Value, 2> constants;
394d581c94dSMatthias Springer     constants.reserve(maxDim);
395d581c94dSMatthias Springer     for (int i = 0; i < maxDim; ++i)
396d581c94dSMatthias Springer       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
397d581c94dSMatthias Springer 
398d581c94dSMatthias Springer     // Traverse all `elements` and create `memref.store` ops.
3998df54a6aSJacques Pienaar     auto elementIt = fromElementsOp.getElements().begin();
400d581c94dSMatthias Springer     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
401d581c94dSMatthias Springer     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
402d581c94dSMatthias Springer                  indices);
403d581c94dSMatthias Springer 
404d581c94dSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, buffer);
405d581c94dSMatthias Springer     return success();
406d581c94dSMatthias Springer   }
407d581c94dSMatthias Springer };
408d581c94dSMatthias Springer 
40971bbb78bSMatthias Springer /// Bufferization of tensor.generate.
41071bbb78bSMatthias Springer struct GenerateOpInterface
41171bbb78bSMatthias Springer     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
41271bbb78bSMatthias Springer                                                     tensor::GenerateOp> {
41371bbb78bSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
414b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
41571bbb78bSMatthias Springer     auto generateOp = cast<tensor::GenerateOp>(op);
416b3ebe3beSMatthias Springer     auto tensorType = generateOp.getType().cast<RankedTensorType>();
41771bbb78bSMatthias Springer     // Allocate memory.
41871bbb78bSMatthias Springer     Location loc = op->getLoc();
419b3ebe3beSMatthias Springer     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
420b55d55ecSMatthias Springer     AnalysisState analysisState(options);
421b3ebe3beSMatthias Springer     Value tensorAlloc = allocateTensorForShapedValue(
4228df54a6aSJacques Pienaar         rewriter, loc, generateOp.getResult(),
4238df54a6aSJacques Pienaar         analysisState.isTensorYielded(generateOp.getResult()),
424b3ebe3beSMatthias Springer         /*copy=*/false);
425b3ebe3beSMatthias Springer     auto memrefType =
426b3ebe3beSMatthias Springer         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
427b3ebe3beSMatthias Springer     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
428b3ebe3beSMatthias Springer         op->getLoc(), memrefType, tensorAlloc);
42971bbb78bSMatthias Springer 
43071bbb78bSMatthias Springer     // Collect loop bounds.
43171bbb78bSMatthias Springer     int64_t rank = memrefType.getRank();
43271bbb78bSMatthias Springer     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
43371bbb78bSMatthias Springer     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
43471bbb78bSMatthias Springer     SmallVector<Value, 4> lowerBounds(rank, zero);
43571bbb78bSMatthias Springer     SmallVector<Value, 4> steps(rank, one);
43671bbb78bSMatthias Springer     SmallVector<Value, 4> upperBounds;
43771bbb78bSMatthias Springer     int nextDynamicIndex = 0;
43871bbb78bSMatthias Springer     for (int i = 0; i < rank; i++) {
4398df54a6aSJacques Pienaar       Value upperBound =
4408df54a6aSJacques Pienaar           memrefType.isDynamicDim(i)
4418df54a6aSJacques Pienaar               ? generateOp.getDynamicExtents()[nextDynamicIndex++]
44271bbb78bSMatthias Springer               : rewriter.create<arith::ConstantIndexOp>(
44371bbb78bSMatthias Springer                     loc, memrefType.getDimSize(i));
44471bbb78bSMatthias Springer       upperBounds.push_back(upperBound);
44571bbb78bSMatthias Springer     }
44671bbb78bSMatthias Springer 
44771bbb78bSMatthias Springer     // Generate tensor elements with a parallel loop that stores into
44871bbb78bSMatthias Springer     // each element of the resulting memref. We use mergeBlockBefore to "move"
44971bbb78bSMatthias Springer     // this op's body into the scf.parallel's body.
45071bbb78bSMatthias Springer     auto parallel =
45171bbb78bSMatthias Springer         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
45271bbb78bSMatthias Springer     Block *parallelBody = parallel.getBody();
453eca86cb2SJacques Pienaar     rewriter.mergeBlockBefore(&generateOp.getBody().front(),
45471bbb78bSMatthias Springer                               parallelBody->getTerminator(),
45571bbb78bSMatthias Springer                               parallelBody->getArguments());
45671bbb78bSMatthias Springer     // Replace the inlined yield op with a store op. The scf.parallel's builder
45771bbb78bSMatthias Springer     // already populated an scf.yield at the end, so we don't need to worry
45871bbb78bSMatthias Springer     // about creating that.
45971bbb78bSMatthias Springer     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
46071bbb78bSMatthias Springer     rewriter.setInsertionPointAfter(elementYield);
46171bbb78bSMatthias Springer     rewriter.replaceOpWithNewOp<memref::StoreOp>(
462b3ebe3beSMatthias Springer         elementYield, elementYield->getOperands()[0], buffer,
46371bbb78bSMatthias Springer         parallelBody->getArguments());
46471bbb78bSMatthias Springer 
465b3ebe3beSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, buffer);
46671bbb78bSMatthias Springer     return success();
46771bbb78bSMatthias Springer   }
46871bbb78bSMatthias Springer };
46971bbb78bSMatthias Springer 
47049e37000SMatthias Springer /// Bufferization of tensor.insert. Replace with memref.store.
47149e37000SMatthias Springer struct InsertOpInterface
47249e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
47349e37000SMatthias Springer                                                     tensor::InsertOp> {
47449e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
4759597b16aSMatthias Springer                               const AnalysisState &state) const {
47649e37000SMatthias Springer     return true;
47749e37000SMatthias Springer   }
47849e37000SMatthias Springer 
47949e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
4809597b16aSMatthias Springer                                const AnalysisState &state) const {
48149e37000SMatthias Springer     return true;
48249e37000SMatthias Springer   }
48349e37000SMatthias Springer 
4849597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
4859597b16aSMatthias Springer                                             const AnalysisState &state) const {
48649e37000SMatthias Springer     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
48749e37000SMatthias Springer            "expected dest OpOperand");
488585a8a32SMatthias Springer     return {op->getOpResult(0)};
48949e37000SMatthias Springer   }
49049e37000SMatthias Springer 
49149e37000SMatthias Springer   SmallVector<OpOperand *>
49249e37000SMatthias Springer   getAliasingOpOperand(Operation *op, OpResult opResult,
4939597b16aSMatthias Springer                        const AnalysisState &state) const {
49449e37000SMatthias Springer     return {&op->getOpOperand(1) /*dest*/};
49549e37000SMatthias Springer   }
49649e37000SMatthias Springer 
49749e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
498b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
49949e37000SMatthias Springer     auto insertOp = cast<tensor::InsertOp>(op);
5008df54a6aSJacques Pienaar     Value destMemref = getBuffer(rewriter, insertOp.getDest(), options);
5018df54a6aSJacques Pienaar     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
5028df54a6aSJacques Pienaar                                      destMemref, insertOp.getIndices());
503b3ebe3beSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, destMemref);
50449e37000SMatthias Springer     return success();
50549e37000SMatthias Springer   }
50649e37000SMatthias Springer 
50749e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
5089597b16aSMatthias Springer                                 const AnalysisState &state) const {
50949e37000SMatthias Springer     return BufferRelation::Equivalent;
51049e37000SMatthias Springer   }
51149e37000SMatthias Springer };
51249e37000SMatthias Springer 
51349e37000SMatthias Springer /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
51449e37000SMatthias Springer /// equivalent operand / result and same offset/sizes/strides specification).
51549e37000SMatthias Springer ///
51649e37000SMatthias Springer /// This is one particular type of relationship between ops on tensors that
51749e37000SMatthias Springer /// reduce to an equivalence on buffers. This should be generalized and
51849e37000SMatthias Springer /// exposed as interfaces on the proper types.
5199597b16aSMatthias Springer static bool areEquivalentExtractSliceOps(const AnalysisState &state,
52049e37000SMatthias Springer                                          ExtractSliceOp st, InsertSliceOp sti) {
52149e37000SMatthias Springer   if (!st || !sti)
52249e37000SMatthias Springer     return false;
52349e37000SMatthias Springer   if (sti != sti &&
5248df54a6aSJacques Pienaar       !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
52549e37000SMatthias Springer     return false;
52649e37000SMatthias Springer   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
52749e37000SMatthias Springer     return false;
52849e37000SMatthias Springer   return true;
52949e37000SMatthias Springer }
53049e37000SMatthias Springer 
53149e37000SMatthias Springer /// Return true if `value` is originating from an ExtractSliceOp that matches
53249e37000SMatthias Springer /// the given InsertSliceOp.
5339597b16aSMatthias Springer static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
5349597b16aSMatthias Springer                                       InsertSliceOp insertOp) {
53549e37000SMatthias Springer   auto condition = [&](Value val) {
53649e37000SMatthias Springer     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
53749e37000SMatthias Springer       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
53849e37000SMatthias Springer         return true;
53949e37000SMatthias Springer     return false;
54049e37000SMatthias Springer   };
54149e37000SMatthias Springer 
54249e37000SMatthias Springer   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
54349e37000SMatthias Springer                       condition);
54449e37000SMatthias Springer }
54549e37000SMatthias Springer 
54649e37000SMatthias Springer /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
54749e37000SMatthias Springer /// certain circumstances, this op can also be a no-op.
54849e37000SMatthias Springer struct InsertSliceOpInterface
54949e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
55049e37000SMatthias Springer                                                     tensor::InsertSliceOp> {
55149e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
5529597b16aSMatthias Springer                               const AnalysisState &state) const {
55349e37000SMatthias Springer     return true;
55449e37000SMatthias Springer   }
55549e37000SMatthias Springer 
55649e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
5579597b16aSMatthias Springer                                const AnalysisState &state) const {
55849e37000SMatthias Springer     return &opOperand == &op->getOpOperand(1) /*dest*/;
55949e37000SMatthias Springer   }
56049e37000SMatthias Springer 
5619597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
5629597b16aSMatthias Springer                                             const AnalysisState &state) const {
563585a8a32SMatthias Springer     if (&opOperand == &op->getOpOperand(1) /*dest*/)
564585a8a32SMatthias Springer       return {op->getResult(0)};
565585a8a32SMatthias Springer     return {};
56649e37000SMatthias Springer   }
56749e37000SMatthias Springer 
56849e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
5699597b16aSMatthias Springer                                 const AnalysisState &state) const {
57049e37000SMatthias Springer     return BufferRelation::Equivalent;
57149e37000SMatthias Springer   }
57249e37000SMatthias Springer 
57349e37000SMatthias Springer   bool isNotConflicting(Operation *op, OpOperand *uRead,
57449e37000SMatthias Springer                         OpOperand *uConflictingWrite,
5759597b16aSMatthias Springer                         const AnalysisState &state) const {
57649e37000SMatthias Springer     Operation *readingOp = uRead->getOwner();
57749e37000SMatthias Springer     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
57849e37000SMatthias Springer 
57949e37000SMatthias Springer     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
58049e37000SMatthias Springer     // uRead is an InsertSliceOp...
58149e37000SMatthias Springer     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
58249e37000SMatthias Springer       // As an example, consider the following IR.
58349e37000SMatthias Springer       //
58449e37000SMatthias Springer       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
58549e37000SMatthias Springer       // %1 = linalg.fill %cst, %0 {inplace= [true] }
58649e37000SMatthias Springer       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
58749e37000SMatthias Springer       //     {inplace= [true] }
58849e37000SMatthias Springer 
58949e37000SMatthias Springer       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
59049e37000SMatthias Springer       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
59149e37000SMatthias Springer           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
59249e37000SMatthias Springer                                     insertSliceOp))
59349e37000SMatthias Springer         // Case 1: The main insight is that InsertSliceOp reads only part of
59449e37000SMatthias Springer         // the destination tensor. The overwritten area is not read. If
59549e37000SMatthias Springer         // uConflictingWrite writes into exactly the memory location that is
59649e37000SMatthias Springer         // being read by uRead, this is not a conflict.
59749e37000SMatthias Springer         //
59849e37000SMatthias Springer         // In the above example:
59949e37000SMatthias Springer         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
60049e37000SMatthias Springer         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
60149e37000SMatthias Springer         //
60249e37000SMatthias Springer         // The read of %t does not conflict with the write of the FillOp
60349e37000SMatthias Springer         // (same aliases!) because the area that the FillOp operates on is
60449e37000SMatthias Springer         // exactly the one that is *not* read via %t.
60549e37000SMatthias Springer         return true;
60649e37000SMatthias Springer 
60749e37000SMatthias Springer       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
60849e37000SMatthias Springer           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
60949e37000SMatthias Springer           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
61049e37000SMatthias Springer         // Case 2: The read of the source tensor and the write to the dest
61149e37000SMatthias Springer         // tensor via an InsertSliceOp is not a conflict if the read is
61249e37000SMatthias Springer         // reading exactly that part of an equivalent tensor that the
61349e37000SMatthias Springer         // InsertSliceOp is writing.
61449e37000SMatthias Springer         //
61549e37000SMatthias Springer         // In the above example:
61649e37000SMatthias Springer         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
61749e37000SMatthias Springer         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
61849e37000SMatthias Springer         return true;
61949e37000SMatthias Springer     }
62049e37000SMatthias Springer 
62149e37000SMatthias Springer     // If uConflictingWrite is an InsertSliceOp...
62249e37000SMatthias Springer     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
62349e37000SMatthias Springer       // As an example, consider the following IR.
62449e37000SMatthias Springer       //
62549e37000SMatthias Springer       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
62649e37000SMatthias Springer       // %1 = linalg.fill %cst, %0 {inplace= [true] }
62749e37000SMatthias Springer       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
62849e37000SMatthias Springer       //     {inplace= [true] }
62949e37000SMatthias Springer       // %3 = vector.transfer_read %1, %cst
63049e37000SMatthias Springer       //
63149e37000SMatthias Springer       // In the above example:
63249e37000SMatthias Springer       // uRead             = OpOperand 0 (%1) of vector.transfer_read
63349e37000SMatthias Springer       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
63449e37000SMatthias Springer       // lastWrite         = %1
63549e37000SMatthias Springer       //
63649e37000SMatthias Springer       // This is not a conflict because the InsertSliceOp overwrites the
63749e37000SMatthias Springer       // memory segment of %1 with the exact same data. (Effectively, there
63849e37000SMatthias Springer       // is no memory write here.)
63949e37000SMatthias Springer       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
64049e37000SMatthias Springer           state.areEquivalentBufferizedValues(uRead->get(),
6418df54a6aSJacques Pienaar                                               insertSliceOp.getSource()) &&
6428df54a6aSJacques Pienaar           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
64349e37000SMatthias Springer                                     insertSliceOp))
64449e37000SMatthias Springer         return true;
64549e37000SMatthias Springer 
64649e37000SMatthias Springer     return false;
64749e37000SMatthias Springer   }
64849e37000SMatthias Springer 
64949e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
650b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
65149e37000SMatthias Springer     // insert_slice ops arise from tiling and bufferizing them out-of-place is
65249e37000SMatthias Springer     // generally a deal breaker. When used with loops, this ends up cloning the
65349e37000SMatthias Springer     // whole tensor on every single iteration and is a symptom of a
65449e37000SMatthias Springer     // catastrophically bad scheduling decision.
65549e37000SMatthias Springer     // TODO: be very loud about it or even consider failing the pass.
65649e37000SMatthias Springer     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
65749e37000SMatthias Springer     Location loc = insertSliceOp.getLoc();
6588df54a6aSJacques Pienaar     Value dstMemref = getBuffer(rewriter, insertSliceOp.getDest(), options);
65949e37000SMatthias Springer 
66049e37000SMatthias Springer     // Expand offsets, sizes and strides to the full rank to handle the
66149e37000SMatthias Springer     // rank-reducing case.
66249e37000SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
66349e37000SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
66449e37000SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
66549e37000SMatthias Springer     OffsetSizeAndStrideOpInterface::expandToRank(
666b3ebe3beSMatthias Springer         dstMemref, mixedOffsets, mixedSizes, mixedStrides,
66749e37000SMatthias Springer         [&](Value target, int64_t dim) -> OpFoldResult {
66849e37000SMatthias Springer           auto shapedType = target.getType().cast<ShapedType>();
66949e37000SMatthias Springer           if (shapedType.isDynamicDim(dim))
67049e37000SMatthias Springer             return rewriter.create<memref::DimOp>(loc, target, dim).result();
67149e37000SMatthias Springer           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
67249e37000SMatthias Springer         });
67349e37000SMatthias Springer     // Take a subview of the dst.
674b3ebe3beSMatthias Springer     auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
67549e37000SMatthias Springer     auto subviewMemRefType =
67649e37000SMatthias Springer         memref::SubViewOp::inferRankReducedResultType(
67749e37000SMatthias Springer             insertSliceOp.getSourceType().getRank(), dstMemrefType,
67849e37000SMatthias Springer             mixedOffsets, mixedSizes, mixedStrides)
67949e37000SMatthias Springer             .cast<MemRefType>();
68049e37000SMatthias Springer     Value subView = rewriter.create<memref::SubViewOp>(
681b3ebe3beSMatthias Springer         loc, subviewMemRefType, dstMemref, mixedOffsets, mixedSizes,
68249e37000SMatthias Springer         mixedStrides);
68349e37000SMatthias Springer 
68449e37000SMatthias Springer     // Copy tensor. If this tensor.insert_slice has a matching
68549e37000SMatthias Springer     // tensor.extract_slice, the copy operation will eventually fold away.
6868df54a6aSJacques Pienaar     auto srcMemref = getBuffer(rewriter, insertSliceOp.getSource(), options);
687b55d55ecSMatthias Springer     if (failed(options.createMemCpy(rewriter, loc, srcMemref, subView)))
68849e37000SMatthias Springer       return failure();
68949e37000SMatthias Springer 
690b3ebe3beSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, dstMemref);
69149e37000SMatthias Springer     return success();
69249e37000SMatthias Springer   }
69349e37000SMatthias Springer };
69449e37000SMatthias Springer 
695fc08d1c2SMatthias Springer /// Bufferization of tensor.rank. Replace with memref.rank.
696fc08d1c2SMatthias Springer struct RankOpInterface
697fc08d1c2SMatthias Springer     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
698fc08d1c2SMatthias Springer                                                     tensor::RankOp> {
699fc08d1c2SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
7009597b16aSMatthias Springer                               const AnalysisState &state) const {
701fc08d1c2SMatthias Springer     return true;
702fc08d1c2SMatthias Springer   }
703fc08d1c2SMatthias Springer 
704fc08d1c2SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
7059597b16aSMatthias Springer                                const AnalysisState &state) const {
706fc08d1c2SMatthias Springer     return false;
707fc08d1c2SMatthias Springer   }
708fc08d1c2SMatthias Springer 
7099597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
7109597b16aSMatthias Springer                                             const AnalysisState &state) const {
711585a8a32SMatthias Springer     return {};
712fc08d1c2SMatthias Springer   }
713fc08d1c2SMatthias Springer 
714fc08d1c2SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
715b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
716fc08d1c2SMatthias Springer     auto rankOp = cast<tensor::RankOp>(op);
7178df54a6aSJacques Pienaar     auto v = getBuffer(rewriter, rankOp.getTensor(), options);
718fc08d1c2SMatthias Springer     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
719b3ebe3beSMatthias Springer                                                  v);
720fc08d1c2SMatthias Springer     return success();
721fc08d1c2SMatthias Springer   }
722fc08d1c2SMatthias Springer };
723fc08d1c2SMatthias Springer 
724e287d647SAshay Rane /// Bufferization of tensor.reshape. Replace with memref.reshape.
725e287d647SAshay Rane struct ReshapeOpInterface
726e287d647SAshay Rane     : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
727e287d647SAshay Rane                                                     tensor::ReshapeOp> {
728e287d647SAshay Rane   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
729e287d647SAshay Rane                               const AnalysisState &state) const {
730e287d647SAshay Rane     if (&opOperand == &op->getOpOperand(1) /* shape */)
731e287d647SAshay Rane       return true;
732e287d647SAshay Rane     return false;
733e287d647SAshay Rane   }
734e287d647SAshay Rane 
735e287d647SAshay Rane   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
736e287d647SAshay Rane                                const AnalysisState &state) const {
737e287d647SAshay Rane     return false;
738e287d647SAshay Rane   }
739e287d647SAshay Rane 
740e287d647SAshay Rane   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
741e287d647SAshay Rane                                             const AnalysisState &state) const {
742e287d647SAshay Rane     return {op->getOpResult(0)};
743e287d647SAshay Rane   }
744e287d647SAshay Rane 
745e287d647SAshay Rane   BufferRelation bufferRelation(Operation *op, OpResult opResult,
746e287d647SAshay Rane                                 const AnalysisState &state) const {
747e287d647SAshay Rane     return BufferRelation::Equivalent;
748e287d647SAshay Rane   }
749e287d647SAshay Rane 
750e287d647SAshay Rane   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
751b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
752e287d647SAshay Rane     auto reshapeOp = cast<tensor::ReshapeOp>(op);
7538df54a6aSJacques Pienaar     Value srcBuffer = getBuffer(rewriter, reshapeOp.getSource(), options);
7548df54a6aSJacques Pienaar     Value shapeBuffer = getBuffer(rewriter, reshapeOp.getShape(), options);
755e287d647SAshay Rane     auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
756b55d55ecSMatthias Springer     auto resultMemRefType = getMemRefType(resultTensorType, options);
757e287d647SAshay Rane     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
758b3ebe3beSMatthias Springer         rewriter, op, resultMemRefType, srcBuffer, shapeBuffer);
759e287d647SAshay Rane     return success();
760e287d647SAshay Rane   }
761e287d647SAshay Rane };
762e287d647SAshay Rane 
76349e37000SMatthias Springer } // namespace
76449e37000SMatthias Springer } // namespace tensor
76549e37000SMatthias Springer } // namespace mlir
76649e37000SMatthias Springer 
76749e37000SMatthias Springer void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
76849e37000SMatthias Springer     DialectRegistry &registry) {
76977eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
77077eee579SRiver Riddle     CastOp::attachInterface<CastOpInterface>(*ctx);
77177eee579SRiver Riddle     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
77277eee579SRiver Riddle     DimOp::attachInterface<DimOpInterface>(*ctx);
77377eee579SRiver Riddle     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
77477eee579SRiver Riddle     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
77577eee579SRiver Riddle     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
77677eee579SRiver Riddle     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
77777eee579SRiver Riddle     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
77877eee579SRiver Riddle     InsertOp::attachInterface<InsertOpInterface>(*ctx);
77977eee579SRiver Riddle     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
78077eee579SRiver Riddle     RankOp::attachInterface<RankOpInterface>(*ctx);
781e287d647SAshay Rane     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
78277eee579SRiver Riddle   });
78349e37000SMatthias Springer }
784