1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 
21 //===----------------------------------------------------------------------===//
22 // BufferizeTypeConverter
23 //===----------------------------------------------------------------------===//
24 
25 static Value materializeToTensor(OpBuilder &builder, TensorType type,
26                                  ValueRange inputs, Location loc) {
27   assert(inputs.size() == 1);
28   assert(inputs[0].getType().isa<BaseMemRefType>());
29   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
30 }
31 
32 /// Registers conversions into BufferizeTypeConverter
33 BufferizeTypeConverter::BufferizeTypeConverter() {
34   // Keep all types unchanged.
35   addConversion([](Type type) { return type; });
36   // Convert RankedTensorType to MemRefType.
37   addConversion([](RankedTensorType type) -> Type {
38     return MemRefType::get(type.getShape(), type.getElementType());
39   });
40   // Convert UnrankedTensorType to UnrankedMemRefType.
41   addConversion([](UnrankedTensorType type) -> Type {
42     return UnrankedMemRefType::get(type.getElementType(), 0);
43   });
44   addArgumentMaterialization(materializeToTensor);
45   addSourceMaterialization(materializeToTensor);
46   addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
47                               ValueRange inputs, Location loc) -> Value {
48     assert(inputs.size() == 1);
49     assert(inputs[0].getType().isa<TensorType>());
50     return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
51   });
52 }
53 
54 void mlir::bufferization::populateBufferizeMaterializationLegality(
55     ConversionTarget &target) {
56   target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
57 }
58 
59 namespace {
60 // In a finalizing bufferize conversion, we know that all tensors have been
61 // converted to memrefs, thus, this op becomes an identity.
62 class BufferizeToTensorOp
63     : public OpConversionPattern<bufferization::ToTensorOp> {
64 public:
65   using OpConversionPattern::OpConversionPattern;
66   LogicalResult
67   matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
68                   ConversionPatternRewriter &rewriter) const override {
69     rewriter.replaceOp(op, adaptor.memref());
70     return success();
71   }
72 };
73 } // namespace
74 
75 namespace {
76 // In a finalizing bufferize conversion, we know that all tensors have been
77 // converted to memrefs, thus, this op becomes an identity.
78 class BufferizeToMemrefOp
79     : public OpConversionPattern<bufferization::ToMemrefOp> {
80 public:
81   using OpConversionPattern::OpConversionPattern;
82   LogicalResult
83   matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
84                   ConversionPatternRewriter &rewriter) const override {
85     rewriter.replaceOp(op, adaptor.tensor());
86     return success();
87   }
88 };
89 } // namespace
90 
91 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
92     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
93   patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
94                                                          patterns.getContext());
95 }
96 
97 namespace {
98 struct FinalizingBufferizePass
99     : public FinalizingBufferizeBase<FinalizingBufferizePass> {
100   using FinalizingBufferizeBase<
101       FinalizingBufferizePass>::FinalizingBufferizeBase;
102 
103   void runOnOperation() override {
104     auto func = getOperation();
105     auto *context = &getContext();
106 
107     BufferizeTypeConverter typeConverter;
108     RewritePatternSet patterns(context);
109     ConversionTarget target(*context);
110 
111     populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
112 
113     // If all result types are legal, and all block arguments are legal (ensured
114     // by func conversion above), then all types in the program are legal.
115     //
116     // We also check that the operand types are legal to avoid creating invalid
117     // IR. For example, this prevents
118     // populateEliminateBufferizeMaterializationsPatterns from updating the
119     // types of the operands to a return op without updating the enclosing
120     // function.
121     target.markUnknownOpDynamicallyLegal(
122         [&](Operation *op) { return typeConverter.isLegal(op); });
123 
124     if (failed(applyFullConversion(func, target, std::move(patterns))))
125       signalPassFailure();
126   }
127 };
128 } // namespace
129 
130 std::unique_ptr<OperationPass<FuncOp>>
131 mlir::bufferization::createFinalizingBufferizePass() {
132   return std::make_unique<FinalizingBufferizePass>();
133 }
134 
135 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
136 
137 /// Return true if the given op has a tensor result or a tensor operand.
138 static bool hasTensorSemantics(Operation *op) {
139   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
140   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
141   return hasTensorResult || hasTensorOperand;
142 }
143 
144 /// Rewrite pattern that bufferizes bufferizable ops.
145 struct BufferizationPattern
146     : public OpInterfaceRewritePattern<BufferizableOpInterface> {
147   BufferizationPattern(MLIRContext *context, const BufferizationState &state,
148                        PatternBenefit benefit = 1)
149       : OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit),
150         state(state) {}
151 
152   LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp,
153                                 PatternRewriter &rewriter) const override {
154     // No tensors => no buffers.
155     if (!hasTensorSemantics(bufferizableOp.getOperation()))
156       return failure();
157     if (!state.getOptions().isOpAllowed(bufferizableOp.getOperation()))
158       return failure();
159     return bufferizableOp.bufferize(rewriter, state);
160   }
161 
162 private:
163   const BufferizationState &state;
164 };
165 
166 /// Check the result of bufferization. Return an error if an op was not
167 /// bufferized, unless partial bufferization is allowed.
168 static LogicalResult
169 checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
170   if (!options.allowUnknownOps) {
171     // Check if all ops were bufferized.
172     LogicalResult status = success();
173     op->walk([&](Operation *op) {
174       if (!hasTensorSemantics(op))
175         return WalkResult::advance();
176 
177       // Bufferization dialect ops will canonicalize away if all other ops are
178       // bufferized.
179       if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op))
180         return WalkResult::advance();
181 
182       // Ops that are not in the allow list can be ignored.
183       if (!options.isOpAllowed(op))
184         return WalkResult::advance();
185 
186       // Ops without any uses and no side effects will fold away.
187       if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
188         return WalkResult::advance();
189 
190       status = op->emitError("op was not bufferized");
191       return WalkResult::interrupt();
192     });
193 
194     if (failed(status))
195       return status;
196   }
197 
198   return success();
199 }
200 
201 LogicalResult bufferization::bufferizeOp(Operation *op,
202                                          const BufferizationState &state) {
203   // Bufferize the op and its nested ops.
204   OwningRewritePatternList patterns(op->getContext());
205   patterns.add<BufferizationPattern>(op->getContext(), state);
206   if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
207     return failure();
208 
209   return checkBufferizationResult(op, state.getOptions());
210 }
211