1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
10
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Utils/StaticValueUtils.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/IR/PatternMatch.h"
21
22 using namespace mlir;
23 using namespace mlir::bufferization;
24 using namespace mlir::scf;
25
26 namespace mlir {
27 namespace scf {
28 namespace {
29
30 // bufferization.to_memref is not allowed to change the rank.
ensureToMemrefOpIsValid(Value tensor,Type memrefType)31 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
32 #ifndef NDEBUG
33 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
34 assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
35 rankedTensorType.getRank())) &&
36 "to_memref would be invalid: mismatching ranks");
37 #endif
38 }
39
40 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
41 /// fully implemented at the moment.
42 struct ExecuteRegionOpInterface
43 : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
44 scf::ExecuteRegionOp> {
45 SmallVector<OpOperand *>
getAliasingOpOperandmlir::scf::__anon76a8a75a0111::ExecuteRegionOpInterface46 getAliasingOpOperand(Operation *op, OpResult opResult,
47 const AnalysisState &state) const {
48 // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
49 // any SSA value that is in scope. To allow for use-def chain traversal
50 // through ExecuteRegionOps in the analysis, the corresponding yield value
51 // is considered to be aliasing with the result.
52 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
53 size_t resultNum = std::distance(op->getOpResults().begin(),
54 llvm::find(op->getOpResults(), opResult));
55 // TODO: Support multiple blocks.
56 assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
57 "expected exactly 1 block");
58 auto yieldOp = dyn_cast<scf::YieldOp>(
59 executeRegionOp.getRegion().front().getTerminator());
60 assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
61 return {&yieldOp->getOpOperand(resultNum)};
62 }
63
64 // TODO: For better bufferization results, this could return `true` only if
65 // there is a memory write in the region.
isMemoryWritemlir::scf::__anon76a8a75a0111::ExecuteRegionOpInterface66 bool isMemoryWrite(Operation *op, OpResult opResult,
67 const AnalysisState &state) const {
68 // Similar to scf.if, results of this op are always considered memory writes
69 // in the analysis. This is a useful pattern for all ops that have tensor
70 // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
71 // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
72 // ops without OpOperands.
73 return true;
74 }
75
bufferizemlir::scf::__anon76a8a75a0111::ExecuteRegionOpInterface76 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
77 const BufferizationOptions &options) const {
78 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
79 assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
80 "only 1 block supported");
81 auto yieldOp =
82 cast<scf::YieldOp>(executeRegionOp.getRegion().front().getTerminator());
83 TypeRange newResultTypes(yieldOp.getResults());
84
85 // Create new op and move over region.
86 auto newOp =
87 rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
88 newOp.getRegion().takeBody(executeRegionOp.getRegion());
89
90 // Update all uses of the old op.
91 rewriter.setInsertionPointAfter(newOp);
92 SmallVector<Value> newResults;
93 for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
94 if (it.value().isa<TensorType>()) {
95 newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
96 executeRegionOp.getLoc(), newOp->getResult(it.index())));
97 } else {
98 newResults.push_back(newOp->getResult(it.index()));
99 }
100 }
101
102 // Replace old op.
103 rewriter.replaceOp(executeRegionOp, newResults);
104
105 return success();
106 }
107
bufferRelationmlir::scf::__anon76a8a75a0111::ExecuteRegionOpInterface108 BufferRelation bufferRelation(Operation *op, OpResult opResult,
109 const AnalysisState &state) const {
110 return BufferRelation::Equivalent;
111 }
112 };
113
114 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
115 struct IfOpInterface
116 : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
117 SmallVector<OpOperand *>
getAliasingOpOperandmlir::scf::__anon76a8a75a0111::IfOpInterface118 getAliasingOpOperand(Operation *op, OpResult opResult,
119 const AnalysisState &state) const {
120 // IfOps do not have tensor OpOperands. The yielded value can be any SSA
121 // value that is in scope. To allow for use-def chain traversal through
122 // IfOps in the analysis, both corresponding yield values from the then/else
123 // branches are considered to be aliasing with the result.
124 auto ifOp = cast<scf::IfOp>(op);
125 size_t resultNum = std::distance(op->getOpResults().begin(),
126 llvm::find(op->getOpResults(), opResult));
127 return {&ifOp.thenYield()->getOpOperand(resultNum),
128 &ifOp.elseYield()->getOpOperand(resultNum)};
129 }
130
131 // TODO: For better bufferization results, this could return `true` only if
132 // there is a memory write in one (or both) of the branches. Since this is not
133 // allowed at the moment, we should never encounter scf.ifs that yield
134 // unmodified tensors. Such scf.yield ops could just fold away.
isMemoryWritemlir::scf::__anon76a8a75a0111::IfOpInterface135 bool isMemoryWrite(Operation *op, OpResult opResult,
136 const AnalysisState &state) const {
137 // IfOp results are always considered memory writes in the analysis. This
138 // design decision simplifies the analysis considerably. E.g., consider the
139 // following test case:
140 //
141 // %0 = "some_writing_op" : tensor<?xf32>
142 // %r = scf.if %c -> (tensor<?xf32>) {
143 // scf.yield %0
144 // } else {
145 // %1 = "another_writing_op"(%0) : tensor<?xf32>
146 // }
147 // "some_reading_op"(%r)
148 //
149 // "another_writing_op" in the above example should be able to bufferize
150 // inplace in the absence of another read of %0. However, if the scf.if op
151 // would not be considered a "write", the analysis would detect the
152 // following conflict:
153 //
154 // * read = some_reading_op
155 // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.)
156 // * conflictingWrite = %1
157 //
158 // For more details, check the "scf.IfOp" section of the design document.
159 return true;
160 }
161
bufferizemlir::scf::__anon76a8a75a0111::IfOpInterface162 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
163 const BufferizationOptions &options) const {
164 OpBuilder::InsertionGuard g(rewriter);
165 auto ifOp = cast<scf::IfOp>(op);
166 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
167 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
168
169 // Reconcile type mismatches between then/else branches by inserting memref
170 // casts.
171 SmallVector<Value> thenResults, elseResults;
172 bool insertedCast = false;
173 for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) {
174 Value thenValue = thenYieldOp.getResults()[i];
175 Value elseValue = elseYieldOp.getResults()[i];
176 if (thenValue.getType() == elseValue.getType()) {
177 thenResults.push_back(thenValue);
178 elseResults.push_back(elseValue);
179 continue;
180 }
181
182 // Type mismatch between then/else yield value. Cast both to a memref type
183 // with a fully dynamic layout map.
184 auto thenMemrefType = thenValue.getType().cast<BaseMemRefType>();
185 auto elseMemrefType = elseValue.getType().cast<BaseMemRefType>();
186 if (thenMemrefType.getMemorySpaceAsInt() !=
187 elseMemrefType.getMemorySpaceAsInt())
188 return op->emitError("inconsistent memory space on then/else branches");
189 rewriter.setInsertionPoint(thenYieldOp);
190 BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout(
191 ifOp.getResultTypes()[i].cast<TensorType>(),
192 thenMemrefType.getMemorySpaceAsInt());
193 thenResults.push_back(rewriter.create<memref::CastOp>(
194 thenYieldOp.getLoc(), memrefType, thenValue));
195 rewriter.setInsertionPoint(elseYieldOp);
196 elseResults.push_back(rewriter.create<memref::CastOp>(
197 elseYieldOp.getLoc(), memrefType, elseValue));
198 insertedCast = true;
199 }
200
201 if (insertedCast) {
202 rewriter.setInsertionPoint(thenYieldOp);
203 rewriter.replaceOpWithNewOp<scf::YieldOp>(thenYieldOp, thenResults);
204 rewriter.setInsertionPoint(elseYieldOp);
205 rewriter.replaceOpWithNewOp<scf::YieldOp>(elseYieldOp, elseResults);
206 }
207
208 // Create new op.
209 rewriter.setInsertionPoint(ifOp);
210 ValueRange resultsValueRange(thenResults);
211 TypeRange newTypes(resultsValueRange);
212 auto newIfOp =
213 rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
214 /*withElseRegion=*/true);
215
216 // Move over then/else blocks.
217 rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
218 rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
219
220 // Replace op results.
221 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
222
223 return success();
224 }
225
bufferRelationmlir::scf::__anon76a8a75a0111::IfOpInterface226 BufferRelation bufferRelation(Operation *op, OpResult opResult,
227 const AnalysisState &state) const {
228 // IfOp results are equivalent to their corresponding yield values if both
229 // yield values are equivalent to each other.
230 auto bufferizableOp = cast<BufferizableOpInterface>(op);
231 SmallVector<OpOperand *> yieldValues =
232 bufferizableOp.getAliasingOpOperand(opResult, state);
233 assert(yieldValues.size() == 2 && "expected 2 yield values");
234 bool equivalentYields = state.areEquivalentBufferizedValues(
235 yieldValues[0]->get(), yieldValues[1]->get());
236 return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
237 }
238 };
239
240 /// Helper function for loop bufferization. Return the indices of all values
241 /// that have a tensor type.
getTensorIndices(ValueRange values)242 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
243 DenseSet<int64_t> result;
244 for (const auto &it : llvm::enumerate(values))
245 if (it.value().getType().isa<TensorType>())
246 result.insert(it.index());
247 return result;
248 }
249
250 /// Helper function for loop bufferization. Return the indices of all
251 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
getEquivalentBuffers(Block::BlockArgListType bbArgs,ValueRange yieldedValues,const AnalysisState & state)252 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
253 ValueRange yieldedValues,
254 const AnalysisState &state) {
255 unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
256 DenseSet<int64_t> result;
257 for (unsigned int i = 0; i < minSize; ++i) {
258 if (!bbArgs[i].getType().isa<TensorType>() ||
259 !yieldedValues[i].getType().isa<TensorType>())
260 continue;
261 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
262 result.insert(i);
263 }
264 return result;
265 }
266
267 /// Helper function for loop bufferization. Cast the given buffer to the given
268 /// memref type.
castBuffer(OpBuilder & b,Value buffer,Type type)269 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
270 assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
271 assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
272 // If the buffer already has the correct type, no cast is needed.
273 if (buffer.getType() == type)
274 return buffer;
275 // TODO: In case `type` has a layout map that is not the fully dynamic
276 // one, we may not be able to cast the buffer. In that case, the loop
277 // iter_arg's layout map must be changed (see uses of `castBuffer`).
278 assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
279 "scf.while op bufferization: cast incompatible");
280 return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
281 }
282
283 /// Helper function for loop bufferization. Return the bufferized values of the
284 /// given OpOperands. If an operand is not a tensor, return the original value.
285 static FailureOr<SmallVector<Value>>
getBuffers(RewriterBase & rewriter,MutableArrayRef<OpOperand> operands,const BufferizationOptions & options)286 getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
287 const BufferizationOptions &options) {
288 SmallVector<Value> result;
289 for (OpOperand &opOperand : operands) {
290 if (opOperand.get().getType().isa<TensorType>()) {
291 FailureOr<Value> resultBuffer =
292 getBuffer(rewriter, opOperand.get(), options);
293 if (failed(resultBuffer))
294 return failure();
295 result.push_back(*resultBuffer);
296 } else {
297 result.push_back(opOperand.get());
298 }
299 }
300 return result;
301 }
302
303 /// Helper function for loop bufferization. Compute the buffer that should be
304 /// yielded from a loop block (loop body or loop condition).
getYieldedBuffer(RewriterBase & rewriter,Value tensor,BaseMemRefType type,const BufferizationOptions & options)305 static FailureOr<Value> getYieldedBuffer(RewriterBase &rewriter, Value tensor,
306 BaseMemRefType type,
307 const BufferizationOptions &options) {
308 assert(tensor.getType().isa<TensorType>() && "expected tensor");
309 ensureToMemrefOpIsValid(tensor, type);
310 FailureOr<Value> yieldedVal = getBuffer(rewriter, tensor, options);
311 if (failed(yieldedVal))
312 return failure();
313 return castBuffer(rewriter, *yieldedVal, type);
314 }
315
316 /// Helper function for loop bufferization. Given a range of values, apply
317 /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
318 /// value in the result vector.
319 static FailureOr<SmallVector<Value>>
convertTensorValues(ValueRange values,const DenseSet<int64_t> & tensorIndices,llvm::function_ref<FailureOr<Value> (Value,int64_t)> func)320 convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
321 llvm::function_ref<FailureOr<Value>(Value, int64_t)> func) {
322 SmallVector<Value> result;
323 for (const auto &it : llvm::enumerate(values)) {
324 size_t idx = it.index();
325 Value val = it.value();
326 if (tensorIndices.contains(idx)) {
327 FailureOr<Value> maybeVal = func(val, idx);
328 if (failed(maybeVal))
329 return failure();
330 result.push_back(*maybeVal);
331 } else {
332 result.push_back(val);
333 }
334 }
335 return result;
336 }
337
338 /// Helper function for loop bufferization. Given a list of pre-bufferization
339 /// yielded values, compute the list of bufferized yielded values.
340 FailureOr<SmallVector<Value>>
getYieldedValues(RewriterBase & rewriter,ValueRange values,TypeRange bufferizedTypes,const DenseSet<int64_t> & tensorIndices,const BufferizationOptions & options)341 getYieldedValues(RewriterBase &rewriter, ValueRange values,
342 TypeRange bufferizedTypes,
343 const DenseSet<int64_t> &tensorIndices,
344 const BufferizationOptions &options) {
345 return convertTensorValues(
346 values, tensorIndices, [&](Value val, int64_t index) {
347 return getYieldedBuffer(rewriter, val,
348 bufferizedTypes[index].cast<BaseMemRefType>(),
349 options);
350 });
351 }
352
353 /// Helper function for loop bufferization. Given a list of bbArgs of the new
354 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
355 /// ToTensorOps, so that the block body can be moved over to the new op.
356 SmallVector<Value>
getBbArgReplacements(RewriterBase & rewriter,Block::BlockArgListType bbArgs,const DenseSet<int64_t> & tensorIndices)357 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
358 const DenseSet<int64_t> &tensorIndices) {
359 SmallVector<Value> result;
360 for (const auto &it : llvm::enumerate(bbArgs)) {
361 size_t idx = it.index();
362 Value val = it.value();
363 if (tensorIndices.contains(idx)) {
364 result.push_back(
365 rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
366 .getResult());
367 } else {
368 result.push_back(val);
369 }
370 }
371 return result;
372 }
373
374 /// Bufferization of scf.for. Replace with a new scf.for that operates on
375 /// memrefs.
376 struct ForOpInterface
377 : public BufferizableOpInterface::ExternalModel<ForOpInterface,
378 scf::ForOp> {
bufferizesToMemoryReadmlir::scf::__anon76a8a75a0111::ForOpInterface379 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
380 const AnalysisState &state) const {
381 // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
382 // its matching bbArg may.
383 auto forOp = cast<scf::ForOp>(op);
384 return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
385 }
386
bufferizesToMemoryWritemlir::scf::__anon76a8a75a0111::ForOpInterface387 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
388 const AnalysisState &state) const {
389 // Tensor iter_args of scf::ForOps are always considered as a write.
390 return true;
391 }
392
getAliasingOpResultmlir::scf::__anon76a8a75a0111::ForOpInterface393 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
394 const AnalysisState &state) const {
395 auto forOp = cast<scf::ForOp>(op);
396 return {forOp.getResultForOpOperand(opOperand)};
397 }
398
bufferRelationmlir::scf::__anon76a8a75a0111::ForOpInterface399 BufferRelation bufferRelation(Operation *op, OpResult opResult,
400 const AnalysisState &state) const {
401 // ForOp results are equivalent to their corresponding init_args if the
402 // corresponding iter_args and yield values are equivalent.
403 auto forOp = cast<scf::ForOp>(op);
404 OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
405 auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
406 auto yieldOp =
407 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
408 bool equivalentYield = state.areEquivalentBufferizedValues(
409 bbArg, yieldOp->getOperand(opResult.getResultNumber()));
410 return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
411 }
412
isWritablemlir::scf::__anon76a8a75a0111::ForOpInterface413 bool isWritable(Operation *op, Value value,
414 const AnalysisState &state) const {
415 // Interestingly, scf::ForOp's bbArg can **always** be viewed
416 // inplace from the perspective of ops nested under:
417 // 1. Either the matching iter operand is not bufferized inplace and an
418 // alloc + optional copy makes the bbArg itself inplaceable.
419 // 2. Or the matching iter operand is bufferized inplace and bbArg just
420 // bufferizes to that too.
421 return true;
422 }
423
resolveConflictsmlir::scf::__anon76a8a75a0111::ForOpInterface424 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
425 const AnalysisState &state) const {
426 auto bufferizableOp = cast<BufferizableOpInterface>(op);
427 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
428 return failure();
429
430 if (!state.getOptions().enforceAliasingInvariants)
431 return success();
432
433 // According to the `getAliasing...` implementations, a bufferized OpResult
434 // may alias only with the corresponding bufferized init_arg and with no
435 // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
436 // but not with any other OpOperand. If a corresponding OpResult/init_arg
437 // pair bufferizes to equivalent buffers, this aliasing requirement is
438 // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
439 // (New buffer copies do not alias with any buffer.)
440 auto forOp = cast<scf::ForOp>(op);
441 auto yieldOp =
442 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
443 OpBuilder::InsertionGuard g(rewriter);
444 rewriter.setInsertionPoint(yieldOp);
445
446 // Indices of all iter_args that have tensor type. These are the ones that
447 // are bufferized.
448 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
449 // For every yielded value, is the value equivalent to its corresponding
450 // bbArg?
451 DenseSet<int64_t> equivalentYields = getEquivalentBuffers(
452 forOp.getRegionIterArgs(), yieldOp.getResults(), state);
453 SmallVector<Value> yieldValues;
454 for (int64_t idx = 0;
455 idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
456 Value value = yieldOp.getResults()[idx];
457 if (!indices.contains(idx) || equivalentYields.contains(idx)) {
458 yieldValues.push_back(value);
459 continue;
460 }
461 FailureOr<Value> alloc =
462 allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
463 /*escape=*/true, state.getOptions());
464 if (failed(alloc))
465 return failure();
466 yieldValues.push_back(*alloc);
467 }
468
469 rewriter.updateRootInPlace(
470 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
471 return success();
472 }
473
474 FailureOr<BaseMemRefType>
getBufferTypemlir::scf::__anon76a8a75a0111::ForOpInterface475 getBufferType(Operation *op, BlockArgument bbArg,
476 const BufferizationOptions &options) const {
477 auto forOp = cast<scf::ForOp>(op);
478 return bufferization::getBufferType(
479 forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
480 }
481
bufferizemlir::scf::__anon76a8a75a0111::ForOpInterface482 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
483 const BufferizationOptions &options) const {
484 auto forOp = cast<scf::ForOp>(op);
485 Block *oldLoopBody = &forOp.getLoopBody().front();
486
487 // Indices of all iter_args that have tensor type. These are the ones that
488 // are bufferized.
489 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
490
491 // The new memref init_args of the loop.
492 FailureOr<SmallVector<Value>> maybeInitArgs =
493 getBuffers(rewriter, forOp.getIterOpOperands(), options);
494 if (failed(maybeInitArgs))
495 return failure();
496 SmallVector<Value> initArgs = *maybeInitArgs;
497
498 // Construct a new scf.for op with memref instead of tensor values.
499 auto newForOp = rewriter.create<scf::ForOp>(
500 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
501 forOp.getStep(), initArgs);
502 newForOp->setAttrs(forOp->getAttrs());
503 ValueRange initArgsRange(initArgs);
504 TypeRange initArgsTypes(initArgsRange);
505 Block *loopBody = &newForOp.getLoopBody().front();
506
507 // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
508 // iter_args of the new loop in ToTensorOps.
509 rewriter.setInsertionPointToStart(loopBody);
510 SmallVector<Value> iterArgs =
511 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
512 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
513
514 // Move loop body to new loop.
515 rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
516
517 // Replace loop results.
518 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
519
520 return success();
521 }
522
523 /// Assert that yielded values of an scf.for op are equivalent to their
524 /// corresponding bbArgs. In that case, the buffer relations of the
525 /// corresponding OpResults are "Equivalent".
526 ///
527 /// If this is not the case, an allocs+copies are inserted and yielded from
528 /// the loop. This could be a performance problem, so it must be explicitly
529 /// activated with `alloc-return-allocs`.
verifyAnalysismlir::scf::__anon76a8a75a0111::ForOpInterface530 LogicalResult verifyAnalysis(Operation *op,
531 const AnalysisState &state) const {
532 const auto &options =
533 static_cast<const OneShotBufferizationOptions &>(state.getOptions());
534 if (options.allowReturnAllocs)
535 return success();
536
537 auto forOp = cast<scf::ForOp>(op);
538 auto yieldOp =
539 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
540 for (OpResult opResult : op->getOpResults()) {
541 if (!opResult.getType().isa<TensorType>())
542 continue;
543
544 // Note: This is overly strict. We should check for aliasing bufferized
545 // values. But we don't have a "must-alias" analysis yet.
546 if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
547 return yieldOp->emitError()
548 << "Yield operand #" << opResult.getResultNumber()
549 << " is not equivalent to the corresponding iter bbArg";
550 }
551
552 return success();
553 }
554 };
555
556 /// Bufferization of scf.while. Replace with a new scf.while that operates on
557 /// memrefs.
558 struct WhileOpInterface
559 : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
560 scf::WhileOp> {
bufferizesToMemoryReadmlir::scf::__anon76a8a75a0111::WhileOpInterface561 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
562 const AnalysisState &state) const {
563 // Tensor iter_args of scf::WhileOps are always considered as a read.
564 return true;
565 }
566
bufferizesToMemoryWritemlir::scf::__anon76a8a75a0111::WhileOpInterface567 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
568 const AnalysisState &state) const {
569 // Tensor iter_args of scf::WhileOps are always considered as a write.
570 return true;
571 }
572
getAliasingOpResultmlir::scf::__anon76a8a75a0111::WhileOpInterface573 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
574 const AnalysisState &state) const {
575 auto whileOp = cast<scf::WhileOp>(op);
576 unsigned int idx = opOperand.getOperandNumber();
577
578 // The OpResults and OpOperands may not match. They may not even have the
579 // same type. The number of OpResults and OpOperands can also differ.
580 if (idx >= op->getNumResults() ||
581 opOperand.get().getType() != op->getResult(idx).getType())
582 return {};
583
584 // The only aliasing OpResult may be the one at the same index.
585 return {whileOp->getResult(idx)};
586 }
587
bufferRelationmlir::scf::__anon76a8a75a0111::WhileOpInterface588 BufferRelation bufferRelation(Operation *op, OpResult opResult,
589 const AnalysisState &state) const {
590 // WhileOp results are equivalent to their corresponding init_args if the
591 // corresponding iter_args and yield values are equivalent (for both the
592 // "before" and the "after" block).
593 unsigned int resultNumber = opResult.getResultNumber();
594 auto whileOp = cast<scf::WhileOp>(op);
595
596 // The "before" region bbArgs and the OpResults may not match.
597 if (resultNumber >= whileOp.getBeforeArguments().size())
598 return BufferRelation::None;
599 if (opResult.getType() !=
600 whileOp.getBeforeArguments()[resultNumber].getType())
601 return BufferRelation::None;
602
603 auto conditionOp = whileOp.getConditionOp();
604 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
605 Value conditionOperand = conditionOp.getArgs()[resultNumber];
606 bool equivCondition =
607 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
608
609 auto yieldOp = whileOp.getYieldOp();
610 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
611 Value yieldOperand = yieldOp.getOperand(resultNumber);
612 bool equivYield =
613 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
614
615 return equivCondition && equivYield ? BufferRelation::Equivalent
616 : BufferRelation::None;
617 }
618
isWritablemlir::scf::__anon76a8a75a0111::WhileOpInterface619 bool isWritable(Operation *op, Value value,
620 const AnalysisState &state) const {
621 // Interestingly, scf::WhileOp's bbArg can **always** be viewed
622 // inplace from the perspective of ops nested under:
623 // 1. Either the matching iter operand is not bufferized inplace and an
624 // alloc + optional copy makes the bbArg itself inplaceable.
625 // 2. Or the matching iter operand is bufferized inplace and bbArg just
626 // bufferizes to that too.
627 return true;
628 }
629
resolveConflictsmlir::scf::__anon76a8a75a0111::WhileOpInterface630 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
631 const AnalysisState &state) const {
632 auto bufferizableOp = cast<BufferizableOpInterface>(op);
633 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
634 return failure();
635
636 if (!state.getOptions().enforceAliasingInvariants)
637 return success();
638
639 // According to the `getAliasing...` implementations, a bufferized OpResult
640 // may alias only with the corresponding bufferized init_arg and with no
641 // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
642 // but not with any other OpOperand. If a corresponding OpResult/init_arg
643 // pair bufferizes to equivalent buffers, this aliasing requirement is
644 // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
645 // (New buffer copies do not alias with any buffer.)
646 OpBuilder::InsertionGuard g(rewriter);
647 auto whileOp = cast<scf::WhileOp>(op);
648 auto conditionOp = whileOp.getConditionOp();
649 auto yieldOp = whileOp.getYieldOp();
650
651 // Indices of all bbArgs that have tensor type. These are the ones that
652 // are bufferized. The "before" and "after" regions may have different args.
653 DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
654 DenseSet<int64_t> indicesAfter =
655 getTensorIndices(whileOp.getAfterArguments());
656
657 // For every yielded value, is the value equivalent to its corresponding
658 // bbArg?
659 DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
660 whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
661 DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
662 whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
663
664 // Update "before" region.
665 rewriter.setInsertionPoint(conditionOp);
666 SmallVector<Value> beforeYieldValues;
667 for (int64_t idx = 0;
668 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
669 Value value = conditionOp.getArgs()[idx];
670 if (!indicesBefore.contains(idx) ||
671 equivalentYieldsBefore.contains(idx)) {
672 beforeYieldValues.push_back(value);
673 continue;
674 }
675 FailureOr<Value> alloc =
676 allocateTensorForShapedValue(rewriter, conditionOp.getLoc(), value,
677 /*escape=*/true, state.getOptions());
678 if (failed(alloc))
679 return failure();
680 beforeYieldValues.push_back(*alloc);
681 }
682 rewriter.updateRootInPlace(conditionOp, [&]() {
683 conditionOp.getArgsMutable().assign(beforeYieldValues);
684 });
685
686 // Update "after" region.
687 rewriter.setInsertionPoint(yieldOp);
688 SmallVector<Value> afterYieldValues;
689 for (int64_t idx = 0;
690 idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
691 Value value = yieldOp.getResults()[idx];
692 if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) {
693 afterYieldValues.push_back(value);
694 continue;
695 }
696 FailureOr<Value> alloc =
697 allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
698 /*escape=*/true, state.getOptions());
699 if (failed(alloc))
700 return failure();
701 afterYieldValues.push_back(*alloc);
702 }
703 rewriter.updateRootInPlace(yieldOp, [&]() {
704 yieldOp.getResultsMutable().assign(afterYieldValues);
705 });
706
707 return success();
708 }
709
710 // TODO: Implement getBufferType interface method and infer buffer types.
711
bufferizemlir::scf::__anon76a8a75a0111::WhileOpInterface712 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
713 const BufferizationOptions &options) const {
714 auto whileOp = cast<scf::WhileOp>(op);
715
716 assert(whileOp.getBefore().getBlocks().size() == 1 &&
717 "regions with multiple blocks not supported");
718 Block *beforeBody = &whileOp.getBefore().front();
719 assert(whileOp.getAfter().getBlocks().size() == 1 &&
720 "regions with multiple blocks not supported");
721 Block *afterBody = &whileOp.getAfter().front();
722
723 // Indices of all bbArgs that have tensor type. These are the ones that
724 // are bufferized. The "before" and "after" regions may have different args.
725 DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
726 DenseSet<int64_t> indicesAfter =
727 getTensorIndices(whileOp.getAfterArguments());
728
729 // The new memref init_args of the loop.
730 FailureOr<SmallVector<Value>> maybeInitArgs =
731 getBuffers(rewriter, whileOp->getOpOperands(), options);
732 if (failed(maybeInitArgs))
733 return failure();
734 SmallVector<Value> initArgs = *maybeInitArgs;
735
736 // The result types of a WhileOp are the same as the "after" bbArg types.
737 SmallVector<Type> argsTypesAfter = llvm::to_vector(
738 llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
739 // TODO: error handling
740 return bufferization::getBufferType(bbArg, options)->cast<Type>();
741 }));
742
743 // Construct a new scf.while op with memref instead of tensor values.
744 ValueRange argsRangeBefore(initArgs);
745 TypeRange argsTypesBefore(argsRangeBefore);
746 auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
747 argsTypesAfter, initArgs);
748
749 // Add before/after regions to the new op.
750 SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc());
751 SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
752 whileOp.getLoc());
753 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
754 newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
755 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
756 newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
757
758 // Set up new iter_args and move the loop condition block to the new op.
759 // The old block uses tensors, so wrap the (memref) bbArgs of the new block
760 // in ToTensorOps.
761 rewriter.setInsertionPointToStart(newBeforeBody);
762 SmallVector<Value> newBeforeArgs = getBbArgReplacements(
763 rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
764 rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
765
766 // Update scf.condition of new loop.
767 auto newConditionOp = newWhileOp.getConditionOp();
768 rewriter.setInsertionPoint(newConditionOp);
769 // Only equivalent buffers or new buffer allocations may be yielded to the
770 // "after" region.
771 // TODO: This could be relaxed for better bufferization results.
772 FailureOr<SmallVector<Value>> newConditionArgs =
773 getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
774 indicesAfter, options);
775 if (failed(newConditionArgs))
776 return failure();
777 newConditionOp.getArgsMutable().assign(*newConditionArgs);
778
779 // Set up new iter_args and move the loop body block to the new op.
780 // The old block uses tensors, so wrap the (memref) bbArgs of the new block
781 // in ToTensorOps.
782 rewriter.setInsertionPointToStart(newAfterBody);
783 SmallVector<Value> newAfterArgs = getBbArgReplacements(
784 rewriter, newWhileOp.getAfterArguments(), indicesAfter);
785 rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
786
787 // Update scf.yield of the new loop.
788 auto newYieldOp = newWhileOp.getYieldOp();
789 rewriter.setInsertionPoint(newYieldOp);
790 // Only equivalent buffers or new buffer allocations may be yielded to the
791 // "before" region.
792 // TODO: This could be relaxed for better bufferization results.
793 FailureOr<SmallVector<Value>> newYieldValues =
794 getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
795 indicesBefore, options);
796 if (failed(newYieldValues))
797 return failure();
798 newYieldOp.getResultsMutable().assign(*newYieldValues);
799
800 // Replace loop results.
801 replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
802
803 return success();
804 }
805
806 /// Assert that yielded values of an scf.while op are equivalent to their
807 /// corresponding bbArgs. In that case, the buffer relations of the
808 /// corresponding OpResults are "Equivalent".
809 ///
810 /// If this is not the case, allocs+copies are inserted and yielded from
811 /// the loop. This could be a performance problem, so it must be explicitly
812 /// activated with `alloc-return-allocs`.
813 ///
814 /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
815 /// equivalence condition must be checked for both.
verifyAnalysismlir::scf::__anon76a8a75a0111::WhileOpInterface816 LogicalResult verifyAnalysis(Operation *op,
817 const AnalysisState &state) const {
818 auto whileOp = cast<scf::WhileOp>(op);
819 const auto &options =
820 static_cast<const OneShotBufferizationOptions &>(state.getOptions());
821 if (options.allowReturnAllocs)
822 return success();
823
824 auto conditionOp = whileOp.getConditionOp();
825 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
826 if (!it.value().getType().isa<TensorType>())
827 continue;
828 if (!state.areEquivalentBufferizedValues(
829 it.value(), conditionOp->getBlock()->getArgument(it.index())))
830 return conditionOp->emitError()
831 << "Condition arg #" << it.index()
832 << " is not equivalent to the corresponding iter bbArg";
833 }
834
835 auto yieldOp = whileOp.getYieldOp();
836 for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
837 if (!it.value().getType().isa<TensorType>())
838 continue;
839 if (!state.areEquivalentBufferizedValues(
840 it.value(), yieldOp->getBlock()->getArgument(it.index())))
841 return yieldOp->emitError()
842 << "Yield operand #" << it.index()
843 << " is not equivalent to the corresponding iter bbArg";
844 }
845
846 return success();
847 }
848 };
849
850 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
851 /// this is for analysis only.
852 struct YieldOpInterface
853 : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
854 scf::YieldOp> {
bufferizesToMemoryReadmlir::scf::__anon76a8a75a0111::YieldOpInterface855 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
856 const AnalysisState &state) const {
857 return true;
858 }
859
bufferizesToMemoryWritemlir::scf::__anon76a8a75a0111::YieldOpInterface860 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
861 const AnalysisState &state) const {
862 return false;
863 }
864
getAliasingOpResultmlir::scf::__anon76a8a75a0111::YieldOpInterface865 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
866 const AnalysisState &state) const {
867 if (isa<scf::IfOp>(op->getParentOp()))
868 return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
869 if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
870 return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
871 return {};
872 }
873
mustBufferizeInPlacemlir::scf::__anon76a8a75a0111::YieldOpInterface874 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
875 const AnalysisState &state) const {
876 // Yield operands always bufferize inplace. Otherwise, an alloc + copy
877 // may be generated inside the block. We should not return/yield allocations
878 // when possible.
879 return true;
880 }
881
bufferizemlir::scf::__anon76a8a75a0111::YieldOpInterface882 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
883 const BufferizationOptions &options) const {
884 auto yieldOp = cast<scf::YieldOp>(op);
885 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
886 yieldOp->getParentOp()))
887 return yieldOp->emitError("unsupported scf::YieldOp parent");
888
889 // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized
890 // together with scf.while.)
891 if (isa<scf::WhileOp>(yieldOp->getParentOp()))
892 return success();
893
894 SmallVector<Value> newResults;
895 for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
896 Value value = it.value();
897 if (value.getType().isa<TensorType>()) {
898 FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
899 if (failed(maybeBuffer))
900 return failure();
901 Value buffer = *maybeBuffer;
902 if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
903 FailureOr<BaseMemRefType> resultType =
904 cast<BufferizableOpInterface>(forOp.getOperation())
905 .getBufferType(forOp.getRegionIterArgs()[it.index()],
906 options);
907 if (failed(resultType))
908 return failure();
909 buffer = castBuffer(rewriter, buffer, *resultType);
910 }
911 newResults.push_back(buffer);
912 } else {
913 newResults.push_back(value);
914 }
915 }
916
917 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
918 return success();
919 }
920 };
921
922 /// Return the destinations that an ForeachThreadOp is inserting into. One per
923 /// ParallelInsertSliceOp.
924 static SmallVector<OpOperand *>
getInsertionDest(ForeachThreadOp foreachThreadOp)925 getInsertionDest(ForeachThreadOp foreachThreadOp) {
926 PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
927 SmallVector<OpOperand *> result;
928 terminator.walk([&](tensor::ParallelInsertSliceOp insertOp) {
929 result.push_back(&insertOp->getOpOperand(1) /*dest*/);
930 });
931 return result;
932 }
933
934 /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
935 /// region. There are op interfaces for the terminators (PerformConcurrentlyOp
936 /// and ParallelInsertSliceOp), but these are only used during analysis. Not
937 /// for bufferization.
938 struct ForeachThreadOpInterface
939 : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface,
940 ForeachThreadOp> {
941 SmallVector<OpOperand *>
getAliasingOpOperandmlir::scf::__anon76a8a75a0111::ForeachThreadOpInterface942 getAliasingOpOperand(Operation *op, OpResult opResult,
943 const AnalysisState &state) const {
944 // Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
945 auto foreachThreadOp = cast<ForeachThreadOp>(op);
946 return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]};
947 }
948
isMemoryWritemlir::scf::__anon76a8a75a0111::ForeachThreadOpInterface949 bool isMemoryWrite(Operation *op, OpResult opResult,
950 const AnalysisState &state) const {
951 // This op is a memory write. Stop lookup here to avoid finding false
952 // conflicts involving this op and one of the ops in the region. This is
953 // similar to how scf.if ops are analyzed.
954 return true;
955 }
956
bufferRelationmlir::scf::__anon76a8a75a0111::ForeachThreadOpInterface957 BufferRelation bufferRelation(Operation *op, OpResult opResult,
958 const AnalysisState &state) const {
959 return BufferRelation::Equivalent;
960 }
961
bufferizemlir::scf::__anon76a8a75a0111::ForeachThreadOpInterface962 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
963 const BufferizationOptions &options) const {
964 auto foreachThreadOp = cast<ForeachThreadOp>(op);
965
966 #ifndef NDEBUG
967 // ParallelInsertSliceOpInterface replaces all uses.
968 for (OpResult opResult : foreachThreadOp->getOpResults())
969 assert(opResult.getUses().empty() &&
970 "expected that all uses were already replaced");
971 #endif // NDEBUG
972
973 // Create new ForeachThreadOp without any results and drop the automatically
974 // introduced terminator.
975 TypeRange newResultTypes;
976 auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
977 foreachThreadOp.getLoc(), newResultTypes,
978 foreachThreadOp.getNumThreads(),
979 extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping()));
980 newForeachThreadOp.getBody()->getTerminator()->erase();
981
982 // Move over block contents of the old op.
983 rewriter.mergeBlocks(foreachThreadOp.getBody(),
984 newForeachThreadOp.getBody(),
985 {newForeachThreadOp.getBody()->getArguments()});
986
987 // Remove the old op.
988 rewriter.eraseOp(op);
989
990 return success();
991 }
992 };
993
994 /// Nothing to do for PerformConcurrentlyOp.
995 struct PerformConcurrentlyOpInterface
996 : public BufferizableOpInterface::ExternalModel<
997 PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
bufferizemlir::scf::__anon76a8a75a0111::PerformConcurrentlyOpInterface998 LogicalResult bufferize(Operation *op, RewriterBase &b,
999 const BufferizationOptions &options) const {
1000 llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1001 return failure();
1002 }
1003 };
1004
1005 } // namespace
1006 } // namespace scf
1007 } // namespace mlir
1008
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)1009 void mlir::scf::registerBufferizableOpInterfaceExternalModels(
1010 DialectRegistry ®istry) {
1011 registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1012 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1013 ForOp::attachInterface<ForOpInterface>(*ctx);
1014 IfOp::attachInterface<IfOpInterface>(*ctx);
1015 ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
1016 PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
1017 *ctx);
1018 WhileOp::attachInterface<WhileOpInterface>(*ctx);
1019 YieldOp::attachInterface<YieldOpInterface>(*ctx);
1020 });
1021 }
1022