1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
16 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Pass/PassManager.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 #include "mlir/Transforms/Passes.h"
22 
23 using namespace mlir;
24 using namespace mlir::bufferization;
25 
26 //===----------------------------------------------------------------------===//
27 // BufferizeTypeConverter
28 //===----------------------------------------------------------------------===//
29 
30 static Value materializeToTensor(OpBuilder &builder, TensorType type,
31                                  ValueRange inputs, Location loc) {
32   assert(inputs.size() == 1);
33   assert(inputs[0].getType().isa<BaseMemRefType>());
34   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
35 }
36 
37 /// Registers conversions into BufferizeTypeConverter
38 BufferizeTypeConverter::BufferizeTypeConverter() {
39   // Keep all types unchanged.
40   addConversion([](Type type) { return type; });
41   // Convert RankedTensorType to MemRefType.
42   addConversion([](RankedTensorType type) -> Type {
43     return MemRefType::get(type.getShape(), type.getElementType());
44   });
45   // Convert UnrankedTensorType to UnrankedMemRefType.
46   addConversion([](UnrankedTensorType type) -> Type {
47     return UnrankedMemRefType::get(type.getElementType(), 0);
48   });
49   addArgumentMaterialization(materializeToTensor);
50   addSourceMaterialization(materializeToTensor);
51   addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
52                               ValueRange inputs, Location loc) -> Value {
53     assert(inputs.size() == 1 && "expected exactly one input");
54 
55     if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) {
56       // MemRef to MemRef cast.
57       assert(inputType != type && "expected different types");
58       // Unranked to ranked and ranked to unranked casts must be explicit.
59       auto rankedDestType = type.dyn_cast<MemRefType>();
60       if (!rankedDestType)
61         return nullptr;
62       FailureOr<Value> replacement =
63           castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
64       if (failed(replacement))
65         return nullptr;
66       return *replacement;
67     }
68 
69     if (inputs[0].getType().isa<TensorType>()) {
70       // Tensor to MemRef cast.
71       return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
72     }
73 
74     llvm_unreachable("only tensor/memref input types supported");
75   });
76 }
77 
78 void mlir::bufferization::populateBufferizeMaterializationLegality(
79     ConversionTarget &target) {
80   target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
81 }
82 
83 namespace {
84 // In a finalizing bufferize conversion, we know that all tensors have been
85 // converted to memrefs, thus, this op becomes an identity.
86 class BufferizeToTensorOp
87     : public OpConversionPattern<bufferization::ToTensorOp> {
88 public:
89   using OpConversionPattern::OpConversionPattern;
90   LogicalResult
91   matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
92                   ConversionPatternRewriter &rewriter) const override {
93     rewriter.replaceOp(op, adaptor.memref());
94     return success();
95   }
96 };
97 } // namespace
98 
99 namespace {
100 // In a finalizing bufferize conversion, we know that all tensors have been
101 // converted to memrefs, thus, this op becomes an identity.
102 class BufferizeToMemrefOp
103     : public OpConversionPattern<bufferization::ToMemrefOp> {
104 public:
105   using OpConversionPattern::OpConversionPattern;
106   LogicalResult
107   matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
108                   ConversionPatternRewriter &rewriter) const override {
109     rewriter.replaceOp(op, adaptor.tensor());
110     return success();
111   }
112 };
113 } // namespace
114 
115 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
116     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
117   patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
118                                                          patterns.getContext());
119 }
120 
121 namespace {
122 struct FinalizingBufferizePass
123     : public FinalizingBufferizeBase<FinalizingBufferizePass> {
124   using FinalizingBufferizeBase<
125       FinalizingBufferizePass>::FinalizingBufferizeBase;
126 
127   void runOnOperation() override {
128     auto func = getOperation();
129     auto *context = &getContext();
130 
131     BufferizeTypeConverter typeConverter;
132     RewritePatternSet patterns(context);
133     ConversionTarget target(*context);
134 
135     populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
136 
137     // If all result types are legal, and all block arguments are legal (ensured
138     // by func conversion above), then all types in the program are legal.
139     //
140     // We also check that the operand types are legal to avoid creating invalid
141     // IR. For example, this prevents
142     // populateEliminateBufferizeMaterializationsPatterns from updating the
143     // types of the operands to a return op without updating the enclosing
144     // function.
145     target.markUnknownOpDynamicallyLegal(
146         [&](Operation *op) { return typeConverter.isLegal(op); });
147 
148     if (failed(applyFullConversion(func, target, std::move(patterns))))
149       signalPassFailure();
150   }
151 };
152 
153 struct OneShotBufferizePass
154     : public OneShotBufferizeBase<OneShotBufferizePass> {
155   OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {}
156 
157   explicit OneShotBufferizePass(const OneShotBufferizationOptions &options)
158       : options(options) {}
159 
160   void getDependentDialects(DialectRegistry &registry) const override {
161     registry
162         .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
163     registerAllocationOpInterfaceExternalModels(registry);
164   }
165 
166   void runOnOperation() override {
167     OneShotBufferizationOptions opt;
168     if (!options) {
169       // Make new bufferization options if none were provided when creating the
170       // pass.
171       opt.allowReturnAllocs = allowReturnAllocs;
172       opt.allowUnknownOps = allowUnknownOps;
173       opt.analysisFuzzerSeed = analysisFuzzerSeed;
174       opt.createDeallocs = createDeallocs;
175       opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps;
176       opt.printConflicts = printConflicts;
177       opt.testAnalysisOnly = testAnalysisOnly;
178       opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
179 
180       BufferizationOptions::OpFilterEntry::FilterFn filterFn =
181           [&](Operation *op) {
182             // Filter may be specified via options.
183             if (this->dialectFilter.hasValue())
184               return llvm::find(this->dialectFilter,
185                                 op->getDialect()->getNamespace()) !=
186                      this->dialectFilter.end();
187             // No filter specified: All other ops are allowed.
188             return true;
189           };
190       opt.allowOperationInFilter(filterFn);
191     } else {
192       opt = *options;
193     }
194 
195     ModuleOp moduleOp = getOperation();
196     if (opt.bufferizeFunctionBoundaries) {
197       if (failed(runOneShotModuleBufferize(moduleOp, opt))) {
198         signalPassFailure();
199         return;
200       }
201     } else {
202       if (failed(runOneShotBufferize(moduleOp, opt))) {
203         signalPassFailure();
204         return;
205       }
206     }
207 
208     if (opt.testAnalysisOnly)
209       return;
210 
211     OpPassManager cleanupPipeline("builtin.module");
212     cleanupPipeline.addPass(createCanonicalizerPass());
213     cleanupPipeline.addPass(createCSEPass());
214     cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
215     (void)runPipeline(cleanupPipeline, moduleOp);
216   }
217 
218 private:
219   llvm::Optional<OneShotBufferizationOptions> options;
220 };
221 } // namespace
222 
223 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() {
224   return std::make_unique<OneShotBufferizePass>();
225 }
226 
227 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass(
228     const OneShotBufferizationOptions &options) {
229   return std::make_unique<OneShotBufferizePass>(options);
230 }
231 
232 std::unique_ptr<OperationPass<func::FuncOp>>
233 mlir::bufferization::createFinalizingBufferizePass() {
234   return std::make_unique<FinalizingBufferizePass>();
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // BufferizableOpInterface-based Bufferization
239 //===----------------------------------------------------------------------===//
240 
241 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
242 
243 /// Return true if the given op has a tensor result or a tensor operand.
244 static bool hasTensorSemantics(Operation *op) {
245   if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
246     bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
247     bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
248     return hasTensorArg || hasTensorResult;
249   }
250 
251   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
252   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
253   return hasTensorResult || hasTensorOperand;
254 }
255 
256 LogicalResult
257 bufferization::finalizeBuffers(Operation *op,
258                                const BufferizationOptions &options) {
259   // Hoist buffers.
260   if (failed(hoistBufferAllocations(op, options)))
261     return failure();
262 
263   // Deallocate buffers that escape block boundaries ("leaking buffers") with
264   // the buffer deallocation pass.
265   bool hasLeakingAlloc = false;
266   if (failed(createAllocDeallocOps(op, options, /*onlyLeakingAllocs=*/true,
267                                    &hasLeakingAlloc)))
268     return failure();
269   if (options.createDeallocs && hasLeakingAlloc &&
270       failed(deallocateBuffers(op)))
271     return failure();
272 
273   // Deallocate all remaining buffers at the end of the block.
274   if (failed(createAllocDeallocOps(op, options)))
275     return failure();
276 
277   return success();
278 }
279 
280 LogicalResult bufferization::bufferizeOp(Operation *op,
281                                          const AnalysisState &analysisState) {
282   // Catch incorrect API usage.
283   assert((analysisState.hasDialectState(
284               func::FuncDialect::getDialectNamespace()) ||
285           !analysisState.getOptions().bufferizeFunctionBoundaries) &&
286          "must use ModuleBufferize to bufferize function boundaries");
287 
288   BufferizationState bufferizationState(analysisState);
289   if (failed(bufferizeOp(op, bufferizationState)))
290     return failure();
291   if (failed(finalizeBuffers(op, analysisState.getOptions())))
292     return failure();
293   return success();
294 }
295 
296 namespace {
297 /// A rewriter that keeps track of extra information during bufferization.
298 class BufferizationRewriter : public IRRewriter {
299 public:
300   BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
301                         DenseSet<Operation *> &toMemrefOps,
302                         SmallVector<Operation *> &worklist)
303       : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
304         worklist(worklist) {}
305 
306 protected:
307   void notifyOperationRemoved(Operation *op) override {
308     IRRewriter::notifyOperationRemoved(op);
309     erasedOps.insert(op);
310   }
311 
312   void notifyOperationInserted(Operation *op) override {
313     IRRewriter::notifyOperationInserted(op);
314 
315     // Keep track of to_memref ops.
316     if (isa<ToMemrefOp>(op)) {
317       toMemrefOps.insert(op);
318       return;
319     }
320 
321     // Skip to_tensor ops.
322     if (isa<ToTensorOp>(op))
323       return;
324 
325     // A new bufferizable op was inserted. Add it to the worklist.
326     if (hasTensorSemantics(op))
327       worklist.push_back(op);
328   }
329 
330 private:
331   /// A set of all erased ops.
332   DenseSet<Operation *> &erasedOps;
333 
334   /// A set of all to_memref ops.
335   DenseSet<Operation *> &toMemrefOps;
336 
337   /// The list of bufferizable ops.
338   SmallVector<Operation *> &worklist;
339 };
340 } // namespace
341 
342 LogicalResult
343 bufferization::bufferizeOp(Operation *op,
344                            BufferizationState &bufferizationState) {
345   const auto &options = bufferizationState.getOptions();
346 
347   // Keep track of to_memref ops.
348   DenseSet<Operation *> toMemrefOps;
349   op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
350 
351   // Gather all bufferizable ops in top-to-bottom order.
352   //
353   // We should ideally know the exact memref type of all operands when
354   // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
355   // Otherwise, we have to use a memref type with a fully dynamic layout map,
356   // which has to canonicalize away. This is less efficient.
357   //
358   // If "fullyDynamicLayoutMaps = false", we would have to insert buffer copies
359   // to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast-compatible
360   // layout maps when doing a traversal other than top-to-bottom. These would
361   // not easily fold away.
362   SmallVector<Operation *> worklist;
363   op->walk<WalkOrder::PreOrder>([&](Operation *op) {
364     if (hasTensorSemantics(op))
365       worklist.push_back(op);
366   });
367 
368   // Keep track of all erased ops.
369   DenseSet<Operation *> erasedOps;
370 
371   // Bufferize all ops.
372   BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
373                                  worklist);
374   for (unsigned i = 0; i < worklist.size(); ++i) {
375     Operation *op = worklist[i];
376     // Skip ops that were erased.
377     if (erasedOps.contains(op))
378       continue;
379     // Skip ops that are not bufferizable.
380     auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
381     if (!bufferizableOp)
382       continue;
383     // Continue ops that are not allowed.
384     if (!options.isOpAllowed(op))
385       continue;
386     // Bufferize the op.
387     rewriter.setInsertionPoint(op);
388     (void)bufferizableOp.bufferize(rewriter, bufferizationState);
389   }
390 
391   // Fold all to_memref(to_tensor(x)) pairs.
392   for (Operation *op : toMemrefOps) {
393     if (erasedOps.contains(op))
394       continue;
395     rewriter.setInsertionPoint(op);
396     (void)bufferization::foldToMemrefToTensorPair(rewriter,
397                                                   cast<ToMemrefOp>(op));
398   }
399 
400   /// Check the result of bufferization. Return an error if an op was not
401   /// bufferized, unless partial bufferization is allowed.
402   if (bufferizationState.getOptions().allowUnknownOps)
403     return success();
404 
405   for (Operation *op : worklist) {
406     // Skip ops that are entirely gone.
407     if (erasedOps.contains(op))
408       continue;
409     // Ops that no longer have tensor semantics (because they were updated
410     // in-place) are allowed.
411     if (!hasTensorSemantics(op))
412       continue;
413     // Continue ops that are not allowed.
414     if (!options.isOpAllowed(op))
415       continue;
416     // Ops without any uses and no side effects will fold away.
417     if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
418       continue;
419     return op->emitError("op was not bufferized");
420   }
421 
422   return success();
423 }
424 
425 namespace {
426 /// This a "no analysis, always copy" AnalysisState. In the absence of an
427 /// analysis, a buffer must be copied each time it is written to. Therefore, all
428 /// OpOperands that bufferize to a memory write must bufferize out-of-place.
429 class AlwaysCopyAnalysisState : public AnalysisState {
430 public:
431   AlwaysCopyAnalysisState(const BufferizationOptions &options)
432       : AnalysisState(options) {}
433 
434   AlwaysCopyAnalysisState(const AlwaysCopyAnalysisState &) = delete;
435 
436   virtual ~AlwaysCopyAnalysisState() = default;
437 
438   /// Return `true` if the given OpResult has been decided to bufferize inplace.
439   bool isInPlace(OpOperand &opOperand) const override {
440     // OpOperands that bufferize to a memory write are out-of-place, i.e., an
441     // alloc and copy is inserted.
442     return !bufferizesToMemoryWrite(opOperand);
443   }
444 
445   /// Return true if `v1` and `v2` bufferize to equivalent buffers.
446   bool areEquivalentBufferizedValues(Value v1, Value v2) const override {
447     // There is no analysis, so we do not know if the values are equivalent. The
448     // conservative answer is "false".
449     return false;
450   }
451 };
452 } // namespace
453 
454 LogicalResult bufferization::bufferizeOp(Operation *op,
455                                          const BufferizationOptions &options) {
456   AlwaysCopyAnalysisState state(options);
457   return bufferizeOp(op, state);
458 }
459 
460 BufferizationOptions bufferization::getPartialBufferizationOptions() {
461   BufferizationOptions options;
462   options.allowUnknownOps = true;
463   options.createDeallocs = false;
464   options.fullyDynamicLayoutMaps = false;
465   return options;
466 }
467