149e37000SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
249e37000SMatthias Springer //
349e37000SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
449e37000SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
549e37000SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
649e37000SMatthias Springer //
749e37000SMatthias Springer //===----------------------------------------------------------------------===//
849e37000SMatthias Springer 
949e37000SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1149e37000SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1249e37000SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
1371bbb78bSMatthias Springer #include "mlir/Dialect/SCF/SCF.h"
1449e37000SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
1549e37000SMatthias Springer #include "mlir/IR/Dialect.h"
1649e37000SMatthias Springer #include "mlir/IR/Operation.h"
1749e37000SMatthias Springer 
1849e37000SMatthias Springer using namespace mlir;
1949e37000SMatthias Springer using namespace mlir::bufferization;
2049e37000SMatthias Springer using namespace mlir::tensor;
2149e37000SMatthias Springer 
2249e37000SMatthias Springer namespace mlir {
2349e37000SMatthias Springer namespace tensor {
2449e37000SMatthias Springer namespace {
2549e37000SMatthias Springer 
2649e37000SMatthias Springer struct CastOpInterface
2749e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
2849e37000SMatthias Springer                                                     tensor::CastOp> {
2949e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
309597b16aSMatthias Springer                               const AnalysisState &state) const {
3149e37000SMatthias Springer     return false;
3249e37000SMatthias Springer   }
3349e37000SMatthias Springer 
3449e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
359597b16aSMatthias Springer                                const AnalysisState &state) const {
3649e37000SMatthias Springer     return false;
3749e37000SMatthias Springer   }
3849e37000SMatthias Springer 
399597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
409597b16aSMatthias Springer                                             const AnalysisState &state) const {
41585a8a32SMatthias Springer     return {op->getResult(0)};
4249e37000SMatthias Springer   }
4349e37000SMatthias Springer 
4449e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
459597b16aSMatthias Springer                                 const AnalysisState &state) const {
4649e37000SMatthias Springer     return BufferRelation::Equivalent;
4749e37000SMatthias Springer   }
4849e37000SMatthias Springer 
4949e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
509597b16aSMatthias Springer                           BufferizationState &state) const {
5149e37000SMatthias Springer     auto castOp = cast<tensor::CastOp>(op);
5249e37000SMatthias Springer 
5349e37000SMatthias Springer     // The result buffer still has the old (pre-cast) type.
5449e37000SMatthias Springer     FailureOr<Value> resultBuffer =
5549e37000SMatthias Springer         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
5649e37000SMatthias Springer     if (failed(resultBuffer))
5749e37000SMatthias Springer       return failure();
5849e37000SMatthias Springer     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
5949e37000SMatthias Springer     Attribute memorySpace = sourceMemRefType.getMemorySpace();
6049e37000SMatthias Springer     TensorType resultTensorType =
6149e37000SMatthias Springer         castOp.getResult().getType().cast<TensorType>();
6249e37000SMatthias Springer     MemRefLayoutAttrInterface layout;
6349e37000SMatthias Springer 
6449e37000SMatthias Springer     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
6549e37000SMatthias Springer       if (resultTensorType.isa<RankedTensorType>())
6649e37000SMatthias Springer         layout = rankedMemRefType.getLayout();
6749e37000SMatthias Springer 
6849e37000SMatthias Springer     // Compute the new memref type.
6926852423SMatthias Springer     Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
7026852423SMatthias Springer                                           layout, memorySpace);
7149e37000SMatthias Springer 
7249e37000SMatthias Springer     // Replace the op with a memref.cast.
7349e37000SMatthias Springer     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
7449e37000SMatthias Springer                                              resultMemRefType) &&
7549e37000SMatthias Springer            "CallOp::bufferize: cast incompatible");
7649e37000SMatthias Springer     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
7749e37000SMatthias Springer                                                  *resultBuffer);
7849e37000SMatthias Springer 
7949e37000SMatthias Springer     return success();
8049e37000SMatthias Springer   }
8149e37000SMatthias Springer };
8249e37000SMatthias Springer 
83e6f69161SMatthias Springer /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
84e6f69161SMatthias Springer struct CollapseShapeOpInterface
85e6f69161SMatthias Springer     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
86e6f69161SMatthias Springer                                                     tensor::CollapseShapeOp> {
87e6f69161SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
889597b16aSMatthias Springer                               const AnalysisState &state) const {
89e6f69161SMatthias Springer     return false;
90e6f69161SMatthias Springer   }
91e6f69161SMatthias Springer 
92e6f69161SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
939597b16aSMatthias Springer                                const AnalysisState &state) const {
94e6f69161SMatthias Springer     return false;
95e6f69161SMatthias Springer   }
96e6f69161SMatthias Springer 
979597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
989597b16aSMatthias Springer                                             const AnalysisState &state) const {
99e6f69161SMatthias Springer     if (&opOperand == &op->getOpOperand(0) /*src*/)
100e6f69161SMatthias Springer       return {op->getOpResult(0)};
101e6f69161SMatthias Springer     return {};
102e6f69161SMatthias Springer   }
103e6f69161SMatthias Springer 
104e6f69161SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
1059597b16aSMatthias Springer                                 const AnalysisState &state) const {
106e6f69161SMatthias Springer     return BufferRelation::Equivalent;
107e6f69161SMatthias Springer   }
108e6f69161SMatthias Springer 
109e6f69161SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1109597b16aSMatthias Springer                           BufferizationState &state) const {
111e6f69161SMatthias Springer     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
11251df6238SMatthias Springer     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
113d7a9bf91SMatthias Springer     OpOperand &srcOperand = collapseShapeOp->getOpOperand(0) /*src*/;
114d7a9bf91SMatthias Springer     auto bufferType = state.getBufferType(srcOperand).cast<MemRefType>();
11551df6238SMatthias Springer 
11651df6238SMatthias Springer     if (tensorResultType.getRank() == 0) {
11751df6238SMatthias Springer       // 0-d collapses must go through a different op builder.
118d7a9bf91SMatthias Springer       Value buffer = *state.getBuffer(rewriter, srcOperand);
11973c0333dSMatthias Springer       MemRefType resultType;
12073c0333dSMatthias Springer 
12173c0333dSMatthias Springer       if (bufferType.getLayout().isIdentity()) {
12273c0333dSMatthias Springer         // Standard layout: result type has no offset.
12351df6238SMatthias Springer         MemRefLayoutAttrInterface layout;
12473c0333dSMatthias Springer         resultType = MemRefType::get({}, tensorResultType.getElementType(),
12551df6238SMatthias Springer                                      layout, bufferType.getMemorySpace());
12673c0333dSMatthias Springer       } else {
12773c0333dSMatthias Springer         // Source memref has a layout map: result type has the same offset as
12873c0333dSMatthias Springer         // the source type.
12973c0333dSMatthias Springer         SmallVector<int64_t> strides;
13073c0333dSMatthias Springer         int64_t offset;
13173c0333dSMatthias Springer         if (failed(getStridesAndOffset(bufferType, strides, offset)))
13273c0333dSMatthias Springer           return failure();
13373c0333dSMatthias Springer         AffineMap resultLayout =
13473c0333dSMatthias Springer             makeStridedLinearLayoutMap({}, offset, op->getContext());
13573c0333dSMatthias Springer         resultType =
13673c0333dSMatthias Springer             MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
13773c0333dSMatthias Springer                             bufferType.getMemorySpaceAsInt());
13873c0333dSMatthias Springer       }
13973c0333dSMatthias Springer 
140e6f69161SMatthias Springer       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
141e6f69161SMatthias Springer           rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
142e6f69161SMatthias Springer       return success();
143e6f69161SMatthias Springer     }
14451df6238SMatthias Springer 
145d7a9bf91SMatthias Springer     // If the dims are not collapsible (due to an incompatible source layout
146d7a9bf91SMatthias Springer     // map), force an out-of-place bufferization, i.e., a buffer copy. This
147d7a9bf91SMatthias Springer     // newly allocated buffer will have no layout map and thus be collapsible.
148a74e5a89SAdrian Kuegel     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
149d7a9bf91SMatthias Springer         bufferType, collapseShapeOp.getReassociationIndices());
150d7a9bf91SMatthias Springer     Optional<BufferizationState::ForceInPlacability> overrideInPlace =
151d7a9bf91SMatthias Springer         canBeCollapsed
152d7a9bf91SMatthias Springer             ? None
153d7a9bf91SMatthias Springer             : Optional<BufferizationState::ForceInPlacability>(
154d7a9bf91SMatthias Springer                   BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE);
155d7a9bf91SMatthias Springer     Value buffer = *state.getBuffer(rewriter, srcOperand, overrideInPlace);
156d7a9bf91SMatthias Springer 
15751df6238SMatthias Springer     // Result type is inferred by the builder.
15851df6238SMatthias Springer     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
15951df6238SMatthias Springer         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
16051df6238SMatthias Springer     return success();
16151df6238SMatthias Springer   }
162e6f69161SMatthias Springer };
163e6f69161SMatthias Springer 
16449e37000SMatthias Springer /// Bufferization of tensor.dim. Replace with memref.dim.
16549e37000SMatthias Springer struct DimOpInterface
16649e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
16749e37000SMatthias Springer                                                     tensor::DimOp> {
16849e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1699597b16aSMatthias Springer                               const AnalysisState &state) const {
17049e37000SMatthias Springer     return true;
17149e37000SMatthias Springer   }
17249e37000SMatthias Springer 
17349e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1749597b16aSMatthias Springer                                const AnalysisState &state) const {
17549e37000SMatthias Springer     return false;
17649e37000SMatthias Springer   }
17749e37000SMatthias Springer 
1789597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
1799597b16aSMatthias Springer                                             const AnalysisState &state) const {
180585a8a32SMatthias Springer     return {};
18149e37000SMatthias Springer   }
18249e37000SMatthias Springer 
18349e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1849597b16aSMatthias Springer                           BufferizationState &state) const {
18549e37000SMatthias Springer     auto dimOp = cast<tensor::DimOp>(op);
18649e37000SMatthias Springer     Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
18749e37000SMatthias Springer     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
18849e37000SMatthias Springer     return success();
18949e37000SMatthias Springer   }
19049e37000SMatthias Springer };
19149e37000SMatthias Springer 
192e6f69161SMatthias Springer /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
193e6f69161SMatthias Springer struct ExpandShapeOpInterface
194e6f69161SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
195e6f69161SMatthias Springer                                                     tensor::ExpandShapeOp> {
196e6f69161SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1979597b16aSMatthias Springer                               const AnalysisState &state) const {
198e6f69161SMatthias Springer     return false;
199e6f69161SMatthias Springer   }
200e6f69161SMatthias Springer 
201e6f69161SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
2029597b16aSMatthias Springer                                const AnalysisState &state) const {
203e6f69161SMatthias Springer     return false;
204e6f69161SMatthias Springer   }
205e6f69161SMatthias Springer 
2069597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
2079597b16aSMatthias Springer                                             const AnalysisState &state) const {
208e6f69161SMatthias Springer     if (&opOperand == &op->getOpOperand(0) /*src*/)
209e6f69161SMatthias Springer       return {op->getOpResult(0)};
210e6f69161SMatthias Springer     return {};
211e6f69161SMatthias Springer   }
212e6f69161SMatthias Springer 
213e6f69161SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
2149597b16aSMatthias Springer                                 const AnalysisState &state) const {
215e6f69161SMatthias Springer     return BufferRelation::Equivalent;
216e6f69161SMatthias Springer   }
217e6f69161SMatthias Springer 
218e6f69161SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
2199597b16aSMatthias Springer                           BufferizationState &state) const {
220e6f69161SMatthias Springer     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
22151df6238SMatthias Springer     auto tensorResultType = expandShapeOp.getResultType();
222e6f69161SMatthias Springer     Value buffer =
223e6f69161SMatthias Springer         *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
22451df6238SMatthias Springer 
22551df6238SMatthias Springer     // Memref result type is inferred by the builder based on reassociation
22651df6238SMatthias Springer     // indices and result shape.
227e6f69161SMatthias Springer     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
22851df6238SMatthias Springer         rewriter, op, tensorResultType.getShape(), buffer,
22951df6238SMatthias Springer         expandShapeOp.getReassociationIndices());
230e6f69161SMatthias Springer     return success();
231e6f69161SMatthias Springer   }
232e6f69161SMatthias Springer };
233e6f69161SMatthias Springer 
23449e37000SMatthias Springer /// Bufferization of tensor.extract_slice. Replace with memref.subview.
23549e37000SMatthias Springer struct ExtractSliceOpInterface
23649e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
23749e37000SMatthias Springer                                                     tensor::ExtractSliceOp> {
23849e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
2399597b16aSMatthias Springer                               const AnalysisState &state) const {
24049e37000SMatthias Springer     return false;
24149e37000SMatthias Springer   }
24249e37000SMatthias Springer 
24349e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
2449597b16aSMatthias Springer                                const AnalysisState &state) const {
24549e37000SMatthias Springer     return false;
24649e37000SMatthias Springer   }
24749e37000SMatthias Springer 
2489597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
2499597b16aSMatthias Springer                                             const AnalysisState &state) const {
250585a8a32SMatthias Springer     if (&opOperand == &op->getOpOperand(0) /*source*/)
251585a8a32SMatthias Springer       return {op->getOpResult(0)};
252585a8a32SMatthias Springer     return {};
25349e37000SMatthias Springer   }
25449e37000SMatthias Springer 
25549e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
2569597b16aSMatthias Springer                                 const AnalysisState &state) const {
25749e37000SMatthias Springer     return BufferRelation::None;
25849e37000SMatthias Springer   }
25949e37000SMatthias Springer 
26049e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
2619597b16aSMatthias Springer                           BufferizationState &state) const {
26249e37000SMatthias Springer     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
26349e37000SMatthias Springer     Location loc = extractSliceOp.getLoc();
264d7a9bf91SMatthias Springer 
265d7a9bf91SMatthias Springer     // Even if this op was decided to bufferize out-of-place, do not insert the
266d7a9bf91SMatthias Springer     // buffer copy yet. This is done later in this function.
26749e37000SMatthias Springer     Value srcMemref =
26849e37000SMatthias Springer         *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
269d7a9bf91SMatthias Springer                          BufferizationState::ForceInPlacability::FORCE_INPLACE);
27049e37000SMatthias Springer     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
27149e37000SMatthias Springer     auto dstTensorType =
27249e37000SMatthias Springer         extractSliceOp.result().getType().cast<RankedTensorType>();
27349e37000SMatthias Springer 
27449e37000SMatthias Springer     // If not inplaceable, alloc.
2759597b16aSMatthias Springer     bool inplace =
2769597b16aSMatthias Springer         state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0));
27749e37000SMatthias Springer     Value alloc;
27849e37000SMatthias Springer     if (!inplace) {
27949e37000SMatthias Springer       FailureOr<Value> allocOrFailure =
28005e0495fSMatthias Springer           state.createAlloc(rewriter, loc, extractSliceOp.result());
28149e37000SMatthias Springer       if (failed(allocOrFailure))
28249e37000SMatthias Springer         return failure();
28349e37000SMatthias Springer       alloc = *allocOrFailure;
28449e37000SMatthias Springer     }
28549e37000SMatthias Springer 
28649e37000SMatthias Springer     // Expand offsets, sizes and strides to the full rank to handle the
28749e37000SMatthias Springer     // rank-reducing case.
28849e37000SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
28949e37000SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
29049e37000SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
29149e37000SMatthias Springer     OffsetSizeAndStrideOpInterface::expandToRank(
29249e37000SMatthias Springer         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
29349e37000SMatthias Springer         [&](Value target, int64_t dim) -> OpFoldResult {
29449e37000SMatthias Springer           auto shapedType = target.getType().cast<ShapedType>();
29549e37000SMatthias Springer           if (shapedType.isDynamicDim(dim))
29649e37000SMatthias Springer             return rewriter.create<memref::DimOp>(loc, target, dim).result();
29749e37000SMatthias Springer           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
29849e37000SMatthias Springer         });
29949e37000SMatthias Springer     // Bufferize to subview.
30049e37000SMatthias Springer     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
30149e37000SMatthias Springer                                  dstTensorType.getRank(), srcMemrefType,
30249e37000SMatthias Springer                                  mixedOffsets, mixedSizes, mixedStrides)
30349e37000SMatthias Springer                                  .cast<MemRefType>();
30449e37000SMatthias Springer     Value subView = rewriter.create<memref::SubViewOp>(
30549e37000SMatthias Springer         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
30649e37000SMatthias Springer         mixedStrides);
30749e37000SMatthias Springer 
30849e37000SMatthias Springer     // If not inplaceable, copy.
30949e37000SMatthias Springer     if (!inplace) {
31049e37000SMatthias Springer       // Do not copy if the copied data is never read.
3119597b16aSMatthias Springer       if (state.getAnalysisState().isValueRead(extractSliceOp.result()))
31249e37000SMatthias Springer         if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
31349e37000SMatthias Springer                                 alloc, state.getOptions())))
31449e37000SMatthias Springer           return failure();
31549e37000SMatthias Springer       subView = alloc;
31649e37000SMatthias Springer     }
31749e37000SMatthias Springer 
31849e37000SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, subView);
31949e37000SMatthias Springer     return success();
32049e37000SMatthias Springer   }
32149e37000SMatthias Springer };
32249e37000SMatthias Springer 
32349e37000SMatthias Springer /// Bufferization of tensor.extract. Replace with memref.load.
32449e37000SMatthias Springer struct ExtractOpInterface
32549e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
32649e37000SMatthias Springer                                                     tensor::ExtractOp> {
32749e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
3289597b16aSMatthias Springer                               const AnalysisState &state) const {
32949e37000SMatthias Springer     return true;
33049e37000SMatthias Springer   }
33149e37000SMatthias Springer 
33249e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
3339597b16aSMatthias Springer                                const AnalysisState &state) const {
33449e37000SMatthias Springer     return false;
33549e37000SMatthias Springer   }
33649e37000SMatthias Springer 
3379597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
3389597b16aSMatthias Springer                                             const AnalysisState &state) const {
339585a8a32SMatthias Springer     return {};
34049e37000SMatthias Springer   }
34149e37000SMatthias Springer 
34249e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
3439597b16aSMatthias Springer                           BufferizationState &state) const {
34449e37000SMatthias Springer     auto extractOp = cast<tensor::ExtractOp>(op);
34549e37000SMatthias Springer     Value srcMemref =
34649e37000SMatthias Springer         *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
34749e37000SMatthias Springer     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
34849e37000SMatthias Springer                                                  extractOp.indices());
34949e37000SMatthias Springer     return success();
35049e37000SMatthias Springer   }
35149e37000SMatthias Springer };
35249e37000SMatthias Springer 
353d581c94dSMatthias Springer // Implements backtracking to traverse indices of the output buffer while
354d581c94dSMatthias Springer // iterating over op.elements().
355d581c94dSMatthias Springer static void createStores(RewriterBase &rewriter, Location loc, int dim,
356d581c94dSMatthias Springer                          Value buffer, ArrayRef<int64_t> shape,
357d581c94dSMatthias Springer                          ArrayRef<Value> constants,
358d581c94dSMatthias Springer                          OperandRange::iterator &elementIt,
359d581c94dSMatthias Springer                          SmallVectorImpl<Value> &indices) {
360d581c94dSMatthias Springer   if (dim == static_cast<int>(shape.size()) - 1) {
361d581c94dSMatthias Springer     for (int i = 0; i < shape.back(); ++i) {
362d581c94dSMatthias Springer       indices.back() = constants[i];
363d581c94dSMatthias Springer       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
364d581c94dSMatthias Springer       ++elementIt;
365d581c94dSMatthias Springer     }
366d581c94dSMatthias Springer     return;
367d581c94dSMatthias Springer   }
368d581c94dSMatthias Springer   for (int i = 0; i < shape[dim]; ++i) {
369d581c94dSMatthias Springer     indices[dim] = constants[i];
370d581c94dSMatthias Springer     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
371d581c94dSMatthias Springer                  indices);
372d581c94dSMatthias Springer   }
373d581c94dSMatthias Springer }
374d581c94dSMatthias Springer 
375d581c94dSMatthias Springer /// Bufferization of tensor.from_elements.
376d581c94dSMatthias Springer struct FromElementsOpInterface
377d581c94dSMatthias Springer     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
378d581c94dSMatthias Springer                                                     tensor::FromElementsOp> {
379d581c94dSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
3809597b16aSMatthias Springer                           BufferizationState &state) const {
381d581c94dSMatthias Springer     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
382d581c94dSMatthias Springer 
383d581c94dSMatthias Springer     // Allocate a buffer for the result.
384d581c94dSMatthias Springer     Location loc = op->getLoc();
385d581c94dSMatthias Springer     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
386d581c94dSMatthias Springer     auto shape = tensorType.getShape();
387d581c94dSMatthias Springer     FailureOr<Value> maybeBuffer =
3889e24f0f4SMatthias Springer         state.createAlloc(rewriter, loc, fromElementsOp.result());
389d581c94dSMatthias Springer     if (failed(maybeBuffer))
390d581c94dSMatthias Springer       return failure();
391d581c94dSMatthias Springer     Value buffer = *maybeBuffer;
392d581c94dSMatthias Springer 
393d581c94dSMatthias Springer     // Case: tensor<0xelem_type>.
394d581c94dSMatthias Springer     if (fromElementsOp.elements().empty()) {
395d581c94dSMatthias Springer       replaceOpWithBufferizedValues(rewriter, op, buffer);
396d581c94dSMatthias Springer       return success();
397d581c94dSMatthias Springer     }
398d581c94dSMatthias Springer 
399d581c94dSMatthias Springer     // Case: tensor<elem_type>.
400d581c94dSMatthias Springer     if (shape.empty()) {
401d581c94dSMatthias Springer       rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
402d581c94dSMatthias Springer                                        buffer);
403d581c94dSMatthias Springer       replaceOpWithBufferizedValues(rewriter, op, buffer);
404d581c94dSMatthias Springer       return success();
405d581c94dSMatthias Springer     }
406d581c94dSMatthias Springer 
407d581c94dSMatthias Springer     // Create constants for the range of possible indices [0, max{shape_i}).
408d581c94dSMatthias Springer     auto maxDim = *std::max_element(shape.begin(), shape.end());
409d581c94dSMatthias Springer     SmallVector<Value, 2> constants;
410d581c94dSMatthias Springer     constants.reserve(maxDim);
411d581c94dSMatthias Springer     for (int i = 0; i < maxDim; ++i)
412d581c94dSMatthias Springer       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
413d581c94dSMatthias Springer 
414d581c94dSMatthias Springer     // Traverse all `elements` and create `memref.store` ops.
415d581c94dSMatthias Springer     auto elementIt = fromElementsOp.elements().begin();
416d581c94dSMatthias Springer     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
417d581c94dSMatthias Springer     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
418d581c94dSMatthias Springer                  indices);
419d581c94dSMatthias Springer 
420d581c94dSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, buffer);
421d581c94dSMatthias Springer     return success();
422d581c94dSMatthias Springer   }
423d581c94dSMatthias Springer };
424d581c94dSMatthias Springer 
42571bbb78bSMatthias Springer /// Bufferization of tensor.generate.
42671bbb78bSMatthias Springer struct GenerateOpInterface
42771bbb78bSMatthias Springer     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
42871bbb78bSMatthias Springer                                                     tensor::GenerateOp> {
42971bbb78bSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
4309597b16aSMatthias Springer                           BufferizationState &state) const {
43171bbb78bSMatthias Springer     auto generateOp = cast<tensor::GenerateOp>(op);
43271bbb78bSMatthias Springer 
43371bbb78bSMatthias Springer     // Allocate memory.
43471bbb78bSMatthias Springer     Location loc = op->getLoc();
43571bbb78bSMatthias Springer     MemRefType memrefType =
43671bbb78bSMatthias Springer         getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
4379e24f0f4SMatthias Springer     FailureOr<Value> maybeResult =
4389e24f0f4SMatthias Springer         state.createAlloc(rewriter, loc, generateOp.result());
43971bbb78bSMatthias Springer     if (failed(maybeResult))
44071bbb78bSMatthias Springer       return failure();
44171bbb78bSMatthias Springer     Value result = *maybeResult;
44271bbb78bSMatthias Springer 
44371bbb78bSMatthias Springer     // Collect loop bounds.
44471bbb78bSMatthias Springer     int64_t rank = memrefType.getRank();
44571bbb78bSMatthias Springer     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
44671bbb78bSMatthias Springer     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
44771bbb78bSMatthias Springer     SmallVector<Value, 4> lowerBounds(rank, zero);
44871bbb78bSMatthias Springer     SmallVector<Value, 4> steps(rank, one);
44971bbb78bSMatthias Springer     SmallVector<Value, 4> upperBounds;
45071bbb78bSMatthias Springer     int nextDynamicIndex = 0;
45171bbb78bSMatthias Springer     for (int i = 0; i < rank; i++) {
45271bbb78bSMatthias Springer       Value upperBound = memrefType.isDynamicDim(i)
45371bbb78bSMatthias Springer                              ? generateOp.dynamicExtents()[nextDynamicIndex++]
45471bbb78bSMatthias Springer                              : rewriter.create<arith::ConstantIndexOp>(
45571bbb78bSMatthias Springer                                    loc, memrefType.getDimSize(i));
45671bbb78bSMatthias Springer       upperBounds.push_back(upperBound);
45771bbb78bSMatthias Springer     }
45871bbb78bSMatthias Springer 
45971bbb78bSMatthias Springer     // Generate tensor elements with a parallel loop that stores into
46071bbb78bSMatthias Springer     // each element of the resulting memref. We use mergeBlockBefore to "move"
46171bbb78bSMatthias Springer     // this op's body into the scf.parallel's body.
46271bbb78bSMatthias Springer     auto parallel =
46371bbb78bSMatthias Springer         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
46471bbb78bSMatthias Springer     Block *parallelBody = parallel.getBody();
46571bbb78bSMatthias Springer     rewriter.mergeBlockBefore(generateOp.getBody(),
46671bbb78bSMatthias Springer                               parallelBody->getTerminator(),
46771bbb78bSMatthias Springer                               parallelBody->getArguments());
46871bbb78bSMatthias Springer     // Replace the inlined yield op with a store op. The scf.parallel's builder
46971bbb78bSMatthias Springer     // already populated an scf.yield at the end, so we don't need to worry
47071bbb78bSMatthias Springer     // about creating that.
47171bbb78bSMatthias Springer     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
47271bbb78bSMatthias Springer     rewriter.setInsertionPointAfter(elementYield);
47371bbb78bSMatthias Springer     rewriter.replaceOpWithNewOp<memref::StoreOp>(
47471bbb78bSMatthias Springer         elementYield, elementYield->getOperands()[0], result,
47571bbb78bSMatthias Springer         parallelBody->getArguments());
47671bbb78bSMatthias Springer 
47771bbb78bSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, result);
47871bbb78bSMatthias Springer     return success();
47971bbb78bSMatthias Springer   }
48071bbb78bSMatthias Springer };
48171bbb78bSMatthias Springer 
48249e37000SMatthias Springer /// Bufferization of tensor.insert. Replace with memref.store.
48349e37000SMatthias Springer struct InsertOpInterface
48449e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
48549e37000SMatthias Springer                                                     tensor::InsertOp> {
48649e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
4879597b16aSMatthias Springer                               const AnalysisState &state) const {
48849e37000SMatthias Springer     return true;
48949e37000SMatthias Springer   }
49049e37000SMatthias Springer 
49149e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
4929597b16aSMatthias Springer                                const AnalysisState &state) const {
49349e37000SMatthias Springer     return true;
49449e37000SMatthias Springer   }
49549e37000SMatthias Springer 
4969597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
4979597b16aSMatthias Springer                                             const AnalysisState &state) const {
49849e37000SMatthias Springer     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
49949e37000SMatthias Springer            "expected dest OpOperand");
500585a8a32SMatthias Springer     return {op->getOpResult(0)};
50149e37000SMatthias Springer   }
50249e37000SMatthias Springer 
50349e37000SMatthias Springer   SmallVector<OpOperand *>
50449e37000SMatthias Springer   getAliasingOpOperand(Operation *op, OpResult opResult,
5059597b16aSMatthias Springer                        const AnalysisState &state) const {
50649e37000SMatthias Springer     return {&op->getOpOperand(1) /*dest*/};
50749e37000SMatthias Springer   }
50849e37000SMatthias Springer 
50949e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
5109597b16aSMatthias Springer                           BufferizationState &state) const {
51149e37000SMatthias Springer     auto insertOp = cast<tensor::InsertOp>(op);
51249e37000SMatthias Springer     FailureOr<Value> destMemref =
51349e37000SMatthias Springer         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
51449e37000SMatthias Springer     if (failed(destMemref))
51549e37000SMatthias Springer       return failure();
51649e37000SMatthias Springer     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
51749e37000SMatthias Springer                                      *destMemref, insertOp.indices());
51849e37000SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
51949e37000SMatthias Springer     return success();
52049e37000SMatthias Springer   }
52149e37000SMatthias Springer 
52249e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
5239597b16aSMatthias Springer                                 const AnalysisState &state) const {
52449e37000SMatthias Springer     return BufferRelation::Equivalent;
52549e37000SMatthias Springer   }
52649e37000SMatthias Springer };
52749e37000SMatthias Springer 
52849e37000SMatthias Springer /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
52949e37000SMatthias Springer /// equivalent operand / result and same offset/sizes/strides specification).
53049e37000SMatthias Springer ///
53149e37000SMatthias Springer /// This is one particular type of relationship between ops on tensors that
53249e37000SMatthias Springer /// reduce to an equivalence on buffers. This should be generalized and
53349e37000SMatthias Springer /// exposed as interfaces on the proper types.
5349597b16aSMatthias Springer static bool areEquivalentExtractSliceOps(const AnalysisState &state,
53549e37000SMatthias Springer                                          ExtractSliceOp st, InsertSliceOp sti) {
53649e37000SMatthias Springer   if (!st || !sti)
53749e37000SMatthias Springer     return false;
53849e37000SMatthias Springer   if (sti != sti &&
53949e37000SMatthias Springer       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
54049e37000SMatthias Springer     return false;
54149e37000SMatthias Springer   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
54249e37000SMatthias Springer     return false;
54349e37000SMatthias Springer   return true;
54449e37000SMatthias Springer }
54549e37000SMatthias Springer 
54649e37000SMatthias Springer /// Return true if `value` is originating from an ExtractSliceOp that matches
54749e37000SMatthias Springer /// the given InsertSliceOp.
5489597b16aSMatthias Springer static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
5499597b16aSMatthias Springer                                       InsertSliceOp insertOp) {
55049e37000SMatthias Springer   auto condition = [&](Value val) {
55149e37000SMatthias Springer     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
55249e37000SMatthias Springer       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
55349e37000SMatthias Springer         return true;
55449e37000SMatthias Springer     return false;
55549e37000SMatthias Springer   };
55649e37000SMatthias Springer 
55749e37000SMatthias Springer   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
55849e37000SMatthias Springer                       condition);
55949e37000SMatthias Springer }
56049e37000SMatthias Springer 
56149e37000SMatthias Springer /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
56249e37000SMatthias Springer /// certain circumstances, this op can also be a no-op.
56349e37000SMatthias Springer struct InsertSliceOpInterface
56449e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
56549e37000SMatthias Springer                                                     tensor::InsertSliceOp> {
56649e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
5679597b16aSMatthias Springer                               const AnalysisState &state) const {
56849e37000SMatthias Springer     return true;
56949e37000SMatthias Springer   }
57049e37000SMatthias Springer 
57149e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
5729597b16aSMatthias Springer                                const AnalysisState &state) const {
57349e37000SMatthias Springer     return &opOperand == &op->getOpOperand(1) /*dest*/;
57449e37000SMatthias Springer   }
57549e37000SMatthias Springer 
5769597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
5779597b16aSMatthias Springer                                             const AnalysisState &state) const {
578585a8a32SMatthias Springer     if (&opOperand == &op->getOpOperand(1) /*dest*/)
579585a8a32SMatthias Springer       return {op->getResult(0)};
580585a8a32SMatthias Springer     return {};
58149e37000SMatthias Springer   }
58249e37000SMatthias Springer 
58349e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
5849597b16aSMatthias Springer                                 const AnalysisState &state) const {
58549e37000SMatthias Springer     return BufferRelation::Equivalent;
58649e37000SMatthias Springer   }
58749e37000SMatthias Springer 
58849e37000SMatthias Springer   bool isNotConflicting(Operation *op, OpOperand *uRead,
58949e37000SMatthias Springer                         OpOperand *uConflictingWrite,
5909597b16aSMatthias Springer                         const AnalysisState &state) const {
59149e37000SMatthias Springer     Operation *readingOp = uRead->getOwner();
59249e37000SMatthias Springer     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
59349e37000SMatthias Springer 
59449e37000SMatthias Springer     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
59549e37000SMatthias Springer     // uRead is an InsertSliceOp...
59649e37000SMatthias Springer     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
59749e37000SMatthias Springer       // As an example, consider the following IR.
59849e37000SMatthias Springer       //
59949e37000SMatthias Springer       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
60049e37000SMatthias Springer       // %1 = linalg.fill %cst, %0 {inplace= [true] }
60149e37000SMatthias Springer       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
60249e37000SMatthias Springer       //     {inplace= [true] }
60349e37000SMatthias Springer 
60449e37000SMatthias Springer       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
60549e37000SMatthias Springer       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
60649e37000SMatthias Springer           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
60749e37000SMatthias Springer                                     insertSliceOp))
60849e37000SMatthias Springer         // Case 1: The main insight is that InsertSliceOp reads only part of
60949e37000SMatthias Springer         // the destination tensor. The overwritten area is not read. If
61049e37000SMatthias Springer         // uConflictingWrite writes into exactly the memory location that is
61149e37000SMatthias Springer         // being read by uRead, this is not a conflict.
61249e37000SMatthias Springer         //
61349e37000SMatthias Springer         // In the above example:
61449e37000SMatthias Springer         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
61549e37000SMatthias Springer         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
61649e37000SMatthias Springer         //
61749e37000SMatthias Springer         // The read of %t does not conflict with the write of the FillOp
61849e37000SMatthias Springer         // (same aliases!) because the area that the FillOp operates on is
61949e37000SMatthias Springer         // exactly the one that is *not* read via %t.
62049e37000SMatthias Springer         return true;
62149e37000SMatthias Springer 
62249e37000SMatthias Springer       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
62349e37000SMatthias Springer           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
62449e37000SMatthias Springer           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
62549e37000SMatthias Springer         // Case 2: The read of the source tensor and the write to the dest
62649e37000SMatthias Springer         // tensor via an InsertSliceOp is not a conflict if the read is
62749e37000SMatthias Springer         // reading exactly that part of an equivalent tensor that the
62849e37000SMatthias Springer         // InsertSliceOp is writing.
62949e37000SMatthias Springer         //
63049e37000SMatthias Springer         // In the above example:
63149e37000SMatthias Springer         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
63249e37000SMatthias Springer         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
63349e37000SMatthias Springer         return true;
63449e37000SMatthias Springer     }
63549e37000SMatthias Springer 
63649e37000SMatthias Springer     // If uConflictingWrite is an InsertSliceOp...
63749e37000SMatthias Springer     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
63849e37000SMatthias Springer       // As an example, consider the following IR.
63949e37000SMatthias Springer       //
64049e37000SMatthias Springer       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
64149e37000SMatthias Springer       // %1 = linalg.fill %cst, %0 {inplace= [true] }
64249e37000SMatthias Springer       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
64349e37000SMatthias Springer       //     {inplace= [true] }
64449e37000SMatthias Springer       // %3 = vector.transfer_read %1, %cst
64549e37000SMatthias Springer       //
64649e37000SMatthias Springer       // In the above example:
64749e37000SMatthias Springer       // uRead             = OpOperand 0 (%1) of vector.transfer_read
64849e37000SMatthias Springer       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
64949e37000SMatthias Springer       // lastWrite         = %1
65049e37000SMatthias Springer       //
65149e37000SMatthias Springer       // This is not a conflict because the InsertSliceOp overwrites the
65249e37000SMatthias Springer       // memory segment of %1 with the exact same data. (Effectively, there
65349e37000SMatthias Springer       // is no memory write here.)
65449e37000SMatthias Springer       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
65549e37000SMatthias Springer           state.areEquivalentBufferizedValues(uRead->get(),
65649e37000SMatthias Springer                                               insertSliceOp.source()) &&
65749e37000SMatthias Springer           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
65849e37000SMatthias Springer                                     insertSliceOp))
65949e37000SMatthias Springer         return true;
66049e37000SMatthias Springer 
66149e37000SMatthias Springer     return false;
66249e37000SMatthias Springer   }
66349e37000SMatthias Springer 
66449e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
6659597b16aSMatthias Springer                           BufferizationState &state) const {
66649e37000SMatthias Springer     // insert_slice ops arise from tiling and bufferizing them out-of-place is
66749e37000SMatthias Springer     // generally a deal breaker. When used with loops, this ends up cloning the
66849e37000SMatthias Springer     // whole tensor on every single iteration and is a symptom of a
66949e37000SMatthias Springer     // catastrophically bad scheduling decision.
67049e37000SMatthias Springer     // TODO: be very loud about it or even consider failing the pass.
67149e37000SMatthias Springer     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
67249e37000SMatthias Springer     Location loc = insertSliceOp.getLoc();
67349e37000SMatthias Springer 
67449e37000SMatthias Springer     // When bufferizing out-of-place, `getResultBuffer` allocates.
67549e37000SMatthias Springer     FailureOr<Value> dstMemref =
67649e37000SMatthias Springer         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
67749e37000SMatthias Springer     if (failed(dstMemref))
67849e37000SMatthias Springer       return failure();
67949e37000SMatthias Springer 
68049e37000SMatthias Springer     // Expand offsets, sizes and strides to the full rank to handle the
68149e37000SMatthias Springer     // rank-reducing case.
68249e37000SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
68349e37000SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
68449e37000SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
68549e37000SMatthias Springer     OffsetSizeAndStrideOpInterface::expandToRank(
68649e37000SMatthias Springer         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
68749e37000SMatthias Springer         [&](Value target, int64_t dim) -> OpFoldResult {
68849e37000SMatthias Springer           auto shapedType = target.getType().cast<ShapedType>();
68949e37000SMatthias Springer           if (shapedType.isDynamicDim(dim))
69049e37000SMatthias Springer             return rewriter.create<memref::DimOp>(loc, target, dim).result();
69149e37000SMatthias Springer           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
69249e37000SMatthias Springer         });
69349e37000SMatthias Springer     // Take a subview of the dst.
69449e37000SMatthias Springer     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
69549e37000SMatthias Springer     auto subviewMemRefType =
69649e37000SMatthias Springer         memref::SubViewOp::inferRankReducedResultType(
69749e37000SMatthias Springer             insertSliceOp.getSourceType().getRank(), dstMemrefType,
69849e37000SMatthias Springer             mixedOffsets, mixedSizes, mixedStrides)
69949e37000SMatthias Springer             .cast<MemRefType>();
70049e37000SMatthias Springer     Value subView = rewriter.create<memref::SubViewOp>(
70149e37000SMatthias Springer         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
70249e37000SMatthias Springer         mixedStrides);
70349e37000SMatthias Springer 
70449e37000SMatthias Springer     // Copy tensor. If this tensor.insert_slice has a matching
70549e37000SMatthias Springer     // tensor.extract_slice, the copy operation will eventually fold away.
70649e37000SMatthias Springer     Value srcMemref =
70749e37000SMatthias Springer         *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
70849e37000SMatthias Springer     if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
70949e37000SMatthias Springer                             state.getOptions())))
71049e37000SMatthias Springer       return failure();
71149e37000SMatthias Springer 
71249e37000SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
71349e37000SMatthias Springer     return success();
71449e37000SMatthias Springer   }
71549e37000SMatthias Springer };
71649e37000SMatthias Springer 
717fc08d1c2SMatthias Springer /// Bufferization of tensor.rank. Replace with memref.rank.
718fc08d1c2SMatthias Springer struct RankOpInterface
719fc08d1c2SMatthias Springer     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
720fc08d1c2SMatthias Springer                                                     tensor::RankOp> {
721fc08d1c2SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
7229597b16aSMatthias Springer                               const AnalysisState &state) const {
723fc08d1c2SMatthias Springer     return true;
724fc08d1c2SMatthias Springer   }
725fc08d1c2SMatthias Springer 
726fc08d1c2SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
7279597b16aSMatthias Springer                                const AnalysisState &state) const {
728fc08d1c2SMatthias Springer     return false;
729fc08d1c2SMatthias Springer   }
730fc08d1c2SMatthias Springer 
7319597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
7329597b16aSMatthias Springer                                             const AnalysisState &state) const {
733585a8a32SMatthias Springer     return {};
734fc08d1c2SMatthias Springer   }
735fc08d1c2SMatthias Springer 
736fc08d1c2SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
7379597b16aSMatthias Springer                           BufferizationState &state) const {
738fc08d1c2SMatthias Springer     auto rankOp = cast<tensor::RankOp>(op);
739fc08d1c2SMatthias Springer     Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
740fc08d1c2SMatthias Springer     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
741fc08d1c2SMatthias Springer                                                  v);
742fc08d1c2SMatthias Springer     return success();
743fc08d1c2SMatthias Springer   }
744fc08d1c2SMatthias Springer };
745fc08d1c2SMatthias Springer 
746*e287d647SAshay Rane /// Bufferization of tensor.reshape. Replace with memref.reshape.
747*e287d647SAshay Rane struct ReshapeOpInterface
748*e287d647SAshay Rane     : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
749*e287d647SAshay Rane                                                     tensor::ReshapeOp> {
750*e287d647SAshay Rane   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
751*e287d647SAshay Rane                               const AnalysisState &state) const {
752*e287d647SAshay Rane     if (&opOperand == &op->getOpOperand(1) /* shape */)
753*e287d647SAshay Rane       return true;
754*e287d647SAshay Rane     return false;
755*e287d647SAshay Rane   }
756*e287d647SAshay Rane 
757*e287d647SAshay Rane   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
758*e287d647SAshay Rane                                const AnalysisState &state) const {
759*e287d647SAshay Rane     return false;
760*e287d647SAshay Rane   }
761*e287d647SAshay Rane 
762*e287d647SAshay Rane   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
763*e287d647SAshay Rane                                             const AnalysisState &state) const {
764*e287d647SAshay Rane     return {op->getOpResult(0)};
765*e287d647SAshay Rane   }
766*e287d647SAshay Rane 
767*e287d647SAshay Rane   BufferRelation bufferRelation(Operation *op, OpResult opResult,
768*e287d647SAshay Rane                                 const AnalysisState &state) const {
769*e287d647SAshay Rane     return BufferRelation::Equivalent;
770*e287d647SAshay Rane   }
771*e287d647SAshay Rane 
772*e287d647SAshay Rane   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
773*e287d647SAshay Rane                           BufferizationState &state) const {
774*e287d647SAshay Rane     auto reshapeOp = cast<tensor::ReshapeOp>(op);
775*e287d647SAshay Rane     auto &srcOperand = reshapeOp->getOpOperand(0);
776*e287d647SAshay Rane     auto srcBuffer = state.getBuffer(rewriter, srcOperand);
777*e287d647SAshay Rane     if (failed(srcBuffer))
778*e287d647SAshay Rane       return failure();
779*e287d647SAshay Rane 
780*e287d647SAshay Rane     auto &shapeOperand = reshapeOp->getOpOperand(1);
781*e287d647SAshay Rane     auto shapeBuffer = state.getBuffer(rewriter, shapeOperand);
782*e287d647SAshay Rane     if (failed(shapeBuffer))
783*e287d647SAshay Rane       return failure();
784*e287d647SAshay Rane 
785*e287d647SAshay Rane     auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
786*e287d647SAshay Rane     auto resultMemRefType = getMemRefType(resultTensorType, state.getOptions());
787*e287d647SAshay Rane 
788*e287d647SAshay Rane     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
789*e287d647SAshay Rane         rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
790*e287d647SAshay Rane     return success();
791*e287d647SAshay Rane   }
792*e287d647SAshay Rane };
793*e287d647SAshay Rane 
79449e37000SMatthias Springer } // namespace
79549e37000SMatthias Springer } // namespace tensor
79649e37000SMatthias Springer } // namespace mlir
79749e37000SMatthias Springer 
79849e37000SMatthias Springer void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
79949e37000SMatthias Springer     DialectRegistry &registry) {
80077eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
80177eee579SRiver Riddle     CastOp::attachInterface<CastOpInterface>(*ctx);
80277eee579SRiver Riddle     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
80377eee579SRiver Riddle     DimOp::attachInterface<DimOpInterface>(*ctx);
80477eee579SRiver Riddle     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
80577eee579SRiver Riddle     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
80677eee579SRiver Riddle     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
80777eee579SRiver Riddle     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
80877eee579SRiver Riddle     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
80977eee579SRiver Riddle     InsertOp::attachInterface<InsertOpInterface>(*ctx);
81077eee579SRiver Riddle     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
81177eee579SRiver Riddle     RankOp::attachInterface<RankOpInterface>(*ctx);
812*e287d647SAshay Rane     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
81377eee579SRiver Riddle   });
81449e37000SMatthias Springer }
815