1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17 18 using namespace mlir; 19 using namespace mlir::bufferization; 20 21 //===----------------------------------------------------------------------===// 22 // BufferizeTypeConverter 23 //===----------------------------------------------------------------------===// 24 25 static Value materializeToTensor(OpBuilder &builder, TensorType type, 26 ValueRange inputs, Location loc) { 27 assert(inputs.size() == 1); 28 assert(inputs[0].getType().isa<BaseMemRefType>()); 29 return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]); 30 } 31 32 /// Registers conversions into BufferizeTypeConverter 33 BufferizeTypeConverter::BufferizeTypeConverter() { 34 // Keep all types unchanged. 35 addConversion([](Type type) { return type; }); 36 // Convert RankedTensorType to MemRefType. 37 addConversion([](RankedTensorType type) -> Type { 38 return MemRefType::get(type.getShape(), type.getElementType()); 39 }); 40 // Convert UnrankedTensorType to UnrankedMemRefType. 41 addConversion([](UnrankedTensorType type) -> Type { 42 return UnrankedMemRefType::get(type.getElementType(), 0); 43 }); 44 addArgumentMaterialization(materializeToTensor); 45 addSourceMaterialization(materializeToTensor); 46 addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, 47 ValueRange inputs, Location loc) -> Value { 48 assert(inputs.size() == 1); 49 assert(inputs[0].getType().isa<TensorType>()); 50 return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]); 51 }); 52 } 53 54 void mlir::bufferization::populateBufferizeMaterializationLegality( 55 ConversionTarget &target) { 56 target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>(); 57 } 58 59 namespace { 60 // In a finalizing bufferize conversion, we know that all tensors have been 61 // converted to memrefs, thus, this op becomes an identity. 62 class BufferizeToTensorOp 63 : public OpConversionPattern<bufferization::ToTensorOp> { 64 public: 65 using OpConversionPattern::OpConversionPattern; 66 LogicalResult 67 matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor, 68 ConversionPatternRewriter &rewriter) const override { 69 rewriter.replaceOp(op, adaptor.memref()); 70 return success(); 71 } 72 }; 73 } // namespace 74 75 namespace { 76 // In a finalizing bufferize conversion, we know that all tensors have been 77 // converted to memrefs, thus, this op becomes an identity. 78 class BufferizeToMemrefOp 79 : public OpConversionPattern<bufferization::ToMemrefOp> { 80 public: 81 using OpConversionPattern::OpConversionPattern; 82 LogicalResult 83 matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor, 84 ConversionPatternRewriter &rewriter) const override { 85 rewriter.replaceOp(op, adaptor.tensor()); 86 return success(); 87 } 88 }; 89 } // namespace 90 91 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns( 92 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 93 patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter, 94 patterns.getContext()); 95 } 96 97 namespace { 98 struct FinalizingBufferizePass 99 : public FinalizingBufferizeBase<FinalizingBufferizePass> { 100 using FinalizingBufferizeBase< 101 FinalizingBufferizePass>::FinalizingBufferizeBase; 102 103 void runOnOperation() override { 104 auto func = getOperation(); 105 auto *context = &getContext(); 106 107 BufferizeTypeConverter typeConverter; 108 RewritePatternSet patterns(context); 109 ConversionTarget target(*context); 110 111 populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns); 112 113 // If all result types are legal, and all block arguments are legal (ensured 114 // by func conversion above), then all types in the program are legal. 115 // 116 // We also check that the operand types are legal to avoid creating invalid 117 // IR. For example, this prevents 118 // populateEliminateBufferizeMaterializationsPatterns from updating the 119 // types of the operands to a return op without updating the enclosing 120 // function. 121 target.markUnknownOpDynamicallyLegal( 122 [&](Operation *op) { return typeConverter.isLegal(op); }); 123 124 if (failed(applyFullConversion(func, target, std::move(patterns)))) 125 signalPassFailure(); 126 } 127 }; 128 } // namespace 129 130 std::unique_ptr<OperationPass<FuncOp>> 131 mlir::bufferization::createFinalizingBufferizePass() { 132 return std::make_unique<FinalizingBufferizePass>(); 133 } 134 135 //===----------------------------------------------------------------------===// 136 // BufferizableOpInterface-based Bufferization 137 //===----------------------------------------------------------------------===// 138 139 static bool isaTensor(Type t) { return t.isa<TensorType>(); } 140 141 /// Return true if the given op has a tensor result or a tensor operand. 142 static bool hasTensorSemantics(Operation *op) { 143 bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); 144 bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); 145 return hasTensorResult || hasTensorOperand; 146 } 147 148 /// Rewrite pattern that bufferizes bufferizable ops. 149 struct BufferizationPattern 150 : public OpInterfaceRewritePattern<BufferizableOpInterface> { 151 BufferizationPattern(MLIRContext *context, const BufferizationState &state, 152 PatternBenefit benefit = 1) 153 : OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit), 154 state(state) {} 155 156 LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp, 157 PatternRewriter &rewriter) const override { 158 // No tensors => no buffers. 159 if (!hasTensorSemantics(bufferizableOp.getOperation())) 160 return failure(); 161 if (!state.getOptions().isOpAllowed(bufferizableOp.getOperation())) 162 return failure(); 163 return bufferizableOp.bufferize(rewriter, state); 164 } 165 166 private: 167 const BufferizationState &state; 168 }; 169 170 /// Check the result of bufferization. Return an error if an op was not 171 /// bufferized, unless partial bufferization is allowed. 172 static LogicalResult 173 checkBufferizationResult(Operation *op, const BufferizationOptions &options) { 174 if (!options.allowUnknownOps) { 175 // Check if all ops were bufferized. 176 LogicalResult status = success(); 177 op->walk([&](Operation *op) { 178 if (!hasTensorSemantics(op)) 179 return WalkResult::advance(); 180 181 // Bufferization dialect ops will canonicalize away if all other ops are 182 // bufferized. 183 if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op)) 184 return WalkResult::advance(); 185 186 // Ops that are not in the allow list can be ignored. 187 if (!options.isOpAllowed(op)) 188 return WalkResult::advance(); 189 190 // Ops without any uses and no side effects will fold away. 191 if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) 192 return WalkResult::advance(); 193 194 status = op->emitError("op was not bufferized"); 195 return WalkResult::interrupt(); 196 }); 197 198 if (failed(status)) 199 return status; 200 } 201 202 return success(); 203 } 204 205 LogicalResult bufferization::bufferizeOp(Operation *op, 206 const BufferizationState &state) { 207 // Bufferize the op and its nested ops. 208 RewritePatternSet patterns(op->getContext()); 209 populateBufferizationPattern(state, patterns); 210 if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) 211 return failure(); 212 213 return checkBufferizationResult(op, state.getOptions()); 214 } 215 216 namespace { 217 /// This a "no analysis, always copy" BufferizationState. In the absence of an 218 /// analysis, a buffer must be copied each time it is written to. Therefore, all 219 /// OpOperands that bufferize to a memory write must bufferize out-of-place. 220 class AlwaysCopyBufferizationState : public BufferizationState { 221 public: 222 AlwaysCopyBufferizationState(const BufferizationOptions &options) 223 : BufferizationState(options) {} 224 225 AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete; 226 227 virtual ~AlwaysCopyBufferizationState() = default; 228 229 /// Return `true` if the given OpResult has been decided to bufferize inplace. 230 bool isInPlace(OpOperand &opOperand) const override { 231 // OpOperands that bufferize to a memory write are out-of-place, i.e., an 232 // alloc and copy is inserted. 233 return !bufferizesToMemoryWrite(opOperand); 234 } 235 236 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 237 bool areEquivalentBufferizedValues(Value v1, Value v2) const override { 238 // There is no analysis, so we do not know if the values are equivalent. The 239 // conservative answer is "false". 240 return false; 241 } 242 }; 243 } // namespace 244 245 LogicalResult bufferization::bufferizeOp(Operation *op, 246 const BufferizationOptions &options) { 247 AlwaysCopyBufferizationState state(options); 248 return bufferizeOp(op, state); 249 } 250 251 void bufferization::populateBufferizationPattern( 252 const BufferizationState &state, RewritePatternSet &patterns) { 253 patterns.add<BufferizationPattern>(patterns.getContext(), state); 254 } 255 256 BufferizationOptions bufferization::getPartialBufferizationOptions() { 257 BufferizationOptions options; 258 options.allowReturnMemref = true; 259 options.allowUnknownOps = true; 260 options.createDeallocs = false; 261 options.fullyDynamicLayoutMaps = false; 262 return options; 263 } 264