1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 using namespace mlir::scf;
22 
23 namespace mlir {
24 namespace scf {
25 namespace {
26 
27 // bufferization.to_memref is not allowed to change the rank.
28 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
29 #ifndef NDEBUG
30   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
31   assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
32                                 rankedTensorType.getRank())) &&
33          "to_memref would be invalid: mismatching ranks");
34 #endif
35 }
36 
37 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
38 /// fully implemented at the moment.
39 struct ExecuteRegionOpInterface
40     : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
41                                                     scf::ExecuteRegionOp> {
42   SmallVector<OpOperand *>
43   getAliasingOpOperand(Operation *op, OpResult opResult,
44                        const AnalysisState &state) const {
45     // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
46     // any SSA value that is in scope. To allow for use-def chain traversal
47     // through ExecuteRegionOps in the analysis, the corresponding yield value
48     // is considered to be aliasing with the result.
49     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
50     size_t resultNum = std::distance(op->getOpResults().begin(),
51                                      llvm::find(op->getOpResults(), opResult));
52     // TODO: Support multiple blocks.
53     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
54            "expected exactly 1 block");
55     auto yieldOp = dyn_cast<scf::YieldOp>(
56         executeRegionOp.getRegion().front().getTerminator());
57     assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
58     return {&yieldOp->getOpOperand(resultNum)};
59   }
60 
61   // TODO: For better bufferization results, this could return `true` only if
62   // there is a memory write in the region.
63   bool isMemoryWrite(Operation *op, OpResult opResult,
64                      const AnalysisState &state) const {
65     // Similar to scf.if, results of this op are always considered memory writes
66     // in the analysis. This is a useful pattern for all ops that have tensor
67     // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
68     // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
69     // ops without OpOperands.
70     return true;
71   }
72 
73   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
74                           BufferizationState &state) const {
75     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
76 
77     // Compute new result types.
78     SmallVector<Type> newResultTypes;
79     for (Type type : executeRegionOp->getResultTypes()) {
80       if (auto tensorType = type.dyn_cast<TensorType>()) {
81         newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
82       } else {
83         newResultTypes.push_back(type);
84       }
85     }
86 
87     // Create new op and move over region.
88     auto newOp =
89         rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
90     newOp.getRegion().takeBody(executeRegionOp.getRegion());
91 
92     // Update terminator.
93     assert(newOp.getRegion().getBlocks().size() == 1 &&
94            "only 1 block supported");
95     Block *newBlock = &newOp.getRegion().front();
96     auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
97     rewriter.setInsertionPoint(yieldOp);
98     SmallVector<Value> newYieldValues;
99     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
100       Value val = it.value();
101       if (val.getType().isa<TensorType>()) {
102         newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
103             yieldOp.getLoc(), newResultTypes[it.index()], val));
104       } else {
105         newYieldValues.push_back(val);
106       }
107     }
108     rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
109 
110     // Update all uses of the old op.
111     rewriter.setInsertionPointAfter(newOp);
112     SmallVector<Value> newResults;
113     for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
114       if (it.value().isa<TensorType>()) {
115         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
116             executeRegionOp.getLoc(), newOp->getResult(it.index())));
117       } else {
118         newResults.push_back(newOp->getResult(it.index()));
119       }
120     }
121 
122     // Replace old op.
123     rewriter.replaceOp(executeRegionOp, newResults);
124 
125     return success();
126   }
127 
128   BufferRelation bufferRelation(Operation *op, OpResult opResult,
129                                 const AnalysisState &state) const {
130     return BufferRelation::Equivalent;
131   }
132 };
133 
134 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
135 struct IfOpInterface
136     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
137   SmallVector<OpOperand *>
138   getAliasingOpOperand(Operation *op, OpResult opResult,
139                        const AnalysisState &state) const {
140     // IfOps do not have tensor OpOperands. The yielded value can be any SSA
141     // value that is in scope. To allow for use-def chain traversal through
142     // IfOps in the analysis, both corresponding yield values from the then/else
143     // branches are considered to be aliasing with the result.
144     auto ifOp = cast<scf::IfOp>(op);
145     size_t resultNum = std::distance(op->getOpResults().begin(),
146                                      llvm::find(op->getOpResults(), opResult));
147     return {&ifOp.thenYield()->getOpOperand(resultNum),
148             &ifOp.elseYield()->getOpOperand(resultNum)};
149   }
150 
151   // TODO: For better bufferization results, this could return `true` only if
152   // there is a memory write in one (or both) of the branches. Since this is not
153   // allowed at the moment, we should never encounter scf.ifs that yield
154   // unmodified tensors. Such scf.yield ops could just fold away.
155   bool isMemoryWrite(Operation *op, OpResult opResult,
156                      const AnalysisState &state) const {
157     // IfOp results are always considered memory writes in the analysis. This
158     // design decision simplifies the analysis considerably. E.g., consider the
159     // following test case:
160     //
161     // %0 = "some_writing_op" : tensor<?xf32>
162     // %r = scf.if %c -> (tensor<?xf32>) {
163     //   scf.yield %0
164     // } else {
165     //   %1 = "another_writing_op"(%0) : tensor<?xf32>
166     // }
167     // "some_reading_op"(%r)
168     //
169     // "another_writing_op" in the above example should be able to bufferize
170     // inplace in the absence of another read of %0. However, if the scf.if op
171     // would not be considered a "write", the analysis would detect the
172     // following conflict:
173     //
174     // * read = some_reading_op
175     // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
176     // * conflictingWrite = %1
177     //
178     // For more details, check the "scf.IfOp" section of the design document.
179     return true;
180   }
181 
182   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
183                           BufferizationState &state) const {
184     auto ifOp = cast<scf::IfOp>(op);
185 
186     // Compute new types of the bufferized scf.if op.
187     SmallVector<Type> newTypes;
188     for (Type returnType : ifOp->getResultTypes()) {
189       if (auto tensorType = returnType.dyn_cast<TensorType>()) {
190         newTypes.push_back(getMemRefType(tensorType, state.getOptions()));
191       } else {
192         newTypes.push_back(returnType);
193       }
194     }
195 
196     // Create new op.
197     auto newIfOp =
198         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
199                                    /*withElseRegion=*/true);
200 
201     // Remove terminators.
202     if (!newIfOp.thenBlock()->empty()) {
203       rewriter.eraseOp(newIfOp.thenBlock()->getTerminator());
204       rewriter.eraseOp(newIfOp.elseBlock()->getTerminator());
205     }
206 
207     // Move over then/else blocks.
208     rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
209     rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
210 
211     // Update scf.yield of new then-block.
212     auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator());
213     rewriter.setInsertionPoint(thenYieldOp);
214     SmallVector<Value> thenYieldValues;
215     for (OpOperand &operand : thenYieldOp->getOpOperands()) {
216       if (operand.get().getType().isa<TensorType>()) {
217         ensureToMemrefOpIsValid(operand.get(),
218                                 newTypes[operand.getOperandNumber()]);
219         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
220             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
221             operand.get());
222         operand.set(toMemrefOp);
223       }
224     }
225 
226     // Update scf.yield of new else-block.
227     auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator());
228     rewriter.setInsertionPoint(elseYieldOp);
229     SmallVector<Value> elseYieldValues;
230     for (OpOperand &operand : elseYieldOp->getOpOperands()) {
231       if (operand.get().getType().isa<TensorType>()) {
232         ensureToMemrefOpIsValid(operand.get(),
233                                 newTypes[operand.getOperandNumber()]);
234         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
235             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
236             operand.get());
237         operand.set(toMemrefOp);
238       }
239     }
240 
241     // Replace op results.
242     replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
243 
244     return success();
245   }
246 
247   BufferRelation bufferRelation(Operation *op, OpResult opResult,
248                                 const AnalysisState &state) const {
249     // IfOp results are equivalent to their corresponding yield values if both
250     // yield values are equivalent to each other.
251     auto bufferizableOp = cast<BufferizableOpInterface>(op);
252     SmallVector<OpOperand *> yieldValues =
253         bufferizableOp.getAliasingOpOperand(opResult, state);
254     assert(yieldValues.size() == 2 && "expected 2 yield values");
255     bool equivalentYields = state.areEquivalentBufferizedValues(
256         yieldValues[0]->get(), yieldValues[1]->get());
257     return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
258   }
259 };
260 
261 /// Bufferization of scf.for. Replace with a new scf.for that operates on
262 /// memrefs.
263 struct ForOpInterface
264     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
265                                                     scf::ForOp> {
266   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
267                               const AnalysisState &state) const {
268     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
269     // its matching bbArg may.
270     auto forOp = cast<scf::ForOp>(op);
271     return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
272   }
273 
274   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
275                                const AnalysisState &state) const {
276     // Tensor iter_args of scf::ForOps are always considered as a write.
277     return true;
278   }
279 
280   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
281                                             const AnalysisState &state) const {
282     auto forOp = cast<scf::ForOp>(op);
283     return {forOp.getResultForOpOperand(opOperand)};
284   }
285 
286   BufferRelation bufferRelation(Operation *op, OpResult opResult,
287                                 const AnalysisState &state) const {
288     // ForOp results are equivalent to their corresponding init_args if the
289     // corresponding iter_args and yield values are equivalent.
290     auto forOp = cast<scf::ForOp>(op);
291     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
292     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
293     auto yieldOp =
294         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
295     bool equivalentYield = state.areEquivalentBufferizedValues(
296         bbArg, yieldOp->getOperand(opResult.getResultNumber()));
297     return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
298   }
299 
300   bool isWritable(Operation *op, Value value,
301                   const AnalysisState &state) const {
302     // Interestingly, scf::ForOp's bbArg can **always** be viewed
303     // inplace from the perspective of ops nested under:
304     //   1. Either the matching iter operand is not bufferized inplace and an
305     //      alloc + optional copy makes the bbArg itself inplaceable.
306     //   2. Or the matching iter operand is bufferized inplace and bbArg just
307     //      bufferizes to that too.
308     return true;
309   }
310 
311   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
312                           BufferizationState &state) const {
313     auto forOp = cast<scf::ForOp>(op);
314     auto bufferizableOp = cast<BufferizableOpInterface>(op);
315     Block *oldLoopBody = &forOp.getLoopBody().front();
316 
317     // Helper function for casting MemRef buffers.
318     auto castBuffer = [&](Value buffer, Type type) {
319       assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
320       assert(buffer.getType().isa<BaseMemRefType>() &&
321              "expected BaseMemRefType");
322       // If the buffer already has the correct type, no cast is needed.
323       if (buffer.getType() == type)
324         return buffer;
325       // TODO: In case `type` has a layout map that is not the fully dynamic
326       // one, we may not be able to cast the buffer. In that case, the loop
327       // iter_arg's layout map must be changed (see uses of `castBuffer`).
328       assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
329              "scf.for op bufferization: cast incompatible");
330       return rewriter.create<memref::CastOp>(buffer.getLoc(), type, buffer)
331           .getResult();
332     };
333 
334     // Indices of all iter_args that have tensor type. These are the ones that
335     // are bufferized.
336     DenseSet<int64_t> indices;
337     // For every yielded value, is the value equivalent to its corresponding
338     // bbArg?
339     SmallVector<bool> equivalentYields;
340     for (const auto &it : llvm::enumerate(forOp.getInitArgs())) {
341       if (it.value().getType().isa<TensorType>()) {
342         indices.insert(it.index());
343         BufferRelation relation = bufferizableOp.bufferRelation(
344             forOp->getResult(it.index()), state.getAnalysisState());
345         equivalentYields.push_back(relation == BufferRelation::Equivalent);
346       } else {
347         equivalentYields.push_back(false);
348       }
349     }
350 
351     // Given a range of values, apply `func` to those marked in `indices`.
352     // Otherwise, store the unmodified value in the result vector.
353     auto convert = [&](ValueRange values,
354                        llvm::function_ref<Value(Value, int64_t)> func) {
355       SmallVector<Value> result;
356       for (const auto &it : llvm::enumerate(values)) {
357         size_t idx = it.index();
358         Value val = it.value();
359         result.push_back(indices.contains(idx) ? func(val, idx) : val);
360       }
361       return result;
362     };
363 
364     // Construct a new scf.for op with memref instead of tensor values.
365     SmallVector<Value> initArgs;
366     for (OpOperand &opOperand : forOp.getIterOpOperands()) {
367       if (opOperand.get().getType().isa<TensorType>()) {
368         FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
369         if (failed(resultBuffer))
370           return failure();
371         initArgs.push_back(*resultBuffer);
372       } else {
373         initArgs.push_back(opOperand.get());
374       }
375     }
376     auto newForOp = rewriter.create<scf::ForOp>(
377         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
378         forOp.getStep(), initArgs);
379     Block *loopBody = &newForOp.getLoopBody().front();
380 
381     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
382     // iter_args of the new loop in ToTensorOps.
383     rewriter.setInsertionPointToStart(loopBody);
384     SmallVector<Value> iterArgs =
385         convert(newForOp.getRegionIterArgs(), [&](Value val, int64_t index) {
386           return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
387         });
388     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
389 
390     // Erase terminator if present.
391     if (iterArgs.size() == 1)
392       rewriter.eraseOp(loopBody->getTerminator());
393 
394     // Move loop body to new loop.
395     rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
396 
397     // Update scf.yield of new loop.
398     auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
399     rewriter.setInsertionPoint(yieldOp);
400     SmallVector<Value> yieldValues =
401         convert(yieldOp.getResults(), [&](Value val, int64_t index) {
402           Type initArgType = initArgs[index].getType();
403           ensureToMemrefOpIsValid(val, initArgType);
404           Value yieldedVal =
405               bufferization::lookupBuffer(rewriter, val, state.getOptions());
406 
407           if (equivalentYields[index])
408             // Yielded value is equivalent to the corresponding iter_arg bbArg.
409             // Yield the value directly. Most IR should be like that. Everything
410             // else must be resolved with copies and is potentially inefficient.
411             // By default, such problematic IR would already have been rejected
412             // during `verifyAnalysis`, unless `allow-return-allocs`.
413             return castBuffer(yieldedVal, initArgType);
414 
415           // It is not certain that the yielded value and the iter_arg bbArg
416           // have the same buffer. Allocate a new buffer and copy. The yielded
417           // buffer will get deallocated by `deallocateBuffers`.
418 
419           // TODO: There are cases in which it is not neccessary to return a new
420           // buffer allocation. E.g., when equivalent values are yielded in a
421           // different order. This could be resolved with copies.
422           Optional<Value> yieldedAlloc = state.createAlloc(
423               rewriter, val.getLoc(), yieldedVal, /*deallocMemref=*/false);
424           // TODO: We should rollback, but for now just assume that this always
425           // succeeds.
426           assert(yieldedAlloc.hasValue() && "could not create alloc");
427           LogicalResult copyStatus =
428               bufferization::createMemCpy(rewriter, val.getLoc(), yieldedVal,
429                                           *yieldedAlloc, state.getOptions());
430           (void)copyStatus;
431           assert(succeeded(copyStatus) && "could not create memcpy");
432 
433           // The iter_arg memref type may have a layout map. Cast the new buffer
434           // to the same type if needed.
435           return castBuffer(*yieldedAlloc, initArgType);
436         });
437     yieldOp.getResultsMutable().assign(yieldValues);
438 
439     // Replace loop results.
440     replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
441 
442     return success();
443   }
444 
445   /// Assert that yielded values of an scf.for op are equivalent to their
446   /// corresponding bbArgs. Otherwise, an alloc+copy are inserted and yielded
447   /// from the loop. This could be a performance problem, so it must be
448   /// explicitly activated with `alloc-return-allocs`.
449   LogicalResult verifyAnalysis(Operation *op,
450                                const AnalysisState &state) const {
451     const auto &options =
452         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
453     if (options.allowReturnAllocs)
454       return success();
455 
456     auto forOp = cast<scf::ForOp>(op);
457     auto yieldOp =
458         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
459     for (OpOperand &operand : yieldOp->getOpOperands()) {
460       auto tensorType = operand.get().getType().dyn_cast<TensorType>();
461       if (!tensorType)
462         continue;
463 
464       OpOperand &forOperand = forOp.getOpOperandForResult(
465           forOp->getResult(operand.getOperandNumber()));
466       auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
467       // Note: This is overly strict. We should check for aliasing bufferized
468       // values. But we don't have a "must-alias" analysis yet.
469       if (!state.areEquivalentBufferizedValues(operand.get(), bbArg))
470         return yieldOp->emitError()
471                << "Yield operand #" << operand.getOperandNumber()
472                << " does not bufferize to a buffer that is aliasing the "
473                   "matching enclosing scf::for operand";
474     }
475     return success();
476   }
477 };
478 
479 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
480 /// this is for analysis only.
481 struct YieldOpInterface
482     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
483                                                     scf::YieldOp> {
484   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
485                               const AnalysisState &state) const {
486     return true;
487   }
488 
489   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
490                                const AnalysisState &state) const {
491     return false;
492   }
493 
494   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
495                                             const AnalysisState &state) const {
496     if (isa<scf::IfOp>(op->getParentOp()))
497       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
498     if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
499       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
500     return {};
501   }
502 
503   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
504                             const AnalysisState &state) const {
505     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
506     // may be generated inside the block. We should not return/yield allocations
507     // when possible.
508     return true;
509   }
510 
511   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
512                           BufferizationState &state) const {
513     auto yieldOp = cast<scf::YieldOp>(op);
514     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(
515             yieldOp->getParentOp()))
516       return yieldOp->emitError("unsupported scf::YieldOp parent");
517     return success();
518   }
519 };
520 
521 } // namespace
522 } // namespace scf
523 } // namespace mlir
524 
525 void mlir::scf::registerBufferizableOpInterfaceExternalModels(
526     DialectRegistry &registry) {
527   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
528     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
529     ForOp::attachInterface<ForOpInterface>(*ctx);
530     IfOp::attachInterface<IfOpInterface>(*ctx);
531     YieldOp::attachInterface<YieldOpInterface>(*ctx);
532   });
533 }
534