1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
16 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/Pass/PassManager.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "mlir/Transforms/Passes.h"
23 
24 using namespace mlir;
25 using namespace mlir::bufferization;
26 
27 //===----------------------------------------------------------------------===//
28 // BufferizeTypeConverter
29 //===----------------------------------------------------------------------===//
30 
31 static Value materializeToTensor(OpBuilder &builder, TensorType type,
32                                  ValueRange inputs, Location loc) {
33   assert(inputs.size() == 1);
34   assert(inputs[0].getType().isa<BaseMemRefType>());
35   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
36 }
37 
38 /// Registers conversions into BufferizeTypeConverter
39 BufferizeTypeConverter::BufferizeTypeConverter() {
40   // Keep all types unchanged.
41   addConversion([](Type type) { return type; });
42   // Convert RankedTensorType to MemRefType.
43   addConversion([](RankedTensorType type) -> Type {
44     return MemRefType::get(type.getShape(), type.getElementType());
45   });
46   // Convert UnrankedTensorType to UnrankedMemRefType.
47   addConversion([](UnrankedTensorType type) -> Type {
48     return UnrankedMemRefType::get(type.getElementType(), 0);
49   });
50   addArgumentMaterialization(materializeToTensor);
51   addSourceMaterialization(materializeToTensor);
52   addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
53                               ValueRange inputs, Location loc) -> Value {
54     assert(inputs.size() == 1 && "expected exactly one input");
55 
56     if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) {
57       // MemRef to MemRef cast.
58       assert(inputType != type && "expected different types");
59       // Unranked to ranked and ranked to unranked casts must be explicit.
60       auto rankedDestType = type.dyn_cast<MemRefType>();
61       if (!rankedDestType)
62         return nullptr;
63       FailureOr<Value> replacement =
64           castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
65       if (failed(replacement))
66         return nullptr;
67       return *replacement;
68     }
69 
70     if (inputs[0].getType().isa<TensorType>()) {
71       // Tensor to MemRef cast.
72       return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
73     }
74 
75     llvm_unreachable("only tensor/memref input types supported");
76   });
77 }
78 
79 void mlir::bufferization::populateBufferizeMaterializationLegality(
80     ConversionTarget &target) {
81   target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
82 }
83 
84 namespace {
85 // In a finalizing bufferize conversion, we know that all tensors have been
86 // converted to memrefs, thus, this op becomes an identity.
87 class BufferizeToTensorOp
88     : public OpConversionPattern<bufferization::ToTensorOp> {
89 public:
90   using OpConversionPattern::OpConversionPattern;
91   LogicalResult
92   matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
93                   ConversionPatternRewriter &rewriter) const override {
94     rewriter.replaceOp(op, adaptor.memref());
95     return success();
96   }
97 };
98 } // namespace
99 
100 namespace {
101 // In a finalizing bufferize conversion, we know that all tensors have been
102 // converted to memrefs, thus, this op becomes an identity.
103 class BufferizeToMemrefOp
104     : public OpConversionPattern<bufferization::ToMemrefOp> {
105 public:
106   using OpConversionPattern::OpConversionPattern;
107   LogicalResult
108   matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
109                   ConversionPatternRewriter &rewriter) const override {
110     rewriter.replaceOp(op, adaptor.tensor());
111     return success();
112   }
113 };
114 } // namespace
115 
116 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
117     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
118   patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
119                                                          patterns.getContext());
120 }
121 
122 namespace {
123 struct FinalizingBufferizePass
124     : public FinalizingBufferizeBase<FinalizingBufferizePass> {
125   using FinalizingBufferizeBase<
126       FinalizingBufferizePass>::FinalizingBufferizeBase;
127 
128   void runOnOperation() override {
129     auto func = getOperation();
130     auto *context = &getContext();
131 
132     BufferizeTypeConverter typeConverter;
133     RewritePatternSet patterns(context);
134     ConversionTarget target(*context);
135 
136     populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
137 
138     // If all result types are legal, and all block arguments are legal (ensured
139     // by func conversion above), then all types in the program are legal.
140     //
141     // We also check that the operand types are legal to avoid creating invalid
142     // IR. For example, this prevents
143     // populateEliminateBufferizeMaterializationsPatterns from updating the
144     // types of the operands to a return op without updating the enclosing
145     // function.
146     target.markUnknownOpDynamicallyLegal(
147         [&](Operation *op) { return typeConverter.isLegal(op); });
148 
149     if (failed(applyFullConversion(func, target, std::move(patterns))))
150       signalPassFailure();
151   }
152 };
153 
154 struct OneShotBufferizePass
155     : public OneShotBufferizeBase<OneShotBufferizePass> {
156   OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {}
157 
158   explicit OneShotBufferizePass(const OneShotBufferizationOptions &options)
159       : options(options) {}
160 
161   void getDependentDialects(DialectRegistry &registry) const override {
162     registry
163         .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
164     registerAllocationOpInterfaceExternalModels(registry);
165   }
166 
167   void runOnOperation() override {
168     OneShotBufferizationOptions opt;
169     if (!options) {
170       // Make new bufferization options if none were provided when creating the
171       // pass.
172       opt.dropEquivalentFuncResults = dropEquivalentFuncResults;
173       opt.allowReturnAllocs = allowReturnAllocs;
174       opt.allowUnknownOps = allowUnknownOps;
175       opt.alwaysAliasingWithDest = alwaysAliasingWithDest;
176       opt.analysisFuzzerSeed = analysisFuzzerSeed;
177       opt.createDeallocs = createDeallocs;
178       opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps;
179       opt.printConflicts = printConflicts;
180       opt.testAnalysisOnly = testAnalysisOnly;
181       opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
182 
183       BufferizationOptions::OpFilterEntry::FilterFn filterFn =
184           [&](Operation *op) {
185             // Filter may be specified via options.
186             if (this->dialectFilter.hasValue())
187               return llvm::find(this->dialectFilter,
188                                 op->getDialect()->getNamespace()) !=
189                      this->dialectFilter.end();
190             // No filter specified: All other ops are allowed.
191             return true;
192           };
193       opt.allowOperationInFilter(filterFn);
194     } else {
195       opt = *options;
196     }
197 
198     ModuleOp moduleOp = getOperation();
199     if (opt.bufferizeFunctionBoundaries) {
200       if (failed(runOneShotModuleBufferize(moduleOp, opt))) {
201         signalPassFailure();
202         return;
203       }
204     } else {
205       if (failed(runOneShotBufferize(moduleOp, opt))) {
206         signalPassFailure();
207         return;
208       }
209     }
210 
211     if (opt.testAnalysisOnly)
212       return;
213 
214     OpPassManager cleanupPipeline("builtin.module");
215     cleanupPipeline.addPass(createCanonicalizerPass());
216     cleanupPipeline.addPass(createCSEPass());
217     cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
218     (void)runPipeline(cleanupPipeline, moduleOp);
219   }
220 
221 private:
222   llvm::Optional<OneShotBufferizationOptions> options;
223 };
224 } // namespace
225 
226 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() {
227   return std::make_unique<OneShotBufferizePass>();
228 }
229 
230 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass(
231     const OneShotBufferizationOptions &options) {
232   return std::make_unique<OneShotBufferizePass>(options);
233 }
234 
235 std::unique_ptr<OperationPass<func::FuncOp>>
236 mlir::bufferization::createFinalizingBufferizePass() {
237   return std::make_unique<FinalizingBufferizePass>();
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // BufferizableOpInterface-based Bufferization
242 //===----------------------------------------------------------------------===//
243 
244 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
245 
246 /// Return true if the given op has a tensor result or a tensor operand.
247 static bool hasTensorSemantics(Operation *op) {
248   if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
249     bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
250     bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
251     return hasTensorArg || hasTensorResult;
252   }
253 
254   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
255   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
256   return hasTensorResult || hasTensorOperand;
257 }
258 
259 LogicalResult
260 bufferization::finalizeBuffers(Operation *op,
261                                const BufferizationOptions &options) {
262   // Hoist buffers.
263   if (failed(hoistBufferAllocations(op, options)))
264     return failure();
265 
266   // Deallocate buffers that escape block boundaries ("leaking buffers") with
267   // the buffer deallocation pass.
268   bool hasLeakingAlloc = false;
269   if (failed(createAllocDeallocOps(op, options, /*onlyLeakingAllocs=*/true,
270                                    &hasLeakingAlloc)))
271     return failure();
272   if (options.createDeallocs && hasLeakingAlloc &&
273       failed(deallocateBuffers(op)))
274     return failure();
275 
276   // Deallocate all remaining buffers at the end of the block.
277   if (failed(createAllocDeallocOps(op, options)))
278     return failure();
279 
280   return success();
281 }
282 
283 LogicalResult bufferization::bufferizeOp(Operation *op,
284                                          const AnalysisState &analysisState) {
285   // Catch incorrect API usage.
286   assert((analysisState.hasDialectState(
287               func::FuncDialect::getDialectNamespace()) ||
288           !analysisState.getOptions().bufferizeFunctionBoundaries) &&
289          "must use ModuleBufferize to bufferize function boundaries");
290 
291   BufferizationState bufferizationState(analysisState);
292   if (failed(bufferizeOp(op, bufferizationState)))
293     return failure();
294   if (failed(finalizeBuffers(op, analysisState.getOptions())))
295     return failure();
296   return success();
297 }
298 
299 namespace {
300 /// A rewriter that keeps track of extra information during bufferization.
301 class BufferizationRewriter : public IRRewriter {
302 public:
303   BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
304                         DenseSet<Operation *> &toMemrefOps,
305                         SmallVector<Operation *> &worklist)
306       : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
307         worklist(worklist) {}
308 
309 protected:
310   void notifyOperationRemoved(Operation *op) override {
311     IRRewriter::notifyOperationRemoved(op);
312     erasedOps.insert(op);
313   }
314 
315   void notifyOperationInserted(Operation *op) override {
316     IRRewriter::notifyOperationInserted(op);
317 
318     // Keep track of to_memref ops.
319     if (isa<ToMemrefOp>(op)) {
320       toMemrefOps.insert(op);
321       return;
322     }
323 
324     // Skip to_tensor ops.
325     if (isa<ToTensorOp>(op))
326       return;
327 
328     // A new bufferizable op was inserted. Add it to the worklist.
329     if (hasTensorSemantics(op))
330       worklist.push_back(op);
331   }
332 
333 private:
334   /// A set of all erased ops.
335   DenseSet<Operation *> &erasedOps;
336 
337   /// A set of all to_memref ops.
338   DenseSet<Operation *> &toMemrefOps;
339 
340   /// The list of bufferizable ops.
341   SmallVector<Operation *> &worklist;
342 };
343 } // namespace
344 
345 LogicalResult
346 bufferization::bufferizeOp(Operation *op,
347                            BufferizationState &bufferizationState) {
348   const auto &options = bufferizationState.getOptions();
349 
350   // Keep track of to_memref ops.
351   DenseSet<Operation *> toMemrefOps;
352   op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
353 
354   // Gather all bufferizable ops in top-to-bottom order.
355   //
356   // We should ideally know the exact memref type of all operands when
357   // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
358   // Otherwise, we have to use a memref type with a fully dynamic layout map,
359   // which has to canonicalize away. This is less efficient.
360   //
361   // If "fullyDynamicLayoutMaps = false", we would have to insert buffer copies
362   // to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast-compatible
363   // layout maps when doing a traversal other than top-to-bottom. These would
364   // not easily fold away.
365   SmallVector<Operation *> worklist;
366   op->walk<WalkOrder::PreOrder>([&](Operation *op) {
367     if (hasTensorSemantics(op))
368       worklist.push_back(op);
369   });
370 
371   // Keep track of all erased ops.
372   DenseSet<Operation *> erasedOps;
373 
374   // Bufferize all ops.
375   BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
376                                  worklist);
377   for (unsigned i = 0; i < worklist.size(); ++i) {
378     Operation *op = worklist[i];
379     // Skip ops that were erased.
380     if (erasedOps.contains(op))
381       continue;
382     // Skip ops that are not bufferizable.
383     auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
384     if (!bufferizableOp)
385       continue;
386     // Continue ops that are not allowed.
387     if (!options.isOpAllowed(op))
388       continue;
389     // Bufferize the op.
390     rewriter.setInsertionPoint(op);
391     (void)bufferizableOp.bufferize(rewriter, bufferizationState);
392   }
393 
394   // Fold all to_memref(to_tensor(x)) pairs.
395   for (Operation *op : toMemrefOps) {
396     if (erasedOps.contains(op))
397       continue;
398     rewriter.setInsertionPoint(op);
399     (void)bufferization::foldToMemrefToTensorPair(rewriter,
400                                                   cast<ToMemrefOp>(op));
401   }
402 
403   /// Check the result of bufferization. Return an error if an op was not
404   /// bufferized, unless partial bufferization is allowed.
405   if (bufferizationState.getOptions().allowUnknownOps)
406     return success();
407 
408   for (Operation *op : worklist) {
409     // Skip ops that are entirely gone.
410     if (erasedOps.contains(op))
411       continue;
412     // Ops that no longer have tensor semantics (because they were updated
413     // in-place) are allowed.
414     if (!hasTensorSemantics(op))
415       continue;
416     // Continue ops that are not allowed.
417     if (!options.isOpAllowed(op))
418       continue;
419     // Ops without any uses and no side effects will fold away.
420     if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
421       continue;
422     return op->emitError("op was not bufferized");
423   }
424 
425   return success();
426 }
427 
428 namespace {
429 /// This a "no analysis, always copy" AnalysisState. In the absence of an
430 /// analysis, a buffer must be copied each time it is written to. Therefore, all
431 /// OpOperands that bufferize to a memory write must bufferize out-of-place.
432 class AlwaysCopyAnalysisState : public AnalysisState {
433 public:
434   AlwaysCopyAnalysisState(const BufferizationOptions &options)
435       : AnalysisState(options) {}
436 
437   AlwaysCopyAnalysisState(const AlwaysCopyAnalysisState &) = delete;
438 
439   virtual ~AlwaysCopyAnalysisState() = default;
440 
441   /// Return `true` if the given OpResult has been decided to bufferize inplace.
442   bool isInPlace(OpOperand &opOperand) const override {
443     // OpOperands that bufferize to a memory write are out-of-place, i.e., an
444     // alloc and copy is inserted.
445     return !bufferizesToMemoryWrite(opOperand);
446   }
447 
448   /// Return true if `v1` and `v2` bufferize to equivalent buffers.
449   bool areEquivalentBufferizedValues(Value v1, Value v2) const override {
450     // There is no analysis, so we do not know if the values are equivalent. The
451     // conservative answer is "false".
452     return false;
453   }
454 };
455 } // namespace
456 
457 LogicalResult bufferization::bufferizeOp(Operation *op,
458                                          const BufferizationOptions &options) {
459   AlwaysCopyAnalysisState state(options);
460   return bufferizeOp(op, state);
461 }
462 
463 BufferizationOptions bufferization::getPartialBufferizationOptions() {
464   BufferizationOptions options;
465   options.allowUnknownOps = true;
466   options.createDeallocs = false;
467   options.fullyDynamicLayoutMaps = false;
468   return options;
469 }
470