1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 
21 //===----------------------------------------------------------------------===//
22 // BufferizeTypeConverter
23 //===----------------------------------------------------------------------===//
24 
25 static Value materializeToTensor(OpBuilder &builder, TensorType type,
26                                  ValueRange inputs, Location loc) {
27   assert(inputs.size() == 1);
28   assert(inputs[0].getType().isa<BaseMemRefType>());
29   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
30 }
31 
32 /// Registers conversions into BufferizeTypeConverter
33 BufferizeTypeConverter::BufferizeTypeConverter() {
34   // Keep all types unchanged.
35   addConversion([](Type type) { return type; });
36   // Convert RankedTensorType to MemRefType.
37   addConversion([](RankedTensorType type) -> Type {
38     return MemRefType::get(type.getShape(), type.getElementType());
39   });
40   // Convert UnrankedTensorType to UnrankedMemRefType.
41   addConversion([](UnrankedTensorType type) -> Type {
42     return UnrankedMemRefType::get(type.getElementType(), 0);
43   });
44   addArgumentMaterialization(materializeToTensor);
45   addSourceMaterialization(materializeToTensor);
46   addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
47                               ValueRange inputs, Location loc) -> Value {
48     assert(inputs.size() == 1);
49     assert(inputs[0].getType().isa<TensorType>());
50     return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
51   });
52 }
53 
54 void mlir::bufferization::populateBufferizeMaterializationLegality(
55     ConversionTarget &target) {
56   target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
57 }
58 
59 namespace {
60 // In a finalizing bufferize conversion, we know that all tensors have been
61 // converted to memrefs, thus, this op becomes an identity.
62 class BufferizeToTensorOp
63     : public OpConversionPattern<bufferization::ToTensorOp> {
64 public:
65   using OpConversionPattern::OpConversionPattern;
66   LogicalResult
67   matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
68                   ConversionPatternRewriter &rewriter) const override {
69     rewriter.replaceOp(op, adaptor.memref());
70     return success();
71   }
72 };
73 } // namespace
74 
75 namespace {
76 // In a finalizing bufferize conversion, we know that all tensors have been
77 // converted to memrefs, thus, this op becomes an identity.
78 class BufferizeToMemrefOp
79     : public OpConversionPattern<bufferization::ToMemrefOp> {
80 public:
81   using OpConversionPattern::OpConversionPattern;
82   LogicalResult
83   matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
84                   ConversionPatternRewriter &rewriter) const override {
85     rewriter.replaceOp(op, adaptor.tensor());
86     return success();
87   }
88 };
89 } // namespace
90 
91 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
92     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
93   patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
94                                                          patterns.getContext());
95 }
96 
97 namespace {
98 struct FinalizingBufferizePass
99     : public FinalizingBufferizeBase<FinalizingBufferizePass> {
100   using FinalizingBufferizeBase<
101       FinalizingBufferizePass>::FinalizingBufferizeBase;
102 
103   void runOnOperation() override {
104     auto func = getOperation();
105     auto *context = &getContext();
106 
107     BufferizeTypeConverter typeConverter;
108     RewritePatternSet patterns(context);
109     ConversionTarget target(*context);
110 
111     populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
112 
113     // If all result types are legal, and all block arguments are legal (ensured
114     // by func conversion above), then all types in the program are legal.
115     //
116     // We also check that the operand types are legal to avoid creating invalid
117     // IR. For example, this prevents
118     // populateEliminateBufferizeMaterializationsPatterns from updating the
119     // types of the operands to a return op without updating the enclosing
120     // function.
121     target.markUnknownOpDynamicallyLegal(
122         [&](Operation *op) { return typeConverter.isLegal(op); });
123 
124     if (failed(applyFullConversion(func, target, std::move(patterns))))
125       signalPassFailure();
126   }
127 };
128 } // namespace
129 
130 std::unique_ptr<OperationPass<FuncOp>>
131 mlir::bufferization::createFinalizingBufferizePass() {
132   return std::make_unique<FinalizingBufferizePass>();
133 }
134 
135 //===----------------------------------------------------------------------===//
136 // BufferizableOpInterface-based Bufferization
137 //===----------------------------------------------------------------------===//
138 
139 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
140 
141 /// Return true if the given op has a tensor result or a tensor operand.
142 static bool hasTensorSemantics(Operation *op) {
143   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
144   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
145   return hasTensorResult || hasTensorOperand;
146 }
147 
148 /// Rewrite pattern that bufferizes bufferizable ops.
149 struct BufferizationPattern
150     : public OpInterfaceRewritePattern<BufferizableOpInterface> {
151   BufferizationPattern(MLIRContext *context, const BufferizationState &state,
152                        PatternBenefit benefit = 1)
153       : OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit),
154         state(state) {}
155 
156   LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp,
157                                 PatternRewriter &rewriter) const override {
158     // No tensors => no buffers.
159     if (!hasTensorSemantics(bufferizableOp.getOperation()))
160       return failure();
161     if (!state.getOptions().isOpAllowed(bufferizableOp.getOperation()))
162       return failure();
163     return bufferizableOp.bufferize(rewriter, state);
164   }
165 
166 private:
167   const BufferizationState &state;
168 };
169 
170 /// Check the result of bufferization. Return an error if an op was not
171 /// bufferized, unless partial bufferization is allowed.
172 static LogicalResult
173 checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
174   if (!options.allowUnknownOps) {
175     // Check if all ops were bufferized.
176     LogicalResult status = success();
177     op->walk([&](Operation *op) {
178       if (!hasTensorSemantics(op))
179         return WalkResult::advance();
180 
181       // Bufferization dialect ops will canonicalize away if all other ops are
182       // bufferized.
183       if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op))
184         return WalkResult::advance();
185 
186       // Ops that are not in the allow list can be ignored.
187       if (!options.isOpAllowed(op))
188         return WalkResult::advance();
189 
190       // Ops without any uses and no side effects will fold away.
191       if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
192         return WalkResult::advance();
193 
194       status = op->emitError("op was not bufferized");
195       return WalkResult::interrupt();
196     });
197 
198     if (failed(status))
199       return status;
200   }
201 
202   return success();
203 }
204 
205 LogicalResult bufferization::bufferizeOp(Operation *op,
206                                          const BufferizationState &state) {
207   // Bufferize the op and its nested ops.
208   RewritePatternSet patterns(op->getContext());
209   populateBufferizationPattern(state, patterns);
210   if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
211     return failure();
212 
213   return checkBufferizationResult(op, state.getOptions());
214 }
215 
216 namespace {
217 /// This a "no analysis, always copy" BufferizationState. In the absence of an
218 /// analysis, a buffer must be copied each time it is written to. Therefore, all
219 /// OpOperands that bufferize to a memory write must bufferize out-of-place.
220 class AlwaysCopyBufferizationState : public BufferizationState {
221 public:
222   AlwaysCopyBufferizationState(const BufferizationOptions &options)
223       : BufferizationState(options) {}
224 
225   AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete;
226 
227   virtual ~AlwaysCopyBufferizationState() = default;
228 
229   /// Return `true` if the given OpResult has been decided to bufferize inplace.
230   bool isInPlace(OpOperand &opOperand) const override {
231     // OpOperands that bufferize to a memory write are out-of-place, i.e., an
232     // alloc and copy is inserted.
233     return !bufferizesToMemoryWrite(opOperand);
234   }
235 
236   /// Return true if `v1` and `v2` bufferize to equivalent buffers.
237   bool areEquivalentBufferizedValues(Value v1, Value v2) const override {
238     // There is no analysis, so we do not know if the values are equivalent. The
239     // conservative answer is "false".
240     return false;
241   }
242 };
243 } // namespace
244 
245 LogicalResult bufferization::bufferizeOp(Operation *op,
246                                          const BufferizationOptions &options) {
247   AlwaysCopyBufferizationState state(options);
248   return bufferizeOp(op, state);
249 }
250 
251 void bufferization::populateBufferizationPattern(
252     const BufferizationState &state, RewritePatternSet &patterns) {
253   patterns.add<BufferizationPattern>(patterns.getContext(), state);
254 }
255 
256 BufferizationOptions bufferization::getPartialBufferizationOptions() {
257   BufferizationOptions options;
258   options.allowReturnMemref = true;
259   options.allowUnknownOps = true;
260   options.createDeallocs = false;
261   options.fullyDynamicLayoutMaps = false;
262   return options;
263 }
264