149e37000SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
249e37000SMatthias Springer //
349e37000SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
449e37000SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
549e37000SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
649e37000SMatthias Springer //
749e37000SMatthias Springer //===----------------------------------------------------------------------===//
849e37000SMatthias Springer 
949e37000SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
1049e37000SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1149e37000SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
1271bbb78bSMatthias Springer #include "mlir/Dialect/SCF/SCF.h"
1349e37000SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
1449e37000SMatthias Springer #include "mlir/IR/Dialect.h"
1549e37000SMatthias Springer #include "mlir/IR/Operation.h"
1649e37000SMatthias Springer 
1749e37000SMatthias Springer using namespace mlir;
1849e37000SMatthias Springer using namespace mlir::bufferization;
1949e37000SMatthias Springer using namespace mlir::tensor;
2049e37000SMatthias Springer 
2149e37000SMatthias Springer namespace mlir {
2249e37000SMatthias Springer namespace tensor {
2349e37000SMatthias Springer namespace {
2449e37000SMatthias Springer 
2549e37000SMatthias Springer struct CastOpInterface
2649e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
2749e37000SMatthias Springer                                                     tensor::CastOp> {
2849e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
299597b16aSMatthias Springer                               const AnalysisState &state) const {
3049e37000SMatthias Springer     return false;
3149e37000SMatthias Springer   }
3249e37000SMatthias Springer 
3349e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
349597b16aSMatthias Springer                                const AnalysisState &state) const {
3549e37000SMatthias Springer     return false;
3649e37000SMatthias Springer   }
3749e37000SMatthias Springer 
389597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
399597b16aSMatthias Springer                                             const AnalysisState &state) const {
40585a8a32SMatthias Springer     return {op->getResult(0)};
4149e37000SMatthias Springer   }
4249e37000SMatthias Springer 
4349e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
449597b16aSMatthias Springer                                 const AnalysisState &state) const {
4549e37000SMatthias Springer     return BufferRelation::Equivalent;
4649e37000SMatthias Springer   }
4749e37000SMatthias Springer 
4849e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
499597b16aSMatthias Springer                           BufferizationState &state) const {
5049e37000SMatthias Springer     auto castOp = cast<tensor::CastOp>(op);
5149e37000SMatthias Springer 
5249e37000SMatthias Springer     // The result buffer still has the old (pre-cast) type.
5349e37000SMatthias Springer     FailureOr<Value> resultBuffer =
5449e37000SMatthias Springer         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
5549e37000SMatthias Springer     if (failed(resultBuffer))
5649e37000SMatthias Springer       return failure();
5749e37000SMatthias Springer     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
5849e37000SMatthias Springer     Attribute memorySpace = sourceMemRefType.getMemorySpace();
5949e37000SMatthias Springer     TensorType resultTensorType =
6049e37000SMatthias Springer         castOp.getResult().getType().cast<TensorType>();
6149e37000SMatthias Springer     MemRefLayoutAttrInterface layout;
6249e37000SMatthias Springer 
6349e37000SMatthias Springer     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
6449e37000SMatthias Springer       if (resultTensorType.isa<RankedTensorType>())
6549e37000SMatthias Springer         layout = rankedMemRefType.getLayout();
6649e37000SMatthias Springer 
6749e37000SMatthias Springer     // Compute the new memref type.
6826852423SMatthias Springer     Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
6926852423SMatthias Springer                                           layout, memorySpace);
7049e37000SMatthias Springer 
7149e37000SMatthias Springer     // Replace the op with a memref.cast.
7249e37000SMatthias Springer     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
7349e37000SMatthias Springer                                              resultMemRefType) &&
7449e37000SMatthias Springer            "CallOp::bufferize: cast incompatible");
7549e37000SMatthias Springer     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
7649e37000SMatthias Springer                                                  *resultBuffer);
7749e37000SMatthias Springer 
7849e37000SMatthias Springer     return success();
7949e37000SMatthias Springer   }
8049e37000SMatthias Springer };
8149e37000SMatthias Springer 
82e6f69161SMatthias Springer /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
83e6f69161SMatthias Springer struct CollapseShapeOpInterface
84e6f69161SMatthias Springer     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
85e6f69161SMatthias Springer                                                     tensor::CollapseShapeOp> {
86e6f69161SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
879597b16aSMatthias Springer                               const AnalysisState &state) const {
88e6f69161SMatthias Springer     return false;
89e6f69161SMatthias Springer   }
90e6f69161SMatthias Springer 
91e6f69161SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
929597b16aSMatthias Springer                                const AnalysisState &state) const {
93e6f69161SMatthias Springer     return false;
94e6f69161SMatthias Springer   }
95e6f69161SMatthias Springer 
969597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
979597b16aSMatthias Springer                                             const AnalysisState &state) const {
98e6f69161SMatthias Springer     if (&opOperand == &op->getOpOperand(0) /*src*/)
99e6f69161SMatthias Springer       return {op->getOpResult(0)};
100e6f69161SMatthias Springer     return {};
101e6f69161SMatthias Springer   }
102e6f69161SMatthias Springer 
103e6f69161SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
1049597b16aSMatthias Springer                                 const AnalysisState &state) const {
105e6f69161SMatthias Springer     return BufferRelation::Equivalent;
106e6f69161SMatthias Springer   }
107e6f69161SMatthias Springer 
108e6f69161SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1099597b16aSMatthias Springer                           BufferizationState &state) const {
110e6f69161SMatthias Springer     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
111e6f69161SMatthias Springer     Value buffer =
112e6f69161SMatthias Springer         *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/);
113e6f69161SMatthias Springer     Type resultType =
114e6f69161SMatthias Springer         getMemRefType(collapseShapeOp.getResultType(), state.getOptions());
115e6f69161SMatthias Springer     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
116e6f69161SMatthias Springer         rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
117e6f69161SMatthias Springer     return success();
118e6f69161SMatthias Springer   }
119e6f69161SMatthias Springer };
120e6f69161SMatthias Springer 
12149e37000SMatthias Springer /// Bufferization of tensor.dim. Replace with memref.dim.
12249e37000SMatthias Springer struct DimOpInterface
12349e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
12449e37000SMatthias Springer                                                     tensor::DimOp> {
12549e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1269597b16aSMatthias Springer                               const AnalysisState &state) const {
12749e37000SMatthias Springer     return true;
12849e37000SMatthias Springer   }
12949e37000SMatthias Springer 
13049e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1319597b16aSMatthias Springer                                const AnalysisState &state) const {
13249e37000SMatthias Springer     return false;
13349e37000SMatthias Springer   }
13449e37000SMatthias Springer 
1359597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
1369597b16aSMatthias Springer                                             const AnalysisState &state) const {
137585a8a32SMatthias Springer     return {};
13849e37000SMatthias Springer   }
13949e37000SMatthias Springer 
14049e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1419597b16aSMatthias Springer                           BufferizationState &state) const {
14249e37000SMatthias Springer     auto dimOp = cast<tensor::DimOp>(op);
14349e37000SMatthias Springer     Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
14449e37000SMatthias Springer     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
14549e37000SMatthias Springer     return success();
14649e37000SMatthias Springer   }
14749e37000SMatthias Springer };
14849e37000SMatthias Springer 
149e6f69161SMatthias Springer /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
150e6f69161SMatthias Springer struct ExpandShapeOpInterface
151e6f69161SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
152e6f69161SMatthias Springer                                                     tensor::ExpandShapeOp> {
153e6f69161SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1549597b16aSMatthias Springer                               const AnalysisState &state) const {
155e6f69161SMatthias Springer     return false;
156e6f69161SMatthias Springer   }
157e6f69161SMatthias Springer 
158e6f69161SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1599597b16aSMatthias Springer                                const AnalysisState &state) const {
160e6f69161SMatthias Springer     return false;
161e6f69161SMatthias Springer   }
162e6f69161SMatthias Springer 
1639597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
1649597b16aSMatthias Springer                                             const AnalysisState &state) const {
165e6f69161SMatthias Springer     if (&opOperand == &op->getOpOperand(0) /*src*/)
166e6f69161SMatthias Springer       return {op->getOpResult(0)};
167e6f69161SMatthias Springer     return {};
168e6f69161SMatthias Springer   }
169e6f69161SMatthias Springer 
170e6f69161SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
1719597b16aSMatthias Springer                                 const AnalysisState &state) const {
172e6f69161SMatthias Springer     return BufferRelation::Equivalent;
173e6f69161SMatthias Springer   }
174e6f69161SMatthias Springer 
175e6f69161SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1769597b16aSMatthias Springer                           BufferizationState &state) const {
177e6f69161SMatthias Springer     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
178e6f69161SMatthias Springer     Value buffer =
179e6f69161SMatthias Springer         *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
180e6f69161SMatthias Springer     Type resultType =
181e6f69161SMatthias Springer         getMemRefType(expandShapeOp.getResultType(), state.getOptions());
182e6f69161SMatthias Springer     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
183e6f69161SMatthias Springer         rewriter, op, resultType, buffer, expandShapeOp.reassociation());
184e6f69161SMatthias Springer     return success();
185e6f69161SMatthias Springer   }
186e6f69161SMatthias Springer };
187e6f69161SMatthias Springer 
18849e37000SMatthias Springer /// Bufferization of tensor.extract_slice. Replace with memref.subview.
18949e37000SMatthias Springer struct ExtractSliceOpInterface
19049e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
19149e37000SMatthias Springer                                                     tensor::ExtractSliceOp> {
19249e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1939597b16aSMatthias Springer                               const AnalysisState &state) const {
19449e37000SMatthias Springer     return false;
19549e37000SMatthias Springer   }
19649e37000SMatthias Springer 
19749e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1989597b16aSMatthias Springer                                const AnalysisState &state) const {
19949e37000SMatthias Springer     return false;
20049e37000SMatthias Springer   }
20149e37000SMatthias Springer 
2029597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
2039597b16aSMatthias Springer                                             const AnalysisState &state) const {
204585a8a32SMatthias Springer     if (&opOperand == &op->getOpOperand(0) /*source*/)
205585a8a32SMatthias Springer       return {op->getOpResult(0)};
206585a8a32SMatthias Springer     return {};
20749e37000SMatthias Springer   }
20849e37000SMatthias Springer 
20949e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
2109597b16aSMatthias Springer                                 const AnalysisState &state) const {
21149e37000SMatthias Springer     return BufferRelation::None;
21249e37000SMatthias Springer   }
21349e37000SMatthias Springer 
21449e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
2159597b16aSMatthias Springer                           BufferizationState &state) const {
21649e37000SMatthias Springer     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
21749e37000SMatthias Springer     Location loc = extractSliceOp.getLoc();
21849e37000SMatthias Springer     Value srcMemref =
21949e37000SMatthias Springer         *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
22049e37000SMatthias Springer                          /*forceInPlace=*/true);
22149e37000SMatthias Springer     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
22249e37000SMatthias Springer     auto dstTensorType =
22349e37000SMatthias Springer         extractSliceOp.result().getType().cast<RankedTensorType>();
22449e37000SMatthias Springer 
22549e37000SMatthias Springer     // If not inplaceable, alloc.
2269597b16aSMatthias Springer     bool inplace =
2279597b16aSMatthias Springer         state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0));
22849e37000SMatthias Springer     Value alloc;
22949e37000SMatthias Springer     if (!inplace) {
23049e37000SMatthias Springer       FailureOr<Value> allocOrFailure =
23105e0495fSMatthias Springer           state.createAlloc(rewriter, loc, extractSliceOp.result());
23249e37000SMatthias Springer       if (failed(allocOrFailure))
23349e37000SMatthias Springer         return failure();
23449e37000SMatthias Springer       alloc = *allocOrFailure;
23549e37000SMatthias Springer     }
23649e37000SMatthias Springer 
23749e37000SMatthias Springer     // Expand offsets, sizes and strides to the full rank to handle the
23849e37000SMatthias Springer     // rank-reducing case.
23949e37000SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
24049e37000SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
24149e37000SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
24249e37000SMatthias Springer     OffsetSizeAndStrideOpInterface::expandToRank(
24349e37000SMatthias Springer         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
24449e37000SMatthias Springer         [&](Value target, int64_t dim) -> OpFoldResult {
24549e37000SMatthias Springer           auto shapedType = target.getType().cast<ShapedType>();
24649e37000SMatthias Springer           if (shapedType.isDynamicDim(dim))
24749e37000SMatthias Springer             return rewriter.create<memref::DimOp>(loc, target, dim).result();
24849e37000SMatthias Springer           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
24949e37000SMatthias Springer         });
25049e37000SMatthias Springer     // Bufferize to subview.
25149e37000SMatthias Springer     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
25249e37000SMatthias Springer                                  dstTensorType.getRank(), srcMemrefType,
25349e37000SMatthias Springer                                  mixedOffsets, mixedSizes, mixedStrides)
25449e37000SMatthias Springer                                  .cast<MemRefType>();
25549e37000SMatthias Springer     Value subView = rewriter.create<memref::SubViewOp>(
25649e37000SMatthias Springer         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
25749e37000SMatthias Springer         mixedStrides);
25849e37000SMatthias Springer 
25949e37000SMatthias Springer     // If not inplaceable, copy.
26049e37000SMatthias Springer     if (!inplace) {
26149e37000SMatthias Springer       // Do not copy if the copied data is never read.
2629597b16aSMatthias Springer       if (state.getAnalysisState().isValueRead(extractSliceOp.result()))
26349e37000SMatthias Springer         if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
26449e37000SMatthias Springer                                 alloc, state.getOptions())))
26549e37000SMatthias Springer           return failure();
26649e37000SMatthias Springer       subView = alloc;
26749e37000SMatthias Springer     }
26849e37000SMatthias Springer 
26949e37000SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, subView);
27049e37000SMatthias Springer     return success();
27149e37000SMatthias Springer   }
27249e37000SMatthias Springer };
27349e37000SMatthias Springer 
27449e37000SMatthias Springer /// Bufferization of tensor.extract. Replace with memref.load.
27549e37000SMatthias Springer struct ExtractOpInterface
27649e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
27749e37000SMatthias Springer                                                     tensor::ExtractOp> {
27849e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
2799597b16aSMatthias Springer                               const AnalysisState &state) const {
28049e37000SMatthias Springer     return true;
28149e37000SMatthias Springer   }
28249e37000SMatthias Springer 
28349e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
2849597b16aSMatthias Springer                                const AnalysisState &state) const {
28549e37000SMatthias Springer     return false;
28649e37000SMatthias Springer   }
28749e37000SMatthias Springer 
2889597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
2899597b16aSMatthias Springer                                             const AnalysisState &state) const {
290585a8a32SMatthias Springer     return {};
29149e37000SMatthias Springer   }
29249e37000SMatthias Springer 
29349e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
2949597b16aSMatthias Springer                           BufferizationState &state) const {
29549e37000SMatthias Springer     auto extractOp = cast<tensor::ExtractOp>(op);
29649e37000SMatthias Springer     Value srcMemref =
29749e37000SMatthias Springer         *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
29849e37000SMatthias Springer     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
29949e37000SMatthias Springer                                                  extractOp.indices());
30049e37000SMatthias Springer     return success();
30149e37000SMatthias Springer   }
30249e37000SMatthias Springer };
30349e37000SMatthias Springer 
304d581c94dSMatthias Springer // Implements backtracking to traverse indices of the output buffer while
305d581c94dSMatthias Springer // iterating over op.elements().
306d581c94dSMatthias Springer static void createStores(RewriterBase &rewriter, Location loc, int dim,
307d581c94dSMatthias Springer                          Value buffer, ArrayRef<int64_t> shape,
308d581c94dSMatthias Springer                          ArrayRef<Value> constants,
309d581c94dSMatthias Springer                          OperandRange::iterator &elementIt,
310d581c94dSMatthias Springer                          SmallVectorImpl<Value> &indices) {
311d581c94dSMatthias Springer   if (dim == static_cast<int>(shape.size()) - 1) {
312d581c94dSMatthias Springer     for (int i = 0; i < shape.back(); ++i) {
313d581c94dSMatthias Springer       indices.back() = constants[i];
314d581c94dSMatthias Springer       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
315d581c94dSMatthias Springer       ++elementIt;
316d581c94dSMatthias Springer     }
317d581c94dSMatthias Springer     return;
318d581c94dSMatthias Springer   }
319d581c94dSMatthias Springer   for (int i = 0; i < shape[dim]; ++i) {
320d581c94dSMatthias Springer     indices[dim] = constants[i];
321d581c94dSMatthias Springer     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
322d581c94dSMatthias Springer                  indices);
323d581c94dSMatthias Springer   }
324d581c94dSMatthias Springer }
325d581c94dSMatthias Springer 
326d581c94dSMatthias Springer /// Bufferization of tensor.from_elements.
327d581c94dSMatthias Springer struct FromElementsOpInterface
328d581c94dSMatthias Springer     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
329d581c94dSMatthias Springer                                                     tensor::FromElementsOp> {
330d581c94dSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
3319597b16aSMatthias Springer                           BufferizationState &state) const {
332d581c94dSMatthias Springer     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
333d581c94dSMatthias Springer 
334d581c94dSMatthias Springer     // Allocate a buffer for the result.
335d581c94dSMatthias Springer     Location loc = op->getLoc();
336d581c94dSMatthias Springer     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
337d581c94dSMatthias Springer     auto shape = tensorType.getShape();
338d581c94dSMatthias Springer     FailureOr<Value> maybeBuffer =
339*9e24f0f4SMatthias Springer         state.createAlloc(rewriter, loc, fromElementsOp.result());
340d581c94dSMatthias Springer     if (failed(maybeBuffer))
341d581c94dSMatthias Springer       return failure();
342d581c94dSMatthias Springer     Value buffer = *maybeBuffer;
343d581c94dSMatthias Springer 
344d581c94dSMatthias Springer     // Case: tensor<0xelem_type>.
345d581c94dSMatthias Springer     if (fromElementsOp.elements().empty()) {
346d581c94dSMatthias Springer       replaceOpWithBufferizedValues(rewriter, op, buffer);
347d581c94dSMatthias Springer       return success();
348d581c94dSMatthias Springer     }
349d581c94dSMatthias Springer 
350d581c94dSMatthias Springer     // Case: tensor<elem_type>.
351d581c94dSMatthias Springer     if (shape.empty()) {
352d581c94dSMatthias Springer       rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
353d581c94dSMatthias Springer                                        buffer);
354d581c94dSMatthias Springer       replaceOpWithBufferizedValues(rewriter, op, buffer);
355d581c94dSMatthias Springer       return success();
356d581c94dSMatthias Springer     }
357d581c94dSMatthias Springer 
358d581c94dSMatthias Springer     // Create constants for the range of possible indices [0, max{shape_i}).
359d581c94dSMatthias Springer     auto maxDim = *std::max_element(shape.begin(), shape.end());
360d581c94dSMatthias Springer     SmallVector<Value, 2> constants;
361d581c94dSMatthias Springer     constants.reserve(maxDim);
362d581c94dSMatthias Springer     for (int i = 0; i < maxDim; ++i)
363d581c94dSMatthias Springer       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
364d581c94dSMatthias Springer 
365d581c94dSMatthias Springer     // Traverse all `elements` and create `memref.store` ops.
366d581c94dSMatthias Springer     auto elementIt = fromElementsOp.elements().begin();
367d581c94dSMatthias Springer     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
368d581c94dSMatthias Springer     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
369d581c94dSMatthias Springer                  indices);
370d581c94dSMatthias Springer 
371d581c94dSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, buffer);
372d581c94dSMatthias Springer     return success();
373d581c94dSMatthias Springer   }
374d581c94dSMatthias Springer };
375d581c94dSMatthias Springer 
37671bbb78bSMatthias Springer /// Bufferization of tensor.generate.
37771bbb78bSMatthias Springer struct GenerateOpInterface
37871bbb78bSMatthias Springer     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
37971bbb78bSMatthias Springer                                                     tensor::GenerateOp> {
38071bbb78bSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
3819597b16aSMatthias Springer                           BufferizationState &state) const {
38271bbb78bSMatthias Springer     auto generateOp = cast<tensor::GenerateOp>(op);
38371bbb78bSMatthias Springer 
38471bbb78bSMatthias Springer     // Allocate memory.
38571bbb78bSMatthias Springer     Location loc = op->getLoc();
38671bbb78bSMatthias Springer     MemRefType memrefType =
38771bbb78bSMatthias Springer         getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
388*9e24f0f4SMatthias Springer     FailureOr<Value> maybeResult =
389*9e24f0f4SMatthias Springer         state.createAlloc(rewriter, loc, generateOp.result());
39071bbb78bSMatthias Springer     if (failed(maybeResult))
39171bbb78bSMatthias Springer       return failure();
39271bbb78bSMatthias Springer     Value result = *maybeResult;
39371bbb78bSMatthias Springer 
39471bbb78bSMatthias Springer     // Collect loop bounds.
39571bbb78bSMatthias Springer     int64_t rank = memrefType.getRank();
39671bbb78bSMatthias Springer     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
39771bbb78bSMatthias Springer     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
39871bbb78bSMatthias Springer     SmallVector<Value, 4> lowerBounds(rank, zero);
39971bbb78bSMatthias Springer     SmallVector<Value, 4> steps(rank, one);
40071bbb78bSMatthias Springer     SmallVector<Value, 4> upperBounds;
40171bbb78bSMatthias Springer     int nextDynamicIndex = 0;
40271bbb78bSMatthias Springer     for (int i = 0; i < rank; i++) {
40371bbb78bSMatthias Springer       Value upperBound = memrefType.isDynamicDim(i)
40471bbb78bSMatthias Springer                              ? generateOp.dynamicExtents()[nextDynamicIndex++]
40571bbb78bSMatthias Springer                              : rewriter.create<arith::ConstantIndexOp>(
40671bbb78bSMatthias Springer                                    loc, memrefType.getDimSize(i));
40771bbb78bSMatthias Springer       upperBounds.push_back(upperBound);
40871bbb78bSMatthias Springer     }
40971bbb78bSMatthias Springer 
41071bbb78bSMatthias Springer     // Generate tensor elements with a parallel loop that stores into
41171bbb78bSMatthias Springer     // each element of the resulting memref. We use mergeBlockBefore to "move"
41271bbb78bSMatthias Springer     // this op's body into the scf.parallel's body.
41371bbb78bSMatthias Springer     auto parallel =
41471bbb78bSMatthias Springer         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
41571bbb78bSMatthias Springer     Block *parallelBody = parallel.getBody();
41671bbb78bSMatthias Springer     rewriter.mergeBlockBefore(generateOp.getBody(),
41771bbb78bSMatthias Springer                               parallelBody->getTerminator(),
41871bbb78bSMatthias Springer                               parallelBody->getArguments());
41971bbb78bSMatthias Springer     // Replace the inlined yield op with a store op. The scf.parallel's builder
42071bbb78bSMatthias Springer     // already populated an scf.yield at the end, so we don't need to worry
42171bbb78bSMatthias Springer     // about creating that.
42271bbb78bSMatthias Springer     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
42371bbb78bSMatthias Springer     rewriter.setInsertionPointAfter(elementYield);
42471bbb78bSMatthias Springer     rewriter.replaceOpWithNewOp<memref::StoreOp>(
42571bbb78bSMatthias Springer         elementYield, elementYield->getOperands()[0], result,
42671bbb78bSMatthias Springer         parallelBody->getArguments());
42771bbb78bSMatthias Springer 
42871bbb78bSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, result);
42971bbb78bSMatthias Springer     return success();
43071bbb78bSMatthias Springer   }
43171bbb78bSMatthias Springer };
43271bbb78bSMatthias Springer 
43349e37000SMatthias Springer /// Bufferization of tensor.insert. Replace with memref.store.
43449e37000SMatthias Springer struct InsertOpInterface
43549e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
43649e37000SMatthias Springer                                                     tensor::InsertOp> {
43749e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
4389597b16aSMatthias Springer                               const AnalysisState &state) const {
43949e37000SMatthias Springer     return true;
44049e37000SMatthias Springer   }
44149e37000SMatthias Springer 
44249e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
4439597b16aSMatthias Springer                                const AnalysisState &state) const {
44449e37000SMatthias Springer     return true;
44549e37000SMatthias Springer   }
44649e37000SMatthias Springer 
4479597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
4489597b16aSMatthias Springer                                             const AnalysisState &state) const {
44949e37000SMatthias Springer     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
45049e37000SMatthias Springer            "expected dest OpOperand");
451585a8a32SMatthias Springer     return {op->getOpResult(0)};
45249e37000SMatthias Springer   }
45349e37000SMatthias Springer 
45449e37000SMatthias Springer   SmallVector<OpOperand *>
45549e37000SMatthias Springer   getAliasingOpOperand(Operation *op, OpResult opResult,
4569597b16aSMatthias Springer                        const AnalysisState &state) const {
45749e37000SMatthias Springer     return {&op->getOpOperand(1) /*dest*/};
45849e37000SMatthias Springer   }
45949e37000SMatthias Springer 
46049e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
4619597b16aSMatthias Springer                           BufferizationState &state) const {
46249e37000SMatthias Springer     auto insertOp = cast<tensor::InsertOp>(op);
46349e37000SMatthias Springer     FailureOr<Value> destMemref =
46449e37000SMatthias Springer         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
46549e37000SMatthias Springer     if (failed(destMemref))
46649e37000SMatthias Springer       return failure();
46749e37000SMatthias Springer     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
46849e37000SMatthias Springer                                      *destMemref, insertOp.indices());
46949e37000SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
47049e37000SMatthias Springer     return success();
47149e37000SMatthias Springer   }
47249e37000SMatthias Springer 
47349e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
4749597b16aSMatthias Springer                                 const AnalysisState &state) const {
47549e37000SMatthias Springer     return BufferRelation::Equivalent;
47649e37000SMatthias Springer   }
47749e37000SMatthias Springer };
47849e37000SMatthias Springer 
47949e37000SMatthias Springer /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
48049e37000SMatthias Springer /// equivalent operand / result and same offset/sizes/strides specification).
48149e37000SMatthias Springer ///
48249e37000SMatthias Springer /// This is one particular type of relationship between ops on tensors that
48349e37000SMatthias Springer /// reduce to an equivalence on buffers. This should be generalized and
48449e37000SMatthias Springer /// exposed as interfaces on the proper types.
4859597b16aSMatthias Springer static bool areEquivalentExtractSliceOps(const AnalysisState &state,
48649e37000SMatthias Springer                                          ExtractSliceOp st, InsertSliceOp sti) {
48749e37000SMatthias Springer   if (!st || !sti)
48849e37000SMatthias Springer     return false;
48949e37000SMatthias Springer   if (sti != sti &&
49049e37000SMatthias Springer       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
49149e37000SMatthias Springer     return false;
49249e37000SMatthias Springer   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
49349e37000SMatthias Springer     return false;
49449e37000SMatthias Springer   return true;
49549e37000SMatthias Springer }
49649e37000SMatthias Springer 
49749e37000SMatthias Springer /// Return true if `value` is originating from an ExtractSliceOp that matches
49849e37000SMatthias Springer /// the given InsertSliceOp.
4999597b16aSMatthias Springer static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
5009597b16aSMatthias Springer                                       InsertSliceOp insertOp) {
50149e37000SMatthias Springer   auto condition = [&](Value val) {
50249e37000SMatthias Springer     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
50349e37000SMatthias Springer       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
50449e37000SMatthias Springer         return true;
50549e37000SMatthias Springer     return false;
50649e37000SMatthias Springer   };
50749e37000SMatthias Springer 
50849e37000SMatthias Springer   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
50949e37000SMatthias Springer                       condition);
51049e37000SMatthias Springer }
51149e37000SMatthias Springer 
51249e37000SMatthias Springer /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
51349e37000SMatthias Springer /// certain circumstances, this op can also be a no-op.
51449e37000SMatthias Springer struct InsertSliceOpInterface
51549e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
51649e37000SMatthias Springer                                                     tensor::InsertSliceOp> {
51749e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
5189597b16aSMatthias Springer                               const AnalysisState &state) const {
51949e37000SMatthias Springer     return true;
52049e37000SMatthias Springer   }
52149e37000SMatthias Springer 
52249e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
5239597b16aSMatthias Springer                                const AnalysisState &state) const {
52449e37000SMatthias Springer     return &opOperand == &op->getOpOperand(1) /*dest*/;
52549e37000SMatthias Springer   }
52649e37000SMatthias Springer 
5279597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
5289597b16aSMatthias Springer                                             const AnalysisState &state) const {
529585a8a32SMatthias Springer     if (&opOperand == &op->getOpOperand(1) /*dest*/)
530585a8a32SMatthias Springer       return {op->getResult(0)};
531585a8a32SMatthias Springer     return {};
53249e37000SMatthias Springer   }
53349e37000SMatthias Springer 
53449e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
5359597b16aSMatthias Springer                                 const AnalysisState &state) const {
53649e37000SMatthias Springer     return BufferRelation::Equivalent;
53749e37000SMatthias Springer   }
53849e37000SMatthias Springer 
53949e37000SMatthias Springer   bool isNotConflicting(Operation *op, OpOperand *uRead,
54049e37000SMatthias Springer                         OpOperand *uConflictingWrite,
5419597b16aSMatthias Springer                         const AnalysisState &state) const {
54249e37000SMatthias Springer     Operation *readingOp = uRead->getOwner();
54349e37000SMatthias Springer     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
54449e37000SMatthias Springer 
54549e37000SMatthias Springer     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
54649e37000SMatthias Springer     // uRead is an InsertSliceOp...
54749e37000SMatthias Springer     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
54849e37000SMatthias Springer       // As an example, consider the following IR.
54949e37000SMatthias Springer       //
55049e37000SMatthias Springer       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
55149e37000SMatthias Springer       // %1 = linalg.fill %cst, %0 {inplace= [true] }
55249e37000SMatthias Springer       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
55349e37000SMatthias Springer       //     {inplace= [true] }
55449e37000SMatthias Springer 
55549e37000SMatthias Springer       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
55649e37000SMatthias Springer       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
55749e37000SMatthias Springer           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
55849e37000SMatthias Springer                                     insertSliceOp))
55949e37000SMatthias Springer         // Case 1: The main insight is that InsertSliceOp reads only part of
56049e37000SMatthias Springer         // the destination tensor. The overwritten area is not read. If
56149e37000SMatthias Springer         // uConflictingWrite writes into exactly the memory location that is
56249e37000SMatthias Springer         // being read by uRead, this is not a conflict.
56349e37000SMatthias Springer         //
56449e37000SMatthias Springer         // In the above example:
56549e37000SMatthias Springer         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
56649e37000SMatthias Springer         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
56749e37000SMatthias Springer         //
56849e37000SMatthias Springer         // The read of %t does not conflict with the write of the FillOp
56949e37000SMatthias Springer         // (same aliases!) because the area that the FillOp operates on is
57049e37000SMatthias Springer         // exactly the one that is *not* read via %t.
57149e37000SMatthias Springer         return true;
57249e37000SMatthias Springer 
57349e37000SMatthias Springer       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
57449e37000SMatthias Springer           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
57549e37000SMatthias Springer           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
57649e37000SMatthias Springer         // Case 2: The read of the source tensor and the write to the dest
57749e37000SMatthias Springer         // tensor via an InsertSliceOp is not a conflict if the read is
57849e37000SMatthias Springer         // reading exactly that part of an equivalent tensor that the
57949e37000SMatthias Springer         // InsertSliceOp is writing.
58049e37000SMatthias Springer         //
58149e37000SMatthias Springer         // In the above example:
58249e37000SMatthias Springer         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
58349e37000SMatthias Springer         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
58449e37000SMatthias Springer         return true;
58549e37000SMatthias Springer     }
58649e37000SMatthias Springer 
58749e37000SMatthias Springer     // If uConflictingWrite is an InsertSliceOp...
58849e37000SMatthias Springer     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
58949e37000SMatthias Springer       // As an example, consider the following IR.
59049e37000SMatthias Springer       //
59149e37000SMatthias Springer       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
59249e37000SMatthias Springer       // %1 = linalg.fill %cst, %0 {inplace= [true] }
59349e37000SMatthias Springer       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
59449e37000SMatthias Springer       //     {inplace= [true] }
59549e37000SMatthias Springer       // %3 = vector.transfer_read %1, %cst
59649e37000SMatthias Springer       //
59749e37000SMatthias Springer       // In the above example:
59849e37000SMatthias Springer       // uRead             = OpOperand 0 (%1) of vector.transfer_read
59949e37000SMatthias Springer       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
60049e37000SMatthias Springer       // lastWrite         = %1
60149e37000SMatthias Springer       //
60249e37000SMatthias Springer       // This is not a conflict because the InsertSliceOp overwrites the
60349e37000SMatthias Springer       // memory segment of %1 with the exact same data. (Effectively, there
60449e37000SMatthias Springer       // is no memory write here.)
60549e37000SMatthias Springer       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
60649e37000SMatthias Springer           state.areEquivalentBufferizedValues(uRead->get(),
60749e37000SMatthias Springer                                               insertSliceOp.source()) &&
60849e37000SMatthias Springer           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
60949e37000SMatthias Springer                                     insertSliceOp))
61049e37000SMatthias Springer         return true;
61149e37000SMatthias Springer 
61249e37000SMatthias Springer     return false;
61349e37000SMatthias Springer   }
61449e37000SMatthias Springer 
61549e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
6169597b16aSMatthias Springer                           BufferizationState &state) const {
61749e37000SMatthias Springer     // insert_slice ops arise from tiling and bufferizing them out-of-place is
61849e37000SMatthias Springer     // generally a deal breaker. When used with loops, this ends up cloning the
61949e37000SMatthias Springer     // whole tensor on every single iteration and is a symptom of a
62049e37000SMatthias Springer     // catastrophically bad scheduling decision.
62149e37000SMatthias Springer     // TODO: be very loud about it or even consider failing the pass.
62249e37000SMatthias Springer     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
62349e37000SMatthias Springer     Location loc = insertSliceOp.getLoc();
62449e37000SMatthias Springer 
62549e37000SMatthias Springer     // When bufferizing out-of-place, `getResultBuffer` allocates.
62649e37000SMatthias Springer     FailureOr<Value> dstMemref =
62749e37000SMatthias Springer         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
62849e37000SMatthias Springer     if (failed(dstMemref))
62949e37000SMatthias Springer       return failure();
63049e37000SMatthias Springer 
63149e37000SMatthias Springer     // Expand offsets, sizes and strides to the full rank to handle the
63249e37000SMatthias Springer     // rank-reducing case.
63349e37000SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
63449e37000SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
63549e37000SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
63649e37000SMatthias Springer     OffsetSizeAndStrideOpInterface::expandToRank(
63749e37000SMatthias Springer         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
63849e37000SMatthias Springer         [&](Value target, int64_t dim) -> OpFoldResult {
63949e37000SMatthias Springer           auto shapedType = target.getType().cast<ShapedType>();
64049e37000SMatthias Springer           if (shapedType.isDynamicDim(dim))
64149e37000SMatthias Springer             return rewriter.create<memref::DimOp>(loc, target, dim).result();
64249e37000SMatthias Springer           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
64349e37000SMatthias Springer         });
64449e37000SMatthias Springer     // Take a subview of the dst.
64549e37000SMatthias Springer     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
64649e37000SMatthias Springer     auto subviewMemRefType =
64749e37000SMatthias Springer         memref::SubViewOp::inferRankReducedResultType(
64849e37000SMatthias Springer             insertSliceOp.getSourceType().getRank(), dstMemrefType,
64949e37000SMatthias Springer             mixedOffsets, mixedSizes, mixedStrides)
65049e37000SMatthias Springer             .cast<MemRefType>();
65149e37000SMatthias Springer     Value subView = rewriter.create<memref::SubViewOp>(
65249e37000SMatthias Springer         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
65349e37000SMatthias Springer         mixedStrides);
65449e37000SMatthias Springer 
65549e37000SMatthias Springer     // Copy tensor. If this tensor.insert_slice has a matching
65649e37000SMatthias Springer     // tensor.extract_slice, the copy operation will eventually fold away.
65749e37000SMatthias Springer     Value srcMemref =
65849e37000SMatthias Springer         *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
65949e37000SMatthias Springer     if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
66049e37000SMatthias Springer                             state.getOptions())))
66149e37000SMatthias Springer       return failure();
66249e37000SMatthias Springer 
66349e37000SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
66449e37000SMatthias Springer     return success();
66549e37000SMatthias Springer   }
66649e37000SMatthias Springer };
66749e37000SMatthias Springer 
668fc08d1c2SMatthias Springer /// Bufferization of tensor.rank. Replace with memref.rank.
669fc08d1c2SMatthias Springer struct RankOpInterface
670fc08d1c2SMatthias Springer     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
671fc08d1c2SMatthias Springer                                                     tensor::RankOp> {
672fc08d1c2SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
6739597b16aSMatthias Springer                               const AnalysisState &state) const {
674fc08d1c2SMatthias Springer     return true;
675fc08d1c2SMatthias Springer   }
676fc08d1c2SMatthias Springer 
677fc08d1c2SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
6789597b16aSMatthias Springer                                const AnalysisState &state) const {
679fc08d1c2SMatthias Springer     return false;
680fc08d1c2SMatthias Springer   }
681fc08d1c2SMatthias Springer 
6829597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
6839597b16aSMatthias Springer                                             const AnalysisState &state) const {
684585a8a32SMatthias Springer     return {};
685fc08d1c2SMatthias Springer   }
686fc08d1c2SMatthias Springer 
687fc08d1c2SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
6889597b16aSMatthias Springer                           BufferizationState &state) const {
689fc08d1c2SMatthias Springer     auto rankOp = cast<tensor::RankOp>(op);
690fc08d1c2SMatthias Springer     Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
691fc08d1c2SMatthias Springer     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
692fc08d1c2SMatthias Springer                                                  v);
693fc08d1c2SMatthias Springer     return success();
694fc08d1c2SMatthias Springer   }
695fc08d1c2SMatthias Springer };
696fc08d1c2SMatthias Springer 
69749e37000SMatthias Springer } // namespace
69849e37000SMatthias Springer } // namespace tensor
69949e37000SMatthias Springer } // namespace mlir
70049e37000SMatthias Springer 
70149e37000SMatthias Springer void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
70249e37000SMatthias Springer     DialectRegistry &registry) {
70349e37000SMatthias Springer   registry.addOpInterface<CastOp, CastOpInterface>();
704e6f69161SMatthias Springer   registry.addOpInterface<CollapseShapeOp, CollapseShapeOpInterface>();
70549e37000SMatthias Springer   registry.addOpInterface<DimOp, DimOpInterface>();
706e6f69161SMatthias Springer   registry.addOpInterface<ExpandShapeOp, ExpandShapeOpInterface>();
70749e37000SMatthias Springer   registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
70849e37000SMatthias Springer   registry.addOpInterface<ExtractOp, ExtractOpInterface>();
709d581c94dSMatthias Springer   registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();
71071bbb78bSMatthias Springer   registry.addOpInterface<GenerateOp, GenerateOpInterface>();
71149e37000SMatthias Springer   registry.addOpInterface<InsertOp, InsertOpInterface>();
71249e37000SMatthias Springer   registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
713fc08d1c2SMatthias Springer   registry.addOpInterface<RankOp, RankOpInterface>();
71449e37000SMatthias Springer }
715