1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/Dialect.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/PatternMatch.h"
20 
21 using namespace mlir;
22 using namespace mlir::bufferization;
23 using namespace mlir::scf;
24 
25 namespace mlir {
26 namespace scf {
27 namespace {
28 
29 // bufferization.to_memref is not allowed to change the rank.
30 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
31 #ifndef NDEBUG
32   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
33   assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
34                                 rankedTensorType.getRank())) &&
35          "to_memref would be invalid: mismatching ranks");
36 #endif
37 }
38 
39 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
40 /// fully implemented at the moment.
41 struct ExecuteRegionOpInterface
42     : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
43                                                     scf::ExecuteRegionOp> {
44   SmallVector<OpOperand *>
45   getAliasingOpOperand(Operation *op, OpResult opResult,
46                        const AnalysisState &state) const {
47     // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
48     // any SSA value that is in scope. To allow for use-def chain traversal
49     // through ExecuteRegionOps in the analysis, the corresponding yield value
50     // is considered to be aliasing with the result.
51     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
52     size_t resultNum = std::distance(op->getOpResults().begin(),
53                                      llvm::find(op->getOpResults(), opResult));
54     // TODO: Support multiple blocks.
55     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
56            "expected exactly 1 block");
57     auto yieldOp = dyn_cast<scf::YieldOp>(
58         executeRegionOp.getRegion().front().getTerminator());
59     assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
60     return {&yieldOp->getOpOperand(resultNum)};
61   }
62 
63   // TODO: For better bufferization results, this could return `true` only if
64   // there is a memory write in the region.
65   bool isMemoryWrite(Operation *op, OpResult opResult,
66                      const AnalysisState &state) const {
67     // Similar to scf.if, results of this op are always considered memory writes
68     // in the analysis. This is a useful pattern for all ops that have tensor
69     // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
70     // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
71     // ops without OpOperands.
72     return true;
73   }
74 
75   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
76                           BufferizationState &state) const {
77     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
78 
79     // Compute new result types.
80     SmallVector<Type> newResultTypes;
81     for (Type type : executeRegionOp->getResultTypes()) {
82       if (auto tensorType = type.dyn_cast<TensorType>()) {
83         // TODO: Infer the result type instead of computing it.
84         newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
85       } else {
86         newResultTypes.push_back(type);
87       }
88     }
89 
90     // Create new op and move over region.
91     auto newOp =
92         rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
93     newOp.getRegion().takeBody(executeRegionOp.getRegion());
94 
95     // Update terminator.
96     assert(newOp.getRegion().getBlocks().size() == 1 &&
97            "only 1 block supported");
98     Block *newBlock = &newOp.getRegion().front();
99     auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
100     rewriter.setInsertionPoint(yieldOp);
101     SmallVector<Value> newYieldValues;
102     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
103       Value val = it.value();
104       if (val.getType().isa<TensorType>()) {
105         newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
106             yieldOp.getLoc(), newResultTypes[it.index()], val));
107       } else {
108         newYieldValues.push_back(val);
109       }
110     }
111     rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
112 
113     // Update all uses of the old op.
114     rewriter.setInsertionPointAfter(newOp);
115     SmallVector<Value> newResults;
116     for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
117       if (it.value().isa<TensorType>()) {
118         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
119             executeRegionOp.getLoc(), newOp->getResult(it.index())));
120       } else {
121         newResults.push_back(newOp->getResult(it.index()));
122       }
123     }
124 
125     // Replace old op.
126     rewriter.replaceOp(executeRegionOp, newResults);
127 
128     return success();
129   }
130 
131   BufferRelation bufferRelation(Operation *op, OpResult opResult,
132                                 const AnalysisState &state) const {
133     return BufferRelation::Equivalent;
134   }
135 };
136 
137 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
138 struct IfOpInterface
139     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
140   SmallVector<OpOperand *>
141   getAliasingOpOperand(Operation *op, OpResult opResult,
142                        const AnalysisState &state) const {
143     // IfOps do not have tensor OpOperands. The yielded value can be any SSA
144     // value that is in scope. To allow for use-def chain traversal through
145     // IfOps in the analysis, both corresponding yield values from the then/else
146     // branches are considered to be aliasing with the result.
147     auto ifOp = cast<scf::IfOp>(op);
148     size_t resultNum = std::distance(op->getOpResults().begin(),
149                                      llvm::find(op->getOpResults(), opResult));
150     return {&ifOp.thenYield()->getOpOperand(resultNum),
151             &ifOp.elseYield()->getOpOperand(resultNum)};
152   }
153 
154   // TODO: For better bufferization results, this could return `true` only if
155   // there is a memory write in one (or both) of the branches. Since this is not
156   // allowed at the moment, we should never encounter scf.ifs that yield
157   // unmodified tensors. Such scf.yield ops could just fold away.
158   bool isMemoryWrite(Operation *op, OpResult opResult,
159                      const AnalysisState &state) const {
160     // IfOp results are always considered memory writes in the analysis. This
161     // design decision simplifies the analysis considerably. E.g., consider the
162     // following test case:
163     //
164     // %0 = "some_writing_op" : tensor<?xf32>
165     // %r = scf.if %c -> (tensor<?xf32>) {
166     //   scf.yield %0
167     // } else {
168     //   %1 = "another_writing_op"(%0) : tensor<?xf32>
169     // }
170     // "some_reading_op"(%r)
171     //
172     // "another_writing_op" in the above example should be able to bufferize
173     // inplace in the absence of another read of %0. However, if the scf.if op
174     // would not be considered a "write", the analysis would detect the
175     // following conflict:
176     //
177     // * read = some_reading_op
178     // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
179     // * conflictingWrite = %1
180     //
181     // For more details, check the "scf.IfOp" section of the design document.
182     return true;
183   }
184 
185   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
186                           BufferizationState &state) const {
187     auto ifOp = cast<scf::IfOp>(op);
188 
189     // Compute new types of the bufferized scf.if op.
190     SmallVector<Type> newTypes;
191     for (Type returnType : ifOp->getResultTypes()) {
192       if (auto tensorType = returnType.dyn_cast<TensorType>()) {
193         // TODO: Infer the result type instead of computing it.
194         newTypes.push_back(getMemRefType(tensorType, state.getOptions()));
195       } else {
196         newTypes.push_back(returnType);
197       }
198     }
199 
200     // Create new op.
201     auto newIfOp =
202         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
203                                    /*withElseRegion=*/true);
204 
205     // Remove terminators.
206     if (!newIfOp.thenBlock()->empty()) {
207       rewriter.eraseOp(newIfOp.thenBlock()->getTerminator());
208       rewriter.eraseOp(newIfOp.elseBlock()->getTerminator());
209     }
210 
211     // Move over then/else blocks.
212     rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
213     rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
214 
215     // Update scf.yield of new then-block.
216     auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator());
217     rewriter.setInsertionPoint(thenYieldOp);
218     SmallVector<Value> thenYieldValues;
219     for (OpOperand &operand : thenYieldOp->getOpOperands()) {
220       if (operand.get().getType().isa<TensorType>()) {
221         ensureToMemrefOpIsValid(operand.get(),
222                                 newTypes[operand.getOperandNumber()]);
223         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
224             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
225             operand.get());
226         operand.set(toMemrefOp);
227       }
228     }
229 
230     // Update scf.yield of new else-block.
231     auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator());
232     rewriter.setInsertionPoint(elseYieldOp);
233     SmallVector<Value> elseYieldValues;
234     for (OpOperand &operand : elseYieldOp->getOpOperands()) {
235       if (operand.get().getType().isa<TensorType>()) {
236         ensureToMemrefOpIsValid(operand.get(),
237                                 newTypes[operand.getOperandNumber()]);
238         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
239             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
240             operand.get());
241         operand.set(toMemrefOp);
242       }
243     }
244 
245     // Replace op results.
246     replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
247 
248     return success();
249   }
250 
251   BufferRelation bufferRelation(Operation *op, OpResult opResult,
252                                 const AnalysisState &state) const {
253     // IfOp results are equivalent to their corresponding yield values if both
254     // yield values are equivalent to each other.
255     auto bufferizableOp = cast<BufferizableOpInterface>(op);
256     SmallVector<OpOperand *> yieldValues =
257         bufferizableOp.getAliasingOpOperand(opResult, state);
258     assert(yieldValues.size() == 2 && "expected 2 yield values");
259     bool equivalentYields = state.areEquivalentBufferizedValues(
260         yieldValues[0]->get(), yieldValues[1]->get());
261     return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
262   }
263 };
264 
265 /// Helper function for loop bufferization. Return the indices of all values
266 /// that have a tensor type.
267 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
268   DenseSet<int64_t> result;
269   for (const auto &it : llvm::enumerate(values))
270     if (it.value().getType().isa<TensorType>())
271       result.insert(it.index());
272   return result;
273 }
274 
275 /// Helper function for loop bufferization. Return the indices of all
276 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
277 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
278                                        ValueRange yieldedValues,
279                                        const AnalysisState &state) {
280   unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
281   DenseSet<int64_t> result;
282   for (unsigned int i = 0; i < minSize; ++i) {
283     if (!bbArgs[i].getType().isa<TensorType>() ||
284         !yieldedValues[i].getType().isa<TensorType>())
285       continue;
286     if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
287       result.insert(i);
288   }
289   return result;
290 }
291 
292 /// Helper function for loop bufferization. Cast the given buffer to the given
293 /// memref type.
294 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
295   assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
296   assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
297   // If the buffer already has the correct type, no cast is needed.
298   if (buffer.getType() == type)
299     return buffer;
300   // TODO: In case `type` has a layout map that is not the fully dynamic
301   // one, we may not be able to cast the buffer. In that case, the loop
302   // iter_arg's layout map must be changed (see uses of `castBuffer`).
303   assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
304          "scf.while op bufferization: cast incompatible");
305   return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
306 }
307 
308 /// Helper function for loop bufferization. Return the bufferized values of the
309 /// given OpOperands. If an operand is not a tensor, return the original value.
310 static SmallVector<Value> getBuffers(RewriterBase &rewriter,
311                                      MutableArrayRef<OpOperand> operands,
312                                      BufferizationState &state) {
313   SmallVector<Value> result;
314   for (OpOperand &opOperand : operands) {
315     if (opOperand.get().getType().isa<TensorType>()) {
316       FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
317       if (failed(resultBuffer))
318         return {};
319       result.push_back(*resultBuffer);
320     } else {
321       result.push_back(opOperand.get());
322     }
323   }
324   return result;
325 }
326 
327 /// Helper function for loop bufferization. Compute the buffer that should be
328 /// yielded from a loop block (loop body or loop condition). If the given tensor
329 /// is equivalent to the corresponding block argument (as indicated by
330 /// `isEquivalent`), the buffer can be yielded directly. Otherwise, a new buffer
331 /// copy must be yielded.
332 ///
333 /// According to the `BufferizableOpInterface` implementation of scf loops, a
334 /// a bufferized OpResult may alias only with the corresponding bufferized
335 /// init_arg and with no other buffers. I.e., the i-th OpResult may alias with
336 /// the i-th init_arg; but not with any other OpOperand. If a corresponding
337 /// OpResult/init_arg pair bufferized to equivalent buffers (as indicated by
338 /// `isEquivalent`), this aliasing requirement is satisfied. Otherwise, we
339 /// cannot be sure and must yield a new buffer copy. (New buffer copies do not
340 /// alias with any buffer.)
341 static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
342                               BaseMemRefType type, bool isEquivalent,
343                               BufferizationState &state) {
344   assert(tensor.getType().isa<TensorType>() && "expected tensor");
345   ensureToMemrefOpIsValid(tensor, type);
346   Value yieldedVal =
347       bufferization::lookupBuffer(rewriter, tensor, state.getOptions());
348 
349   if (isEquivalent)
350     // Yielded value is equivalent to the corresponding iter_arg bbArg.
351     // Yield the value directly. Most IR should be like that. Everything
352     // else must be resolved with copies and is potentially inefficient.
353     // By default, such problematic IR would already have been rejected
354     // during `verifyAnalysis`, unless `allow-return-allocs`.
355     return castBuffer(rewriter, yieldedVal, type);
356 
357   // It is not certain that the yielded value and the iter_arg bbArg
358   // have the same buffer. Allocate a new buffer and copy. The yielded
359   // buffer will get deallocated by `deallocateBuffers`.
360 
361   // TODO: There are cases in which it is not neccessary to return a new
362   // buffer allocation. E.g., when equivalent values are yielded in a
363   // different order. This could be resolved with copies.
364   Optional<Value> yieldedAlloc = state.createAlloc(
365       rewriter, tensor.getLoc(), yieldedVal, /*deallocMemref=*/false);
366   // TODO: We should rollback, but for now just assume that this always
367   // succeeds.
368   assert(yieldedAlloc.hasValue() && "could not create alloc");
369   LogicalResult copyStatus = state.getOptions().createMemCpy(
370       rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc);
371   (void)copyStatus;
372   assert(succeeded(copyStatus) && "could not create memcpy");
373 
374   // The iter_arg memref type may have a layout map. Cast the new buffer
375   // to the same type if needed.
376   return castBuffer(rewriter, *yieldedAlloc, type);
377 }
378 
379 /// Helper function for loop bufferization. Given a range of values, apply
380 /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
381 /// value in the result vector.
382 static SmallVector<Value>
383 convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
384                     llvm::function_ref<Value(Value, int64_t)> func) {
385   SmallVector<Value> result;
386   for (const auto &it : llvm::enumerate(values)) {
387     size_t idx = it.index();
388     Value val = it.value();
389     result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val);
390   }
391   return result;
392 }
393 
394 /// Helper function for loop bufferization. Given a list of pre-bufferization
395 /// yielded values, compute the list of bufferized yielded values.
396 SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
397                                     TypeRange bufferizedTypes,
398                                     const DenseSet<int64_t> &tensorIndices,
399                                     const DenseSet<int64_t> &equivalentTensors,
400                                     BufferizationState &state) {
401   return convertTensorValues(
402       values, tensorIndices, [&](Value val, int64_t index) {
403         return getYieldedBuffer(rewriter, val,
404                                 bufferizedTypes[index].cast<BaseMemRefType>(),
405                                 equivalentTensors.contains(index), state);
406       });
407 }
408 
409 /// Helper function for loop bufferization. Given a list of bbArgs of the new
410 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
411 /// ToTensorOps, so that the block body can be moved over to the new op.
412 SmallVector<Value>
413 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
414                      const DenseSet<int64_t> &tensorIndices) {
415   return convertTensorValues(
416       bbArgs, tensorIndices, [&](Value val, int64_t index) {
417         return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
418       });
419 }
420 
421 /// Bufferization of scf.for. Replace with a new scf.for that operates on
422 /// memrefs.
423 struct ForOpInterface
424     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
425                                                     scf::ForOp> {
426   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
427                               const AnalysisState &state) const {
428     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
429     // its matching bbArg may.
430     auto forOp = cast<scf::ForOp>(op);
431     return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
432   }
433 
434   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
435                                const AnalysisState &state) const {
436     // Tensor iter_args of scf::ForOps are always considered as a write.
437     return true;
438   }
439 
440   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
441                                             const AnalysisState &state) const {
442     auto forOp = cast<scf::ForOp>(op);
443     return {forOp.getResultForOpOperand(opOperand)};
444   }
445 
446   BufferRelation bufferRelation(Operation *op, OpResult opResult,
447                                 const AnalysisState &state) const {
448     // ForOp results are equivalent to their corresponding init_args if the
449     // corresponding iter_args and yield values are equivalent.
450     auto forOp = cast<scf::ForOp>(op);
451     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
452     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
453     auto yieldOp =
454         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
455     bool equivalentYield = state.areEquivalentBufferizedValues(
456         bbArg, yieldOp->getOperand(opResult.getResultNumber()));
457     return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
458   }
459 
460   bool isWritable(Operation *op, Value value,
461                   const AnalysisState &state) const {
462     // Interestingly, scf::ForOp's bbArg can **always** be viewed
463     // inplace from the perspective of ops nested under:
464     //   1. Either the matching iter operand is not bufferized inplace and an
465     //      alloc + optional copy makes the bbArg itself inplaceable.
466     //   2. Or the matching iter operand is bufferized inplace and bbArg just
467     //      bufferizes to that too.
468     return true;
469   }
470 
471   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
472                           BufferizationState &state) const {
473     auto forOp = cast<scf::ForOp>(op);
474     auto oldYieldOp =
475         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
476     Block *oldLoopBody = &forOp.getLoopBody().front();
477 
478     // Indices of all iter_args that have tensor type. These are the ones that
479     // are bufferized.
480     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
481     // For every yielded value, is the value equivalent to its corresponding
482     // bbArg?
483     DenseSet<int64_t> equivalentYields =
484         getEquivalentBuffers(forOp.getRegionIterArgs(), oldYieldOp.getResults(),
485                              state.getAnalysisState());
486 
487     // The new memref init_args of the loop.
488     SmallVector<Value> initArgs =
489         getBuffers(rewriter, forOp.getIterOpOperands(), state);
490 
491     // Construct a new scf.for op with memref instead of tensor values.
492     auto newForOp = rewriter.create<scf::ForOp>(
493         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
494         forOp.getStep(), initArgs);
495     newForOp->setAttrs(forOp->getAttrs());
496     ValueRange initArgsRange(initArgs);
497     TypeRange initArgsTypes(initArgsRange);
498     Block *loopBody = &newForOp.getLoopBody().front();
499 
500     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
501     // iter_args of the new loop in ToTensorOps.
502     rewriter.setInsertionPointToStart(loopBody);
503     SmallVector<Value> iterArgs =
504         getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
505     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
506 
507     // Erase terminator if present.
508     if (iterArgs.size() == 1)
509       rewriter.eraseOp(loopBody->getTerminator());
510 
511     // Move loop body to new loop.
512     rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
513 
514     // Update scf.yield of new loop.
515     auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
516     rewriter.setInsertionPoint(yieldOp);
517     SmallVector<Value> yieldValues =
518         getYieldedValues(rewriter, yieldOp.getResults(), initArgsTypes, indices,
519                          equivalentYields, state);
520     yieldOp.getResultsMutable().assign(yieldValues);
521 
522     // Replace loop results.
523     replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
524 
525     return success();
526   }
527 
528   /// Assert that yielded values of an scf.for op are equivalent to their
529   /// corresponding bbArgs. In that case, the buffer relations of the
530   /// corresponding OpResults are "Equivalent".
531   ///
532   /// If this is not the case, an allocs+copies are inserted and yielded from
533   /// the loop. This could be a performance problem, so it must be explicitly
534   /// activated with `alloc-return-allocs`.
535   LogicalResult verifyAnalysis(Operation *op,
536                                const AnalysisState &state) const {
537     const auto &options =
538         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
539     if (options.allowReturnAllocs)
540       return success();
541 
542     auto forOp = cast<scf::ForOp>(op);
543     auto yieldOp =
544         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
545     for (OpResult opResult : op->getOpResults()) {
546       if (!opResult.getType().isa<TensorType>())
547         continue;
548 
549       // Note: This is overly strict. We should check for aliasing bufferized
550       // values. But we don't have a "must-alias" analysis yet.
551       if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
552         return yieldOp->emitError()
553                << "Yield operand #" << opResult.getResultNumber()
554                << " is not equivalent to the corresponding iter bbArg";
555     }
556 
557     return success();
558   }
559 };
560 
561 /// Bufferization of scf.while. Replace with a new scf.while that operates on
562 /// memrefs.
563 struct WhileOpInterface
564     : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
565                                                     scf::WhileOp> {
566   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
567                               const AnalysisState &state) const {
568     // Tensor iter_args of scf::WhileOps are always considered as a read.
569     return true;
570   }
571 
572   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
573                                const AnalysisState &state) const {
574     // Tensor iter_args of scf::WhileOps are always considered as a write.
575     return true;
576   }
577 
578   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
579                                             const AnalysisState &state) const {
580     auto whileOp = cast<scf::WhileOp>(op);
581     unsigned int idx = opOperand.getOperandNumber();
582 
583     // The OpResults and OpOperands may not match. They may not even have the
584     // same type. The number of OpResults and OpOperands can also differ.
585     if (idx >= op->getNumResults() ||
586         opOperand.get().getType() != op->getResult(idx).getType())
587       return {};
588 
589     // The only aliasing OpResult may be the one at the same index.
590     return {whileOp->getResult(idx)};
591   }
592 
593   BufferRelation bufferRelation(Operation *op, OpResult opResult,
594                                 const AnalysisState &state) const {
595     // WhileOp results are equivalent to their corresponding init_args if the
596     // corresponding iter_args and yield values are equivalent (for both the
597     // "before" and the "after" block).
598     unsigned int resultNumber = opResult.getResultNumber();
599     auto whileOp = cast<scf::WhileOp>(op);
600 
601     // The "before" region bbArgs and the OpResults may not match.
602     if (resultNumber >= whileOp.getBeforeArguments().size())
603       return BufferRelation::None;
604     if (opResult.getType() !=
605         whileOp.getBeforeArguments()[resultNumber].getType())
606       return BufferRelation::None;
607 
608     auto conditionOp = whileOp.getConditionOp();
609     BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
610     Value conditionOperand = conditionOp.getArgs()[resultNumber];
611     bool equivCondition =
612         state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
613 
614     auto yieldOp = whileOp.getYieldOp();
615     BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
616     Value yieldOperand = yieldOp.getOperand(resultNumber);
617     bool equivYield =
618         state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
619 
620     return equivCondition && equivYield ? BufferRelation::Equivalent
621                                         : BufferRelation::None;
622   }
623 
624   bool isWritable(Operation *op, Value value,
625                   const AnalysisState &state) const {
626     // Interestingly, scf::WhileOp's bbArg can **always** be viewed
627     // inplace from the perspective of ops nested under:
628     //   1. Either the matching iter operand is not bufferized inplace and an
629     //      alloc + optional copy makes the bbArg itself inplaceable.
630     //   2. Or the matching iter operand is bufferized inplace and bbArg just
631     //      bufferizes to that too.
632     return true;
633   }
634 
635   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
636                           BufferizationState &state) const {
637     auto whileOp = cast<scf::WhileOp>(op);
638 
639     assert(whileOp.getBefore().getBlocks().size() == 1 &&
640            "regions with multiple blocks not supported");
641     Block *beforeBody = &whileOp.getBefore().front();
642     assert(whileOp.getAfter().getBlocks().size() == 1 &&
643            "regions with multiple blocks not supported");
644     Block *afterBody = &whileOp.getAfter().front();
645 
646     // Indices of all bbArgs that have tensor type. These are the ones that
647     // are bufferized. The "before" and "after" regions may have different args.
648     DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
649     DenseSet<int64_t> indicesAfter =
650         getTensorIndices(whileOp.getAfterArguments());
651 
652     // For every yielded value, is the value equivalent to its corresponding
653     // bbArg?
654     DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
655         whileOp.getBeforeArguments(), whileOp.getConditionOp().getArgs(),
656         state.getAnalysisState());
657     DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
658         whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(),
659         state.getAnalysisState());
660 
661     // The new memref init_args of the loop.
662     SmallVector<Value> initArgs =
663         getBuffers(rewriter, whileOp->getOpOperands(), state);
664 
665     // The result types of a WhileOp are the same as the "after" bbArg types.
666     SmallVector<Type> argsTypesAfter = llvm::to_vector(
667         llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
668           return state.getBufferType(bbArg).cast<Type>();
669         }));
670 
671     // Construct a new scf.while op with memref instead of tensor values.
672     ValueRange argsRangeBefore(initArgs);
673     TypeRange argsTypesBefore(argsRangeBefore);
674     auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
675                                                     argsTypesAfter, initArgs);
676 
677     // Add before/after regions to the new op.
678     SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc());
679     SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
680                                          whileOp.getLoc());
681     Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
682     newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
683     Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
684     newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
685 
686     // Set up new iter_args and move the loop condition block to the new op.
687     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
688     // in ToTensorOps.
689     rewriter.setInsertionPointToStart(newBeforeBody);
690     SmallVector<Value> newBeforeArgs = getBbArgReplacements(
691         rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
692     rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
693 
694     // Update scf.condition of new loop.
695     auto newConditionOp = newWhileOp.getConditionOp();
696     rewriter.setInsertionPoint(newConditionOp);
697     // Only equivalent buffers or new buffer allocations may be yielded to the
698     // "after" region.
699     // TODO: This could be relaxed for better bufferization results.
700     SmallVector<Value> newConditionArgs =
701         getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
702                          indicesAfter, equivalentYieldsBefore, state);
703     newConditionOp.getArgsMutable().assign(newConditionArgs);
704 
705     // Set up new iter_args and move the loop body block to the new op.
706     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
707     // in ToTensorOps.
708     rewriter.setInsertionPointToStart(newAfterBody);
709     SmallVector<Value> newAfterArgs = getBbArgReplacements(
710         rewriter, newWhileOp.getAfterArguments(), indicesAfter);
711     rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
712 
713     // Update scf.yield of the new loop.
714     auto newYieldOp = newWhileOp.getYieldOp();
715     rewriter.setInsertionPoint(newYieldOp);
716     // Only equivalent buffers or new buffer allocations may be yielded to the
717     // "before" region.
718     // TODO: This could be relaxed for better bufferization results.
719     SmallVector<Value> newYieldValues =
720         getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
721                          indicesBefore, equivalentYieldsAfter, state);
722     newYieldOp.getResultsMutable().assign(newYieldValues);
723 
724     // Replace loop results.
725     replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
726 
727     return success();
728   }
729 
730   /// Assert that yielded values of an scf.while op are equivalent to their
731   /// corresponding bbArgs. In that case, the buffer relations of the
732   /// corresponding OpResults are "Equivalent".
733   ///
734   /// If this is not the case, allocs+copies are inserted and yielded from
735   /// the loop. This could be a performance problem, so it must be explicitly
736   /// activated with `alloc-return-allocs`.
737   ///
738   /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
739   /// equivalence condition must be checked for both.
740   LogicalResult verifyAnalysis(Operation *op,
741                                const AnalysisState &state) const {
742     auto whileOp = cast<scf::WhileOp>(op);
743     const auto &options =
744         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
745     if (options.allowReturnAllocs)
746       return success();
747 
748     auto conditionOp = whileOp.getConditionOp();
749     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
750       if (!it.value().getType().isa<TensorType>())
751         continue;
752       if (!state.areEquivalentBufferizedValues(
753               it.value(), conditionOp->getBlock()->getArgument(it.index())))
754         return conditionOp->emitError()
755                << "Condition arg #" << it.index()
756                << " is not equivalent to the corresponding iter bbArg";
757     }
758 
759     auto yieldOp = whileOp.getYieldOp();
760     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
761       if (!it.value().getType().isa<TensorType>())
762         continue;
763       if (!state.areEquivalentBufferizedValues(
764               it.value(), yieldOp->getBlock()->getArgument(it.index())))
765         return yieldOp->emitError()
766                << "Yield operand #" << it.index()
767                << " is not equivalent to the corresponding iter bbArg";
768     }
769 
770     return success();
771   }
772 };
773 
774 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
775 /// this is for analysis only.
776 struct YieldOpInterface
777     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
778                                                     scf::YieldOp> {
779   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
780                               const AnalysisState &state) const {
781     return true;
782   }
783 
784   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
785                                const AnalysisState &state) const {
786     return false;
787   }
788 
789   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
790                                             const AnalysisState &state) const {
791     if (isa<scf::IfOp>(op->getParentOp()))
792       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
793     if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
794       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
795     return {};
796   }
797 
798   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
799                             const AnalysisState &state) const {
800     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
801     // may be generated inside the block. We should not return/yield allocations
802     // when possible.
803     return true;
804   }
805 
806   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
807                           BufferizationState &state) const {
808     auto yieldOp = cast<scf::YieldOp>(op);
809     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
810             yieldOp->getParentOp()))
811       return yieldOp->emitError("unsupported scf::YieldOp parent");
812     return success();
813   }
814 };
815 
816 using tensor::ExtractSliceOp;
817 
818 /// Return the destinations that an ForeachThreadOp is inserting into. One per
819 /// ParallelInsertSliceOp.
820 static SmallVector<OpOperand *>
821 getInsertionDest(ForeachThreadOp foreachThreadOp) {
822   PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
823   SmallVector<OpOperand *> result;
824   terminator.walk([&](ParallelInsertSliceOp insertOp) {
825     result.push_back(&insertOp->getOpOperand(1) /*dest*/);
826   });
827   return result;
828 }
829 
830 /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
831 /// region. There are op interfaces for the terminators (PerformConcurrentlyOp
832 /// and ParallelInsertSliceOp), but these are only used during analysis. Not
833 /// for bufferization.
834 struct ForeachThreadOpInterface
835     : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface,
836                                                     ForeachThreadOp> {
837   SmallVector<OpOperand *>
838   getAliasingOpOperand(Operation *op, OpResult opResult,
839                        const AnalysisState &state) const {
840     // Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
841     auto foreachThreadOp = cast<ForeachThreadOp>(op);
842     return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]};
843   }
844 
845   bool isMemoryWrite(Operation *op, OpResult opResult,
846                      const AnalysisState &state) const {
847     // This op is a memory write. Stop lookup here to avoid finding false
848     // conflicts involving this op and one of the ops in the region. This is
849     // similar to how scf.if ops are analyzed.
850     return true;
851   }
852 
853   BufferRelation bufferRelation(Operation *op, OpResult opResult,
854                                 const AnalysisState &state) const {
855     return BufferRelation::Equivalent;
856   }
857 
858   LogicalResult bufferize(Operation *op, RewriterBase &b,
859                           BufferizationState &state) const {
860     OpBuilder::InsertionGuard g(b);
861     auto foreachThreadOp = cast<ForeachThreadOp>(op);
862 
863     // Gather new results of the ForeachThreadOp.
864     SmallVector<Value> newResults;
865     for (OpResult opResult : foreachThreadOp->getOpResults()) {
866       SmallVector<OpOperand *> insertDestOperands =
867           state.getAnalysisState().getAliasingOpOperand(opResult);
868       assert(insertDestOperands.size() == 1 &&
869              "expected exactly one aliasing OpOperand");
870       // Insert copies right before the PerformConcurrentlyOp terminator. They
871       // should not be inside terminator (which would be the default insertion
872       // point).
873       Value buffer = *state.getBuffer(b, *insertDestOperands.front(),
874                                       /*forceInPlace=*/llvm::None,
875                                       /*customCopyInsertionPoint=*/op);
876       newResults.push_back(buffer);
877     }
878 
879     // Create new ForeachThreadOp without any results and drop the automatically
880     // introduced terminator.
881     TypeRange newResultTypes;
882     auto newForeachThreadOp =
883         b.create<ForeachThreadOp>(foreachThreadOp.getLoc(), newResultTypes,
884                                   foreachThreadOp.getNumThreads());
885     newForeachThreadOp.getBody()->getTerminator()->erase();
886 
887     // Move over block contents of the old op.
888     b.mergeBlocks(foreachThreadOp.getBody(), newForeachThreadOp.getBody(),
889                   {newForeachThreadOp.getBody()->getArguments()});
890 
891     // Bufferize terminator.
892     auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(
893         newForeachThreadOp.getBody()->getTerminator());
894     b.setInsertionPoint(performConcurrentlyOp);
895     unsigned resultCounter = 0;
896     WalkResult walkResult =
897         performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) {
898           Location loc = insertOp.getLoc();
899           Type srcType = getMemRefType(
900               insertOp.getSource().getType().cast<RankedTensorType>(),
901               state.getOptions());
902           // ParallelInsertSliceOp bufferizes to a copy.
903           auto srcMemref = b.create<bufferization::ToMemrefOp>(
904               loc, srcType, insertOp.getSource());
905           Value destMemref = newResults[resultCounter++];
906           Value subview = b.create<memref::SubViewOp>(
907               loc, destMemref, insertOp.getMixedOffsets(),
908               insertOp.getMixedSizes(), insertOp.getMixedStrides());
909           // This memcpy will fold away if everything bufferizes in-place.
910           if (failed(state.getOptions().createMemCpy(b, insertOp.getLoc(),
911                                                      srcMemref, subview)))
912             return WalkResult::interrupt();
913           b.eraseOp(insertOp);
914           return WalkResult::advance();
915         });
916     if (walkResult.wasInterrupted())
917       return failure();
918 
919     // Replace the op.
920     replaceOpWithBufferizedValues(b, op, newResults);
921 
922     return success();
923   }
924 };
925 
926 /// Nothing to do for PerformConcurrentlyOp.
927 struct PerformConcurrentlyOpInterface
928     : public BufferizableOpInterface::ExternalModel<
929           PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
930   LogicalResult bufferize(Operation *op, RewriterBase &b,
931                           BufferizationState &state) const {
932     assert(false && "op does not have any tensor OpOperands / OpResults");
933     return failure();
934   }
935 };
936 
937 /// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
938 /// equivalent operand / result and same offset/sizes/strides specification).
939 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
940                                          ExtractSliceOp st,
941                                          ParallelInsertSliceOp sti) {
942   if (!st || !sti)
943     return false;
944   if (st != sti &&
945       !state.areEquivalentBufferizedValues(st.source(), sti.getDest()))
946     return false;
947   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
948     return false;
949   return true;
950 }
951 
952 /// Return true if `value` is originating from an ExtractSliceOp that matches
953 /// the given InsertSliceOp.
954 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
955                                       ParallelInsertSliceOp insertOp) {
956   auto condition = [&](Value val) {
957     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
958       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
959         return true;
960     return false;
961   };
962 
963   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
964                       condition);
965 }
966 
967 /// Analysis of ParallelInsertSliceOp.
968 struct ParallelInsertSliceOpInterface
969     : public BufferizableOpInterface::ExternalModel<
970           ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
971   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
972                                             const AnalysisState &state) const {
973     if (&opOperand != &op->getOpOperand(1) /*dest*/)
974       return {};
975 
976     // ParallelInsertSliceOp itself has no results. Tensors are returned via
977     // the parent op.
978     auto foreachThreadOp = op->getParentOfType<ForeachThreadOp>();
979     assert(foreachThreadOp &&
980            "could not find valid owner of parallel_insert_slice");
981 
982     // The i-th ParallelInsertSliceOp result is returned via the i-th OpResult
983     // of the parent ForeachThreadOp.
984     Block *block = op->getBlock();
985     unsigned int opIdx = 0;
986     for (ParallelInsertSliceOp insertOp :
987          block->getOps<ParallelInsertSliceOp>()) {
988       if (insertOp.getOperation() == op)
989         break;
990       ++opIdx;
991     }
992     assert(opIdx < foreachThreadOp->getNumResults() &&
993            "could not find op inside terminator op");
994 
995     return {foreachThreadOp->getResult(opIdx)};
996   }
997 
998   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
999                               const AnalysisState &state) const {
1000     return true;
1001   }
1002 
1003   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1004                                const AnalysisState &state) const {
1005     return &opOperand == &op->getOpOperand(1) /*dest*/;
1006   }
1007 
1008   BufferRelation bufferRelation(Operation *op, OpResult opResult,
1009                                 const AnalysisState &state) const {
1010     return BufferRelation::Equivalent;
1011   }
1012 
1013   LogicalResult bufferize(Operation *op, RewriterBase &b,
1014                           BufferizationState &state) const {
1015     // Will be bufferized as part of ForeachThreadOp.
1016     return failure();
1017   }
1018 
1019   // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
1020   // the code.
1021   bool isNotConflicting(Operation *op, OpOperand *uRead,
1022                         OpOperand *uConflictingWrite,
1023                         const AnalysisState &state) const {
1024     Operation *readingOp = uRead->getOwner();
1025     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
1026 
1027     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
1028     // uRead is an InsertSliceOp...
1029     if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
1030       // As an example, consider the following IR.
1031       //
1032       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1033       // %1 = linalg.fill %cst, %0 {inplace= [true] }
1034       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1035       //     {inplace= [true] }
1036 
1037       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
1038       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1039           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
1040                                     insertSliceOp))
1041         // Case 1: The main insight is that InsertSliceOp reads only part of
1042         // the destination tensor. The overwritten area is not read. If
1043         // uConflictingWrite writes into exactly the memory location that is
1044         // being read by uRead, this is not a conflict.
1045         //
1046         // In the above example:
1047         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
1048         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
1049         //
1050         // The read of %t does not conflict with the write of the FillOp
1051         // (same aliases!) because the area that the FillOp operates on is
1052         // exactly the one that is *not* read via %t.
1053         return true;
1054 
1055       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
1056           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1057           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
1058         // Case 2: The read of the source tensor and the write to the dest
1059         // tensor via an InsertSliceOp is not a conflict if the read is
1060         // reading exactly that part of an equivalent tensor that the
1061         // InsertSliceOp is writing.
1062         //
1063         // In the above example:
1064         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
1065         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1066         return true;
1067     }
1068 
1069     // If uConflictingWrite is an InsertSliceOp...
1070     if (auto insertSliceOp =
1071             dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
1072       // As an example, consider the following IR.
1073       //
1074       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1075       // %1 = linalg.fill %cst, %0 {inplace= [true] }
1076       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1077       //     {inplace= [true] }
1078       // %3 = vector.transfer_read %1, %cst
1079       //
1080       // In the above example:
1081       // uRead             = OpOperand 0 (%1) of vector.transfer_read
1082       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1083       // lastWrite         = %1
1084       //
1085       // This is not a conflict because the InsertSliceOp overwrites the
1086       // memory segment of %1 with the exact same data. (Effectively, there
1087       // is no memory write here.)
1088       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1089           state.areEquivalentBufferizedValues(uRead->get(),
1090                                               insertSliceOp.getSource()) &&
1091           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
1092                                     insertSliceOp))
1093         return true;
1094 
1095     return false;
1096   }
1097 };
1098 
1099 } // namespace
1100 } // namespace scf
1101 } // namespace mlir
1102 
1103 void mlir::scf::registerBufferizableOpInterfaceExternalModels(
1104     DialectRegistry &registry) {
1105   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1106     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1107     ForOp::attachInterface<ForOpInterface>(*ctx);
1108     IfOp::attachInterface<IfOpInterface>(*ctx);
1109     ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
1110     ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1111         *ctx);
1112     PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
1113         *ctx);
1114     WhileOp::attachInterface<WhileOpInterface>(*ctx);
1115     YieldOp::attachInterface<YieldOpInterface>(*ctx);
1116   });
1117 }
1118