1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/Dialect.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/PatternMatch.h"
20 
21 using namespace mlir;
22 using namespace mlir::bufferization;
23 using namespace mlir::scf;
24 
25 namespace mlir {
26 namespace scf {
27 namespace {
28 
29 // bufferization.to_memref is not allowed to change the rank.
30 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
31 #ifndef NDEBUG
32   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
33   assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
34                                 rankedTensorType.getRank())) &&
35          "to_memref would be invalid: mismatching ranks");
36 #endif
37 }
38 
39 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
40 /// fully implemented at the moment.
41 struct ExecuteRegionOpInterface
42     : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
43                                                     scf::ExecuteRegionOp> {
44   SmallVector<OpOperand *>
45   getAliasingOpOperand(Operation *op, OpResult opResult,
46                        const AnalysisState &state) const {
47     // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
48     // any SSA value that is in scope. To allow for use-def chain traversal
49     // through ExecuteRegionOps in the analysis, the corresponding yield value
50     // is considered to be aliasing with the result.
51     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
52     size_t resultNum = std::distance(op->getOpResults().begin(),
53                                      llvm::find(op->getOpResults(), opResult));
54     // TODO: Support multiple blocks.
55     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
56            "expected exactly 1 block");
57     auto yieldOp = dyn_cast<scf::YieldOp>(
58         executeRegionOp.getRegion().front().getTerminator());
59     assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
60     return {&yieldOp->getOpOperand(resultNum)};
61   }
62 
63   // TODO: For better bufferization results, this could return `true` only if
64   // there is a memory write in the region.
65   bool isMemoryWrite(Operation *op, OpResult opResult,
66                      const AnalysisState &state) const {
67     // Similar to scf.if, results of this op are always considered memory writes
68     // in the analysis. This is a useful pattern for all ops that have tensor
69     // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
70     // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
71     // ops without OpOperands.
72     return true;
73   }
74 
75   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
76                           const BufferizationOptions &options) const {
77     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
78     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
79            "only 1 block supported");
80     auto yieldOp =
81         cast<scf::YieldOp>(executeRegionOp.getRegion().front().getTerminator());
82     TypeRange newResultTypes(yieldOp.getResults());
83 
84     // Create new op and move over region.
85     auto newOp =
86         rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
87     newOp.getRegion().takeBody(executeRegionOp.getRegion());
88 
89     // Update all uses of the old op.
90     rewriter.setInsertionPointAfter(newOp);
91     SmallVector<Value> newResults;
92     for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
93       if (it.value().isa<TensorType>()) {
94         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
95             executeRegionOp.getLoc(), newOp->getResult(it.index())));
96       } else {
97         newResults.push_back(newOp->getResult(it.index()));
98       }
99     }
100 
101     // Replace old op.
102     rewriter.replaceOp(executeRegionOp, newResults);
103 
104     return success();
105   }
106 
107   BufferRelation bufferRelation(Operation *op, OpResult opResult,
108                                 const AnalysisState &state) const {
109     return BufferRelation::Equivalent;
110   }
111 };
112 
113 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
114 struct IfOpInterface
115     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
116   SmallVector<OpOperand *>
117   getAliasingOpOperand(Operation *op, OpResult opResult,
118                        const AnalysisState &state) const {
119     // IfOps do not have tensor OpOperands. The yielded value can be any SSA
120     // value that is in scope. To allow for use-def chain traversal through
121     // IfOps in the analysis, both corresponding yield values from the then/else
122     // branches are considered to be aliasing with the result.
123     auto ifOp = cast<scf::IfOp>(op);
124     size_t resultNum = std::distance(op->getOpResults().begin(),
125                                      llvm::find(op->getOpResults(), opResult));
126     return {&ifOp.thenYield()->getOpOperand(resultNum),
127             &ifOp.elseYield()->getOpOperand(resultNum)};
128   }
129 
130   // TODO: For better bufferization results, this could return `true` only if
131   // there is a memory write in one (or both) of the branches. Since this is not
132   // allowed at the moment, we should never encounter scf.ifs that yield
133   // unmodified tensors. Such scf.yield ops could just fold away.
134   bool isMemoryWrite(Operation *op, OpResult opResult,
135                      const AnalysisState &state) const {
136     // IfOp results are always considered memory writes in the analysis. This
137     // design decision simplifies the analysis considerably. E.g., consider the
138     // following test case:
139     //
140     // %0 = "some_writing_op" : tensor<?xf32>
141     // %r = scf.if %c -> (tensor<?xf32>) {
142     //   scf.yield %0
143     // } else {
144     //   %1 = "another_writing_op"(%0) : tensor<?xf32>
145     // }
146     // "some_reading_op"(%r)
147     //
148     // "another_writing_op" in the above example should be able to bufferize
149     // inplace in the absence of another read of %0. However, if the scf.if op
150     // would not be considered a "write", the analysis would detect the
151     // following conflict:
152     //
153     // * read = some_reading_op
154     // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
155     // * conflictingWrite = %1
156     //
157     // For more details, check the "scf.IfOp" section of the design document.
158     return true;
159   }
160 
161   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
162                           const BufferizationOptions &options) const {
163     OpBuilder::InsertionGuard g(rewriter);
164     auto ifOp = cast<scf::IfOp>(op);
165     auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
166     auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
167 
168     // Reconcile type mismatches between then/else branches by inserting memref
169     // casts.
170     SmallVector<Value> thenResults, elseResults;
171     bool insertedCast = false;
172     for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) {
173       Value thenValue = thenYieldOp.getResults()[i];
174       Value elseValue = elseYieldOp.getResults()[i];
175       if (thenValue.getType() == elseValue.getType()) {
176         thenResults.push_back(thenValue);
177         elseResults.push_back(elseValue);
178         continue;
179       }
180 
181       // Type mismatch between then/else yield value. Cast both to a memref type
182       // with a fully dynamic layout map.
183       auto thenMemrefType = thenValue.getType().cast<BaseMemRefType>();
184       auto elseMemrefType = elseValue.getType().cast<BaseMemRefType>();
185       if (thenMemrefType.getMemorySpaceAsInt() !=
186           elseMemrefType.getMemorySpaceAsInt())
187         return op->emitError("inconsistent memory space on then/else branches");
188       rewriter.setInsertionPoint(thenYieldOp);
189       BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout(
190           ifOp.getResultTypes()[i].cast<TensorType>(),
191           thenMemrefType.getMemorySpaceAsInt());
192       thenResults.push_back(rewriter.create<memref::CastOp>(
193           thenYieldOp.getLoc(), memrefType, thenValue));
194       rewriter.setInsertionPoint(elseYieldOp);
195       elseResults.push_back(rewriter.create<memref::CastOp>(
196           elseYieldOp.getLoc(), memrefType, elseValue));
197       insertedCast = true;
198     }
199 
200     if (insertedCast) {
201       rewriter.setInsertionPoint(thenYieldOp);
202       rewriter.replaceOpWithNewOp<scf::YieldOp>(thenYieldOp, thenResults);
203       rewriter.setInsertionPoint(elseYieldOp);
204       rewriter.replaceOpWithNewOp<scf::YieldOp>(elseYieldOp, elseResults);
205     }
206 
207     // Create new op.
208     rewriter.setInsertionPoint(ifOp);
209     ValueRange resultsValueRange(thenResults);
210     TypeRange newTypes(resultsValueRange);
211     auto newIfOp =
212         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
213                                    /*withElseRegion=*/true);
214 
215     // Move over then/else blocks.
216     rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
217     rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
218 
219     // Replace op results.
220     replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
221 
222     return success();
223   }
224 
225   BufferRelation bufferRelation(Operation *op, OpResult opResult,
226                                 const AnalysisState &state) const {
227     // IfOp results are equivalent to their corresponding yield values if both
228     // yield values are equivalent to each other.
229     auto bufferizableOp = cast<BufferizableOpInterface>(op);
230     SmallVector<OpOperand *> yieldValues =
231         bufferizableOp.getAliasingOpOperand(opResult, state);
232     assert(yieldValues.size() == 2 && "expected 2 yield values");
233     bool equivalentYields = state.areEquivalentBufferizedValues(
234         yieldValues[0]->get(), yieldValues[1]->get());
235     return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
236   }
237 };
238 
239 /// Helper function for loop bufferization. Return the indices of all values
240 /// that have a tensor type.
241 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
242   DenseSet<int64_t> result;
243   for (const auto &it : llvm::enumerate(values))
244     if (it.value().getType().isa<TensorType>())
245       result.insert(it.index());
246   return result;
247 }
248 
249 /// Helper function for loop bufferization. Return the indices of all
250 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
251 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
252                                        ValueRange yieldedValues,
253                                        const AnalysisState &state) {
254   unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
255   DenseSet<int64_t> result;
256   for (unsigned int i = 0; i < minSize; ++i) {
257     if (!bbArgs[i].getType().isa<TensorType>() ||
258         !yieldedValues[i].getType().isa<TensorType>())
259       continue;
260     if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
261       result.insert(i);
262   }
263   return result;
264 }
265 
266 /// Helper function for loop bufferization. Cast the given buffer to the given
267 /// memref type.
268 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
269   assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
270   assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
271   // If the buffer already has the correct type, no cast is needed.
272   if (buffer.getType() == type)
273     return buffer;
274   // TODO: In case `type` has a layout map that is not the fully dynamic
275   // one, we may not be able to cast the buffer. In that case, the loop
276   // iter_arg's layout map must be changed (see uses of `castBuffer`).
277   assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
278          "scf.while op bufferization: cast incompatible");
279   return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
280 }
281 
282 /// Helper function for loop bufferization. Return the bufferized values of the
283 /// given OpOperands. If an operand is not a tensor, return the original value.
284 static FailureOr<SmallVector<Value>>
285 getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
286            const BufferizationOptions &options) {
287   SmallVector<Value> result;
288   for (OpOperand &opOperand : operands) {
289     if (opOperand.get().getType().isa<TensorType>()) {
290       FailureOr<Value> resultBuffer =
291           getBuffer(rewriter, opOperand.get(), options);
292       if (failed(resultBuffer))
293         return failure();
294       result.push_back(*resultBuffer);
295     } else {
296       result.push_back(opOperand.get());
297     }
298   }
299   return result;
300 }
301 
302 /// Helper function for loop bufferization. Compute the buffer that should be
303 /// yielded from a loop block (loop body or loop condition).
304 static FailureOr<Value> getYieldedBuffer(RewriterBase &rewriter, Value tensor,
305                                          BaseMemRefType type,
306                                          const BufferizationOptions &options) {
307   assert(tensor.getType().isa<TensorType>() && "expected tensor");
308   ensureToMemrefOpIsValid(tensor, type);
309   FailureOr<Value> yieldedVal = getBuffer(rewriter, tensor, options);
310   if (failed(yieldedVal))
311     return failure();
312   return castBuffer(rewriter, *yieldedVal, type);
313 }
314 
315 /// Helper function for loop bufferization. Given a range of values, apply
316 /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
317 /// value in the result vector.
318 static FailureOr<SmallVector<Value>>
319 convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
320                     llvm::function_ref<FailureOr<Value>(Value, int64_t)> func) {
321   SmallVector<Value> result;
322   for (const auto &it : llvm::enumerate(values)) {
323     size_t idx = it.index();
324     Value val = it.value();
325     if (tensorIndices.contains(idx)) {
326       FailureOr<Value> maybeVal = func(val, idx);
327       if (failed(maybeVal))
328         return failure();
329       result.push_back(*maybeVal);
330     } else {
331       result.push_back(val);
332     }
333   }
334   return result;
335 }
336 
337 /// Helper function for loop bufferization. Given a list of pre-bufferization
338 /// yielded values, compute the list of bufferized yielded values.
339 FailureOr<SmallVector<Value>>
340 getYieldedValues(RewriterBase &rewriter, ValueRange values,
341                  TypeRange bufferizedTypes,
342                  const DenseSet<int64_t> &tensorIndices,
343                  const BufferizationOptions &options) {
344   return convertTensorValues(
345       values, tensorIndices, [&](Value val, int64_t index) {
346         return getYieldedBuffer(rewriter, val,
347                                 bufferizedTypes[index].cast<BaseMemRefType>(),
348                                 options);
349       });
350 }
351 
352 /// Helper function for loop bufferization. Given a list of bbArgs of the new
353 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
354 /// ToTensorOps, so that the block body can be moved over to the new op.
355 SmallVector<Value>
356 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
357                      const DenseSet<int64_t> &tensorIndices) {
358   SmallVector<Value> result;
359   for (const auto &it : llvm::enumerate(bbArgs)) {
360     size_t idx = it.index();
361     Value val = it.value();
362     if (tensorIndices.contains(idx)) {
363       result.push_back(
364           rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
365               .getResult());
366     } else {
367       result.push_back(val);
368     }
369   }
370   return result;
371 }
372 
373 /// Bufferization of scf.for. Replace with a new scf.for that operates on
374 /// memrefs.
375 struct ForOpInterface
376     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
377                                                     scf::ForOp> {
378   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
379                               const AnalysisState &state) const {
380     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
381     // its matching bbArg may.
382     auto forOp = cast<scf::ForOp>(op);
383     return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
384   }
385 
386   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
387                                const AnalysisState &state) const {
388     // Tensor iter_args of scf::ForOps are always considered as a write.
389     return true;
390   }
391 
392   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
393                                             const AnalysisState &state) const {
394     auto forOp = cast<scf::ForOp>(op);
395     return {forOp.getResultForOpOperand(opOperand)};
396   }
397 
398   BufferRelation bufferRelation(Operation *op, OpResult opResult,
399                                 const AnalysisState &state) const {
400     // ForOp results are equivalent to their corresponding init_args if the
401     // corresponding iter_args and yield values are equivalent.
402     auto forOp = cast<scf::ForOp>(op);
403     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
404     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
405     auto yieldOp =
406         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
407     bool equivalentYield = state.areEquivalentBufferizedValues(
408         bbArg, yieldOp->getOperand(opResult.getResultNumber()));
409     return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
410   }
411 
412   bool isWritable(Operation *op, Value value,
413                   const AnalysisState &state) const {
414     // Interestingly, scf::ForOp's bbArg can **always** be viewed
415     // inplace from the perspective of ops nested under:
416     //   1. Either the matching iter operand is not bufferized inplace and an
417     //      alloc + optional copy makes the bbArg itself inplaceable.
418     //   2. Or the matching iter operand is bufferized inplace and bbArg just
419     //      bufferizes to that too.
420     return true;
421   }
422 
423   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
424                                  const AnalysisState &state) const {
425     auto bufferizableOp = cast<BufferizableOpInterface>(op);
426     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
427       return failure();
428 
429     if (!state.getOptions().enforceAliasingInvariants)
430       return success();
431 
432     // According to the `getAliasing...` implementations, a bufferized OpResult
433     // may alias only with the corresponding bufferized init_arg and with no
434     // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
435     // but not with any other OpOperand. If a corresponding OpResult/init_arg
436     // pair bufferizes to equivalent buffers, this aliasing requirement is
437     // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
438     // (New buffer copies do not alias with any buffer.)
439     auto forOp = cast<scf::ForOp>(op);
440     auto yieldOp =
441         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
442     OpBuilder::InsertionGuard g(rewriter);
443     rewriter.setInsertionPoint(yieldOp);
444 
445     // Indices of all iter_args that have tensor type. These are the ones that
446     // are bufferized.
447     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
448     // For every yielded value, is the value equivalent to its corresponding
449     // bbArg?
450     DenseSet<int64_t> equivalentYields = getEquivalentBuffers(
451         forOp.getRegionIterArgs(), yieldOp.getResults(), state);
452     SmallVector<Value> yieldValues;
453     for (int64_t idx = 0;
454          idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
455       Value value = yieldOp.getResults()[idx];
456       if (!indices.contains(idx) || equivalentYields.contains(idx)) {
457         yieldValues.push_back(value);
458         continue;
459       }
460       Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(),
461                                                  value, /*escape=*/true);
462       yieldValues.push_back(alloc);
463     }
464 
465     rewriter.updateRootInPlace(
466         yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
467     return success();
468   }
469 
470   FailureOr<BaseMemRefType>
471   getBufferType(Operation *op, BlockArgument bbArg,
472                 const BufferizationOptions &options) const {
473     auto forOp = cast<scf::ForOp>(op);
474     return bufferization::getBufferType(
475         forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
476   }
477 
478   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
479                           const BufferizationOptions &options) const {
480     auto forOp = cast<scf::ForOp>(op);
481     Block *oldLoopBody = &forOp.getLoopBody().front();
482 
483     // Indices of all iter_args that have tensor type. These are the ones that
484     // are bufferized.
485     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
486 
487     // The new memref init_args of the loop.
488     FailureOr<SmallVector<Value>> maybeInitArgs =
489         getBuffers(rewriter, forOp.getIterOpOperands(), options);
490     if (failed(maybeInitArgs))
491       return failure();
492     SmallVector<Value> initArgs = *maybeInitArgs;
493 
494     // Construct a new scf.for op with memref instead of tensor values.
495     auto newForOp = rewriter.create<scf::ForOp>(
496         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
497         forOp.getStep(), initArgs);
498     newForOp->setAttrs(forOp->getAttrs());
499     ValueRange initArgsRange(initArgs);
500     TypeRange initArgsTypes(initArgsRange);
501     Block *loopBody = &newForOp.getLoopBody().front();
502 
503     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
504     // iter_args of the new loop in ToTensorOps.
505     rewriter.setInsertionPointToStart(loopBody);
506     SmallVector<Value> iterArgs =
507         getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
508     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
509 
510     // Move loop body to new loop.
511     rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
512 
513     // Replace loop results.
514     replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
515 
516     return success();
517   }
518 
519   /// Assert that yielded values of an scf.for op are equivalent to their
520   /// corresponding bbArgs. In that case, the buffer relations of the
521   /// corresponding OpResults are "Equivalent".
522   ///
523   /// If this is not the case, an allocs+copies are inserted and yielded from
524   /// the loop. This could be a performance problem, so it must be explicitly
525   /// activated with `alloc-return-allocs`.
526   LogicalResult verifyAnalysis(Operation *op,
527                                const AnalysisState &state) const {
528     const auto &options =
529         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
530     if (options.allowReturnAllocs)
531       return success();
532 
533     auto forOp = cast<scf::ForOp>(op);
534     auto yieldOp =
535         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
536     for (OpResult opResult : op->getOpResults()) {
537       if (!opResult.getType().isa<TensorType>())
538         continue;
539 
540       // Note: This is overly strict. We should check for aliasing bufferized
541       // values. But we don't have a "must-alias" analysis yet.
542       if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
543         return yieldOp->emitError()
544                << "Yield operand #" << opResult.getResultNumber()
545                << " is not equivalent to the corresponding iter bbArg";
546     }
547 
548     return success();
549   }
550 };
551 
552 /// Bufferization of scf.while. Replace with a new scf.while that operates on
553 /// memrefs.
554 struct WhileOpInterface
555     : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
556                                                     scf::WhileOp> {
557   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
558                               const AnalysisState &state) const {
559     // Tensor iter_args of scf::WhileOps are always considered as a read.
560     return true;
561   }
562 
563   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
564                                const AnalysisState &state) const {
565     // Tensor iter_args of scf::WhileOps are always considered as a write.
566     return true;
567   }
568 
569   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
570                                             const AnalysisState &state) const {
571     auto whileOp = cast<scf::WhileOp>(op);
572     unsigned int idx = opOperand.getOperandNumber();
573 
574     // The OpResults and OpOperands may not match. They may not even have the
575     // same type. The number of OpResults and OpOperands can also differ.
576     if (idx >= op->getNumResults() ||
577         opOperand.get().getType() != op->getResult(idx).getType())
578       return {};
579 
580     // The only aliasing OpResult may be the one at the same index.
581     return {whileOp->getResult(idx)};
582   }
583 
584   BufferRelation bufferRelation(Operation *op, OpResult opResult,
585                                 const AnalysisState &state) const {
586     // WhileOp results are equivalent to their corresponding init_args if the
587     // corresponding iter_args and yield values are equivalent (for both the
588     // "before" and the "after" block).
589     unsigned int resultNumber = opResult.getResultNumber();
590     auto whileOp = cast<scf::WhileOp>(op);
591 
592     // The "before" region bbArgs and the OpResults may not match.
593     if (resultNumber >= whileOp.getBeforeArguments().size())
594       return BufferRelation::None;
595     if (opResult.getType() !=
596         whileOp.getBeforeArguments()[resultNumber].getType())
597       return BufferRelation::None;
598 
599     auto conditionOp = whileOp.getConditionOp();
600     BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
601     Value conditionOperand = conditionOp.getArgs()[resultNumber];
602     bool equivCondition =
603         state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
604 
605     auto yieldOp = whileOp.getYieldOp();
606     BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
607     Value yieldOperand = yieldOp.getOperand(resultNumber);
608     bool equivYield =
609         state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
610 
611     return equivCondition && equivYield ? BufferRelation::Equivalent
612                                         : BufferRelation::None;
613   }
614 
615   bool isWritable(Operation *op, Value value,
616                   const AnalysisState &state) const {
617     // Interestingly, scf::WhileOp's bbArg can **always** be viewed
618     // inplace from the perspective of ops nested under:
619     //   1. Either the matching iter operand is not bufferized inplace and an
620     //      alloc + optional copy makes the bbArg itself inplaceable.
621     //   2. Or the matching iter operand is bufferized inplace and bbArg just
622     //      bufferizes to that too.
623     return true;
624   }
625 
626   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
627                                  const AnalysisState &state) const {
628     auto bufferizableOp = cast<BufferizableOpInterface>(op);
629     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
630       return failure();
631 
632     if (!state.getOptions().enforceAliasingInvariants)
633       return success();
634 
635     // According to the `getAliasing...` implementations, a bufferized OpResult
636     // may alias only with the corresponding bufferized init_arg and with no
637     // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
638     // but not with any other OpOperand. If a corresponding OpResult/init_arg
639     // pair bufferizes to equivalent buffers, this aliasing requirement is
640     // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
641     // (New buffer copies do not alias with any buffer.)
642     OpBuilder::InsertionGuard g(rewriter);
643     auto whileOp = cast<scf::WhileOp>(op);
644     auto conditionOp = whileOp.getConditionOp();
645     auto yieldOp = whileOp.getYieldOp();
646 
647     // Indices of all bbArgs that have tensor type. These are the ones that
648     // are bufferized. The "before" and "after" regions may have different args.
649     DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
650     DenseSet<int64_t> indicesAfter =
651         getTensorIndices(whileOp.getAfterArguments());
652 
653     // For every yielded value, is the value equivalent to its corresponding
654     // bbArg?
655     DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
656         whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
657     DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
658         whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
659 
660     // Update "before" region.
661     rewriter.setInsertionPoint(conditionOp);
662     SmallVector<Value> beforeYieldValues;
663     for (int64_t idx = 0;
664          idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
665       Value value = conditionOp.getArgs()[idx];
666       if (!indicesBefore.contains(idx) ||
667           equivalentYieldsBefore.contains(idx)) {
668         beforeYieldValues.push_back(value);
669         continue;
670       }
671       Value alloc = allocateTensorForShapedValue(rewriter, conditionOp.getLoc(),
672                                                  value, /*escape=*/true);
673       beforeYieldValues.push_back(alloc);
674     }
675     rewriter.updateRootInPlace(conditionOp, [&]() {
676       conditionOp.getArgsMutable().assign(beforeYieldValues);
677     });
678 
679     // Update "after" region.
680     rewriter.setInsertionPoint(yieldOp);
681     SmallVector<Value> afterYieldValues;
682     for (int64_t idx = 0;
683          idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
684       Value value = yieldOp.getResults()[idx];
685       if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) {
686         afterYieldValues.push_back(value);
687         continue;
688       }
689       Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(),
690                                                  value, /*escape=*/true);
691       afterYieldValues.push_back(alloc);
692     }
693     rewriter.updateRootInPlace(yieldOp, [&]() {
694       yieldOp.getResultsMutable().assign(afterYieldValues);
695     });
696 
697     return success();
698   }
699 
700   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
701                           const BufferizationOptions &options) const {
702     auto whileOp = cast<scf::WhileOp>(op);
703 
704     assert(whileOp.getBefore().getBlocks().size() == 1 &&
705            "regions with multiple blocks not supported");
706     Block *beforeBody = &whileOp.getBefore().front();
707     assert(whileOp.getAfter().getBlocks().size() == 1 &&
708            "regions with multiple blocks not supported");
709     Block *afterBody = &whileOp.getAfter().front();
710 
711     // Indices of all bbArgs that have tensor type. These are the ones that
712     // are bufferized. The "before" and "after" regions may have different args.
713     DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
714     DenseSet<int64_t> indicesAfter =
715         getTensorIndices(whileOp.getAfterArguments());
716 
717     // The new memref init_args of the loop.
718     FailureOr<SmallVector<Value>> maybeInitArgs =
719         getBuffers(rewriter, whileOp->getOpOperands(), options);
720     if (failed(maybeInitArgs))
721       return failure();
722     SmallVector<Value> initArgs = *maybeInitArgs;
723 
724     // The result types of a WhileOp are the same as the "after" bbArg types.
725     SmallVector<Type> argsTypesAfter = llvm::to_vector(
726         llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
727           // TODO: error handling
728           return bufferization::getBufferType(bbArg, options)->cast<Type>();
729         }));
730 
731     // Construct a new scf.while op with memref instead of tensor values.
732     ValueRange argsRangeBefore(initArgs);
733     TypeRange argsTypesBefore(argsRangeBefore);
734     auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
735                                                     argsTypesAfter, initArgs);
736 
737     // Add before/after regions to the new op.
738     SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc());
739     SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
740                                          whileOp.getLoc());
741     Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
742     newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
743     Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
744     newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
745 
746     // Set up new iter_args and move the loop condition block to the new op.
747     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
748     // in ToTensorOps.
749     rewriter.setInsertionPointToStart(newBeforeBody);
750     SmallVector<Value> newBeforeArgs = getBbArgReplacements(
751         rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
752     rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
753 
754     // Update scf.condition of new loop.
755     auto newConditionOp = newWhileOp.getConditionOp();
756     rewriter.setInsertionPoint(newConditionOp);
757     // Only equivalent buffers or new buffer allocations may be yielded to the
758     // "after" region.
759     // TODO: This could be relaxed for better bufferization results.
760     FailureOr<SmallVector<Value>> newConditionArgs =
761         getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
762                          indicesAfter, options);
763     if (failed(newConditionArgs))
764       return failure();
765     newConditionOp.getArgsMutable().assign(*newConditionArgs);
766 
767     // Set up new iter_args and move the loop body block to the new op.
768     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
769     // in ToTensorOps.
770     rewriter.setInsertionPointToStart(newAfterBody);
771     SmallVector<Value> newAfterArgs = getBbArgReplacements(
772         rewriter, newWhileOp.getAfterArguments(), indicesAfter);
773     rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
774 
775     // Update scf.yield of the new loop.
776     auto newYieldOp = newWhileOp.getYieldOp();
777     rewriter.setInsertionPoint(newYieldOp);
778     // Only equivalent buffers or new buffer allocations may be yielded to the
779     // "before" region.
780     // TODO: This could be relaxed for better bufferization results.
781     FailureOr<SmallVector<Value>> newYieldValues =
782         getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
783                          indicesBefore, options);
784     if (failed(newYieldValues))
785       return failure();
786     newYieldOp.getResultsMutable().assign(*newYieldValues);
787 
788     // Replace loop results.
789     replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
790 
791     return success();
792   }
793 
794   /// Assert that yielded values of an scf.while op are equivalent to their
795   /// corresponding bbArgs. In that case, the buffer relations of the
796   /// corresponding OpResults are "Equivalent".
797   ///
798   /// If this is not the case, allocs+copies are inserted and yielded from
799   /// the loop. This could be a performance problem, so it must be explicitly
800   /// activated with `alloc-return-allocs`.
801   ///
802   /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
803   /// equivalence condition must be checked for both.
804   LogicalResult verifyAnalysis(Operation *op,
805                                const AnalysisState &state) const {
806     auto whileOp = cast<scf::WhileOp>(op);
807     const auto &options =
808         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
809     if (options.allowReturnAllocs)
810       return success();
811 
812     auto conditionOp = whileOp.getConditionOp();
813     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
814       if (!it.value().getType().isa<TensorType>())
815         continue;
816       if (!state.areEquivalentBufferizedValues(
817               it.value(), conditionOp->getBlock()->getArgument(it.index())))
818         return conditionOp->emitError()
819                << "Condition arg #" << it.index()
820                << " is not equivalent to the corresponding iter bbArg";
821     }
822 
823     auto yieldOp = whileOp.getYieldOp();
824     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
825       if (!it.value().getType().isa<TensorType>())
826         continue;
827       if (!state.areEquivalentBufferizedValues(
828               it.value(), yieldOp->getBlock()->getArgument(it.index())))
829         return yieldOp->emitError()
830                << "Yield operand #" << it.index()
831                << " is not equivalent to the corresponding iter bbArg";
832     }
833 
834     return success();
835   }
836 };
837 
838 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
839 /// this is for analysis only.
840 struct YieldOpInterface
841     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
842                                                     scf::YieldOp> {
843   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
844                               const AnalysisState &state) const {
845     return true;
846   }
847 
848   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
849                                const AnalysisState &state) const {
850     return false;
851   }
852 
853   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
854                                             const AnalysisState &state) const {
855     if (isa<scf::IfOp>(op->getParentOp()))
856       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
857     if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
858       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
859     return {};
860   }
861 
862   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
863                             const AnalysisState &state) const {
864     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
865     // may be generated inside the block. We should not return/yield allocations
866     // when possible.
867     return true;
868   }
869 
870   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
871                           const BufferizationOptions &options) const {
872     auto yieldOp = cast<scf::YieldOp>(op);
873     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
874             yieldOp->getParentOp()))
875       return yieldOp->emitError("unsupported scf::YieldOp parent");
876 
877     // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized
878     // together with scf.while.)
879     if (isa<scf::WhileOp>(yieldOp->getParentOp()))
880       return success();
881 
882     SmallVector<Value> newResults;
883     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
884       Value value = it.value();
885       if (value.getType().isa<TensorType>()) {
886         FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
887         if (failed(maybeBuffer))
888           return failure();
889         Value buffer = *maybeBuffer;
890         if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
891           FailureOr<BaseMemRefType> resultType =
892               cast<BufferizableOpInterface>(forOp.getOperation())
893                   .getBufferType(forOp.getRegionIterArgs()[it.index()],
894                                  options);
895           if (failed(resultType))
896             return failure();
897           buffer = castBuffer(rewriter, buffer, *resultType);
898         }
899         newResults.push_back(buffer);
900       } else {
901         newResults.push_back(value);
902       }
903     }
904 
905     replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
906     return success();
907   }
908 };
909 
910 using tensor::ExtractSliceOp;
911 
912 /// Return the destinations that an ForeachThreadOp is inserting into. One per
913 /// ParallelInsertSliceOp.
914 static SmallVector<OpOperand *>
915 getInsertionDest(ForeachThreadOp foreachThreadOp) {
916   PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
917   SmallVector<OpOperand *> result;
918   terminator.walk([&](ParallelInsertSliceOp insertOp) {
919     result.push_back(&insertOp->getOpOperand(1) /*dest*/);
920   });
921   return result;
922 }
923 
924 /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
925 /// region. There are op interfaces for the terminators (PerformConcurrentlyOp
926 /// and ParallelInsertSliceOp), but these are only used during analysis. Not
927 /// for bufferization.
928 struct ForeachThreadOpInterface
929     : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface,
930                                                     ForeachThreadOp> {
931   SmallVector<OpOperand *>
932   getAliasingOpOperand(Operation *op, OpResult opResult,
933                        const AnalysisState &state) const {
934     // Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
935     auto foreachThreadOp = cast<ForeachThreadOp>(op);
936     return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]};
937   }
938 
939   bool isMemoryWrite(Operation *op, OpResult opResult,
940                      const AnalysisState &state) const {
941     // This op is a memory write. Stop lookup here to avoid finding false
942     // conflicts involving this op and one of the ops in the region. This is
943     // similar to how scf.if ops are analyzed.
944     return true;
945   }
946 
947   BufferRelation bufferRelation(Operation *op, OpResult opResult,
948                                 const AnalysisState &state) const {
949     return BufferRelation::Equivalent;
950   }
951 
952   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
953                                  const AnalysisState &state) const {
954     auto bufferizableOp = cast<BufferizableOpInterface>(op);
955     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
956       return failure();
957 
958     OpBuilder::InsertionGuard g(rewriter);
959     auto foreachThreadOp = cast<ForeachThreadOp>(op);
960     for (OpResult opResult : foreachThreadOp->getOpResults()) {
961       SmallVector<OpOperand *> destOperands =
962           state.getAliasingOpOperand(opResult);
963       assert(destOperands.size() == 1 &&
964              "expected exactly one aliasing OpOperand");
965       assert(isa<ParallelInsertSliceOp>(destOperands.front()->getOwner()) &&
966              "expected ParallelInsertSliceOp");
967 
968       // Nothing to do if there is no conflict.
969       if (state.isInPlace(*destOperands.front()))
970         continue;
971 
972       // Insert tensor allocation.
973       bool isYielded = state.isTensorYielded(opResult);
974       Value alloc = allocateTensorForShapedValue(rewriter, op->getLoc(),
975                                                  destOperands.front()->get(),
976                                                  /*escape=*/isYielded);
977 
978       // Update terminator operand.
979       rewriter.updateRootInPlace(destOperands.front()->getOwner(),
980                                  [&]() { destOperands.front()->set(alloc); });
981     }
982 
983     return success();
984   }
985 
986   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
987                           const BufferizationOptions &options) const {
988     auto foreachThreadOp = cast<ForeachThreadOp>(op);
989 
990 #ifndef NDEBUG
991     // ParallelInsertSliceOpInterface replaces all uses.
992     for (OpResult opResult : foreachThreadOp->getOpResults())
993       assert(opResult.getUses().empty() &&
994              "expected that all uses were already replaced");
995 #endif // NDEBUG
996 
997     // Create new ForeachThreadOp without any results and drop the automatically
998     // introduced terminator.
999     TypeRange newResultTypes;
1000     auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
1001         foreachThreadOp.getLoc(), newResultTypes,
1002         foreachThreadOp.getNumThreads());
1003     newForeachThreadOp.getBody()->getTerminator()->erase();
1004 
1005     // Move over block contents of the old op.
1006     rewriter.mergeBlocks(foreachThreadOp.getBody(),
1007                          newForeachThreadOp.getBody(),
1008                          {newForeachThreadOp.getBody()->getArguments()});
1009 
1010     // Remove the old op.
1011     rewriter.eraseOp(op);
1012 
1013     return success();
1014   }
1015 };
1016 
1017 /// Nothing to do for PerformConcurrentlyOp.
1018 struct PerformConcurrentlyOpInterface
1019     : public BufferizableOpInterface::ExternalModel<
1020           PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
1021   LogicalResult bufferize(Operation *op, RewriterBase &b,
1022                           const BufferizationOptions &options) const {
1023     llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1024     return failure();
1025   }
1026 };
1027 
1028 /// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
1029 /// equivalent operand / result and same offset/sizes/strides specification).
1030 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
1031                                          ExtractSliceOp st,
1032                                          ParallelInsertSliceOp sti) {
1033   if (!st || !sti)
1034     return false;
1035   if (st != sti &&
1036       !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
1037     return false;
1038   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
1039     return false;
1040   return true;
1041 }
1042 
1043 /// Return true if `value` is originating from an ExtractSliceOp that matches
1044 /// the given InsertSliceOp.
1045 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
1046                                       ParallelInsertSliceOp insertOp) {
1047   auto condition = [&](Value val) {
1048     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
1049       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
1050         return true;
1051     return false;
1052   };
1053 
1054   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
1055                       condition);
1056 }
1057 
1058 /// Analysis of ParallelInsertSliceOp.
1059 struct ParallelInsertSliceOpInterface
1060     : public BufferizableOpInterface::ExternalModel<
1061           ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
1062   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
1063                                             const AnalysisState &state) const {
1064     if (&opOperand != &op->getOpOperand(1) /*dest*/)
1065       return {};
1066 
1067     // ParallelInsertSliceOp itself has no results. Tensors are returned via
1068     // the parent op.
1069     auto foreachThreadOp = op->getParentOfType<ForeachThreadOp>();
1070     assert(foreachThreadOp &&
1071            "could not find valid owner of parallel_insert_slice");
1072 
1073     // The i-th ParallelInsertSliceOp result is returned via the i-th OpResult
1074     // of the parent ForeachThreadOp.
1075     Block *block = op->getBlock();
1076     unsigned int opIdx = 0;
1077     for (ParallelInsertSliceOp insertOp :
1078          block->getOps<ParallelInsertSliceOp>()) {
1079       if (insertOp.getOperation() == op)
1080         break;
1081       ++opIdx;
1082     }
1083     assert(opIdx < foreachThreadOp->getNumResults() &&
1084            "could not find op inside terminator op");
1085 
1086     return {foreachThreadOp->getResult(opIdx)};
1087   }
1088 
1089   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1090                               const AnalysisState &state) const {
1091     return true;
1092   }
1093 
1094   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1095                                const AnalysisState &state) const {
1096     return &opOperand == &op->getOpOperand(1) /*dest*/;
1097   }
1098 
1099   BufferRelation bufferRelation(Operation *op, OpResult opResult,
1100                                 const AnalysisState &state) const {
1101     return BufferRelation::Equivalent;
1102   }
1103 
1104   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
1105                                  const AnalysisState &state) const {
1106     return success();
1107   }
1108 
1109   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1110                           const BufferizationOptions &options) const {
1111     OpBuilder::InsertionGuard g(rewriter);
1112     auto insertOp = cast<ParallelInsertSliceOp>(op);
1113     auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(op->getParentOp());
1114     auto foreachThreadOp =
1115         cast<ForeachThreadOp>(performConcurrentlyOp->getParentOp());
1116 
1117     // If the op bufferizes out-of-place, allocate the copy before the
1118     // ForeachThreadOp.
1119     rewriter.setInsertionPoint(foreachThreadOp);
1120     FailureOr<Value> destBuffer =
1121         getBuffer(rewriter, insertOp.getDest(), options);
1122     if (failed(destBuffer))
1123       return failure();
1124 
1125     // Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp.
1126     rewriter.setInsertionPoint(performConcurrentlyOp);
1127     FailureOr<Value> srcBuffer =
1128         getBuffer(rewriter, insertOp.getSource(), options);
1129     if (failed(srcBuffer))
1130       return failure();
1131     Value subview = rewriter.create<memref::SubViewOp>(
1132         insertOp.getLoc(), *destBuffer, insertOp.getMixedOffsets(),
1133         insertOp.getMixedSizes(), insertOp.getMixedStrides());
1134     // This memcpy will fold away if everything bufferizes in-place.
1135     if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer,
1136                                     subview)))
1137       return failure();
1138     rewriter.eraseOp(op);
1139 
1140     // Replace all uses of ForeachThreadOp (just the corresponding result).
1141     rewriter.setInsertionPointAfter(foreachThreadOp);
1142     Value toTensorOp =
1143         rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer);
1144     unsigned resultNum = 0;
1145     for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) {
1146       if (&nextOp == op)
1147         break;
1148       resultNum++;
1149     }
1150     assert(resultNum < foreachThreadOp->getNumResults() &&
1151            "ParallelInsertSliceOp not found in PerformConcurrentlyOp");
1152     SmallVector<OpOperand *> resultUses = llvm::to_vector(
1153         llvm::map_range(foreachThreadOp->getResult(resultNum).getUses(),
1154                         [](OpOperand &use) { return &use; }));
1155     for (OpOperand *use : resultUses) {
1156       rewriter.updateRootInPlace(use->getOwner(),
1157                                  [&]() { use->set(toTensorOp); });
1158     }
1159     return success();
1160   }
1161 
1162   // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
1163   // the code.
1164   bool isNotConflicting(Operation *op, OpOperand *uRead,
1165                         OpOperand *uConflictingWrite,
1166                         const AnalysisState &state) const {
1167     Operation *readingOp = uRead->getOwner();
1168     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
1169 
1170     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
1171     // uRead is an InsertSliceOp...
1172     if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
1173       // As an example, consider the following IR.
1174       //
1175       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1176       // %1 = linalg.fill %cst, %0 {inplace= [true] }
1177       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1178       //     {inplace= [true] }
1179 
1180       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
1181       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1182           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
1183                                     insertSliceOp))
1184         // Case 1: The main insight is that InsertSliceOp reads only part of
1185         // the destination tensor. The overwritten area is not read. If
1186         // uConflictingWrite writes into exactly the memory location that is
1187         // being read by uRead, this is not a conflict.
1188         //
1189         // In the above example:
1190         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
1191         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
1192         //
1193         // The read of %t does not conflict with the write of the FillOp
1194         // (same aliases!) because the area that the FillOp operates on is
1195         // exactly the one that is *not* read via %t.
1196         return true;
1197 
1198       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
1199           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1200           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
1201         // Case 2: The read of the source tensor and the write to the dest
1202         // tensor via an InsertSliceOp is not a conflict if the read is
1203         // reading exactly that part of an equivalent tensor that the
1204         // InsertSliceOp is writing.
1205         //
1206         // In the above example:
1207         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
1208         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1209         return true;
1210     }
1211 
1212     // If uConflictingWrite is an InsertSliceOp...
1213     if (auto insertSliceOp =
1214             dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
1215       // As an example, consider the following IR.
1216       //
1217       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1218       // %1 = linalg.fill %cst, %0 {inplace= [true] }
1219       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1220       //     {inplace= [true] }
1221       // %3 = vector.transfer_read %1, %cst
1222       //
1223       // In the above example:
1224       // uRead             = OpOperand 0 (%1) of vector.transfer_read
1225       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1226       // lastWrite         = %1
1227       //
1228       // This is not a conflict because the InsertSliceOp overwrites the
1229       // memory segment of %1 with the exact same data. (Effectively, there
1230       // is no memory write here.)
1231       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1232           state.areEquivalentBufferizedValues(uRead->get(),
1233                                               insertSliceOp.getSource()) &&
1234           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
1235                                     insertSliceOp))
1236         return true;
1237 
1238     return false;
1239   }
1240 };
1241 
1242 } // namespace
1243 } // namespace scf
1244 } // namespace mlir
1245 
1246 void mlir::scf::registerBufferizableOpInterfaceExternalModels(
1247     DialectRegistry &registry) {
1248   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1249     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1250     ForOp::attachInterface<ForOpInterface>(*ctx);
1251     IfOp::attachInterface<IfOpInterface>(*ctx);
1252     ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
1253     ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1254         *ctx);
1255     PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
1256         *ctx);
1257     WhileOp::attachInterface<WhileOpInterface>(*ctx);
1258     YieldOp::attachInterface<YieldOpInterface>(*ctx);
1259   });
1260 }
1261