1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/Dialect.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/PatternMatch.h"
20 
21 using namespace mlir;
22 using namespace mlir::bufferization;
23 using namespace mlir::scf;
24 
25 namespace mlir {
26 namespace scf {
27 namespace {
28 
29 // bufferization.to_memref is not allowed to change the rank.
30 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
31 #ifndef NDEBUG
32   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
33   assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
34                                 rankedTensorType.getRank())) &&
35          "to_memref would be invalid: mismatching ranks");
36 #endif
37 }
38 
39 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
40 /// fully implemented at the moment.
41 struct ExecuteRegionOpInterface
42     : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
43                                                     scf::ExecuteRegionOp> {
44   SmallVector<OpOperand *>
45   getAliasingOpOperand(Operation *op, OpResult opResult,
46                        const AnalysisState &state) const {
47     // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
48     // any SSA value that is in scope. To allow for use-def chain traversal
49     // through ExecuteRegionOps in the analysis, the corresponding yield value
50     // is considered to be aliasing with the result.
51     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
52     size_t resultNum = std::distance(op->getOpResults().begin(),
53                                      llvm::find(op->getOpResults(), opResult));
54     // TODO: Support multiple blocks.
55     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
56            "expected exactly 1 block");
57     auto yieldOp = dyn_cast<scf::YieldOp>(
58         executeRegionOp.getRegion().front().getTerminator());
59     assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
60     return {&yieldOp->getOpOperand(resultNum)};
61   }
62 
63   // TODO: For better bufferization results, this could return `true` only if
64   // there is a memory write in the region.
65   bool isMemoryWrite(Operation *op, OpResult opResult,
66                      const AnalysisState &state) const {
67     // Similar to scf.if, results of this op are always considered memory writes
68     // in the analysis. This is a useful pattern for all ops that have tensor
69     // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
70     // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
71     // ops without OpOperands.
72     return true;
73   }
74 
75   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
76                           const BufferizationOptions &options) const {
77     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
78 
79     // Compute new result types.
80     SmallVector<Type> newResultTypes;
81     for (Type type : executeRegionOp->getResultTypes()) {
82       if (auto tensorType = type.dyn_cast<TensorType>()) {
83         // TODO: Infer the result type instead of computing it.
84         newResultTypes.push_back(getMemRefType(tensorType, options));
85       } else {
86         newResultTypes.push_back(type);
87       }
88     }
89 
90     // Create new op and move over region.
91     auto newOp =
92         rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
93     newOp.getRegion().takeBody(executeRegionOp.getRegion());
94 
95     // Update terminator.
96     assert(newOp.getRegion().getBlocks().size() == 1 &&
97            "only 1 block supported");
98     Block *newBlock = &newOp.getRegion().front();
99     auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
100     rewriter.setInsertionPoint(yieldOp);
101     SmallVector<Value> newYieldValues;
102     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
103       Value val = it.value();
104       if (val.getType().isa<TensorType>()) {
105         newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
106             yieldOp.getLoc(), newResultTypes[it.index()], val));
107       } else {
108         newYieldValues.push_back(val);
109       }
110     }
111     rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
112 
113     // Update all uses of the old op.
114     rewriter.setInsertionPointAfter(newOp);
115     SmallVector<Value> newResults;
116     for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
117       if (it.value().isa<TensorType>()) {
118         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
119             executeRegionOp.getLoc(), newOp->getResult(it.index())));
120       } else {
121         newResults.push_back(newOp->getResult(it.index()));
122       }
123     }
124 
125     // Replace old op.
126     rewriter.replaceOp(executeRegionOp, newResults);
127 
128     return success();
129   }
130 
131   BufferRelation bufferRelation(Operation *op, OpResult opResult,
132                                 const AnalysisState &state) const {
133     return BufferRelation::Equivalent;
134   }
135 };
136 
137 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
138 struct IfOpInterface
139     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
140   SmallVector<OpOperand *>
141   getAliasingOpOperand(Operation *op, OpResult opResult,
142                        const AnalysisState &state) const {
143     // IfOps do not have tensor OpOperands. The yielded value can be any SSA
144     // value that is in scope. To allow for use-def chain traversal through
145     // IfOps in the analysis, both corresponding yield values from the then/else
146     // branches are considered to be aliasing with the result.
147     auto ifOp = cast<scf::IfOp>(op);
148     size_t resultNum = std::distance(op->getOpResults().begin(),
149                                      llvm::find(op->getOpResults(), opResult));
150     return {&ifOp.thenYield()->getOpOperand(resultNum),
151             &ifOp.elseYield()->getOpOperand(resultNum)};
152   }
153 
154   // TODO: For better bufferization results, this could return `true` only if
155   // there is a memory write in one (or both) of the branches. Since this is not
156   // allowed at the moment, we should never encounter scf.ifs that yield
157   // unmodified tensors. Such scf.yield ops could just fold away.
158   bool isMemoryWrite(Operation *op, OpResult opResult,
159                      const AnalysisState &state) const {
160     // IfOp results are always considered memory writes in the analysis. This
161     // design decision simplifies the analysis considerably. E.g., consider the
162     // following test case:
163     //
164     // %0 = "some_writing_op" : tensor<?xf32>
165     // %r = scf.if %c -> (tensor<?xf32>) {
166     //   scf.yield %0
167     // } else {
168     //   %1 = "another_writing_op"(%0) : tensor<?xf32>
169     // }
170     // "some_reading_op"(%r)
171     //
172     // "another_writing_op" in the above example should be able to bufferize
173     // inplace in the absence of another read of %0. However, if the scf.if op
174     // would not be considered a "write", the analysis would detect the
175     // following conflict:
176     //
177     // * read = some_reading_op
178     // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
179     // * conflictingWrite = %1
180     //
181     // For more details, check the "scf.IfOp" section of the design document.
182     return true;
183   }
184 
185   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
186                           const BufferizationOptions &options) const {
187     auto ifOp = cast<scf::IfOp>(op);
188 
189     // Compute new types of the bufferized scf.if op.
190     SmallVector<Type> newTypes;
191     for (Type returnType : ifOp->getResultTypes()) {
192       if (auto tensorType = returnType.dyn_cast<TensorType>()) {
193         // TODO: Infer the result type instead of computing it.
194         newTypes.push_back(getMemRefType(tensorType, options));
195       } else {
196         newTypes.push_back(returnType);
197       }
198     }
199 
200     // Create new op.
201     auto newIfOp =
202         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
203                                    /*withElseRegion=*/true);
204 
205     // Remove terminators.
206     if (!newIfOp.thenBlock()->empty()) {
207       rewriter.eraseOp(newIfOp.thenBlock()->getTerminator());
208       rewriter.eraseOp(newIfOp.elseBlock()->getTerminator());
209     }
210 
211     // Move over then/else blocks.
212     rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
213     rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
214 
215     // Update scf.yield of new then-block.
216     auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator());
217     rewriter.setInsertionPoint(thenYieldOp);
218     SmallVector<Value> thenYieldValues;
219     for (OpOperand &operand : thenYieldOp->getOpOperands()) {
220       if (operand.get().getType().isa<TensorType>()) {
221         ensureToMemrefOpIsValid(operand.get(),
222                                 newTypes[operand.getOperandNumber()]);
223         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
224             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
225             operand.get());
226         operand.set(toMemrefOp);
227       }
228     }
229 
230     // Update scf.yield of new else-block.
231     auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator());
232     rewriter.setInsertionPoint(elseYieldOp);
233     SmallVector<Value> elseYieldValues;
234     for (OpOperand &operand : elseYieldOp->getOpOperands()) {
235       if (operand.get().getType().isa<TensorType>()) {
236         ensureToMemrefOpIsValid(operand.get(),
237                                 newTypes[operand.getOperandNumber()]);
238         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
239             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
240             operand.get());
241         operand.set(toMemrefOp);
242       }
243     }
244 
245     // Replace op results.
246     replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
247 
248     return success();
249   }
250 
251   BufferRelation bufferRelation(Operation *op, OpResult opResult,
252                                 const AnalysisState &state) const {
253     // IfOp results are equivalent to their corresponding yield values if both
254     // yield values are equivalent to each other.
255     auto bufferizableOp = cast<BufferizableOpInterface>(op);
256     SmallVector<OpOperand *> yieldValues =
257         bufferizableOp.getAliasingOpOperand(opResult, state);
258     assert(yieldValues.size() == 2 && "expected 2 yield values");
259     bool equivalentYields = state.areEquivalentBufferizedValues(
260         yieldValues[0]->get(), yieldValues[1]->get());
261     return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
262   }
263 };
264 
265 /// Helper function for loop bufferization. Return the indices of all values
266 /// that have a tensor type.
267 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
268   DenseSet<int64_t> result;
269   for (const auto &it : llvm::enumerate(values))
270     if (it.value().getType().isa<TensorType>())
271       result.insert(it.index());
272   return result;
273 }
274 
275 /// Helper function for loop bufferization. Return the indices of all
276 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
277 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
278                                        ValueRange yieldedValues,
279                                        const AnalysisState &state) {
280   unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
281   DenseSet<int64_t> result;
282   for (unsigned int i = 0; i < minSize; ++i) {
283     if (!bbArgs[i].getType().isa<TensorType>() ||
284         !yieldedValues[i].getType().isa<TensorType>())
285       continue;
286     if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
287       result.insert(i);
288   }
289   return result;
290 }
291 
292 /// Helper function for loop bufferization. Cast the given buffer to the given
293 /// memref type.
294 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
295   assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
296   assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
297   // If the buffer already has the correct type, no cast is needed.
298   if (buffer.getType() == type)
299     return buffer;
300   // TODO: In case `type` has a layout map that is not the fully dynamic
301   // one, we may not be able to cast the buffer. In that case, the loop
302   // iter_arg's layout map must be changed (see uses of `castBuffer`).
303   assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
304          "scf.while op bufferization: cast incompatible");
305   return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
306 }
307 
308 /// Helper function for loop bufferization. Return the bufferized values of the
309 /// given OpOperands. If an operand is not a tensor, return the original value.
310 static SmallVector<Value> getBuffers(RewriterBase &rewriter,
311                                      MutableArrayRef<OpOperand> operands,
312                                      const BufferizationOptions &options) {
313   SmallVector<Value> result;
314   for (OpOperand &opOperand : operands) {
315     if (opOperand.get().getType().isa<TensorType>()) {
316       Value resultBuffer = getBuffer(rewriter, opOperand.get(), options);
317       result.push_back(resultBuffer);
318     } else {
319       result.push_back(opOperand.get());
320     }
321   }
322   return result;
323 }
324 
325 /// Helper function for loop bufferization. Compute the buffer that should be
326 /// yielded from a loop block (loop body or loop condition).
327 static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
328                               BaseMemRefType type,
329                               const BufferizationOptions &options) {
330   assert(tensor.getType().isa<TensorType>() && "expected tensor");
331   ensureToMemrefOpIsValid(tensor, type);
332   Value yieldedVal = getBuffer(rewriter, tensor, options);
333   return castBuffer(rewriter, yieldedVal, type);
334 }
335 
336 /// Helper function for loop bufferization. Given a range of values, apply
337 /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
338 /// value in the result vector.
339 static SmallVector<Value>
340 convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
341                     llvm::function_ref<Value(Value, int64_t)> func) {
342   SmallVector<Value> result;
343   for (const auto &it : llvm::enumerate(values)) {
344     size_t idx = it.index();
345     Value val = it.value();
346     result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val);
347   }
348   return result;
349 }
350 
351 /// Helper function for loop bufferization. Given a list of pre-bufferization
352 /// yielded values, compute the list of bufferized yielded values.
353 SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
354                                     TypeRange bufferizedTypes,
355                                     const DenseSet<int64_t> &tensorIndices,
356                                     const BufferizationOptions &options) {
357   return convertTensorValues(
358       values, tensorIndices, [&](Value val, int64_t index) {
359         return getYieldedBuffer(rewriter, val,
360                                 bufferizedTypes[index].cast<BaseMemRefType>(),
361                                 options);
362       });
363 }
364 
365 /// Helper function for loop bufferization. Given a list of bbArgs of the new
366 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
367 /// ToTensorOps, so that the block body can be moved over to the new op.
368 SmallVector<Value>
369 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
370                      const DenseSet<int64_t> &tensorIndices) {
371   return convertTensorValues(
372       bbArgs, tensorIndices, [&](Value val, int64_t index) {
373         return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
374       });
375 }
376 
377 /// Bufferization of scf.for. Replace with a new scf.for that operates on
378 /// memrefs.
379 struct ForOpInterface
380     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
381                                                     scf::ForOp> {
382   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
383                               const AnalysisState &state) const {
384     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
385     // its matching bbArg may.
386     auto forOp = cast<scf::ForOp>(op);
387     return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
388   }
389 
390   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
391                                const AnalysisState &state) const {
392     // Tensor iter_args of scf::ForOps are always considered as a write.
393     return true;
394   }
395 
396   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
397                                             const AnalysisState &state) const {
398     auto forOp = cast<scf::ForOp>(op);
399     return {forOp.getResultForOpOperand(opOperand)};
400   }
401 
402   BufferRelation bufferRelation(Operation *op, OpResult opResult,
403                                 const AnalysisState &state) const {
404     // ForOp results are equivalent to their corresponding init_args if the
405     // corresponding iter_args and yield values are equivalent.
406     auto forOp = cast<scf::ForOp>(op);
407     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
408     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
409     auto yieldOp =
410         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
411     bool equivalentYield = state.areEquivalentBufferizedValues(
412         bbArg, yieldOp->getOperand(opResult.getResultNumber()));
413     return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
414   }
415 
416   bool isWritable(Operation *op, Value value,
417                   const AnalysisState &state) const {
418     // Interestingly, scf::ForOp's bbArg can **always** be viewed
419     // inplace from the perspective of ops nested under:
420     //   1. Either the matching iter operand is not bufferized inplace and an
421     //      alloc + optional copy makes the bbArg itself inplaceable.
422     //   2. Or the matching iter operand is bufferized inplace and bbArg just
423     //      bufferizes to that too.
424     return true;
425   }
426 
427   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
428                                  const AnalysisState &state) const {
429     auto bufferizableOp = cast<BufferizableOpInterface>(op);
430     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
431       return failure();
432 
433     if (!state.getOptions().enforceAliasingInvariants)
434       return success();
435 
436     // According to the `getAliasing...` implementations, a bufferized OpResult
437     // may alias only with the corresponding bufferized init_arg and with no
438     // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
439     // but not with any other OpOperand. If a corresponding OpResult/init_arg
440     // pair bufferizes to equivalent buffers, this aliasing requirement is
441     // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
442     // (New buffer copies do not alias with any buffer.)
443     auto forOp = cast<scf::ForOp>(op);
444     auto yieldOp =
445         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
446     OpBuilder::InsertionGuard g(rewriter);
447     rewriter.setInsertionPoint(yieldOp);
448 
449     // Indices of all iter_args that have tensor type. These are the ones that
450     // are bufferized.
451     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
452     // For every yielded value, is the value equivalent to its corresponding
453     // bbArg?
454     DenseSet<int64_t> equivalentYields = getEquivalentBuffers(
455         forOp.getRegionIterArgs(), yieldOp.getResults(), state);
456     SmallVector<Value> yieldValues;
457     for (int64_t idx = 0;
458          idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
459       Value value = yieldOp.getResults()[idx];
460       if (!indices.contains(idx) || equivalentYields.contains(idx)) {
461         yieldValues.push_back(value);
462         continue;
463       }
464       Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(),
465                                                  value, /*escape=*/true);
466       yieldValues.push_back(alloc);
467     }
468 
469     rewriter.updateRootInPlace(
470         yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
471     return success();
472   }
473 
474   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
475                           const BufferizationOptions &options) const {
476     auto forOp = cast<scf::ForOp>(op);
477     Block *oldLoopBody = &forOp.getLoopBody().front();
478 
479     // Indices of all iter_args that have tensor type. These are the ones that
480     // are bufferized.
481     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
482 
483     // The new memref init_args of the loop.
484     SmallVector<Value> initArgs =
485         getBuffers(rewriter, forOp.getIterOpOperands(), options);
486 
487     // Construct a new scf.for op with memref instead of tensor values.
488     auto newForOp = rewriter.create<scf::ForOp>(
489         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
490         forOp.getStep(), initArgs);
491     newForOp->setAttrs(forOp->getAttrs());
492     ValueRange initArgsRange(initArgs);
493     TypeRange initArgsTypes(initArgsRange);
494     Block *loopBody = &newForOp.getLoopBody().front();
495 
496     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
497     // iter_args of the new loop in ToTensorOps.
498     rewriter.setInsertionPointToStart(loopBody);
499     SmallVector<Value> iterArgs =
500         getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
501     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
502 
503     // Erase terminator if present.
504     if (iterArgs.size() == 1)
505       rewriter.eraseOp(loopBody->getTerminator());
506 
507     // Move loop body to new loop.
508     rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
509 
510     // Update scf.yield of new loop.
511     auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
512     rewriter.setInsertionPoint(yieldOp);
513     SmallVector<Value> yieldValues = getYieldedValues(
514         rewriter, yieldOp.getResults(), initArgsTypes, indices, options);
515     yieldOp.getResultsMutable().assign(yieldValues);
516 
517     // Replace loop results.
518     replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
519 
520     return success();
521   }
522 
523   /// Assert that yielded values of an scf.for op are equivalent to their
524   /// corresponding bbArgs. In that case, the buffer relations of the
525   /// corresponding OpResults are "Equivalent".
526   ///
527   /// If this is not the case, an allocs+copies are inserted and yielded from
528   /// the loop. This could be a performance problem, so it must be explicitly
529   /// activated with `alloc-return-allocs`.
530   LogicalResult verifyAnalysis(Operation *op,
531                                const AnalysisState &state) const {
532     const auto &options =
533         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
534     if (options.allowReturnAllocs)
535       return success();
536 
537     auto forOp = cast<scf::ForOp>(op);
538     auto yieldOp =
539         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
540     for (OpResult opResult : op->getOpResults()) {
541       if (!opResult.getType().isa<TensorType>())
542         continue;
543 
544       // Note: This is overly strict. We should check for aliasing bufferized
545       // values. But we don't have a "must-alias" analysis yet.
546       if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
547         return yieldOp->emitError()
548                << "Yield operand #" << opResult.getResultNumber()
549                << " is not equivalent to the corresponding iter bbArg";
550     }
551 
552     return success();
553   }
554 };
555 
556 /// Bufferization of scf.while. Replace with a new scf.while that operates on
557 /// memrefs.
558 struct WhileOpInterface
559     : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
560                                                     scf::WhileOp> {
561   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
562                               const AnalysisState &state) const {
563     // Tensor iter_args of scf::WhileOps are always considered as a read.
564     return true;
565   }
566 
567   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
568                                const AnalysisState &state) const {
569     // Tensor iter_args of scf::WhileOps are always considered as a write.
570     return true;
571   }
572 
573   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
574                                             const AnalysisState &state) const {
575     auto whileOp = cast<scf::WhileOp>(op);
576     unsigned int idx = opOperand.getOperandNumber();
577 
578     // The OpResults and OpOperands may not match. They may not even have the
579     // same type. The number of OpResults and OpOperands can also differ.
580     if (idx >= op->getNumResults() ||
581         opOperand.get().getType() != op->getResult(idx).getType())
582       return {};
583 
584     // The only aliasing OpResult may be the one at the same index.
585     return {whileOp->getResult(idx)};
586   }
587 
588   BufferRelation bufferRelation(Operation *op, OpResult opResult,
589                                 const AnalysisState &state) const {
590     // WhileOp results are equivalent to their corresponding init_args if the
591     // corresponding iter_args and yield values are equivalent (for both the
592     // "before" and the "after" block).
593     unsigned int resultNumber = opResult.getResultNumber();
594     auto whileOp = cast<scf::WhileOp>(op);
595 
596     // The "before" region bbArgs and the OpResults may not match.
597     if (resultNumber >= whileOp.getBeforeArguments().size())
598       return BufferRelation::None;
599     if (opResult.getType() !=
600         whileOp.getBeforeArguments()[resultNumber].getType())
601       return BufferRelation::None;
602 
603     auto conditionOp = whileOp.getConditionOp();
604     BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
605     Value conditionOperand = conditionOp.getArgs()[resultNumber];
606     bool equivCondition =
607         state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
608 
609     auto yieldOp = whileOp.getYieldOp();
610     BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
611     Value yieldOperand = yieldOp.getOperand(resultNumber);
612     bool equivYield =
613         state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
614 
615     return equivCondition && equivYield ? BufferRelation::Equivalent
616                                         : BufferRelation::None;
617   }
618 
619   bool isWritable(Operation *op, Value value,
620                   const AnalysisState &state) const {
621     // Interestingly, scf::WhileOp's bbArg can **always** be viewed
622     // inplace from the perspective of ops nested under:
623     //   1. Either the matching iter operand is not bufferized inplace and an
624     //      alloc + optional copy makes the bbArg itself inplaceable.
625     //   2. Or the matching iter operand is bufferized inplace and bbArg just
626     //      bufferizes to that too.
627     return true;
628   }
629 
630   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
631                                  const AnalysisState &state) const {
632     auto bufferizableOp = cast<BufferizableOpInterface>(op);
633     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
634       return failure();
635 
636     if (!state.getOptions().enforceAliasingInvariants)
637       return success();
638 
639     // According to the `getAliasing...` implementations, a bufferized OpResult
640     // may alias only with the corresponding bufferized init_arg and with no
641     // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
642     // but not with any other OpOperand. If a corresponding OpResult/init_arg
643     // pair bufferizes to equivalent buffers, this aliasing requirement is
644     // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
645     // (New buffer copies do not alias with any buffer.)
646     OpBuilder::InsertionGuard g(rewriter);
647     auto whileOp = cast<scf::WhileOp>(op);
648     auto conditionOp = whileOp.getConditionOp();
649     auto yieldOp = whileOp.getYieldOp();
650 
651     // Indices of all bbArgs that have tensor type. These are the ones that
652     // are bufferized. The "before" and "after" regions may have different args.
653     DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
654     DenseSet<int64_t> indicesAfter =
655         getTensorIndices(whileOp.getAfterArguments());
656 
657     // For every yielded value, is the value equivalent to its corresponding
658     // bbArg?
659     DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
660         whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
661     DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
662         whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
663 
664     // Update "before" region.
665     rewriter.setInsertionPoint(conditionOp);
666     SmallVector<Value> beforeYieldValues;
667     for (int64_t idx = 0;
668          idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
669       Value value = conditionOp.getArgs()[idx];
670       if (!indicesBefore.contains(idx) ||
671           equivalentYieldsBefore.contains(idx)) {
672         beforeYieldValues.push_back(value);
673         continue;
674       }
675       Value alloc = allocateTensorForShapedValue(rewriter, conditionOp.getLoc(),
676                                                  value, /*escape=*/true);
677       beforeYieldValues.push_back(alloc);
678     }
679     rewriter.updateRootInPlace(conditionOp, [&]() {
680       conditionOp.getArgsMutable().assign(beforeYieldValues);
681     });
682 
683     // Update "after" region.
684     rewriter.setInsertionPoint(yieldOp);
685     SmallVector<Value> afterYieldValues;
686     for (int64_t idx = 0;
687          idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
688       Value value = yieldOp.getResults()[idx];
689       if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) {
690         afterYieldValues.push_back(value);
691         continue;
692       }
693       Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(),
694                                                  value, /*escape=*/true);
695       afterYieldValues.push_back(alloc);
696     }
697     rewriter.updateRootInPlace(yieldOp, [&]() {
698       yieldOp.getResultsMutable().assign(afterYieldValues);
699     });
700 
701     return success();
702   }
703 
704   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
705                           const BufferizationOptions &options) const {
706     auto whileOp = cast<scf::WhileOp>(op);
707 
708     assert(whileOp.getBefore().getBlocks().size() == 1 &&
709            "regions with multiple blocks not supported");
710     Block *beforeBody = &whileOp.getBefore().front();
711     assert(whileOp.getAfter().getBlocks().size() == 1 &&
712            "regions with multiple blocks not supported");
713     Block *afterBody = &whileOp.getAfter().front();
714 
715     // Indices of all bbArgs that have tensor type. These are the ones that
716     // are bufferized. The "before" and "after" regions may have different args.
717     DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
718     DenseSet<int64_t> indicesAfter =
719         getTensorIndices(whileOp.getAfterArguments());
720 
721     // The new memref init_args of the loop.
722     SmallVector<Value> initArgs =
723         getBuffers(rewriter, whileOp->getOpOperands(), options);
724 
725     // The result types of a WhileOp are the same as the "after" bbArg types.
726     SmallVector<Type> argsTypesAfter = llvm::to_vector(
727         llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
728           return getBufferType(bbArg, options).cast<Type>();
729         }));
730 
731     // Construct a new scf.while op with memref instead of tensor values.
732     ValueRange argsRangeBefore(initArgs);
733     TypeRange argsTypesBefore(argsRangeBefore);
734     auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
735                                                     argsTypesAfter, initArgs);
736 
737     // Add before/after regions to the new op.
738     SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc());
739     SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
740                                          whileOp.getLoc());
741     Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
742     newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
743     Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
744     newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
745 
746     // Set up new iter_args and move the loop condition block to the new op.
747     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
748     // in ToTensorOps.
749     rewriter.setInsertionPointToStart(newBeforeBody);
750     SmallVector<Value> newBeforeArgs = getBbArgReplacements(
751         rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
752     rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
753 
754     // Update scf.condition of new loop.
755     auto newConditionOp = newWhileOp.getConditionOp();
756     rewriter.setInsertionPoint(newConditionOp);
757     // Only equivalent buffers or new buffer allocations may be yielded to the
758     // "after" region.
759     // TODO: This could be relaxed for better bufferization results.
760     SmallVector<Value> newConditionArgs =
761         getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
762                          indicesAfter, options);
763     newConditionOp.getArgsMutable().assign(newConditionArgs);
764 
765     // Set up new iter_args and move the loop body block to the new op.
766     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
767     // in ToTensorOps.
768     rewriter.setInsertionPointToStart(newAfterBody);
769     SmallVector<Value> newAfterArgs = getBbArgReplacements(
770         rewriter, newWhileOp.getAfterArguments(), indicesAfter);
771     rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
772 
773     // Update scf.yield of the new loop.
774     auto newYieldOp = newWhileOp.getYieldOp();
775     rewriter.setInsertionPoint(newYieldOp);
776     // Only equivalent buffers or new buffer allocations may be yielded to the
777     // "before" region.
778     // TODO: This could be relaxed for better bufferization results.
779     SmallVector<Value> newYieldValues =
780         getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
781                          indicesBefore, options);
782     newYieldOp.getResultsMutable().assign(newYieldValues);
783 
784     // Replace loop results.
785     replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
786 
787     return success();
788   }
789 
790   /// Assert that yielded values of an scf.while op are equivalent to their
791   /// corresponding bbArgs. In that case, the buffer relations of the
792   /// corresponding OpResults are "Equivalent".
793   ///
794   /// If this is not the case, allocs+copies are inserted and yielded from
795   /// the loop. This could be a performance problem, so it must be explicitly
796   /// activated with `alloc-return-allocs`.
797   ///
798   /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
799   /// equivalence condition must be checked for both.
800   LogicalResult verifyAnalysis(Operation *op,
801                                const AnalysisState &state) const {
802     auto whileOp = cast<scf::WhileOp>(op);
803     const auto &options =
804         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
805     if (options.allowReturnAllocs)
806       return success();
807 
808     auto conditionOp = whileOp.getConditionOp();
809     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
810       if (!it.value().getType().isa<TensorType>())
811         continue;
812       if (!state.areEquivalentBufferizedValues(
813               it.value(), conditionOp->getBlock()->getArgument(it.index())))
814         return conditionOp->emitError()
815                << "Condition arg #" << it.index()
816                << " is not equivalent to the corresponding iter bbArg";
817     }
818 
819     auto yieldOp = whileOp.getYieldOp();
820     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
821       if (!it.value().getType().isa<TensorType>())
822         continue;
823       if (!state.areEquivalentBufferizedValues(
824               it.value(), yieldOp->getBlock()->getArgument(it.index())))
825         return yieldOp->emitError()
826                << "Yield operand #" << it.index()
827                << " is not equivalent to the corresponding iter bbArg";
828     }
829 
830     return success();
831   }
832 };
833 
834 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
835 /// this is for analysis only.
836 struct YieldOpInterface
837     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
838                                                     scf::YieldOp> {
839   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
840                               const AnalysisState &state) const {
841     return true;
842   }
843 
844   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
845                                const AnalysisState &state) const {
846     return false;
847   }
848 
849   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
850                                             const AnalysisState &state) const {
851     if (isa<scf::IfOp>(op->getParentOp()))
852       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
853     if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
854       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
855     return {};
856   }
857 
858   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
859                             const AnalysisState &state) const {
860     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
861     // may be generated inside the block. We should not return/yield allocations
862     // when possible.
863     return true;
864   }
865 
866   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
867                           const BufferizationOptions &options) const {
868     auto yieldOp = cast<scf::YieldOp>(op);
869     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
870             yieldOp->getParentOp()))
871       return yieldOp->emitError("unsupported scf::YieldOp parent");
872     return success();
873   }
874 };
875 
876 using tensor::ExtractSliceOp;
877 
878 /// Return the destinations that an ForeachThreadOp is inserting into. One per
879 /// ParallelInsertSliceOp.
880 static SmallVector<OpOperand *>
881 getInsertionDest(ForeachThreadOp foreachThreadOp) {
882   PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
883   SmallVector<OpOperand *> result;
884   terminator.walk([&](ParallelInsertSliceOp insertOp) {
885     result.push_back(&insertOp->getOpOperand(1) /*dest*/);
886   });
887   return result;
888 }
889 
890 /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
891 /// region. There are op interfaces for the terminators (PerformConcurrentlyOp
892 /// and ParallelInsertSliceOp), but these are only used during analysis. Not
893 /// for bufferization.
894 struct ForeachThreadOpInterface
895     : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface,
896                                                     ForeachThreadOp> {
897   SmallVector<OpOperand *>
898   getAliasingOpOperand(Operation *op, OpResult opResult,
899                        const AnalysisState &state) const {
900     // Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
901     auto foreachThreadOp = cast<ForeachThreadOp>(op);
902     return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]};
903   }
904 
905   bool isMemoryWrite(Operation *op, OpResult opResult,
906                      const AnalysisState &state) const {
907     // This op is a memory write. Stop lookup here to avoid finding false
908     // conflicts involving this op and one of the ops in the region. This is
909     // similar to how scf.if ops are analyzed.
910     return true;
911   }
912 
913   BufferRelation bufferRelation(Operation *op, OpResult opResult,
914                                 const AnalysisState &state) const {
915     return BufferRelation::Equivalent;
916   }
917 
918   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
919                                  const AnalysisState &state) const {
920     auto bufferizableOp = cast<BufferizableOpInterface>(op);
921     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
922       return failure();
923 
924     OpBuilder::InsertionGuard g(rewriter);
925     auto foreachThreadOp = cast<ForeachThreadOp>(op);
926     for (OpResult opResult : foreachThreadOp->getOpResults()) {
927       SmallVector<OpOperand *> destOperands =
928           state.getAliasingOpOperand(opResult);
929       assert(destOperands.size() == 1 &&
930              "expected exactly one aliasing OpOperand");
931       assert(isa<ParallelInsertSliceOp>(destOperands.front()->getOwner()) &&
932              "expected ParallelInsertSliceOp");
933 
934       // Nothing to do if there is no conflict.
935       if (state.isInPlace(*destOperands.front()))
936         continue;
937 
938       // Insert tensor allocation.
939       bool isYielded = state.isTensorYielded(opResult);
940       Value alloc = allocateTensorForShapedValue(rewriter, op->getLoc(),
941                                                  destOperands.front()->get(),
942                                                  /*escape=*/isYielded);
943 
944       // Update terminator operand.
945       rewriter.updateRootInPlace(destOperands.front()->getOwner(),
946                                  [&]() { destOperands.front()->set(alloc); });
947     }
948 
949     return success();
950   }
951 
952   LogicalResult bufferize(Operation *op, RewriterBase &b,
953                           const BufferizationOptions &options) const {
954     OpBuilder::InsertionGuard g(b);
955     auto foreachThreadOp = cast<ForeachThreadOp>(op);
956 
957     // Gather new results of the ForeachThreadOp.
958     SmallVector<Value> newResults;
959     for (OpResult opResult : foreachThreadOp->getOpResults()) {
960       OpOperand *insertDest =
961           getInsertionDest(foreachThreadOp)[opResult.getResultNumber()];
962       // Insert copies right before the PerformConcurrentlyOp terminator. They
963       // should not be inside terminator (which would be the default insertion
964       // point).
965       Value buffer = getBuffer(b, insertDest->get(), options);
966       newResults.push_back(buffer);
967     }
968 
969     // Create new ForeachThreadOp without any results and drop the automatically
970     // introduced terminator.
971     TypeRange newResultTypes;
972     auto newForeachThreadOp =
973         b.create<ForeachThreadOp>(foreachThreadOp.getLoc(), newResultTypes,
974                                   foreachThreadOp.getNumThreads());
975     newForeachThreadOp.getBody()->getTerminator()->erase();
976 
977     // Move over block contents of the old op.
978     b.mergeBlocks(foreachThreadOp.getBody(), newForeachThreadOp.getBody(),
979                   {newForeachThreadOp.getBody()->getArguments()});
980 
981     // Bufferize terminator.
982     auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(
983         newForeachThreadOp.getBody()->getTerminator());
984     b.setInsertionPoint(performConcurrentlyOp);
985     unsigned resultCounter = 0;
986     WalkResult walkResult =
987         performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) {
988           Location loc = insertOp.getLoc();
989           Type srcType = getMemRefType(
990               insertOp.getSource().getType().cast<RankedTensorType>(), options);
991           // ParallelInsertSliceOp bufferizes to a copy.
992           auto srcMemref = b.create<bufferization::ToMemrefOp>(
993               loc, srcType, insertOp.getSource());
994           Value destMemref = newResults[resultCounter++];
995           Value subview = b.create<memref::SubViewOp>(
996               loc, destMemref, insertOp.getMixedOffsets(),
997               insertOp.getMixedSizes(), insertOp.getMixedStrides());
998           // This memcpy will fold away if everything bufferizes in-place.
999           if (failed(options.createMemCpy(b, insertOp.getLoc(), srcMemref,
1000                                           subview)))
1001             return WalkResult::interrupt();
1002           b.eraseOp(insertOp);
1003           return WalkResult::advance();
1004         });
1005     if (walkResult.wasInterrupted())
1006       return failure();
1007 
1008     // Replace the op.
1009     replaceOpWithBufferizedValues(b, op, newResults);
1010 
1011     return success();
1012   }
1013 };
1014 
1015 /// Nothing to do for PerformConcurrentlyOp.
1016 struct PerformConcurrentlyOpInterface
1017     : public BufferizableOpInterface::ExternalModel<
1018           PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
1019   LogicalResult bufferize(Operation *op, RewriterBase &b,
1020                           const BufferizationOptions &options) const {
1021     llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1022     return failure();
1023   }
1024 };
1025 
1026 /// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
1027 /// equivalent operand / result and same offset/sizes/strides specification).
1028 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
1029                                          ExtractSliceOp st,
1030                                          ParallelInsertSliceOp sti) {
1031   if (!st || !sti)
1032     return false;
1033   if (st != sti &&
1034       !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
1035     return false;
1036   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
1037     return false;
1038   return true;
1039 }
1040 
1041 /// Return true if `value` is originating from an ExtractSliceOp that matches
1042 /// the given InsertSliceOp.
1043 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
1044                                       ParallelInsertSliceOp insertOp) {
1045   auto condition = [&](Value val) {
1046     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
1047       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
1048         return true;
1049     return false;
1050   };
1051 
1052   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
1053                       condition);
1054 }
1055 
1056 /// Analysis of ParallelInsertSliceOp.
1057 struct ParallelInsertSliceOpInterface
1058     : public BufferizableOpInterface::ExternalModel<
1059           ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
1060   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
1061                                             const AnalysisState &state) const {
1062     if (&opOperand != &op->getOpOperand(1) /*dest*/)
1063       return {};
1064 
1065     // ParallelInsertSliceOp itself has no results. Tensors are returned via
1066     // the parent op.
1067     auto foreachThreadOp = op->getParentOfType<ForeachThreadOp>();
1068     assert(foreachThreadOp &&
1069            "could not find valid owner of parallel_insert_slice");
1070 
1071     // The i-th ParallelInsertSliceOp result is returned via the i-th OpResult
1072     // of the parent ForeachThreadOp.
1073     Block *block = op->getBlock();
1074     unsigned int opIdx = 0;
1075     for (ParallelInsertSliceOp insertOp :
1076          block->getOps<ParallelInsertSliceOp>()) {
1077       if (insertOp.getOperation() == op)
1078         break;
1079       ++opIdx;
1080     }
1081     assert(opIdx < foreachThreadOp->getNumResults() &&
1082            "could not find op inside terminator op");
1083 
1084     return {foreachThreadOp->getResult(opIdx)};
1085   }
1086 
1087   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1088                               const AnalysisState &state) const {
1089     return true;
1090   }
1091 
1092   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1093                                const AnalysisState &state) const {
1094     return &opOperand == &op->getOpOperand(1) /*dest*/;
1095   }
1096 
1097   BufferRelation bufferRelation(Operation *op, OpResult opResult,
1098                                 const AnalysisState &state) const {
1099     return BufferRelation::Equivalent;
1100   }
1101 
1102   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
1103                                  const AnalysisState &state) const {
1104     return success();
1105   }
1106 
1107   LogicalResult bufferize(Operation *op, RewriterBase &b,
1108                           const BufferizationOptions &options) const {
1109     // Will be bufferized as part of ForeachThreadOp.
1110     return failure();
1111   }
1112 
1113   // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
1114   // the code.
1115   bool isNotConflicting(Operation *op, OpOperand *uRead,
1116                         OpOperand *uConflictingWrite,
1117                         const AnalysisState &state) const {
1118     Operation *readingOp = uRead->getOwner();
1119     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
1120 
1121     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
1122     // uRead is an InsertSliceOp...
1123     if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
1124       // As an example, consider the following IR.
1125       //
1126       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1127       // %1 = linalg.fill %cst, %0 {inplace= [true] }
1128       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1129       //     {inplace= [true] }
1130 
1131       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
1132       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1133           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
1134                                     insertSliceOp))
1135         // Case 1: The main insight is that InsertSliceOp reads only part of
1136         // the destination tensor. The overwritten area is not read. If
1137         // uConflictingWrite writes into exactly the memory location that is
1138         // being read by uRead, this is not a conflict.
1139         //
1140         // In the above example:
1141         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
1142         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
1143         //
1144         // The read of %t does not conflict with the write of the FillOp
1145         // (same aliases!) because the area that the FillOp operates on is
1146         // exactly the one that is *not* read via %t.
1147         return true;
1148 
1149       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
1150           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1151           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
1152         // Case 2: The read of the source tensor and the write to the dest
1153         // tensor via an InsertSliceOp is not a conflict if the read is
1154         // reading exactly that part of an equivalent tensor that the
1155         // InsertSliceOp is writing.
1156         //
1157         // In the above example:
1158         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
1159         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1160         return true;
1161     }
1162 
1163     // If uConflictingWrite is an InsertSliceOp...
1164     if (auto insertSliceOp =
1165             dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
1166       // As an example, consider the following IR.
1167       //
1168       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1169       // %1 = linalg.fill %cst, %0 {inplace= [true] }
1170       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1171       //     {inplace= [true] }
1172       // %3 = vector.transfer_read %1, %cst
1173       //
1174       // In the above example:
1175       // uRead             = OpOperand 0 (%1) of vector.transfer_read
1176       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1177       // lastWrite         = %1
1178       //
1179       // This is not a conflict because the InsertSliceOp overwrites the
1180       // memory segment of %1 with the exact same data. (Effectively, there
1181       // is no memory write here.)
1182       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1183           state.areEquivalentBufferizedValues(uRead->get(),
1184                                               insertSliceOp.getSource()) &&
1185           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
1186                                     insertSliceOp))
1187         return true;
1188 
1189     return false;
1190   }
1191 };
1192 
1193 } // namespace
1194 } // namespace scf
1195 } // namespace mlir
1196 
1197 void mlir::scf::registerBufferizableOpInterfaceExternalModels(
1198     DialectRegistry &registry) {
1199   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1200     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1201     ForOp::attachInterface<ForOpInterface>(*ctx);
1202     IfOp::attachInterface<IfOpInterface>(*ctx);
1203     ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
1204     ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1205         *ctx);
1206     PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
1207         *ctx);
1208     WhileOp::attachInterface<WhileOpInterface>(*ctx);
1209     YieldOp::attachInterface<YieldOpInterface>(*ctx);
1210   });
1211 }
1212