1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 using namespace mlir::tensor;
22 
23 namespace mlir {
24 namespace tensor {
25 namespace {
26 
27 struct CastOpInterface
28     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
29                                                     tensor::CastOp> {
30   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
31                               const AnalysisState &state) const {
32     return false;
33   }
34 
35   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
36                                const AnalysisState &state) const {
37     return false;
38   }
39 
40   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
41                                             const AnalysisState &state) const {
42     return {op->getResult(0)};
43   }
44 
45   BufferRelation bufferRelation(Operation *op, OpResult opResult,
46                                 const AnalysisState &state) const {
47     return BufferRelation::Equivalent;
48   }
49 
50   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51                           const BufferizationOptions &options) const {
52     auto castOp = cast<tensor::CastOp>(op);
53 
54     // The result buffer still has the old (pre-cast) type.
55     FailureOr<Value> resultBuffer =
56         getBuffer(rewriter, castOp.getSource(), options);
57     if (failed(resultBuffer))
58       return failure();
59     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
60     TensorType resultTensorType =
61         castOp.getResult().getType().cast<TensorType>();
62     MemRefLayoutAttrInterface layout;
63 
64     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
65       if (resultTensorType.isa<RankedTensorType>())
66         layout = rankedMemRefType.getLayout();
67 
68     // Compute the new memref type.
69     Type resultMemRefType =
70         getMemRefType(castOp.getResult(), options, layout,
71                       sourceMemRefType.getMemorySpaceAsInt());
72 
73     // Replace the op with a memref.cast.
74     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
75                                              resultMemRefType) &&
76            "CallOp::bufferize: cast incompatible");
77     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
78                                                  *resultBuffer);
79 
80     return success();
81   }
82 };
83 
84 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
85 struct CollapseShapeOpInterface
86     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
87                                                     tensor::CollapseShapeOp> {
88   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
89                               const AnalysisState &state) const {
90     return false;
91   }
92 
93   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
94                                const AnalysisState &state) const {
95     return false;
96   }
97 
98   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
99                                             const AnalysisState &state) const {
100     if (&opOperand == &op->getOpOperand(0) /*src*/)
101       return {op->getOpResult(0)};
102     return {};
103   }
104 
105   BufferRelation bufferRelation(Operation *op, OpResult opResult,
106                                 const AnalysisState &state) const {
107     return BufferRelation::Equivalent;
108   }
109 
110   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
111                           const BufferizationOptions &options) const {
112     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
113     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
114     FailureOr<Value> maybeBuffer =
115         getBuffer(rewriter, collapseShapeOp.getSrc(), options);
116     if (failed(maybeBuffer))
117       return failure();
118     Value buffer = *maybeBuffer;
119     auto bufferType = buffer.getType().cast<MemRefType>();
120 
121     if (tensorResultType.getRank() == 0) {
122       // 0-d collapses must go through a different op builder.
123       MemRefType resultType;
124 
125       if (bufferType.getLayout().isIdentity()) {
126         // Standard layout: result type has no offset.
127         MemRefLayoutAttrInterface layout;
128         resultType = MemRefType::get({}, tensorResultType.getElementType(),
129                                      layout, bufferType.getMemorySpace());
130       } else {
131         // Source memref has a layout map: result type has the same offset as
132         // the source type.
133         SmallVector<int64_t> strides;
134         int64_t offset;
135         if (failed(getStridesAndOffset(bufferType, strides, offset)))
136           return failure();
137         AffineMap resultLayout =
138             makeStridedLinearLayoutMap({}, offset, op->getContext());
139         resultType =
140             MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
141                             bufferType.getMemorySpaceAsInt());
142       }
143 
144       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
145           rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
146       return success();
147     }
148 
149     // If the dims are not collapsible (due to an incompatible source layout
150     // map), force an out-of-place bufferization, i.e., a buffer copy. This
151     // newly allocated buffer will have no layout map and thus be collapsible.
152     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
153         bufferType, collapseShapeOp.getReassociationIndices());
154     if (!canBeCollapsed) {
155       // TODO: Create alloc_tensor ops during TensorCopyInsertion.
156       AnalysisState analysisState(options);
157       FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
158           rewriter, op->getLoc(), collapseShapeOp.getSrc(),
159           analysisState.isTensorYielded(collapseShapeOp.getResult()), options);
160       if (failed(tensorAlloc))
161         return failure();
162       auto memrefType =
163           MemRefType::get(collapseShapeOp.getSrcType().getShape(),
164                           collapseShapeOp.getSrcType().getElementType(),
165                           AffineMap(), bufferType.getMemorySpaceAsInt());
166       buffer = rewriter.create<bufferization::ToMemrefOp>(
167           op->getLoc(), memrefType, *tensorAlloc);
168     }
169 
170     // Result type is inferred by the builder.
171     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
172         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
173     return success();
174   }
175 };
176 
177 /// Bufferization of tensor.dim. Replace with memref.dim.
178 struct DimOpInterface
179     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
180                                                     tensor::DimOp> {
181   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
182                               const AnalysisState &state) const {
183     return true;
184   }
185 
186   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
187                                const AnalysisState &state) const {
188     return false;
189   }
190 
191   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
192                                             const AnalysisState &state) const {
193     return {};
194   }
195 
196   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
197                           const BufferizationOptions &options) const {
198     auto dimOp = cast<tensor::DimOp>(op);
199     FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
200     if (failed(v))
201       return failure();
202     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
203                                                 dimOp.getIndex());
204     return success();
205   }
206 };
207 
208 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
209 struct ExpandShapeOpInterface
210     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
211                                                     tensor::ExpandShapeOp> {
212   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
213                               const AnalysisState &state) const {
214     return false;
215   }
216 
217   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
218                                const AnalysisState &state) const {
219     return false;
220   }
221 
222   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
223                                             const AnalysisState &state) const {
224     if (&opOperand == &op->getOpOperand(0) /*src*/)
225       return {op->getOpResult(0)};
226     return {};
227   }
228 
229   BufferRelation bufferRelation(Operation *op, OpResult opResult,
230                                 const AnalysisState &state) const {
231     return BufferRelation::Equivalent;
232   }
233 
234   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
235                           const BufferizationOptions &options) const {
236     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
237     auto tensorResultType = expandShapeOp.getResultType();
238     FailureOr<Value> buffer =
239         getBuffer(rewriter, expandShapeOp.getSrc(), options);
240     if (failed(buffer))
241       return failure();
242 
243     // Memref result type is inferred by the builder based on reassociation
244     // indices and result shape.
245     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
246         rewriter, op, tensorResultType.getShape(), *buffer,
247         expandShapeOp.getReassociationIndices());
248     return success();
249   }
250 };
251 
252 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
253 struct ExtractSliceOpInterface
254     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
255                                                     tensor::ExtractSliceOp> {
256   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
257                               const AnalysisState &state) const {
258     return false;
259   }
260 
261   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
262                                const AnalysisState &state) const {
263     return false;
264   }
265 
266   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
267                                             const AnalysisState &state) const {
268     if (&opOperand == &op->getOpOperand(0) /*source*/)
269       return {op->getOpResult(0)};
270     return {};
271   }
272 
273   BufferRelation bufferRelation(Operation *op, OpResult opResult,
274                                 const AnalysisState &state) const {
275     return BufferRelation::None;
276   }
277 
278   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
279                           const BufferizationOptions &options) const {
280     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
281     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
282     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
283     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
284     Location loc = extractSliceOp.getLoc();
285 
286     // Get source buffer.
287     FailureOr<Value> srcMemref =
288         getBuffer(rewriter, extractSliceOp.getSource(), options);
289     if (failed(srcMemref))
290       return failure();
291     auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
292 
293     // Take a subview of the source buffer.
294     auto subviewMemRefType =
295         memref::SubViewOp::inferRankReducedResultType(
296             extractSliceOp.getType().getShape(), srcMemrefType, mixedOffsets,
297             mixedSizes, mixedStrides)
298             .cast<MemRefType>();
299     Value subView = rewriter.create<memref::SubViewOp>(
300         loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
301         mixedStrides);
302 
303     replaceOpWithBufferizedValues(rewriter, op, subView);
304     return success();
305   }
306 };
307 
308 /// Bufferization of tensor.extract. Replace with memref.load.
309 struct ExtractOpInterface
310     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
311                                                     tensor::ExtractOp> {
312   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
313                               const AnalysisState &state) const {
314     return true;
315   }
316 
317   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
318                                const AnalysisState &state) const {
319     return false;
320   }
321 
322   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
323                                             const AnalysisState &state) const {
324     return {};
325   }
326 
327   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
328                           const BufferizationOptions &options) const {
329     auto extractOp = cast<tensor::ExtractOp>(op);
330     FailureOr<Value> srcMemref =
331         getBuffer(rewriter, extractOp.getTensor(), options);
332     if (failed(srcMemref))
333       return failure();
334     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
335                                                  extractOp.getIndices());
336     return success();
337   }
338 };
339 
340 // Implements backtracking to traverse indices of the output buffer while
341 // iterating over op.elements().
342 static void createStores(RewriterBase &rewriter, Location loc, int dim,
343                          Value buffer, ArrayRef<int64_t> shape,
344                          ArrayRef<Value> constants,
345                          OperandRange::iterator &elementIt,
346                          SmallVectorImpl<Value> &indices) {
347   if (dim == static_cast<int>(shape.size()) - 1) {
348     for (int i = 0; i < shape.back(); ++i) {
349       indices.back() = constants[i];
350       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
351       ++elementIt;
352     }
353     return;
354   }
355   for (int i = 0; i < shape[dim]; ++i) {
356     indices[dim] = constants[i];
357     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
358                  indices);
359   }
360 }
361 
362 /// Bufferization of tensor.from_elements.
363 struct FromElementsOpInterface
364     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
365                                                     tensor::FromElementsOp> {
366 
367   bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
368     return true;
369   }
370 
371   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
372                           const BufferizationOptions &options) const {
373     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
374     // Should the buffer be deallocated?
375     bool dealloc = shouldDeallocateOpResult(
376         fromElementsOp.getResult().cast<OpResult>(), options);
377 
378     // TODO: Implement memory space for this op.
379     if (options.defaultMemorySpace != static_cast<unsigned>(0))
380       return op->emitError("memory space not implemented yet");
381 
382     // Allocate a buffer for the result.
383     Location loc = op->getLoc();
384     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
385     auto shape = tensorType.getShape();
386     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
387     FailureOr<Value> tensorAlloc =
388         allocateTensorForShapedValue(rewriter, loc, fromElementsOp.getResult(),
389                                      /*escape=*/!dealloc, options,
390                                      /*copy=*/false);
391     if (failed(tensorAlloc))
392       return failure();
393     auto memrefType =
394         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
395     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
396         op->getLoc(), memrefType, *tensorAlloc);
397 
398     // Case: tensor<0xelem_type>.
399     if (fromElementsOp.getElements().empty()) {
400       replaceOpWithBufferizedValues(rewriter, op, buffer);
401       return success();
402     }
403 
404     // Case: tensor<elem_type>.
405     if (shape.empty()) {
406       rewriter.create<memref::StoreOp>(
407           loc, fromElementsOp.getElements().front(), buffer);
408       replaceOpWithBufferizedValues(rewriter, op, buffer);
409       return success();
410     }
411 
412     // Create constants for the range of possible indices [0, max{shape_i}).
413     auto maxDim = *std::max_element(shape.begin(), shape.end());
414     SmallVector<Value, 2> constants;
415     constants.reserve(maxDim);
416     for (int i = 0; i < maxDim; ++i)
417       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
418 
419     // Traverse all `elements` and create `memref.store` ops.
420     auto elementIt = fromElementsOp.getElements().begin();
421     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
422     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
423                  indices);
424 
425     replaceOpWithBufferizedValues(rewriter, op, buffer);
426 
427     return success();
428   }
429 };
430 
431 /// Bufferization of tensor.generate.
432 struct GenerateOpInterface
433     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
434                                                     tensor::GenerateOp> {
435 
436   bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
437     return true;
438   }
439 
440   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
441                           const BufferizationOptions &options) const {
442     auto generateOp = cast<tensor::GenerateOp>(op);
443     // Should the buffer be deallocated?
444     bool dealloc = shouldDeallocateOpResult(
445         generateOp.getResult().cast<OpResult>(), options);
446 
447     // TODO: Implement memory space for this op.
448     if (options.defaultMemorySpace != static_cast<unsigned>(0))
449       return op->emitError("memory space not implemented yet");
450 
451     auto tensorType = generateOp.getType().cast<RankedTensorType>();
452     // Allocate memory.
453     Location loc = op->getLoc();
454     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
455     FailureOr<Value> tensorAlloc =
456         allocateTensorForShapedValue(rewriter, loc, generateOp.getResult(),
457                                      /*escape=*/!dealloc, options,
458                                      /*copy=*/false);
459     if (failed(tensorAlloc))
460       return failure();
461     auto memrefType =
462         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
463     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
464         op->getLoc(), memrefType, *tensorAlloc);
465 
466     // Collect loop bounds.
467     int64_t rank = memrefType.getRank();
468     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
469     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
470     SmallVector<Value, 4> lowerBounds(rank, zero);
471     SmallVector<Value, 4> steps(rank, one);
472     SmallVector<Value, 4> upperBounds;
473     int nextDynamicIndex = 0;
474     for (int i = 0; i < rank; i++) {
475       Value upperBound =
476           memrefType.isDynamicDim(i)
477               ? generateOp.getDynamicExtents()[nextDynamicIndex++]
478               : rewriter.create<arith::ConstantIndexOp>(
479                     loc, memrefType.getDimSize(i));
480       upperBounds.push_back(upperBound);
481     }
482 
483     // Generate tensor elements with a parallel loop that stores into
484     // each element of the resulting memref. We use mergeBlockBefore to "move"
485     // this op's body into the scf.parallel's body.
486     auto parallel =
487         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
488     Block *parallelBody = parallel.getBody();
489     rewriter.mergeBlockBefore(&generateOp.getBody().front(),
490                               parallelBody->getTerminator(),
491                               parallelBody->getArguments());
492     // Replace the inlined yield op with a store op. The scf.parallel's builder
493     // already populated an scf.yield at the end, so we don't need to worry
494     // about creating that.
495     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
496     rewriter.setInsertionPointAfter(elementYield);
497     rewriter.replaceOpWithNewOp<memref::StoreOp>(
498         elementYield, elementYield->getOperands()[0], buffer,
499         parallelBody->getArguments());
500 
501     replaceOpWithBufferizedValues(rewriter, op, buffer);
502 
503     return success();
504   }
505 };
506 
507 /// Bufferization of tensor.insert. Replace with memref.store.
508 struct InsertOpInterface
509     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
510                                                     tensor::InsertOp> {
511   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
512                               const AnalysisState &state) const {
513     return true;
514   }
515 
516   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
517                                const AnalysisState &state) const {
518     return true;
519   }
520 
521   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
522                                             const AnalysisState &state) const {
523     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
524            "expected dest OpOperand");
525     return {op->getOpResult(0)};
526   }
527 
528   SmallVector<OpOperand *>
529   getAliasingOpOperand(Operation *op, OpResult opResult,
530                        const AnalysisState &state) const {
531     return {&op->getOpOperand(1) /*dest*/};
532   }
533 
534   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
535                           const BufferizationOptions &options) const {
536     auto insertOp = cast<tensor::InsertOp>(op);
537     FailureOr<Value> destMemref =
538         getBuffer(rewriter, insertOp.getDest(), options);
539     if (failed(destMemref))
540       return failure();
541     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
542                                      *destMemref, insertOp.getIndices());
543     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
544     return success();
545   }
546 
547   BufferRelation bufferRelation(Operation *op, OpResult opResult,
548                                 const AnalysisState &state) const {
549     return BufferRelation::Equivalent;
550   }
551 };
552 
553 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
554 /// equivalent operand / result and same offset/sizes/strides specification).
555 ///
556 /// This is one particular type of relationship between ops on tensors that
557 /// reduce to an equivalence on buffers. This should be generalized and
558 /// exposed as interfaces on the proper types.
559 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
560                                          ExtractSliceOp st, InsertSliceOp sti) {
561   if (!st || !sti)
562     return false;
563   if (sti != sti &&
564       !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
565     return false;
566   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
567     return false;
568   return true;
569 }
570 
571 /// Return true if `value` is originating from an ExtractSliceOp that matches
572 /// the given InsertSliceOp.
573 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
574                                       InsertSliceOp insertOp) {
575   auto condition = [&](Value val) {
576     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
577       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
578         return true;
579     return false;
580   };
581 
582   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
583                       condition);
584 }
585 
586 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
587 /// certain circumstances, this op can also be a no-op.
588 struct InsertSliceOpInterface
589     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
590                                                     tensor::InsertSliceOp> {
591   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
592                               const AnalysisState &state) const {
593     return true;
594   }
595 
596   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
597                                const AnalysisState &state) const {
598     return &opOperand == &op->getOpOperand(1) /*dest*/;
599   }
600 
601   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
602                                             const AnalysisState &state) const {
603     if (&opOperand == &op->getOpOperand(1) /*dest*/)
604       return {op->getResult(0)};
605     return {};
606   }
607 
608   BufferRelation bufferRelation(Operation *op, OpResult opResult,
609                                 const AnalysisState &state) const {
610     return BufferRelation::Equivalent;
611   }
612 
613   bool isNotConflicting(Operation *op, OpOperand *uRead,
614                         OpOperand *uConflictingWrite,
615                         const AnalysisState &state) const {
616     Operation *readingOp = uRead->getOwner();
617     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
618 
619     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
620     // uRead is an InsertSliceOp...
621     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
622       // As an example, consider the following IR.
623       //
624       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
625       // %1 = linalg.fill %cst, %0 {inplace= [true] }
626       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
627       //     {inplace= [true] }
628 
629       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
630       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
631           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
632                                     insertSliceOp))
633         // Case 1: The main insight is that InsertSliceOp reads only part of
634         // the destination tensor. The overwritten area is not read. If
635         // uConflictingWrite writes into exactly the memory location that is
636         // being read by uRead, this is not a conflict.
637         //
638         // In the above example:
639         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
640         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
641         //
642         // The read of %t does not conflict with the write of the FillOp
643         // (same aliases!) because the area that the FillOp operates on is
644         // exactly the one that is *not* read via %t.
645         return true;
646 
647       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
648           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
649           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
650         // Case 2: The read of the source tensor and the write to the dest
651         // tensor via an InsertSliceOp is not a conflict if the read is
652         // reading exactly that part of an equivalent tensor that the
653         // InsertSliceOp is writing.
654         //
655         // In the above example:
656         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
657         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
658         return true;
659     }
660 
661     // If uConflictingWrite is an InsertSliceOp...
662     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
663       // As an example, consider the following IR.
664       //
665       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
666       // %1 = linalg.fill %cst, %0 {inplace= [true] }
667       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
668       //     {inplace= [true] }
669       // %3 = vector.transfer_read %1, %cst
670       //
671       // In the above example:
672       // uRead             = OpOperand 0 (%1) of vector.transfer_read
673       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
674       // lastWrite         = %1
675       //
676       // This is not a conflict because the InsertSliceOp overwrites the
677       // memory segment of %1 with the exact same data. (Effectively, there
678       // is no memory write here.)
679       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
680           state.areEquivalentBufferizedValues(uRead->get(),
681                                               insertSliceOp.getSource()) &&
682           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
683                                     insertSliceOp))
684         return true;
685 
686     return false;
687   }
688 
689   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
690                           const BufferizationOptions &options) const {
691     // insert_slice ops arise from tiling and bufferizing them out-of-place is
692     // generally a deal breaker. When used with loops, this ends up cloning the
693     // whole tensor on every single iteration and is a symptom of a
694     // catastrophically bad scheduling decision.
695     // TODO: be very loud about it or even consider failing the pass.
696     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
697     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
698     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
699     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
700     Location loc = insertSliceOp.getLoc();
701 
702     // Get destination buffer.
703     FailureOr<Value> dstMemref =
704         getBuffer(rewriter, insertSliceOp.getDest(), options);
705     if (failed(dstMemref))
706       return failure();
707 
708     // Take a subview of the destination buffer.
709     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
710     auto subviewMemRefType =
711         memref::SubViewOp::inferRankReducedResultType(
712             insertSliceOp.getSourceType().getShape(), dstMemrefType,
713             mixedOffsets, mixedSizes, mixedStrides)
714             .cast<MemRefType>();
715     Value subView = rewriter.create<memref::SubViewOp>(
716         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
717         mixedStrides);
718 
719     // Copy tensor. If this tensor.insert_slice has a matching
720     // tensor.extract_slice, the copy operation will eventually fold away.
721     FailureOr<Value> srcMemref =
722         getBuffer(rewriter, insertSliceOp.getSource(), options);
723     if (failed(srcMemref))
724       return failure();
725     if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
726       return failure();
727 
728     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
729     return success();
730   }
731 };
732 
733 /// Bufferization of tensor.rank. Replace with memref.rank.
734 struct RankOpInterface
735     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
736                                                     tensor::RankOp> {
737   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
738                               const AnalysisState &state) const {
739     return true;
740   }
741 
742   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
743                                const AnalysisState &state) const {
744     return false;
745   }
746 
747   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
748                                             const AnalysisState &state) const {
749     return {};
750   }
751 
752   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
753                           const BufferizationOptions &options) const {
754     auto rankOp = cast<tensor::RankOp>(op);
755     FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
756     if (failed(v))
757       return failure();
758     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
759                                                  *v);
760     return success();
761   }
762 };
763 
764 /// Bufferization of tensor.reshape. Replace with memref.reshape.
765 struct ReshapeOpInterface
766     : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
767                                                     tensor::ReshapeOp> {
768   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
769                               const AnalysisState &state) const {
770     if (&opOperand == &op->getOpOperand(1) /* shape */)
771       return true;
772     return false;
773   }
774 
775   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
776                                const AnalysisState &state) const {
777     return false;
778   }
779 
780   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
781                                             const AnalysisState &state) const {
782     return {op->getOpResult(0)};
783   }
784 
785   BufferRelation bufferRelation(Operation *op, OpResult opResult,
786                                 const AnalysisState &state) const {
787     return BufferRelation::Equivalent;
788   }
789 
790   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
791                           const BufferizationOptions &options) const {
792     auto reshapeOp = cast<tensor::ReshapeOp>(op);
793     FailureOr<Value> srcBuffer =
794         getBuffer(rewriter, reshapeOp.getSource(), options);
795     FailureOr<Value> shapeBuffer =
796         getBuffer(rewriter, reshapeOp.getShape(), options);
797     if (failed(srcBuffer) || failed(shapeBuffer))
798       return failure();
799     auto resultMemRefType = getMemRefType(
800         reshapeOp.getResult(), options, /*layout=*/{},
801         srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt());
802     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
803         rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
804     return success();
805   }
806 };
807 
808 /// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
809 /// equivalent operand / result and same offset/sizes/strides specification).
810 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
811                                          ExtractSliceOp st,
812                                          ParallelInsertSliceOp sti) {
813   if (!st || !sti)
814     return false;
815   if (st != sti &&
816       !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
817     return false;
818   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
819     return false;
820   return true;
821 }
822 
823 /// Return true if `value` is originating from an ExtractSliceOp that matches
824 /// the given InsertSliceOp.
825 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
826                                       ParallelInsertSliceOp insertOp) {
827   auto condition = [&](Value val) {
828     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
829       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
830         return true;
831     return false;
832   };
833 
834   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
835                       condition);
836 }
837 
838 /// Analysis of ParallelInsertSliceOp.
839 struct ParallelInsertSliceOpInterface
840     : public BufferizableOpInterface::ExternalModel<
841           ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
842   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
843                                             const AnalysisState &state) const {
844     if (&opOperand != &op->getOpOperand(1) /*dest*/)
845       return {};
846 
847     // ParallelInsertSliceOp itself has no results, query its tied op results.
848     auto insertOp = cast<ParallelInsertSliceOp>(op);
849     return {insertOp.getTiedOpResult()};
850   }
851 
852   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
853                               const AnalysisState &state) const {
854     return true;
855   }
856 
857   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
858                                const AnalysisState &state) const {
859     return &opOperand == &op->getOpOperand(1) /*dest*/;
860   }
861 
862   BufferRelation bufferRelation(Operation *op, OpResult opResult,
863                                 const AnalysisState &state) const {
864     return BufferRelation::Equivalent;
865   }
866 
867   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
868                                  const AnalysisState &state) const {
869     // This interface method is overridden because we want to set a custom
870     // insertion point for tensor copies. They should be inserted right before
871     // the ForeachThreadOp. E.g.:
872     //
873     // %r0, %r1 = foreach_thead ... {
874     //   ...
875     //   perform_concurrently {
876     //     parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
877     //     parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
878     //   }
879     // }
880     //
881     // After TensorCopyInsertion:
882     //
883     // %copy = bufferization.alloc_tensor() copy(%d)
884     // %r0, %r1 = foreach_thead ... {
885     //   ...
886     //   perform_concurrently {
887     //     parallel_insert_slice %a into %b ...
888     //     parallel_insert_slice %c into %copy ...
889     //   }
890     // }
891 
892     OpBuilder::InsertionGuard g(rewriter);
893     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
894     ParallelCombiningOpInterface parallelCombiningParent =
895         parallelInsertSliceOp.getParallelCombiningParent();
896     Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
897 
898     // Nothing to do if the destination tensor is inplace.
899     assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
900            "source is always in-place");
901     if (state.isInPlace(op->getOpOperand(1) /*dest*/))
902       return success();
903 
904     // Find corresponding OpResult.
905     OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
906 
907     // Insert tensor allocation right before the ForeachThreadOp.
908     rewriter.setInsertionPoint(parallelIteratingOp);
909     bool isYielded = state.isTensorYielded(opResult);
910     FailureOr<Value> alloc = allocateTensorForShapedValue(
911         rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
912         /*escape=*/isYielded, state.getOptions());
913     if (failed(alloc))
914       return failure();
915 
916     // Update destination operand.
917     rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
918       parallelInsertSliceOp.getDestMutable().assign(*alloc);
919     });
920 
921     return success();
922   }
923 
924   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
925                           const BufferizationOptions &options) const {
926     OpBuilder::InsertionGuard g(rewriter);
927     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
928     ParallelCombiningOpInterface parallelCombiningParent =
929         parallelInsertSliceOp.getParallelCombiningParent();
930     Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
931 
932     // Get destination buffer.
933     FailureOr<Value> destBuffer =
934         getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
935     if (failed(destBuffer))
936       return failure();
937 
938     // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
939     rewriter.setInsertionPoint(parallelCombiningParent);
940     FailureOr<Value> srcBuffer =
941         getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
942     if (failed(srcBuffer))
943       return failure();
944 
945     // Take a subview of the destination buffer.
946     auto destBufferType = destBuffer->getType().cast<MemRefType>();
947     auto subviewMemRefType =
948         memref::SubViewOp::inferRankReducedResultType(
949             parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
950             parallelInsertSliceOp.getMixedOffsets(),
951             parallelInsertSliceOp.getMixedSizes(),
952             parallelInsertSliceOp.getMixedStrides())
953             .cast<MemRefType>();
954     Value subview = rewriter.create<memref::SubViewOp>(
955         parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
956         parallelInsertSliceOp.getMixedOffsets(),
957         parallelInsertSliceOp.getMixedSizes(),
958         parallelInsertSliceOp.getMixedStrides());
959 
960     // This memcpy will fold away if everything bufferizes in-place.
961     if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
962                                     *srcBuffer, subview)))
963       return failure();
964 
965     // Replace all uses of parallelIteratingOp (just the corresponding result).
966     rewriter.setInsertionPointAfter(parallelIteratingOp);
967     Value toTensorOp =
968         rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
969     // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
970     SmallVector<OpOperand *> resultUses = llvm::to_vector(
971         llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
972                         [](OpOperand &use) { return &use; }));
973     for (OpOperand *use : resultUses) {
974       rewriter.updateRootInPlace(use->getOwner(),
975                                  [&]() { use->set(toTensorOp); });
976     }
977     rewriter.eraseOp(op);
978     return success();
979   }
980 
981   // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
982   // the code.
983   bool isNotConflicting(Operation *op, OpOperand *uRead,
984                         OpOperand *uConflictingWrite,
985                         const AnalysisState &state) const {
986     Operation *readingOp = uRead->getOwner();
987     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
988 
989     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
990     // uRead is an InsertSliceOp...
991     if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
992       // As an example, consider the following IR.
993       //
994       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
995       // %1 = linalg.fill %cst, %0 {inplace= [true] }
996       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
997       //     {inplace= [true] }
998 
999       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
1000       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1001           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
1002                                     insertSliceOp))
1003         // Case 1: The main insight is that InsertSliceOp reads only part of
1004         // the destination tensor. The overwritten area is not read. If
1005         // uConflictingWrite writes into exactly the memory location that is
1006         // being read by uRead, this is not a conflict.
1007         //
1008         // In the above example:
1009         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
1010         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
1011         //
1012         // The read of %t does not conflict with the write of the FillOp
1013         // (same aliases!) because the area that the FillOp operates on is
1014         // exactly the one that is *not* read via %t.
1015         return true;
1016 
1017       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
1018           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1019           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
1020         // Case 2: The read of the source tensor and the write to the dest
1021         // tensor via an InsertSliceOp is not a conflict if the read is
1022         // reading exactly that part of an equivalent tensor that the
1023         // InsertSliceOp is writing.
1024         //
1025         // In the above example:
1026         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
1027         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1028         return true;
1029     }
1030 
1031     // If uConflictingWrite is an InsertSliceOp...
1032     if (auto insertSliceOp =
1033             dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
1034       // As an example, consider the following IR.
1035       //
1036       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1037       // %1 = linalg.fill %cst, %0 {inplace= [true] }
1038       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1039       //     {inplace= [true] }
1040       // %3 = vector.transfer_read %1, %cst
1041       //
1042       // In the above example:
1043       // uRead             = OpOperand 0 (%1) of vector.transfer_read
1044       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1045       // lastWrite         = %1
1046       //
1047       // This is not a conflict because the InsertSliceOp overwrites the
1048       // memory segment of %1 with the exact same data. (Effectively, there
1049       // is no memory write here.)
1050       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1051           state.areEquivalentBufferizedValues(uRead->get(),
1052                                               insertSliceOp.getSource()) &&
1053           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
1054                                     insertSliceOp))
1055         return true;
1056 
1057     return false;
1058   }
1059 };
1060 
1061 } // namespace
1062 } // namespace tensor
1063 } // namespace mlir
1064 
1065 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
1066     DialectRegistry &registry) {
1067   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
1068     CastOp::attachInterface<CastOpInterface>(*ctx);
1069     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1070     DimOp::attachInterface<DimOpInterface>(*ctx);
1071     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
1072     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
1073     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
1074     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
1075     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
1076     InsertOp::attachInterface<InsertOpInterface>(*ctx);
1077     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
1078     ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1079         *ctx);
1080     RankOp::attachInterface<RankOpInterface>(*ctx);
1081     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
1082 
1083     // Load additional dialects of which ops may get created.
1084     ctx->loadDialect<arith::ArithmeticDialect, scf::SCFDialect>();
1085   });
1086 }
1087