1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 
21 //===----------------------------------------------------------------------===//
22 // BufferizeTypeConverter
23 //===----------------------------------------------------------------------===//
24 
25 static Value materializeToTensor(OpBuilder &builder, TensorType type,
26                                  ValueRange inputs, Location loc) {
27   assert(inputs.size() == 1);
28   assert(inputs[0].getType().isa<BaseMemRefType>());
29   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
30 }
31 
32 /// Registers conversions into BufferizeTypeConverter
33 BufferizeTypeConverter::BufferizeTypeConverter() {
34   // Keep all types unchanged.
35   addConversion([](Type type) { return type; });
36   // Convert RankedTensorType to MemRefType.
37   addConversion([](RankedTensorType type) -> Type {
38     return MemRefType::get(type.getShape(), type.getElementType());
39   });
40   // Convert UnrankedTensorType to UnrankedMemRefType.
41   addConversion([](UnrankedTensorType type) -> Type {
42     return UnrankedMemRefType::get(type.getElementType(), 0);
43   });
44   addArgumentMaterialization(materializeToTensor);
45   addSourceMaterialization(materializeToTensor);
46   addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
47                               ValueRange inputs, Location loc) -> Value {
48     assert(inputs.size() == 1 && "expected exactly one input");
49 
50     if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) {
51       // MemRef to MemRef cast.
52       assert(inputType != type && "expected different types");
53       // Unranked to ranked and ranked to unranked casts must be explicit.
54       auto rankedDestType = type.dyn_cast<MemRefType>();
55       if (!rankedDestType)
56         return nullptr;
57       FailureOr<Value> replacement =
58           castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
59       if (failed(replacement))
60         return nullptr;
61       return *replacement;
62     }
63 
64     if (inputs[0].getType().isa<TensorType>()) {
65       // Tensor to MemRef cast.
66       return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
67     }
68 
69     llvm_unreachable("only tensor/memref input types supported");
70   });
71 }
72 
73 void mlir::bufferization::populateBufferizeMaterializationLegality(
74     ConversionTarget &target) {
75   target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
76 }
77 
78 namespace {
79 // In a finalizing bufferize conversion, we know that all tensors have been
80 // converted to memrefs, thus, this op becomes an identity.
81 class BufferizeToTensorOp
82     : public OpConversionPattern<bufferization::ToTensorOp> {
83 public:
84   using OpConversionPattern::OpConversionPattern;
85   LogicalResult
86   matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
87                   ConversionPatternRewriter &rewriter) const override {
88     rewriter.replaceOp(op, adaptor.memref());
89     return success();
90   }
91 };
92 } // namespace
93 
94 namespace {
95 // In a finalizing bufferize conversion, we know that all tensors have been
96 // converted to memrefs, thus, this op becomes an identity.
97 class BufferizeToMemrefOp
98     : public OpConversionPattern<bufferization::ToMemrefOp> {
99 public:
100   using OpConversionPattern::OpConversionPattern;
101   LogicalResult
102   matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
103                   ConversionPatternRewriter &rewriter) const override {
104     rewriter.replaceOp(op, adaptor.tensor());
105     return success();
106   }
107 };
108 } // namespace
109 
110 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
111     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
112   patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
113                                                          patterns.getContext());
114 }
115 
116 namespace {
117 struct FinalizingBufferizePass
118     : public FinalizingBufferizeBase<FinalizingBufferizePass> {
119   using FinalizingBufferizeBase<
120       FinalizingBufferizePass>::FinalizingBufferizeBase;
121 
122   void runOnOperation() override {
123     auto func = getOperation();
124     auto *context = &getContext();
125 
126     BufferizeTypeConverter typeConverter;
127     RewritePatternSet patterns(context);
128     ConversionTarget target(*context);
129 
130     populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
131 
132     // If all result types are legal, and all block arguments are legal (ensured
133     // by func conversion above), then all types in the program are legal.
134     //
135     // We also check that the operand types are legal to avoid creating invalid
136     // IR. For example, this prevents
137     // populateEliminateBufferizeMaterializationsPatterns from updating the
138     // types of the operands to a return op without updating the enclosing
139     // function.
140     target.markUnknownOpDynamicallyLegal(
141         [&](Operation *op) { return typeConverter.isLegal(op); });
142 
143     if (failed(applyFullConversion(func, target, std::move(patterns))))
144       signalPassFailure();
145   }
146 };
147 } // namespace
148 
149 std::unique_ptr<OperationPass<FuncOp>>
150 mlir::bufferization::createFinalizingBufferizePass() {
151   return std::make_unique<FinalizingBufferizePass>();
152 }
153 
154 //===----------------------------------------------------------------------===//
155 // BufferizableOpInterface-based Bufferization
156 //===----------------------------------------------------------------------===//
157 
158 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
159 
160 /// Return true if the given op has a tensor result or a tensor operand.
161 static bool hasTensorSemantics(Operation *op) {
162   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
163   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
164   return hasTensorResult || hasTensorOperand;
165 }
166 
167 /// Rewrite pattern that bufferizes bufferizable ops.
168 struct BufferizationPattern
169     : public OpInterfaceRewritePattern<BufferizableOpInterface> {
170   BufferizationPattern(MLIRContext *context, const BufferizationState &state,
171                        PatternBenefit benefit = 1)
172       : OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit),
173         state(state) {}
174 
175   LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp,
176                                 PatternRewriter &rewriter) const override {
177     // No tensors => no buffers.
178     if (!hasTensorSemantics(bufferizableOp.getOperation()))
179       return failure();
180     if (!state.getOptions().isOpAllowed(bufferizableOp.getOperation()))
181       return failure();
182     return bufferizableOp.bufferize(rewriter, state);
183   }
184 
185 private:
186   const BufferizationState &state;
187 };
188 
189 /// Check the result of bufferization. Return an error if an op was not
190 /// bufferized, unless partial bufferization is allowed.
191 static LogicalResult
192 checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
193   if (!options.allowUnknownOps) {
194     // Check if all ops were bufferized.
195     LogicalResult status = success();
196     op->walk([&](Operation *op) {
197       if (!hasTensorSemantics(op))
198         return WalkResult::advance();
199 
200       // Bufferization dialect ops will canonicalize away if all other ops are
201       // bufferized.
202       if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op))
203         return WalkResult::advance();
204 
205       // Ops that are not in the allow list can be ignored.
206       if (!options.isOpAllowed(op))
207         return WalkResult::advance();
208 
209       // Ops without any uses and no side effects will fold away.
210       if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
211         return WalkResult::advance();
212 
213       status = op->emitError("op was not bufferized");
214       return WalkResult::interrupt();
215     });
216 
217     if (failed(status))
218       return status;
219   }
220 
221   return success();
222 }
223 
224 LogicalResult bufferization::bufferizeOp(Operation *op,
225                                          const BufferizationState &state) {
226   // Bufferize the op and its nested ops.
227   RewritePatternSet patterns(op->getContext());
228   populateBufferizationPattern(state, patterns);
229   if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
230     return failure();
231 
232   return checkBufferizationResult(op, state.getOptions());
233 }
234 
235 namespace {
236 /// This a "no analysis, always copy" BufferizationState. In the absence of an
237 /// analysis, a buffer must be copied each time it is written to. Therefore, all
238 /// OpOperands that bufferize to a memory write must bufferize out-of-place.
239 class AlwaysCopyBufferizationState : public BufferizationState {
240 public:
241   AlwaysCopyBufferizationState(const BufferizationOptions &options)
242       : BufferizationState(options) {}
243 
244   AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete;
245 
246   virtual ~AlwaysCopyBufferizationState() = default;
247 
248   /// Return `true` if the given OpResult has been decided to bufferize inplace.
249   bool isInPlace(OpOperand &opOperand) const override {
250     // OpOperands that bufferize to a memory write are out-of-place, i.e., an
251     // alloc and copy is inserted.
252     return !bufferizesToMemoryWrite(opOperand);
253   }
254 
255   /// Return true if `v1` and `v2` bufferize to equivalent buffers.
256   bool areEquivalentBufferizedValues(Value v1, Value v2) const override {
257     // There is no analysis, so we do not know if the values are equivalent. The
258     // conservative answer is "false".
259     return false;
260   }
261 };
262 } // namespace
263 
264 LogicalResult bufferization::bufferizeOp(Operation *op,
265                                          const BufferizationOptions &options) {
266   AlwaysCopyBufferizationState state(options);
267   return bufferizeOp(op, state);
268 }
269 
270 void bufferization::populateBufferizationPattern(
271     const BufferizationState &state, RewritePatternSet &patterns) {
272   patterns.add<BufferizationPattern>(patterns.getContext(), state);
273 }
274 
275 BufferizationOptions bufferization::getPartialBufferizationOptions() {
276   BufferizationOptions options;
277   options.allowReturnMemref = true;
278   options.allowUnknownOps = true;
279   options.createDeallocs = false;
280   options.fullyDynamicLayoutMaps = false;
281   return options;
282 }
283