1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Utils/StaticValueUtils.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/IR/PatternMatch.h"
21 
22 using namespace mlir;
23 using namespace mlir::bufferization;
24 using namespace mlir::scf;
25 
26 namespace mlir {
27 namespace scf {
28 namespace {
29 
30 // bufferization.to_memref is not allowed to change the rank.
ensureToMemrefOpIsValid(Value tensor,Type memrefType)31 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
32 #ifndef NDEBUG
33   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
34   assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
35                                 rankedTensorType.getRank())) &&
36          "to_memref would be invalid: mismatching ranks");
37 #endif
38 }
39 
40 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
41 /// fully implemented at the moment.
42 struct ExecuteRegionOpInterface
43     : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
44                                                     scf::ExecuteRegionOp> {
45   SmallVector<OpOperand *>
getAliasingOpOperandmlir::scf::__anon76a8a75a0111::ExecuteRegionOpInterface46   getAliasingOpOperand(Operation *op, OpResult opResult,
47                        const AnalysisState &state) const {
48     // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
49     // any SSA value that is in scope. To allow for use-def chain traversal
50     // through ExecuteRegionOps in the analysis, the corresponding yield value
51     // is considered to be aliasing with the result.
52     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
53     size_t resultNum = std::distance(op->getOpResults().begin(),
54                                      llvm::find(op->getOpResults(), opResult));
55     // TODO: Support multiple blocks.
56     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
57            "expected exactly 1 block");
58     auto yieldOp = dyn_cast<scf::YieldOp>(
59         executeRegionOp.getRegion().front().getTerminator());
60     assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
61     return {&yieldOp->getOpOperand(resultNum)};
62   }
63 
64   // TODO: For better bufferization results, this could return `true` only if
65   // there is a memory write in the region.
isMemoryWritemlir::scf::__anon76a8a75a0111::ExecuteRegionOpInterface66   bool isMemoryWrite(Operation *op, OpResult opResult,
67                      const AnalysisState &state) const {
68     // Similar to scf.if, results of this op are always considered memory writes
69     // in the analysis. This is a useful pattern for all ops that have tensor
70     // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
71     // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
72     // ops without OpOperands.
73     return true;
74   }
75 
bufferizemlir::scf::__anon76a8a75a0111::ExecuteRegionOpInterface76   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
77                           const BufferizationOptions &options) const {
78     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
79     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
80            "only 1 block supported");
81     auto yieldOp =
82         cast<scf::YieldOp>(executeRegionOp.getRegion().front().getTerminator());
83     TypeRange newResultTypes(yieldOp.getResults());
84 
85     // Create new op and move over region.
86     auto newOp =
87         rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
88     newOp.getRegion().takeBody(executeRegionOp.getRegion());
89 
90     // Update all uses of the old op.
91     rewriter.setInsertionPointAfter(newOp);
92     SmallVector<Value> newResults;
93     for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
94       if (it.value().isa<TensorType>()) {
95         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
96             executeRegionOp.getLoc(), newOp->getResult(it.index())));
97       } else {
98         newResults.push_back(newOp->getResult(it.index()));
99       }
100     }
101 
102     // Replace old op.
103     rewriter.replaceOp(executeRegionOp, newResults);
104 
105     return success();
106   }
107 
bufferRelationmlir::scf::__anon76a8a75a0111::ExecuteRegionOpInterface108   BufferRelation bufferRelation(Operation *op, OpResult opResult,
109                                 const AnalysisState &state) const {
110     return BufferRelation::Equivalent;
111   }
112 };
113 
114 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
115 struct IfOpInterface
116     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
117   SmallVector<OpOperand *>
getAliasingOpOperandmlir::scf::__anon76a8a75a0111::IfOpInterface118   getAliasingOpOperand(Operation *op, OpResult opResult,
119                        const AnalysisState &state) const {
120     // IfOps do not have tensor OpOperands. The yielded value can be any SSA
121     // value that is in scope. To allow for use-def chain traversal through
122     // IfOps in the analysis, both corresponding yield values from the then/else
123     // branches are considered to be aliasing with the result.
124     auto ifOp = cast<scf::IfOp>(op);
125     size_t resultNum = std::distance(op->getOpResults().begin(),
126                                      llvm::find(op->getOpResults(), opResult));
127     return {&ifOp.thenYield()->getOpOperand(resultNum),
128             &ifOp.elseYield()->getOpOperand(resultNum)};
129   }
130 
131   // TODO: For better bufferization results, this could return `true` only if
132   // there is a memory write in one (or both) of the branches. Since this is not
133   // allowed at the moment, we should never encounter scf.ifs that yield
134   // unmodified tensors. Such scf.yield ops could just fold away.
isMemoryWritemlir::scf::__anon76a8a75a0111::IfOpInterface135   bool isMemoryWrite(Operation *op, OpResult opResult,
136                      const AnalysisState &state) const {
137     // IfOp results are always considered memory writes in the analysis. This
138     // design decision simplifies the analysis considerably. E.g., consider the
139     // following test case:
140     //
141     // %0 = "some_writing_op" : tensor<?xf32>
142     // %r = scf.if %c -> (tensor<?xf32>) {
143     //   scf.yield %0
144     // } else {
145     //   %1 = "another_writing_op"(%0) : tensor<?xf32>
146     // }
147     // "some_reading_op"(%r)
148     //
149     // "another_writing_op" in the above example should be able to bufferize
150     // inplace in the absence of another read of %0. However, if the scf.if op
151     // would not be considered a "write", the analysis would detect the
152     // following conflict:
153     //
154     // * read = some_reading_op
155     // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
156     // * conflictingWrite = %1
157     //
158     // For more details, check the "scf.IfOp" section of the design document.
159     return true;
160   }
161 
bufferizemlir::scf::__anon76a8a75a0111::IfOpInterface162   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
163                           const BufferizationOptions &options) const {
164     OpBuilder::InsertionGuard g(rewriter);
165     auto ifOp = cast<scf::IfOp>(op);
166     auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
167     auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
168 
169     // Reconcile type mismatches between then/else branches by inserting memref
170     // casts.
171     SmallVector<Value> thenResults, elseResults;
172     bool insertedCast = false;
173     for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) {
174       Value thenValue = thenYieldOp.getResults()[i];
175       Value elseValue = elseYieldOp.getResults()[i];
176       if (thenValue.getType() == elseValue.getType()) {
177         thenResults.push_back(thenValue);
178         elseResults.push_back(elseValue);
179         continue;
180       }
181 
182       // Type mismatch between then/else yield value. Cast both to a memref type
183       // with a fully dynamic layout map.
184       auto thenMemrefType = thenValue.getType().cast<BaseMemRefType>();
185       auto elseMemrefType = elseValue.getType().cast<BaseMemRefType>();
186       if (thenMemrefType.getMemorySpaceAsInt() !=
187           elseMemrefType.getMemorySpaceAsInt())
188         return op->emitError("inconsistent memory space on then/else branches");
189       rewriter.setInsertionPoint(thenYieldOp);
190       BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout(
191           ifOp.getResultTypes()[i].cast<TensorType>(),
192           thenMemrefType.getMemorySpaceAsInt());
193       thenResults.push_back(rewriter.create<memref::CastOp>(
194           thenYieldOp.getLoc(), memrefType, thenValue));
195       rewriter.setInsertionPoint(elseYieldOp);
196       elseResults.push_back(rewriter.create<memref::CastOp>(
197           elseYieldOp.getLoc(), memrefType, elseValue));
198       insertedCast = true;
199     }
200 
201     if (insertedCast) {
202       rewriter.setInsertionPoint(thenYieldOp);
203       rewriter.replaceOpWithNewOp<scf::YieldOp>(thenYieldOp, thenResults);
204       rewriter.setInsertionPoint(elseYieldOp);
205       rewriter.replaceOpWithNewOp<scf::YieldOp>(elseYieldOp, elseResults);
206     }
207 
208     // Create new op.
209     rewriter.setInsertionPoint(ifOp);
210     ValueRange resultsValueRange(thenResults);
211     TypeRange newTypes(resultsValueRange);
212     auto newIfOp =
213         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
214                                    /*withElseRegion=*/true);
215 
216     // Move over then/else blocks.
217     rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
218     rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
219 
220     // Replace op results.
221     replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
222 
223     return success();
224   }
225 
bufferRelationmlir::scf::__anon76a8a75a0111::IfOpInterface226   BufferRelation bufferRelation(Operation *op, OpResult opResult,
227                                 const AnalysisState &state) const {
228     // IfOp results are equivalent to their corresponding yield values if both
229     // yield values are equivalent to each other.
230     auto bufferizableOp = cast<BufferizableOpInterface>(op);
231     SmallVector<OpOperand *> yieldValues =
232         bufferizableOp.getAliasingOpOperand(opResult, state);
233     assert(yieldValues.size() == 2 && "expected 2 yield values");
234     bool equivalentYields = state.areEquivalentBufferizedValues(
235         yieldValues[0]->get(), yieldValues[1]->get());
236     return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
237   }
238 };
239 
240 /// Helper function for loop bufferization. Return the indices of all values
241 /// that have a tensor type.
getTensorIndices(ValueRange values)242 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
243   DenseSet<int64_t> result;
244   for (const auto &it : llvm::enumerate(values))
245     if (it.value().getType().isa<TensorType>())
246       result.insert(it.index());
247   return result;
248 }
249 
250 /// Helper function for loop bufferization. Return the indices of all
251 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
getEquivalentBuffers(Block::BlockArgListType bbArgs,ValueRange yieldedValues,const AnalysisState & state)252 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
253                                        ValueRange yieldedValues,
254                                        const AnalysisState &state) {
255   unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
256   DenseSet<int64_t> result;
257   for (unsigned int i = 0; i < minSize; ++i) {
258     if (!bbArgs[i].getType().isa<TensorType>() ||
259         !yieldedValues[i].getType().isa<TensorType>())
260       continue;
261     if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
262       result.insert(i);
263   }
264   return result;
265 }
266 
267 /// Helper function for loop bufferization. Cast the given buffer to the given
268 /// memref type.
castBuffer(OpBuilder & b,Value buffer,Type type)269 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
270   assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
271   assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
272   // If the buffer already has the correct type, no cast is needed.
273   if (buffer.getType() == type)
274     return buffer;
275   // TODO: In case `type` has a layout map that is not the fully dynamic
276   // one, we may not be able to cast the buffer. In that case, the loop
277   // iter_arg's layout map must be changed (see uses of `castBuffer`).
278   assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
279          "scf.while op bufferization: cast incompatible");
280   return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
281 }
282 
283 /// Helper function for loop bufferization. Return the bufferized values of the
284 /// given OpOperands. If an operand is not a tensor, return the original value.
285 static FailureOr<SmallVector<Value>>
getBuffers(RewriterBase & rewriter,MutableArrayRef<OpOperand> operands,const BufferizationOptions & options)286 getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
287            const BufferizationOptions &options) {
288   SmallVector<Value> result;
289   for (OpOperand &opOperand : operands) {
290     if (opOperand.get().getType().isa<TensorType>()) {
291       FailureOr<Value> resultBuffer =
292           getBuffer(rewriter, opOperand.get(), options);
293       if (failed(resultBuffer))
294         return failure();
295       result.push_back(*resultBuffer);
296     } else {
297       result.push_back(opOperand.get());
298     }
299   }
300   return result;
301 }
302 
303 /// Helper function for loop bufferization. Compute the buffer that should be
304 /// yielded from a loop block (loop body or loop condition).
getYieldedBuffer(RewriterBase & rewriter,Value tensor,BaseMemRefType type,const BufferizationOptions & options)305 static FailureOr<Value> getYieldedBuffer(RewriterBase &rewriter, Value tensor,
306                                          BaseMemRefType type,
307                                          const BufferizationOptions &options) {
308   assert(tensor.getType().isa<TensorType>() && "expected tensor");
309   ensureToMemrefOpIsValid(tensor, type);
310   FailureOr<Value> yieldedVal = getBuffer(rewriter, tensor, options);
311   if (failed(yieldedVal))
312     return failure();
313   return castBuffer(rewriter, *yieldedVal, type);
314 }
315 
316 /// Helper function for loop bufferization. Given a range of values, apply
317 /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
318 /// value in the result vector.
319 static FailureOr<SmallVector<Value>>
convertTensorValues(ValueRange values,const DenseSet<int64_t> & tensorIndices,llvm::function_ref<FailureOr<Value> (Value,int64_t)> func)320 convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
321                     llvm::function_ref<FailureOr<Value>(Value, int64_t)> func) {
322   SmallVector<Value> result;
323   for (const auto &it : llvm::enumerate(values)) {
324     size_t idx = it.index();
325     Value val = it.value();
326     if (tensorIndices.contains(idx)) {
327       FailureOr<Value> maybeVal = func(val, idx);
328       if (failed(maybeVal))
329         return failure();
330       result.push_back(*maybeVal);
331     } else {
332       result.push_back(val);
333     }
334   }
335   return result;
336 }
337 
338 /// Helper function for loop bufferization. Given a list of pre-bufferization
339 /// yielded values, compute the list of bufferized yielded values.
340 FailureOr<SmallVector<Value>>
getYieldedValues(RewriterBase & rewriter,ValueRange values,TypeRange bufferizedTypes,const DenseSet<int64_t> & tensorIndices,const BufferizationOptions & options)341 getYieldedValues(RewriterBase &rewriter, ValueRange values,
342                  TypeRange bufferizedTypes,
343                  const DenseSet<int64_t> &tensorIndices,
344                  const BufferizationOptions &options) {
345   return convertTensorValues(
346       values, tensorIndices, [&](Value val, int64_t index) {
347         return getYieldedBuffer(rewriter, val,
348                                 bufferizedTypes[index].cast<BaseMemRefType>(),
349                                 options);
350       });
351 }
352 
353 /// Helper function for loop bufferization. Given a list of bbArgs of the new
354 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
355 /// ToTensorOps, so that the block body can be moved over to the new op.
356 SmallVector<Value>
getBbArgReplacements(RewriterBase & rewriter,Block::BlockArgListType bbArgs,const DenseSet<int64_t> & tensorIndices)357 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
358                      const DenseSet<int64_t> &tensorIndices) {
359   SmallVector<Value> result;
360   for (const auto &it : llvm::enumerate(bbArgs)) {
361     size_t idx = it.index();
362     Value val = it.value();
363     if (tensorIndices.contains(idx)) {
364       result.push_back(
365           rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
366               .getResult());
367     } else {
368       result.push_back(val);
369     }
370   }
371   return result;
372 }
373 
374 /// Bufferization of scf.for. Replace with a new scf.for that operates on
375 /// memrefs.
376 struct ForOpInterface
377     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
378                                                     scf::ForOp> {
bufferizesToMemoryReadmlir::scf::__anon76a8a75a0111::ForOpInterface379   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
380                               const AnalysisState &state) const {
381     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
382     // its matching bbArg may.
383     auto forOp = cast<scf::ForOp>(op);
384     return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
385   }
386 
bufferizesToMemoryWritemlir::scf::__anon76a8a75a0111::ForOpInterface387   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
388                                const AnalysisState &state) const {
389     // Tensor iter_args of scf::ForOps are always considered as a write.
390     return true;
391   }
392 
getAliasingOpResultmlir::scf::__anon76a8a75a0111::ForOpInterface393   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
394                                             const AnalysisState &state) const {
395     auto forOp = cast<scf::ForOp>(op);
396     return {forOp.getResultForOpOperand(opOperand)};
397   }
398 
bufferRelationmlir::scf::__anon76a8a75a0111::ForOpInterface399   BufferRelation bufferRelation(Operation *op, OpResult opResult,
400                                 const AnalysisState &state) const {
401     // ForOp results are equivalent to their corresponding init_args if the
402     // corresponding iter_args and yield values are equivalent.
403     auto forOp = cast<scf::ForOp>(op);
404     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
405     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
406     auto yieldOp =
407         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
408     bool equivalentYield = state.areEquivalentBufferizedValues(
409         bbArg, yieldOp->getOperand(opResult.getResultNumber()));
410     return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
411   }
412 
isWritablemlir::scf::__anon76a8a75a0111::ForOpInterface413   bool isWritable(Operation *op, Value value,
414                   const AnalysisState &state) const {
415     // Interestingly, scf::ForOp's bbArg can **always** be viewed
416     // inplace from the perspective of ops nested under:
417     //   1. Either the matching iter operand is not bufferized inplace and an
418     //      alloc + optional copy makes the bbArg itself inplaceable.
419     //   2. Or the matching iter operand is bufferized inplace and bbArg just
420     //      bufferizes to that too.
421     return true;
422   }
423 
resolveConflictsmlir::scf::__anon76a8a75a0111::ForOpInterface424   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
425                                  const AnalysisState &state) const {
426     auto bufferizableOp = cast<BufferizableOpInterface>(op);
427     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
428       return failure();
429 
430     if (!state.getOptions().enforceAliasingInvariants)
431       return success();
432 
433     // According to the `getAliasing...` implementations, a bufferized OpResult
434     // may alias only with the corresponding bufferized init_arg and with no
435     // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
436     // but not with any other OpOperand. If a corresponding OpResult/init_arg
437     // pair bufferizes to equivalent buffers, this aliasing requirement is
438     // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
439     // (New buffer copies do not alias with any buffer.)
440     auto forOp = cast<scf::ForOp>(op);
441     auto yieldOp =
442         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
443     OpBuilder::InsertionGuard g(rewriter);
444     rewriter.setInsertionPoint(yieldOp);
445 
446     // Indices of all iter_args that have tensor type. These are the ones that
447     // are bufferized.
448     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
449     // For every yielded value, is the value equivalent to its corresponding
450     // bbArg?
451     DenseSet<int64_t> equivalentYields = getEquivalentBuffers(
452         forOp.getRegionIterArgs(), yieldOp.getResults(), state);
453     SmallVector<Value> yieldValues;
454     for (int64_t idx = 0;
455          idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
456       Value value = yieldOp.getResults()[idx];
457       if (!indices.contains(idx) || equivalentYields.contains(idx)) {
458         yieldValues.push_back(value);
459         continue;
460       }
461       FailureOr<Value> alloc =
462           allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
463                                        /*escape=*/true, state.getOptions());
464       if (failed(alloc))
465         return failure();
466       yieldValues.push_back(*alloc);
467     }
468 
469     rewriter.updateRootInPlace(
470         yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
471     return success();
472   }
473 
474   FailureOr<BaseMemRefType>
getBufferTypemlir::scf::__anon76a8a75a0111::ForOpInterface475   getBufferType(Operation *op, BlockArgument bbArg,
476                 const BufferizationOptions &options) const {
477     auto forOp = cast<scf::ForOp>(op);
478     return bufferization::getBufferType(
479         forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
480   }
481 
bufferizemlir::scf::__anon76a8a75a0111::ForOpInterface482   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
483                           const BufferizationOptions &options) const {
484     auto forOp = cast<scf::ForOp>(op);
485     Block *oldLoopBody = &forOp.getLoopBody().front();
486 
487     // Indices of all iter_args that have tensor type. These are the ones that
488     // are bufferized.
489     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
490 
491     // The new memref init_args of the loop.
492     FailureOr<SmallVector<Value>> maybeInitArgs =
493         getBuffers(rewriter, forOp.getIterOpOperands(), options);
494     if (failed(maybeInitArgs))
495       return failure();
496     SmallVector<Value> initArgs = *maybeInitArgs;
497 
498     // Construct a new scf.for op with memref instead of tensor values.
499     auto newForOp = rewriter.create<scf::ForOp>(
500         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
501         forOp.getStep(), initArgs);
502     newForOp->setAttrs(forOp->getAttrs());
503     ValueRange initArgsRange(initArgs);
504     TypeRange initArgsTypes(initArgsRange);
505     Block *loopBody = &newForOp.getLoopBody().front();
506 
507     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
508     // iter_args of the new loop in ToTensorOps.
509     rewriter.setInsertionPointToStart(loopBody);
510     SmallVector<Value> iterArgs =
511         getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
512     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
513 
514     // Move loop body to new loop.
515     rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
516 
517     // Replace loop results.
518     replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
519 
520     return success();
521   }
522 
523   /// Assert that yielded values of an scf.for op are equivalent to their
524   /// corresponding bbArgs. In that case, the buffer relations of the
525   /// corresponding OpResults are "Equivalent".
526   ///
527   /// If this is not the case, an allocs+copies are inserted and yielded from
528   /// the loop. This could be a performance problem, so it must be explicitly
529   /// activated with `alloc-return-allocs`.
verifyAnalysismlir::scf::__anon76a8a75a0111::ForOpInterface530   LogicalResult verifyAnalysis(Operation *op,
531                                const AnalysisState &state) const {
532     const auto &options =
533         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
534     if (options.allowReturnAllocs)
535       return success();
536 
537     auto forOp = cast<scf::ForOp>(op);
538     auto yieldOp =
539         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
540     for (OpResult opResult : op->getOpResults()) {
541       if (!opResult.getType().isa<TensorType>())
542         continue;
543 
544       // Note: This is overly strict. We should check for aliasing bufferized
545       // values. But we don't have a "must-alias" analysis yet.
546       if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
547         return yieldOp->emitError()
548                << "Yield operand #" << opResult.getResultNumber()
549                << " is not equivalent to the corresponding iter bbArg";
550     }
551 
552     return success();
553   }
554 };
555 
556 /// Bufferization of scf.while. Replace with a new scf.while that operates on
557 /// memrefs.
558 struct WhileOpInterface
559     : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
560                                                     scf::WhileOp> {
bufferizesToMemoryReadmlir::scf::__anon76a8a75a0111::WhileOpInterface561   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
562                               const AnalysisState &state) const {
563     // Tensor iter_args of scf::WhileOps are always considered as a read.
564     return true;
565   }
566 
bufferizesToMemoryWritemlir::scf::__anon76a8a75a0111::WhileOpInterface567   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
568                                const AnalysisState &state) const {
569     // Tensor iter_args of scf::WhileOps are always considered as a write.
570     return true;
571   }
572 
getAliasingOpResultmlir::scf::__anon76a8a75a0111::WhileOpInterface573   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
574                                             const AnalysisState &state) const {
575     auto whileOp = cast<scf::WhileOp>(op);
576     unsigned int idx = opOperand.getOperandNumber();
577 
578     // The OpResults and OpOperands may not match. They may not even have the
579     // same type. The number of OpResults and OpOperands can also differ.
580     if (idx >= op->getNumResults() ||
581         opOperand.get().getType() != op->getResult(idx).getType())
582       return {};
583 
584     // The only aliasing OpResult may be the one at the same index.
585     return {whileOp->getResult(idx)};
586   }
587 
bufferRelationmlir::scf::__anon76a8a75a0111::WhileOpInterface588   BufferRelation bufferRelation(Operation *op, OpResult opResult,
589                                 const AnalysisState &state) const {
590     // WhileOp results are equivalent to their corresponding init_args if the
591     // corresponding iter_args and yield values are equivalent (for both the
592     // "before" and the "after" block).
593     unsigned int resultNumber = opResult.getResultNumber();
594     auto whileOp = cast<scf::WhileOp>(op);
595 
596     // The "before" region bbArgs and the OpResults may not match.
597     if (resultNumber >= whileOp.getBeforeArguments().size())
598       return BufferRelation::None;
599     if (opResult.getType() !=
600         whileOp.getBeforeArguments()[resultNumber].getType())
601       return BufferRelation::None;
602 
603     auto conditionOp = whileOp.getConditionOp();
604     BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
605     Value conditionOperand = conditionOp.getArgs()[resultNumber];
606     bool equivCondition =
607         state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
608 
609     auto yieldOp = whileOp.getYieldOp();
610     BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
611     Value yieldOperand = yieldOp.getOperand(resultNumber);
612     bool equivYield =
613         state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
614 
615     return equivCondition && equivYield ? BufferRelation::Equivalent
616                                         : BufferRelation::None;
617   }
618 
isWritablemlir::scf::__anon76a8a75a0111::WhileOpInterface619   bool isWritable(Operation *op, Value value,
620                   const AnalysisState &state) const {
621     // Interestingly, scf::WhileOp's bbArg can **always** be viewed
622     // inplace from the perspective of ops nested under:
623     //   1. Either the matching iter operand is not bufferized inplace and an
624     //      alloc + optional copy makes the bbArg itself inplaceable.
625     //   2. Or the matching iter operand is bufferized inplace and bbArg just
626     //      bufferizes to that too.
627     return true;
628   }
629 
resolveConflictsmlir::scf::__anon76a8a75a0111::WhileOpInterface630   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
631                                  const AnalysisState &state) const {
632     auto bufferizableOp = cast<BufferizableOpInterface>(op);
633     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
634       return failure();
635 
636     if (!state.getOptions().enforceAliasingInvariants)
637       return success();
638 
639     // According to the `getAliasing...` implementations, a bufferized OpResult
640     // may alias only with the corresponding bufferized init_arg and with no
641     // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
642     // but not with any other OpOperand. If a corresponding OpResult/init_arg
643     // pair bufferizes to equivalent buffers, this aliasing requirement is
644     // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
645     // (New buffer copies do not alias with any buffer.)
646     OpBuilder::InsertionGuard g(rewriter);
647     auto whileOp = cast<scf::WhileOp>(op);
648     auto conditionOp = whileOp.getConditionOp();
649     auto yieldOp = whileOp.getYieldOp();
650 
651     // Indices of all bbArgs that have tensor type. These are the ones that
652     // are bufferized. The "before" and "after" regions may have different args.
653     DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
654     DenseSet<int64_t> indicesAfter =
655         getTensorIndices(whileOp.getAfterArguments());
656 
657     // For every yielded value, is the value equivalent to its corresponding
658     // bbArg?
659     DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
660         whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
661     DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
662         whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
663 
664     // Update "before" region.
665     rewriter.setInsertionPoint(conditionOp);
666     SmallVector<Value> beforeYieldValues;
667     for (int64_t idx = 0;
668          idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
669       Value value = conditionOp.getArgs()[idx];
670       if (!indicesBefore.contains(idx) ||
671           equivalentYieldsBefore.contains(idx)) {
672         beforeYieldValues.push_back(value);
673         continue;
674       }
675       FailureOr<Value> alloc =
676           allocateTensorForShapedValue(rewriter, conditionOp.getLoc(), value,
677                                        /*escape=*/true, state.getOptions());
678       if (failed(alloc))
679         return failure();
680       beforeYieldValues.push_back(*alloc);
681     }
682     rewriter.updateRootInPlace(conditionOp, [&]() {
683       conditionOp.getArgsMutable().assign(beforeYieldValues);
684     });
685 
686     // Update "after" region.
687     rewriter.setInsertionPoint(yieldOp);
688     SmallVector<Value> afterYieldValues;
689     for (int64_t idx = 0;
690          idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
691       Value value = yieldOp.getResults()[idx];
692       if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) {
693         afterYieldValues.push_back(value);
694         continue;
695       }
696       FailureOr<Value> alloc =
697           allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
698                                        /*escape=*/true, state.getOptions());
699       if (failed(alloc))
700         return failure();
701       afterYieldValues.push_back(*alloc);
702     }
703     rewriter.updateRootInPlace(yieldOp, [&]() {
704       yieldOp.getResultsMutable().assign(afterYieldValues);
705     });
706 
707     return success();
708   }
709 
710   // TODO: Implement getBufferType interface method and infer buffer types.
711 
bufferizemlir::scf::__anon76a8a75a0111::WhileOpInterface712   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
713                           const BufferizationOptions &options) const {
714     auto whileOp = cast<scf::WhileOp>(op);
715 
716     assert(whileOp.getBefore().getBlocks().size() == 1 &&
717            "regions with multiple blocks not supported");
718     Block *beforeBody = &whileOp.getBefore().front();
719     assert(whileOp.getAfter().getBlocks().size() == 1 &&
720            "regions with multiple blocks not supported");
721     Block *afterBody = &whileOp.getAfter().front();
722 
723     // Indices of all bbArgs that have tensor type. These are the ones that
724     // are bufferized. The "before" and "after" regions may have different args.
725     DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
726     DenseSet<int64_t> indicesAfter =
727         getTensorIndices(whileOp.getAfterArguments());
728 
729     // The new memref init_args of the loop.
730     FailureOr<SmallVector<Value>> maybeInitArgs =
731         getBuffers(rewriter, whileOp->getOpOperands(), options);
732     if (failed(maybeInitArgs))
733       return failure();
734     SmallVector<Value> initArgs = *maybeInitArgs;
735 
736     // The result types of a WhileOp are the same as the "after" bbArg types.
737     SmallVector<Type> argsTypesAfter = llvm::to_vector(
738         llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
739           // TODO: error handling
740           return bufferization::getBufferType(bbArg, options)->cast<Type>();
741         }));
742 
743     // Construct a new scf.while op with memref instead of tensor values.
744     ValueRange argsRangeBefore(initArgs);
745     TypeRange argsTypesBefore(argsRangeBefore);
746     auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
747                                                     argsTypesAfter, initArgs);
748 
749     // Add before/after regions to the new op.
750     SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc());
751     SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
752                                          whileOp.getLoc());
753     Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
754     newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
755     Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
756     newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
757 
758     // Set up new iter_args and move the loop condition block to the new op.
759     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
760     // in ToTensorOps.
761     rewriter.setInsertionPointToStart(newBeforeBody);
762     SmallVector<Value> newBeforeArgs = getBbArgReplacements(
763         rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
764     rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
765 
766     // Update scf.condition of new loop.
767     auto newConditionOp = newWhileOp.getConditionOp();
768     rewriter.setInsertionPoint(newConditionOp);
769     // Only equivalent buffers or new buffer allocations may be yielded to the
770     // "after" region.
771     // TODO: This could be relaxed for better bufferization results.
772     FailureOr<SmallVector<Value>> newConditionArgs =
773         getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
774                          indicesAfter, options);
775     if (failed(newConditionArgs))
776       return failure();
777     newConditionOp.getArgsMutable().assign(*newConditionArgs);
778 
779     // Set up new iter_args and move the loop body block to the new op.
780     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
781     // in ToTensorOps.
782     rewriter.setInsertionPointToStart(newAfterBody);
783     SmallVector<Value> newAfterArgs = getBbArgReplacements(
784         rewriter, newWhileOp.getAfterArguments(), indicesAfter);
785     rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
786 
787     // Update scf.yield of the new loop.
788     auto newYieldOp = newWhileOp.getYieldOp();
789     rewriter.setInsertionPoint(newYieldOp);
790     // Only equivalent buffers or new buffer allocations may be yielded to the
791     // "before" region.
792     // TODO: This could be relaxed for better bufferization results.
793     FailureOr<SmallVector<Value>> newYieldValues =
794         getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
795                          indicesBefore, options);
796     if (failed(newYieldValues))
797       return failure();
798     newYieldOp.getResultsMutable().assign(*newYieldValues);
799 
800     // Replace loop results.
801     replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
802 
803     return success();
804   }
805 
806   /// Assert that yielded values of an scf.while op are equivalent to their
807   /// corresponding bbArgs. In that case, the buffer relations of the
808   /// corresponding OpResults are "Equivalent".
809   ///
810   /// If this is not the case, allocs+copies are inserted and yielded from
811   /// the loop. This could be a performance problem, so it must be explicitly
812   /// activated with `alloc-return-allocs`.
813   ///
814   /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
815   /// equivalence condition must be checked for both.
verifyAnalysismlir::scf::__anon76a8a75a0111::WhileOpInterface816   LogicalResult verifyAnalysis(Operation *op,
817                                const AnalysisState &state) const {
818     auto whileOp = cast<scf::WhileOp>(op);
819     const auto &options =
820         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
821     if (options.allowReturnAllocs)
822       return success();
823 
824     auto conditionOp = whileOp.getConditionOp();
825     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
826       if (!it.value().getType().isa<TensorType>())
827         continue;
828       if (!state.areEquivalentBufferizedValues(
829               it.value(), conditionOp->getBlock()->getArgument(it.index())))
830         return conditionOp->emitError()
831                << "Condition arg #" << it.index()
832                << " is not equivalent to the corresponding iter bbArg";
833     }
834 
835     auto yieldOp = whileOp.getYieldOp();
836     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
837       if (!it.value().getType().isa<TensorType>())
838         continue;
839       if (!state.areEquivalentBufferizedValues(
840               it.value(), yieldOp->getBlock()->getArgument(it.index())))
841         return yieldOp->emitError()
842                << "Yield operand #" << it.index()
843                << " is not equivalent to the corresponding iter bbArg";
844     }
845 
846     return success();
847   }
848 };
849 
850 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
851 /// this is for analysis only.
852 struct YieldOpInterface
853     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
854                                                     scf::YieldOp> {
bufferizesToMemoryReadmlir::scf::__anon76a8a75a0111::YieldOpInterface855   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
856                               const AnalysisState &state) const {
857     return true;
858   }
859 
bufferizesToMemoryWritemlir::scf::__anon76a8a75a0111::YieldOpInterface860   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
861                                const AnalysisState &state) const {
862     return false;
863   }
864 
getAliasingOpResultmlir::scf::__anon76a8a75a0111::YieldOpInterface865   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
866                                             const AnalysisState &state) const {
867     if (isa<scf::IfOp>(op->getParentOp()))
868       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
869     if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
870       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
871     return {};
872   }
873 
mustBufferizeInPlacemlir::scf::__anon76a8a75a0111::YieldOpInterface874   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
875                             const AnalysisState &state) const {
876     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
877     // may be generated inside the block. We should not return/yield allocations
878     // when possible.
879     return true;
880   }
881 
bufferizemlir::scf::__anon76a8a75a0111::YieldOpInterface882   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
883                           const BufferizationOptions &options) const {
884     auto yieldOp = cast<scf::YieldOp>(op);
885     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
886             yieldOp->getParentOp()))
887       return yieldOp->emitError("unsupported scf::YieldOp parent");
888 
889     // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized
890     // together with scf.while.)
891     if (isa<scf::WhileOp>(yieldOp->getParentOp()))
892       return success();
893 
894     SmallVector<Value> newResults;
895     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
896       Value value = it.value();
897       if (value.getType().isa<TensorType>()) {
898         FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
899         if (failed(maybeBuffer))
900           return failure();
901         Value buffer = *maybeBuffer;
902         if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
903           FailureOr<BaseMemRefType> resultType =
904               cast<BufferizableOpInterface>(forOp.getOperation())
905                   .getBufferType(forOp.getRegionIterArgs()[it.index()],
906                                  options);
907           if (failed(resultType))
908             return failure();
909           buffer = castBuffer(rewriter, buffer, *resultType);
910         }
911         newResults.push_back(buffer);
912       } else {
913         newResults.push_back(value);
914       }
915     }
916 
917     replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
918     return success();
919   }
920 };
921 
922 /// Return the destinations that an ForeachThreadOp is inserting into. One per
923 /// ParallelInsertSliceOp.
924 static SmallVector<OpOperand *>
getInsertionDest(ForeachThreadOp foreachThreadOp)925 getInsertionDest(ForeachThreadOp foreachThreadOp) {
926   PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
927   SmallVector<OpOperand *> result;
928   terminator.walk([&](tensor::ParallelInsertSliceOp insertOp) {
929     result.push_back(&insertOp->getOpOperand(1) /*dest*/);
930   });
931   return result;
932 }
933 
934 /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
935 /// region. There are op interfaces for the terminators (PerformConcurrentlyOp
936 /// and ParallelInsertSliceOp), but these are only used during analysis. Not
937 /// for bufferization.
938 struct ForeachThreadOpInterface
939     : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface,
940                                                     ForeachThreadOp> {
941   SmallVector<OpOperand *>
getAliasingOpOperandmlir::scf::__anon76a8a75a0111::ForeachThreadOpInterface942   getAliasingOpOperand(Operation *op, OpResult opResult,
943                        const AnalysisState &state) const {
944     // Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
945     auto foreachThreadOp = cast<ForeachThreadOp>(op);
946     return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]};
947   }
948 
isMemoryWritemlir::scf::__anon76a8a75a0111::ForeachThreadOpInterface949   bool isMemoryWrite(Operation *op, OpResult opResult,
950                      const AnalysisState &state) const {
951     // This op is a memory write. Stop lookup here to avoid finding false
952     // conflicts involving this op and one of the ops in the region. This is
953     // similar to how scf.if ops are analyzed.
954     return true;
955   }
956 
bufferRelationmlir::scf::__anon76a8a75a0111::ForeachThreadOpInterface957   BufferRelation bufferRelation(Operation *op, OpResult opResult,
958                                 const AnalysisState &state) const {
959     return BufferRelation::Equivalent;
960   }
961 
bufferizemlir::scf::__anon76a8a75a0111::ForeachThreadOpInterface962   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
963                           const BufferizationOptions &options) const {
964     auto foreachThreadOp = cast<ForeachThreadOp>(op);
965 
966 #ifndef NDEBUG
967     // ParallelInsertSliceOpInterface replaces all uses.
968     for (OpResult opResult : foreachThreadOp->getOpResults())
969       assert(opResult.getUses().empty() &&
970              "expected that all uses were already replaced");
971 #endif // NDEBUG
972 
973     // Create new ForeachThreadOp without any results and drop the automatically
974     // introduced terminator.
975     TypeRange newResultTypes;
976     auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
977         foreachThreadOp.getLoc(), newResultTypes,
978         foreachThreadOp.getNumThreads(),
979         extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping()));
980     newForeachThreadOp.getBody()->getTerminator()->erase();
981 
982     // Move over block contents of the old op.
983     rewriter.mergeBlocks(foreachThreadOp.getBody(),
984                          newForeachThreadOp.getBody(),
985                          {newForeachThreadOp.getBody()->getArguments()});
986 
987     // Remove the old op.
988     rewriter.eraseOp(op);
989 
990     return success();
991   }
992 };
993 
994 /// Nothing to do for PerformConcurrentlyOp.
995 struct PerformConcurrentlyOpInterface
996     : public BufferizableOpInterface::ExternalModel<
997           PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
bufferizemlir::scf::__anon76a8a75a0111::PerformConcurrentlyOpInterface998   LogicalResult bufferize(Operation *op, RewriterBase &b,
999                           const BufferizationOptions &options) const {
1000     llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1001     return failure();
1002   }
1003 };
1004 
1005 } // namespace
1006 } // namespace scf
1007 } // namespace mlir
1008 
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)1009 void mlir::scf::registerBufferizableOpInterfaceExternalModels(
1010     DialectRegistry &registry) {
1011   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1012     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1013     ForOp::attachInterface<ForOpInterface>(*ctx);
1014     IfOp::attachInterface<IfOpInterface>(*ctx);
1015     ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
1016     PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
1017         *ctx);
1018     WhileOp::attachInterface<WhileOpInterface>(*ctx);
1019     YieldOp::attachInterface<YieldOpInterface>(*ctx);
1020   });
1021 }
1022