1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 using namespace mlir::bufferization;
22 using namespace mlir::scf;
23 
24 namespace mlir {
25 namespace scf {
26 namespace {
27 
28 // bufferization.to_memref is not allowed to change the rank.
29 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
30 #ifndef NDEBUG
31   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
32   assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
33                                 rankedTensorType.getRank())) &&
34          "to_memref would be invalid: mismatching ranks");
35 #endif
36 }
37 
38 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
39 /// fully implemented at the moment.
40 struct ExecuteRegionOpInterface
41     : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
42                                                     scf::ExecuteRegionOp> {
43   SmallVector<OpOperand *>
44   getAliasingOpOperand(Operation *op, OpResult opResult,
45                        const AnalysisState &state) const {
46     // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
47     // any SSA value that is in scope. To allow for use-def chain traversal
48     // through ExecuteRegionOps in the analysis, the corresponding yield value
49     // is considered to be aliasing with the result.
50     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
51     size_t resultNum = std::distance(op->getOpResults().begin(),
52                                      llvm::find(op->getOpResults(), opResult));
53     // TODO: Support multiple blocks.
54     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
55            "expected exactly 1 block");
56     auto yieldOp = dyn_cast<scf::YieldOp>(
57         executeRegionOp.getRegion().front().getTerminator());
58     assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
59     return {&yieldOp->getOpOperand(resultNum)};
60   }
61 
62   // TODO: For better bufferization results, this could return `true` only if
63   // there is a memory write in the region.
64   bool isMemoryWrite(Operation *op, OpResult opResult,
65                      const AnalysisState &state) const {
66     // Similar to scf.if, results of this op are always considered memory writes
67     // in the analysis. This is a useful pattern for all ops that have tensor
68     // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
69     // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
70     // ops without OpOperands.
71     return true;
72   }
73 
74   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
75                           BufferizationState &state) const {
76     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
77 
78     // Compute new result types.
79     SmallVector<Type> newResultTypes;
80     for (Type type : executeRegionOp->getResultTypes()) {
81       if (auto tensorType = type.dyn_cast<TensorType>()) {
82         // TODO: Infer the result type instead of computing it.
83         newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
84       } else {
85         newResultTypes.push_back(type);
86       }
87     }
88 
89     // Create new op and move over region.
90     auto newOp =
91         rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
92     newOp.getRegion().takeBody(executeRegionOp.getRegion());
93 
94     // Update terminator.
95     assert(newOp.getRegion().getBlocks().size() == 1 &&
96            "only 1 block supported");
97     Block *newBlock = &newOp.getRegion().front();
98     auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
99     rewriter.setInsertionPoint(yieldOp);
100     SmallVector<Value> newYieldValues;
101     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
102       Value val = it.value();
103       if (val.getType().isa<TensorType>()) {
104         newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
105             yieldOp.getLoc(), newResultTypes[it.index()], val));
106       } else {
107         newYieldValues.push_back(val);
108       }
109     }
110     rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
111 
112     // Update all uses of the old op.
113     rewriter.setInsertionPointAfter(newOp);
114     SmallVector<Value> newResults;
115     for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
116       if (it.value().isa<TensorType>()) {
117         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
118             executeRegionOp.getLoc(), newOp->getResult(it.index())));
119       } else {
120         newResults.push_back(newOp->getResult(it.index()));
121       }
122     }
123 
124     // Replace old op.
125     rewriter.replaceOp(executeRegionOp, newResults);
126 
127     return success();
128   }
129 
130   BufferRelation bufferRelation(Operation *op, OpResult opResult,
131                                 const AnalysisState &state) const {
132     return BufferRelation::Equivalent;
133   }
134 };
135 
136 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
137 struct IfOpInterface
138     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
139   SmallVector<OpOperand *>
140   getAliasingOpOperand(Operation *op, OpResult opResult,
141                        const AnalysisState &state) const {
142     // IfOps do not have tensor OpOperands. The yielded value can be any SSA
143     // value that is in scope. To allow for use-def chain traversal through
144     // IfOps in the analysis, both corresponding yield values from the then/else
145     // branches are considered to be aliasing with the result.
146     auto ifOp = cast<scf::IfOp>(op);
147     size_t resultNum = std::distance(op->getOpResults().begin(),
148                                      llvm::find(op->getOpResults(), opResult));
149     return {&ifOp.thenYield()->getOpOperand(resultNum),
150             &ifOp.elseYield()->getOpOperand(resultNum)};
151   }
152 
153   // TODO: For better bufferization results, this could return `true` only if
154   // there is a memory write in one (or both) of the branches. Since this is not
155   // allowed at the moment, we should never encounter scf.ifs that yield
156   // unmodified tensors. Such scf.yield ops could just fold away.
157   bool isMemoryWrite(Operation *op, OpResult opResult,
158                      const AnalysisState &state) const {
159     // IfOp results are always considered memory writes in the analysis. This
160     // design decision simplifies the analysis considerably. E.g., consider the
161     // following test case:
162     //
163     // %0 = "some_writing_op" : tensor<?xf32>
164     // %r = scf.if %c -> (tensor<?xf32>) {
165     //   scf.yield %0
166     // } else {
167     //   %1 = "another_writing_op"(%0) : tensor<?xf32>
168     // }
169     // "some_reading_op"(%r)
170     //
171     // "another_writing_op" in the above example should be able to bufferize
172     // inplace in the absence of another read of %0. However, if the scf.if op
173     // would not be considered a "write", the analysis would detect the
174     // following conflict:
175     //
176     // * read = some_reading_op
177     // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
178     // * conflictingWrite = %1
179     //
180     // For more details, check the "scf.IfOp" section of the design document.
181     return true;
182   }
183 
184   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
185                           BufferizationState &state) const {
186     auto ifOp = cast<scf::IfOp>(op);
187 
188     // Compute new types of the bufferized scf.if op.
189     SmallVector<Type> newTypes;
190     for (Type returnType : ifOp->getResultTypes()) {
191       if (auto tensorType = returnType.dyn_cast<TensorType>()) {
192         // TODO: Infer the result type instead of computing it.
193         newTypes.push_back(getMemRefType(tensorType, state.getOptions()));
194       } else {
195         newTypes.push_back(returnType);
196       }
197     }
198 
199     // Create new op.
200     auto newIfOp =
201         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
202                                    /*withElseRegion=*/true);
203 
204     // Remove terminators.
205     if (!newIfOp.thenBlock()->empty()) {
206       rewriter.eraseOp(newIfOp.thenBlock()->getTerminator());
207       rewriter.eraseOp(newIfOp.elseBlock()->getTerminator());
208     }
209 
210     // Move over then/else blocks.
211     rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
212     rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
213 
214     // Update scf.yield of new then-block.
215     auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator());
216     rewriter.setInsertionPoint(thenYieldOp);
217     SmallVector<Value> thenYieldValues;
218     for (OpOperand &operand : thenYieldOp->getOpOperands()) {
219       if (operand.get().getType().isa<TensorType>()) {
220         ensureToMemrefOpIsValid(operand.get(),
221                                 newTypes[operand.getOperandNumber()]);
222         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
223             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
224             operand.get());
225         operand.set(toMemrefOp);
226       }
227     }
228 
229     // Update scf.yield of new else-block.
230     auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator());
231     rewriter.setInsertionPoint(elseYieldOp);
232     SmallVector<Value> elseYieldValues;
233     for (OpOperand &operand : elseYieldOp->getOpOperands()) {
234       if (operand.get().getType().isa<TensorType>()) {
235         ensureToMemrefOpIsValid(operand.get(),
236                                 newTypes[operand.getOperandNumber()]);
237         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
238             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
239             operand.get());
240         operand.set(toMemrefOp);
241       }
242     }
243 
244     // Replace op results.
245     replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
246 
247     return success();
248   }
249 
250   BufferRelation bufferRelation(Operation *op, OpResult opResult,
251                                 const AnalysisState &state) const {
252     // IfOp results are equivalent to their corresponding yield values if both
253     // yield values are equivalent to each other.
254     auto bufferizableOp = cast<BufferizableOpInterface>(op);
255     SmallVector<OpOperand *> yieldValues =
256         bufferizableOp.getAliasingOpOperand(opResult, state);
257     assert(yieldValues.size() == 2 && "expected 2 yield values");
258     bool equivalentYields = state.areEquivalentBufferizedValues(
259         yieldValues[0]->get(), yieldValues[1]->get());
260     return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
261   }
262 };
263 
264 /// Helper function for loop bufferization. Return the indices of all values
265 /// that have a tensor type.
266 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
267   DenseSet<int64_t> result;
268   for (const auto &it : llvm::enumerate(values))
269     if (it.value().getType().isa<TensorType>())
270       result.insert(it.index());
271   return result;
272 }
273 
274 /// Helper function for loop bufferization. Return the indices of all
275 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
276 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
277                                        ValueRange yieldedValues,
278                                        const AnalysisState &state) {
279   unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
280   DenseSet<int64_t> result;
281   for (unsigned int i = 0; i < minSize; ++i) {
282     if (!bbArgs[i].getType().isa<TensorType>() ||
283         !yieldedValues[i].getType().isa<TensorType>())
284       continue;
285     if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
286       result.insert(i);
287   }
288   return result;
289 }
290 
291 /// Helper function for loop bufferization. Cast the given buffer to the given
292 /// memref type.
293 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
294   assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
295   assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
296   // If the buffer already has the correct type, no cast is needed.
297   if (buffer.getType() == type)
298     return buffer;
299   // TODO: In case `type` has a layout map that is not the fully dynamic
300   // one, we may not be able to cast the buffer. In that case, the loop
301   // iter_arg's layout map must be changed (see uses of `castBuffer`).
302   assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
303          "scf.while op bufferization: cast incompatible");
304   return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
305 }
306 
307 /// Helper function for loop bufferization. Return the bufferized values of the
308 /// given OpOperands. If an operand is not a tensor, return the original value.
309 static SmallVector<Value> getBuffers(RewriterBase &rewriter,
310                                      MutableArrayRef<OpOperand> operands,
311                                      BufferizationState &state) {
312   SmallVector<Value> result;
313   for (OpOperand &opOperand : operands) {
314     if (opOperand.get().getType().isa<TensorType>()) {
315       FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
316       if (failed(resultBuffer))
317         return {};
318       result.push_back(*resultBuffer);
319     } else {
320       result.push_back(opOperand.get());
321     }
322   }
323   return result;
324 }
325 
326 /// Helper function for loop bufferization. Compute the buffer that should be
327 /// yielded from a loop block (loop body or loop condition). If the given tensor
328 /// is equivalent to the corresponding block argument (as indicated by
329 /// `isEquivalent`), the buffer can be yielded directly. Otherwise, a new buffer
330 /// copy must be yielded.
331 ///
332 /// According to the `BufferizableOpInterface` implementation of scf loops, a
333 /// a bufferized OpResult may alias only with the corresponding bufferized
334 /// init_arg and with no other buffers. I.e., the i-th OpResult may alias with
335 /// the i-th init_arg; but not with any other OpOperand. If a corresponding
336 /// OpResult/init_arg pair bufferized to equivalent buffers (as indicated by
337 /// `isEquivalent`), this aliasing requirement is satisfied. Otherwise, we
338 /// cannot be sure and must yield a new buffer copy. (New buffer copies do not
339 /// alias with any buffer.)
340 static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
341                               BaseMemRefType type, bool isEquivalent,
342                               BufferizationState &state) {
343   assert(tensor.getType().isa<TensorType>() && "expected tensor");
344   ensureToMemrefOpIsValid(tensor, type);
345   Value yieldedVal =
346       bufferization::lookupBuffer(rewriter, tensor, state.getOptions());
347 
348   if (isEquivalent)
349     // Yielded value is equivalent to the corresponding iter_arg bbArg.
350     // Yield the value directly. Most IR should be like that. Everything
351     // else must be resolved with copies and is potentially inefficient.
352     // By default, such problematic IR would already have been rejected
353     // during `verifyAnalysis`, unless `allow-return-allocs`.
354     return castBuffer(rewriter, yieldedVal, type);
355 
356   // It is not certain that the yielded value and the iter_arg bbArg
357   // have the same buffer. Allocate a new buffer and copy. The yielded
358   // buffer will get deallocated by `deallocateBuffers`.
359 
360   // TODO: There are cases in which it is not neccessary to return a new
361   // buffer allocation. E.g., when equivalent values are yielded in a
362   // different order. This could be resolved with copies.
363   Optional<Value> yieldedAlloc = state.createAlloc(
364       rewriter, tensor.getLoc(), yieldedVal, /*deallocMemref=*/false);
365   // TODO: We should rollback, but for now just assume that this always
366   // succeeds.
367   assert(yieldedAlloc.hasValue() && "could not create alloc");
368   LogicalResult copyStatus = state.getOptions().createMemCpy(
369       rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc);
370   (void)copyStatus;
371   assert(succeeded(copyStatus) && "could not create memcpy");
372 
373   // The iter_arg memref type may have a layout map. Cast the new buffer
374   // to the same type if needed.
375   return castBuffer(rewriter, *yieldedAlloc, type);
376 }
377 
378 /// Helper function for loop bufferization. Given a range of values, apply
379 /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
380 /// value in the result vector.
381 static SmallVector<Value>
382 convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
383                     llvm::function_ref<Value(Value, int64_t)> func) {
384   SmallVector<Value> result;
385   for (const auto &it : llvm::enumerate(values)) {
386     size_t idx = it.index();
387     Value val = it.value();
388     result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val);
389   }
390   return result;
391 }
392 
393 /// Helper function for loop bufferization. Given a list of pre-bufferization
394 /// yielded values, compute the list of bufferized yielded values.
395 SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
396                                     TypeRange bufferizedTypes,
397                                     const DenseSet<int64_t> &tensorIndices,
398                                     const DenseSet<int64_t> &equivalentTensors,
399                                     BufferizationState &state) {
400   return convertTensorValues(
401       values, tensorIndices, [&](Value val, int64_t index) {
402         return getYieldedBuffer(rewriter, val,
403                                 bufferizedTypes[index].cast<BaseMemRefType>(),
404                                 equivalentTensors.contains(index), state);
405       });
406 }
407 
408 /// Helper function for loop bufferization. Given a list of bbArgs of the new
409 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
410 /// ToTensorOps, so that the block body can be moved over to the new op.
411 SmallVector<Value>
412 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
413                      const DenseSet<int64_t> &tensorIndices) {
414   return convertTensorValues(
415       bbArgs, tensorIndices, [&](Value val, int64_t index) {
416         return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
417       });
418 }
419 
420 /// Bufferization of scf.for. Replace with a new scf.for that operates on
421 /// memrefs.
422 struct ForOpInterface
423     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
424                                                     scf::ForOp> {
425   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
426                               const AnalysisState &state) const {
427     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
428     // its matching bbArg may.
429     auto forOp = cast<scf::ForOp>(op);
430     return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
431   }
432 
433   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
434                                const AnalysisState &state) const {
435     // Tensor iter_args of scf::ForOps are always considered as a write.
436     return true;
437   }
438 
439   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
440                                             const AnalysisState &state) const {
441     auto forOp = cast<scf::ForOp>(op);
442     return {forOp.getResultForOpOperand(opOperand)};
443   }
444 
445   BufferRelation bufferRelation(Operation *op, OpResult opResult,
446                                 const AnalysisState &state) const {
447     // ForOp results are equivalent to their corresponding init_args if the
448     // corresponding iter_args and yield values are equivalent.
449     auto forOp = cast<scf::ForOp>(op);
450     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
451     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
452     auto yieldOp =
453         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
454     bool equivalentYield = state.areEquivalentBufferizedValues(
455         bbArg, yieldOp->getOperand(opResult.getResultNumber()));
456     return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
457   }
458 
459   bool isWritable(Operation *op, Value value,
460                   const AnalysisState &state) const {
461     // Interestingly, scf::ForOp's bbArg can **always** be viewed
462     // inplace from the perspective of ops nested under:
463     //   1. Either the matching iter operand is not bufferized inplace and an
464     //      alloc + optional copy makes the bbArg itself inplaceable.
465     //   2. Or the matching iter operand is bufferized inplace and bbArg just
466     //      bufferizes to that too.
467     return true;
468   }
469 
470   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
471                           BufferizationState &state) const {
472     auto forOp = cast<scf::ForOp>(op);
473     auto oldYieldOp =
474         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
475     Block *oldLoopBody = &forOp.getLoopBody().front();
476 
477     // Indices of all iter_args that have tensor type. These are the ones that
478     // are bufferized.
479     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
480     // For every yielded value, is the value equivalent to its corresponding
481     // bbArg?
482     DenseSet<int64_t> equivalentYields =
483         getEquivalentBuffers(forOp.getRegionIterArgs(), oldYieldOp.getResults(),
484                              state.getAnalysisState());
485 
486     // The new memref init_args of the loop.
487     SmallVector<Value> initArgs =
488         getBuffers(rewriter, forOp.getIterOpOperands(), state);
489 
490     // Construct a new scf.for op with memref instead of tensor values.
491     auto newForOp = rewriter.create<scf::ForOp>(
492         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
493         forOp.getStep(), initArgs);
494     newForOp->setAttrs(forOp->getAttrs());
495     ValueRange initArgsRange(initArgs);
496     TypeRange initArgsTypes(initArgsRange);
497     Block *loopBody = &newForOp.getLoopBody().front();
498 
499     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
500     // iter_args of the new loop in ToTensorOps.
501     rewriter.setInsertionPointToStart(loopBody);
502     SmallVector<Value> iterArgs =
503         getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
504     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
505 
506     // Erase terminator if present.
507     if (iterArgs.size() == 1)
508       rewriter.eraseOp(loopBody->getTerminator());
509 
510     // Move loop body to new loop.
511     rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
512 
513     // Update scf.yield of new loop.
514     auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
515     rewriter.setInsertionPoint(yieldOp);
516     SmallVector<Value> yieldValues =
517         getYieldedValues(rewriter, yieldOp.getResults(), initArgsTypes, indices,
518                          equivalentYields, state);
519     yieldOp.getResultsMutable().assign(yieldValues);
520 
521     // Replace loop results.
522     replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
523 
524     return success();
525   }
526 
527   /// Assert that yielded values of an scf.for op are equivalent to their
528   /// corresponding bbArgs. In that case, the buffer relations of the
529   /// corresponding OpResults are "Equivalent".
530   ///
531   /// If this is not the case, an allocs+copies are inserted and yielded from
532   /// the loop. This could be a performance problem, so it must be explicitly
533   /// activated with `alloc-return-allocs`.
534   LogicalResult verifyAnalysis(Operation *op,
535                                const AnalysisState &state) const {
536     const auto &options =
537         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
538     if (options.allowReturnAllocs)
539       return success();
540 
541     auto forOp = cast<scf::ForOp>(op);
542     auto yieldOp =
543         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
544     for (OpResult opResult : op->getOpResults()) {
545       if (!opResult.getType().isa<TensorType>())
546         continue;
547 
548       // Note: This is overly strict. We should check for aliasing bufferized
549       // values. But we don't have a "must-alias" analysis yet.
550       if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
551         return yieldOp->emitError()
552                << "Yield operand #" << opResult.getResultNumber()
553                << " is not equivalent to the corresponding iter bbArg";
554     }
555 
556     return success();
557   }
558 };
559 
560 /// Bufferization of scf.while. Replace with a new scf.while that operates on
561 /// memrefs.
562 struct WhileOpInterface
563     : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
564                                                     scf::WhileOp> {
565   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
566                               const AnalysisState &state) const {
567     // Tensor iter_args of scf::WhileOps are always considered as a read.
568     return true;
569   }
570 
571   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
572                                const AnalysisState &state) const {
573     // Tensor iter_args of scf::WhileOps are always considered as a write.
574     return true;
575   }
576 
577   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
578                                             const AnalysisState &state) const {
579     auto whileOp = cast<scf::WhileOp>(op);
580     unsigned int idx = opOperand.getOperandNumber();
581 
582     // The OpResults and OpOperands may not match. They may not even have the
583     // same type. The number of OpResults and OpOperands can also differ.
584     if (idx >= op->getNumResults() ||
585         opOperand.get().getType() != op->getResult(idx).getType())
586       return {};
587 
588     // The only aliasing OpResult may be the one at the same index.
589     return {whileOp->getResult(idx)};
590   }
591 
592   BufferRelation bufferRelation(Operation *op, OpResult opResult,
593                                 const AnalysisState &state) const {
594     // WhileOp results are equivalent to their corresponding init_args if the
595     // corresponding iter_args and yield values are equivalent (for both the
596     // "before" and the "after" block).
597     unsigned int resultNumber = opResult.getResultNumber();
598     auto whileOp = cast<scf::WhileOp>(op);
599 
600     // The "before" region bbArgs and the OpResults may not match.
601     if (resultNumber >= whileOp.getBeforeArguments().size())
602       return BufferRelation::None;
603     if (opResult.getType() !=
604         whileOp.getBeforeArguments()[resultNumber].getType())
605       return BufferRelation::None;
606 
607     auto conditionOp = whileOp.getConditionOp();
608     BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
609     Value conditionOperand = conditionOp.getArgs()[resultNumber];
610     bool equivCondition =
611         state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
612 
613     auto yieldOp = whileOp.getYieldOp();
614     BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
615     Value yieldOperand = yieldOp.getOperand(resultNumber);
616     bool equivYield =
617         state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
618 
619     return equivCondition && equivYield ? BufferRelation::Equivalent
620                                         : BufferRelation::None;
621   }
622 
623   bool isWritable(Operation *op, Value value,
624                   const AnalysisState &state) const {
625     // Interestingly, scf::WhileOp's bbArg can **always** be viewed
626     // inplace from the perspective of ops nested under:
627     //   1. Either the matching iter operand is not bufferized inplace and an
628     //      alloc + optional copy makes the bbArg itself inplaceable.
629     //   2. Or the matching iter operand is bufferized inplace and bbArg just
630     //      bufferizes to that too.
631     return true;
632   }
633 
634   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
635                           BufferizationState &state) const {
636     auto whileOp = cast<scf::WhileOp>(op);
637 
638     assert(whileOp.getBefore().getBlocks().size() == 1 &&
639            "regions with multiple blocks not supported");
640     Block *beforeBody = &whileOp.getBefore().front();
641     assert(whileOp.getAfter().getBlocks().size() == 1 &&
642            "regions with multiple blocks not supported");
643     Block *afterBody = &whileOp.getAfter().front();
644 
645     // Indices of all bbArgs that have tensor type. These are the ones that
646     // are bufferized. The "before" and "after" regions may have different args.
647     DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
648     DenseSet<int64_t> indicesAfter =
649         getTensorIndices(whileOp.getAfterArguments());
650 
651     // For every yielded value, is the value equivalent to its corresponding
652     // bbArg?
653     DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
654         whileOp.getBeforeArguments(), whileOp.getConditionOp().getArgs(),
655         state.getAnalysisState());
656     DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
657         whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(),
658         state.getAnalysisState());
659 
660     // The new memref init_args of the loop.
661     SmallVector<Value> initArgs =
662         getBuffers(rewriter, whileOp->getOpOperands(), state);
663 
664     // The result types of a WhileOp are the same as the "after" bbArg types.
665     SmallVector<Type> argsTypesAfter = llvm::to_vector(
666         llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
667           return state.getBufferType(bbArg).cast<Type>();
668         }));
669 
670     // Construct a new scf.while op with memref instead of tensor values.
671     ValueRange argsRangeBefore(initArgs);
672     TypeRange argsTypesBefore(argsRangeBefore);
673     auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
674                                                     argsTypesAfter, initArgs);
675 
676     // Add before/after regions to the new op.
677     SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc());
678     SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
679                                          whileOp.getLoc());
680     Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
681     newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
682     Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
683     newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
684 
685     // Set up new iter_args and move the loop condition block to the new op.
686     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
687     // in ToTensorOps.
688     rewriter.setInsertionPointToStart(newBeforeBody);
689     SmallVector<Value> newBeforeArgs = getBbArgReplacements(
690         rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
691     rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
692 
693     // Update scf.condition of new loop.
694     auto newConditionOp = newWhileOp.getConditionOp();
695     rewriter.setInsertionPoint(newConditionOp);
696     // Only equivalent buffers or new buffer allocations may be yielded to the
697     // "after" region.
698     // TODO: This could be relaxed for better bufferization results.
699     SmallVector<Value> newConditionArgs =
700         getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
701                          indicesAfter, equivalentYieldsBefore, state);
702     newConditionOp.getArgsMutable().assign(newConditionArgs);
703 
704     // Set up new iter_args and move the loop body block to the new op.
705     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
706     // in ToTensorOps.
707     rewriter.setInsertionPointToStart(newAfterBody);
708     SmallVector<Value> newAfterArgs = getBbArgReplacements(
709         rewriter, newWhileOp.getAfterArguments(), indicesAfter);
710     rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
711 
712     // Update scf.yield of the new loop.
713     auto newYieldOp = newWhileOp.getYieldOp();
714     rewriter.setInsertionPoint(newYieldOp);
715     // Only equivalent buffers or new buffer allocations may be yielded to the
716     // "before" region.
717     // TODO: This could be relaxed for better bufferization results.
718     SmallVector<Value> newYieldValues =
719         getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
720                          indicesBefore, equivalentYieldsAfter, state);
721     newYieldOp.getResultsMutable().assign(newYieldValues);
722 
723     // Replace loop results.
724     replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
725 
726     return success();
727   }
728 
729   /// Assert that yielded values of an scf.while op are equivalent to their
730   /// corresponding bbArgs. In that case, the buffer relations of the
731   /// corresponding OpResults are "Equivalent".
732   ///
733   /// If this is not the case, allocs+copies are inserted and yielded from
734   /// the loop. This could be a performance problem, so it must be explicitly
735   /// activated with `alloc-return-allocs`.
736   ///
737   /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
738   /// equivalence condition must be checked for both.
739   LogicalResult verifyAnalysis(Operation *op,
740                                const AnalysisState &state) const {
741     auto whileOp = cast<scf::WhileOp>(op);
742     const auto &options =
743         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
744     if (options.allowReturnAllocs)
745       return success();
746 
747     auto conditionOp = whileOp.getConditionOp();
748     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
749       if (!it.value().getType().isa<TensorType>())
750         continue;
751       if (!state.areEquivalentBufferizedValues(
752               it.value(), conditionOp->getBlock()->getArgument(it.index())))
753         return conditionOp->emitError()
754                << "Condition arg #" << it.index()
755                << " is not equivalent to the corresponding iter bbArg";
756     }
757 
758     auto yieldOp = whileOp.getYieldOp();
759     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
760       if (!it.value().getType().isa<TensorType>())
761         continue;
762       if (!state.areEquivalentBufferizedValues(
763               it.value(), yieldOp->getBlock()->getArgument(it.index())))
764         return yieldOp->emitError()
765                << "Yield operand #" << it.index()
766                << " is not equivalent to the corresponding iter bbArg";
767     }
768 
769     return success();
770   }
771 };
772 
773 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
774 /// this is for analysis only.
775 struct YieldOpInterface
776     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
777                                                     scf::YieldOp> {
778   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
779                               const AnalysisState &state) const {
780     return true;
781   }
782 
783   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
784                                const AnalysisState &state) const {
785     return false;
786   }
787 
788   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
789                                             const AnalysisState &state) const {
790     if (isa<scf::IfOp>(op->getParentOp()))
791       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
792     if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
793       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
794     return {};
795   }
796 
797   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
798                             const AnalysisState &state) const {
799     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
800     // may be generated inside the block. We should not return/yield allocations
801     // when possible.
802     return true;
803   }
804 
805   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
806                           BufferizationState &state) const {
807     auto yieldOp = cast<scf::YieldOp>(op);
808     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
809             yieldOp->getParentOp()))
810       return yieldOp->emitError("unsupported scf::YieldOp parent");
811     return success();
812   }
813 };
814 
815 } // namespace
816 } // namespace scf
817 } // namespace mlir
818 
819 void mlir::scf::registerBufferizableOpInterfaceExternalModels(
820     DialectRegistry &registry) {
821   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
822     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
823     ForOp::attachInterface<ForOpInterface>(*ctx);
824     IfOp::attachInterface<IfOpInterface>(*ctx);
825     WhileOp::attachInterface<WhileOpInterface>(*ctx);
826     YieldOp::attachInterface<YieldOpInterface>(*ctx);
827   });
828 }
829