1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Utils/StaticValueUtils.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/IR/PatternMatch.h"
21 
22 using namespace mlir;
23 using namespace mlir::bufferization;
24 using namespace mlir::scf;
25 
26 namespace mlir {
27 namespace scf {
28 namespace {
29 
30 // bufferization.to_memref is not allowed to change the rank.
31 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
32 #ifndef NDEBUG
33   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
34   assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
35                                 rankedTensorType.getRank())) &&
36          "to_memref would be invalid: mismatching ranks");
37 #endif
38 }
39 
40 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
41 /// fully implemented at the moment.
42 struct ExecuteRegionOpInterface
43     : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
44                                                     scf::ExecuteRegionOp> {
45   SmallVector<OpOperand *>
46   getAliasingOpOperand(Operation *op, OpResult opResult,
47                        const AnalysisState &state) const {
48     // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
49     // any SSA value that is in scope. To allow for use-def chain traversal
50     // through ExecuteRegionOps in the analysis, the corresponding yield value
51     // is considered to be aliasing with the result.
52     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
53     size_t resultNum = std::distance(op->getOpResults().begin(),
54                                      llvm::find(op->getOpResults(), opResult));
55     // TODO: Support multiple blocks.
56     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
57            "expected exactly 1 block");
58     auto yieldOp = dyn_cast<scf::YieldOp>(
59         executeRegionOp.getRegion().front().getTerminator());
60     assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
61     return {&yieldOp->getOpOperand(resultNum)};
62   }
63 
64   // TODO: For better bufferization results, this could return `true` only if
65   // there is a memory write in the region.
66   bool isMemoryWrite(Operation *op, OpResult opResult,
67                      const AnalysisState &state) const {
68     // Similar to scf.if, results of this op are always considered memory writes
69     // in the analysis. This is a useful pattern for all ops that have tensor
70     // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
71     // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
72     // ops without OpOperands.
73     return true;
74   }
75 
76   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
77                           const BufferizationOptions &options) const {
78     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
79     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
80            "only 1 block supported");
81     auto yieldOp =
82         cast<scf::YieldOp>(executeRegionOp.getRegion().front().getTerminator());
83     TypeRange newResultTypes(yieldOp.getResults());
84 
85     // Create new op and move over region.
86     auto newOp =
87         rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
88     newOp.getRegion().takeBody(executeRegionOp.getRegion());
89 
90     // Update all uses of the old op.
91     rewriter.setInsertionPointAfter(newOp);
92     SmallVector<Value> newResults;
93     for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
94       if (it.value().isa<TensorType>()) {
95         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
96             executeRegionOp.getLoc(), newOp->getResult(it.index())));
97       } else {
98         newResults.push_back(newOp->getResult(it.index()));
99       }
100     }
101 
102     // Replace old op.
103     rewriter.replaceOp(executeRegionOp, newResults);
104 
105     return success();
106   }
107 
108   BufferRelation bufferRelation(Operation *op, OpResult opResult,
109                                 const AnalysisState &state) const {
110     return BufferRelation::Equivalent;
111   }
112 };
113 
114 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
115 struct IfOpInterface
116     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
117   SmallVector<OpOperand *>
118   getAliasingOpOperand(Operation *op, OpResult opResult,
119                        const AnalysisState &state) const {
120     // IfOps do not have tensor OpOperands. The yielded value can be any SSA
121     // value that is in scope. To allow for use-def chain traversal through
122     // IfOps in the analysis, both corresponding yield values from the then/else
123     // branches are considered to be aliasing with the result.
124     auto ifOp = cast<scf::IfOp>(op);
125     size_t resultNum = std::distance(op->getOpResults().begin(),
126                                      llvm::find(op->getOpResults(), opResult));
127     return {&ifOp.thenYield()->getOpOperand(resultNum),
128             &ifOp.elseYield()->getOpOperand(resultNum)};
129   }
130 
131   // TODO: For better bufferization results, this could return `true` only if
132   // there is a memory write in one (or both) of the branches. Since this is not
133   // allowed at the moment, we should never encounter scf.ifs that yield
134   // unmodified tensors. Such scf.yield ops could just fold away.
135   bool isMemoryWrite(Operation *op, OpResult opResult,
136                      const AnalysisState &state) const {
137     // IfOp results are always considered memory writes in the analysis. This
138     // design decision simplifies the analysis considerably. E.g., consider the
139     // following test case:
140     //
141     // %0 = "some_writing_op" : tensor<?xf32>
142     // %r = scf.if %c -> (tensor<?xf32>) {
143     //   scf.yield %0
144     // } else {
145     //   %1 = "another_writing_op"(%0) : tensor<?xf32>
146     // }
147     // "some_reading_op"(%r)
148     //
149     // "another_writing_op" in the above example should be able to bufferize
150     // inplace in the absence of another read of %0. However, if the scf.if op
151     // would not be considered a "write", the analysis would detect the
152     // following conflict:
153     //
154     // * read = some_reading_op
155     // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
156     // * conflictingWrite = %1
157     //
158     // For more details, check the "scf.IfOp" section of the design document.
159     return true;
160   }
161 
162   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
163                           const BufferizationOptions &options) const {
164     OpBuilder::InsertionGuard g(rewriter);
165     auto ifOp = cast<scf::IfOp>(op);
166     auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
167     auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
168 
169     // Reconcile type mismatches between then/else branches by inserting memref
170     // casts.
171     SmallVector<Value> thenResults, elseResults;
172     bool insertedCast = false;
173     for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) {
174       Value thenValue = thenYieldOp.getResults()[i];
175       Value elseValue = elseYieldOp.getResults()[i];
176       if (thenValue.getType() == elseValue.getType()) {
177         thenResults.push_back(thenValue);
178         elseResults.push_back(elseValue);
179         continue;
180       }
181 
182       // Type mismatch between then/else yield value. Cast both to a memref type
183       // with a fully dynamic layout map.
184       auto thenMemrefType = thenValue.getType().cast<BaseMemRefType>();
185       auto elseMemrefType = elseValue.getType().cast<BaseMemRefType>();
186       if (thenMemrefType.getMemorySpaceAsInt() !=
187           elseMemrefType.getMemorySpaceAsInt())
188         return op->emitError("inconsistent memory space on then/else branches");
189       rewriter.setInsertionPoint(thenYieldOp);
190       BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout(
191           ifOp.getResultTypes()[i].cast<TensorType>(),
192           thenMemrefType.getMemorySpaceAsInt());
193       thenResults.push_back(rewriter.create<memref::CastOp>(
194           thenYieldOp.getLoc(), memrefType, thenValue));
195       rewriter.setInsertionPoint(elseYieldOp);
196       elseResults.push_back(rewriter.create<memref::CastOp>(
197           elseYieldOp.getLoc(), memrefType, elseValue));
198       insertedCast = true;
199     }
200 
201     if (insertedCast) {
202       rewriter.setInsertionPoint(thenYieldOp);
203       rewriter.replaceOpWithNewOp<scf::YieldOp>(thenYieldOp, thenResults);
204       rewriter.setInsertionPoint(elseYieldOp);
205       rewriter.replaceOpWithNewOp<scf::YieldOp>(elseYieldOp, elseResults);
206     }
207 
208     // Create new op.
209     rewriter.setInsertionPoint(ifOp);
210     ValueRange resultsValueRange(thenResults);
211     TypeRange newTypes(resultsValueRange);
212     auto newIfOp =
213         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
214                                    /*withElseRegion=*/true);
215 
216     // Move over then/else blocks.
217     rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
218     rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
219 
220     // Replace op results.
221     replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
222 
223     return success();
224   }
225 
226   BufferRelation bufferRelation(Operation *op, OpResult opResult,
227                                 const AnalysisState &state) const {
228     // IfOp results are equivalent to their corresponding yield values if both
229     // yield values are equivalent to each other.
230     auto bufferizableOp = cast<BufferizableOpInterface>(op);
231     SmallVector<OpOperand *> yieldValues =
232         bufferizableOp.getAliasingOpOperand(opResult, state);
233     assert(yieldValues.size() == 2 && "expected 2 yield values");
234     bool equivalentYields = state.areEquivalentBufferizedValues(
235         yieldValues[0]->get(), yieldValues[1]->get());
236     return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
237   }
238 };
239 
240 /// Helper function for loop bufferization. Return the indices of all values
241 /// that have a tensor type.
242 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
243   DenseSet<int64_t> result;
244   for (const auto &it : llvm::enumerate(values))
245     if (it.value().getType().isa<TensorType>())
246       result.insert(it.index());
247   return result;
248 }
249 
250 /// Helper function for loop bufferization. Return the indices of all
251 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
252 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
253                                        ValueRange yieldedValues,
254                                        const AnalysisState &state) {
255   unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
256   DenseSet<int64_t> result;
257   for (unsigned int i = 0; i < minSize; ++i) {
258     if (!bbArgs[i].getType().isa<TensorType>() ||
259         !yieldedValues[i].getType().isa<TensorType>())
260       continue;
261     if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
262       result.insert(i);
263   }
264   return result;
265 }
266 
267 /// Helper function for loop bufferization. Cast the given buffer to the given
268 /// memref type.
269 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
270   assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
271   assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
272   // If the buffer already has the correct type, no cast is needed.
273   if (buffer.getType() == type)
274     return buffer;
275   // TODO: In case `type` has a layout map that is not the fully dynamic
276   // one, we may not be able to cast the buffer. In that case, the loop
277   // iter_arg's layout map must be changed (see uses of `castBuffer`).
278   assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
279          "scf.while op bufferization: cast incompatible");
280   return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
281 }
282 
283 /// Helper function for loop bufferization. Return the bufferized values of the
284 /// given OpOperands. If an operand is not a tensor, return the original value.
285 static FailureOr<SmallVector<Value>>
286 getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
287            const BufferizationOptions &options) {
288   SmallVector<Value> result;
289   for (OpOperand &opOperand : operands) {
290     if (opOperand.get().getType().isa<TensorType>()) {
291       FailureOr<Value> resultBuffer =
292           getBuffer(rewriter, opOperand.get(), options);
293       if (failed(resultBuffer))
294         return failure();
295       result.push_back(*resultBuffer);
296     } else {
297       result.push_back(opOperand.get());
298     }
299   }
300   return result;
301 }
302 
303 /// Helper function for loop bufferization. Compute the buffer that should be
304 /// yielded from a loop block (loop body or loop condition).
305 static FailureOr<Value> getYieldedBuffer(RewriterBase &rewriter, Value tensor,
306                                          BaseMemRefType type,
307                                          const BufferizationOptions &options) {
308   assert(tensor.getType().isa<TensorType>() && "expected tensor");
309   ensureToMemrefOpIsValid(tensor, type);
310   FailureOr<Value> yieldedVal = getBuffer(rewriter, tensor, options);
311   if (failed(yieldedVal))
312     return failure();
313   return castBuffer(rewriter, *yieldedVal, type);
314 }
315 
316 /// Helper function for loop bufferization. Given a range of values, apply
317 /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
318 /// value in the result vector.
319 static FailureOr<SmallVector<Value>>
320 convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
321                     llvm::function_ref<FailureOr<Value>(Value, int64_t)> func) {
322   SmallVector<Value> result;
323   for (const auto &it : llvm::enumerate(values)) {
324     size_t idx = it.index();
325     Value val = it.value();
326     if (tensorIndices.contains(idx)) {
327       FailureOr<Value> maybeVal = func(val, idx);
328       if (failed(maybeVal))
329         return failure();
330       result.push_back(*maybeVal);
331     } else {
332       result.push_back(val);
333     }
334   }
335   return result;
336 }
337 
338 /// Helper function for loop bufferization. Given a list of pre-bufferization
339 /// yielded values, compute the list of bufferized yielded values.
340 FailureOr<SmallVector<Value>>
341 getYieldedValues(RewriterBase &rewriter, ValueRange values,
342                  TypeRange bufferizedTypes,
343                  const DenseSet<int64_t> &tensorIndices,
344                  const BufferizationOptions &options) {
345   return convertTensorValues(
346       values, tensorIndices, [&](Value val, int64_t index) {
347         return getYieldedBuffer(rewriter, val,
348                                 bufferizedTypes[index].cast<BaseMemRefType>(),
349                                 options);
350       });
351 }
352 
353 /// Helper function for loop bufferization. Given a list of bbArgs of the new
354 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
355 /// ToTensorOps, so that the block body can be moved over to the new op.
356 SmallVector<Value>
357 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
358                      const DenseSet<int64_t> &tensorIndices) {
359   SmallVector<Value> result;
360   for (const auto &it : llvm::enumerate(bbArgs)) {
361     size_t idx = it.index();
362     Value val = it.value();
363     if (tensorIndices.contains(idx)) {
364       result.push_back(
365           rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
366               .getResult());
367     } else {
368       result.push_back(val);
369     }
370   }
371   return result;
372 }
373 
374 /// Bufferization of scf.for. Replace with a new scf.for that operates on
375 /// memrefs.
376 struct ForOpInterface
377     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
378                                                     scf::ForOp> {
379   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
380                               const AnalysisState &state) const {
381     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
382     // its matching bbArg may.
383     auto forOp = cast<scf::ForOp>(op);
384     return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
385   }
386 
387   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
388                                const AnalysisState &state) const {
389     // Tensor iter_args of scf::ForOps are always considered as a write.
390     return true;
391   }
392 
393   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
394                                             const AnalysisState &state) const {
395     auto forOp = cast<scf::ForOp>(op);
396     return {forOp.getResultForOpOperand(opOperand)};
397   }
398 
399   BufferRelation bufferRelation(Operation *op, OpResult opResult,
400                                 const AnalysisState &state) const {
401     // ForOp results are equivalent to their corresponding init_args if the
402     // corresponding iter_args and yield values are equivalent.
403     auto forOp = cast<scf::ForOp>(op);
404     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
405     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
406     auto yieldOp =
407         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
408     bool equivalentYield = state.areEquivalentBufferizedValues(
409         bbArg, yieldOp->getOperand(opResult.getResultNumber()));
410     return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
411   }
412 
413   bool isWritable(Operation *op, Value value,
414                   const AnalysisState &state) const {
415     // Interestingly, scf::ForOp's bbArg can **always** be viewed
416     // inplace from the perspective of ops nested under:
417     //   1. Either the matching iter operand is not bufferized inplace and an
418     //      alloc + optional copy makes the bbArg itself inplaceable.
419     //   2. Or the matching iter operand is bufferized inplace and bbArg just
420     //      bufferizes to that too.
421     return true;
422   }
423 
424   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
425                                  const AnalysisState &state) const {
426     auto bufferizableOp = cast<BufferizableOpInterface>(op);
427     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
428       return failure();
429 
430     if (!state.getOptions().enforceAliasingInvariants)
431       return success();
432 
433     // According to the `getAliasing...` implementations, a bufferized OpResult
434     // may alias only with the corresponding bufferized init_arg and with no
435     // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
436     // but not with any other OpOperand. If a corresponding OpResult/init_arg
437     // pair bufferizes to equivalent buffers, this aliasing requirement is
438     // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
439     // (New buffer copies do not alias with any buffer.)
440     auto forOp = cast<scf::ForOp>(op);
441     auto yieldOp =
442         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
443     OpBuilder::InsertionGuard g(rewriter);
444     rewriter.setInsertionPoint(yieldOp);
445 
446     // Indices of all iter_args that have tensor type. These are the ones that
447     // are bufferized.
448     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
449     // For every yielded value, is the value equivalent to its corresponding
450     // bbArg?
451     DenseSet<int64_t> equivalentYields = getEquivalentBuffers(
452         forOp.getRegionIterArgs(), yieldOp.getResults(), state);
453     SmallVector<Value> yieldValues;
454     for (int64_t idx = 0;
455          idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
456       Value value = yieldOp.getResults()[idx];
457       if (!indices.contains(idx) || equivalentYields.contains(idx)) {
458         yieldValues.push_back(value);
459         continue;
460       }
461       Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(),
462                                                  value, /*escape=*/true);
463       yieldValues.push_back(alloc);
464     }
465 
466     rewriter.updateRootInPlace(
467         yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
468     return success();
469   }
470 
471   FailureOr<BaseMemRefType>
472   getBufferType(Operation *op, BlockArgument bbArg,
473                 const BufferizationOptions &options) const {
474     auto forOp = cast<scf::ForOp>(op);
475     return bufferization::getBufferType(
476         forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
477   }
478 
479   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
480                           const BufferizationOptions &options) const {
481     auto forOp = cast<scf::ForOp>(op);
482     Block *oldLoopBody = &forOp.getLoopBody().front();
483 
484     // Indices of all iter_args that have tensor type. These are the ones that
485     // are bufferized.
486     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
487 
488     // The new memref init_args of the loop.
489     FailureOr<SmallVector<Value>> maybeInitArgs =
490         getBuffers(rewriter, forOp.getIterOpOperands(), options);
491     if (failed(maybeInitArgs))
492       return failure();
493     SmallVector<Value> initArgs = *maybeInitArgs;
494 
495     // Construct a new scf.for op with memref instead of tensor values.
496     auto newForOp = rewriter.create<scf::ForOp>(
497         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
498         forOp.getStep(), initArgs);
499     newForOp->setAttrs(forOp->getAttrs());
500     ValueRange initArgsRange(initArgs);
501     TypeRange initArgsTypes(initArgsRange);
502     Block *loopBody = &newForOp.getLoopBody().front();
503 
504     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
505     // iter_args of the new loop in ToTensorOps.
506     rewriter.setInsertionPointToStart(loopBody);
507     SmallVector<Value> iterArgs =
508         getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
509     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
510 
511     // Move loop body to new loop.
512     rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
513 
514     // Replace loop results.
515     replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
516 
517     return success();
518   }
519 
520   /// Assert that yielded values of an scf.for op are equivalent to their
521   /// corresponding bbArgs. In that case, the buffer relations of the
522   /// corresponding OpResults are "Equivalent".
523   ///
524   /// If this is not the case, an allocs+copies are inserted and yielded from
525   /// the loop. This could be a performance problem, so it must be explicitly
526   /// activated with `alloc-return-allocs`.
527   LogicalResult verifyAnalysis(Operation *op,
528                                const AnalysisState &state) const {
529     const auto &options =
530         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
531     if (options.allowReturnAllocs)
532       return success();
533 
534     auto forOp = cast<scf::ForOp>(op);
535     auto yieldOp =
536         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
537     for (OpResult opResult : op->getOpResults()) {
538       if (!opResult.getType().isa<TensorType>())
539         continue;
540 
541       // Note: This is overly strict. We should check for aliasing bufferized
542       // values. But we don't have a "must-alias" analysis yet.
543       if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
544         return yieldOp->emitError()
545                << "Yield operand #" << opResult.getResultNumber()
546                << " is not equivalent to the corresponding iter bbArg";
547     }
548 
549     return success();
550   }
551 };
552 
553 /// Bufferization of scf.while. Replace with a new scf.while that operates on
554 /// memrefs.
555 struct WhileOpInterface
556     : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
557                                                     scf::WhileOp> {
558   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
559                               const AnalysisState &state) const {
560     // Tensor iter_args of scf::WhileOps are always considered as a read.
561     return true;
562   }
563 
564   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
565                                const AnalysisState &state) const {
566     // Tensor iter_args of scf::WhileOps are always considered as a write.
567     return true;
568   }
569 
570   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
571                                             const AnalysisState &state) const {
572     auto whileOp = cast<scf::WhileOp>(op);
573     unsigned int idx = opOperand.getOperandNumber();
574 
575     // The OpResults and OpOperands may not match. They may not even have the
576     // same type. The number of OpResults and OpOperands can also differ.
577     if (idx >= op->getNumResults() ||
578         opOperand.get().getType() != op->getResult(idx).getType())
579       return {};
580 
581     // The only aliasing OpResult may be the one at the same index.
582     return {whileOp->getResult(idx)};
583   }
584 
585   BufferRelation bufferRelation(Operation *op, OpResult opResult,
586                                 const AnalysisState &state) const {
587     // WhileOp results are equivalent to their corresponding init_args if the
588     // corresponding iter_args and yield values are equivalent (for both the
589     // "before" and the "after" block).
590     unsigned int resultNumber = opResult.getResultNumber();
591     auto whileOp = cast<scf::WhileOp>(op);
592 
593     // The "before" region bbArgs and the OpResults may not match.
594     if (resultNumber >= whileOp.getBeforeArguments().size())
595       return BufferRelation::None;
596     if (opResult.getType() !=
597         whileOp.getBeforeArguments()[resultNumber].getType())
598       return BufferRelation::None;
599 
600     auto conditionOp = whileOp.getConditionOp();
601     BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
602     Value conditionOperand = conditionOp.getArgs()[resultNumber];
603     bool equivCondition =
604         state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
605 
606     auto yieldOp = whileOp.getYieldOp();
607     BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
608     Value yieldOperand = yieldOp.getOperand(resultNumber);
609     bool equivYield =
610         state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
611 
612     return equivCondition && equivYield ? BufferRelation::Equivalent
613                                         : BufferRelation::None;
614   }
615 
616   bool isWritable(Operation *op, Value value,
617                   const AnalysisState &state) const {
618     // Interestingly, scf::WhileOp's bbArg can **always** be viewed
619     // inplace from the perspective of ops nested under:
620     //   1. Either the matching iter operand is not bufferized inplace and an
621     //      alloc + optional copy makes the bbArg itself inplaceable.
622     //   2. Or the matching iter operand is bufferized inplace and bbArg just
623     //      bufferizes to that too.
624     return true;
625   }
626 
627   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
628                                  const AnalysisState &state) const {
629     auto bufferizableOp = cast<BufferizableOpInterface>(op);
630     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
631       return failure();
632 
633     if (!state.getOptions().enforceAliasingInvariants)
634       return success();
635 
636     // According to the `getAliasing...` implementations, a bufferized OpResult
637     // may alias only with the corresponding bufferized init_arg and with no
638     // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
639     // but not with any other OpOperand. If a corresponding OpResult/init_arg
640     // pair bufferizes to equivalent buffers, this aliasing requirement is
641     // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
642     // (New buffer copies do not alias with any buffer.)
643     OpBuilder::InsertionGuard g(rewriter);
644     auto whileOp = cast<scf::WhileOp>(op);
645     auto conditionOp = whileOp.getConditionOp();
646     auto yieldOp = whileOp.getYieldOp();
647 
648     // Indices of all bbArgs that have tensor type. These are the ones that
649     // are bufferized. The "before" and "after" regions may have different args.
650     DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
651     DenseSet<int64_t> indicesAfter =
652         getTensorIndices(whileOp.getAfterArguments());
653 
654     // For every yielded value, is the value equivalent to its corresponding
655     // bbArg?
656     DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
657         whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
658     DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
659         whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
660 
661     // Update "before" region.
662     rewriter.setInsertionPoint(conditionOp);
663     SmallVector<Value> beforeYieldValues;
664     for (int64_t idx = 0;
665          idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
666       Value value = conditionOp.getArgs()[idx];
667       if (!indicesBefore.contains(idx) ||
668           equivalentYieldsBefore.contains(idx)) {
669         beforeYieldValues.push_back(value);
670         continue;
671       }
672       Value alloc = allocateTensorForShapedValue(rewriter, conditionOp.getLoc(),
673                                                  value, /*escape=*/true);
674       beforeYieldValues.push_back(alloc);
675     }
676     rewriter.updateRootInPlace(conditionOp, [&]() {
677       conditionOp.getArgsMutable().assign(beforeYieldValues);
678     });
679 
680     // Update "after" region.
681     rewriter.setInsertionPoint(yieldOp);
682     SmallVector<Value> afterYieldValues;
683     for (int64_t idx = 0;
684          idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
685       Value value = yieldOp.getResults()[idx];
686       if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) {
687         afterYieldValues.push_back(value);
688         continue;
689       }
690       Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(),
691                                                  value, /*escape=*/true);
692       afterYieldValues.push_back(alloc);
693     }
694     rewriter.updateRootInPlace(yieldOp, [&]() {
695       yieldOp.getResultsMutable().assign(afterYieldValues);
696     });
697 
698     return success();
699   }
700 
701   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
702                           const BufferizationOptions &options) const {
703     auto whileOp = cast<scf::WhileOp>(op);
704 
705     assert(whileOp.getBefore().getBlocks().size() == 1 &&
706            "regions with multiple blocks not supported");
707     Block *beforeBody = &whileOp.getBefore().front();
708     assert(whileOp.getAfter().getBlocks().size() == 1 &&
709            "regions with multiple blocks not supported");
710     Block *afterBody = &whileOp.getAfter().front();
711 
712     // Indices of all bbArgs that have tensor type. These are the ones that
713     // are bufferized. The "before" and "after" regions may have different args.
714     DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
715     DenseSet<int64_t> indicesAfter =
716         getTensorIndices(whileOp.getAfterArguments());
717 
718     // The new memref init_args of the loop.
719     FailureOr<SmallVector<Value>> maybeInitArgs =
720         getBuffers(rewriter, whileOp->getOpOperands(), options);
721     if (failed(maybeInitArgs))
722       return failure();
723     SmallVector<Value> initArgs = *maybeInitArgs;
724 
725     // The result types of a WhileOp are the same as the "after" bbArg types.
726     SmallVector<Type> argsTypesAfter = llvm::to_vector(
727         llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
728           // TODO: error handling
729           return bufferization::getBufferType(bbArg, options)->cast<Type>();
730         }));
731 
732     // Construct a new scf.while op with memref instead of tensor values.
733     ValueRange argsRangeBefore(initArgs);
734     TypeRange argsTypesBefore(argsRangeBefore);
735     auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
736                                                     argsTypesAfter, initArgs);
737 
738     // Add before/after regions to the new op.
739     SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc());
740     SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
741                                          whileOp.getLoc());
742     Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
743     newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
744     Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
745     newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
746 
747     // Set up new iter_args and move the loop condition block to the new op.
748     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
749     // in ToTensorOps.
750     rewriter.setInsertionPointToStart(newBeforeBody);
751     SmallVector<Value> newBeforeArgs = getBbArgReplacements(
752         rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
753     rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
754 
755     // Update scf.condition of new loop.
756     auto newConditionOp = newWhileOp.getConditionOp();
757     rewriter.setInsertionPoint(newConditionOp);
758     // Only equivalent buffers or new buffer allocations may be yielded to the
759     // "after" region.
760     // TODO: This could be relaxed for better bufferization results.
761     FailureOr<SmallVector<Value>> newConditionArgs =
762         getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
763                          indicesAfter, options);
764     if (failed(newConditionArgs))
765       return failure();
766     newConditionOp.getArgsMutable().assign(*newConditionArgs);
767 
768     // Set up new iter_args and move the loop body block to the new op.
769     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
770     // in ToTensorOps.
771     rewriter.setInsertionPointToStart(newAfterBody);
772     SmallVector<Value> newAfterArgs = getBbArgReplacements(
773         rewriter, newWhileOp.getAfterArguments(), indicesAfter);
774     rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
775 
776     // Update scf.yield of the new loop.
777     auto newYieldOp = newWhileOp.getYieldOp();
778     rewriter.setInsertionPoint(newYieldOp);
779     // Only equivalent buffers or new buffer allocations may be yielded to the
780     // "before" region.
781     // TODO: This could be relaxed for better bufferization results.
782     FailureOr<SmallVector<Value>> newYieldValues =
783         getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
784                          indicesBefore, options);
785     if (failed(newYieldValues))
786       return failure();
787     newYieldOp.getResultsMutable().assign(*newYieldValues);
788 
789     // Replace loop results.
790     replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
791 
792     return success();
793   }
794 
795   /// Assert that yielded values of an scf.while op are equivalent to their
796   /// corresponding bbArgs. In that case, the buffer relations of the
797   /// corresponding OpResults are "Equivalent".
798   ///
799   /// If this is not the case, allocs+copies are inserted and yielded from
800   /// the loop. This could be a performance problem, so it must be explicitly
801   /// activated with `alloc-return-allocs`.
802   ///
803   /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
804   /// equivalence condition must be checked for both.
805   LogicalResult verifyAnalysis(Operation *op,
806                                const AnalysisState &state) const {
807     auto whileOp = cast<scf::WhileOp>(op);
808     const auto &options =
809         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
810     if (options.allowReturnAllocs)
811       return success();
812 
813     auto conditionOp = whileOp.getConditionOp();
814     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
815       if (!it.value().getType().isa<TensorType>())
816         continue;
817       if (!state.areEquivalentBufferizedValues(
818               it.value(), conditionOp->getBlock()->getArgument(it.index())))
819         return conditionOp->emitError()
820                << "Condition arg #" << it.index()
821                << " is not equivalent to the corresponding iter bbArg";
822     }
823 
824     auto yieldOp = whileOp.getYieldOp();
825     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
826       if (!it.value().getType().isa<TensorType>())
827         continue;
828       if (!state.areEquivalentBufferizedValues(
829               it.value(), yieldOp->getBlock()->getArgument(it.index())))
830         return yieldOp->emitError()
831                << "Yield operand #" << it.index()
832                << " is not equivalent to the corresponding iter bbArg";
833     }
834 
835     return success();
836   }
837 };
838 
839 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
840 /// this is for analysis only.
841 struct YieldOpInterface
842     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
843                                                     scf::YieldOp> {
844   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
845                               const AnalysisState &state) const {
846     return true;
847   }
848 
849   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
850                                const AnalysisState &state) const {
851     return false;
852   }
853 
854   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
855                                             const AnalysisState &state) const {
856     if (isa<scf::IfOp>(op->getParentOp()))
857       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
858     if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
859       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
860     return {};
861   }
862 
863   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
864                             const AnalysisState &state) const {
865     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
866     // may be generated inside the block. We should not return/yield allocations
867     // when possible.
868     return true;
869   }
870 
871   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
872                           const BufferizationOptions &options) const {
873     auto yieldOp = cast<scf::YieldOp>(op);
874     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
875             yieldOp->getParentOp()))
876       return yieldOp->emitError("unsupported scf::YieldOp parent");
877 
878     // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized
879     // together with scf.while.)
880     if (isa<scf::WhileOp>(yieldOp->getParentOp()))
881       return success();
882 
883     SmallVector<Value> newResults;
884     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
885       Value value = it.value();
886       if (value.getType().isa<TensorType>()) {
887         FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
888         if (failed(maybeBuffer))
889           return failure();
890         Value buffer = *maybeBuffer;
891         if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
892           FailureOr<BaseMemRefType> resultType =
893               cast<BufferizableOpInterface>(forOp.getOperation())
894                   .getBufferType(forOp.getRegionIterArgs()[it.index()],
895                                  options);
896           if (failed(resultType))
897             return failure();
898           buffer = castBuffer(rewriter, buffer, *resultType);
899         }
900         newResults.push_back(buffer);
901       } else {
902         newResults.push_back(value);
903       }
904     }
905 
906     replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
907     return success();
908   }
909 };
910 
911 using tensor::ExtractSliceOp;
912 
913 /// Return the destinations that an ForeachThreadOp is inserting into. One per
914 /// ParallelInsertSliceOp.
915 static SmallVector<OpOperand *>
916 getInsertionDest(ForeachThreadOp foreachThreadOp) {
917   PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
918   SmallVector<OpOperand *> result;
919   terminator.walk([&](ParallelInsertSliceOp insertOp) {
920     result.push_back(&insertOp->getOpOperand(1) /*dest*/);
921   });
922   return result;
923 }
924 
925 /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
926 /// region. There are op interfaces for the terminators (PerformConcurrentlyOp
927 /// and ParallelInsertSliceOp), but these are only used during analysis. Not
928 /// for bufferization.
929 struct ForeachThreadOpInterface
930     : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface,
931                                                     ForeachThreadOp> {
932   SmallVector<OpOperand *>
933   getAliasingOpOperand(Operation *op, OpResult opResult,
934                        const AnalysisState &state) const {
935     // Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
936     auto foreachThreadOp = cast<ForeachThreadOp>(op);
937     return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]};
938   }
939 
940   bool isMemoryWrite(Operation *op, OpResult opResult,
941                      const AnalysisState &state) const {
942     // This op is a memory write. Stop lookup here to avoid finding false
943     // conflicts involving this op and one of the ops in the region. This is
944     // similar to how scf.if ops are analyzed.
945     return true;
946   }
947 
948   BufferRelation bufferRelation(Operation *op, OpResult opResult,
949                                 const AnalysisState &state) const {
950     return BufferRelation::Equivalent;
951   }
952 
953   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
954                                  const AnalysisState &state) const {
955     auto bufferizableOp = cast<BufferizableOpInterface>(op);
956     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
957       return failure();
958 
959     OpBuilder::InsertionGuard g(rewriter);
960     auto foreachThreadOp = cast<ForeachThreadOp>(op);
961     for (OpResult opResult : foreachThreadOp->getOpResults()) {
962       SmallVector<OpOperand *> destOperands =
963           state.getAliasingOpOperand(opResult);
964       assert(destOperands.size() == 1 &&
965              "expected exactly one aliasing OpOperand");
966       assert(isa<ParallelInsertSliceOp>(destOperands.front()->getOwner()) &&
967              "expected ParallelInsertSliceOp");
968 
969       // Nothing to do if there is no conflict.
970       if (state.isInPlace(*destOperands.front()))
971         continue;
972 
973       // Insert tensor allocation.
974       bool isYielded = state.isTensorYielded(opResult);
975       Value alloc = allocateTensorForShapedValue(rewriter, op->getLoc(),
976                                                  destOperands.front()->get(),
977                                                  /*escape=*/isYielded);
978 
979       // Update terminator operand.
980       rewriter.updateRootInPlace(destOperands.front()->getOwner(),
981                                  [&]() { destOperands.front()->set(alloc); });
982     }
983 
984     return success();
985   }
986 
987   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
988                           const BufferizationOptions &options) const {
989     auto foreachThreadOp = cast<ForeachThreadOp>(op);
990 
991 #ifndef NDEBUG
992     // ParallelInsertSliceOpInterface replaces all uses.
993     for (OpResult opResult : foreachThreadOp->getOpResults())
994       assert(opResult.getUses().empty() &&
995              "expected that all uses were already replaced");
996 #endif // NDEBUG
997 
998     // Create new ForeachThreadOp without any results and drop the automatically
999     // introduced terminator.
1000     TypeRange newResultTypes;
1001     auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
1002         foreachThreadOp.getLoc(), newResultTypes,
1003         foreachThreadOp.getNumThreads(),
1004         extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping()));
1005     newForeachThreadOp.getBody()->getTerminator()->erase();
1006 
1007     // Move over block contents of the old op.
1008     rewriter.mergeBlocks(foreachThreadOp.getBody(),
1009                          newForeachThreadOp.getBody(),
1010                          {newForeachThreadOp.getBody()->getArguments()});
1011 
1012     // Remove the old op.
1013     rewriter.eraseOp(op);
1014 
1015     return success();
1016   }
1017 };
1018 
1019 /// Nothing to do for PerformConcurrentlyOp.
1020 struct PerformConcurrentlyOpInterface
1021     : public BufferizableOpInterface::ExternalModel<
1022           PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
1023   LogicalResult bufferize(Operation *op, RewriterBase &b,
1024                           const BufferizationOptions &options) const {
1025     llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1026     return failure();
1027   }
1028 };
1029 
1030 /// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
1031 /// equivalent operand / result and same offset/sizes/strides specification).
1032 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
1033                                          ExtractSliceOp st,
1034                                          ParallelInsertSliceOp sti) {
1035   if (!st || !sti)
1036     return false;
1037   if (st != sti &&
1038       !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
1039     return false;
1040   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
1041     return false;
1042   return true;
1043 }
1044 
1045 /// Return true if `value` is originating from an ExtractSliceOp that matches
1046 /// the given InsertSliceOp.
1047 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
1048                                       ParallelInsertSliceOp insertOp) {
1049   auto condition = [&](Value val) {
1050     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
1051       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
1052         return true;
1053     return false;
1054   };
1055 
1056   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
1057                       condition);
1058 }
1059 
1060 /// Analysis of ParallelInsertSliceOp.
1061 struct ParallelInsertSliceOpInterface
1062     : public BufferizableOpInterface::ExternalModel<
1063           ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
1064   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
1065                                             const AnalysisState &state) const {
1066     if (&opOperand != &op->getOpOperand(1) /*dest*/)
1067       return {};
1068 
1069     // ParallelInsertSliceOp itself has no results. Tensors are returned via
1070     // the parent op.
1071     auto foreachThreadOp = op->getParentOfType<ForeachThreadOp>();
1072     assert(foreachThreadOp &&
1073            "could not find valid owner of parallel_insert_slice");
1074 
1075     // The i-th ParallelInsertSliceOp result is returned via the i-th OpResult
1076     // of the parent ForeachThreadOp.
1077     Block *block = op->getBlock();
1078     unsigned int opIdx = 0;
1079     for (ParallelInsertSliceOp insertOp :
1080          block->getOps<ParallelInsertSliceOp>()) {
1081       if (insertOp.getOperation() == op)
1082         break;
1083       ++opIdx;
1084     }
1085     assert(opIdx < foreachThreadOp->getNumResults() &&
1086            "could not find op inside terminator op");
1087 
1088     return {foreachThreadOp->getResult(opIdx)};
1089   }
1090 
1091   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1092                               const AnalysisState &state) const {
1093     return true;
1094   }
1095 
1096   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1097                                const AnalysisState &state) const {
1098     return &opOperand == &op->getOpOperand(1) /*dest*/;
1099   }
1100 
1101   BufferRelation bufferRelation(Operation *op, OpResult opResult,
1102                                 const AnalysisState &state) const {
1103     return BufferRelation::Equivalent;
1104   }
1105 
1106   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
1107                                  const AnalysisState &state) const {
1108     return success();
1109   }
1110 
1111   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1112                           const BufferizationOptions &options) const {
1113     OpBuilder::InsertionGuard g(rewriter);
1114     auto insertOp = cast<ParallelInsertSliceOp>(op);
1115     auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(op->getParentOp());
1116     auto foreachThreadOp =
1117         cast<ForeachThreadOp>(performConcurrentlyOp->getParentOp());
1118 
1119     // If the op bufferizes out-of-place, allocate the copy before the
1120     // ForeachThreadOp.
1121     rewriter.setInsertionPoint(foreachThreadOp);
1122     FailureOr<Value> destBuffer =
1123         getBuffer(rewriter, insertOp.getDest(), options);
1124     if (failed(destBuffer))
1125       return failure();
1126 
1127     // Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp.
1128     rewriter.setInsertionPoint(performConcurrentlyOp);
1129     FailureOr<Value> srcBuffer =
1130         getBuffer(rewriter, insertOp.getSource(), options);
1131     if (failed(srcBuffer))
1132       return failure();
1133     Value subview = rewriter.create<memref::SubViewOp>(
1134         insertOp.getLoc(), *destBuffer, insertOp.getMixedOffsets(),
1135         insertOp.getMixedSizes(), insertOp.getMixedStrides());
1136     // This memcpy will fold away if everything bufferizes in-place.
1137     if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer,
1138                                     subview)))
1139       return failure();
1140     rewriter.eraseOp(op);
1141 
1142     // Replace all uses of ForeachThreadOp (just the corresponding result).
1143     rewriter.setInsertionPointAfter(foreachThreadOp);
1144     Value toTensorOp =
1145         rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer);
1146     unsigned resultNum = 0;
1147     for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) {
1148       if (&nextOp == op)
1149         break;
1150       resultNum++;
1151     }
1152     assert(resultNum < foreachThreadOp->getNumResults() &&
1153            "ParallelInsertSliceOp not found in PerformConcurrentlyOp");
1154     SmallVector<OpOperand *> resultUses = llvm::to_vector(
1155         llvm::map_range(foreachThreadOp->getResult(resultNum).getUses(),
1156                         [](OpOperand &use) { return &use; }));
1157     for (OpOperand *use : resultUses) {
1158       rewriter.updateRootInPlace(use->getOwner(),
1159                                  [&]() { use->set(toTensorOp); });
1160     }
1161     return success();
1162   }
1163 
1164   // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
1165   // the code.
1166   bool isNotConflicting(Operation *op, OpOperand *uRead,
1167                         OpOperand *uConflictingWrite,
1168                         const AnalysisState &state) const {
1169     Operation *readingOp = uRead->getOwner();
1170     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
1171 
1172     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
1173     // uRead is an InsertSliceOp...
1174     if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
1175       // As an example, consider the following IR.
1176       //
1177       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1178       // %1 = linalg.fill %cst, %0 {inplace= [true] }
1179       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1180       //     {inplace= [true] }
1181 
1182       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
1183       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1184           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
1185                                     insertSliceOp))
1186         // Case 1: The main insight is that InsertSliceOp reads only part of
1187         // the destination tensor. The overwritten area is not read. If
1188         // uConflictingWrite writes into exactly the memory location that is
1189         // being read by uRead, this is not a conflict.
1190         //
1191         // In the above example:
1192         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
1193         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
1194         //
1195         // The read of %t does not conflict with the write of the FillOp
1196         // (same aliases!) because the area that the FillOp operates on is
1197         // exactly the one that is *not* read via %t.
1198         return true;
1199 
1200       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
1201           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1202           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
1203         // Case 2: The read of the source tensor and the write to the dest
1204         // tensor via an InsertSliceOp is not a conflict if the read is
1205         // reading exactly that part of an equivalent tensor that the
1206         // InsertSliceOp is writing.
1207         //
1208         // In the above example:
1209         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
1210         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1211         return true;
1212     }
1213 
1214     // If uConflictingWrite is an InsertSliceOp...
1215     if (auto insertSliceOp =
1216             dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
1217       // As an example, consider the following IR.
1218       //
1219       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1220       // %1 = linalg.fill %cst, %0 {inplace= [true] }
1221       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1222       //     {inplace= [true] }
1223       // %3 = vector.transfer_read %1, %cst
1224       //
1225       // In the above example:
1226       // uRead             = OpOperand 0 (%1) of vector.transfer_read
1227       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1228       // lastWrite         = %1
1229       //
1230       // This is not a conflict because the InsertSliceOp overwrites the
1231       // memory segment of %1 with the exact same data. (Effectively, there
1232       // is no memory write here.)
1233       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1234           state.areEquivalentBufferizedValues(uRead->get(),
1235                                               insertSliceOp.getSource()) &&
1236           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
1237                                     insertSliceOp))
1238         return true;
1239 
1240     return false;
1241   }
1242 };
1243 
1244 } // namespace
1245 } // namespace scf
1246 } // namespace mlir
1247 
1248 void mlir::scf::registerBufferizableOpInterfaceExternalModels(
1249     DialectRegistry &registry) {
1250   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1251     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1252     ForOp::attachInterface<ForOpInterface>(*ctx);
1253     IfOp::attachInterface<IfOpInterface>(*ctx);
1254     ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
1255     ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1256         *ctx);
1257     PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
1258         *ctx);
1259     WhileOp::attachInterface<WhileOpInterface>(*ctx);
1260     YieldOp::attachInterface<YieldOpInterface>(*ctx);
1261   });
1262 }
1263