1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include "mlir/Transforms/Passes.h"
21 
22 using namespace mlir;
23 using namespace mlir::bufferization;
24 
25 //===----------------------------------------------------------------------===//
26 // BufferizeTypeConverter
27 //===----------------------------------------------------------------------===//
28 
29 static Value materializeToTensor(OpBuilder &builder, TensorType type,
30                                  ValueRange inputs, Location loc) {
31   assert(inputs.size() == 1);
32   assert(inputs[0].getType().isa<BaseMemRefType>());
33   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
34 }
35 
36 /// Registers conversions into BufferizeTypeConverter
37 BufferizeTypeConverter::BufferizeTypeConverter() {
38   // Keep all types unchanged.
39   addConversion([](Type type) { return type; });
40   // Convert RankedTensorType to MemRefType.
41   addConversion([](RankedTensorType type) -> Type {
42     return MemRefType::get(type.getShape(), type.getElementType());
43   });
44   // Convert UnrankedTensorType to UnrankedMemRefType.
45   addConversion([](UnrankedTensorType type) -> Type {
46     return UnrankedMemRefType::get(type.getElementType(), 0);
47   });
48   addArgumentMaterialization(materializeToTensor);
49   addSourceMaterialization(materializeToTensor);
50   addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
51                               ValueRange inputs, Location loc) -> Value {
52     assert(inputs.size() == 1 && "expected exactly one input");
53 
54     if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) {
55       // MemRef to MemRef cast.
56       assert(inputType != type && "expected different types");
57       // Unranked to ranked and ranked to unranked casts must be explicit.
58       auto rankedDestType = type.dyn_cast<MemRefType>();
59       if (!rankedDestType)
60         return nullptr;
61       FailureOr<Value> replacement =
62           castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
63       if (failed(replacement))
64         return nullptr;
65       return *replacement;
66     }
67 
68     if (inputs[0].getType().isa<TensorType>()) {
69       // Tensor to MemRef cast.
70       return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
71     }
72 
73     llvm_unreachable("only tensor/memref input types supported");
74   });
75 }
76 
77 void mlir::bufferization::populateBufferizeMaterializationLegality(
78     ConversionTarget &target) {
79   target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
80 }
81 
82 namespace {
83 // In a finalizing bufferize conversion, we know that all tensors have been
84 // converted to memrefs, thus, this op becomes an identity.
85 class BufferizeToTensorOp
86     : public OpConversionPattern<bufferization::ToTensorOp> {
87 public:
88   using OpConversionPattern::OpConversionPattern;
89   LogicalResult
90   matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
91                   ConversionPatternRewriter &rewriter) const override {
92     rewriter.replaceOp(op, adaptor.memref());
93     return success();
94   }
95 };
96 } // namespace
97 
98 namespace {
99 // In a finalizing bufferize conversion, we know that all tensors have been
100 // converted to memrefs, thus, this op becomes an identity.
101 class BufferizeToMemrefOp
102     : public OpConversionPattern<bufferization::ToMemrefOp> {
103 public:
104   using OpConversionPattern::OpConversionPattern;
105   LogicalResult
106   matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
107                   ConversionPatternRewriter &rewriter) const override {
108     rewriter.replaceOp(op, adaptor.tensor());
109     return success();
110   }
111 };
112 } // namespace
113 
114 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
115     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
116   patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
117                                                          patterns.getContext());
118 }
119 
120 namespace {
121 struct FinalizingBufferizePass
122     : public FinalizingBufferizeBase<FinalizingBufferizePass> {
123   using FinalizingBufferizeBase<
124       FinalizingBufferizePass>::FinalizingBufferizeBase;
125 
126   void runOnOperation() override {
127     auto func = getOperation();
128     auto *context = &getContext();
129 
130     BufferizeTypeConverter typeConverter;
131     RewritePatternSet patterns(context);
132     ConversionTarget target(*context);
133 
134     populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
135 
136     // If all result types are legal, and all block arguments are legal (ensured
137     // by func conversion above), then all types in the program are legal.
138     //
139     // We also check that the operand types are legal to avoid creating invalid
140     // IR. For example, this prevents
141     // populateEliminateBufferizeMaterializationsPatterns from updating the
142     // types of the operands to a return op without updating the enclosing
143     // function.
144     target.markUnknownOpDynamicallyLegal(
145         [&](Operation *op) { return typeConverter.isLegal(op); });
146 
147     if (failed(applyFullConversion(func, target, std::move(patterns))))
148       signalPassFailure();
149   }
150 };
151 
152 struct OneShotBufferizePass
153     : public OneShotBufferizeBase<OneShotBufferizePass> {
154   OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {}
155 
156   explicit OneShotBufferizePass(const OneShotBufferizationOptions &options)
157       : options(options) {}
158 
159   void getDependentDialects(DialectRegistry &registry) const override {
160     registry
161         .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
162     registerAllocationOpInterfaceExternalModels(registry);
163   }
164 
165   void runOnOperation() override {
166     OneShotBufferizationOptions opt;
167     if (!options) {
168       // Make new bufferization options if none were provided when creating the
169       // pass.
170       opt.allowReturnAllocs = allowReturnAllocs;
171       opt.allowUnknownOps = allowUnknownOps;
172       opt.analysisFuzzerSeed = analysisFuzzerSeed;
173       opt.createDeallocs = createDeallocs;
174       opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps;
175       opt.printConflicts = printConflicts;
176       opt.testAnalysisOnly = testAnalysisOnly;
177 
178       BufferizationOptions::OpFilterEntry::FilterFn filterFn =
179           [&](Operation *op) {
180             // Disallow non-func dialect ops. I.e., no ops related to function
181             // calls.
182             if (isa<func::FuncDialect>(op->getDialect()))
183               return false;
184             // Filter may be specified via options.
185             if (this->dialectFilter.hasValue())
186               return llvm::find(this->dialectFilter,
187                                 op->getDialect()->getNamespace()) !=
188                      this->dialectFilter.end();
189             // No filter specified: All other ops are allowed.
190             return true;
191           };
192       opt.allowOperationInFilter(filterFn);
193     } else {
194       opt = *options;
195     }
196 
197     ModuleOp moduleOp = getOperation();
198     if (failed(runOneShotBufferize(moduleOp, opt))) {
199       signalPassFailure();
200       return;
201     }
202 
203     if (opt.testAnalysisOnly)
204       return;
205 
206     OpPassManager cleanupPipeline("builtin.module");
207     cleanupPipeline.addPass(createCanonicalizerPass());
208     cleanupPipeline.addPass(createCSEPass());
209     cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
210     (void)runPipeline(cleanupPipeline, moduleOp);
211   }
212 
213 private:
214   llvm::Optional<OneShotBufferizationOptions> options;
215 };
216 } // namespace
217 
218 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() {
219   return std::make_unique<OneShotBufferizePass>();
220 }
221 
222 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass(
223     const OneShotBufferizationOptions &options) {
224   return std::make_unique<OneShotBufferizePass>(options);
225 }
226 
227 std::unique_ptr<OperationPass<func::FuncOp>>
228 mlir::bufferization::createFinalizingBufferizePass() {
229   return std::make_unique<FinalizingBufferizePass>();
230 }
231 
232 //===----------------------------------------------------------------------===//
233 // BufferizableOpInterface-based Bufferization
234 //===----------------------------------------------------------------------===//
235 
236 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
237 
238 /// Return true if the given op has a tensor result or a tensor operand.
239 static bool hasTensorSemantics(Operation *op) {
240   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
241   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
242   return hasTensorResult || hasTensorOperand;
243 }
244 
245 /// Rewrite pattern that bufferizes bufferizable ops.
246 struct BufferizationPattern
247     : public OpInterfaceRewritePattern<BufferizableOpInterface> {
248   BufferizationPattern(MLIRContext *context, BufferizationState &state,
249                        PatternBenefit benefit = 1)
250       : OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit),
251         state(&state) {}
252 
253   LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp,
254                                 PatternRewriter &rewriter) const override {
255     const BufferizationOptions &options = state->getOptions();
256 
257     // No tensors => no buffers.
258     if (!hasTensorSemantics(bufferizableOp.getOperation()))
259       return failure();
260     if (!options.isOpAllowed(bufferizableOp.getOperation()))
261       return failure();
262     return bufferizableOp.bufferize(rewriter, *state);
263   }
264 
265 private:
266   BufferizationState *const state;
267 };
268 
269 /// Check the result of bufferization. Return an error if an op was not
270 /// bufferized, unless partial bufferization is allowed.
271 static LogicalResult
272 checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
273   if (!options.allowUnknownOps) {
274     // Check if all ops were bufferized.
275     LogicalResult status = success();
276     op->walk([&](Operation *op) {
277       if (!hasTensorSemantics(op))
278         return WalkResult::advance();
279 
280       // Bufferization dialect ops will canonicalize away if all other ops are
281       // bufferized.
282       if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op))
283         return WalkResult::advance();
284 
285       // Ops that are not in the allow list can be ignored.
286       if (!options.isOpAllowed(op))
287         return WalkResult::advance();
288 
289       // Ops without any uses and no side effects will fold away.
290       if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
291         return WalkResult::advance();
292 
293       status = op->emitError("op was not bufferized");
294       return WalkResult::interrupt();
295     });
296 
297     if (failed(status))
298       return status;
299   }
300 
301   return success();
302 }
303 
304 LogicalResult
305 bufferization::finalizeBuffers(Operation *op,
306                                const BufferizationOptions &options) {
307   // Hoist buffers.
308   if (failed(hoistBufferAllocations(op, options)))
309     return failure();
310 
311   // Deallocate buffers that escape block boundaries ("leaking buffers") with
312   // the buffer deallocation pass.
313   bool hasLeakingAlloc = false;
314   if (failed(createAllocDeallocOps(op, options, /*onlyLeakingAllocs=*/true,
315                                    &hasLeakingAlloc)))
316     return failure();
317   if (options.createDeallocs && hasLeakingAlloc &&
318       failed(deallocateBuffers(op)))
319     return failure();
320 
321   // Deallocate all remaining buffers at the end of the block.
322   if (failed(createAllocDeallocOps(op, options)))
323     return failure();
324 
325   return success();
326 }
327 
328 LogicalResult bufferization::bufferizeOp(Operation *op,
329                                          const AnalysisState &analysisState) {
330   BufferizationState bufferizationState(analysisState);
331   if (failed(bufferizeOp(op, bufferizationState)))
332     return failure();
333   if (failed(finalizeBuffers(op, analysisState.getOptions())))
334     return failure();
335   return success();
336 }
337 
338 LogicalResult
339 bufferization::bufferizeOp(Operation *op,
340                            BufferizationState &bufferizationState) {
341   // Bufferize the op and its nested ops.
342   RewritePatternSet patterns(op->getContext());
343   patterns.add<BufferizationPattern>(patterns.getContext(), bufferizationState);
344 
345   // Bufferize ops top-to-bottom. When creating a new op, we should ideally
346   // know the exact memref type of all operands. Otherwise, we have to use a
347   // memref type with a fully dynamic layout map, which has to canonicalize
348   // away. This is less efficient.
349   //
350   // Note: If "fullyDynamicLayoutMaps = false", we may have to insert buffer
351   // copies to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast-
352   // compatible layout maps when doing a traversal other than top-to-bottom.
353   // There are currently no canonicalization patterns to fold these away.
354   GreedyRewriteConfig config;
355   config.useTopDownTraversal = true;
356 
357   // TODO: Perform a preorder walk instead of the greedy pattern rewriter. This
358   // would be more efficient because every bufferization pattern is guaranteed
359   // to apply only a single time (otherwise, an assertion would be triggered).
360   // However, there are restrictions wrt. erasing ops during a preorder walk,
361   // which would likely require a larger refactoring.
362   if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
363     return failure();
364 
365   if (failed(checkBufferizationResult(op, bufferizationState.getOptions())))
366     return failure();
367 
368   return success();
369 }
370 
371 namespace {
372 /// This a "no analysis, always copy" AnalysisState. In the absence of an
373 /// analysis, a buffer must be copied each time it is written to. Therefore, all
374 /// OpOperands that bufferize to a memory write must bufferize out-of-place.
375 class AlwaysCopyAnalysisState : public AnalysisState {
376 public:
377   AlwaysCopyAnalysisState(const BufferizationOptions &options)
378       : AnalysisState(options) {}
379 
380   AlwaysCopyAnalysisState(const AlwaysCopyAnalysisState &) = delete;
381 
382   virtual ~AlwaysCopyAnalysisState() = default;
383 
384   /// Return `true` if the given OpResult has been decided to bufferize inplace.
385   bool isInPlace(OpOperand &opOperand) const override {
386     // OpOperands that bufferize to a memory write are out-of-place, i.e., an
387     // alloc and copy is inserted.
388     return !bufferizesToMemoryWrite(opOperand);
389   }
390 
391   /// Return true if `v1` and `v2` bufferize to equivalent buffers.
392   bool areEquivalentBufferizedValues(Value v1, Value v2) const override {
393     // There is no analysis, so we do not know if the values are equivalent. The
394     // conservative answer is "false".
395     return false;
396   }
397 };
398 } // namespace
399 
400 LogicalResult bufferization::bufferizeOp(Operation *op,
401                                          const BufferizationOptions &options) {
402   AlwaysCopyAnalysisState state(options);
403   return bufferizeOp(op, state);
404 }
405 
406 BufferizationOptions bufferization::getPartialBufferizationOptions() {
407   BufferizationOptions options;
408   options.allowUnknownOps = true;
409   options.createDeallocs = false;
410   options.fullyDynamicLayoutMaps = false;
411   return options;
412 }
413