149e37000SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
249e37000SMatthias Springer //
349e37000SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
449e37000SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
549e37000SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
649e37000SMatthias Springer //
749e37000SMatthias Springer //===----------------------------------------------------------------------===//
849e37000SMatthias Springer 
949e37000SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
1049e37000SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1149e37000SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
1271bbb78bSMatthias Springer #include "mlir/Dialect/SCF/SCF.h"
1349e37000SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
1449e37000SMatthias Springer #include "mlir/IR/Dialect.h"
1549e37000SMatthias Springer #include "mlir/IR/Operation.h"
1649e37000SMatthias Springer 
1749e37000SMatthias Springer using namespace mlir;
1849e37000SMatthias Springer using namespace mlir::bufferization;
1949e37000SMatthias Springer using namespace mlir::tensor;
2049e37000SMatthias Springer 
2149e37000SMatthias Springer namespace mlir {
2249e37000SMatthias Springer namespace tensor {
2349e37000SMatthias Springer namespace {
2449e37000SMatthias Springer 
2549e37000SMatthias Springer struct CastOpInterface
2649e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
2749e37000SMatthias Springer                                                     tensor::CastOp> {
2849e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
2949e37000SMatthias Springer                               const BufferizationState &state) const {
3049e37000SMatthias Springer     return false;
3149e37000SMatthias Springer   }
3249e37000SMatthias Springer 
3349e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
3449e37000SMatthias Springer                                const BufferizationState &state) const {
3549e37000SMatthias Springer     return false;
3649e37000SMatthias Springer   }
3749e37000SMatthias Springer 
3849e37000SMatthias Springer   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
3949e37000SMatthias Springer                                const BufferizationState &state) const {
4049e37000SMatthias Springer     return op->getResult(0);
4149e37000SMatthias Springer   }
4249e37000SMatthias Springer 
4349e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
4449e37000SMatthias Springer                                 const BufferizationState &state) const {
4549e37000SMatthias Springer     return BufferRelation::Equivalent;
4649e37000SMatthias Springer   }
4749e37000SMatthias Springer 
4849e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
4949e37000SMatthias Springer                           const BufferizationState &state) const {
5049e37000SMatthias Springer     auto castOp = cast<tensor::CastOp>(op);
5149e37000SMatthias Springer 
5249e37000SMatthias Springer     // The result buffer still has the old (pre-cast) type.
5349e37000SMatthias Springer     FailureOr<Value> resultBuffer =
5449e37000SMatthias Springer         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
5549e37000SMatthias Springer     if (failed(resultBuffer))
5649e37000SMatthias Springer       return failure();
5749e37000SMatthias Springer     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
5849e37000SMatthias Springer     Attribute memorySpace = sourceMemRefType.getMemorySpace();
5949e37000SMatthias Springer     TensorType resultTensorType =
6049e37000SMatthias Springer         castOp.getResult().getType().cast<TensorType>();
6149e37000SMatthias Springer     MemRefLayoutAttrInterface layout;
6249e37000SMatthias Springer 
6349e37000SMatthias Springer     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
6449e37000SMatthias Springer       if (resultTensorType.isa<RankedTensorType>())
6549e37000SMatthias Springer         layout = rankedMemRefType.getLayout();
6649e37000SMatthias Springer 
6749e37000SMatthias Springer     // Compute the new memref type.
6849e37000SMatthias Springer     Type resultMemRefType;
6949e37000SMatthias Springer     if (resultTensorType.isa<RankedTensorType>()) {
7049e37000SMatthias Springer       resultMemRefType =
7149e37000SMatthias Springer           getContiguousMemRefType(resultTensorType, layout, memorySpace);
7249e37000SMatthias Springer     } else {
7349e37000SMatthias Springer       resultMemRefType =
7449e37000SMatthias Springer           getUnrankedMemRefType(resultTensorType.getElementType(), memorySpace);
7549e37000SMatthias Springer     }
7649e37000SMatthias Springer 
7749e37000SMatthias Springer     // Replace the op with a memref.cast.
7849e37000SMatthias Springer     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
7949e37000SMatthias Springer                                              resultMemRefType) &&
8049e37000SMatthias Springer            "CallOp::bufferize: cast incompatible");
8149e37000SMatthias Springer     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
8249e37000SMatthias Springer                                                  *resultBuffer);
8349e37000SMatthias Springer 
8449e37000SMatthias Springer     return success();
8549e37000SMatthias Springer   }
8649e37000SMatthias Springer };
8749e37000SMatthias Springer 
8849e37000SMatthias Springer /// Bufferization of tensor.dim. Replace with memref.dim.
8949e37000SMatthias Springer struct DimOpInterface
9049e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
9149e37000SMatthias Springer                                                     tensor::DimOp> {
9249e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
9349e37000SMatthias Springer                               const BufferizationState &state) const {
9449e37000SMatthias Springer     return true;
9549e37000SMatthias Springer   }
9649e37000SMatthias Springer 
9749e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
9849e37000SMatthias Springer                                const BufferizationState &state) const {
9949e37000SMatthias Springer     return false;
10049e37000SMatthias Springer   }
10149e37000SMatthias Springer 
10249e37000SMatthias Springer   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
10349e37000SMatthias Springer                                const BufferizationState &state) const {
10449e37000SMatthias Springer     return OpResult();
10549e37000SMatthias Springer   }
10649e37000SMatthias Springer 
10749e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
10849e37000SMatthias Springer                           const BufferizationState &state) const {
10949e37000SMatthias Springer     auto dimOp = cast<tensor::DimOp>(op);
11049e37000SMatthias Springer     Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
11149e37000SMatthias Springer     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
11249e37000SMatthias Springer     return success();
11349e37000SMatthias Springer   }
11449e37000SMatthias Springer };
11549e37000SMatthias Springer 
11649e37000SMatthias Springer /// Bufferization of tensor.extract_slice. Replace with memref.subview.
11749e37000SMatthias Springer struct ExtractSliceOpInterface
11849e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
11949e37000SMatthias Springer                                                     tensor::ExtractSliceOp> {
12049e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
12149e37000SMatthias Springer                               const BufferizationState &state) const {
12249e37000SMatthias Springer     return false;
12349e37000SMatthias Springer   }
12449e37000SMatthias Springer 
12549e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
12649e37000SMatthias Springer                                const BufferizationState &state) const {
12749e37000SMatthias Springer     return false;
12849e37000SMatthias Springer   }
12949e37000SMatthias Springer 
13049e37000SMatthias Springer   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
13149e37000SMatthias Springer                                const BufferizationState &state) const {
13249e37000SMatthias Springer     return &opOperand == &op->getOpOperand(0) /*source*/
13349e37000SMatthias Springer                ? op->getResult(0)
13449e37000SMatthias Springer                : OpResult();
13549e37000SMatthias Springer   }
13649e37000SMatthias Springer 
13749e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
13849e37000SMatthias Springer                                 const BufferizationState &state) const {
13949e37000SMatthias Springer     return BufferRelation::None;
14049e37000SMatthias Springer   }
14149e37000SMatthias Springer 
14249e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
14349e37000SMatthias Springer                           const BufferizationState &state) const {
14449e37000SMatthias Springer     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
14549e37000SMatthias Springer     Location loc = extractSliceOp.getLoc();
14649e37000SMatthias Springer     Value srcMemref =
14749e37000SMatthias Springer         *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
14849e37000SMatthias Springer                          /*forceInPlace=*/true);
14949e37000SMatthias Springer     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
15049e37000SMatthias Springer     auto dstTensorType =
15149e37000SMatthias Springer         extractSliceOp.result().getType().cast<RankedTensorType>();
15249e37000SMatthias Springer 
15349e37000SMatthias Springer     // If not inplaceable, alloc.
15449e37000SMatthias Springer     bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0));
15549e37000SMatthias Springer     Value alloc;
15649e37000SMatthias Springer     if (!inplace) {
15749e37000SMatthias Springer       FailureOr<Value> allocOrFailure =
15849e37000SMatthias Springer           createAlloc(rewriter, loc, extractSliceOp.result(),
15949e37000SMatthias Springer                       state.getOptions().createDeallocs, state.getOptions());
16049e37000SMatthias Springer       if (failed(allocOrFailure))
16149e37000SMatthias Springer         return failure();
16249e37000SMatthias Springer       alloc = *allocOrFailure;
16349e37000SMatthias Springer     }
16449e37000SMatthias Springer 
16549e37000SMatthias Springer     // Expand offsets, sizes and strides to the full rank to handle the
16649e37000SMatthias Springer     // rank-reducing case.
16749e37000SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
16849e37000SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
16949e37000SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
17049e37000SMatthias Springer     OffsetSizeAndStrideOpInterface::expandToRank(
17149e37000SMatthias Springer         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
17249e37000SMatthias Springer         [&](Value target, int64_t dim) -> OpFoldResult {
17349e37000SMatthias Springer           auto shapedType = target.getType().cast<ShapedType>();
17449e37000SMatthias Springer           if (shapedType.isDynamicDim(dim))
17549e37000SMatthias Springer             return rewriter.create<memref::DimOp>(loc, target, dim).result();
17649e37000SMatthias Springer           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
17749e37000SMatthias Springer         });
17849e37000SMatthias Springer     // Bufferize to subview.
17949e37000SMatthias Springer     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
18049e37000SMatthias Springer                                  dstTensorType.getRank(), srcMemrefType,
18149e37000SMatthias Springer                                  mixedOffsets, mixedSizes, mixedStrides)
18249e37000SMatthias Springer                                  .cast<MemRefType>();
18349e37000SMatthias Springer     Value subView = rewriter.create<memref::SubViewOp>(
18449e37000SMatthias Springer         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
18549e37000SMatthias Springer         mixedStrides);
18649e37000SMatthias Springer 
18749e37000SMatthias Springer     // If not inplaceable, copy.
18849e37000SMatthias Springer     if (!inplace) {
18949e37000SMatthias Springer       // Do not copy if the copied data is never read.
19049e37000SMatthias Springer       if (state.isValueRead(extractSliceOp.result()))
19149e37000SMatthias Springer         if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
19249e37000SMatthias Springer                                 alloc, state.getOptions())))
19349e37000SMatthias Springer           return failure();
19449e37000SMatthias Springer       subView = alloc;
19549e37000SMatthias Springer     }
19649e37000SMatthias Springer 
19749e37000SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, subView);
19849e37000SMatthias Springer     return success();
19949e37000SMatthias Springer   }
20049e37000SMatthias Springer };
20149e37000SMatthias Springer 
20249e37000SMatthias Springer /// Bufferization of tensor.extract. Replace with memref.load.
20349e37000SMatthias Springer struct ExtractOpInterface
20449e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
20549e37000SMatthias Springer                                                     tensor::ExtractOp> {
20649e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
20749e37000SMatthias Springer                               const BufferizationState &state) const {
20849e37000SMatthias Springer     return true;
20949e37000SMatthias Springer   }
21049e37000SMatthias Springer 
21149e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
21249e37000SMatthias Springer                                const BufferizationState &state) const {
21349e37000SMatthias Springer     return false;
21449e37000SMatthias Springer   }
21549e37000SMatthias Springer 
21649e37000SMatthias Springer   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
21749e37000SMatthias Springer                                const BufferizationState &state) const {
21849e37000SMatthias Springer     return OpResult();
21949e37000SMatthias Springer   }
22049e37000SMatthias Springer 
22149e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
22249e37000SMatthias Springer                           const BufferizationState &state) const {
22349e37000SMatthias Springer     auto extractOp = cast<tensor::ExtractOp>(op);
22449e37000SMatthias Springer     Value srcMemref =
22549e37000SMatthias Springer         *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
22649e37000SMatthias Springer     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
22749e37000SMatthias Springer                                                  extractOp.indices());
22849e37000SMatthias Springer     return success();
22949e37000SMatthias Springer   }
23049e37000SMatthias Springer };
23149e37000SMatthias Springer 
232*d581c94dSMatthias Springer // Implements backtracking to traverse indices of the output buffer while
233*d581c94dSMatthias Springer // iterating over op.elements().
234*d581c94dSMatthias Springer static void createStores(RewriterBase &rewriter, Location loc, int dim,
235*d581c94dSMatthias Springer                          Value buffer, ArrayRef<int64_t> shape,
236*d581c94dSMatthias Springer                          ArrayRef<Value> constants,
237*d581c94dSMatthias Springer                          OperandRange::iterator &elementIt,
238*d581c94dSMatthias Springer                          SmallVectorImpl<Value> &indices) {
239*d581c94dSMatthias Springer   if (dim == static_cast<int>(shape.size()) - 1) {
240*d581c94dSMatthias Springer     for (int i = 0; i < shape.back(); ++i) {
241*d581c94dSMatthias Springer       indices.back() = constants[i];
242*d581c94dSMatthias Springer       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
243*d581c94dSMatthias Springer       ++elementIt;
244*d581c94dSMatthias Springer     }
245*d581c94dSMatthias Springer     return;
246*d581c94dSMatthias Springer   }
247*d581c94dSMatthias Springer   for (int i = 0; i < shape[dim]; ++i) {
248*d581c94dSMatthias Springer     indices[dim] = constants[i];
249*d581c94dSMatthias Springer     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
250*d581c94dSMatthias Springer                  indices);
251*d581c94dSMatthias Springer   }
252*d581c94dSMatthias Springer }
253*d581c94dSMatthias Springer 
254*d581c94dSMatthias Springer /// Bufferization of tensor.from_elements.
255*d581c94dSMatthias Springer struct FromElementsOpInterface
256*d581c94dSMatthias Springer     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
257*d581c94dSMatthias Springer                                                     tensor::FromElementsOp> {
258*d581c94dSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
259*d581c94dSMatthias Springer                           const BufferizationState &state) const {
260*d581c94dSMatthias Springer     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
261*d581c94dSMatthias Springer 
262*d581c94dSMatthias Springer     // Allocate a buffer for the result.
263*d581c94dSMatthias Springer     Location loc = op->getLoc();
264*d581c94dSMatthias Springer     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
265*d581c94dSMatthias Springer     auto shape = tensorType.getShape();
266*d581c94dSMatthias Springer     MemRefType resultType =
267*d581c94dSMatthias Springer         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
268*d581c94dSMatthias Springer     FailureOr<Value> maybeBuffer =
269*d581c94dSMatthias Springer         createAlloc(rewriter, loc, resultType, {},
270*d581c94dSMatthias Springer                     /*deallocMemref=*/state.getOptions().createDeallocs,
271*d581c94dSMatthias Springer                     state.getOptions());
272*d581c94dSMatthias Springer     if (failed(maybeBuffer))
273*d581c94dSMatthias Springer       return failure();
274*d581c94dSMatthias Springer     Value buffer = *maybeBuffer;
275*d581c94dSMatthias Springer 
276*d581c94dSMatthias Springer     // Case: tensor<0xelem_type>.
277*d581c94dSMatthias Springer     if (fromElementsOp.elements().empty()) {
278*d581c94dSMatthias Springer       replaceOpWithBufferizedValues(rewriter, op, buffer);
279*d581c94dSMatthias Springer       return success();
280*d581c94dSMatthias Springer     }
281*d581c94dSMatthias Springer 
282*d581c94dSMatthias Springer     // Case: tensor<elem_type>.
283*d581c94dSMatthias Springer     if (shape.empty()) {
284*d581c94dSMatthias Springer       rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
285*d581c94dSMatthias Springer                                        buffer);
286*d581c94dSMatthias Springer       replaceOpWithBufferizedValues(rewriter, op, buffer);
287*d581c94dSMatthias Springer       return success();
288*d581c94dSMatthias Springer     }
289*d581c94dSMatthias Springer 
290*d581c94dSMatthias Springer     // Create constants for the range of possible indices [0, max{shape_i}).
291*d581c94dSMatthias Springer     auto maxDim = *std::max_element(shape.begin(), shape.end());
292*d581c94dSMatthias Springer     SmallVector<Value, 2> constants;
293*d581c94dSMatthias Springer     constants.reserve(maxDim);
294*d581c94dSMatthias Springer     for (int i = 0; i < maxDim; ++i)
295*d581c94dSMatthias Springer       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
296*d581c94dSMatthias Springer 
297*d581c94dSMatthias Springer     // Traverse all `elements` and create `memref.store` ops.
298*d581c94dSMatthias Springer     auto elementIt = fromElementsOp.elements().begin();
299*d581c94dSMatthias Springer     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
300*d581c94dSMatthias Springer     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
301*d581c94dSMatthias Springer                  indices);
302*d581c94dSMatthias Springer 
303*d581c94dSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, buffer);
304*d581c94dSMatthias Springer     return success();
305*d581c94dSMatthias Springer   }
306*d581c94dSMatthias Springer };
307*d581c94dSMatthias Springer 
30871bbb78bSMatthias Springer /// Bufferization of tensor.generate.
30971bbb78bSMatthias Springer struct GenerateOpInterface
31071bbb78bSMatthias Springer     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
31171bbb78bSMatthias Springer                                                     tensor::GenerateOp> {
31271bbb78bSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
31371bbb78bSMatthias Springer                           const BufferizationState &state) const {
31471bbb78bSMatthias Springer     auto generateOp = cast<tensor::GenerateOp>(op);
31571bbb78bSMatthias Springer 
31671bbb78bSMatthias Springer     // Allocate memory.
31771bbb78bSMatthias Springer     Location loc = op->getLoc();
31871bbb78bSMatthias Springer     MemRefType memrefType =
31971bbb78bSMatthias Springer         getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
32071bbb78bSMatthias Springer     FailureOr<Value> maybeResult =
32171bbb78bSMatthias Springer         createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(),
32271bbb78bSMatthias Springer                     /*deallocMemref=*/state.getOptions().createDeallocs,
32371bbb78bSMatthias Springer                     state.getOptions());
32471bbb78bSMatthias Springer     if (failed(maybeResult))
32571bbb78bSMatthias Springer       return failure();
32671bbb78bSMatthias Springer     Value result = *maybeResult;
32771bbb78bSMatthias Springer 
32871bbb78bSMatthias Springer     // Collect loop bounds.
32971bbb78bSMatthias Springer     int64_t rank = memrefType.getRank();
33071bbb78bSMatthias Springer     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
33171bbb78bSMatthias Springer     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
33271bbb78bSMatthias Springer     SmallVector<Value, 4> lowerBounds(rank, zero);
33371bbb78bSMatthias Springer     SmallVector<Value, 4> steps(rank, one);
33471bbb78bSMatthias Springer     SmallVector<Value, 4> upperBounds;
33571bbb78bSMatthias Springer     int nextDynamicIndex = 0;
33671bbb78bSMatthias Springer     for (int i = 0; i < rank; i++) {
33771bbb78bSMatthias Springer       Value upperBound = memrefType.isDynamicDim(i)
33871bbb78bSMatthias Springer                              ? generateOp.dynamicExtents()[nextDynamicIndex++]
33971bbb78bSMatthias Springer                              : rewriter.create<arith::ConstantIndexOp>(
34071bbb78bSMatthias Springer                                    loc, memrefType.getDimSize(i));
34171bbb78bSMatthias Springer       upperBounds.push_back(upperBound);
34271bbb78bSMatthias Springer     }
34371bbb78bSMatthias Springer 
34471bbb78bSMatthias Springer     // Generate tensor elements with a parallel loop that stores into
34571bbb78bSMatthias Springer     // each element of the resulting memref. We use mergeBlockBefore to "move"
34671bbb78bSMatthias Springer     // this op's body into the scf.parallel's body.
34771bbb78bSMatthias Springer     auto parallel =
34871bbb78bSMatthias Springer         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
34971bbb78bSMatthias Springer     Block *parallelBody = parallel.getBody();
35071bbb78bSMatthias Springer     rewriter.mergeBlockBefore(generateOp.getBody(),
35171bbb78bSMatthias Springer                               parallelBody->getTerminator(),
35271bbb78bSMatthias Springer                               parallelBody->getArguments());
35371bbb78bSMatthias Springer     // Replace the inlined yield op with a store op. The scf.parallel's builder
35471bbb78bSMatthias Springer     // already populated an scf.yield at the end, so we don't need to worry
35571bbb78bSMatthias Springer     // about creating that.
35671bbb78bSMatthias Springer     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
35771bbb78bSMatthias Springer     rewriter.setInsertionPointAfter(elementYield);
35871bbb78bSMatthias Springer     rewriter.replaceOpWithNewOp<memref::StoreOp>(
35971bbb78bSMatthias Springer         elementYield, elementYield->getOperands()[0], result,
36071bbb78bSMatthias Springer         parallelBody->getArguments());
36171bbb78bSMatthias Springer 
36271bbb78bSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, result);
36371bbb78bSMatthias Springer     return success();
36471bbb78bSMatthias Springer   }
36571bbb78bSMatthias Springer };
36671bbb78bSMatthias Springer 
36749e37000SMatthias Springer /// Bufferization of tensor.insert. Replace with memref.store.
36849e37000SMatthias Springer struct InsertOpInterface
36949e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
37049e37000SMatthias Springer                                                     tensor::InsertOp> {
37149e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
37249e37000SMatthias Springer                               const BufferizationState &state) const {
37349e37000SMatthias Springer     return true;
37449e37000SMatthias Springer   }
37549e37000SMatthias Springer 
37649e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
37749e37000SMatthias Springer                                const BufferizationState &state) const {
37849e37000SMatthias Springer     return true;
37949e37000SMatthias Springer   }
38049e37000SMatthias Springer 
38149e37000SMatthias Springer   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
38249e37000SMatthias Springer                                const BufferizationState &state) const {
38349e37000SMatthias Springer     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
38449e37000SMatthias Springer            "expected dest OpOperand");
38549e37000SMatthias Springer     return op->getOpResult(0);
38649e37000SMatthias Springer   }
38749e37000SMatthias Springer 
38849e37000SMatthias Springer   SmallVector<OpOperand *>
38949e37000SMatthias Springer   getAliasingOpOperand(Operation *op, OpResult opResult,
39049e37000SMatthias Springer                        const BufferizationState &state) const {
39149e37000SMatthias Springer     return {&op->getOpOperand(1) /*dest*/};
39249e37000SMatthias Springer   }
39349e37000SMatthias Springer 
39449e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
39549e37000SMatthias Springer                           const BufferizationState &state) const {
39649e37000SMatthias Springer     auto insertOp = cast<tensor::InsertOp>(op);
39749e37000SMatthias Springer     FailureOr<Value> destMemref =
39849e37000SMatthias Springer         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
39949e37000SMatthias Springer     if (failed(destMemref))
40049e37000SMatthias Springer       return failure();
40149e37000SMatthias Springer     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
40249e37000SMatthias Springer                                      *destMemref, insertOp.indices());
40349e37000SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
40449e37000SMatthias Springer     return success();
40549e37000SMatthias Springer   }
40649e37000SMatthias Springer 
40749e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
40849e37000SMatthias Springer                                 const BufferizationState &state) const {
40949e37000SMatthias Springer     return BufferRelation::Equivalent;
41049e37000SMatthias Springer   }
41149e37000SMatthias Springer };
41249e37000SMatthias Springer 
41349e37000SMatthias Springer /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
41449e37000SMatthias Springer /// equivalent operand / result and same offset/sizes/strides specification).
41549e37000SMatthias Springer ///
41649e37000SMatthias Springer /// This is one particular type of relationship between ops on tensors that
41749e37000SMatthias Springer /// reduce to an equivalence on buffers. This should be generalized and
41849e37000SMatthias Springer /// exposed as interfaces on the proper types.
41949e37000SMatthias Springer static bool areEquivalentExtractSliceOps(const BufferizationState &state,
42049e37000SMatthias Springer                                          ExtractSliceOp st, InsertSliceOp sti) {
42149e37000SMatthias Springer   if (!st || !sti)
42249e37000SMatthias Springer     return false;
42349e37000SMatthias Springer   if (sti != sti &&
42449e37000SMatthias Springer       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
42549e37000SMatthias Springer     return false;
42649e37000SMatthias Springer   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
42749e37000SMatthias Springer     return false;
42849e37000SMatthias Springer   return true;
42949e37000SMatthias Springer }
43049e37000SMatthias Springer 
43149e37000SMatthias Springer /// Return true if `value` is originating from an ExtractSliceOp that matches
43249e37000SMatthias Springer /// the given InsertSliceOp.
43349e37000SMatthias Springer static bool hasMatchingExtractSliceOp(const BufferizationState &state,
43449e37000SMatthias Springer                                       Value value, InsertSliceOp insertOp) {
43549e37000SMatthias Springer   auto condition = [&](Value val) {
43649e37000SMatthias Springer     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
43749e37000SMatthias Springer       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
43849e37000SMatthias Springer         return true;
43949e37000SMatthias Springer     return false;
44049e37000SMatthias Springer   };
44149e37000SMatthias Springer 
44249e37000SMatthias Springer   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
44349e37000SMatthias Springer                       condition);
44449e37000SMatthias Springer }
44549e37000SMatthias Springer 
44649e37000SMatthias Springer /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
44749e37000SMatthias Springer /// certain circumstances, this op can also be a no-op.
44849e37000SMatthias Springer struct InsertSliceOpInterface
44949e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
45049e37000SMatthias Springer                                                     tensor::InsertSliceOp> {
45149e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
45249e37000SMatthias Springer                               const BufferizationState &state) const {
45349e37000SMatthias Springer     return true;
45449e37000SMatthias Springer   }
45549e37000SMatthias Springer 
45649e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
45749e37000SMatthias Springer                                const BufferizationState &state) const {
45849e37000SMatthias Springer     return &opOperand == &op->getOpOperand(1) /*dest*/;
45949e37000SMatthias Springer   }
46049e37000SMatthias Springer 
46149e37000SMatthias Springer   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
46249e37000SMatthias Springer                                const BufferizationState &state) const {
46349e37000SMatthias Springer     return &opOperand == &op->getOpOperand(1) /*dest*/
46449e37000SMatthias Springer                ? op->getResult(0)
46549e37000SMatthias Springer                : OpResult();
46649e37000SMatthias Springer   }
46749e37000SMatthias Springer 
46849e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
46949e37000SMatthias Springer                                 const BufferizationState &state) const {
47049e37000SMatthias Springer     return BufferRelation::Equivalent;
47149e37000SMatthias Springer   }
47249e37000SMatthias Springer 
47349e37000SMatthias Springer   bool isNotConflicting(Operation *op, OpOperand *uRead,
47449e37000SMatthias Springer                         OpOperand *uConflictingWrite,
47549e37000SMatthias Springer                         const BufferizationState &state) const {
47649e37000SMatthias Springer     Operation *readingOp = uRead->getOwner();
47749e37000SMatthias Springer     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
47849e37000SMatthias Springer 
47949e37000SMatthias Springer     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
48049e37000SMatthias Springer     // uRead is an InsertSliceOp...
48149e37000SMatthias Springer     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
48249e37000SMatthias Springer       // As an example, consider the following IR.
48349e37000SMatthias Springer       //
48449e37000SMatthias Springer       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
48549e37000SMatthias Springer       // %1 = linalg.fill %cst, %0 {inplace= [true] }
48649e37000SMatthias Springer       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
48749e37000SMatthias Springer       //     {inplace= [true] }
48849e37000SMatthias Springer 
48949e37000SMatthias Springer       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
49049e37000SMatthias Springer       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
49149e37000SMatthias Springer           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
49249e37000SMatthias Springer                                     insertSliceOp))
49349e37000SMatthias Springer         // Case 1: The main insight is that InsertSliceOp reads only part of
49449e37000SMatthias Springer         // the destination tensor. The overwritten area is not read. If
49549e37000SMatthias Springer         // uConflictingWrite writes into exactly the memory location that is
49649e37000SMatthias Springer         // being read by uRead, this is not a conflict.
49749e37000SMatthias Springer         //
49849e37000SMatthias Springer         // In the above example:
49949e37000SMatthias Springer         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
50049e37000SMatthias Springer         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
50149e37000SMatthias Springer         //
50249e37000SMatthias Springer         // The read of %t does not conflict with the write of the FillOp
50349e37000SMatthias Springer         // (same aliases!) because the area that the FillOp operates on is
50449e37000SMatthias Springer         // exactly the one that is *not* read via %t.
50549e37000SMatthias Springer         return true;
50649e37000SMatthias Springer 
50749e37000SMatthias Springer       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
50849e37000SMatthias Springer           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
50949e37000SMatthias Springer           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
51049e37000SMatthias Springer         // Case 2: The read of the source tensor and the write to the dest
51149e37000SMatthias Springer         // tensor via an InsertSliceOp is not a conflict if the read is
51249e37000SMatthias Springer         // reading exactly that part of an equivalent tensor that the
51349e37000SMatthias Springer         // InsertSliceOp is writing.
51449e37000SMatthias Springer         //
51549e37000SMatthias Springer         // In the above example:
51649e37000SMatthias Springer         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
51749e37000SMatthias Springer         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
51849e37000SMatthias Springer         return true;
51949e37000SMatthias Springer     }
52049e37000SMatthias Springer 
52149e37000SMatthias Springer     // If uConflictingWrite is an InsertSliceOp...
52249e37000SMatthias Springer     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
52349e37000SMatthias Springer       // As an example, consider the following IR.
52449e37000SMatthias Springer       //
52549e37000SMatthias Springer       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
52649e37000SMatthias Springer       // %1 = linalg.fill %cst, %0 {inplace= [true] }
52749e37000SMatthias Springer       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
52849e37000SMatthias Springer       //     {inplace= [true] }
52949e37000SMatthias Springer       // %3 = vector.transfer_read %1, %cst
53049e37000SMatthias Springer       //
53149e37000SMatthias Springer       // In the above example:
53249e37000SMatthias Springer       // uRead             = OpOperand 0 (%1) of vector.transfer_read
53349e37000SMatthias Springer       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
53449e37000SMatthias Springer       // lastWrite         = %1
53549e37000SMatthias Springer       //
53649e37000SMatthias Springer       // This is not a conflict because the InsertSliceOp overwrites the
53749e37000SMatthias Springer       // memory segment of %1 with the exact same data. (Effectively, there
53849e37000SMatthias Springer       // is no memory write here.)
53949e37000SMatthias Springer       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
54049e37000SMatthias Springer           state.areEquivalentBufferizedValues(uRead->get(),
54149e37000SMatthias Springer                                               insertSliceOp.source()) &&
54249e37000SMatthias Springer           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
54349e37000SMatthias Springer                                     insertSliceOp))
54449e37000SMatthias Springer         return true;
54549e37000SMatthias Springer 
54649e37000SMatthias Springer     return false;
54749e37000SMatthias Springer   }
54849e37000SMatthias Springer 
54949e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
55049e37000SMatthias Springer                           const BufferizationState &state) const {
55149e37000SMatthias Springer     // insert_slice ops arise from tiling and bufferizing them out-of-place is
55249e37000SMatthias Springer     // generally a deal breaker. When used with loops, this ends up cloning the
55349e37000SMatthias Springer     // whole tensor on every single iteration and is a symptom of a
55449e37000SMatthias Springer     // catastrophically bad scheduling decision.
55549e37000SMatthias Springer     // TODO: be very loud about it or even consider failing the pass.
55649e37000SMatthias Springer     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
55749e37000SMatthias Springer     Location loc = insertSliceOp.getLoc();
55849e37000SMatthias Springer 
55949e37000SMatthias Springer     // When bufferizing out-of-place, `getResultBuffer` allocates.
56049e37000SMatthias Springer     FailureOr<Value> dstMemref =
56149e37000SMatthias Springer         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
56249e37000SMatthias Springer     if (failed(dstMemref))
56349e37000SMatthias Springer       return failure();
56449e37000SMatthias Springer 
56549e37000SMatthias Springer     // Expand offsets, sizes and strides to the full rank to handle the
56649e37000SMatthias Springer     // rank-reducing case.
56749e37000SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
56849e37000SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
56949e37000SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
57049e37000SMatthias Springer     OffsetSizeAndStrideOpInterface::expandToRank(
57149e37000SMatthias Springer         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
57249e37000SMatthias Springer         [&](Value target, int64_t dim) -> OpFoldResult {
57349e37000SMatthias Springer           auto shapedType = target.getType().cast<ShapedType>();
57449e37000SMatthias Springer           if (shapedType.isDynamicDim(dim))
57549e37000SMatthias Springer             return rewriter.create<memref::DimOp>(loc, target, dim).result();
57649e37000SMatthias Springer           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
57749e37000SMatthias Springer         });
57849e37000SMatthias Springer     // Take a subview of the dst.
57949e37000SMatthias Springer     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
58049e37000SMatthias Springer     auto subviewMemRefType =
58149e37000SMatthias Springer         memref::SubViewOp::inferRankReducedResultType(
58249e37000SMatthias Springer             insertSliceOp.getSourceType().getRank(), dstMemrefType,
58349e37000SMatthias Springer             mixedOffsets, mixedSizes, mixedStrides)
58449e37000SMatthias Springer             .cast<MemRefType>();
58549e37000SMatthias Springer     Value subView = rewriter.create<memref::SubViewOp>(
58649e37000SMatthias Springer         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
58749e37000SMatthias Springer         mixedStrides);
58849e37000SMatthias Springer 
58949e37000SMatthias Springer     // Copy tensor. If this tensor.insert_slice has a matching
59049e37000SMatthias Springer     // tensor.extract_slice, the copy operation will eventually fold away.
59149e37000SMatthias Springer     Value srcMemref =
59249e37000SMatthias Springer         *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
59349e37000SMatthias Springer     if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
59449e37000SMatthias Springer                             state.getOptions())))
59549e37000SMatthias Springer       return failure();
59649e37000SMatthias Springer 
59749e37000SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
59849e37000SMatthias Springer     return success();
59949e37000SMatthias Springer   }
60049e37000SMatthias Springer };
60149e37000SMatthias Springer 
602fc08d1c2SMatthias Springer /// Bufferization of tensor.rank. Replace with memref.rank.
603fc08d1c2SMatthias Springer struct RankOpInterface
604fc08d1c2SMatthias Springer     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
605fc08d1c2SMatthias Springer                                                     tensor::RankOp> {
606fc08d1c2SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
607fc08d1c2SMatthias Springer                               const BufferizationState &state) const {
608fc08d1c2SMatthias Springer     return true;
609fc08d1c2SMatthias Springer   }
610fc08d1c2SMatthias Springer 
611fc08d1c2SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
612fc08d1c2SMatthias Springer                                const BufferizationState &state) const {
613fc08d1c2SMatthias Springer     return false;
614fc08d1c2SMatthias Springer   }
615fc08d1c2SMatthias Springer 
616fc08d1c2SMatthias Springer   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
617fc08d1c2SMatthias Springer                                const BufferizationState &state) const {
618fc08d1c2SMatthias Springer     return OpResult();
619fc08d1c2SMatthias Springer   }
620fc08d1c2SMatthias Springer 
621fc08d1c2SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
622fc08d1c2SMatthias Springer                           const BufferizationState &state) const {
623fc08d1c2SMatthias Springer     auto rankOp = cast<tensor::RankOp>(op);
624fc08d1c2SMatthias Springer     Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
625fc08d1c2SMatthias Springer     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
626fc08d1c2SMatthias Springer                                                  v);
627fc08d1c2SMatthias Springer     return success();
628fc08d1c2SMatthias Springer   }
629fc08d1c2SMatthias Springer };
630fc08d1c2SMatthias Springer 
63149e37000SMatthias Springer } // namespace
63249e37000SMatthias Springer } // namespace tensor
63349e37000SMatthias Springer } // namespace mlir
63449e37000SMatthias Springer 
63549e37000SMatthias Springer void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
63649e37000SMatthias Springer     DialectRegistry &registry) {
63749e37000SMatthias Springer   registry.addOpInterface<CastOp, CastOpInterface>();
63849e37000SMatthias Springer   registry.addOpInterface<DimOp, DimOpInterface>();
63949e37000SMatthias Springer   registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
64049e37000SMatthias Springer   registry.addOpInterface<ExtractOp, ExtractOpInterface>();
641*d581c94dSMatthias Springer   registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();
64271bbb78bSMatthias Springer   registry.addOpInterface<GenerateOp, GenerateOpInterface>();
64349e37000SMatthias Springer   registry.addOpInterface<InsertOp, InsertOpInterface>();
64449e37000SMatthias Springer   registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
645fc08d1c2SMatthias Springer   registry.addOpInterface<RankOp, RankOpInterface>();
64649e37000SMatthias Springer }
647