1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Operation.h"
15 
16 using namespace mlir;
17 using namespace mlir::bufferization;
18 using namespace mlir::tensor;
19 
20 namespace mlir {
21 namespace tensor {
22 namespace {
23 
24 struct CastOpInterface
25     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
26                                                     tensor::CastOp> {
27   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
28                               const BufferizationState &state) const {
29     return false;
30   }
31 
32   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
33                                const BufferizationState &state) const {
34     return false;
35   }
36 
37   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
38                                const BufferizationState &state) const {
39     return op->getResult(0);
40   }
41 
42   BufferRelation bufferRelation(Operation *op, OpResult opResult,
43                                 const BufferizationState &state) const {
44     return BufferRelation::Equivalent;
45   }
46 
47   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
48                           const BufferizationState &state) const {
49     auto castOp = cast<tensor::CastOp>(op);
50 
51     // The result buffer still has the old (pre-cast) type.
52     FailureOr<Value> resultBuffer =
53         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
54     if (failed(resultBuffer))
55       return failure();
56     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
57     Attribute memorySpace = sourceMemRefType.getMemorySpace();
58     TensorType resultTensorType =
59         castOp.getResult().getType().cast<TensorType>();
60     MemRefLayoutAttrInterface layout;
61 
62     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
63       if (resultTensorType.isa<RankedTensorType>())
64         layout = rankedMemRefType.getLayout();
65 
66     // Compute the new memref type.
67     Type resultMemRefType;
68     if (resultTensorType.isa<RankedTensorType>()) {
69       resultMemRefType =
70           getContiguousMemRefType(resultTensorType, layout, memorySpace);
71     } else {
72       resultMemRefType =
73           getUnrankedMemRefType(resultTensorType.getElementType(), memorySpace);
74     }
75 
76     // Replace the op with a memref.cast.
77     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
78                                              resultMemRefType) &&
79            "CallOp::bufferize: cast incompatible");
80     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
81                                                  *resultBuffer);
82 
83     return success();
84   }
85 };
86 
87 /// Bufferization of tensor.dim. Replace with memref.dim.
88 struct DimOpInterface
89     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
90                                                     tensor::DimOp> {
91   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
92                               const BufferizationState &state) const {
93     return true;
94   }
95 
96   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
97                                const BufferizationState &state) const {
98     return false;
99   }
100 
101   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
102                                const BufferizationState &state) const {
103     return OpResult();
104   }
105 
106   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
107                           const BufferizationState &state) const {
108     auto dimOp = cast<tensor::DimOp>(op);
109     Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
110     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
111     return success();
112   }
113 };
114 
115 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
116 struct ExtractSliceOpInterface
117     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
118                                                     tensor::ExtractSliceOp> {
119   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
120                               const BufferizationState &state) const {
121     return false;
122   }
123 
124   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
125                                const BufferizationState &state) const {
126     return false;
127   }
128 
129   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
130                                const BufferizationState &state) const {
131     return &opOperand == &op->getOpOperand(0) /*source*/
132                ? op->getResult(0)
133                : OpResult();
134   }
135 
136   BufferRelation bufferRelation(Operation *op, OpResult opResult,
137                                 const BufferizationState &state) const {
138     return BufferRelation::None;
139   }
140 
141   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
142                           const BufferizationState &state) const {
143     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
144     Location loc = extractSliceOp.getLoc();
145     Value srcMemref =
146         *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
147                          /*forceInPlace=*/true);
148     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
149     auto dstTensorType =
150         extractSliceOp.result().getType().cast<RankedTensorType>();
151 
152     // If not inplaceable, alloc.
153     bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0));
154     Value alloc;
155     if (!inplace) {
156       FailureOr<Value> allocOrFailure =
157           createAlloc(rewriter, loc, extractSliceOp.result(),
158                       state.getOptions().createDeallocs, state.getOptions());
159       if (failed(allocOrFailure))
160         return failure();
161       alloc = *allocOrFailure;
162     }
163 
164     // Expand offsets, sizes and strides to the full rank to handle the
165     // rank-reducing case.
166     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
167     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
168     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
169     OffsetSizeAndStrideOpInterface::expandToRank(
170         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
171         [&](Value target, int64_t dim) -> OpFoldResult {
172           auto shapedType = target.getType().cast<ShapedType>();
173           if (shapedType.isDynamicDim(dim))
174             return rewriter.create<memref::DimOp>(loc, target, dim).result();
175           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
176         });
177     // Bufferize to subview.
178     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
179                                  dstTensorType.getRank(), srcMemrefType,
180                                  mixedOffsets, mixedSizes, mixedStrides)
181                                  .cast<MemRefType>();
182     Value subView = rewriter.create<memref::SubViewOp>(
183         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
184         mixedStrides);
185 
186     // If not inplaceable, copy.
187     if (!inplace) {
188       // Do not copy if the copied data is never read.
189       if (state.isValueRead(extractSliceOp.result()))
190         if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
191                                 alloc, state.getOptions())))
192           return failure();
193       subView = alloc;
194     }
195 
196     replaceOpWithBufferizedValues(rewriter, op, subView);
197     return success();
198   }
199 };
200 
201 /// Bufferization of tensor.extract. Replace with memref.load.
202 struct ExtractOpInterface
203     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
204                                                     tensor::ExtractOp> {
205   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
206                               const BufferizationState &state) const {
207     return true;
208   }
209 
210   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
211                                const BufferizationState &state) const {
212     return false;
213   }
214 
215   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
216                                const BufferizationState &state) const {
217     return OpResult();
218   }
219 
220   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
221                           const BufferizationState &state) const {
222     auto extractOp = cast<tensor::ExtractOp>(op);
223     Value srcMemref =
224         *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
225     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
226                                                  extractOp.indices());
227     return success();
228   }
229 };
230 
231 /// Bufferization of tensor.insert. Replace with memref.store.
232 struct InsertOpInterface
233     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
234                                                     tensor::InsertOp> {
235   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
236                               const BufferizationState &state) const {
237     return true;
238   }
239 
240   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
241                                const BufferizationState &state) const {
242     return true;
243   }
244 
245   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
246                                const BufferizationState &state) const {
247     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
248            "expected dest OpOperand");
249     return op->getOpResult(0);
250   }
251 
252   SmallVector<OpOperand *>
253   getAliasingOpOperand(Operation *op, OpResult opResult,
254                        const BufferizationState &state) const {
255     return {&op->getOpOperand(1) /*dest*/};
256   }
257 
258   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
259                           const BufferizationState &state) const {
260     auto insertOp = cast<tensor::InsertOp>(op);
261     FailureOr<Value> destMemref =
262         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
263     if (failed(destMemref))
264       return failure();
265     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
266                                      *destMemref, insertOp.indices());
267     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
268     return success();
269   }
270 
271   BufferRelation bufferRelation(Operation *op, OpResult opResult,
272                                 const BufferizationState &state) const {
273     return BufferRelation::Equivalent;
274   }
275 };
276 
277 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
278 /// equivalent operand / result and same offset/sizes/strides specification).
279 ///
280 /// This is one particular type of relationship between ops on tensors that
281 /// reduce to an equivalence on buffers. This should be generalized and
282 /// exposed as interfaces on the proper types.
283 static bool areEquivalentExtractSliceOps(const BufferizationState &state,
284                                          ExtractSliceOp st, InsertSliceOp sti) {
285   if (!st || !sti)
286     return false;
287   if (sti != sti &&
288       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
289     return false;
290   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
291     return false;
292   return true;
293 }
294 
295 /// Return true if `value` is originating from an ExtractSliceOp that matches
296 /// the given InsertSliceOp.
297 static bool hasMatchingExtractSliceOp(const BufferizationState &state,
298                                       Value value, InsertSliceOp insertOp) {
299   auto condition = [&](Value val) {
300     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
301       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
302         return true;
303     return false;
304   };
305 
306   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
307                       condition);
308 }
309 
310 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
311 /// certain circumstances, this op can also be a no-op.
312 struct InsertSliceOpInterface
313     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
314                                                     tensor::InsertSliceOp> {
315   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
316                               const BufferizationState &state) const {
317     return true;
318   }
319 
320   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
321                                const BufferizationState &state) const {
322     return &opOperand == &op->getOpOperand(1) /*dest*/;
323   }
324 
325   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
326                                const BufferizationState &state) const {
327     return &opOperand == &op->getOpOperand(1) /*dest*/
328                ? op->getResult(0)
329                : OpResult();
330   }
331 
332   BufferRelation bufferRelation(Operation *op, OpResult opResult,
333                                 const BufferizationState &state) const {
334     return BufferRelation::Equivalent;
335   }
336 
337   bool isNotConflicting(Operation *op, OpOperand *uRead,
338                         OpOperand *uConflictingWrite,
339                         const BufferizationState &state) const {
340     Operation *readingOp = uRead->getOwner();
341     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
342 
343     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
344     // uRead is an InsertSliceOp...
345     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
346       // As an example, consider the following IR.
347       //
348       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
349       // %1 = linalg.fill %cst, %0 {inplace= [true] }
350       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
351       //     {inplace= [true] }
352 
353       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
354       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
355           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
356                                     insertSliceOp))
357         // Case 1: The main insight is that InsertSliceOp reads only part of
358         // the destination tensor. The overwritten area is not read. If
359         // uConflictingWrite writes into exactly the memory location that is
360         // being read by uRead, this is not a conflict.
361         //
362         // In the above example:
363         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
364         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
365         //
366         // The read of %t does not conflict with the write of the FillOp
367         // (same aliases!) because the area that the FillOp operates on is
368         // exactly the one that is *not* read via %t.
369         return true;
370 
371       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
372           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
373           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
374         // Case 2: The read of the source tensor and the write to the dest
375         // tensor via an InsertSliceOp is not a conflict if the read is
376         // reading exactly that part of an equivalent tensor that the
377         // InsertSliceOp is writing.
378         //
379         // In the above example:
380         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
381         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
382         return true;
383     }
384 
385     // If uConflictingWrite is an InsertSliceOp...
386     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
387       // As an example, consider the following IR.
388       //
389       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
390       // %1 = linalg.fill %cst, %0 {inplace= [true] }
391       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
392       //     {inplace= [true] }
393       // %3 = vector.transfer_read %1, %cst
394       //
395       // In the above example:
396       // uRead             = OpOperand 0 (%1) of vector.transfer_read
397       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
398       // lastWrite         = %1
399       //
400       // This is not a conflict because the InsertSliceOp overwrites the
401       // memory segment of %1 with the exact same data. (Effectively, there
402       // is no memory write here.)
403       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
404           state.areEquivalentBufferizedValues(uRead->get(),
405                                               insertSliceOp.source()) &&
406           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
407                                     insertSliceOp))
408         return true;
409 
410     return false;
411   }
412 
413   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
414                           const BufferizationState &state) const {
415     // insert_slice ops arise from tiling and bufferizing them out-of-place is
416     // generally a deal breaker. When used with loops, this ends up cloning the
417     // whole tensor on every single iteration and is a symptom of a
418     // catastrophically bad scheduling decision.
419     // TODO: be very loud about it or even consider failing the pass.
420     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
421     Location loc = insertSliceOp.getLoc();
422 
423     // When bufferizing out-of-place, `getResultBuffer` allocates.
424     FailureOr<Value> dstMemref =
425         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
426     if (failed(dstMemref))
427       return failure();
428 
429     // Expand offsets, sizes and strides to the full rank to handle the
430     // rank-reducing case.
431     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
432     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
433     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
434     OffsetSizeAndStrideOpInterface::expandToRank(
435         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
436         [&](Value target, int64_t dim) -> OpFoldResult {
437           auto shapedType = target.getType().cast<ShapedType>();
438           if (shapedType.isDynamicDim(dim))
439             return rewriter.create<memref::DimOp>(loc, target, dim).result();
440           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
441         });
442     // Take a subview of the dst.
443     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
444     auto subviewMemRefType =
445         memref::SubViewOp::inferRankReducedResultType(
446             insertSliceOp.getSourceType().getRank(), dstMemrefType,
447             mixedOffsets, mixedSizes, mixedStrides)
448             .cast<MemRefType>();
449     Value subView = rewriter.create<memref::SubViewOp>(
450         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
451         mixedStrides);
452 
453     // Copy tensor. If this tensor.insert_slice has a matching
454     // tensor.extract_slice, the copy operation will eventually fold away.
455     Value srcMemref =
456         *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
457     if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
458                             state.getOptions())))
459       return failure();
460 
461     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
462     return success();
463   }
464 };
465 
466 /// Bufferization of tensor.rank. Replace with memref.rank.
467 struct RankOpInterface
468     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
469                                                     tensor::RankOp> {
470   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
471                               const BufferizationState &state) const {
472     return true;
473   }
474 
475   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
476                                const BufferizationState &state) const {
477     return false;
478   }
479 
480   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
481                                const BufferizationState &state) const {
482     return OpResult();
483   }
484 
485   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
486                           const BufferizationState &state) const {
487     auto rankOp = cast<tensor::RankOp>(op);
488     Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
489     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
490                                                  v);
491     return success();
492   }
493 };
494 
495 } // namespace
496 } // namespace tensor
497 } // namespace mlir
498 
499 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
500     DialectRegistry &registry) {
501   registry.addOpInterface<CastOp, CastOpInterface>();
502   registry.addOpInterface<DimOp, DimOpInterface>();
503   registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
504   registry.addOpInterface<ExtractOp, ExtractOpInterface>();
505   registry.addOpInterface<InsertOp, InsertOpInterface>();
506   registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
507   registry.addOpInterface<RankOp, RankOpInterface>();
508 }
509