1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::tensor;
21 
22 namespace mlir {
23 namespace tensor {
24 namespace {
25 
26 struct CastOpInterface
27     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
28                                                     tensor::CastOp> {
29   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
30                               const AnalysisState &state) const {
31     return false;
32   }
33 
34   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
35                                const AnalysisState &state) const {
36     return false;
37   }
38 
39   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
40                                             const AnalysisState &state) const {
41     return {op->getResult(0)};
42   }
43 
44   BufferRelation bufferRelation(Operation *op, OpResult opResult,
45                                 const AnalysisState &state) const {
46     return BufferRelation::Equivalent;
47   }
48 
49   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
50                           BufferizationState &state) const {
51     auto castOp = cast<tensor::CastOp>(op);
52 
53     // The result buffer still has the old (pre-cast) type.
54     FailureOr<Value> resultBuffer =
55         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
56     if (failed(resultBuffer))
57       return failure();
58     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
59     Attribute memorySpace = sourceMemRefType.getMemorySpace();
60     TensorType resultTensorType =
61         castOp.getResult().getType().cast<TensorType>();
62     MemRefLayoutAttrInterface layout;
63 
64     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
65       if (resultTensorType.isa<RankedTensorType>())
66         layout = rankedMemRefType.getLayout();
67 
68     // Compute the new memref type.
69     Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
70                                           layout, memorySpace);
71 
72     // Replace the op with a memref.cast.
73     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
74                                              resultMemRefType) &&
75            "CallOp::bufferize: cast incompatible");
76     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
77                                                  *resultBuffer);
78 
79     return success();
80   }
81 };
82 
83 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
84 struct CollapseShapeOpInterface
85     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
86                                                     tensor::CollapseShapeOp> {
87   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
88                               const AnalysisState &state) const {
89     return false;
90   }
91 
92   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
93                                const AnalysisState &state) const {
94     return false;
95   }
96 
97   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
98                                             const AnalysisState &state) const {
99     if (&opOperand == &op->getOpOperand(0) /*src*/)
100       return {op->getOpResult(0)};
101     return {};
102   }
103 
104   BufferRelation bufferRelation(Operation *op, OpResult opResult,
105                                 const AnalysisState &state) const {
106     return BufferRelation::Equivalent;
107   }
108 
109   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
110                           BufferizationState &state) const {
111     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
112     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
113     OpOperand &srcOperand = collapseShapeOp->getOpOperand(0) /*src*/;
114     auto bufferType = state.getBufferType(srcOperand).cast<MemRefType>();
115 
116     if (tensorResultType.getRank() == 0) {
117       // 0-d collapses must go through a different op builder.
118       Value buffer = *state.getBuffer(rewriter, srcOperand);
119       MemRefType resultType;
120 
121       if (bufferType.getLayout().isIdentity()) {
122         // Standard layout: result type has no offset.
123         MemRefLayoutAttrInterface layout;
124         resultType = MemRefType::get({}, tensorResultType.getElementType(),
125                                      layout, bufferType.getMemorySpace());
126       } else {
127         // Source memref has a layout map: result type has the same offset as
128         // the source type.
129         SmallVector<int64_t> strides;
130         int64_t offset;
131         if (failed(getStridesAndOffset(bufferType, strides, offset)))
132           return failure();
133         AffineMap resultLayout =
134             makeStridedLinearLayoutMap({}, offset, op->getContext());
135         resultType =
136             MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
137                             bufferType.getMemorySpaceAsInt());
138       }
139 
140       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
141           rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
142       return success();
143     }
144 
145     // If the dims are not collapsible (due to an incompatible source layout
146     // map), force an out-of-place bufferization, i.e., a buffer copy. This
147     // newly allocated buffer will have no layout map and thus be collapsible.
148     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
149         bufferType, collapseShapeOp.getReassociationIndices());
150     Optional<BufferizationState::ForceInPlacability> overrideInPlace =
151         canBeCollapsed
152             ? None
153             : Optional<BufferizationState::ForceInPlacability>(
154                   BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE);
155     Value buffer = *state.getBuffer(rewriter, srcOperand, overrideInPlace);
156 
157     // Result type is inferred by the builder.
158     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
159         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
160     return success();
161   }
162 };
163 
164 /// Bufferization of tensor.dim. Replace with memref.dim.
165 struct DimOpInterface
166     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
167                                                     tensor::DimOp> {
168   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
169                               const AnalysisState &state) const {
170     return true;
171   }
172 
173   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
174                                const AnalysisState &state) const {
175     return false;
176   }
177 
178   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
179                                             const AnalysisState &state) const {
180     return {};
181   }
182 
183   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
184                           BufferizationState &state) const {
185     auto dimOp = cast<tensor::DimOp>(op);
186     Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
187     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
188     return success();
189   }
190 };
191 
192 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
193 struct ExpandShapeOpInterface
194     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
195                                                     tensor::ExpandShapeOp> {
196   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
197                               const AnalysisState &state) const {
198     return false;
199   }
200 
201   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
202                                const AnalysisState &state) const {
203     return false;
204   }
205 
206   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
207                                             const AnalysisState &state) const {
208     if (&opOperand == &op->getOpOperand(0) /*src*/)
209       return {op->getOpResult(0)};
210     return {};
211   }
212 
213   BufferRelation bufferRelation(Operation *op, OpResult opResult,
214                                 const AnalysisState &state) const {
215     return BufferRelation::Equivalent;
216   }
217 
218   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
219                           BufferizationState &state) const {
220     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
221     auto tensorResultType = expandShapeOp.getResultType();
222     Value buffer =
223         *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
224 
225     // Memref result type is inferred by the builder based on reassociation
226     // indices and result shape.
227     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
228         rewriter, op, tensorResultType.getShape(), buffer,
229         expandShapeOp.getReassociationIndices());
230     return success();
231   }
232 };
233 
234 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
235 struct ExtractSliceOpInterface
236     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
237                                                     tensor::ExtractSliceOp> {
238   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
239                               const AnalysisState &state) const {
240     return false;
241   }
242 
243   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
244                                const AnalysisState &state) const {
245     return false;
246   }
247 
248   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
249                                             const AnalysisState &state) const {
250     if (&opOperand == &op->getOpOperand(0) /*source*/)
251       return {op->getOpResult(0)};
252     return {};
253   }
254 
255   BufferRelation bufferRelation(Operation *op, OpResult opResult,
256                                 const AnalysisState &state) const {
257     return BufferRelation::None;
258   }
259 
260   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
261                           BufferizationState &state) const {
262     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
263     Location loc = extractSliceOp.getLoc();
264 
265     // Even if this op was decided to bufferize out-of-place, do not insert the
266     // buffer copy yet. This is done later in this function.
267     Value srcMemref =
268         *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
269                          BufferizationState::ForceInPlacability::FORCE_INPLACE);
270     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
271     auto dstTensorType =
272         extractSliceOp.result().getType().cast<RankedTensorType>();
273 
274     // If not inplaceable, alloc.
275     bool inplace =
276         state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0));
277     Value alloc;
278     if (!inplace) {
279       FailureOr<Value> allocOrFailure =
280           state.createAlloc(rewriter, loc, extractSliceOp.result());
281       if (failed(allocOrFailure))
282         return failure();
283       alloc = *allocOrFailure;
284     }
285 
286     // Expand offsets, sizes and strides to the full rank to handle the
287     // rank-reducing case.
288     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
289     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
290     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
291     OffsetSizeAndStrideOpInterface::expandToRank(
292         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
293         [&](Value target, int64_t dim) -> OpFoldResult {
294           auto shapedType = target.getType().cast<ShapedType>();
295           if (shapedType.isDynamicDim(dim))
296             return rewriter.create<memref::DimOp>(loc, target, dim).result();
297           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
298         });
299     // Bufferize to subview.
300     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
301                                  dstTensorType.getRank(), srcMemrefType,
302                                  mixedOffsets, mixedSizes, mixedStrides)
303                                  .cast<MemRefType>();
304     Value subView = rewriter.create<memref::SubViewOp>(
305         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
306         mixedStrides);
307 
308     // If not inplaceable, copy.
309     if (!inplace) {
310       // Do not copy if the copied data is never read.
311       if (state.getAnalysisState().isValueRead(extractSliceOp.result()))
312         if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
313                                 alloc, state.getOptions())))
314           return failure();
315       subView = alloc;
316     }
317 
318     replaceOpWithBufferizedValues(rewriter, op, subView);
319     return success();
320   }
321 };
322 
323 /// Bufferization of tensor.extract. Replace with memref.load.
324 struct ExtractOpInterface
325     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
326                                                     tensor::ExtractOp> {
327   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
328                               const AnalysisState &state) const {
329     return true;
330   }
331 
332   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
333                                const AnalysisState &state) const {
334     return false;
335   }
336 
337   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
338                                             const AnalysisState &state) const {
339     return {};
340   }
341 
342   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
343                           BufferizationState &state) const {
344     auto extractOp = cast<tensor::ExtractOp>(op);
345     Value srcMemref =
346         *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
347     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
348                                                  extractOp.indices());
349     return success();
350   }
351 };
352 
353 // Implements backtracking to traverse indices of the output buffer while
354 // iterating over op.elements().
355 static void createStores(RewriterBase &rewriter, Location loc, int dim,
356                          Value buffer, ArrayRef<int64_t> shape,
357                          ArrayRef<Value> constants,
358                          OperandRange::iterator &elementIt,
359                          SmallVectorImpl<Value> &indices) {
360   if (dim == static_cast<int>(shape.size()) - 1) {
361     for (int i = 0; i < shape.back(); ++i) {
362       indices.back() = constants[i];
363       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
364       ++elementIt;
365     }
366     return;
367   }
368   for (int i = 0; i < shape[dim]; ++i) {
369     indices[dim] = constants[i];
370     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
371                  indices);
372   }
373 }
374 
375 /// Bufferization of tensor.from_elements.
376 struct FromElementsOpInterface
377     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
378                                                     tensor::FromElementsOp> {
379   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
380                           BufferizationState &state) const {
381     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
382 
383     // Allocate a buffer for the result.
384     Location loc = op->getLoc();
385     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
386     auto shape = tensorType.getShape();
387     FailureOr<Value> maybeBuffer =
388         state.createAlloc(rewriter, loc, fromElementsOp.result());
389     if (failed(maybeBuffer))
390       return failure();
391     Value buffer = *maybeBuffer;
392 
393     // Case: tensor<0xelem_type>.
394     if (fromElementsOp.elements().empty()) {
395       replaceOpWithBufferizedValues(rewriter, op, buffer);
396       return success();
397     }
398 
399     // Case: tensor<elem_type>.
400     if (shape.empty()) {
401       rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
402                                        buffer);
403       replaceOpWithBufferizedValues(rewriter, op, buffer);
404       return success();
405     }
406 
407     // Create constants for the range of possible indices [0, max{shape_i}).
408     auto maxDim = *std::max_element(shape.begin(), shape.end());
409     SmallVector<Value, 2> constants;
410     constants.reserve(maxDim);
411     for (int i = 0; i < maxDim; ++i)
412       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
413 
414     // Traverse all `elements` and create `memref.store` ops.
415     auto elementIt = fromElementsOp.elements().begin();
416     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
417     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
418                  indices);
419 
420     replaceOpWithBufferizedValues(rewriter, op, buffer);
421     return success();
422   }
423 };
424 
425 /// Bufferization of tensor.generate.
426 struct GenerateOpInterface
427     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
428                                                     tensor::GenerateOp> {
429   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
430                           BufferizationState &state) const {
431     auto generateOp = cast<tensor::GenerateOp>(op);
432 
433     // Allocate memory.
434     Location loc = op->getLoc();
435     MemRefType memrefType =
436         getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
437     FailureOr<Value> maybeResult =
438         state.createAlloc(rewriter, loc, generateOp.result());
439     if (failed(maybeResult))
440       return failure();
441     Value result = *maybeResult;
442 
443     // Collect loop bounds.
444     int64_t rank = memrefType.getRank();
445     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
446     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
447     SmallVector<Value, 4> lowerBounds(rank, zero);
448     SmallVector<Value, 4> steps(rank, one);
449     SmallVector<Value, 4> upperBounds;
450     int nextDynamicIndex = 0;
451     for (int i = 0; i < rank; i++) {
452       Value upperBound = memrefType.isDynamicDim(i)
453                              ? generateOp.dynamicExtents()[nextDynamicIndex++]
454                              : rewriter.create<arith::ConstantIndexOp>(
455                                    loc, memrefType.getDimSize(i));
456       upperBounds.push_back(upperBound);
457     }
458 
459     // Generate tensor elements with a parallel loop that stores into
460     // each element of the resulting memref. We use mergeBlockBefore to "move"
461     // this op's body into the scf.parallel's body.
462     auto parallel =
463         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
464     Block *parallelBody = parallel.getBody();
465     rewriter.mergeBlockBefore(generateOp.getBody(),
466                               parallelBody->getTerminator(),
467                               parallelBody->getArguments());
468     // Replace the inlined yield op with a store op. The scf.parallel's builder
469     // already populated an scf.yield at the end, so we don't need to worry
470     // about creating that.
471     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
472     rewriter.setInsertionPointAfter(elementYield);
473     rewriter.replaceOpWithNewOp<memref::StoreOp>(
474         elementYield, elementYield->getOperands()[0], result,
475         parallelBody->getArguments());
476 
477     replaceOpWithBufferizedValues(rewriter, op, result);
478     return success();
479   }
480 };
481 
482 /// Bufferization of tensor.insert. Replace with memref.store.
483 struct InsertOpInterface
484     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
485                                                     tensor::InsertOp> {
486   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
487                               const AnalysisState &state) const {
488     return true;
489   }
490 
491   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
492                                const AnalysisState &state) const {
493     return true;
494   }
495 
496   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
497                                             const AnalysisState &state) const {
498     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
499            "expected dest OpOperand");
500     return {op->getOpResult(0)};
501   }
502 
503   SmallVector<OpOperand *>
504   getAliasingOpOperand(Operation *op, OpResult opResult,
505                        const AnalysisState &state) const {
506     return {&op->getOpOperand(1) /*dest*/};
507   }
508 
509   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
510                           BufferizationState &state) const {
511     auto insertOp = cast<tensor::InsertOp>(op);
512     FailureOr<Value> destMemref =
513         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
514     if (failed(destMemref))
515       return failure();
516     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
517                                      *destMemref, insertOp.indices());
518     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
519     return success();
520   }
521 
522   BufferRelation bufferRelation(Operation *op, OpResult opResult,
523                                 const AnalysisState &state) const {
524     return BufferRelation::Equivalent;
525   }
526 };
527 
528 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
529 /// equivalent operand / result and same offset/sizes/strides specification).
530 ///
531 /// This is one particular type of relationship between ops on tensors that
532 /// reduce to an equivalence on buffers. This should be generalized and
533 /// exposed as interfaces on the proper types.
534 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
535                                          ExtractSliceOp st, InsertSliceOp sti) {
536   if (!st || !sti)
537     return false;
538   if (sti != sti &&
539       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
540     return false;
541   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
542     return false;
543   return true;
544 }
545 
546 /// Return true if `value` is originating from an ExtractSliceOp that matches
547 /// the given InsertSliceOp.
548 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
549                                       InsertSliceOp insertOp) {
550   auto condition = [&](Value val) {
551     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
552       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
553         return true;
554     return false;
555   };
556 
557   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
558                       condition);
559 }
560 
561 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
562 /// certain circumstances, this op can also be a no-op.
563 struct InsertSliceOpInterface
564     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
565                                                     tensor::InsertSliceOp> {
566   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
567                               const AnalysisState &state) const {
568     return true;
569   }
570 
571   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
572                                const AnalysisState &state) const {
573     return &opOperand == &op->getOpOperand(1) /*dest*/;
574   }
575 
576   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
577                                             const AnalysisState &state) const {
578     if (&opOperand == &op->getOpOperand(1) /*dest*/)
579       return {op->getResult(0)};
580     return {};
581   }
582 
583   BufferRelation bufferRelation(Operation *op, OpResult opResult,
584                                 const AnalysisState &state) const {
585     return BufferRelation::Equivalent;
586   }
587 
588   bool isNotConflicting(Operation *op, OpOperand *uRead,
589                         OpOperand *uConflictingWrite,
590                         const AnalysisState &state) const {
591     Operation *readingOp = uRead->getOwner();
592     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
593 
594     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
595     // uRead is an InsertSliceOp...
596     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
597       // As an example, consider the following IR.
598       //
599       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
600       // %1 = linalg.fill %cst, %0 {inplace= [true] }
601       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
602       //     {inplace= [true] }
603 
604       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
605       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
606           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
607                                     insertSliceOp))
608         // Case 1: The main insight is that InsertSliceOp reads only part of
609         // the destination tensor. The overwritten area is not read. If
610         // uConflictingWrite writes into exactly the memory location that is
611         // being read by uRead, this is not a conflict.
612         //
613         // In the above example:
614         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
615         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
616         //
617         // The read of %t does not conflict with the write of the FillOp
618         // (same aliases!) because the area that the FillOp operates on is
619         // exactly the one that is *not* read via %t.
620         return true;
621 
622       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
623           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
624           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
625         // Case 2: The read of the source tensor and the write to the dest
626         // tensor via an InsertSliceOp is not a conflict if the read is
627         // reading exactly that part of an equivalent tensor that the
628         // InsertSliceOp is writing.
629         //
630         // In the above example:
631         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
632         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
633         return true;
634     }
635 
636     // If uConflictingWrite is an InsertSliceOp...
637     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
638       // As an example, consider the following IR.
639       //
640       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
641       // %1 = linalg.fill %cst, %0 {inplace= [true] }
642       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
643       //     {inplace= [true] }
644       // %3 = vector.transfer_read %1, %cst
645       //
646       // In the above example:
647       // uRead             = OpOperand 0 (%1) of vector.transfer_read
648       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
649       // lastWrite         = %1
650       //
651       // This is not a conflict because the InsertSliceOp overwrites the
652       // memory segment of %1 with the exact same data. (Effectively, there
653       // is no memory write here.)
654       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
655           state.areEquivalentBufferizedValues(uRead->get(),
656                                               insertSliceOp.source()) &&
657           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
658                                     insertSliceOp))
659         return true;
660 
661     return false;
662   }
663 
664   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
665                           BufferizationState &state) const {
666     // insert_slice ops arise from tiling and bufferizing them out-of-place is
667     // generally a deal breaker. When used with loops, this ends up cloning the
668     // whole tensor on every single iteration and is a symptom of a
669     // catastrophically bad scheduling decision.
670     // TODO: be very loud about it or even consider failing the pass.
671     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
672     Location loc = insertSliceOp.getLoc();
673 
674     // When bufferizing out-of-place, `getResultBuffer` allocates.
675     FailureOr<Value> dstMemref =
676         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
677     if (failed(dstMemref))
678       return failure();
679 
680     // Expand offsets, sizes and strides to the full rank to handle the
681     // rank-reducing case.
682     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
683     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
684     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
685     OffsetSizeAndStrideOpInterface::expandToRank(
686         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
687         [&](Value target, int64_t dim) -> OpFoldResult {
688           auto shapedType = target.getType().cast<ShapedType>();
689           if (shapedType.isDynamicDim(dim))
690             return rewriter.create<memref::DimOp>(loc, target, dim).result();
691           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
692         });
693     // Take a subview of the dst.
694     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
695     auto subviewMemRefType =
696         memref::SubViewOp::inferRankReducedResultType(
697             insertSliceOp.getSourceType().getRank(), dstMemrefType,
698             mixedOffsets, mixedSizes, mixedStrides)
699             .cast<MemRefType>();
700     Value subView = rewriter.create<memref::SubViewOp>(
701         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
702         mixedStrides);
703 
704     // Copy tensor. If this tensor.insert_slice has a matching
705     // tensor.extract_slice, the copy operation will eventually fold away.
706     Value srcMemref =
707         *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
708     if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
709                             state.getOptions())))
710       return failure();
711 
712     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
713     return success();
714   }
715 };
716 
717 /// Bufferization of tensor.rank. Replace with memref.rank.
718 struct RankOpInterface
719     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
720                                                     tensor::RankOp> {
721   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
722                               const AnalysisState &state) const {
723     return true;
724   }
725 
726   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
727                                const AnalysisState &state) const {
728     return false;
729   }
730 
731   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
732                                             const AnalysisState &state) const {
733     return {};
734   }
735 
736   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
737                           BufferizationState &state) const {
738     auto rankOp = cast<tensor::RankOp>(op);
739     Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
740     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
741                                                  v);
742     return success();
743   }
744 };
745 
746 } // namespace
747 } // namespace tensor
748 } // namespace mlir
749 
750 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
751     DialectRegistry &registry) {
752   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
753     CastOp::attachInterface<CastOpInterface>(*ctx);
754     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
755     DimOp::attachInterface<DimOpInterface>(*ctx);
756     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
757     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
758     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
759     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
760     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
761     InsertOp::attachInterface<InsertOpInterface>(*ctx);
762     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
763     RankOp::attachInterface<RankOpInterface>(*ctx);
764   });
765 }
766