1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 using namespace mlir::bufferization;
22 using namespace mlir::scf;
23 
24 namespace mlir {
25 namespace scf {
26 namespace {
27 
28 // bufferization.to_memref is not allowed to change the rank.
29 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
30 #ifndef NDEBUG
31   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
32   assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
33                                 rankedTensorType.getRank())) &&
34          "to_memref would be invalid: mismatching ranks");
35 #endif
36 }
37 
38 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
39 /// fully implemented at the moment.
40 struct ExecuteRegionOpInterface
41     : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
42                                                     scf::ExecuteRegionOp> {
43   SmallVector<OpOperand *>
44   getAliasingOpOperand(Operation *op, OpResult opResult,
45                        const AnalysisState &state) const {
46     // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
47     // any SSA value that is in scope. To allow for use-def chain traversal
48     // through ExecuteRegionOps in the analysis, the corresponding yield value
49     // is considered to be aliasing with the result.
50     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
51     size_t resultNum = std::distance(op->getOpResults().begin(),
52                                      llvm::find(op->getOpResults(), opResult));
53     // TODO: Support multiple blocks.
54     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
55            "expected exactly 1 block");
56     auto yieldOp = dyn_cast<scf::YieldOp>(
57         executeRegionOp.getRegion().front().getTerminator());
58     assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
59     return {&yieldOp->getOpOperand(resultNum)};
60   }
61 
62   // TODO: For better bufferization results, this could return `true` only if
63   // there is a memory write in the region.
64   bool isMemoryWrite(Operation *op, OpResult opResult,
65                      const AnalysisState &state) const {
66     // Similar to scf.if, results of this op are always considered memory writes
67     // in the analysis. This is a useful pattern for all ops that have tensor
68     // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
69     // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
70     // ops without OpOperands.
71     return true;
72   }
73 
74   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
75                           BufferizationState &state) const {
76     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
77 
78     // Compute new result types.
79     SmallVector<Type> newResultTypes;
80     for (Type type : executeRegionOp->getResultTypes()) {
81       if (auto tensorType = type.dyn_cast<TensorType>()) {
82         newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
83       } else {
84         newResultTypes.push_back(type);
85       }
86     }
87 
88     // Create new op and move over region.
89     auto newOp =
90         rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
91     newOp.getRegion().takeBody(executeRegionOp.getRegion());
92 
93     // Update terminator.
94     assert(newOp.getRegion().getBlocks().size() == 1 &&
95            "only 1 block supported");
96     Block *newBlock = &newOp.getRegion().front();
97     auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
98     rewriter.setInsertionPoint(yieldOp);
99     SmallVector<Value> newYieldValues;
100     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
101       Value val = it.value();
102       if (val.getType().isa<TensorType>()) {
103         newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
104             yieldOp.getLoc(), newResultTypes[it.index()], val));
105       } else {
106         newYieldValues.push_back(val);
107       }
108     }
109     rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
110 
111     // Update all uses of the old op.
112     rewriter.setInsertionPointAfter(newOp);
113     SmallVector<Value> newResults;
114     for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
115       if (it.value().isa<TensorType>()) {
116         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
117             executeRegionOp.getLoc(), newOp->getResult(it.index())));
118       } else {
119         newResults.push_back(newOp->getResult(it.index()));
120       }
121     }
122 
123     // Replace old op.
124     rewriter.replaceOp(executeRegionOp, newResults);
125 
126     return success();
127   }
128 
129   BufferRelation bufferRelation(Operation *op, OpResult opResult,
130                                 const AnalysisState &state) const {
131     return BufferRelation::Equivalent;
132   }
133 };
134 
135 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
136 struct IfOpInterface
137     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
138   SmallVector<OpOperand *>
139   getAliasingOpOperand(Operation *op, OpResult opResult,
140                        const AnalysisState &state) const {
141     // IfOps do not have tensor OpOperands. The yielded value can be any SSA
142     // value that is in scope. To allow for use-def chain traversal through
143     // IfOps in the analysis, both corresponding yield values from the then/else
144     // branches are considered to be aliasing with the result.
145     auto ifOp = cast<scf::IfOp>(op);
146     size_t resultNum = std::distance(op->getOpResults().begin(),
147                                      llvm::find(op->getOpResults(), opResult));
148     return {&ifOp.thenYield()->getOpOperand(resultNum),
149             &ifOp.elseYield()->getOpOperand(resultNum)};
150   }
151 
152   // TODO: For better bufferization results, this could return `true` only if
153   // there is a memory write in one (or both) of the branches. Since this is not
154   // allowed at the moment, we should never encounter scf.ifs that yield
155   // unmodified tensors. Such scf.yield ops could just fold away.
156   bool isMemoryWrite(Operation *op, OpResult opResult,
157                      const AnalysisState &state) const {
158     // IfOp results are always considered memory writes in the analysis. This
159     // design decision simplifies the analysis considerably. E.g., consider the
160     // following test case:
161     //
162     // %0 = "some_writing_op" : tensor<?xf32>
163     // %r = scf.if %c -> (tensor<?xf32>) {
164     //   scf.yield %0
165     // } else {
166     //   %1 = "another_writing_op"(%0) : tensor<?xf32>
167     // }
168     // "some_reading_op"(%r)
169     //
170     // "another_writing_op" in the above example should be able to bufferize
171     // inplace in the absence of another read of %0. However, if the scf.if op
172     // would not be considered a "write", the analysis would detect the
173     // following conflict:
174     //
175     // * read = some_reading_op
176     // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
177     // * conflictingWrite = %1
178     //
179     // For more details, check the "scf.IfOp" section of the design document.
180     return true;
181   }
182 
183   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
184                           BufferizationState &state) const {
185     auto ifOp = cast<scf::IfOp>(op);
186 
187     // Compute new types of the bufferized scf.if op.
188     SmallVector<Type> newTypes;
189     for (Type returnType : ifOp->getResultTypes()) {
190       if (auto tensorType = returnType.dyn_cast<TensorType>()) {
191         newTypes.push_back(getMemRefType(tensorType, state.getOptions()));
192       } else {
193         newTypes.push_back(returnType);
194       }
195     }
196 
197     // Create new op.
198     auto newIfOp =
199         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
200                                    /*withElseRegion=*/true);
201 
202     // Remove terminators.
203     if (!newIfOp.thenBlock()->empty()) {
204       rewriter.eraseOp(newIfOp.thenBlock()->getTerminator());
205       rewriter.eraseOp(newIfOp.elseBlock()->getTerminator());
206     }
207 
208     // Move over then/else blocks.
209     rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
210     rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
211 
212     // Update scf.yield of new then-block.
213     auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator());
214     rewriter.setInsertionPoint(thenYieldOp);
215     SmallVector<Value> thenYieldValues;
216     for (OpOperand &operand : thenYieldOp->getOpOperands()) {
217       if (operand.get().getType().isa<TensorType>()) {
218         ensureToMemrefOpIsValid(operand.get(),
219                                 newTypes[operand.getOperandNumber()]);
220         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
221             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
222             operand.get());
223         operand.set(toMemrefOp);
224       }
225     }
226 
227     // Update scf.yield of new else-block.
228     auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator());
229     rewriter.setInsertionPoint(elseYieldOp);
230     SmallVector<Value> elseYieldValues;
231     for (OpOperand &operand : elseYieldOp->getOpOperands()) {
232       if (operand.get().getType().isa<TensorType>()) {
233         ensureToMemrefOpIsValid(operand.get(),
234                                 newTypes[operand.getOperandNumber()]);
235         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
236             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
237             operand.get());
238         operand.set(toMemrefOp);
239       }
240     }
241 
242     // Replace op results.
243     replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
244 
245     return success();
246   }
247 
248   BufferRelation bufferRelation(Operation *op, OpResult opResult,
249                                 const AnalysisState &state) const {
250     // IfOp results are equivalent to their corresponding yield values if both
251     // yield values are equivalent to each other.
252     auto bufferizableOp = cast<BufferizableOpInterface>(op);
253     SmallVector<OpOperand *> yieldValues =
254         bufferizableOp.getAliasingOpOperand(opResult, state);
255     assert(yieldValues.size() == 2 && "expected 2 yield values");
256     bool equivalentYields = state.areEquivalentBufferizedValues(
257         yieldValues[0]->get(), yieldValues[1]->get());
258     return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
259   }
260 };
261 
262 /// Helper function for loop bufferization. Return the indices of all values
263 /// that have a tensor type.
264 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
265   DenseSet<int64_t> result;
266   for (const auto &it : llvm::enumerate(values))
267     if (it.value().getType().isa<TensorType>())
268       result.insert(it.index());
269   return result;
270 }
271 
272 /// Helper function for loop bufferization. Return the indices of all
273 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
274 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
275                                        ValueRange yieldedValues,
276                                        const AnalysisState &state) {
277   DenseSet<int64_t> result;
278   int64_t counter = 0;
279   for (const auto &it : llvm::zip(bbArgs, yieldedValues)) {
280     if (!std::get<0>(it).getType().isa<TensorType>())
281       continue;
282     if (state.areEquivalentBufferizedValues(std::get<0>(it), std::get<1>(it)))
283       result.insert(counter);
284     counter++;
285   }
286   return result;
287 }
288 
289 /// Helper function for loop bufferization. Cast the given buffer to the given
290 /// memref type.
291 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
292   assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
293   assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
294   // If the buffer already has the correct type, no cast is needed.
295   if (buffer.getType() == type)
296     return buffer;
297   // TODO: In case `type` has a layout map that is not the fully dynamic
298   // one, we may not be able to cast the buffer. In that case, the loop
299   // iter_arg's layout map must be changed (see uses of `castBuffer`).
300   assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
301          "scf.while op bufferization: cast incompatible");
302   return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
303 }
304 
305 /// Helper function for loop bufferization. Return the bufferized values of the
306 /// given OpOperands. If an operand is not a tensor, return the original value.
307 static SmallVector<Value> getBuffers(RewriterBase &rewriter,
308                                      MutableArrayRef<OpOperand> operands,
309                                      BufferizationState &state) {
310   SmallVector<Value> result;
311   for (OpOperand &opOperand : operands) {
312     if (opOperand.get().getType().isa<TensorType>()) {
313       FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
314       if (failed(resultBuffer))
315         return {};
316       result.push_back(*resultBuffer);
317     } else {
318       result.push_back(opOperand.get());
319     }
320   }
321   return result;
322 }
323 
324 /// Helper function for loop bufferization. Compute the buffer that should be
325 /// yielded from a loop block (loop body or loop condition). If the given tensor
326 /// is equivalent to the corresponding block argument (as indicated by
327 /// `isEquivalent`), the buffer can be yielded directly. Otherwise, a new buffer
328 /// copy must be yielded.
329 ///
330 /// According to the `BufferizableOpInterface` implementation of scf loops, a
331 /// a bufferized OpResult may alias only with the corresponding bufferized
332 /// init_arg and with no other buffers. I.e., the i-th OpResult may alias with
333 /// the i-th init_arg; but not with any other OpOperand. If a corresponding
334 /// OpResult/init_arg pair bufferized to equivalent buffers (as indicated by
335 /// `isEquivalent`), this aliasing requirement is satisfied. Otherwise, we
336 /// cannot be sure and must yield a new buffer copy. (New buffer copies do not
337 /// alias with any buffer.)
338 static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
339                               BaseMemRefType type, bool isEquivalent,
340                               BufferizationState &state) {
341   assert(tensor.getType().isa<TensorType>() && "expected tensor");
342   ensureToMemrefOpIsValid(tensor, type);
343   Value yieldedVal =
344       bufferization::lookupBuffer(rewriter, tensor, state.getOptions());
345 
346   if (isEquivalent)
347     // Yielded value is equivalent to the corresponding iter_arg bbArg.
348     // Yield the value directly. Most IR should be like that. Everything
349     // else must be resolved with copies and is potentially inefficient.
350     // By default, such problematic IR would already have been rejected
351     // during `verifyAnalysis`, unless `allow-return-allocs`.
352     return castBuffer(rewriter, yieldedVal, type);
353 
354   // It is not certain that the yielded value and the iter_arg bbArg
355   // have the same buffer. Allocate a new buffer and copy. The yielded
356   // buffer will get deallocated by `deallocateBuffers`.
357 
358   // TODO: There are cases in which it is not neccessary to return a new
359   // buffer allocation. E.g., when equivalent values are yielded in a
360   // different order. This could be resolved with copies.
361   Optional<Value> yieldedAlloc = state.createAlloc(
362       rewriter, tensor.getLoc(), yieldedVal, /*deallocMemref=*/false);
363   // TODO: We should rollback, but for now just assume that this always
364   // succeeds.
365   assert(yieldedAlloc.hasValue() && "could not create alloc");
366   LogicalResult copyStatus = bufferization::createMemCpy(
367       rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc, state.getOptions());
368   (void)copyStatus;
369   assert(succeeded(copyStatus) && "could not create memcpy");
370 
371   // The iter_arg memref type may have a layout map. Cast the new buffer
372   // to the same type if needed.
373   return castBuffer(rewriter, *yieldedAlloc, type);
374 }
375 
376 /// Helper function for loop bufferization. Given a range of values, apply
377 /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
378 /// value in the result vector.
379 static SmallVector<Value>
380 convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
381                     llvm::function_ref<Value(Value, int64_t)> func) {
382   SmallVector<Value> result;
383   for (const auto &it : llvm::enumerate(values)) {
384     size_t idx = it.index();
385     Value val = it.value();
386     result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val);
387   }
388   return result;
389 }
390 
391 /// Helper function for loop bufferization. Given a list of pre-bufferization
392 /// yielded values, compute the list of bufferized yielded values.
393 SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
394                                     TypeRange bufferizedTypes,
395                                     const DenseSet<int64_t> &tensorIndices,
396                                     const DenseSet<int64_t> &equivalentTensors,
397                                     BufferizationState &state) {
398   return convertTensorValues(
399       values, tensorIndices, [&](Value val, int64_t index) {
400         return getYieldedBuffer(rewriter, val,
401                                 bufferizedTypes[index].cast<BaseMemRefType>(),
402                                 equivalentTensors.contains(index), state);
403       });
404 }
405 
406 /// Helper function for loop bufferization. Given a list of bbArgs of the new
407 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
408 /// ToTensorOps, so that the block body can be moved over to the new op.
409 SmallVector<Value>
410 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
411                      const DenseSet<int64_t> &tensorIndices) {
412   return convertTensorValues(
413       bbArgs, tensorIndices, [&](Value val, int64_t index) {
414         return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
415       });
416 }
417 
418 /// Bufferization of scf.for. Replace with a new scf.for that operates on
419 /// memrefs.
420 struct ForOpInterface
421     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
422                                                     scf::ForOp> {
423   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
424                               const AnalysisState &state) const {
425     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
426     // its matching bbArg may.
427     auto forOp = cast<scf::ForOp>(op);
428     return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
429   }
430 
431   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
432                                const AnalysisState &state) const {
433     // Tensor iter_args of scf::ForOps are always considered as a write.
434     return true;
435   }
436 
437   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
438                                             const AnalysisState &state) const {
439     auto forOp = cast<scf::ForOp>(op);
440     return {forOp.getResultForOpOperand(opOperand)};
441   }
442 
443   BufferRelation bufferRelation(Operation *op, OpResult opResult,
444                                 const AnalysisState &state) const {
445     // ForOp results are equivalent to their corresponding init_args if the
446     // corresponding iter_args and yield values are equivalent.
447     auto forOp = cast<scf::ForOp>(op);
448     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
449     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
450     auto yieldOp =
451         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
452     bool equivalentYield = state.areEquivalentBufferizedValues(
453         bbArg, yieldOp->getOperand(opResult.getResultNumber()));
454     return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
455   }
456 
457   bool isWritable(Operation *op, Value value,
458                   const AnalysisState &state) const {
459     // Interestingly, scf::ForOp's bbArg can **always** be viewed
460     // inplace from the perspective of ops nested under:
461     //   1. Either the matching iter operand is not bufferized inplace and an
462     //      alloc + optional copy makes the bbArg itself inplaceable.
463     //   2. Or the matching iter operand is bufferized inplace and bbArg just
464     //      bufferizes to that too.
465     return true;
466   }
467 
468   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
469                           BufferizationState &state) const {
470     auto forOp = cast<scf::ForOp>(op);
471     auto oldYieldOp =
472         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
473     Block *oldLoopBody = &forOp.getLoopBody().front();
474 
475     // Indices of all iter_args that have tensor type. These are the ones that
476     // are bufferized.
477     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
478     // For every yielded value, is the value equivalent to its corresponding
479     // bbArg?
480     DenseSet<int64_t> equivalentYields =
481         getEquivalentBuffers(forOp.getRegionIterArgs(), oldYieldOp.getResults(),
482                              state.getAnalysisState());
483 
484     // The new memref init_args of the loop.
485     SmallVector<Value> initArgs =
486         getBuffers(rewriter, forOp.getIterOpOperands(), state);
487     if (initArgs.size() != indices.size())
488       return failure();
489 
490     // Construct a new scf.for op with memref instead of tensor values.
491     auto newForOp = rewriter.create<scf::ForOp>(
492         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
493         forOp.getStep(), initArgs);
494     ValueRange initArgsRange(initArgs);
495     TypeRange initArgsTypes(initArgsRange);
496     Block *loopBody = &newForOp.getLoopBody().front();
497 
498     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
499     // iter_args of the new loop in ToTensorOps.
500     rewriter.setInsertionPointToStart(loopBody);
501     SmallVector<Value> iterArgs =
502         getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
503     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
504 
505     // Erase terminator if present.
506     if (iterArgs.size() == 1)
507       rewriter.eraseOp(loopBody->getTerminator());
508 
509     // Move loop body to new loop.
510     rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
511 
512     // Update scf.yield of new loop.
513     auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
514     rewriter.setInsertionPoint(yieldOp);
515     SmallVector<Value> yieldValues =
516         getYieldedValues(rewriter, yieldOp.getResults(), initArgsTypes, indices,
517                          equivalentYields, state);
518     yieldOp.getResultsMutable().assign(yieldValues);
519 
520     // Replace loop results.
521     replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
522 
523     return success();
524   }
525 
526   /// Assert that yielded values of an scf.for op are equivalent to their
527   /// corresponding bbArgs. In that case, the buffer relations of the
528   /// corresponding OpResults are "Equivalent".
529   ///
530   /// If this is not the case, an allocs+copies are inserted and yielded from
531   /// the loop. This could be a performance problem, so it must be explicitly
532   /// activated with `alloc-return-allocs`.
533   LogicalResult verifyAnalysis(Operation *op,
534                                const AnalysisState &state) const {
535     const auto &options =
536         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
537     if (options.allowReturnAllocs)
538       return success();
539 
540     auto forOp = cast<scf::ForOp>(op);
541     auto yieldOp =
542         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
543     for (OpResult opResult : op->getOpResults()) {
544       if (!opResult.getType().isa<TensorType>())
545         continue;
546 
547       // Note: This is overly strict. We should check for aliasing bufferized
548       // values. But we don't have a "must-alias" analysis yet.
549       if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
550         return yieldOp->emitError()
551                << "Yield operand #" << opResult.getResultNumber()
552                << " is not equivalent to the corresponding iter bbArg";
553     }
554 
555     return success();
556   }
557 };
558 
559 /// Bufferization of scf.while. Replace with a new scf.while that operates on
560 /// memrefs.
561 struct WhileOpInterface
562     : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
563                                                     scf::WhileOp> {
564   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
565                               const AnalysisState &state) const {
566     // Tensor iter_args of scf::WhileOps are always considered as a read.
567     return true;
568   }
569 
570   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
571                                const AnalysisState &state) const {
572     // Tensor iter_args of scf::WhileOps are always considered as a write.
573     return true;
574   }
575 
576   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
577                                             const AnalysisState &state) const {
578     auto whileOp = cast<scf::WhileOp>(op);
579     return {whileOp->getResult(opOperand.getOperandNumber())};
580   }
581 
582   BufferRelation bufferRelation(Operation *op, OpResult opResult,
583                                 const AnalysisState &state) const {
584     // WhileOp results are equivalent to their corresponding init_args if the
585     // corresponding iter_args and yield values are equivalent (for both the
586     // "before" and the "after" block).
587     unsigned int resultNumber = opResult.getResultNumber();
588     auto whileOp = cast<scf::WhileOp>(op);
589 
590     auto conditionOp = whileOp.getConditionOp();
591     BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
592     Value conditionOperand = conditionOp.getArgs()[resultNumber];
593     bool equivCondition =
594         state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
595 
596     auto yieldOp = whileOp.getYieldOp();
597     BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
598     Value yieldOperand = yieldOp.getOperand(resultNumber);
599     bool equivYield =
600         state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
601 
602     return equivCondition && equivYield ? BufferRelation::Equivalent
603                                         : BufferRelation::None;
604   }
605 
606   bool isWritable(Operation *op, Value value,
607                   const AnalysisState &state) const {
608     // Interestingly, scf::WhileOp's bbArg can **always** be viewed
609     // inplace from the perspective of ops nested under:
610     //   1. Either the matching iter operand is not bufferized inplace and an
611     //      alloc + optional copy makes the bbArg itself inplaceable.
612     //   2. Or the matching iter operand is bufferized inplace and bbArg just
613     //      bufferizes to that too.
614     return true;
615   }
616 
617   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
618                           BufferizationState &state) const {
619     auto whileOp = cast<scf::WhileOp>(op);
620 
621     assert(whileOp.getBefore().getBlocks().size() == 1 &&
622            "regions with multiple blocks not supported");
623     Block *beforeBody = &whileOp.getBefore().front();
624     assert(whileOp.getAfter().getBlocks().size() == 1 &&
625            "regions with multiple blocks not supported");
626     Block *afterBody = &whileOp.getAfter().front();
627 
628     // Indices of all iter_args that have tensor type. These are the ones that
629     // are bufferized.
630     DenseSet<int64_t> indices = getTensorIndices(whileOp.getInits());
631     // For every yielded value, is the value equivalent to its corresponding
632     // bbArg?
633     DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
634         whileOp.getBeforeArguments(), whileOp.getConditionOp().getArgs(),
635         state.getAnalysisState());
636     DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
637         whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(),
638         state.getAnalysisState());
639 
640     // The new memref init_args of the loop.
641     SmallVector<Value> initArgs =
642         getBuffers(rewriter, whileOp->getOpOperands(), state);
643     if (initArgs.size() != indices.size())
644       return failure();
645 
646     // Construct a new scf.while op with memref instead of tensor values.
647     ValueRange argsRange(initArgs);
648     TypeRange argsTypes(argsRange);
649     auto newWhileOp =
650         rewriter.create<scf::WhileOp>(whileOp.getLoc(), argsTypes, initArgs);
651     // Add before/after regions to the new op.
652     SmallVector<Location> bbArgLocs(initArgs.size(), whileOp.getLoc());
653     Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
654     newWhileOp.getBefore().addArguments(argsTypes, bbArgLocs);
655     Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
656     newWhileOp.getAfter().addArguments(argsTypes, bbArgLocs);
657 
658     // Set up new iter_args and move the loop condition block to the new op.
659     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
660     // in ToTensorOps.
661     rewriter.setInsertionPointToStart(newBeforeBody);
662     SmallVector<Value> newBeforeArgs = getBbArgReplacements(
663         rewriter, newWhileOp.getBeforeArguments(), indices);
664     rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
665 
666     // Update scf.condition of new loop.
667     auto newConditionOp = newWhileOp.getConditionOp();
668     rewriter.setInsertionPoint(newConditionOp);
669     SmallVector<Value> newConditionArgs =
670         getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypes, indices,
671                          equivalentYieldsBefore, state);
672     newConditionOp.getArgsMutable().assign(newConditionArgs);
673 
674     // Set up new iter_args and move the loop body block to the new op.
675     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
676     // in ToTensorOps.
677     rewriter.setInsertionPointToStart(newAfterBody);
678     SmallVector<Value> newAfterArgs =
679         getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(), indices);
680     rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
681 
682     // Update scf.yield of the new loop.
683     auto newYieldOp = newWhileOp.getYieldOp();
684     rewriter.setInsertionPoint(newYieldOp);
685     SmallVector<Value> newYieldValues =
686         getYieldedValues(rewriter, newYieldOp.getResults(), argsTypes, indices,
687                          equivalentYieldsAfter, state);
688     newYieldOp.getResultsMutable().assign(newYieldValues);
689 
690     // Replace loop results.
691     replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
692 
693     return success();
694   }
695 
696   /// Assert that yielded values of an scf.while op are equivalent to their
697   /// corresponding bbArgs. In that case, the buffer relations of the
698   /// corresponding OpResults are "Equivalent".
699   ///
700   /// If this is not the case, allocs+copies are inserted and yielded from
701   /// the loop. This could be a performance problem, so it must be explicitly
702   /// activated with `alloc-return-allocs`.
703   ///
704   /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
705   /// equivalence condition must be checked for both.
706   LogicalResult verifyAnalysis(Operation *op,
707                                const AnalysisState &state) const {
708     auto whileOp = cast<scf::WhileOp>(op);
709     const auto &options =
710         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
711     if (options.allowReturnAllocs)
712       return success();
713 
714     auto conditionOp = whileOp.getConditionOp();
715     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
716       if (!it.value().getType().isa<TensorType>())
717         continue;
718       if (!state.areEquivalentBufferizedValues(
719               it.value(), conditionOp->getBlock()->getArgument(it.index())))
720         return conditionOp->emitError()
721                << "Condition arg #" << it.index()
722                << " is not equivalent to the corresponding iter bbArg";
723     }
724 
725     auto yieldOp = whileOp.getYieldOp();
726     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
727       if (!it.value().getType().isa<TensorType>())
728         continue;
729       if (!state.areEquivalentBufferizedValues(
730               it.value(), yieldOp->getBlock()->getArgument(it.index())))
731         return yieldOp->emitError()
732                << "Yield operand #" << it.index()
733                << " is not equivalent to the corresponding iter bbArg";
734     }
735 
736     return success();
737   }
738 };
739 
740 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
741 /// this is for analysis only.
742 struct YieldOpInterface
743     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
744                                                     scf::YieldOp> {
745   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
746                               const AnalysisState &state) const {
747     return true;
748   }
749 
750   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
751                                const AnalysisState &state) const {
752     return false;
753   }
754 
755   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
756                                             const AnalysisState &state) const {
757     if (isa<scf::IfOp>(op->getParentOp()))
758       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
759     if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
760       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
761     return {};
762   }
763 
764   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
765                             const AnalysisState &state) const {
766     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
767     // may be generated inside the block. We should not return/yield allocations
768     // when possible.
769     return true;
770   }
771 
772   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
773                           BufferizationState &state) const {
774     auto yieldOp = cast<scf::YieldOp>(op);
775     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
776             yieldOp->getParentOp()))
777       return yieldOp->emitError("unsupported scf::YieldOp parent");
778     return success();
779   }
780 };
781 
782 } // namespace
783 } // namespace scf
784 } // namespace mlir
785 
786 void mlir::scf::registerBufferizableOpInterfaceExternalModels(
787     DialectRegistry &registry) {
788   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
789     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
790     ForOp::attachInterface<ForOpInterface>(*ctx);
791     IfOp::attachInterface<IfOpInterface>(*ctx);
792     WhileOp::attachInterface<WhileOpInterface>(*ctx);
793     YieldOp::attachInterface<YieldOpInterface>(*ctx);
794   });
795 }
796