1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::scf;
21 
22 namespace mlir {
23 namespace scf {
24 namespace {
25 
26 // bufferization.to_memref is not allowed to change the rank.
27 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
28 #ifndef NDEBUG
29   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
30   assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
31                                 rankedTensorType.getRank())) &&
32          "to_memref would be invalid: mismatching ranks");
33 #endif
34 }
35 
36 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
37 /// fully implemented at the moment.
38 struct ExecuteRegionOpInterface
39     : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
40                                                     scf::ExecuteRegionOp> {
41   SmallVector<OpOperand *>
42   getAliasingOpOperand(Operation *op, OpResult opResult,
43                        const BufferizationState &state) const {
44     // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
45     // any SSA value that is in scope. To allow for use-def chain traversal
46     // through ExecuteRegionOps in the analysis, the corresponding yield value
47     // is considered to be aliasing with the result.
48     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
49     size_t resultNum = std::distance(op->getOpResults().begin(),
50                                      llvm::find(op->getOpResults(), opResult));
51     // TODO: Support multiple blocks.
52     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
53            "expected exactly 1 block");
54     auto yieldOp = dyn_cast<scf::YieldOp>(
55         executeRegionOp.getRegion().front().getTerminator());
56     assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
57     return {&yieldOp->getOpOperand(resultNum)};
58   }
59 
60   // TODO: For better bufferization results, this could return `true` only if
61   // there is a memory write in the region.
62   bool isMemoryWrite(Operation *op, OpResult opResult,
63                      const BufferizationState &state) const {
64     // Similar to scf.if, results of this op are always considered memory writes
65     // in the analysis. This is a useful pattern for all ops that have tensor
66     // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
67     // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
68     // ops without OpOperands.
69     return true;
70   }
71 
72   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
73                           const BufferizationState &state) const {
74     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
75 
76     // Compute new result types.
77     SmallVector<Type> newResultTypes;
78     for (Type type : executeRegionOp->getResultTypes()) {
79       if (auto tensorType = type.dyn_cast<TensorType>()) {
80         newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
81       } else {
82         newResultTypes.push_back(type);
83       }
84     }
85 
86     // Create new op and move over region.
87     auto newOp =
88         rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
89     newOp.getRegion().takeBody(executeRegionOp.getRegion());
90 
91     // Update terminator.
92     assert(newOp.getRegion().getBlocks().size() == 1 &&
93            "only 1 block supported");
94     Block *newBlock = &newOp.getRegion().front();
95     auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
96     rewriter.setInsertionPoint(yieldOp);
97     SmallVector<Value> newYieldValues;
98     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
99       Value val = it.value();
100       if (val.getType().isa<TensorType>()) {
101         newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
102             yieldOp.getLoc(), newResultTypes[it.index()], val));
103       } else {
104         newYieldValues.push_back(val);
105       }
106     }
107     rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
108 
109     // Update all uses of the old op.
110     rewriter.setInsertionPointAfter(newOp);
111     SmallVector<Value> newResults;
112     for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
113       if (it.value().isa<TensorType>()) {
114         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
115             executeRegionOp.getLoc(), newOp->getResult(it.index())));
116       } else {
117         newResults.push_back(newOp->getResult(it.index()));
118       }
119     }
120 
121     // Replace old op.
122     rewriter.replaceOp(executeRegionOp, newResults);
123 
124     return success();
125   }
126 
127   BufferRelation bufferRelation(Operation *op, OpResult opResult,
128                                 const BufferizationState &state) const {
129     return BufferRelation::Equivalent;
130   }
131 };
132 
133 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
134 struct IfOpInterface
135     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
136   SmallVector<OpOperand *>
137   getAliasingOpOperand(Operation *op, OpResult opResult,
138                        const BufferizationState &state) const {
139     // IfOps do not have tensor OpOperands. The yielded value can be any SSA
140     // value that is in scope. To allow for use-def chain traversal through
141     // IfOps in the analysis, both corresponding yield values from the then/else
142     // branches are considered to be aliasing with the result.
143     auto ifOp = cast<scf::IfOp>(op);
144     size_t resultNum = std::distance(op->getOpResults().begin(),
145                                      llvm::find(op->getOpResults(), opResult));
146     return {&ifOp.thenYield()->getOpOperand(resultNum),
147             &ifOp.elseYield()->getOpOperand(resultNum)};
148   }
149 
150   // TODO: For better bufferization results, this could return `true` only if
151   // there is a memory write in one (or both) of the branches. Since this is not
152   // allowed at the moment, we should never encounter scf.ifs that yield
153   // unmodified tensors. Such scf.yield ops could just fold away.
154   bool isMemoryWrite(Operation *op, OpResult opResult,
155                      const BufferizationState &state) const {
156     // IfOp results are always considered memory writes in the analysis. This
157     // design decision simplifies the analysis considerably. E.g., consider the
158     // following test case:
159     //
160     // %0 = "some_writing_op" : tensor<?xf32>
161     // %r = scf.if %c -> (tensor<?xf32>) {
162     //   scf.yield %0
163     // } else {
164     //   %1 = "another_writing_op"(%0) : tensor<?xf32>
165     // }
166     // "some_reading_op"(%r)
167     //
168     // "another_writing_op" in the above example should be able to bufferize
169     // inplace in the absence of another read of %0. However, if the scf.if op
170     // would not be considered a "write", the analysis would detect the
171     // following conflict:
172     //
173     // * read = some_reading_op
174     // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
175     // * conflictingWrite = %1
176     //
177     // For more details, check the "scf.IfOp" section of the design document.
178     return true;
179   }
180 
181   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
182                           const BufferizationState &state) const {
183     auto ifOp = cast<scf::IfOp>(op);
184 
185     // Compute new types of the bufferized scf.if op.
186     SmallVector<Type> newTypes;
187     for (Type returnType : ifOp->getResultTypes()) {
188       if (auto tensorType = returnType.dyn_cast<TensorType>()) {
189         newTypes.push_back(getMemRefType(tensorType, state.getOptions()));
190       } else {
191         newTypes.push_back(returnType);
192       }
193     }
194 
195     // Create new op.
196     auto newIfOp =
197         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
198                                    /*withElseRegion=*/true);
199 
200     // Remove terminators.
201     if (!newIfOp.thenBlock()->empty()) {
202       rewriter.eraseOp(newIfOp.thenBlock()->getTerminator());
203       rewriter.eraseOp(newIfOp.elseBlock()->getTerminator());
204     }
205 
206     // Move over then/else blocks.
207     rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
208     rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
209 
210     // Update scf.yield of new then-block.
211     auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator());
212     rewriter.setInsertionPoint(thenYieldOp);
213     SmallVector<Value> thenYieldValues;
214     for (OpOperand &operand : thenYieldOp->getOpOperands()) {
215       if (operand.get().getType().isa<TensorType>()) {
216         ensureToMemrefOpIsValid(operand.get(),
217                                 newTypes[operand.getOperandNumber()]);
218         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
219             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
220             operand.get());
221         operand.set(toMemrefOp);
222       }
223     }
224 
225     // Update scf.yield of new else-block.
226     auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator());
227     rewriter.setInsertionPoint(elseYieldOp);
228     SmallVector<Value> elseYieldValues;
229     for (OpOperand &operand : elseYieldOp->getOpOperands()) {
230       if (operand.get().getType().isa<TensorType>()) {
231         ensureToMemrefOpIsValid(operand.get(),
232                                 newTypes[operand.getOperandNumber()]);
233         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
234             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
235             operand.get());
236         operand.set(toMemrefOp);
237       }
238     }
239 
240     // Replace op results.
241     replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
242 
243     return success();
244   }
245 
246   BufferRelation bufferRelation(Operation *op, OpResult opResult,
247                                 const BufferizationState &state) const {
248     // IfOp results are equivalent to their corresponding yield values if both
249     // yield values are equivalent to each other.
250     auto bufferizableOp = cast<BufferizableOpInterface>(op);
251     SmallVector<OpOperand *> yieldValues =
252         bufferizableOp.getAliasingOpOperand(opResult, state);
253     assert(yieldValues.size() == 2 && "expected 2 yield values");
254     bool equivalentYields = state.areEquivalentBufferizedValues(
255         yieldValues[0]->get(), yieldValues[1]->get());
256     return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
257   }
258 };
259 
260 /// Bufferization of scf.for. Replace with a new scf.for that operates on
261 /// memrefs.
262 struct ForOpInterface
263     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
264                                                     scf::ForOp> {
265   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
266                               const BufferizationState &state) const {
267     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
268     // its matching bbArg may.
269     auto forOp = cast<scf::ForOp>(op);
270     return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
271   }
272 
273   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
274                                const BufferizationState &state) const {
275     // Tensor iter_args of scf::ForOps are always considered as a write. This is
276     // to simplify the analysis.
277     // TODO: Consider doing sth. like isValueWritten.
278     return true;
279   }
280 
281   SmallVector<OpResult>
282   getAliasingOpResult(Operation *op, OpOperand &opOperand,
283                       const BufferizationState &state) const {
284     auto forOp = cast<scf::ForOp>(op);
285     if (!opOperand.get().getType().isa<RankedTensorType>())
286       return {};
287     return {forOp.getResultForOpOperand(opOperand)};
288   }
289 
290   BufferRelation bufferRelation(Operation *op, OpResult opResult,
291                                 const BufferizationState &state) const {
292     // ForOp results are equivalent to their corresponding init_args if the
293     // corresponding iter_args and yield values are equivalent.
294     auto forOp = cast<scf::ForOp>(op);
295     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
296     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
297     auto yieldOp = cast<scf::YieldOp>(&forOp.getLoopBody().front().back());
298     bool equivalentYield = state.areEquivalentBufferizedValues(
299         bbArg, yieldOp->getOperand(opResult.getResultNumber()));
300     return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
301   }
302 
303   bool isWritable(Operation *op, Value value,
304                   const BufferizationState &state) const {
305     // Interestingly, scf::ForOp's bbArg can **always** be viewed
306     // inplace from the perspective of ops nested under:
307     //   1. Either the matching iter operand is not bufferized inplace and an
308     //      alloc + optional copy makes the bbArg itself inplaceable.
309     //   2. Or the matching iter operand is bufferized inplace and bbArg just
310     //      bufferizes to that too.
311     return true;
312   }
313 
314   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
315                           const BufferizationState &state) const {
316     auto forOp = cast<scf::ForOp>(op);
317     Block *oldLoopBody = &forOp.getLoopBody().front();
318 
319     // Indices of all iter_args that have tensor type. These are the ones that
320     // are bufferized.
321     DenseSet<int64_t> indices;
322     for (const auto &it : llvm::enumerate(forOp.getInitArgs()))
323       if (it.value().getType().isa<TensorType>())
324         indices.insert(it.index());
325 
326     // Given a range of values, apply `func` to those marked in `indices`.
327     // Otherwise, store the unmodified value in the result vector.
328     auto convert = [&](ValueRange values,
329                        llvm::function_ref<Value(Value, int64_t)> func) {
330       SmallVector<Value> result;
331       for (const auto &it : llvm::enumerate(values)) {
332         size_t idx = it.index();
333         Value val = it.value();
334         result.push_back(indices.contains(idx) ? func(val, idx) : val);
335       }
336       return result;
337     };
338 
339     // Construct a new scf.for op with memref instead of tensor values.
340     SmallVector<Value> initArgs;
341     for (OpOperand &opOperand : forOp.getIterOpOperands()) {
342       if (opOperand.get().getType().isa<TensorType>()) {
343         FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
344         if (failed(resultBuffer))
345           return failure();
346         initArgs.push_back(*resultBuffer);
347       } else {
348         initArgs.push_back(opOperand.get());
349       }
350     }
351     auto newForOp = rewriter.create<scf::ForOp>(
352         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
353         forOp.getStep(), initArgs);
354     Block *loopBody = &newForOp.getLoopBody().front();
355 
356     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
357     // iter_args of the new loop in ToTensorOps.
358     rewriter.setInsertionPointToStart(loopBody);
359     SmallVector<Value> iterArgs =
360         convert(newForOp.getRegionIterArgs(), [&](Value val, int64_t index) {
361           return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
362         });
363     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
364 
365     // Erase terminator if present.
366     if (iterArgs.size() == 1)
367       rewriter.eraseOp(loopBody->getTerminator());
368 
369     // Move loop body to new loop.
370     rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
371 
372     // Update scf.yield of new loop.
373     auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
374     rewriter.setInsertionPoint(yieldOp);
375     SmallVector<Value> yieldValues =
376         convert(yieldOp.getResults(), [&](Value val, int64_t index) {
377           ensureToMemrefOpIsValid(val, initArgs[index].getType());
378           return rewriter.create<bufferization::ToMemrefOp>(
379               val.getLoc(), initArgs[index].getType(), val);
380         });
381     yieldOp.getResultsMutable().assign(yieldValues);
382 
383     // Replace loop results.
384     replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
385 
386     return success();
387   }
388 };
389 
390 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
391 /// this is for analysis only.
392 struct YieldOpInterface
393     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
394                                                     scf::YieldOp> {
395   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
396                               const BufferizationState &state) const {
397     return true;
398   }
399 
400   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
401                                const BufferizationState &state) const {
402     return false;
403   }
404 
405   SmallVector<OpResult>
406   getAliasingOpResult(Operation *op, OpOperand &opOperand,
407                       const BufferizationState &state) const {
408     if (isa<scf::IfOp>(op->getParentOp()))
409       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
410     if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
411       return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
412     return {};
413   }
414 
415   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
416                             const BufferizationState &state) const {
417     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
418     // may be generated inside the block. We should not return/yield allocations
419     // when possible.
420     return true;
421   }
422 
423   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
424                           const BufferizationState &state) const {
425     auto yieldOp = cast<scf::YieldOp>(op);
426     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(
427             yieldOp->getParentOp()))
428       return yieldOp->emitError("unsupported scf::YieldOp parent");
429     return success();
430   }
431 };
432 
433 } // namespace
434 } // namespace scf
435 } // namespace mlir
436 
437 LogicalResult mlir::scf::assertScfForAliasingProperties(
438     Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo,
439     SmallVector<Operation *> &newOps) {
440   LogicalResult status = success();
441 
442   op->walk([&](scf::ForOp forOp) {
443     auto yieldOp =
444         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
445     for (OpOperand &operand : yieldOp->getOpOperands()) {
446       auto tensorType = operand.get().getType().dyn_cast<TensorType>();
447       if (!tensorType)
448         continue;
449 
450       OpOperand &forOperand = forOp.getOpOperandForResult(
451           forOp->getResult(operand.getOperandNumber()));
452       auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
453       // Note: This is overly strict. We should check for aliasing bufferized
454       // values. But we don't have a "must-alias" analysis yet.
455       if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
456         // TODO: this could get resolved with copies but it can also turn into
457         // swaps so we need to be careful about order of copies.
458         status =
459             yieldOp->emitError()
460             << "Yield operand #" << operand.getOperandNumber()
461             << " does not bufferize to a buffer that is aliasing the matching"
462             << " enclosing scf::for operand";
463         return WalkResult::interrupt();
464       }
465     }
466     return WalkResult::advance();
467   });
468 
469   return status;
470 }
471 
472 void mlir::scf::registerBufferizableOpInterfaceExternalModels(
473     DialectRegistry &registry) {
474   registry.addOpInterface<ExecuteRegionOp, ExecuteRegionOpInterface>();
475   registry.addOpInterface<ForOp, ForOpInterface>();
476   registry.addOpInterface<IfOp, IfOpInterface>();
477   registry.addOpInterface<YieldOp, YieldOpInterface>();
478   registry
479       .addOpInterface<ParallelOp, AllocationHoistingBarrierOnly<ParallelOp>>();
480 }
481