1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 using namespace mlir::tensor;
22 
23 namespace mlir {
24 namespace tensor {
25 namespace {
26 
27 struct CastOpInterface
28     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
29                                                     tensor::CastOp> {
30   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
31                               const AnalysisState &state) const {
32     return false;
33   }
34 
35   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
36                                const AnalysisState &state) const {
37     return false;
38   }
39 
40   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
41                                             const AnalysisState &state) const {
42     return {op->getResult(0)};
43   }
44 
45   BufferRelation bufferRelation(Operation *op, OpResult opResult,
46                                 const AnalysisState &state) const {
47     return BufferRelation::Equivalent;
48   }
49 
50   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51                           const BufferizationOptions &options) const {
52     auto castOp = cast<tensor::CastOp>(op);
53 
54     // The result buffer still has the old (pre-cast) type.
55     FailureOr<Value> resultBuffer =
56         getBuffer(rewriter, castOp.getSource(), options);
57     if (failed(resultBuffer))
58       return failure();
59     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
60     TensorType resultTensorType =
61         castOp.getResult().getType().cast<TensorType>();
62     MemRefLayoutAttrInterface layout;
63 
64     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
65       if (resultTensorType.isa<RankedTensorType>())
66         layout = rankedMemRefType.getLayout();
67 
68     // Compute the new memref type.
69     Type resultMemRefType =
70         getMemRefType(castOp.getResult(), options, layout,
71                       sourceMemRefType.getMemorySpaceAsInt());
72 
73     // Replace the op with a memref.cast.
74     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
75                                              resultMemRefType) &&
76            "CallOp::bufferize: cast incompatible");
77     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
78                                                  *resultBuffer);
79 
80     return success();
81   }
82 };
83 
84 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
85 struct CollapseShapeOpInterface
86     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
87                                                     tensor::CollapseShapeOp> {
88   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
89                               const AnalysisState &state) const {
90     return false;
91   }
92 
93   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
94                                const AnalysisState &state) const {
95     return false;
96   }
97 
98   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
99                                             const AnalysisState &state) const {
100     if (&opOperand == &op->getOpOperand(0) /*src*/)
101       return {op->getOpResult(0)};
102     return {};
103   }
104 
105   BufferRelation bufferRelation(Operation *op, OpResult opResult,
106                                 const AnalysisState &state) const {
107     return BufferRelation::Equivalent;
108   }
109 
110   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
111                           const BufferizationOptions &options) const {
112     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
113     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
114     FailureOr<Value> maybeBuffer =
115         getBuffer(rewriter, collapseShapeOp.getSrc(), options);
116     if (failed(maybeBuffer))
117       return failure();
118     Value buffer = *maybeBuffer;
119     auto bufferType = buffer.getType().cast<MemRefType>();
120 
121     if (tensorResultType.getRank() == 0) {
122       // 0-d collapses must go through a different op builder.
123       MemRefType resultType;
124 
125       if (bufferType.getLayout().isIdentity()) {
126         // Standard layout: result type has no offset.
127         MemRefLayoutAttrInterface layout;
128         resultType = MemRefType::get({}, tensorResultType.getElementType(),
129                                      layout, bufferType.getMemorySpace());
130       } else {
131         // Source memref has a layout map: result type has the same offset as
132         // the source type.
133         SmallVector<int64_t> strides;
134         int64_t offset;
135         if (failed(getStridesAndOffset(bufferType, strides, offset)))
136           return failure();
137         AffineMap resultLayout =
138             makeStridedLinearLayoutMap({}, offset, op->getContext());
139         resultType =
140             MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
141                             bufferType.getMemorySpaceAsInt());
142       }
143 
144       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
145           rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
146       return success();
147     }
148 
149     // If the dims are not collapsible (due to an incompatible source layout
150     // map), force an out-of-place bufferization, i.e., a buffer copy. This
151     // newly allocated buffer will have no layout map and thus be collapsible.
152     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
153         bufferType, collapseShapeOp.getReassociationIndices());
154     if (!canBeCollapsed) {
155       // TODO: Create alloc_tensor ops during TensorCopyInsertion.
156       AnalysisState analysisState(options);
157       FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
158           rewriter, op->getLoc(), collapseShapeOp.getSrc(),
159           analysisState.isTensorYielded(collapseShapeOp.getResult()), options);
160       if (failed(tensorAlloc))
161         return failure();
162       auto memrefType =
163           MemRefType::get(collapseShapeOp.getSrcType().getShape(),
164                           collapseShapeOp.getSrcType().getElementType(),
165                           AffineMap(), bufferType.getMemorySpaceAsInt());
166       buffer = rewriter.create<bufferization::ToMemrefOp>(
167           op->getLoc(), memrefType, *tensorAlloc);
168     }
169 
170     // Result type is inferred by the builder.
171     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
172         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
173     return success();
174   }
175 };
176 
177 /// Bufferization of tensor.dim. Replace with memref.dim.
178 struct DimOpInterface
179     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
180                                                     tensor::DimOp> {
181   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
182                               const AnalysisState &state) const {
183     return true;
184   }
185 
186   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
187                                const AnalysisState &state) const {
188     return false;
189   }
190 
191   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
192                                             const AnalysisState &state) const {
193     return {};
194   }
195 
196   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
197                           const BufferizationOptions &options) const {
198     auto dimOp = cast<tensor::DimOp>(op);
199     FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
200     if (failed(v))
201       return failure();
202     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
203                                                 dimOp.getIndex());
204     return success();
205   }
206 };
207 
208 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
209 struct ExpandShapeOpInterface
210     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
211                                                     tensor::ExpandShapeOp> {
212   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
213                               const AnalysisState &state) const {
214     return false;
215   }
216 
217   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
218                                const AnalysisState &state) const {
219     return false;
220   }
221 
222   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
223                                             const AnalysisState &state) const {
224     if (&opOperand == &op->getOpOperand(0) /*src*/)
225       return {op->getOpResult(0)};
226     return {};
227   }
228 
229   BufferRelation bufferRelation(Operation *op, OpResult opResult,
230                                 const AnalysisState &state) const {
231     return BufferRelation::Equivalent;
232   }
233 
234   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
235                           const BufferizationOptions &options) const {
236     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
237     auto tensorResultType = expandShapeOp.getResultType();
238     FailureOr<Value> buffer =
239         getBuffer(rewriter, expandShapeOp.getSrc(), options);
240     if (failed(buffer))
241       return failure();
242 
243     // Memref result type is inferred by the builder based on reassociation
244     // indices and result shape.
245     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
246         rewriter, op, tensorResultType.getShape(), *buffer,
247         expandShapeOp.getReassociationIndices());
248     return success();
249   }
250 };
251 
252 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
253 struct ExtractSliceOpInterface
254     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
255                                                     tensor::ExtractSliceOp> {
256   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
257                               const AnalysisState &state) const {
258     return false;
259   }
260 
261   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
262                                const AnalysisState &state) const {
263     return false;
264   }
265 
266   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
267                                             const AnalysisState &state) const {
268     if (&opOperand == &op->getOpOperand(0) /*source*/)
269       return {op->getOpResult(0)};
270     return {};
271   }
272 
273   BufferRelation bufferRelation(Operation *op, OpResult opResult,
274                                 const AnalysisState &state) const {
275     return BufferRelation::None;
276   }
277 
278   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
279                           const BufferizationOptions &options) const {
280     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
281     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
282     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
283     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
284     Location loc = extractSliceOp.getLoc();
285 
286     // Get source buffer.
287     FailureOr<Value> srcMemref =
288         getBuffer(rewriter, extractSliceOp.getSource(), options);
289     if (failed(srcMemref))
290       return failure();
291     auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
292 
293     // Take a subview of the source buffer.
294     auto subviewMemRefType =
295         memref::SubViewOp::inferRankReducedResultType(
296             extractSliceOp.getType().getShape(), srcMemrefType, mixedOffsets,
297             mixedSizes, mixedStrides)
298             .cast<MemRefType>();
299     Value subView = rewriter.create<memref::SubViewOp>(
300         loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
301         mixedStrides);
302 
303     replaceOpWithBufferizedValues(rewriter, op, subView);
304     return success();
305   }
306 };
307 
308 /// Bufferization of tensor.extract. Replace with memref.load.
309 struct ExtractOpInterface
310     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
311                                                     tensor::ExtractOp> {
312   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
313                               const AnalysisState &state) const {
314     return true;
315   }
316 
317   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
318                                const AnalysisState &state) const {
319     return false;
320   }
321 
322   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
323                                             const AnalysisState &state) const {
324     return {};
325   }
326 
327   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
328                           const BufferizationOptions &options) const {
329     auto extractOp = cast<tensor::ExtractOp>(op);
330     FailureOr<Value> srcMemref =
331         getBuffer(rewriter, extractOp.getTensor(), options);
332     if (failed(srcMemref))
333       return failure();
334     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
335                                                  extractOp.getIndices());
336     return success();
337   }
338 };
339 
340 // Implements backtracking to traverse indices of the output buffer while
341 // iterating over op.elements().
342 static void createStores(RewriterBase &rewriter, Location loc, int dim,
343                          Value buffer, ArrayRef<int64_t> shape,
344                          ArrayRef<Value> constants,
345                          OperandRange::iterator &elementIt,
346                          SmallVectorImpl<Value> &indices) {
347   if (dim == static_cast<int>(shape.size()) - 1) {
348     for (int i = 0; i < shape.back(); ++i) {
349       indices.back() = constants[i];
350       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
351       ++elementIt;
352     }
353     return;
354   }
355   for (int i = 0; i < shape[dim]; ++i) {
356     indices[dim] = constants[i];
357     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
358                  indices);
359   }
360 }
361 
362 /// Bufferization of tensor.from_elements.
363 struct FromElementsOpInterface
364     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
365                                                     tensor::FromElementsOp> {
366   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
367                           const BufferizationOptions &options) const {
368     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
369 
370     // TODO: Implement memory space for this op.
371     if (options.defaultMemorySpace != static_cast<unsigned>(0))
372       return op->emitError("memory space not implemented yet");
373 
374     // Allocate a buffer for the result.
375     Location loc = op->getLoc();
376     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
377     auto shape = tensorType.getShape();
378     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
379     AnalysisState analysisState(options);
380     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
381         rewriter, loc, fromElementsOp.getResult(),
382         analysisState.isTensorYielded(fromElementsOp.getResult()), options,
383         /*copy=*/false);
384     if (failed(tensorAlloc))
385       return failure();
386     auto memrefType =
387         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
388     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
389         op->getLoc(), memrefType, *tensorAlloc);
390 
391     // Case: tensor<0xelem_type>.
392     if (fromElementsOp.getElements().empty()) {
393       replaceOpWithBufferizedValues(rewriter, op, buffer);
394       return success();
395     }
396 
397     // Case: tensor<elem_type>.
398     if (shape.empty()) {
399       rewriter.create<memref::StoreOp>(
400           loc, fromElementsOp.getElements().front(), buffer);
401       replaceOpWithBufferizedValues(rewriter, op, buffer);
402       return success();
403     }
404 
405     // Create constants for the range of possible indices [0, max{shape_i}).
406     auto maxDim = *std::max_element(shape.begin(), shape.end());
407     SmallVector<Value, 2> constants;
408     constants.reserve(maxDim);
409     for (int i = 0; i < maxDim; ++i)
410       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
411 
412     // Traverse all `elements` and create `memref.store` ops.
413     auto elementIt = fromElementsOp.getElements().begin();
414     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
415     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
416                  indices);
417 
418     replaceOpWithBufferizedValues(rewriter, op, buffer);
419     return success();
420   }
421 };
422 
423 /// Bufferization of tensor.generate.
424 struct GenerateOpInterface
425     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
426                                                     tensor::GenerateOp> {
427   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
428                           const BufferizationOptions &options) const {
429     auto generateOp = cast<tensor::GenerateOp>(op);
430 
431     // TODO: Implement memory space for this op.
432     if (options.defaultMemorySpace != static_cast<unsigned>(0))
433       return op->emitError("memory space not implemented yet");
434 
435     auto tensorType = generateOp.getType().cast<RankedTensorType>();
436     // Allocate memory.
437     Location loc = op->getLoc();
438     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
439     AnalysisState analysisState(options);
440     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
441         rewriter, loc, generateOp.getResult(),
442         analysisState.isTensorYielded(generateOp.getResult()), options,
443         /*copy=*/false);
444     if (failed(tensorAlloc))
445       return failure();
446     auto memrefType =
447         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
448     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
449         op->getLoc(), memrefType, *tensorAlloc);
450 
451     // Collect loop bounds.
452     int64_t rank = memrefType.getRank();
453     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
454     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
455     SmallVector<Value, 4> lowerBounds(rank, zero);
456     SmallVector<Value, 4> steps(rank, one);
457     SmallVector<Value, 4> upperBounds;
458     int nextDynamicIndex = 0;
459     for (int i = 0; i < rank; i++) {
460       Value upperBound =
461           memrefType.isDynamicDim(i)
462               ? generateOp.getDynamicExtents()[nextDynamicIndex++]
463               : rewriter.create<arith::ConstantIndexOp>(
464                     loc, memrefType.getDimSize(i));
465       upperBounds.push_back(upperBound);
466     }
467 
468     // Generate tensor elements with a parallel loop that stores into
469     // each element of the resulting memref. We use mergeBlockBefore to "move"
470     // this op's body into the scf.parallel's body.
471     auto parallel =
472         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
473     Block *parallelBody = parallel.getBody();
474     rewriter.mergeBlockBefore(&generateOp.getBody().front(),
475                               parallelBody->getTerminator(),
476                               parallelBody->getArguments());
477     // Replace the inlined yield op with a store op. The scf.parallel's builder
478     // already populated an scf.yield at the end, so we don't need to worry
479     // about creating that.
480     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
481     rewriter.setInsertionPointAfter(elementYield);
482     rewriter.replaceOpWithNewOp<memref::StoreOp>(
483         elementYield, elementYield->getOperands()[0], buffer,
484         parallelBody->getArguments());
485 
486     replaceOpWithBufferizedValues(rewriter, op, buffer);
487     return success();
488   }
489 };
490 
491 /// Bufferization of tensor.insert. Replace with memref.store.
492 struct InsertOpInterface
493     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
494                                                     tensor::InsertOp> {
495   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
496                               const AnalysisState &state) const {
497     return true;
498   }
499 
500   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
501                                const AnalysisState &state) const {
502     return true;
503   }
504 
505   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
506                                             const AnalysisState &state) const {
507     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
508            "expected dest OpOperand");
509     return {op->getOpResult(0)};
510   }
511 
512   SmallVector<OpOperand *>
513   getAliasingOpOperand(Operation *op, OpResult opResult,
514                        const AnalysisState &state) const {
515     return {&op->getOpOperand(1) /*dest*/};
516   }
517 
518   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
519                           const BufferizationOptions &options) const {
520     auto insertOp = cast<tensor::InsertOp>(op);
521     FailureOr<Value> destMemref =
522         getBuffer(rewriter, insertOp.getDest(), options);
523     if (failed(destMemref))
524       return failure();
525     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
526                                      *destMemref, insertOp.getIndices());
527     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
528     return success();
529   }
530 
531   BufferRelation bufferRelation(Operation *op, OpResult opResult,
532                                 const AnalysisState &state) const {
533     return BufferRelation::Equivalent;
534   }
535 };
536 
537 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
538 /// equivalent operand / result and same offset/sizes/strides specification).
539 ///
540 /// This is one particular type of relationship between ops on tensors that
541 /// reduce to an equivalence on buffers. This should be generalized and
542 /// exposed as interfaces on the proper types.
543 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
544                                          ExtractSliceOp st, InsertSliceOp sti) {
545   if (!st || !sti)
546     return false;
547   if (sti != sti &&
548       !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
549     return false;
550   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
551     return false;
552   return true;
553 }
554 
555 /// Return true if `value` is originating from an ExtractSliceOp that matches
556 /// the given InsertSliceOp.
557 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
558                                       InsertSliceOp insertOp) {
559   auto condition = [&](Value val) {
560     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
561       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
562         return true;
563     return false;
564   };
565 
566   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
567                       condition);
568 }
569 
570 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
571 /// certain circumstances, this op can also be a no-op.
572 struct InsertSliceOpInterface
573     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
574                                                     tensor::InsertSliceOp> {
575   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
576                               const AnalysisState &state) const {
577     return true;
578   }
579 
580   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
581                                const AnalysisState &state) const {
582     return &opOperand == &op->getOpOperand(1) /*dest*/;
583   }
584 
585   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
586                                             const AnalysisState &state) const {
587     if (&opOperand == &op->getOpOperand(1) /*dest*/)
588       return {op->getResult(0)};
589     return {};
590   }
591 
592   BufferRelation bufferRelation(Operation *op, OpResult opResult,
593                                 const AnalysisState &state) const {
594     return BufferRelation::Equivalent;
595   }
596 
597   bool isNotConflicting(Operation *op, OpOperand *uRead,
598                         OpOperand *uConflictingWrite,
599                         const AnalysisState &state) const {
600     Operation *readingOp = uRead->getOwner();
601     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
602 
603     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
604     // uRead is an InsertSliceOp...
605     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
606       // As an example, consider the following IR.
607       //
608       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
609       // %1 = linalg.fill %cst, %0 {inplace= [true] }
610       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
611       //     {inplace= [true] }
612 
613       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
614       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
615           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
616                                     insertSliceOp))
617         // Case 1: The main insight is that InsertSliceOp reads only part of
618         // the destination tensor. The overwritten area is not read. If
619         // uConflictingWrite writes into exactly the memory location that is
620         // being read by uRead, this is not a conflict.
621         //
622         // In the above example:
623         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
624         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
625         //
626         // The read of %t does not conflict with the write of the FillOp
627         // (same aliases!) because the area that the FillOp operates on is
628         // exactly the one that is *not* read via %t.
629         return true;
630 
631       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
632           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
633           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
634         // Case 2: The read of the source tensor and the write to the dest
635         // tensor via an InsertSliceOp is not a conflict if the read is
636         // reading exactly that part of an equivalent tensor that the
637         // InsertSliceOp is writing.
638         //
639         // In the above example:
640         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
641         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
642         return true;
643     }
644 
645     // If uConflictingWrite is an InsertSliceOp...
646     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
647       // As an example, consider the following IR.
648       //
649       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
650       // %1 = linalg.fill %cst, %0 {inplace= [true] }
651       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
652       //     {inplace= [true] }
653       // %3 = vector.transfer_read %1, %cst
654       //
655       // In the above example:
656       // uRead             = OpOperand 0 (%1) of vector.transfer_read
657       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
658       // lastWrite         = %1
659       //
660       // This is not a conflict because the InsertSliceOp overwrites the
661       // memory segment of %1 with the exact same data. (Effectively, there
662       // is no memory write here.)
663       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
664           state.areEquivalentBufferizedValues(uRead->get(),
665                                               insertSliceOp.getSource()) &&
666           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
667                                     insertSliceOp))
668         return true;
669 
670     return false;
671   }
672 
673   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
674                           const BufferizationOptions &options) const {
675     // insert_slice ops arise from tiling and bufferizing them out-of-place is
676     // generally a deal breaker. When used with loops, this ends up cloning the
677     // whole tensor on every single iteration and is a symptom of a
678     // catastrophically bad scheduling decision.
679     // TODO: be very loud about it or even consider failing the pass.
680     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
681     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
682     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
683     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
684     Location loc = insertSliceOp.getLoc();
685 
686     // Get destination buffer.
687     FailureOr<Value> dstMemref =
688         getBuffer(rewriter, insertSliceOp.getDest(), options);
689     if (failed(dstMemref))
690       return failure();
691 
692     // Take a subview of the destination buffer.
693     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
694     auto subviewMemRefType =
695         memref::SubViewOp::inferRankReducedResultType(
696             insertSliceOp.getSourceType().getShape(), dstMemrefType,
697             mixedOffsets, mixedSizes, mixedStrides)
698             .cast<MemRefType>();
699     Value subView = rewriter.create<memref::SubViewOp>(
700         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
701         mixedStrides);
702 
703     // Copy tensor. If this tensor.insert_slice has a matching
704     // tensor.extract_slice, the copy operation will eventually fold away.
705     FailureOr<Value> srcMemref =
706         getBuffer(rewriter, insertSliceOp.getSource(), options);
707     if (failed(srcMemref))
708       return failure();
709     if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
710       return failure();
711 
712     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
713     return success();
714   }
715 };
716 
717 /// Bufferization of tensor.rank. Replace with memref.rank.
718 struct RankOpInterface
719     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
720                                                     tensor::RankOp> {
721   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
722                               const AnalysisState &state) const {
723     return true;
724   }
725 
726   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
727                                const AnalysisState &state) const {
728     return false;
729   }
730 
731   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
732                                             const AnalysisState &state) const {
733     return {};
734   }
735 
736   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
737                           const BufferizationOptions &options) const {
738     auto rankOp = cast<tensor::RankOp>(op);
739     FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
740     if (failed(v))
741       return failure();
742     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
743                                                  *v);
744     return success();
745   }
746 };
747 
748 /// Bufferization of tensor.reshape. Replace with memref.reshape.
749 struct ReshapeOpInterface
750     : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
751                                                     tensor::ReshapeOp> {
752   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
753                               const AnalysisState &state) const {
754     if (&opOperand == &op->getOpOperand(1) /* shape */)
755       return true;
756     return false;
757   }
758 
759   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
760                                const AnalysisState &state) const {
761     return false;
762   }
763 
764   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
765                                             const AnalysisState &state) const {
766     return {op->getOpResult(0)};
767   }
768 
769   BufferRelation bufferRelation(Operation *op, OpResult opResult,
770                                 const AnalysisState &state) const {
771     return BufferRelation::Equivalent;
772   }
773 
774   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
775                           const BufferizationOptions &options) const {
776     auto reshapeOp = cast<tensor::ReshapeOp>(op);
777     FailureOr<Value> srcBuffer =
778         getBuffer(rewriter, reshapeOp.getSource(), options);
779     FailureOr<Value> shapeBuffer =
780         getBuffer(rewriter, reshapeOp.getShape(), options);
781     if (failed(srcBuffer) || failed(shapeBuffer))
782       return failure();
783     auto resultMemRefType = getMemRefType(
784         reshapeOp.getResult(), options, /*layout=*/{},
785         srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt());
786     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
787         rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
788     return success();
789   }
790 };
791 
792 /// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
793 /// equivalent operand / result and same offset/sizes/strides specification).
794 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
795                                          ExtractSliceOp st,
796                                          ParallelInsertSliceOp sti) {
797   if (!st || !sti)
798     return false;
799   if (st != sti &&
800       !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
801     return false;
802   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
803     return false;
804   return true;
805 }
806 
807 /// Return true if `value` is originating from an ExtractSliceOp that matches
808 /// the given InsertSliceOp.
809 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
810                                       ParallelInsertSliceOp insertOp) {
811   auto condition = [&](Value val) {
812     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
813       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
814         return true;
815     return false;
816   };
817 
818   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
819                       condition);
820 }
821 
822 /// Analysis of ParallelInsertSliceOp.
823 struct ParallelInsertSliceOpInterface
824     : public BufferizableOpInterface::ExternalModel<
825           ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
826   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
827                                             const AnalysisState &state) const {
828     if (&opOperand != &op->getOpOperand(1) /*dest*/)
829       return {};
830 
831     // ParallelInsertSliceOp itself has no results, query its tied op results.
832     auto insertOp = cast<ParallelInsertSliceOp>(op);
833     return {insertOp.getTiedOpResult()};
834   }
835 
836   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
837                               const AnalysisState &state) const {
838     return true;
839   }
840 
841   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
842                                const AnalysisState &state) const {
843     return &opOperand == &op->getOpOperand(1) /*dest*/;
844   }
845 
846   BufferRelation bufferRelation(Operation *op, OpResult opResult,
847                                 const AnalysisState &state) const {
848     return BufferRelation::Equivalent;
849   }
850 
851   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
852                                  const AnalysisState &state) const {
853     // This interface method is overridden because we want to set a custom
854     // insertion point for tensor copies. They should be inserted right before
855     // the ForeachThreadOp. E.g.:
856     //
857     // %r0, %r1 = foreach_thead ... {
858     //   ...
859     //   perform_concurrently {
860     //     parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
861     //     parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
862     //   }
863     // }
864     //
865     // After TensorCopyInsertion:
866     //
867     // %copy = bufferization.alloc_tensor() copy(%d)
868     // %r0, %r1 = foreach_thead ... {
869     //   ...
870     //   perform_concurrently {
871     //     parallel_insert_slice %a into %b ...
872     //     parallel_insert_slice %c into %copy ...
873     //   }
874     // }
875 
876     OpBuilder::InsertionGuard g(rewriter);
877     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
878     ParallelCombiningOpInterface parallelCombiningParent =
879         parallelInsertSliceOp.getParallelCombiningParent();
880     Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
881 
882     // Nothing to do if the destination tensor is inplace.
883     assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
884            "source is always in-place");
885     if (state.isInPlace(op->getOpOperand(1) /*dest*/))
886       return success();
887 
888     // Find corresponding OpResult.
889     OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
890 
891     // Insert tensor allocation right before the ForeachThreadOp.
892     rewriter.setInsertionPoint(parallelIteratingOp);
893     bool isYielded = state.isTensorYielded(opResult);
894     FailureOr<Value> alloc = allocateTensorForShapedValue(
895         rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
896         /*escape=*/isYielded, state.getOptions());
897     if (failed(alloc))
898       return failure();
899 
900     // Update destination operand.
901     rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
902       parallelInsertSliceOp.getDestMutable().assign(*alloc);
903     });
904 
905     return success();
906   }
907 
908   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
909                           const BufferizationOptions &options) const {
910     OpBuilder::InsertionGuard g(rewriter);
911     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
912     ParallelCombiningOpInterface parallelCombiningParent =
913         parallelInsertSliceOp.getParallelCombiningParent();
914     Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
915 
916     // Get destination buffer.
917     FailureOr<Value> destBuffer =
918         getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
919     if (failed(destBuffer))
920       return failure();
921 
922     // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
923     rewriter.setInsertionPoint(parallelCombiningParent);
924     FailureOr<Value> srcBuffer =
925         getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
926     if (failed(srcBuffer))
927       return failure();
928 
929     // Take a subview of the destination buffer.
930     auto destBufferType = destBuffer->getType().cast<MemRefType>();
931     auto subviewMemRefType =
932         memref::SubViewOp::inferRankReducedResultType(
933             parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
934             parallelInsertSliceOp.getMixedOffsets(),
935             parallelInsertSliceOp.getMixedSizes(),
936             parallelInsertSliceOp.getMixedStrides())
937             .cast<MemRefType>();
938     Value subview = rewriter.create<memref::SubViewOp>(
939         parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
940         parallelInsertSliceOp.getMixedOffsets(),
941         parallelInsertSliceOp.getMixedSizes(),
942         parallelInsertSliceOp.getMixedStrides());
943 
944     // This memcpy will fold away if everything bufferizes in-place.
945     if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
946                                     *srcBuffer, subview)))
947       return failure();
948 
949     // Replace all uses of parallelIteratingOp (just the corresponding result).
950     rewriter.setInsertionPointAfter(parallelIteratingOp);
951     Value toTensorOp =
952         rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
953     // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
954     SmallVector<OpOperand *> resultUses = llvm::to_vector(
955         llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
956                         [](OpOperand &use) { return &use; }));
957     for (OpOperand *use : resultUses) {
958       rewriter.updateRootInPlace(use->getOwner(),
959                                  [&]() { use->set(toTensorOp); });
960     }
961     rewriter.eraseOp(op);
962     return success();
963   }
964 
965   // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
966   // the code.
967   bool isNotConflicting(Operation *op, OpOperand *uRead,
968                         OpOperand *uConflictingWrite,
969                         const AnalysisState &state) const {
970     Operation *readingOp = uRead->getOwner();
971     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
972 
973     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
974     // uRead is an InsertSliceOp...
975     if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
976       // As an example, consider the following IR.
977       //
978       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
979       // %1 = linalg.fill %cst, %0 {inplace= [true] }
980       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
981       //     {inplace= [true] }
982 
983       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
984       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
985           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
986                                     insertSliceOp))
987         // Case 1: The main insight is that InsertSliceOp reads only part of
988         // the destination tensor. The overwritten area is not read. If
989         // uConflictingWrite writes into exactly the memory location that is
990         // being read by uRead, this is not a conflict.
991         //
992         // In the above example:
993         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
994         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
995         //
996         // The read of %t does not conflict with the write of the FillOp
997         // (same aliases!) because the area that the FillOp operates on is
998         // exactly the one that is *not* read via %t.
999         return true;
1000 
1001       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
1002           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1003           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
1004         // Case 2: The read of the source tensor and the write to the dest
1005         // tensor via an InsertSliceOp is not a conflict if the read is
1006         // reading exactly that part of an equivalent tensor that the
1007         // InsertSliceOp is writing.
1008         //
1009         // In the above example:
1010         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
1011         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1012         return true;
1013     }
1014 
1015     // If uConflictingWrite is an InsertSliceOp...
1016     if (auto insertSliceOp =
1017             dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
1018       // As an example, consider the following IR.
1019       //
1020       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1021       // %1 = linalg.fill %cst, %0 {inplace= [true] }
1022       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1023       //     {inplace= [true] }
1024       // %3 = vector.transfer_read %1, %cst
1025       //
1026       // In the above example:
1027       // uRead             = OpOperand 0 (%1) of vector.transfer_read
1028       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1029       // lastWrite         = %1
1030       //
1031       // This is not a conflict because the InsertSliceOp overwrites the
1032       // memory segment of %1 with the exact same data. (Effectively, there
1033       // is no memory write here.)
1034       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1035           state.areEquivalentBufferizedValues(uRead->get(),
1036                                               insertSliceOp.getSource()) &&
1037           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
1038                                     insertSliceOp))
1039         return true;
1040 
1041     return false;
1042   }
1043 };
1044 
1045 } // namespace
1046 } // namespace tensor
1047 } // namespace mlir
1048 
1049 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
1050     DialectRegistry &registry) {
1051   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
1052     CastOp::attachInterface<CastOpInterface>(*ctx);
1053     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1054     DimOp::attachInterface<DimOpInterface>(*ctx);
1055     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
1056     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
1057     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
1058     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
1059     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
1060     InsertOp::attachInterface<InsertOpInterface>(*ctx);
1061     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
1062     ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1063         *ctx);
1064     RankOp::attachInterface<RankOpInterface>(*ctx);
1065     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
1066   });
1067 }
1068