1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
16 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
17 #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/Pass/PassManager.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "mlir/Transforms/Passes.h"
24 
25 using namespace mlir;
26 using namespace mlir::bufferization;
27 
28 //===----------------------------------------------------------------------===//
29 // BufferizeTypeConverter
30 //===----------------------------------------------------------------------===//
31 
32 static Value materializeToTensor(OpBuilder &builder, TensorType type,
33                                  ValueRange inputs, Location loc) {
34   assert(inputs.size() == 1);
35   assert(inputs[0].getType().isa<BaseMemRefType>());
36   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
37 }
38 
39 /// Registers conversions into BufferizeTypeConverter
40 BufferizeTypeConverter::BufferizeTypeConverter() {
41   // Keep all types unchanged.
42   addConversion([](Type type) { return type; });
43   // Convert RankedTensorType to MemRefType.
44   addConversion([](RankedTensorType type) -> Type {
45     return MemRefType::get(type.getShape(), type.getElementType());
46   });
47   // Convert UnrankedTensorType to UnrankedMemRefType.
48   addConversion([](UnrankedTensorType type) -> Type {
49     return UnrankedMemRefType::get(type.getElementType(), 0);
50   });
51   addArgumentMaterialization(materializeToTensor);
52   addSourceMaterialization(materializeToTensor);
53   addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
54                               ValueRange inputs, Location loc) -> Value {
55     assert(inputs.size() == 1 && "expected exactly one input");
56 
57     if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) {
58       // MemRef to MemRef cast.
59       assert(inputType != type && "expected different types");
60       // Unranked to ranked and ranked to unranked casts must be explicit.
61       auto rankedDestType = type.dyn_cast<MemRefType>();
62       if (!rankedDestType)
63         return nullptr;
64       FailureOr<Value> replacement =
65           castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
66       if (failed(replacement))
67         return nullptr;
68       return *replacement;
69     }
70 
71     if (inputs[0].getType().isa<TensorType>()) {
72       // Tensor to MemRef cast.
73       return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
74     }
75 
76     llvm_unreachable("only tensor/memref input types supported");
77   });
78 }
79 
80 void mlir::bufferization::populateBufferizeMaterializationLegality(
81     ConversionTarget &target) {
82   target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
83 }
84 
85 namespace {
86 // In a finalizing bufferize conversion, we know that all tensors have been
87 // converted to memrefs, thus, this op becomes an identity.
88 class BufferizeToTensorOp
89     : public OpConversionPattern<bufferization::ToTensorOp> {
90 public:
91   using OpConversionPattern::OpConversionPattern;
92   LogicalResult
93   matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
94                   ConversionPatternRewriter &rewriter) const override {
95     rewriter.replaceOp(op, adaptor.getMemref());
96     return success();
97   }
98 };
99 } // namespace
100 
101 namespace {
102 // In a finalizing bufferize conversion, we know that all tensors have been
103 // converted to memrefs, thus, this op becomes an identity.
104 class BufferizeToMemrefOp
105     : public OpConversionPattern<bufferization::ToMemrefOp> {
106 public:
107   using OpConversionPattern::OpConversionPattern;
108   LogicalResult
109   matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
110                   ConversionPatternRewriter &rewriter) const override {
111     rewriter.replaceOp(op, adaptor.getTensor());
112     return success();
113   }
114 };
115 } // namespace
116 
117 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
118     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
119   patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
120                                                          patterns.getContext());
121 }
122 
123 namespace {
124 struct FinalizingBufferizePass
125     : public FinalizingBufferizeBase<FinalizingBufferizePass> {
126   using FinalizingBufferizeBase<
127       FinalizingBufferizePass>::FinalizingBufferizeBase;
128 
129   void runOnOperation() override {
130     auto func = getOperation();
131     auto *context = &getContext();
132 
133     BufferizeTypeConverter typeConverter;
134     RewritePatternSet patterns(context);
135     ConversionTarget target(*context);
136 
137     populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
138 
139     // If all result types are legal, and all block arguments are legal (ensured
140     // by func conversion above), then all types in the program are legal.
141     //
142     // We also check that the operand types are legal to avoid creating invalid
143     // IR. For example, this prevents
144     // populateEliminateBufferizeMaterializationsPatterns from updating the
145     // types of the operands to a return op without updating the enclosing
146     // function.
147     target.markUnknownOpDynamicallyLegal(
148         [&](Operation *op) { return typeConverter.isLegal(op); });
149 
150     if (failed(applyFullConversion(func, target, std::move(patterns))))
151       signalPassFailure();
152   }
153 };
154 
155 static BufferizationOptions::LayoutMapOption
156 parseLayoutMapOption(const std::string &s) {
157   if (s == "fully-dynamic-layout-map")
158     return BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap;
159   if (s == "identity-layout-map")
160     return BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
161   if (s == "infer-layout-map")
162     return BufferizationOptions::LayoutMapOption::InferLayoutMap;
163   llvm_unreachable("invalid layout map option");
164 }
165 
166 struct OneShotBufferizePass
167     : public OneShotBufferizeBase<OneShotBufferizePass> {
168   OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {}
169 
170   explicit OneShotBufferizePass(const OneShotBufferizationOptions &options)
171       : options(options) {}
172 
173   void getDependentDialects(DialectRegistry &registry) const override {
174     registry
175         .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
176     registerAllocationOpInterfaceExternalModels(registry);
177   }
178 
179   void runOnOperation() override {
180     OneShotBufferizationOptions opt;
181     if (!options) {
182       // Make new bufferization options if none were provided when creating the
183       // pass.
184       opt.allowReturnAllocs = allowReturnAllocs;
185       opt.allowUnknownOps = allowUnknownOps;
186       opt.analysisFuzzerSeed = analysisFuzzerSeed;
187       opt.createDeallocs = createDeallocs;
188       opt.functionBoundaryTypeConversion =
189           parseLayoutMapOption(functionBoundaryTypeConversion);
190       if (mustInferMemorySpace)
191         opt.defaultMemorySpace = None;
192       opt.printConflicts = printConflicts;
193       opt.testAnalysisOnly = testAnalysisOnly;
194       opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
195 
196       // Configure type converter.
197       BufferizationOptions::LayoutMapOption unknownTypeConversionOption =
198           parseLayoutMapOption(unknownTypeConversion);
199       opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace,
200                                        const BufferizationOptions &options) {
201         auto tensorType = value.getType().cast<TensorType>();
202         if (unknownTypeConversionOption ==
203             BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
204           return bufferization::getMemRefTypeWithStaticIdentityLayout(
205               tensorType, memorySpace);
206         assert(
207             unknownTypeConversionOption ==
208                 BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap &&
209             "invalid layout map option");
210         return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
211                                                                   memorySpace);
212       };
213 
214       // Configure op filter.
215       OpFilter::Entry::FilterFn filterFn =
216           [&](Operation *op) {
217             // Filter may be specified via options.
218             if (this->dialectFilter.hasValue())
219               return llvm::is_contained(this->dialectFilter,
220                                         op->getDialect()->getNamespace());
221             // No filter specified: All other ops are allowed.
222             return true;
223           };
224       opt.opFilter.allowOperation(filterFn);
225     } else {
226       opt = *options;
227     }
228 
229     ModuleOp moduleOp = getOperation();
230     if (opt.bufferizeFunctionBoundaries) {
231       if (failed(runOneShotModuleBufferize(moduleOp, opt))) {
232         signalPassFailure();
233         return;
234       }
235     } else {
236       if (failed(runOneShotBufferize(moduleOp, opt))) {
237         signalPassFailure();
238         return;
239       }
240     }
241 
242     if (opt.testAnalysisOnly)
243       return;
244 
245     OpPassManager cleanupPipeline("builtin.module");
246     cleanupPipeline.addPass(createCanonicalizerPass());
247     cleanupPipeline.addPass(createCSEPass());
248     cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
249     (void)runPipeline(cleanupPipeline, moduleOp);
250   }
251 
252 private:
253   llvm::Optional<OneShotBufferizationOptions> options;
254 };
255 } // namespace
256 
257 namespace {
258 struct BufferizationBufferizePass
259     : public BufferizationBufferizeBase<BufferizationBufferizePass> {
260   void runOnOperation() override {
261     BufferizationOptions options = getPartialBufferizationOptions();
262     options.opFilter.allowDialect<BufferizationDialect>();
263 
264     if (failed(bufferizeOp(getOperation(), options)))
265       signalPassFailure();
266   }
267 
268   void getDependentDialects(DialectRegistry &registry) const override {
269     registry
270         .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
271   }
272 };
273 } // namespace
274 
275 std::unique_ptr<Pass> mlir::bufferization::createBufferizationBufferizePass() {
276   return std::make_unique<BufferizationBufferizePass>();
277 }
278 
279 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() {
280   return std::make_unique<OneShotBufferizePass>();
281 }
282 
283 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass(
284     const OneShotBufferizationOptions &options) {
285   return std::make_unique<OneShotBufferizePass>(options);
286 }
287 
288 std::unique_ptr<OperationPass<func::FuncOp>>
289 mlir::bufferization::createFinalizingBufferizePass() {
290   return std::make_unique<FinalizingBufferizePass>();
291 }
292 
293 //===----------------------------------------------------------------------===//
294 // BufferizableOpInterface-based Bufferization
295 //===----------------------------------------------------------------------===//
296 
297 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
298 
299 /// Return true if the given op has a tensor result or a tensor operand.
300 static bool hasTensorSemantics(Operation *op) {
301   if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
302     bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
303     bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
304     return hasTensorArg || hasTensorResult;
305   }
306 
307   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
308   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
309   return hasTensorResult || hasTensorOperand;
310 }
311 
312 namespace {
313 /// A rewriter that keeps track of extra information during bufferization.
314 class BufferizationRewriter : public IRRewriter {
315 public:
316   BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
317                         DenseSet<Operation *> &toMemrefOps,
318                         SmallVector<Operation *> &worklist,
319                         const BufferizationOptions &options,
320                         const OpFilter *opFilter)
321       : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
322         worklist(worklist), analysisState(options), opFilter(opFilter) {}
323 
324 protected:
325   void notifyOperationRemoved(Operation *op) override {
326     IRRewriter::notifyOperationRemoved(op);
327     erasedOps.insert(op);
328     // Erase if present.
329     toMemrefOps.erase(op);
330   }
331 
332   void notifyOperationInserted(Operation *op) override {
333     IRRewriter::notifyOperationInserted(op);
334     erasedOps.erase(op);
335 
336     // Keep track of to_memref ops.
337     if (isa<ToMemrefOp>(op)) {
338       toMemrefOps.insert(op);
339       return;
340     }
341 
342     // Skip to_tensor ops.
343     if (isa<ToTensorOp>(op))
344       return;
345 
346     // Skip non-tensor ops.
347     if (!hasTensorSemantics(op))
348       return;
349 
350     // Skip ops that are not allowed to be bufferized.
351     auto const &options = analysisState.getOptions();
352     if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op)))
353       return;
354 
355 #ifndef NDEBUG
356     // Read-only tensor ops may be created during bufferization. Ops that are
357     // writing should not be created because such ops were never analyzed.
358     // Bufferizing such ops could introduce a RaW conflict.
359     for (OpOperand &operand : op->getOpOperands())
360       if (operand.get().getType().isa<TensorType>())
361         assert(!analysisState.bufferizesToMemoryWrite(operand) &&
362                "creating tensor ops that bufferize to a memory write is not "
363                "allowed during bufferization");
364 #endif // NDEBUG
365 
366     // Add op to worklist.
367     worklist.push_back(op);
368   }
369 
370 private:
371   /// A set of all erased ops.
372   DenseSet<Operation *> &erasedOps;
373 
374   /// A set of all to_memref ops.
375   DenseSet<Operation *> &toMemrefOps;
376 
377   /// The worklist of ops to be bufferized.
378   SmallVector<Operation *> &worklist;
379 
380   /// The analysis state. Used for debug assertions and access to the
381   /// bufferization options.
382   const AnalysisState analysisState;
383 
384   /// An extra op filter for bufferization.
385   const OpFilter *opFilter;
386 };
387 } // namespace
388 
389 LogicalResult bufferization::bufferizeOp(Operation *op,
390                                          const BufferizationOptions &options,
391                                          bool copyBeforeWrite,
392                                          const OpFilter *opFilter) {
393   if (copyBeforeWrite) {
394     AnalysisState state(options);
395     if (failed(insertTensorCopies(op, state)))
396       return failure();
397   }
398 
399   // Keep track of to_memref ops.
400   DenseSet<Operation *> toMemrefOps;
401   op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
402 
403   // Gather all bufferizable ops in top-to-bottom order.
404   //
405   // We should ideally know the exact memref type of all operands when
406   // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
407   // Otherwise, we have to use a memref type with a fully dynamic layout map to
408   // avoid copies. We are currently missing patterns for layout maps to
409   // canonicalize away (or canonicalize to more precise layouts).
410   //
411   // FuncOps must be bufferized before their bodies, so add them to the worklist
412   // first.
413   SmallVector<Operation *> worklist;
414   op->walk([&](func::FuncOp funcOp) {
415     if (hasTensorSemantics(funcOp))
416       worklist.push_back(funcOp);
417   });
418   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
419     if (hasTensorSemantics(op) && !isa<func::FuncOp>(op))
420       worklist.push_back(op);
421   });
422 
423   // Keep track of all erased ops.
424   DenseSet<Operation *> erasedOps;
425 
426   // Bufferize all ops.
427   BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
428                                  worklist, options, opFilter);
429   for (unsigned i = 0; i < worklist.size(); ++i) {
430     Operation *op = worklist[i];
431     // Skip ops that were erased.
432     if (erasedOps.contains(op))
433       continue;
434     // Skip ops that are not bufferizable or not allowed.
435     auto bufferizableOp = options.dynCastBufferizableOp(op);
436     if (!bufferizableOp)
437       continue;
438     if (opFilter && !opFilter->isOpAllowed(op))
439       continue;
440     // Skip ops that no longer have tensor semantics.
441     if (!hasTensorSemantics(op))
442       continue;
443     // Bufferize the op.
444     rewriter.setInsertionPoint(op);
445     if (failed(bufferizableOp.bufferize(rewriter, options)))
446       return op->emitError("failed to bufferize op");
447   }
448 
449   // Fold all to_memref(to_tensor(x)) pairs.
450   for (Operation *op : toMemrefOps) {
451     rewriter.setInsertionPoint(op);
452     (void)bufferization::foldToMemrefToTensorPair(rewriter,
453                                                   cast<ToMemrefOp>(op));
454   }
455 
456   /// Check the result of bufferization. Return an error if an op was not
457   /// bufferized, unless partial bufferization is allowed.
458   if (options.allowUnknownOps)
459     return success();
460 
461   for (Operation *op : worklist) {
462     // Skip ops that are entirely gone.
463     if (erasedOps.contains(op))
464       continue;
465     // Ops that no longer have tensor semantics (because they were updated
466     // in-place) are allowed.
467     if (!hasTensorSemantics(op))
468       continue;
469     // Continue ops that are not allowed.
470     if (!options.isOpAllowed(op))
471       continue;
472     if (opFilter && !opFilter->isOpAllowed(op))
473       continue;
474     // Ops without any uses and no side effects will fold away.
475     if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
476       continue;
477     // ToTensorOps/ToMemrefOps are allowed in the output.
478     if (isa<ToTensorOp, ToMemrefOp>(op))
479       continue;
480     return op->emitError("op was not bufferized");
481   }
482 
483   return success();
484 }
485 
486 BufferizationOptions bufferization::getPartialBufferizationOptions() {
487   BufferizationOptions options;
488   options.allowUnknownOps = true;
489   options.createDeallocs = false;
490   options.enforceAliasingInvariants = false;
491   options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
492                                       const BufferizationOptions &options) {
493     return getMemRefTypeWithStaticIdentityLayout(
494         value.getType().cast<TensorType>(), memorySpace);
495   };
496   options.opFilter.allowDialect<BufferizationDialect>();
497   return options;
498 }
499