149e37000SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
249e37000SMatthias Springer //
349e37000SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
449e37000SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
549e37000SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
649e37000SMatthias Springer //
749e37000SMatthias Springer //===----------------------------------------------------------------------===//
849e37000SMatthias Springer 
949e37000SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
1049e37000SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1149e37000SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
1249e37000SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
1349e37000SMatthias Springer #include "mlir/IR/Dialect.h"
1449e37000SMatthias Springer #include "mlir/IR/Operation.h"
1549e37000SMatthias Springer 
1649e37000SMatthias Springer using namespace mlir;
1749e37000SMatthias Springer using namespace mlir::bufferization;
1849e37000SMatthias Springer using namespace mlir::tensor;
1949e37000SMatthias Springer 
2049e37000SMatthias Springer namespace mlir {
2149e37000SMatthias Springer namespace tensor {
2249e37000SMatthias Springer namespace {
2349e37000SMatthias Springer 
2449e37000SMatthias Springer struct CastOpInterface
2549e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
2649e37000SMatthias Springer                                                     tensor::CastOp> {
2749e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
2849e37000SMatthias Springer                               const BufferizationState &state) const {
2949e37000SMatthias Springer     return false;
3049e37000SMatthias Springer   }
3149e37000SMatthias Springer 
3249e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
3349e37000SMatthias Springer                                const BufferizationState &state) const {
3449e37000SMatthias Springer     return false;
3549e37000SMatthias Springer   }
3649e37000SMatthias Springer 
3749e37000SMatthias Springer   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
3849e37000SMatthias Springer                                const BufferizationState &state) const {
3949e37000SMatthias Springer     return op->getResult(0);
4049e37000SMatthias Springer   }
4149e37000SMatthias Springer 
4249e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
4349e37000SMatthias Springer                                 const BufferizationState &state) const {
4449e37000SMatthias Springer     return BufferRelation::Equivalent;
4549e37000SMatthias Springer   }
4649e37000SMatthias Springer 
4749e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
4849e37000SMatthias Springer                           const BufferizationState &state) const {
4949e37000SMatthias Springer     auto castOp = cast<tensor::CastOp>(op);
5049e37000SMatthias Springer 
5149e37000SMatthias Springer     // The result buffer still has the old (pre-cast) type.
5249e37000SMatthias Springer     FailureOr<Value> resultBuffer =
5349e37000SMatthias Springer         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
5449e37000SMatthias Springer     if (failed(resultBuffer))
5549e37000SMatthias Springer       return failure();
5649e37000SMatthias Springer     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
5749e37000SMatthias Springer     Attribute memorySpace = sourceMemRefType.getMemorySpace();
5849e37000SMatthias Springer     TensorType resultTensorType =
5949e37000SMatthias Springer         castOp.getResult().getType().cast<TensorType>();
6049e37000SMatthias Springer     MemRefLayoutAttrInterface layout;
6149e37000SMatthias Springer 
6249e37000SMatthias Springer     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
6349e37000SMatthias Springer       if (resultTensorType.isa<RankedTensorType>())
6449e37000SMatthias Springer         layout = rankedMemRefType.getLayout();
6549e37000SMatthias Springer 
6649e37000SMatthias Springer     // Compute the new memref type.
6749e37000SMatthias Springer     Type resultMemRefType;
6849e37000SMatthias Springer     if (resultTensorType.isa<RankedTensorType>()) {
6949e37000SMatthias Springer       resultMemRefType =
7049e37000SMatthias Springer           getContiguousMemRefType(resultTensorType, layout, memorySpace);
7149e37000SMatthias Springer     } else {
7249e37000SMatthias Springer       resultMemRefType =
7349e37000SMatthias Springer           getUnrankedMemRefType(resultTensorType.getElementType(), memorySpace);
7449e37000SMatthias Springer     }
7549e37000SMatthias Springer 
7649e37000SMatthias Springer     // Replace the op with a memref.cast.
7749e37000SMatthias Springer     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
7849e37000SMatthias Springer                                              resultMemRefType) &&
7949e37000SMatthias Springer            "CallOp::bufferize: cast incompatible");
8049e37000SMatthias Springer     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
8149e37000SMatthias Springer                                                  *resultBuffer);
8249e37000SMatthias Springer 
8349e37000SMatthias Springer     return success();
8449e37000SMatthias Springer   }
8549e37000SMatthias Springer };
8649e37000SMatthias Springer 
8749e37000SMatthias Springer /// Bufferization of tensor.dim. Replace with memref.dim.
8849e37000SMatthias Springer struct DimOpInterface
8949e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
9049e37000SMatthias Springer                                                     tensor::DimOp> {
9149e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
9249e37000SMatthias Springer                               const BufferizationState &state) const {
9349e37000SMatthias Springer     return true;
9449e37000SMatthias Springer   }
9549e37000SMatthias Springer 
9649e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
9749e37000SMatthias Springer                                const BufferizationState &state) const {
9849e37000SMatthias Springer     return false;
9949e37000SMatthias Springer   }
10049e37000SMatthias Springer 
10149e37000SMatthias Springer   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
10249e37000SMatthias Springer                                const BufferizationState &state) const {
10349e37000SMatthias Springer     return OpResult();
10449e37000SMatthias Springer   }
10549e37000SMatthias Springer 
10649e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
10749e37000SMatthias Springer                           const BufferizationState &state) const {
10849e37000SMatthias Springer     auto dimOp = cast<tensor::DimOp>(op);
10949e37000SMatthias Springer     Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
11049e37000SMatthias Springer     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
11149e37000SMatthias Springer     return success();
11249e37000SMatthias Springer   }
11349e37000SMatthias Springer };
11449e37000SMatthias Springer 
11549e37000SMatthias Springer /// Bufferization of tensor.extract_slice. Replace with memref.subview.
11649e37000SMatthias Springer struct ExtractSliceOpInterface
11749e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
11849e37000SMatthias Springer                                                     tensor::ExtractSliceOp> {
11949e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
12049e37000SMatthias Springer                               const BufferizationState &state) const {
12149e37000SMatthias Springer     return false;
12249e37000SMatthias Springer   }
12349e37000SMatthias Springer 
12449e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
12549e37000SMatthias Springer                                const BufferizationState &state) const {
12649e37000SMatthias Springer     return false;
12749e37000SMatthias Springer   }
12849e37000SMatthias Springer 
12949e37000SMatthias Springer   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
13049e37000SMatthias Springer                                const BufferizationState &state) const {
13149e37000SMatthias Springer     return &opOperand == &op->getOpOperand(0) /*source*/
13249e37000SMatthias Springer                ? op->getResult(0)
13349e37000SMatthias Springer                : OpResult();
13449e37000SMatthias Springer   }
13549e37000SMatthias Springer 
13649e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
13749e37000SMatthias Springer                                 const BufferizationState &state) const {
13849e37000SMatthias Springer     return BufferRelation::None;
13949e37000SMatthias Springer   }
14049e37000SMatthias Springer 
14149e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
14249e37000SMatthias Springer                           const BufferizationState &state) const {
14349e37000SMatthias Springer     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
14449e37000SMatthias Springer     Location loc = extractSliceOp.getLoc();
14549e37000SMatthias Springer     Value srcMemref =
14649e37000SMatthias Springer         *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
14749e37000SMatthias Springer                          /*forceInPlace=*/true);
14849e37000SMatthias Springer     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
14949e37000SMatthias Springer     auto dstTensorType =
15049e37000SMatthias Springer         extractSliceOp.result().getType().cast<RankedTensorType>();
15149e37000SMatthias Springer 
15249e37000SMatthias Springer     // If not inplaceable, alloc.
15349e37000SMatthias Springer     bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0));
15449e37000SMatthias Springer     Value alloc;
15549e37000SMatthias Springer     if (!inplace) {
15649e37000SMatthias Springer       FailureOr<Value> allocOrFailure =
15749e37000SMatthias Springer           createAlloc(rewriter, loc, extractSliceOp.result(),
15849e37000SMatthias Springer                       state.getOptions().createDeallocs, state.getOptions());
15949e37000SMatthias Springer       if (failed(allocOrFailure))
16049e37000SMatthias Springer         return failure();
16149e37000SMatthias Springer       alloc = *allocOrFailure;
16249e37000SMatthias Springer     }
16349e37000SMatthias Springer 
16449e37000SMatthias Springer     // Expand offsets, sizes and strides to the full rank to handle the
16549e37000SMatthias Springer     // rank-reducing case.
16649e37000SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
16749e37000SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
16849e37000SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
16949e37000SMatthias Springer     OffsetSizeAndStrideOpInterface::expandToRank(
17049e37000SMatthias Springer         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
17149e37000SMatthias Springer         [&](Value target, int64_t dim) -> OpFoldResult {
17249e37000SMatthias Springer           auto shapedType = target.getType().cast<ShapedType>();
17349e37000SMatthias Springer           if (shapedType.isDynamicDim(dim))
17449e37000SMatthias Springer             return rewriter.create<memref::DimOp>(loc, target, dim).result();
17549e37000SMatthias Springer           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
17649e37000SMatthias Springer         });
17749e37000SMatthias Springer     // Bufferize to subview.
17849e37000SMatthias Springer     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
17949e37000SMatthias Springer                                  dstTensorType.getRank(), srcMemrefType,
18049e37000SMatthias Springer                                  mixedOffsets, mixedSizes, mixedStrides)
18149e37000SMatthias Springer                                  .cast<MemRefType>();
18249e37000SMatthias Springer     Value subView = rewriter.create<memref::SubViewOp>(
18349e37000SMatthias Springer         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
18449e37000SMatthias Springer         mixedStrides);
18549e37000SMatthias Springer 
18649e37000SMatthias Springer     // If not inplaceable, copy.
18749e37000SMatthias Springer     if (!inplace) {
18849e37000SMatthias Springer       // Do not copy if the copied data is never read.
18949e37000SMatthias Springer       if (state.isValueRead(extractSliceOp.result()))
19049e37000SMatthias Springer         if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
19149e37000SMatthias Springer                                 alloc, state.getOptions())))
19249e37000SMatthias Springer           return failure();
19349e37000SMatthias Springer       subView = alloc;
19449e37000SMatthias Springer     }
19549e37000SMatthias Springer 
19649e37000SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, subView);
19749e37000SMatthias Springer     return success();
19849e37000SMatthias Springer   }
19949e37000SMatthias Springer };
20049e37000SMatthias Springer 
20149e37000SMatthias Springer /// Bufferization of tensor.extract. Replace with memref.load.
20249e37000SMatthias Springer struct ExtractOpInterface
20349e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
20449e37000SMatthias Springer                                                     tensor::ExtractOp> {
20549e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
20649e37000SMatthias Springer                               const BufferizationState &state) const {
20749e37000SMatthias Springer     return true;
20849e37000SMatthias Springer   }
20949e37000SMatthias Springer 
21049e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
21149e37000SMatthias Springer                                const BufferizationState &state) const {
21249e37000SMatthias Springer     return false;
21349e37000SMatthias Springer   }
21449e37000SMatthias Springer 
21549e37000SMatthias Springer   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
21649e37000SMatthias Springer                                const BufferizationState &state) const {
21749e37000SMatthias Springer     return OpResult();
21849e37000SMatthias Springer   }
21949e37000SMatthias Springer 
22049e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
22149e37000SMatthias Springer                           const BufferizationState &state) const {
22249e37000SMatthias Springer     auto extractOp = cast<tensor::ExtractOp>(op);
22349e37000SMatthias Springer     Value srcMemref =
22449e37000SMatthias Springer         *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
22549e37000SMatthias Springer     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
22649e37000SMatthias Springer                                                  extractOp.indices());
22749e37000SMatthias Springer     return success();
22849e37000SMatthias Springer   }
22949e37000SMatthias Springer };
23049e37000SMatthias Springer 
23149e37000SMatthias Springer /// Bufferization of tensor.insert. Replace with memref.store.
23249e37000SMatthias Springer struct InsertOpInterface
23349e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
23449e37000SMatthias Springer                                                     tensor::InsertOp> {
23549e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
23649e37000SMatthias Springer                               const BufferizationState &state) const {
23749e37000SMatthias Springer     return true;
23849e37000SMatthias Springer   }
23949e37000SMatthias Springer 
24049e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
24149e37000SMatthias Springer                                const BufferizationState &state) const {
24249e37000SMatthias Springer     return true;
24349e37000SMatthias Springer   }
24449e37000SMatthias Springer 
24549e37000SMatthias Springer   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
24649e37000SMatthias Springer                                const BufferizationState &state) const {
24749e37000SMatthias Springer     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
24849e37000SMatthias Springer            "expected dest OpOperand");
24949e37000SMatthias Springer     return op->getOpResult(0);
25049e37000SMatthias Springer   }
25149e37000SMatthias Springer 
25249e37000SMatthias Springer   SmallVector<OpOperand *>
25349e37000SMatthias Springer   getAliasingOpOperand(Operation *op, OpResult opResult,
25449e37000SMatthias Springer                        const BufferizationState &state) const {
25549e37000SMatthias Springer     return {&op->getOpOperand(1) /*dest*/};
25649e37000SMatthias Springer   }
25749e37000SMatthias Springer 
25849e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
25949e37000SMatthias Springer                           const BufferizationState &state) const {
26049e37000SMatthias Springer     auto insertOp = cast<tensor::InsertOp>(op);
26149e37000SMatthias Springer     FailureOr<Value> destMemref =
26249e37000SMatthias Springer         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
26349e37000SMatthias Springer     if (failed(destMemref))
26449e37000SMatthias Springer       return failure();
26549e37000SMatthias Springer     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
26649e37000SMatthias Springer                                      *destMemref, insertOp.indices());
26749e37000SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
26849e37000SMatthias Springer     return success();
26949e37000SMatthias Springer   }
27049e37000SMatthias Springer 
27149e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
27249e37000SMatthias Springer                                 const BufferizationState &state) const {
27349e37000SMatthias Springer     return BufferRelation::Equivalent;
27449e37000SMatthias Springer   }
27549e37000SMatthias Springer };
27649e37000SMatthias Springer 
27749e37000SMatthias Springer /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
27849e37000SMatthias Springer /// equivalent operand / result and same offset/sizes/strides specification).
27949e37000SMatthias Springer ///
28049e37000SMatthias Springer /// This is one particular type of relationship between ops on tensors that
28149e37000SMatthias Springer /// reduce to an equivalence on buffers. This should be generalized and
28249e37000SMatthias Springer /// exposed as interfaces on the proper types.
28349e37000SMatthias Springer static bool areEquivalentExtractSliceOps(const BufferizationState &state,
28449e37000SMatthias Springer                                          ExtractSliceOp st, InsertSliceOp sti) {
28549e37000SMatthias Springer   if (!st || !sti)
28649e37000SMatthias Springer     return false;
28749e37000SMatthias Springer   if (sti != sti &&
28849e37000SMatthias Springer       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
28949e37000SMatthias Springer     return false;
29049e37000SMatthias Springer   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
29149e37000SMatthias Springer     return false;
29249e37000SMatthias Springer   return true;
29349e37000SMatthias Springer }
29449e37000SMatthias Springer 
29549e37000SMatthias Springer /// Return true if `value` is originating from an ExtractSliceOp that matches
29649e37000SMatthias Springer /// the given InsertSliceOp.
29749e37000SMatthias Springer static bool hasMatchingExtractSliceOp(const BufferizationState &state,
29849e37000SMatthias Springer                                       Value value, InsertSliceOp insertOp) {
29949e37000SMatthias Springer   auto condition = [&](Value val) {
30049e37000SMatthias Springer     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
30149e37000SMatthias Springer       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
30249e37000SMatthias Springer         return true;
30349e37000SMatthias Springer     return false;
30449e37000SMatthias Springer   };
30549e37000SMatthias Springer 
30649e37000SMatthias Springer   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
30749e37000SMatthias Springer                       condition);
30849e37000SMatthias Springer }
30949e37000SMatthias Springer 
31049e37000SMatthias Springer /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
31149e37000SMatthias Springer /// certain circumstances, this op can also be a no-op.
31249e37000SMatthias Springer struct InsertSliceOpInterface
31349e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
31449e37000SMatthias Springer                                                     tensor::InsertSliceOp> {
31549e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
31649e37000SMatthias Springer                               const BufferizationState &state) const {
31749e37000SMatthias Springer     return true;
31849e37000SMatthias Springer   }
31949e37000SMatthias Springer 
32049e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
32149e37000SMatthias Springer                                const BufferizationState &state) const {
32249e37000SMatthias Springer     return &opOperand == &op->getOpOperand(1) /*dest*/;
32349e37000SMatthias Springer   }
32449e37000SMatthias Springer 
32549e37000SMatthias Springer   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
32649e37000SMatthias Springer                                const BufferizationState &state) const {
32749e37000SMatthias Springer     return &opOperand == &op->getOpOperand(1) /*dest*/
32849e37000SMatthias Springer                ? op->getResult(0)
32949e37000SMatthias Springer                : OpResult();
33049e37000SMatthias Springer   }
33149e37000SMatthias Springer 
33249e37000SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
33349e37000SMatthias Springer                                 const BufferizationState &state) const {
33449e37000SMatthias Springer     return BufferRelation::Equivalent;
33549e37000SMatthias Springer   }
33649e37000SMatthias Springer 
33749e37000SMatthias Springer   bool isNotConflicting(Operation *op, OpOperand *uRead,
33849e37000SMatthias Springer                         OpOperand *uConflictingWrite,
33949e37000SMatthias Springer                         const BufferizationState &state) const {
34049e37000SMatthias Springer     Operation *readingOp = uRead->getOwner();
34149e37000SMatthias Springer     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
34249e37000SMatthias Springer 
34349e37000SMatthias Springer     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
34449e37000SMatthias Springer     // uRead is an InsertSliceOp...
34549e37000SMatthias Springer     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
34649e37000SMatthias Springer       // As an example, consider the following IR.
34749e37000SMatthias Springer       //
34849e37000SMatthias Springer       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
34949e37000SMatthias Springer       // %1 = linalg.fill %cst, %0 {inplace= [true] }
35049e37000SMatthias Springer       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
35149e37000SMatthias Springer       //     {inplace= [true] }
35249e37000SMatthias Springer 
35349e37000SMatthias Springer       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
35449e37000SMatthias Springer       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
35549e37000SMatthias Springer           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
35649e37000SMatthias Springer                                     insertSliceOp))
35749e37000SMatthias Springer         // Case 1: The main insight is that InsertSliceOp reads only part of
35849e37000SMatthias Springer         // the destination tensor. The overwritten area is not read. If
35949e37000SMatthias Springer         // uConflictingWrite writes into exactly the memory location that is
36049e37000SMatthias Springer         // being read by uRead, this is not a conflict.
36149e37000SMatthias Springer         //
36249e37000SMatthias Springer         // In the above example:
36349e37000SMatthias Springer         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
36449e37000SMatthias Springer         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
36549e37000SMatthias Springer         //
36649e37000SMatthias Springer         // The read of %t does not conflict with the write of the FillOp
36749e37000SMatthias Springer         // (same aliases!) because the area that the FillOp operates on is
36849e37000SMatthias Springer         // exactly the one that is *not* read via %t.
36949e37000SMatthias Springer         return true;
37049e37000SMatthias Springer 
37149e37000SMatthias Springer       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
37249e37000SMatthias Springer           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
37349e37000SMatthias Springer           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
37449e37000SMatthias Springer         // Case 2: The read of the source tensor and the write to the dest
37549e37000SMatthias Springer         // tensor via an InsertSliceOp is not a conflict if the read is
37649e37000SMatthias Springer         // reading exactly that part of an equivalent tensor that the
37749e37000SMatthias Springer         // InsertSliceOp is writing.
37849e37000SMatthias Springer         //
37949e37000SMatthias Springer         // In the above example:
38049e37000SMatthias Springer         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
38149e37000SMatthias Springer         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
38249e37000SMatthias Springer         return true;
38349e37000SMatthias Springer     }
38449e37000SMatthias Springer 
38549e37000SMatthias Springer     // If uConflictingWrite is an InsertSliceOp...
38649e37000SMatthias Springer     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
38749e37000SMatthias Springer       // As an example, consider the following IR.
38849e37000SMatthias Springer       //
38949e37000SMatthias Springer       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
39049e37000SMatthias Springer       // %1 = linalg.fill %cst, %0 {inplace= [true] }
39149e37000SMatthias Springer       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
39249e37000SMatthias Springer       //     {inplace= [true] }
39349e37000SMatthias Springer       // %3 = vector.transfer_read %1, %cst
39449e37000SMatthias Springer       //
39549e37000SMatthias Springer       // In the above example:
39649e37000SMatthias Springer       // uRead             = OpOperand 0 (%1) of vector.transfer_read
39749e37000SMatthias Springer       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
39849e37000SMatthias Springer       // lastWrite         = %1
39949e37000SMatthias Springer       //
40049e37000SMatthias Springer       // This is not a conflict because the InsertSliceOp overwrites the
40149e37000SMatthias Springer       // memory segment of %1 with the exact same data. (Effectively, there
40249e37000SMatthias Springer       // is no memory write here.)
40349e37000SMatthias Springer       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
40449e37000SMatthias Springer           state.areEquivalentBufferizedValues(uRead->get(),
40549e37000SMatthias Springer                                               insertSliceOp.source()) &&
40649e37000SMatthias Springer           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
40749e37000SMatthias Springer                                     insertSliceOp))
40849e37000SMatthias Springer         return true;
40949e37000SMatthias Springer 
41049e37000SMatthias Springer     return false;
41149e37000SMatthias Springer   }
41249e37000SMatthias Springer 
41349e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
41449e37000SMatthias Springer                           const BufferizationState &state) const {
41549e37000SMatthias Springer     // insert_slice ops arise from tiling and bufferizing them out-of-place is
41649e37000SMatthias Springer     // generally a deal breaker. When used with loops, this ends up cloning the
41749e37000SMatthias Springer     // whole tensor on every single iteration and is a symptom of a
41849e37000SMatthias Springer     // catastrophically bad scheduling decision.
41949e37000SMatthias Springer     // TODO: be very loud about it or even consider failing the pass.
42049e37000SMatthias Springer     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
42149e37000SMatthias Springer     Location loc = insertSliceOp.getLoc();
42249e37000SMatthias Springer 
42349e37000SMatthias Springer     // When bufferizing out-of-place, `getResultBuffer` allocates.
42449e37000SMatthias Springer     FailureOr<Value> dstMemref =
42549e37000SMatthias Springer         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
42649e37000SMatthias Springer     if (failed(dstMemref))
42749e37000SMatthias Springer       return failure();
42849e37000SMatthias Springer 
42949e37000SMatthias Springer     // Expand offsets, sizes and strides to the full rank to handle the
43049e37000SMatthias Springer     // rank-reducing case.
43149e37000SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
43249e37000SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
43349e37000SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
43449e37000SMatthias Springer     OffsetSizeAndStrideOpInterface::expandToRank(
43549e37000SMatthias Springer         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
43649e37000SMatthias Springer         [&](Value target, int64_t dim) -> OpFoldResult {
43749e37000SMatthias Springer           auto shapedType = target.getType().cast<ShapedType>();
43849e37000SMatthias Springer           if (shapedType.isDynamicDim(dim))
43949e37000SMatthias Springer             return rewriter.create<memref::DimOp>(loc, target, dim).result();
44049e37000SMatthias Springer           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
44149e37000SMatthias Springer         });
44249e37000SMatthias Springer     // Take a subview of the dst.
44349e37000SMatthias Springer     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
44449e37000SMatthias Springer     auto subviewMemRefType =
44549e37000SMatthias Springer         memref::SubViewOp::inferRankReducedResultType(
44649e37000SMatthias Springer             insertSliceOp.getSourceType().getRank(), dstMemrefType,
44749e37000SMatthias Springer             mixedOffsets, mixedSizes, mixedStrides)
44849e37000SMatthias Springer             .cast<MemRefType>();
44949e37000SMatthias Springer     Value subView = rewriter.create<memref::SubViewOp>(
45049e37000SMatthias Springer         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
45149e37000SMatthias Springer         mixedStrides);
45249e37000SMatthias Springer 
45349e37000SMatthias Springer     // Copy tensor. If this tensor.insert_slice has a matching
45449e37000SMatthias Springer     // tensor.extract_slice, the copy operation will eventually fold away.
45549e37000SMatthias Springer     Value srcMemref =
45649e37000SMatthias Springer         *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
45749e37000SMatthias Springer     if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
45849e37000SMatthias Springer                             state.getOptions())))
45949e37000SMatthias Springer       return failure();
46049e37000SMatthias Springer 
46149e37000SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
46249e37000SMatthias Springer     return success();
46349e37000SMatthias Springer   }
46449e37000SMatthias Springer };
46549e37000SMatthias Springer 
466*fc08d1c2SMatthias Springer /// Bufferization of tensor.rank. Replace with memref.rank.
467*fc08d1c2SMatthias Springer struct RankOpInterface
468*fc08d1c2SMatthias Springer     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
469*fc08d1c2SMatthias Springer                                                     tensor::RankOp> {
470*fc08d1c2SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
471*fc08d1c2SMatthias Springer                               const BufferizationState &state) const {
472*fc08d1c2SMatthias Springer     return true;
473*fc08d1c2SMatthias Springer   }
474*fc08d1c2SMatthias Springer 
475*fc08d1c2SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
476*fc08d1c2SMatthias Springer                                const BufferizationState &state) const {
477*fc08d1c2SMatthias Springer     return false;
478*fc08d1c2SMatthias Springer   }
479*fc08d1c2SMatthias Springer 
480*fc08d1c2SMatthias Springer   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
481*fc08d1c2SMatthias Springer                                const BufferizationState &state) const {
482*fc08d1c2SMatthias Springer     return OpResult();
483*fc08d1c2SMatthias Springer   }
484*fc08d1c2SMatthias Springer 
485*fc08d1c2SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
486*fc08d1c2SMatthias Springer                           const BufferizationState &state) const {
487*fc08d1c2SMatthias Springer     auto rankOp = cast<tensor::RankOp>(op);
488*fc08d1c2SMatthias Springer     Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
489*fc08d1c2SMatthias Springer     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
490*fc08d1c2SMatthias Springer                                                  v);
491*fc08d1c2SMatthias Springer     return success();
492*fc08d1c2SMatthias Springer   }
493*fc08d1c2SMatthias Springer };
494*fc08d1c2SMatthias Springer 
49549e37000SMatthias Springer } // namespace
49649e37000SMatthias Springer } // namespace tensor
49749e37000SMatthias Springer } // namespace mlir
49849e37000SMatthias Springer 
49949e37000SMatthias Springer void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
50049e37000SMatthias Springer     DialectRegistry &registry) {
50149e37000SMatthias Springer   registry.addOpInterface<CastOp, CastOpInterface>();
50249e37000SMatthias Springer   registry.addOpInterface<DimOp, DimOpInterface>();
50349e37000SMatthias Springer   registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
50449e37000SMatthias Springer   registry.addOpInterface<ExtractOp, ExtractOpInterface>();
50549e37000SMatthias Springer   registry.addOpInterface<InsertOp, InsertOpInterface>();
50649e37000SMatthias Springer   registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
507*fc08d1c2SMatthias Springer   registry.addOpInterface<RankOp, RankOpInterface>();
50849e37000SMatthias Springer }
509