1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 using namespace mlir::bufferization;
22 using namespace mlir::scf;
23 
24 namespace mlir {
25 namespace scf {
26 namespace {
27 
28 // bufferization.to_memref is not allowed to change the rank.
29 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
30 #ifndef NDEBUG
31   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
32   assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
33                                 rankedTensorType.getRank())) &&
34          "to_memref would be invalid: mismatching ranks");
35 #endif
36 }
37 
38 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
39 /// fully implemented at the moment.
40 struct ExecuteRegionOpInterface
41     : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
42                                                     scf::ExecuteRegionOp> {
43   SmallVector<OpOperand *>
44   getAliasingOpOperand(Operation *op, OpResult opResult,
45                        const AnalysisState &state) const {
46     // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
47     // any SSA value that is in scope. To allow for use-def chain traversal
48     // through ExecuteRegionOps in the analysis, the corresponding yield value
49     // is considered to be aliasing with the result.
50     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
51     size_t resultNum = std::distance(op->getOpResults().begin(),
52                                      llvm::find(op->getOpResults(), opResult));
53     // TODO: Support multiple blocks.
54     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
55            "expected exactly 1 block");
56     auto yieldOp = dyn_cast<scf::YieldOp>(
57         executeRegionOp.getRegion().front().getTerminator());
58     assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
59     return {&yieldOp->getOpOperand(resultNum)};
60   }
61 
62   // TODO: For better bufferization results, this could return `true` only if
63   // there is a memory write in the region.
64   bool isMemoryWrite(Operation *op, OpResult opResult,
65                      const AnalysisState &state) const {
66     // Similar to scf.if, results of this op are always considered memory writes
67     // in the analysis. This is a useful pattern for all ops that have tensor
68     // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
69     // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
70     // ops without OpOperands.
71     return true;
72   }
73 
74   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
75                           BufferizationState &state) const {
76     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
77 
78     // Compute new result types.
79     SmallVector<Type> newResultTypes;
80     for (Type type : executeRegionOp->getResultTypes()) {
81       if (auto tensorType = type.dyn_cast<TensorType>()) {
82         newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
83       } else {
84         newResultTypes.push_back(type);
85       }
86     }
87 
88     // Create new op and move over region.
89     auto newOp =
90         rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
91     newOp.getRegion().takeBody(executeRegionOp.getRegion());
92 
93     // Update terminator.
94     assert(newOp.getRegion().getBlocks().size() == 1 &&
95            "only 1 block supported");
96     Block *newBlock = &newOp.getRegion().front();
97     auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
98     rewriter.setInsertionPoint(yieldOp);
99     SmallVector<Value> newYieldValues;
100     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
101       Value val = it.value();
102       if (val.getType().isa<TensorType>()) {
103         newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
104             yieldOp.getLoc(), newResultTypes[it.index()], val));
105       } else {
106         newYieldValues.push_back(val);
107       }
108     }
109     rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
110 
111     // Update all uses of the old op.
112     rewriter.setInsertionPointAfter(newOp);
113     SmallVector<Value> newResults;
114     for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
115       if (it.value().isa<TensorType>()) {
116         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
117             executeRegionOp.getLoc(), newOp->getResult(it.index())));
118       } else {
119         newResults.push_back(newOp->getResult(it.index()));
120       }
121     }
122 
123     // Replace old op.
124     rewriter.replaceOp(executeRegionOp, newResults);
125 
126     return success();
127   }
128 
129   BufferRelation bufferRelation(Operation *op, OpResult opResult,
130                                 const AnalysisState &state) const {
131     return BufferRelation::Equivalent;
132   }
133 };
134 
135 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
136 struct IfOpInterface
137     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
138   SmallVector<OpOperand *>
139   getAliasingOpOperand(Operation *op, OpResult opResult,
140                        const AnalysisState &state) const {
141     // IfOps do not have tensor OpOperands. The yielded value can be any SSA
142     // value that is in scope. To allow for use-def chain traversal through
143     // IfOps in the analysis, both corresponding yield values from the then/else
144     // branches are considered to be aliasing with the result.
145     auto ifOp = cast<scf::IfOp>(op);
146     size_t resultNum = std::distance(op->getOpResults().begin(),
147                                      llvm::find(op->getOpResults(), opResult));
148     return {&ifOp.thenYield()->getOpOperand(resultNum),
149             &ifOp.elseYield()->getOpOperand(resultNum)};
150   }
151 
152   // TODO: For better bufferization results, this could return `true` only if
153   // there is a memory write in one (or both) of the branches. Since this is not
154   // allowed at the moment, we should never encounter scf.ifs that yield
155   // unmodified tensors. Such scf.yield ops could just fold away.
156   bool isMemoryWrite(Operation *op, OpResult opResult,
157                      const AnalysisState &state) const {
158     // IfOp results are always considered memory writes in the analysis. This
159     // design decision simplifies the analysis considerably. E.g., consider the
160     // following test case:
161     //
162     // %0 = "some_writing_op" : tensor<?xf32>
163     // %r = scf.if %c -> (tensor<?xf32>) {
164     //   scf.yield %0
165     // } else {
166     //   %1 = "another_writing_op"(%0) : tensor<?xf32>
167     // }
168     // "some_reading_op"(%r)
169     //
170     // "another_writing_op" in the above example should be able to bufferize
171     // inplace in the absence of another read of %0. However, if the scf.if op
172     // would not be considered a "write", the analysis would detect the
173     // following conflict:
174     //
175     // * read = some_reading_op
176     // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
177     // * conflictingWrite = %1
178     //
179     // For more details, check the "scf.IfOp" section of the design document.
180     return true;
181   }
182 
183   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
184                           BufferizationState &state) const {
185     auto ifOp = cast<scf::IfOp>(op);
186 
187     // Compute new types of the bufferized scf.if op.
188     SmallVector<Type> newTypes;
189     for (Type returnType : ifOp->getResultTypes()) {
190       if (auto tensorType = returnType.dyn_cast<TensorType>()) {
191         newTypes.push_back(getMemRefType(tensorType, state.getOptions()));
192       } else {
193         newTypes.push_back(returnType);
194       }
195     }
196 
197     // Create new op.
198     auto newIfOp =
199         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
200                                    /*withElseRegion=*/true);
201 
202     // Remove terminators.
203     if (!newIfOp.thenBlock()->empty()) {
204       rewriter.eraseOp(newIfOp.thenBlock()->getTerminator());
205       rewriter.eraseOp(newIfOp.elseBlock()->getTerminator());
206     }
207 
208     // Move over then/else blocks.
209     rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
210     rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
211 
212     // Update scf.yield of new then-block.
213     auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator());
214     rewriter.setInsertionPoint(thenYieldOp);
215     SmallVector<Value> thenYieldValues;
216     for (OpOperand &operand : thenYieldOp->getOpOperands()) {
217       if (operand.get().getType().isa<TensorType>()) {
218         ensureToMemrefOpIsValid(operand.get(),
219                                 newTypes[operand.getOperandNumber()]);
220         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
221             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
222             operand.get());
223         operand.set(toMemrefOp);
224       }
225     }
226 
227     // Update scf.yield of new else-block.
228     auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator());
229     rewriter.setInsertionPoint(elseYieldOp);
230     SmallVector<Value> elseYieldValues;
231     for (OpOperand &operand : elseYieldOp->getOpOperands()) {
232       if (operand.get().getType().isa<TensorType>()) {
233         ensureToMemrefOpIsValid(operand.get(),
234                                 newTypes[operand.getOperandNumber()]);
235         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
236             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
237             operand.get());
238         operand.set(toMemrefOp);
239       }
240     }
241 
242     // Replace op results.
243     replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
244 
245     return success();
246   }
247 
248   BufferRelation bufferRelation(Operation *op, OpResult opResult,
249                                 const AnalysisState &state) const {
250     // IfOp results are equivalent to their corresponding yield values if both
251     // yield values are equivalent to each other.
252     auto bufferizableOp = cast<BufferizableOpInterface>(op);
253     SmallVector<OpOperand *> yieldValues =
254         bufferizableOp.getAliasingOpOperand(opResult, state);
255     assert(yieldValues.size() == 2 && "expected 2 yield values");
256     bool equivalentYields = state.areEquivalentBufferizedValues(
257         yieldValues[0]->get(), yieldValues[1]->get());
258     return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
259   }
260 };
261 
262 /// Bufferization of scf.for. Replace with a new scf.for that operates on
263 /// memrefs.
264 struct ForOpInterface
265     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
266                                                     scf::ForOp> {
267   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
268                               const AnalysisState &state) const {
269     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
270     // its matching bbArg may.
271     auto forOp = cast<scf::ForOp>(op);
272     return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
273   }
274 
275   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
276                                const AnalysisState &state) const {
277     // Tensor iter_args of scf::ForOps are always considered as a write.
278     return true;
279   }
280 
281   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
282                                             const AnalysisState &state) const {
283     auto forOp = cast<scf::ForOp>(op);
284     return {forOp.getResultForOpOperand(opOperand)};
285   }
286 
287   BufferRelation bufferRelation(Operation *op, OpResult opResult,
288                                 const AnalysisState &state) const {
289     // ForOp results are equivalent to their corresponding init_args if the
290     // corresponding iter_args and yield values are equivalent.
291     auto forOp = cast<scf::ForOp>(op);
292     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
293     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
294     auto yieldOp =
295         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
296     bool equivalentYield = state.areEquivalentBufferizedValues(
297         bbArg, yieldOp->getOperand(opResult.getResultNumber()));
298     return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
299   }
300 
301   bool isWritable(Operation *op, Value value,
302                   const AnalysisState &state) const {
303     // Interestingly, scf::ForOp's bbArg can **always** be viewed
304     // inplace from the perspective of ops nested under:
305     //   1. Either the matching iter operand is not bufferized inplace and an
306     //      alloc + optional copy makes the bbArg itself inplaceable.
307     //   2. Or the matching iter operand is bufferized inplace and bbArg just
308     //      bufferizes to that too.
309     return true;
310   }
311 
312   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
313                           BufferizationState &state) const {
314     auto forOp = cast<scf::ForOp>(op);
315     auto bufferizableOp = cast<BufferizableOpInterface>(op);
316     Block *oldLoopBody = &forOp.getLoopBody().front();
317 
318     // Helper function for casting MemRef buffers.
319     auto castBuffer = [&](Value buffer, Type type) {
320       assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
321       assert(buffer.getType().isa<BaseMemRefType>() &&
322              "expected BaseMemRefType");
323       // If the buffer already has the correct type, no cast is needed.
324       if (buffer.getType() == type)
325         return buffer;
326       // TODO: In case `type` has a layout map that is not the fully dynamic
327       // one, we may not be able to cast the buffer. In that case, the loop
328       // iter_arg's layout map must be changed (see uses of `castBuffer`).
329       assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
330              "scf.for op bufferization: cast incompatible");
331       return rewriter.create<memref::CastOp>(buffer.getLoc(), type, buffer)
332           .getResult();
333     };
334 
335     // Indices of all iter_args that have tensor type. These are the ones that
336     // are bufferized.
337     DenseSet<int64_t> indices;
338     // For every yielded value, is the value equivalent to its corresponding
339     // bbArg?
340     SmallVector<bool> equivalentYields;
341     for (const auto &it : llvm::enumerate(forOp.getInitArgs())) {
342       if (it.value().getType().isa<TensorType>()) {
343         indices.insert(it.index());
344         BufferRelation relation = bufferizableOp.bufferRelation(
345             forOp->getResult(it.index()), state.getAnalysisState());
346         equivalentYields.push_back(relation == BufferRelation::Equivalent);
347       } else {
348         equivalentYields.push_back(false);
349       }
350     }
351 
352     // Given a range of values, apply `func` to those marked in `indices`.
353     // Otherwise, store the unmodified value in the result vector.
354     auto convert = [&](ValueRange values,
355                        llvm::function_ref<Value(Value, int64_t)> func) {
356       SmallVector<Value> result;
357       for (const auto &it : llvm::enumerate(values)) {
358         size_t idx = it.index();
359         Value val = it.value();
360         result.push_back(indices.contains(idx) ? func(val, idx) : val);
361       }
362       return result;
363     };
364 
365     // Construct a new scf.for op with memref instead of tensor values.
366     SmallVector<Value> initArgs;
367     for (OpOperand &opOperand : forOp.getIterOpOperands()) {
368       if (opOperand.get().getType().isa<TensorType>()) {
369         FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
370         if (failed(resultBuffer))
371           return failure();
372         initArgs.push_back(*resultBuffer);
373       } else {
374         initArgs.push_back(opOperand.get());
375       }
376     }
377     auto newForOp = rewriter.create<scf::ForOp>(
378         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
379         forOp.getStep(), initArgs);
380     Block *loopBody = &newForOp.getLoopBody().front();
381 
382     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
383     // iter_args of the new loop in ToTensorOps.
384     rewriter.setInsertionPointToStart(loopBody);
385     SmallVector<Value> iterArgs =
386         convert(newForOp.getRegionIterArgs(), [&](Value val, int64_t index) {
387           return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
388         });
389     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
390 
391     // Erase terminator if present.
392     if (iterArgs.size() == 1)
393       rewriter.eraseOp(loopBody->getTerminator());
394 
395     // Move loop body to new loop.
396     rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
397 
398     // Update scf.yield of new loop.
399     auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
400     rewriter.setInsertionPoint(yieldOp);
401     SmallVector<Value> yieldValues =
402         convert(yieldOp.getResults(), [&](Value val, int64_t index) {
403           Type initArgType = initArgs[index].getType();
404           ensureToMemrefOpIsValid(val, initArgType);
405           Value yieldedVal =
406               bufferization::lookupBuffer(rewriter, val, state.getOptions());
407 
408           if (equivalentYields[index])
409             // Yielded value is equivalent to the corresponding iter_arg bbArg.
410             // Yield the value directly. Most IR should be like that. Everything
411             // else must be resolved with copies and is potentially inefficient.
412             // By default, such problematic IR would already have been rejected
413             // during `verifyAnalysis`, unless `allow-return-allocs`.
414             return castBuffer(yieldedVal, initArgType);
415 
416           // It is not certain that the yielded value and the iter_arg bbArg
417           // have the same buffer. Allocate a new buffer and copy. The yielded
418           // buffer will get deallocated by `deallocateBuffers`.
419 
420           // TODO: There are cases in which it is not neccessary to return a new
421           // buffer allocation. E.g., when equivalent values are yielded in a
422           // different order. This could be resolved with copies.
423           Optional<Value> yieldedAlloc = state.createAlloc(
424               rewriter, val.getLoc(), yieldedVal, /*deallocMemref=*/false);
425           // TODO: We should rollback, but for now just assume that this always
426           // succeeds.
427           assert(yieldedAlloc.hasValue() && "could not create alloc");
428           LogicalResult copyStatus =
429               bufferization::createMemCpy(rewriter, val.getLoc(), yieldedVal,
430                                           *yieldedAlloc, state.getOptions());
431           (void)copyStatus;
432           assert(succeeded(copyStatus) && "could not create memcpy");
433 
434           // The iter_arg memref type may have a layout map. Cast the new buffer
435           // to the same type if needed.
436           return castBuffer(*yieldedAlloc, initArgType);
437         });
438     yieldOp.getResultsMutable().assign(yieldValues);
439 
440     // Replace loop results.
441     replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
442 
443     return success();
444   }
445 
446   /// Assert that yielded values of an scf.for op are equivalent to their
447   /// corresponding bbArgs. Otherwise, an alloc+copy are inserted and yielded
448   /// from the loop. This could be a performance problem, so it must be
449   /// explicitly activated with `alloc-return-allocs`.
450   LogicalResult verifyAnalysis(Operation *op,
451                                const AnalysisState &state) const {
452     const auto &options =
453         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
454     if (options.allowReturnAllocs)
455       return success();
456 
457     auto forOp = cast<scf::ForOp>(op);
458     auto yieldOp =
459         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
460     for (OpOperand &operand : yieldOp->getOpOperands()) {
461       auto tensorType = operand.get().getType().dyn_cast<TensorType>();
462       if (!tensorType)
463         continue;
464 
465       OpOperand &forOperand = forOp.getOpOperandForResult(
466           forOp->getResult(operand.getOperandNumber()));
467       auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
468       // Note: This is overly strict. We should check for aliasing bufferized
469       // values. But we don't have a "must-alias" analysis yet.
470       if (!state.areEquivalentBufferizedValues(operand.get(), bbArg))
471         return yieldOp->emitError()
472                << "Yield operand #" << operand.getOperandNumber()
473                << " does not bufferize to a buffer that is aliasing the "
474                   "matching enclosing scf::for operand";
475     }
476     return success();
477   }
478 };
479 
480 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
481 /// this is for analysis only.
482 struct YieldOpInterface
483     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
484                                                     scf::YieldOp> {
485   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
486                               const AnalysisState &state) const {
487     return true;
488   }
489 
490   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
491                                const AnalysisState &state) const {
492     return false;
493   }
494 
495   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
496                                             const AnalysisState &state) const {
497     if (isa<scf::IfOp>(op->getParentOp()))
498       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
499     if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
500       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
501     return {};
502   }
503 
504   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
505                             const AnalysisState &state) const {
506     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
507     // may be generated inside the block. We should not return/yield allocations
508     // when possible.
509     return true;
510   }
511 
512   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
513                           BufferizationState &state) const {
514     auto yieldOp = cast<scf::YieldOp>(op);
515     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(
516             yieldOp->getParentOp()))
517       return yieldOp->emitError("unsupported scf::YieldOp parent");
518     return success();
519   }
520 };
521 
522 } // namespace
523 } // namespace scf
524 } // namespace mlir
525 
526 void mlir::scf::registerBufferizableOpInterfaceExternalModels(
527     DialectRegistry &registry) {
528   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
529     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
530     ForOp::attachInterface<ForOpInterface>(*ctx);
531     IfOp::attachInterface<IfOpInterface>(*ctx);
532     YieldOp::attachInterface<YieldOpInterface>(*ctx);
533   });
534 }
535