1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Utils/StaticValueUtils.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/IR/PatternMatch.h"
21 
22 using namespace mlir;
23 using namespace mlir::bufferization;
24 using namespace mlir::scf;
25 
26 namespace mlir {
27 namespace scf {
28 namespace {
29 
30 // bufferization.to_memref is not allowed to change the rank.
31 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
32 #ifndef NDEBUG
33   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
34   assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
35                                 rankedTensorType.getRank())) &&
36          "to_memref would be invalid: mismatching ranks");
37 #endif
38 }
39 
40 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
41 /// fully implemented at the moment.
42 struct ExecuteRegionOpInterface
43     : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
44                                                     scf::ExecuteRegionOp> {
45   SmallVector<OpOperand *>
46   getAliasingOpOperand(Operation *op, OpResult opResult,
47                        const AnalysisState &state) const {
48     // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
49     // any SSA value that is in scope. To allow for use-def chain traversal
50     // through ExecuteRegionOps in the analysis, the corresponding yield value
51     // is considered to be aliasing with the result.
52     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
53     size_t resultNum = std::distance(op->getOpResults().begin(),
54                                      llvm::find(op->getOpResults(), opResult));
55     // TODO: Support multiple blocks.
56     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
57            "expected exactly 1 block");
58     auto yieldOp = dyn_cast<scf::YieldOp>(
59         executeRegionOp.getRegion().front().getTerminator());
60     assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
61     return {&yieldOp->getOpOperand(resultNum)};
62   }
63 
64   // TODO: For better bufferization results, this could return `true` only if
65   // there is a memory write in the region.
66   bool isMemoryWrite(Operation *op, OpResult opResult,
67                      const AnalysisState &state) const {
68     // Similar to scf.if, results of this op are always considered memory writes
69     // in the analysis. This is a useful pattern for all ops that have tensor
70     // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
71     // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
72     // ops without OpOperands.
73     return true;
74   }
75 
76   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
77                           const BufferizationOptions &options) const {
78     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
79     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
80            "only 1 block supported");
81     auto yieldOp =
82         cast<scf::YieldOp>(executeRegionOp.getRegion().front().getTerminator());
83     TypeRange newResultTypes(yieldOp.getResults());
84 
85     // Create new op and move over region.
86     auto newOp =
87         rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
88     newOp.getRegion().takeBody(executeRegionOp.getRegion());
89 
90     // Update all uses of the old op.
91     rewriter.setInsertionPointAfter(newOp);
92     SmallVector<Value> newResults;
93     for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
94       if (it.value().isa<TensorType>()) {
95         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
96             executeRegionOp.getLoc(), newOp->getResult(it.index())));
97       } else {
98         newResults.push_back(newOp->getResult(it.index()));
99       }
100     }
101 
102     // Replace old op.
103     rewriter.replaceOp(executeRegionOp, newResults);
104 
105     return success();
106   }
107 
108   BufferRelation bufferRelation(Operation *op, OpResult opResult,
109                                 const AnalysisState &state) const {
110     return BufferRelation::Equivalent;
111   }
112 };
113 
114 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
115 struct IfOpInterface
116     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
117   SmallVector<OpOperand *>
118   getAliasingOpOperand(Operation *op, OpResult opResult,
119                        const AnalysisState &state) const {
120     // IfOps do not have tensor OpOperands. The yielded value can be any SSA
121     // value that is in scope. To allow for use-def chain traversal through
122     // IfOps in the analysis, both corresponding yield values from the then/else
123     // branches are considered to be aliasing with the result.
124     auto ifOp = cast<scf::IfOp>(op);
125     size_t resultNum = std::distance(op->getOpResults().begin(),
126                                      llvm::find(op->getOpResults(), opResult));
127     return {&ifOp.thenYield()->getOpOperand(resultNum),
128             &ifOp.elseYield()->getOpOperand(resultNum)};
129   }
130 
131   // TODO: For better bufferization results, this could return `true` only if
132   // there is a memory write in one (or both) of the branches. Since this is not
133   // allowed at the moment, we should never encounter scf.ifs that yield
134   // unmodified tensors. Such scf.yield ops could just fold away.
135   bool isMemoryWrite(Operation *op, OpResult opResult,
136                      const AnalysisState &state) const {
137     // IfOp results are always considered memory writes in the analysis. This
138     // design decision simplifies the analysis considerably. E.g., consider the
139     // following test case:
140     //
141     // %0 = "some_writing_op" : tensor<?xf32>
142     // %r = scf.if %c -> (tensor<?xf32>) {
143     //   scf.yield %0
144     // } else {
145     //   %1 = "another_writing_op"(%0) : tensor<?xf32>
146     // }
147     // "some_reading_op"(%r)
148     //
149     // "another_writing_op" in the above example should be able to bufferize
150     // inplace in the absence of another read of %0. However, if the scf.if op
151     // would not be considered a "write", the analysis would detect the
152     // following conflict:
153     //
154     // * read = some_reading_op
155     // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
156     // * conflictingWrite = %1
157     //
158     // For more details, check the "scf.IfOp" section of the design document.
159     return true;
160   }
161 
162   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
163                           const BufferizationOptions &options) const {
164     OpBuilder::InsertionGuard g(rewriter);
165     auto ifOp = cast<scf::IfOp>(op);
166     auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
167     auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
168 
169     // Reconcile type mismatches between then/else branches by inserting memref
170     // casts.
171     SmallVector<Value> thenResults, elseResults;
172     bool insertedCast = false;
173     for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) {
174       Value thenValue = thenYieldOp.getResults()[i];
175       Value elseValue = elseYieldOp.getResults()[i];
176       if (thenValue.getType() == elseValue.getType()) {
177         thenResults.push_back(thenValue);
178         elseResults.push_back(elseValue);
179         continue;
180       }
181 
182       // Type mismatch between then/else yield value. Cast both to a memref type
183       // with a fully dynamic layout map.
184       auto thenMemrefType = thenValue.getType().cast<BaseMemRefType>();
185       auto elseMemrefType = elseValue.getType().cast<BaseMemRefType>();
186       if (thenMemrefType.getMemorySpaceAsInt() !=
187           elseMemrefType.getMemorySpaceAsInt())
188         return op->emitError("inconsistent memory space on then/else branches");
189       rewriter.setInsertionPoint(thenYieldOp);
190       BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout(
191           ifOp.getResultTypes()[i].cast<TensorType>(),
192           thenMemrefType.getMemorySpaceAsInt());
193       thenResults.push_back(rewriter.create<memref::CastOp>(
194           thenYieldOp.getLoc(), memrefType, thenValue));
195       rewriter.setInsertionPoint(elseYieldOp);
196       elseResults.push_back(rewriter.create<memref::CastOp>(
197           elseYieldOp.getLoc(), memrefType, elseValue));
198       insertedCast = true;
199     }
200 
201     if (insertedCast) {
202       rewriter.setInsertionPoint(thenYieldOp);
203       rewriter.replaceOpWithNewOp<scf::YieldOp>(thenYieldOp, thenResults);
204       rewriter.setInsertionPoint(elseYieldOp);
205       rewriter.replaceOpWithNewOp<scf::YieldOp>(elseYieldOp, elseResults);
206     }
207 
208     // Create new op.
209     rewriter.setInsertionPoint(ifOp);
210     ValueRange resultsValueRange(thenResults);
211     TypeRange newTypes(resultsValueRange);
212     auto newIfOp =
213         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
214                                    /*withElseRegion=*/true);
215 
216     // Move over then/else blocks.
217     rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
218     rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
219 
220     // Replace op results.
221     replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
222 
223     return success();
224   }
225 
226   BufferRelation bufferRelation(Operation *op, OpResult opResult,
227                                 const AnalysisState &state) const {
228     // IfOp results are equivalent to their corresponding yield values if both
229     // yield values are equivalent to each other.
230     auto bufferizableOp = cast<BufferizableOpInterface>(op);
231     SmallVector<OpOperand *> yieldValues =
232         bufferizableOp.getAliasingOpOperand(opResult, state);
233     assert(yieldValues.size() == 2 && "expected 2 yield values");
234     bool equivalentYields = state.areEquivalentBufferizedValues(
235         yieldValues[0]->get(), yieldValues[1]->get());
236     return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
237   }
238 };
239 
240 /// Helper function for loop bufferization. Return the indices of all values
241 /// that have a tensor type.
242 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
243   DenseSet<int64_t> result;
244   for (const auto &it : llvm::enumerate(values))
245     if (it.value().getType().isa<TensorType>())
246       result.insert(it.index());
247   return result;
248 }
249 
250 /// Helper function for loop bufferization. Return the indices of all
251 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
252 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
253                                        ValueRange yieldedValues,
254                                        const AnalysisState &state) {
255   unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
256   DenseSet<int64_t> result;
257   for (unsigned int i = 0; i < minSize; ++i) {
258     if (!bbArgs[i].getType().isa<TensorType>() ||
259         !yieldedValues[i].getType().isa<TensorType>())
260       continue;
261     if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
262       result.insert(i);
263   }
264   return result;
265 }
266 
267 /// Helper function for loop bufferization. Cast the given buffer to the given
268 /// memref type.
269 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
270   assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
271   assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
272   // If the buffer already has the correct type, no cast is needed.
273   if (buffer.getType() == type)
274     return buffer;
275   // TODO: In case `type` has a layout map that is not the fully dynamic
276   // one, we may not be able to cast the buffer. In that case, the loop
277   // iter_arg's layout map must be changed (see uses of `castBuffer`).
278   assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
279          "scf.while op bufferization: cast incompatible");
280   return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
281 }
282 
283 /// Helper function for loop bufferization. Return the bufferized values of the
284 /// given OpOperands. If an operand is not a tensor, return the original value.
285 static FailureOr<SmallVector<Value>>
286 getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
287            const BufferizationOptions &options) {
288   SmallVector<Value> result;
289   for (OpOperand &opOperand : operands) {
290     if (opOperand.get().getType().isa<TensorType>()) {
291       FailureOr<Value> resultBuffer =
292           getBuffer(rewriter, opOperand.get(), options);
293       if (failed(resultBuffer))
294         return failure();
295       result.push_back(*resultBuffer);
296     } else {
297       result.push_back(opOperand.get());
298     }
299   }
300   return result;
301 }
302 
303 /// Helper function for loop bufferization. Compute the buffer that should be
304 /// yielded from a loop block (loop body or loop condition).
305 static FailureOr<Value> getYieldedBuffer(RewriterBase &rewriter, Value tensor,
306                                          BaseMemRefType type,
307                                          const BufferizationOptions &options) {
308   assert(tensor.getType().isa<TensorType>() && "expected tensor");
309   ensureToMemrefOpIsValid(tensor, type);
310   FailureOr<Value> yieldedVal = getBuffer(rewriter, tensor, options);
311   if (failed(yieldedVal))
312     return failure();
313   return castBuffer(rewriter, *yieldedVal, type);
314 }
315 
316 /// Helper function for loop bufferization. Given a range of values, apply
317 /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
318 /// value in the result vector.
319 static FailureOr<SmallVector<Value>>
320 convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
321                     llvm::function_ref<FailureOr<Value>(Value, int64_t)> func) {
322   SmallVector<Value> result;
323   for (const auto &it : llvm::enumerate(values)) {
324     size_t idx = it.index();
325     Value val = it.value();
326     if (tensorIndices.contains(idx)) {
327       FailureOr<Value> maybeVal = func(val, idx);
328       if (failed(maybeVal))
329         return failure();
330       result.push_back(*maybeVal);
331     } else {
332       result.push_back(val);
333     }
334   }
335   return result;
336 }
337 
338 /// Helper function for loop bufferization. Given a list of pre-bufferization
339 /// yielded values, compute the list of bufferized yielded values.
340 FailureOr<SmallVector<Value>>
341 getYieldedValues(RewriterBase &rewriter, ValueRange values,
342                  TypeRange bufferizedTypes,
343                  const DenseSet<int64_t> &tensorIndices,
344                  const BufferizationOptions &options) {
345   return convertTensorValues(
346       values, tensorIndices, [&](Value val, int64_t index) {
347         return getYieldedBuffer(rewriter, val,
348                                 bufferizedTypes[index].cast<BaseMemRefType>(),
349                                 options);
350       });
351 }
352 
353 /// Helper function for loop bufferization. Given a list of bbArgs of the new
354 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
355 /// ToTensorOps, so that the block body can be moved over to the new op.
356 SmallVector<Value>
357 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
358                      const DenseSet<int64_t> &tensorIndices) {
359   SmallVector<Value> result;
360   for (const auto &it : llvm::enumerate(bbArgs)) {
361     size_t idx = it.index();
362     Value val = it.value();
363     if (tensorIndices.contains(idx)) {
364       result.push_back(
365           rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
366               .getResult());
367     } else {
368       result.push_back(val);
369     }
370   }
371   return result;
372 }
373 
374 /// Bufferization of scf.for. Replace with a new scf.for that operates on
375 /// memrefs.
376 struct ForOpInterface
377     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
378                                                     scf::ForOp> {
379   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
380                               const AnalysisState &state) const {
381     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
382     // its matching bbArg may.
383     auto forOp = cast<scf::ForOp>(op);
384     return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
385   }
386 
387   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
388                                const AnalysisState &state) const {
389     // Tensor iter_args of scf::ForOps are always considered as a write.
390     return true;
391   }
392 
393   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
394                                             const AnalysisState &state) const {
395     auto forOp = cast<scf::ForOp>(op);
396     return {forOp.getResultForOpOperand(opOperand)};
397   }
398 
399   BufferRelation bufferRelation(Operation *op, OpResult opResult,
400                                 const AnalysisState &state) const {
401     // ForOp results are equivalent to their corresponding init_args if the
402     // corresponding iter_args and yield values are equivalent.
403     auto forOp = cast<scf::ForOp>(op);
404     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
405     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
406     auto yieldOp =
407         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
408     bool equivalentYield = state.areEquivalentBufferizedValues(
409         bbArg, yieldOp->getOperand(opResult.getResultNumber()));
410     return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
411   }
412 
413   bool isWritable(Operation *op, Value value,
414                   const AnalysisState &state) const {
415     // Interestingly, scf::ForOp's bbArg can **always** be viewed
416     // inplace from the perspective of ops nested under:
417     //   1. Either the matching iter operand is not bufferized inplace and an
418     //      alloc + optional copy makes the bbArg itself inplaceable.
419     //   2. Or the matching iter operand is bufferized inplace and bbArg just
420     //      bufferizes to that too.
421     return true;
422   }
423 
424   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
425                                  const AnalysisState &state) const {
426     auto bufferizableOp = cast<BufferizableOpInterface>(op);
427     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
428       return failure();
429 
430     if (!state.getOptions().enforceAliasingInvariants)
431       return success();
432 
433     // According to the `getAliasing...` implementations, a bufferized OpResult
434     // may alias only with the corresponding bufferized init_arg and with no
435     // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
436     // but not with any other OpOperand. If a corresponding OpResult/init_arg
437     // pair bufferizes to equivalent buffers, this aliasing requirement is
438     // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
439     // (New buffer copies do not alias with any buffer.)
440     auto forOp = cast<scf::ForOp>(op);
441     auto yieldOp =
442         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
443     OpBuilder::InsertionGuard g(rewriter);
444     rewriter.setInsertionPoint(yieldOp);
445 
446     // Indices of all iter_args that have tensor type. These are the ones that
447     // are bufferized.
448     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
449     // For every yielded value, is the value equivalent to its corresponding
450     // bbArg?
451     DenseSet<int64_t> equivalentYields = getEquivalentBuffers(
452         forOp.getRegionIterArgs(), yieldOp.getResults(), state);
453     SmallVector<Value> yieldValues;
454     for (int64_t idx = 0;
455          idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
456       Value value = yieldOp.getResults()[idx];
457       if (!indices.contains(idx) || equivalentYields.contains(idx)) {
458         yieldValues.push_back(value);
459         continue;
460       }
461       FailureOr<Value> alloc =
462           allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
463                                        /*escape=*/true, state.getOptions());
464       if (failed(alloc))
465         return failure();
466       yieldValues.push_back(*alloc);
467     }
468 
469     rewriter.updateRootInPlace(
470         yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
471     return success();
472   }
473 
474   FailureOr<BaseMemRefType>
475   getBufferType(Operation *op, BlockArgument bbArg,
476                 const BufferizationOptions &options) const {
477     auto forOp = cast<scf::ForOp>(op);
478     return bufferization::getBufferType(
479         forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
480   }
481 
482   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
483                           const BufferizationOptions &options) const {
484     auto forOp = cast<scf::ForOp>(op);
485     Block *oldLoopBody = &forOp.getLoopBody().front();
486 
487     // Indices of all iter_args that have tensor type. These are the ones that
488     // are bufferized.
489     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
490 
491     // The new memref init_args of the loop.
492     FailureOr<SmallVector<Value>> maybeInitArgs =
493         getBuffers(rewriter, forOp.getIterOpOperands(), options);
494     if (failed(maybeInitArgs))
495       return failure();
496     SmallVector<Value> initArgs = *maybeInitArgs;
497 
498     // Construct a new scf.for op with memref instead of tensor values.
499     auto newForOp = rewriter.create<scf::ForOp>(
500         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
501         forOp.getStep(), initArgs);
502     newForOp->setAttrs(forOp->getAttrs());
503     ValueRange initArgsRange(initArgs);
504     TypeRange initArgsTypes(initArgsRange);
505     Block *loopBody = &newForOp.getLoopBody().front();
506 
507     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
508     // iter_args of the new loop in ToTensorOps.
509     rewriter.setInsertionPointToStart(loopBody);
510     SmallVector<Value> iterArgs =
511         getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
512     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
513 
514     // Move loop body to new loop.
515     rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
516 
517     // Replace loop results.
518     replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
519 
520     return success();
521   }
522 
523   /// Assert that yielded values of an scf.for op are equivalent to their
524   /// corresponding bbArgs. In that case, the buffer relations of the
525   /// corresponding OpResults are "Equivalent".
526   ///
527   /// If this is not the case, an allocs+copies are inserted and yielded from
528   /// the loop. This could be a performance problem, so it must be explicitly
529   /// activated with `alloc-return-allocs`.
530   LogicalResult verifyAnalysis(Operation *op,
531                                const AnalysisState &state) const {
532     const auto &options =
533         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
534     if (options.allowReturnAllocs)
535       return success();
536 
537     auto forOp = cast<scf::ForOp>(op);
538     auto yieldOp =
539         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
540     for (OpResult opResult : op->getOpResults()) {
541       if (!opResult.getType().isa<TensorType>())
542         continue;
543 
544       // Note: This is overly strict. We should check for aliasing bufferized
545       // values. But we don't have a "must-alias" analysis yet.
546       if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
547         return yieldOp->emitError()
548                << "Yield operand #" << opResult.getResultNumber()
549                << " is not equivalent to the corresponding iter bbArg";
550     }
551 
552     return success();
553   }
554 };
555 
556 /// Bufferization of scf.while. Replace with a new scf.while that operates on
557 /// memrefs.
558 struct WhileOpInterface
559     : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
560                                                     scf::WhileOp> {
561   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
562                               const AnalysisState &state) const {
563     // Tensor iter_args of scf::WhileOps are always considered as a read.
564     return true;
565   }
566 
567   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
568                                const AnalysisState &state) const {
569     // Tensor iter_args of scf::WhileOps are always considered as a write.
570     return true;
571   }
572 
573   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
574                                             const AnalysisState &state) const {
575     auto whileOp = cast<scf::WhileOp>(op);
576     unsigned int idx = opOperand.getOperandNumber();
577 
578     // The OpResults and OpOperands may not match. They may not even have the
579     // same type. The number of OpResults and OpOperands can also differ.
580     if (idx >= op->getNumResults() ||
581         opOperand.get().getType() != op->getResult(idx).getType())
582       return {};
583 
584     // The only aliasing OpResult may be the one at the same index.
585     return {whileOp->getResult(idx)};
586   }
587 
588   BufferRelation bufferRelation(Operation *op, OpResult opResult,
589                                 const AnalysisState &state) const {
590     // WhileOp results are equivalent to their corresponding init_args if the
591     // corresponding iter_args and yield values are equivalent (for both the
592     // "before" and the "after" block).
593     unsigned int resultNumber = opResult.getResultNumber();
594     auto whileOp = cast<scf::WhileOp>(op);
595 
596     // The "before" region bbArgs and the OpResults may not match.
597     if (resultNumber >= whileOp.getBeforeArguments().size())
598       return BufferRelation::None;
599     if (opResult.getType() !=
600         whileOp.getBeforeArguments()[resultNumber].getType())
601       return BufferRelation::None;
602 
603     auto conditionOp = whileOp.getConditionOp();
604     BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
605     Value conditionOperand = conditionOp.getArgs()[resultNumber];
606     bool equivCondition =
607         state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
608 
609     auto yieldOp = whileOp.getYieldOp();
610     BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
611     Value yieldOperand = yieldOp.getOperand(resultNumber);
612     bool equivYield =
613         state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
614 
615     return equivCondition && equivYield ? BufferRelation::Equivalent
616                                         : BufferRelation::None;
617   }
618 
619   bool isWritable(Operation *op, Value value,
620                   const AnalysisState &state) const {
621     // Interestingly, scf::WhileOp's bbArg can **always** be viewed
622     // inplace from the perspective of ops nested under:
623     //   1. Either the matching iter operand is not bufferized inplace and an
624     //      alloc + optional copy makes the bbArg itself inplaceable.
625     //   2. Or the matching iter operand is bufferized inplace and bbArg just
626     //      bufferizes to that too.
627     return true;
628   }
629 
630   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
631                                  const AnalysisState &state) const {
632     auto bufferizableOp = cast<BufferizableOpInterface>(op);
633     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
634       return failure();
635 
636     if (!state.getOptions().enforceAliasingInvariants)
637       return success();
638 
639     // According to the `getAliasing...` implementations, a bufferized OpResult
640     // may alias only with the corresponding bufferized init_arg and with no
641     // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
642     // but not with any other OpOperand. If a corresponding OpResult/init_arg
643     // pair bufferizes to equivalent buffers, this aliasing requirement is
644     // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
645     // (New buffer copies do not alias with any buffer.)
646     OpBuilder::InsertionGuard g(rewriter);
647     auto whileOp = cast<scf::WhileOp>(op);
648     auto conditionOp = whileOp.getConditionOp();
649     auto yieldOp = whileOp.getYieldOp();
650 
651     // Indices of all bbArgs that have tensor type. These are the ones that
652     // are bufferized. The "before" and "after" regions may have different args.
653     DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
654     DenseSet<int64_t> indicesAfter =
655         getTensorIndices(whileOp.getAfterArguments());
656 
657     // For every yielded value, is the value equivalent to its corresponding
658     // bbArg?
659     DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
660         whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
661     DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
662         whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
663 
664     // Update "before" region.
665     rewriter.setInsertionPoint(conditionOp);
666     SmallVector<Value> beforeYieldValues;
667     for (int64_t idx = 0;
668          idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
669       Value value = conditionOp.getArgs()[idx];
670       if (!indicesBefore.contains(idx) ||
671           equivalentYieldsBefore.contains(idx)) {
672         beforeYieldValues.push_back(value);
673         continue;
674       }
675       FailureOr<Value> alloc =
676           allocateTensorForShapedValue(rewriter, conditionOp.getLoc(), value,
677                                        /*escape=*/true, state.getOptions());
678       if (failed(alloc))
679         return failure();
680       beforeYieldValues.push_back(*alloc);
681     }
682     rewriter.updateRootInPlace(conditionOp, [&]() {
683       conditionOp.getArgsMutable().assign(beforeYieldValues);
684     });
685 
686     // Update "after" region.
687     rewriter.setInsertionPoint(yieldOp);
688     SmallVector<Value> afterYieldValues;
689     for (int64_t idx = 0;
690          idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
691       Value value = yieldOp.getResults()[idx];
692       if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) {
693         afterYieldValues.push_back(value);
694         continue;
695       }
696       FailureOr<Value> alloc =
697           allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
698                                        /*escape=*/true, state.getOptions());
699       if (failed(alloc))
700         return failure();
701       afterYieldValues.push_back(*alloc);
702     }
703     rewriter.updateRootInPlace(yieldOp, [&]() {
704       yieldOp.getResultsMutable().assign(afterYieldValues);
705     });
706 
707     return success();
708   }
709 
710   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
711                           const BufferizationOptions &options) const {
712     auto whileOp = cast<scf::WhileOp>(op);
713 
714     assert(whileOp.getBefore().getBlocks().size() == 1 &&
715            "regions with multiple blocks not supported");
716     Block *beforeBody = &whileOp.getBefore().front();
717     assert(whileOp.getAfter().getBlocks().size() == 1 &&
718            "regions with multiple blocks not supported");
719     Block *afterBody = &whileOp.getAfter().front();
720 
721     // Indices of all bbArgs that have tensor type. These are the ones that
722     // are bufferized. The "before" and "after" regions may have different args.
723     DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
724     DenseSet<int64_t> indicesAfter =
725         getTensorIndices(whileOp.getAfterArguments());
726 
727     // The new memref init_args of the loop.
728     FailureOr<SmallVector<Value>> maybeInitArgs =
729         getBuffers(rewriter, whileOp->getOpOperands(), options);
730     if (failed(maybeInitArgs))
731       return failure();
732     SmallVector<Value> initArgs = *maybeInitArgs;
733 
734     // The result types of a WhileOp are the same as the "after" bbArg types.
735     SmallVector<Type> argsTypesAfter = llvm::to_vector(
736         llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
737           // TODO: error handling
738           return bufferization::getBufferType(bbArg, options)->cast<Type>();
739         }));
740 
741     // Construct a new scf.while op with memref instead of tensor values.
742     ValueRange argsRangeBefore(initArgs);
743     TypeRange argsTypesBefore(argsRangeBefore);
744     auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
745                                                     argsTypesAfter, initArgs);
746 
747     // Add before/after regions to the new op.
748     SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc());
749     SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
750                                          whileOp.getLoc());
751     Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
752     newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
753     Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
754     newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
755 
756     // Set up new iter_args and move the loop condition block to the new op.
757     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
758     // in ToTensorOps.
759     rewriter.setInsertionPointToStart(newBeforeBody);
760     SmallVector<Value> newBeforeArgs = getBbArgReplacements(
761         rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
762     rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
763 
764     // Update scf.condition of new loop.
765     auto newConditionOp = newWhileOp.getConditionOp();
766     rewriter.setInsertionPoint(newConditionOp);
767     // Only equivalent buffers or new buffer allocations may be yielded to the
768     // "after" region.
769     // TODO: This could be relaxed for better bufferization results.
770     FailureOr<SmallVector<Value>> newConditionArgs =
771         getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
772                          indicesAfter, options);
773     if (failed(newConditionArgs))
774       return failure();
775     newConditionOp.getArgsMutable().assign(*newConditionArgs);
776 
777     // Set up new iter_args and move the loop body block to the new op.
778     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
779     // in ToTensorOps.
780     rewriter.setInsertionPointToStart(newAfterBody);
781     SmallVector<Value> newAfterArgs = getBbArgReplacements(
782         rewriter, newWhileOp.getAfterArguments(), indicesAfter);
783     rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
784 
785     // Update scf.yield of the new loop.
786     auto newYieldOp = newWhileOp.getYieldOp();
787     rewriter.setInsertionPoint(newYieldOp);
788     // Only equivalent buffers or new buffer allocations may be yielded to the
789     // "before" region.
790     // TODO: This could be relaxed for better bufferization results.
791     FailureOr<SmallVector<Value>> newYieldValues =
792         getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
793                          indicesBefore, options);
794     if (failed(newYieldValues))
795       return failure();
796     newYieldOp.getResultsMutable().assign(*newYieldValues);
797 
798     // Replace loop results.
799     replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
800 
801     return success();
802   }
803 
804   /// Assert that yielded values of an scf.while op are equivalent to their
805   /// corresponding bbArgs. In that case, the buffer relations of the
806   /// corresponding OpResults are "Equivalent".
807   ///
808   /// If this is not the case, allocs+copies are inserted and yielded from
809   /// the loop. This could be a performance problem, so it must be explicitly
810   /// activated with `alloc-return-allocs`.
811   ///
812   /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
813   /// equivalence condition must be checked for both.
814   LogicalResult verifyAnalysis(Operation *op,
815                                const AnalysisState &state) const {
816     auto whileOp = cast<scf::WhileOp>(op);
817     const auto &options =
818         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
819     if (options.allowReturnAllocs)
820       return success();
821 
822     auto conditionOp = whileOp.getConditionOp();
823     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
824       if (!it.value().getType().isa<TensorType>())
825         continue;
826       if (!state.areEquivalentBufferizedValues(
827               it.value(), conditionOp->getBlock()->getArgument(it.index())))
828         return conditionOp->emitError()
829                << "Condition arg #" << it.index()
830                << " is not equivalent to the corresponding iter bbArg";
831     }
832 
833     auto yieldOp = whileOp.getYieldOp();
834     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
835       if (!it.value().getType().isa<TensorType>())
836         continue;
837       if (!state.areEquivalentBufferizedValues(
838               it.value(), yieldOp->getBlock()->getArgument(it.index())))
839         return yieldOp->emitError()
840                << "Yield operand #" << it.index()
841                << " is not equivalent to the corresponding iter bbArg";
842     }
843 
844     return success();
845   }
846 };
847 
848 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
849 /// this is for analysis only.
850 struct YieldOpInterface
851     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
852                                                     scf::YieldOp> {
853   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
854                               const AnalysisState &state) const {
855     return true;
856   }
857 
858   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
859                                const AnalysisState &state) const {
860     return false;
861   }
862 
863   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
864                                             const AnalysisState &state) const {
865     if (isa<scf::IfOp>(op->getParentOp()))
866       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
867     if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
868       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
869     return {};
870   }
871 
872   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
873                             const AnalysisState &state) const {
874     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
875     // may be generated inside the block. We should not return/yield allocations
876     // when possible.
877     return true;
878   }
879 
880   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
881                           const BufferizationOptions &options) const {
882     auto yieldOp = cast<scf::YieldOp>(op);
883     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
884             yieldOp->getParentOp()))
885       return yieldOp->emitError("unsupported scf::YieldOp parent");
886 
887     // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized
888     // together with scf.while.)
889     if (isa<scf::WhileOp>(yieldOp->getParentOp()))
890       return success();
891 
892     SmallVector<Value> newResults;
893     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
894       Value value = it.value();
895       if (value.getType().isa<TensorType>()) {
896         FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
897         if (failed(maybeBuffer))
898           return failure();
899         Value buffer = *maybeBuffer;
900         if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
901           FailureOr<BaseMemRefType> resultType =
902               cast<BufferizableOpInterface>(forOp.getOperation())
903                   .getBufferType(forOp.getRegionIterArgs()[it.index()],
904                                  options);
905           if (failed(resultType))
906             return failure();
907           buffer = castBuffer(rewriter, buffer, *resultType);
908         }
909         newResults.push_back(buffer);
910       } else {
911         newResults.push_back(value);
912       }
913     }
914 
915     replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
916     return success();
917   }
918 };
919 
920 using tensor::ExtractSliceOp;
921 
922 /// Return the destinations that an ForeachThreadOp is inserting into. One per
923 /// ParallelInsertSliceOp.
924 static SmallVector<OpOperand *>
925 getInsertionDest(ForeachThreadOp foreachThreadOp) {
926   PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
927   SmallVector<OpOperand *> result;
928   terminator.walk([&](ParallelInsertSliceOp insertOp) {
929     result.push_back(&insertOp->getOpOperand(1) /*dest*/);
930   });
931   return result;
932 }
933 
934 /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
935 /// region. There are op interfaces for the terminators (PerformConcurrentlyOp
936 /// and ParallelInsertSliceOp), but these are only used during analysis. Not
937 /// for bufferization.
938 struct ForeachThreadOpInterface
939     : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface,
940                                                     ForeachThreadOp> {
941   SmallVector<OpOperand *>
942   getAliasingOpOperand(Operation *op, OpResult opResult,
943                        const AnalysisState &state) const {
944     // Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
945     auto foreachThreadOp = cast<ForeachThreadOp>(op);
946     return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]};
947   }
948 
949   bool isMemoryWrite(Operation *op, OpResult opResult,
950                      const AnalysisState &state) const {
951     // This op is a memory write. Stop lookup here to avoid finding false
952     // conflicts involving this op and one of the ops in the region. This is
953     // similar to how scf.if ops are analyzed.
954     return true;
955   }
956 
957   BufferRelation bufferRelation(Operation *op, OpResult opResult,
958                                 const AnalysisState &state) const {
959     return BufferRelation::Equivalent;
960   }
961 
962   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
963                                  const AnalysisState &state) const {
964     auto bufferizableOp = cast<BufferizableOpInterface>(op);
965     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
966       return failure();
967 
968     OpBuilder::InsertionGuard g(rewriter);
969     auto foreachThreadOp = cast<ForeachThreadOp>(op);
970     for (OpResult opResult : foreachThreadOp->getOpResults()) {
971       SmallVector<OpOperand *> destOperands =
972           state.getAliasingOpOperand(opResult);
973       assert(destOperands.size() == 1 &&
974              "expected exactly one aliasing OpOperand");
975       assert(isa<ParallelInsertSliceOp>(destOperands.front()->getOwner()) &&
976              "expected ParallelInsertSliceOp");
977 
978       // Nothing to do if there is no conflict.
979       if (state.isInPlace(*destOperands.front()))
980         continue;
981 
982       // Insert tensor allocation.
983       bool isYielded = state.isTensorYielded(opResult);
984       FailureOr<Value> alloc = allocateTensorForShapedValue(
985           rewriter, op->getLoc(), destOperands.front()->get(),
986           /*escape=*/isYielded, state.getOptions());
987       if (failed(alloc))
988         return failure();
989 
990       // Update terminator operand.
991       rewriter.updateRootInPlace(destOperands.front()->getOwner(),
992                                  [&]() { destOperands.front()->set(*alloc); });
993     }
994 
995     return success();
996   }
997 
998   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
999                           const BufferizationOptions &options) const {
1000     auto foreachThreadOp = cast<ForeachThreadOp>(op);
1001 
1002 #ifndef NDEBUG
1003     // ParallelInsertSliceOpInterface replaces all uses.
1004     for (OpResult opResult : foreachThreadOp->getOpResults())
1005       assert(opResult.getUses().empty() &&
1006              "expected that all uses were already replaced");
1007 #endif // NDEBUG
1008 
1009     // Create new ForeachThreadOp without any results and drop the automatically
1010     // introduced terminator.
1011     TypeRange newResultTypes;
1012     auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
1013         foreachThreadOp.getLoc(), newResultTypes,
1014         foreachThreadOp.getNumThreads(),
1015         extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping()));
1016     newForeachThreadOp.getBody()->getTerminator()->erase();
1017 
1018     // Move over block contents of the old op.
1019     rewriter.mergeBlocks(foreachThreadOp.getBody(),
1020                          newForeachThreadOp.getBody(),
1021                          {newForeachThreadOp.getBody()->getArguments()});
1022 
1023     // Remove the old op.
1024     rewriter.eraseOp(op);
1025 
1026     return success();
1027   }
1028 };
1029 
1030 /// Nothing to do for PerformConcurrentlyOp.
1031 struct PerformConcurrentlyOpInterface
1032     : public BufferizableOpInterface::ExternalModel<
1033           PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
1034   LogicalResult bufferize(Operation *op, RewriterBase &b,
1035                           const BufferizationOptions &options) const {
1036     llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1037     return failure();
1038   }
1039 };
1040 
1041 /// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
1042 /// equivalent operand / result and same offset/sizes/strides specification).
1043 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
1044                                          ExtractSliceOp st,
1045                                          ParallelInsertSliceOp sti) {
1046   if (!st || !sti)
1047     return false;
1048   if (st != sti &&
1049       !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
1050     return false;
1051   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
1052     return false;
1053   return true;
1054 }
1055 
1056 /// Return true if `value` is originating from an ExtractSliceOp that matches
1057 /// the given InsertSliceOp.
1058 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
1059                                       ParallelInsertSliceOp insertOp) {
1060   auto condition = [&](Value val) {
1061     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
1062       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
1063         return true;
1064     return false;
1065   };
1066 
1067   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
1068                       condition);
1069 }
1070 
1071 /// Analysis of ParallelInsertSliceOp.
1072 struct ParallelInsertSliceOpInterface
1073     : public BufferizableOpInterface::ExternalModel<
1074           ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
1075   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
1076                                             const AnalysisState &state) const {
1077     if (&opOperand != &op->getOpOperand(1) /*dest*/)
1078       return {};
1079 
1080     // ParallelInsertSliceOp itself has no results. Tensors are returned via
1081     // the parent op.
1082     auto foreachThreadOp = op->getParentOfType<ForeachThreadOp>();
1083     assert(foreachThreadOp &&
1084            "could not find valid owner of parallel_insert_slice");
1085 
1086     // The i-th ParallelInsertSliceOp result is returned via the i-th OpResult
1087     // of the parent ForeachThreadOp.
1088     Block *block = op->getBlock();
1089     unsigned int opIdx = 0;
1090     for (ParallelInsertSliceOp insertOp :
1091          block->getOps<ParallelInsertSliceOp>()) {
1092       if (insertOp.getOperation() == op)
1093         break;
1094       ++opIdx;
1095     }
1096     assert(opIdx < foreachThreadOp->getNumResults() &&
1097            "could not find op inside terminator op");
1098 
1099     return {foreachThreadOp->getResult(opIdx)};
1100   }
1101 
1102   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1103                               const AnalysisState &state) const {
1104     return true;
1105   }
1106 
1107   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1108                                const AnalysisState &state) const {
1109     return &opOperand == &op->getOpOperand(1) /*dest*/;
1110   }
1111 
1112   BufferRelation bufferRelation(Operation *op, OpResult opResult,
1113                                 const AnalysisState &state) const {
1114     return BufferRelation::Equivalent;
1115   }
1116 
1117   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
1118                                  const AnalysisState &state) const {
1119     return success();
1120   }
1121 
1122   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1123                           const BufferizationOptions &options) const {
1124     OpBuilder::InsertionGuard g(rewriter);
1125     auto insertOp = cast<ParallelInsertSliceOp>(op);
1126     auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(op->getParentOp());
1127     auto foreachThreadOp =
1128         cast<ForeachThreadOp>(performConcurrentlyOp->getParentOp());
1129 
1130     // If the op bufferizes out-of-place, allocate the copy before the
1131     // ForeachThreadOp.
1132     rewriter.setInsertionPoint(foreachThreadOp);
1133     FailureOr<Value> destBuffer =
1134         getBuffer(rewriter, insertOp.getDest(), options);
1135     if (failed(destBuffer))
1136       return failure();
1137 
1138     // Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp.
1139     rewriter.setInsertionPoint(performConcurrentlyOp);
1140     FailureOr<Value> srcBuffer =
1141         getBuffer(rewriter, insertOp.getSource(), options);
1142     if (failed(srcBuffer))
1143       return failure();
1144     Value subview = rewriter.create<memref::SubViewOp>(
1145         insertOp.getLoc(), *destBuffer, insertOp.getMixedOffsets(),
1146         insertOp.getMixedSizes(), insertOp.getMixedStrides());
1147     // This memcpy will fold away if everything bufferizes in-place.
1148     if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer,
1149                                     subview)))
1150       return failure();
1151     rewriter.eraseOp(op);
1152 
1153     // Replace all uses of ForeachThreadOp (just the corresponding result).
1154     rewriter.setInsertionPointAfter(foreachThreadOp);
1155     Value toTensorOp =
1156         rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer);
1157     unsigned resultNum = 0;
1158     for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) {
1159       if (&nextOp == op)
1160         break;
1161       resultNum++;
1162     }
1163     assert(resultNum < foreachThreadOp->getNumResults() &&
1164            "ParallelInsertSliceOp not found in PerformConcurrentlyOp");
1165     SmallVector<OpOperand *> resultUses = llvm::to_vector(
1166         llvm::map_range(foreachThreadOp->getResult(resultNum).getUses(),
1167                         [](OpOperand &use) { return &use; }));
1168     for (OpOperand *use : resultUses) {
1169       rewriter.updateRootInPlace(use->getOwner(),
1170                                  [&]() { use->set(toTensorOp); });
1171     }
1172     return success();
1173   }
1174 
1175   // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
1176   // the code.
1177   bool isNotConflicting(Operation *op, OpOperand *uRead,
1178                         OpOperand *uConflictingWrite,
1179                         const AnalysisState &state) const {
1180     Operation *readingOp = uRead->getOwner();
1181     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
1182 
1183     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
1184     // uRead is an InsertSliceOp...
1185     if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
1186       // As an example, consider the following IR.
1187       //
1188       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1189       // %1 = linalg.fill %cst, %0 {inplace= [true] }
1190       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1191       //     {inplace= [true] }
1192 
1193       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
1194       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1195           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
1196                                     insertSliceOp))
1197         // Case 1: The main insight is that InsertSliceOp reads only part of
1198         // the destination tensor. The overwritten area is not read. If
1199         // uConflictingWrite writes into exactly the memory location that is
1200         // being read by uRead, this is not a conflict.
1201         //
1202         // In the above example:
1203         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
1204         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
1205         //
1206         // The read of %t does not conflict with the write of the FillOp
1207         // (same aliases!) because the area that the FillOp operates on is
1208         // exactly the one that is *not* read via %t.
1209         return true;
1210 
1211       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
1212           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1213           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
1214         // Case 2: The read of the source tensor and the write to the dest
1215         // tensor via an InsertSliceOp is not a conflict if the read is
1216         // reading exactly that part of an equivalent tensor that the
1217         // InsertSliceOp is writing.
1218         //
1219         // In the above example:
1220         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
1221         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1222         return true;
1223     }
1224 
1225     // If uConflictingWrite is an InsertSliceOp...
1226     if (auto insertSliceOp =
1227             dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
1228       // As an example, consider the following IR.
1229       //
1230       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1231       // %1 = linalg.fill %cst, %0 {inplace= [true] }
1232       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1233       //     {inplace= [true] }
1234       // %3 = vector.transfer_read %1, %cst
1235       //
1236       // In the above example:
1237       // uRead             = OpOperand 0 (%1) of vector.transfer_read
1238       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1239       // lastWrite         = %1
1240       //
1241       // This is not a conflict because the InsertSliceOp overwrites the
1242       // memory segment of %1 with the exact same data. (Effectively, there
1243       // is no memory write here.)
1244       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1245           state.areEquivalentBufferizedValues(uRead->get(),
1246                                               insertSliceOp.getSource()) &&
1247           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
1248                                     insertSliceOp))
1249         return true;
1250 
1251     return false;
1252   }
1253 };
1254 
1255 } // namespace
1256 } // namespace scf
1257 } // namespace mlir
1258 
1259 void mlir::scf::registerBufferizableOpInterfaceExternalModels(
1260     DialectRegistry &registry) {
1261   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1262     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1263     ForOp::attachInterface<ForOpInterface>(*ctx);
1264     IfOp::attachInterface<IfOpInterface>(*ctx);
1265     ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
1266     ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1267         *ctx);
1268     PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
1269         *ctx);
1270     WhileOp::attachInterface<WhileOpInterface>(*ctx);
1271     YieldOp::attachInterface<YieldOpInterface>(*ctx);
1272   });
1273 }
1274