1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17 18 using namespace mlir; 19 using namespace mlir::bufferization; 20 21 //===----------------------------------------------------------------------===// 22 // BufferizeTypeConverter 23 //===----------------------------------------------------------------------===// 24 25 static Value materializeToTensor(OpBuilder &builder, TensorType type, 26 ValueRange inputs, Location loc) { 27 assert(inputs.size() == 1); 28 assert(inputs[0].getType().isa<BaseMemRefType>()); 29 return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]); 30 } 31 32 /// Registers conversions into BufferizeTypeConverter 33 BufferizeTypeConverter::BufferizeTypeConverter() { 34 // Keep all types unchanged. 35 addConversion([](Type type) { return type; }); 36 // Convert RankedTensorType to MemRefType. 37 addConversion([](RankedTensorType type) -> Type { 38 return MemRefType::get(type.getShape(), type.getElementType()); 39 }); 40 // Convert UnrankedTensorType to UnrankedMemRefType. 41 addConversion([](UnrankedTensorType type) -> Type { 42 return UnrankedMemRefType::get(type.getElementType(), 0); 43 }); 44 addArgumentMaterialization(materializeToTensor); 45 addSourceMaterialization(materializeToTensor); 46 addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, 47 ValueRange inputs, Location loc) -> Value { 48 assert(inputs.size() == 1 && "expected exactly one input"); 49 50 if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) { 51 // MemRef to MemRef cast. 52 assert(inputType != type && "expected different types"); 53 // Unranked to ranked and ranked to unranked casts must be explicit. 54 auto rankedDestType = type.dyn_cast<MemRefType>(); 55 if (!rankedDestType) 56 return nullptr; 57 FailureOr<Value> replacement = 58 castOrReallocMemRefValue(builder, inputs[0], rankedDestType); 59 if (failed(replacement)) 60 return nullptr; 61 return *replacement; 62 } 63 64 if (inputs[0].getType().isa<TensorType>()) { 65 // Tensor to MemRef cast. 66 return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]); 67 } 68 69 llvm_unreachable("only tensor/memref input types supported"); 70 }); 71 } 72 73 void mlir::bufferization::populateBufferizeMaterializationLegality( 74 ConversionTarget &target) { 75 target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>(); 76 } 77 78 namespace { 79 // In a finalizing bufferize conversion, we know that all tensors have been 80 // converted to memrefs, thus, this op becomes an identity. 81 class BufferizeToTensorOp 82 : public OpConversionPattern<bufferization::ToTensorOp> { 83 public: 84 using OpConversionPattern::OpConversionPattern; 85 LogicalResult 86 matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor, 87 ConversionPatternRewriter &rewriter) const override { 88 rewriter.replaceOp(op, adaptor.memref()); 89 return success(); 90 } 91 }; 92 } // namespace 93 94 namespace { 95 // In a finalizing bufferize conversion, we know that all tensors have been 96 // converted to memrefs, thus, this op becomes an identity. 97 class BufferizeToMemrefOp 98 : public OpConversionPattern<bufferization::ToMemrefOp> { 99 public: 100 using OpConversionPattern::OpConversionPattern; 101 LogicalResult 102 matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor, 103 ConversionPatternRewriter &rewriter) const override { 104 rewriter.replaceOp(op, adaptor.tensor()); 105 return success(); 106 } 107 }; 108 } // namespace 109 110 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns( 111 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 112 patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter, 113 patterns.getContext()); 114 } 115 116 namespace { 117 struct FinalizingBufferizePass 118 : public FinalizingBufferizeBase<FinalizingBufferizePass> { 119 using FinalizingBufferizeBase< 120 FinalizingBufferizePass>::FinalizingBufferizeBase; 121 122 void runOnOperation() override { 123 auto func = getOperation(); 124 auto *context = &getContext(); 125 126 BufferizeTypeConverter typeConverter; 127 RewritePatternSet patterns(context); 128 ConversionTarget target(*context); 129 130 populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns); 131 132 // If all result types are legal, and all block arguments are legal (ensured 133 // by func conversion above), then all types in the program are legal. 134 // 135 // We also check that the operand types are legal to avoid creating invalid 136 // IR. For example, this prevents 137 // populateEliminateBufferizeMaterializationsPatterns from updating the 138 // types of the operands to a return op without updating the enclosing 139 // function. 140 target.markUnknownOpDynamicallyLegal( 141 [&](Operation *op) { return typeConverter.isLegal(op); }); 142 143 if (failed(applyFullConversion(func, target, std::move(patterns)))) 144 signalPassFailure(); 145 } 146 }; 147 } // namespace 148 149 std::unique_ptr<OperationPass<FuncOp>> 150 mlir::bufferization::createFinalizingBufferizePass() { 151 return std::make_unique<FinalizingBufferizePass>(); 152 } 153 154 //===----------------------------------------------------------------------===// 155 // BufferizableOpInterface-based Bufferization 156 //===----------------------------------------------------------------------===// 157 158 static bool isaTensor(Type t) { return t.isa<TensorType>(); } 159 160 /// Return true if the given op has a tensor result or a tensor operand. 161 static bool hasTensorSemantics(Operation *op) { 162 bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); 163 bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); 164 return hasTensorResult || hasTensorOperand; 165 } 166 167 /// Rewrite pattern that bufferizes bufferizable ops. 168 struct BufferizationPattern 169 : public OpInterfaceRewritePattern<BufferizableOpInterface> { 170 BufferizationPattern(MLIRContext *context, const BufferizationState &state, 171 PatternBenefit benefit = 1) 172 : OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit), 173 state(state) {} 174 175 LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp, 176 PatternRewriter &rewriter) const override { 177 // No tensors => no buffers. 178 if (!hasTensorSemantics(bufferizableOp.getOperation())) 179 return failure(); 180 if (!state.getOptions().isOpAllowed(bufferizableOp.getOperation())) 181 return failure(); 182 return bufferizableOp.bufferize(rewriter, state); 183 } 184 185 private: 186 const BufferizationState &state; 187 }; 188 189 /// Check the result of bufferization. Return an error if an op was not 190 /// bufferized, unless partial bufferization is allowed. 191 static LogicalResult 192 checkBufferizationResult(Operation *op, const BufferizationOptions &options) { 193 if (!options.allowUnknownOps) { 194 // Check if all ops were bufferized. 195 LogicalResult status = success(); 196 op->walk([&](Operation *op) { 197 if (!hasTensorSemantics(op)) 198 return WalkResult::advance(); 199 200 // Bufferization dialect ops will canonicalize away if all other ops are 201 // bufferized. 202 if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op)) 203 return WalkResult::advance(); 204 205 // Ops that are not in the allow list can be ignored. 206 if (!options.isOpAllowed(op)) 207 return WalkResult::advance(); 208 209 // Ops without any uses and no side effects will fold away. 210 if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) 211 return WalkResult::advance(); 212 213 status = op->emitError("op was not bufferized"); 214 return WalkResult::interrupt(); 215 }); 216 217 if (failed(status)) 218 return status; 219 } 220 221 return success(); 222 } 223 224 LogicalResult bufferization::bufferizeOp(Operation *op, 225 const BufferizationState &state) { 226 // Bufferize the op and its nested ops. 227 RewritePatternSet patterns(op->getContext()); 228 populateBufferizationPattern(state, patterns); 229 if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) 230 return failure(); 231 232 return checkBufferizationResult(op, state.getOptions()); 233 } 234 235 namespace { 236 /// This a "no analysis, always copy" BufferizationState. In the absence of an 237 /// analysis, a buffer must be copied each time it is written to. Therefore, all 238 /// OpOperands that bufferize to a memory write must bufferize out-of-place. 239 class AlwaysCopyBufferizationState : public BufferizationState { 240 public: 241 AlwaysCopyBufferizationState(const BufferizationOptions &options) 242 : BufferizationState(options) {} 243 244 AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete; 245 246 virtual ~AlwaysCopyBufferizationState() = default; 247 248 /// Return `true` if the given OpResult has been decided to bufferize inplace. 249 bool isInPlace(OpOperand &opOperand) const override { 250 // OpOperands that bufferize to a memory write are out-of-place, i.e., an 251 // alloc and copy is inserted. 252 return !bufferizesToMemoryWrite(opOperand); 253 } 254 255 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 256 bool areEquivalentBufferizedValues(Value v1, Value v2) const override { 257 // There is no analysis, so we do not know if the values are equivalent. The 258 // conservative answer is "false". 259 return false; 260 } 261 }; 262 } // namespace 263 264 LogicalResult bufferization::bufferizeOp(Operation *op, 265 const BufferizationOptions &options) { 266 AlwaysCopyBufferizationState state(options); 267 return bufferizeOp(op, state); 268 } 269 270 void bufferization::populateBufferizationPattern( 271 const BufferizationState &state, RewritePatternSet &patterns) { 272 patterns.add<BufferizationPattern>(patterns.getContext(), state); 273 } 274 275 BufferizationOptions bufferization::getPartialBufferizationOptions() { 276 BufferizationOptions options; 277 options.allowReturnMemref = true; 278 options.allowUnknownOps = true; 279 options.createDeallocs = false; 280 options.fullyDynamicLayoutMaps = false; 281 return options; 282 } 283