1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
16 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/Pass/PassManager.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "mlir/Transforms/Passes.h"
23 
24 using namespace mlir;
25 using namespace mlir::bufferization;
26 
27 //===----------------------------------------------------------------------===//
28 // BufferizeTypeConverter
29 //===----------------------------------------------------------------------===//
30 
31 static Value materializeToTensor(OpBuilder &builder, TensorType type,
32                                  ValueRange inputs, Location loc) {
33   assert(inputs.size() == 1);
34   assert(inputs[0].getType().isa<BaseMemRefType>());
35   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
36 }
37 
38 /// Registers conversions into BufferizeTypeConverter
39 BufferizeTypeConverter::BufferizeTypeConverter() {
40   // Keep all types unchanged.
41   addConversion([](Type type) { return type; });
42   // Convert RankedTensorType to MemRefType.
43   addConversion([](RankedTensorType type) -> Type {
44     return MemRefType::get(type.getShape(), type.getElementType());
45   });
46   // Convert UnrankedTensorType to UnrankedMemRefType.
47   addConversion([](UnrankedTensorType type) -> Type {
48     return UnrankedMemRefType::get(type.getElementType(), 0);
49   });
50   addArgumentMaterialization(materializeToTensor);
51   addSourceMaterialization(materializeToTensor);
52   addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
53                               ValueRange inputs, Location loc) -> Value {
54     assert(inputs.size() == 1 && "expected exactly one input");
55 
56     if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) {
57       // MemRef to MemRef cast.
58       assert(inputType != type && "expected different types");
59       // Unranked to ranked and ranked to unranked casts must be explicit.
60       auto rankedDestType = type.dyn_cast<MemRefType>();
61       if (!rankedDestType)
62         return nullptr;
63       FailureOr<Value> replacement =
64           castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
65       if (failed(replacement))
66         return nullptr;
67       return *replacement;
68     }
69 
70     if (inputs[0].getType().isa<TensorType>()) {
71       // Tensor to MemRef cast.
72       return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
73     }
74 
75     llvm_unreachable("only tensor/memref input types supported");
76   });
77 }
78 
79 void mlir::bufferization::populateBufferizeMaterializationLegality(
80     ConversionTarget &target) {
81   target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
82 }
83 
84 namespace {
85 // In a finalizing bufferize conversion, we know that all tensors have been
86 // converted to memrefs, thus, this op becomes an identity.
87 class BufferizeToTensorOp
88     : public OpConversionPattern<bufferization::ToTensorOp> {
89 public:
90   using OpConversionPattern::OpConversionPattern;
91   LogicalResult
92   matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
93                   ConversionPatternRewriter &rewriter) const override {
94     rewriter.replaceOp(op, adaptor.memref());
95     return success();
96   }
97 };
98 } // namespace
99 
100 namespace {
101 // In a finalizing bufferize conversion, we know that all tensors have been
102 // converted to memrefs, thus, this op becomes an identity.
103 class BufferizeToMemrefOp
104     : public OpConversionPattern<bufferization::ToMemrefOp> {
105 public:
106   using OpConversionPattern::OpConversionPattern;
107   LogicalResult
108   matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
109                   ConversionPatternRewriter &rewriter) const override {
110     rewriter.replaceOp(op, adaptor.tensor());
111     return success();
112   }
113 };
114 } // namespace
115 
116 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
117     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
118   patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
119                                                          patterns.getContext());
120 }
121 
122 namespace {
123 struct FinalizingBufferizePass
124     : public FinalizingBufferizeBase<FinalizingBufferizePass> {
125   using FinalizingBufferizeBase<
126       FinalizingBufferizePass>::FinalizingBufferizeBase;
127 
128   void runOnOperation() override {
129     auto func = getOperation();
130     auto *context = &getContext();
131 
132     BufferizeTypeConverter typeConverter;
133     RewritePatternSet patterns(context);
134     ConversionTarget target(*context);
135 
136     populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
137 
138     // If all result types are legal, and all block arguments are legal (ensured
139     // by func conversion above), then all types in the program are legal.
140     //
141     // We also check that the operand types are legal to avoid creating invalid
142     // IR. For example, this prevents
143     // populateEliminateBufferizeMaterializationsPatterns from updating the
144     // types of the operands to a return op without updating the enclosing
145     // function.
146     target.markUnknownOpDynamicallyLegal(
147         [&](Operation *op) { return typeConverter.isLegal(op); });
148 
149     if (failed(applyFullConversion(func, target, std::move(patterns))))
150       signalPassFailure();
151   }
152 };
153 
154 static BufferizationOptions::LayoutMapOption
155 parseLayoutMapOption(const std::string &s) {
156   if (s == "fully-dynamic-layout-map")
157     return BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap;
158   if (s == "identity-layout-map")
159     return BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
160   if (s == "infer-layout-map")
161     return BufferizationOptions::LayoutMapOption::InferLayoutMap;
162   llvm_unreachable("invalid layout map option");
163 }
164 
165 struct OneShotBufferizePass
166     : public OneShotBufferizeBase<OneShotBufferizePass> {
167   OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {}
168 
169   explicit OneShotBufferizePass(const OneShotBufferizationOptions &options)
170       : options(options) {}
171 
172   void getDependentDialects(DialectRegistry &registry) const override {
173     registry
174         .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
175     registerAllocationOpInterfaceExternalModels(registry);
176   }
177 
178   void runOnOperation() override {
179     OneShotBufferizationOptions opt;
180     if (!options) {
181       // Make new bufferization options if none were provided when creating the
182       // pass.
183       opt.dropEquivalentFuncResults = dropEquivalentFuncResults;
184       opt.allowReturnAllocs = allowReturnAllocs;
185       opt.allowUnknownOps = allowUnknownOps;
186       opt.alwaysAliasingWithDest = alwaysAliasingWithDest;
187       opt.analysisFuzzerSeed = analysisFuzzerSeed;
188       opt.createDeallocs = createDeallocs;
189       opt.functionBoundaryTypeConversion =
190           parseLayoutMapOption(functionBoundaryTypeConversion);
191       opt.printConflicts = printConflicts;
192       opt.testAnalysisOnly = testAnalysisOnly;
193       opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
194       opt.promoteBufferResultsToOutParams = promoteBufferResultsToOutParams;
195       opt.unknownTypeConversion = parseLayoutMapOption(unknownTypeConversion);
196 
197       OpFilter::Entry::FilterFn filterFn =
198           [&](Operation *op) {
199             // Filter may be specified via options.
200             if (this->dialectFilter.hasValue())
201               return llvm::find(this->dialectFilter,
202                                 op->getDialect()->getNamespace()) !=
203                      this->dialectFilter.end();
204             // No filter specified: All other ops are allowed.
205             return true;
206           };
207       opt.opFilter.allowOperation(filterFn);
208     } else {
209       opt = *options;
210     }
211 
212     ModuleOp moduleOp = getOperation();
213     if (opt.bufferizeFunctionBoundaries) {
214       if (failed(runOneShotModuleBufferize(moduleOp, opt))) {
215         signalPassFailure();
216         return;
217       }
218     } else {
219       if (failed(runOneShotBufferize(moduleOp, opt))) {
220         signalPassFailure();
221         return;
222       }
223     }
224 
225     if (opt.testAnalysisOnly)
226       return;
227 
228     OpPassManager cleanupPipeline("builtin.module");
229     cleanupPipeline.addPass(createCanonicalizerPass());
230     cleanupPipeline.addPass(createCSEPass());
231     cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
232     (void)runPipeline(cleanupPipeline, moduleOp);
233   }
234 
235 private:
236   llvm::Optional<OneShotBufferizationOptions> options;
237 };
238 } // namespace
239 
240 namespace {
241 struct BufferizationBufferizePass
242     : public BufferizationBufferizeBase<BufferizationBufferizePass> {
243   void runOnOperation() override {
244     BufferizationOptions options = getPartialBufferizationOptions();
245     options.opFilter.allowDialect<BufferizationDialect>();
246 
247     if (failed(bufferizeOp(getOperation(), options)))
248       signalPassFailure();
249   }
250 
251   void getDependentDialects(DialectRegistry &registry) const override {
252     registry
253         .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
254   }
255 };
256 } // namespace
257 
258 std::unique_ptr<Pass> mlir::bufferization::createBufferizationBufferizePass() {
259   return std::make_unique<BufferizationBufferizePass>();
260 }
261 
262 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() {
263   return std::make_unique<OneShotBufferizePass>();
264 }
265 
266 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass(
267     const OneShotBufferizationOptions &options) {
268   return std::make_unique<OneShotBufferizePass>(options);
269 }
270 
271 std::unique_ptr<OperationPass<func::FuncOp>>
272 mlir::bufferization::createFinalizingBufferizePass() {
273   return std::make_unique<FinalizingBufferizePass>();
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // BufferizableOpInterface-based Bufferization
278 //===----------------------------------------------------------------------===//
279 
280 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
281 
282 /// Return true if the given op has a tensor result or a tensor operand.
283 static bool hasTensorSemantics(Operation *op) {
284   if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
285     bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
286     bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
287     return hasTensorArg || hasTensorResult;
288   }
289 
290   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
291   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
292   return hasTensorResult || hasTensorOperand;
293 }
294 
295 LogicalResult
296 bufferization::finalizeBuffers(Operation *op,
297                                const BufferizationOptions &options) {
298   // Create allocation ops for "leaking buffers", i.e., buffer allocations that
299   // escape block boundaries. If there are no leaking allocs, `hasLeakingAllocs`
300   // is set to `false`.
301   bool hasLeakingAllocs = false;
302   if (failed(createAllocDeallocOps(op, options, /*onlyLeakingAllocs=*/true,
303                                    &hasLeakingAllocs)))
304     return failure();
305 
306   // Promote returned buffers to "out" parameters.
307   // TODO: Pass options to support custom dealloc ops.
308   if (options.promoteBufferResultsToOutParams && isa<ModuleOp>(op) &&
309       failed(promoteBufferResultsToOutParams(cast<ModuleOp>(op))))
310     return failure();
311 
312   // Create deallocation ops for all "leaking buffers" and all buffer
313   // allocations that were added during the above promotion process.
314   // TODO: Pass options to support custom dealloc ops.
315   if (hasLeakingAllocs && options.createDeallocs &&
316       failed(deallocateBuffers(op)))
317     return failure();
318 
319   // Deallocate all remaining buffers at the end of their parent blocks.
320   if (failed(createAllocDeallocOps(op, options)))
321     return failure();
322 
323   return success();
324 }
325 
326 LogicalResult bufferization::bufferizeOp(Operation *op,
327                                          const AnalysisState &analysisState) {
328   // Catch incorrect API usage.
329   assert((analysisState.hasDialectState(
330               func::FuncDialect::getDialectNamespace()) ||
331           !analysisState.getOptions().bufferizeFunctionBoundaries) &&
332          "must use ModuleBufferize to bufferize function boundaries");
333 
334   BufferizationState bufferizationState(analysisState);
335   if (failed(bufferizeOp(op, bufferizationState)))
336     return failure();
337   if (failed(finalizeBuffers(op, analysisState.getOptions())))
338     return failure();
339   return success();
340 }
341 
342 namespace {
343 /// A rewriter that keeps track of extra information during bufferization.
344 class BufferizationRewriter : public IRRewriter {
345 public:
346   BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
347                         DenseSet<Operation *> &toMemrefOps,
348                         const BufferizationOptions &options,
349                         const OpFilter *opFilter)
350       : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
351         options(options), opFilter(opFilter) {}
352 
353 protected:
354   void notifyOperationRemoved(Operation *op) override {
355     IRRewriter::notifyOperationRemoved(op);
356     erasedOps.insert(op);
357     // Erase if present.
358     toMemrefOps.erase(op);
359   }
360 
361   void notifyOperationInserted(Operation *op) override {
362     IRRewriter::notifyOperationInserted(op);
363 
364     // Keep track of to_memref ops.
365     if (isa<ToMemrefOp>(op)) {
366       toMemrefOps.insert(op);
367       return;
368     }
369 
370     // Skip to_tensor ops.
371     if (isa<ToTensorOp>(op))
372       return;
373 
374     // Skip non-tensor ops.
375     if (!hasTensorSemantics(op))
376       return;
377 
378     // Skip ops that are not allowed.
379     if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op)))
380       return;
381 
382     // Adding new bufferizable ops is not allowed during bufferization. Such ops
383     // would not be analyzed and can lead to surprising behavior.
384     llvm_unreachable(
385         "creating new tensor ops is not allowed during bufferization");
386   }
387 
388 private:
389   /// A set of all erased ops.
390   DenseSet<Operation *> &erasedOps;
391 
392   /// A set of all to_memref ops.
393   DenseSet<Operation *> &toMemrefOps;
394 
395   /// The bufferization options.
396   /// Used for debug modes.
397   LLVM_ATTRIBUTE_UNUSED
398   const BufferizationOptions &options;
399 
400   const OpFilter *opFilter;
401 };
402 } // namespace
403 
404 LogicalResult bufferization::bufferizeOp(Operation *op,
405                                          BufferizationState &bufferizationState,
406                                          const OpFilter *opFilter) {
407   const auto &options = bufferizationState.getOptions();
408   assert(options.unknownTypeConversion !=
409              BufferizationOptions::LayoutMapOption::InferLayoutMap &&
410          "invalid layout map option");
411 
412   // Keep track of to_memref ops.
413   DenseSet<Operation *> toMemrefOps;
414   op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
415 
416   // Gather all bufferizable ops in top-to-bottom order.
417   //
418   // We should ideally know the exact memref type of all operands when
419   // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
420   // Otherwise, we have to use a memref type with a fully dynamic layout map to
421   // avoid copies. We are currently missing patterns for layout maps to
422   // canonicalize away (or canonicalize to more precise layouts).
423   SmallVector<Operation *> worklist;
424   op->walk<WalkOrder::PreOrder>([&](Operation *op) {
425     if (hasTensorSemantics(op))
426       worklist.push_back(op);
427   });
428 
429   // Keep track of all erased ops.
430   DenseSet<Operation *> erasedOps;
431 
432   // Bufferize all ops.
433   BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
434                                  bufferizationState.getOptions(), opFilter);
435   for (unsigned i = 0; i < worklist.size(); ++i) {
436     Operation *op = worklist[i];
437     // Skip ops that were erased.
438     if (erasedOps.contains(op))
439       continue;
440     // Skip ops that are not bufferizable or not allowed.
441     auto bufferizableOp = options.dynCastBufferizableOp(op);
442     if (!bufferizableOp)
443       continue;
444     if (opFilter && !opFilter->isOpAllowed(op))
445       continue;
446     // Skip ops that no longer have tensor semantics.
447     if (!hasTensorSemantics(op))
448       continue;
449     // Bufferize the op.
450     rewriter.setInsertionPoint(op);
451     if (failed(bufferizableOp.bufferize(rewriter, bufferizationState)))
452       return op->emitError("failed to bufferize op");
453   }
454 
455   // Fold all to_memref(to_tensor(x)) pairs.
456   for (Operation *op : toMemrefOps) {
457     rewriter.setInsertionPoint(op);
458     (void)bufferization::foldToMemrefToTensorPair(rewriter,
459                                                   cast<ToMemrefOp>(op));
460   }
461 
462   /// Check the result of bufferization. Return an error if an op was not
463   /// bufferized, unless partial bufferization is allowed.
464   if (bufferizationState.getOptions().allowUnknownOps)
465     return success();
466 
467   for (Operation *op : worklist) {
468     // Skip ops that are entirely gone.
469     if (erasedOps.contains(op))
470       continue;
471     // Ops that no longer have tensor semantics (because they were updated
472     // in-place) are allowed.
473     if (!hasTensorSemantics(op))
474       continue;
475     // Continue ops that are not allowed.
476     if (!options.isOpAllowed(op))
477       continue;
478     if (opFilter && !opFilter->isOpAllowed(op))
479       continue;
480     // Ops without any uses and no side effects will fold away.
481     if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
482       continue;
483     return op->emitError("op was not bufferized");
484   }
485 
486   return success();
487 }
488 
489 namespace {
490 /// This a "no analysis, always copy" AnalysisState. In the absence of an
491 /// analysis, a buffer must be copied each time it is written to. Therefore, all
492 /// OpOperands that bufferize to a memory write must bufferize out-of-place.
493 class AlwaysCopyAnalysisState : public AnalysisState {
494 public:
495   AlwaysCopyAnalysisState(const BufferizationOptions &options)
496       : AnalysisState(options) {
497     // Note: Allocations must be deallocated with a subsequent run of the buffer
498     // deallocation pass.
499     assert(!options.createDeallocs &&
500            "cannot create deallocs with AlwaysCopyBufferizationState");
501   }
502 
503   AlwaysCopyAnalysisState(const AlwaysCopyAnalysisState &) = delete;
504 
505   virtual ~AlwaysCopyAnalysisState() = default;
506 
507   /// Return `true` if the given OpResult has been decided to bufferize inplace.
508   bool isInPlace(OpOperand &opOperand) const override {
509     // OpOperands that bufferize to a memory write are out-of-place, i.e., an
510     // alloc and copy is inserted.
511     return !bufferizesToMemoryWrite(opOperand);
512   }
513 
514   /// Return true if `v1` and `v2` bufferize to equivalent buffers.
515   bool areEquivalentBufferizedValues(Value v1, Value v2) const override {
516     // There is no analysis, so we do not know if the values are equivalent. The
517     // conservative answer is "false".
518     return false;
519   }
520 
521   /// Return true if `v1` and `v2` may bufferize to aliasing buffers.
522   bool areAliasingBufferizedValues(Value v1, Value v2) const override {
523     // There is no analysis, so we do not know if the values are equivalent. The
524     // conservative answer is "true".
525     return true;
526   }
527 
528   /// Return `true` if the given tensor has undefined contents.
529   bool hasUndefinedContents(OpOperand *opOperand) const override {
530     // There is no analysis, so the conservative answer is "false".
531     return false;
532   }
533 
534   /// Return true if the given tensor (or an aliasing tensor) is yielded from
535   /// the containing block. Also include all aliasing tensors in the same block.
536   bool isTensorYielded(Value tensor) const override {
537     // There is no analysis, so conservatively answer "true".
538     return true;
539   }
540 };
541 } // namespace
542 
543 LogicalResult bufferization::bufferizeOp(Operation *op,
544                                          const BufferizationOptions &options) {
545   AlwaysCopyAnalysisState state(options);
546   return bufferizeOp(op, state);
547 }
548 
549 BufferizationOptions bufferization::getPartialBufferizationOptions() {
550   BufferizationOptions options;
551   options.allowUnknownOps = true;
552   options.createDeallocs = false;
553   options.unknownTypeConversion =
554       BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
555   return options;
556 }
557