1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 15 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Pass/PassManager.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "mlir/Transforms/Passes.h" 21 22 using namespace mlir; 23 using namespace mlir::bufferization; 24 25 //===----------------------------------------------------------------------===// 26 // BufferizeTypeConverter 27 //===----------------------------------------------------------------------===// 28 29 static Value materializeToTensor(OpBuilder &builder, TensorType type, 30 ValueRange inputs, Location loc) { 31 assert(inputs.size() == 1); 32 assert(inputs[0].getType().isa<BaseMemRefType>()); 33 return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]); 34 } 35 36 /// Registers conversions into BufferizeTypeConverter 37 BufferizeTypeConverter::BufferizeTypeConverter() { 38 // Keep all types unchanged. 39 addConversion([](Type type) { return type; }); 40 // Convert RankedTensorType to MemRefType. 41 addConversion([](RankedTensorType type) -> Type { 42 return MemRefType::get(type.getShape(), type.getElementType()); 43 }); 44 // Convert UnrankedTensorType to UnrankedMemRefType. 45 addConversion([](UnrankedTensorType type) -> Type { 46 return UnrankedMemRefType::get(type.getElementType(), 0); 47 }); 48 addArgumentMaterialization(materializeToTensor); 49 addSourceMaterialization(materializeToTensor); 50 addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, 51 ValueRange inputs, Location loc) -> Value { 52 assert(inputs.size() == 1 && "expected exactly one input"); 53 54 if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) { 55 // MemRef to MemRef cast. 56 assert(inputType != type && "expected different types"); 57 // Unranked to ranked and ranked to unranked casts must be explicit. 58 auto rankedDestType = type.dyn_cast<MemRefType>(); 59 if (!rankedDestType) 60 return nullptr; 61 FailureOr<Value> replacement = 62 castOrReallocMemRefValue(builder, inputs[0], rankedDestType); 63 if (failed(replacement)) 64 return nullptr; 65 return *replacement; 66 } 67 68 if (inputs[0].getType().isa<TensorType>()) { 69 // Tensor to MemRef cast. 70 return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]); 71 } 72 73 llvm_unreachable("only tensor/memref input types supported"); 74 }); 75 } 76 77 void mlir::bufferization::populateBufferizeMaterializationLegality( 78 ConversionTarget &target) { 79 target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>(); 80 } 81 82 namespace { 83 // In a finalizing bufferize conversion, we know that all tensors have been 84 // converted to memrefs, thus, this op becomes an identity. 85 class BufferizeToTensorOp 86 : public OpConversionPattern<bufferization::ToTensorOp> { 87 public: 88 using OpConversionPattern::OpConversionPattern; 89 LogicalResult 90 matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor, 91 ConversionPatternRewriter &rewriter) const override { 92 rewriter.replaceOp(op, adaptor.memref()); 93 return success(); 94 } 95 }; 96 } // namespace 97 98 namespace { 99 // In a finalizing bufferize conversion, we know that all tensors have been 100 // converted to memrefs, thus, this op becomes an identity. 101 class BufferizeToMemrefOp 102 : public OpConversionPattern<bufferization::ToMemrefOp> { 103 public: 104 using OpConversionPattern::OpConversionPattern; 105 LogicalResult 106 matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor, 107 ConversionPatternRewriter &rewriter) const override { 108 rewriter.replaceOp(op, adaptor.tensor()); 109 return success(); 110 } 111 }; 112 } // namespace 113 114 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns( 115 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 116 patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter, 117 patterns.getContext()); 118 } 119 120 namespace { 121 struct FinalizingBufferizePass 122 : public FinalizingBufferizeBase<FinalizingBufferizePass> { 123 using FinalizingBufferizeBase< 124 FinalizingBufferizePass>::FinalizingBufferizeBase; 125 126 void runOnOperation() override { 127 auto func = getOperation(); 128 auto *context = &getContext(); 129 130 BufferizeTypeConverter typeConverter; 131 RewritePatternSet patterns(context); 132 ConversionTarget target(*context); 133 134 populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns); 135 136 // If all result types are legal, and all block arguments are legal (ensured 137 // by func conversion above), then all types in the program are legal. 138 // 139 // We also check that the operand types are legal to avoid creating invalid 140 // IR. For example, this prevents 141 // populateEliminateBufferizeMaterializationsPatterns from updating the 142 // types of the operands to a return op without updating the enclosing 143 // function. 144 target.markUnknownOpDynamicallyLegal( 145 [&](Operation *op) { return typeConverter.isLegal(op); }); 146 147 if (failed(applyFullConversion(func, target, std::move(patterns)))) 148 signalPassFailure(); 149 } 150 }; 151 152 struct OneShotBufferizePass 153 : public OneShotBufferizeBase<OneShotBufferizePass> { 154 OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {} 155 156 explicit OneShotBufferizePass(const AnalysisBufferizationOptions &options) 157 : options(options) {} 158 159 void getDependentDialects(DialectRegistry ®istry) const override { 160 registry.insert<bufferization::BufferizationDialect>(); 161 } 162 163 void runOnOperation() override { 164 AnalysisBufferizationOptions opt; 165 if (!options) { 166 // Make new bufferization options if none were provided when creating the 167 // pass. 168 opt.allowReturnMemref = allowReturnMemref; 169 opt.allowUnknownOps = allowUnknownOps; 170 opt.analysisFuzzerSeed = analysisFuzzerSeed; 171 opt.createDeallocs = createDeallocs; 172 opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps; 173 opt.printConflicts = printConflicts; 174 opt.testAnalysisOnly = testAnalysisOnly; 175 176 BufferizationOptions::OpFilterEntry::FilterFn filterFn = 177 [&](Operation *op) { 178 // Disallow non-func dialect ops. I.e., no ops related to function 179 // calls. 180 if (isa<func::FuncDialect>(op->getDialect())) 181 return false; 182 // Filter may be specified via options. 183 if (this->dialectFilter.hasValue()) 184 return llvm::find(this->dialectFilter, 185 op->getDialect()->getNamespace()) != 186 this->dialectFilter.end(); 187 // No filter specified: All other ops are allowed. 188 return true; 189 }; 190 opt.allowOperationInFilter(filterFn); 191 } else { 192 opt = *options; 193 } 194 195 ModuleOp moduleOp = getOperation(); 196 if (failed(runOneShotBufferize(moduleOp, opt))) { 197 signalPassFailure(); 198 return; 199 } 200 201 if (opt.testAnalysisOnly) 202 return; 203 204 OpPassManager cleanupPipeline("builtin.module"); 205 cleanupPipeline.addPass(createCanonicalizerPass()); 206 cleanupPipeline.addPass(createCSEPass()); 207 cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); 208 (void)runPipeline(cleanupPipeline, moduleOp); 209 } 210 211 private: 212 llvm::Optional<AnalysisBufferizationOptions> options; 213 }; 214 } // namespace 215 216 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() { 217 return std::make_unique<OneShotBufferizePass>(); 218 } 219 220 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass( 221 const AnalysisBufferizationOptions &options) { 222 return std::make_unique<OneShotBufferizePass>(options); 223 } 224 225 std::unique_ptr<OperationPass<FuncOp>> 226 mlir::bufferization::createFinalizingBufferizePass() { 227 return std::make_unique<FinalizingBufferizePass>(); 228 } 229 230 //===----------------------------------------------------------------------===// 231 // BufferizableOpInterface-based Bufferization 232 //===----------------------------------------------------------------------===// 233 234 static bool isaTensor(Type t) { return t.isa<TensorType>(); } 235 236 /// Return true if the given op has a tensor result or a tensor operand. 237 static bool hasTensorSemantics(Operation *op) { 238 bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); 239 bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); 240 return hasTensorResult || hasTensorOperand; 241 } 242 243 /// Rewrite pattern that bufferizes bufferizable ops. 244 struct BufferizationPattern 245 : public OpInterfaceRewritePattern<BufferizableOpInterface> { 246 BufferizationPattern(MLIRContext *context, const BufferizationState &state, 247 PatternBenefit benefit = 1) 248 : OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit), 249 state(state) {} 250 251 LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp, 252 PatternRewriter &rewriter) const override { 253 // No tensors => no buffers. 254 if (!hasTensorSemantics(bufferizableOp.getOperation())) 255 return failure(); 256 if (!state.getOptions().isOpAllowed(bufferizableOp.getOperation())) 257 return failure(); 258 return bufferizableOp.bufferize(rewriter, state); 259 } 260 261 private: 262 const BufferizationState &state; 263 }; 264 265 /// Check the result of bufferization. Return an error if an op was not 266 /// bufferized, unless partial bufferization is allowed. 267 static LogicalResult 268 checkBufferizationResult(Operation *op, const BufferizationOptions &options) { 269 if (!options.allowUnknownOps) { 270 // Check if all ops were bufferized. 271 LogicalResult status = success(); 272 op->walk([&](Operation *op) { 273 if (!hasTensorSemantics(op)) 274 return WalkResult::advance(); 275 276 // Bufferization dialect ops will canonicalize away if all other ops are 277 // bufferized. 278 if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op)) 279 return WalkResult::advance(); 280 281 // Ops that are not in the allow list can be ignored. 282 if (!options.isOpAllowed(op)) 283 return WalkResult::advance(); 284 285 // Ops without any uses and no side effects will fold away. 286 if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) 287 return WalkResult::advance(); 288 289 status = op->emitError("op was not bufferized"); 290 return WalkResult::interrupt(); 291 }); 292 293 if (failed(status)) 294 return status; 295 } 296 297 return success(); 298 } 299 300 LogicalResult bufferization::bufferizeOp(Operation *op, 301 const BufferizationState &state) { 302 // Bufferize the op and its nested ops. 303 RewritePatternSet patterns(op->getContext()); 304 populateBufferizationPattern(state, patterns); 305 306 // Bufferize ops top-to-bottom. When creating a new op, we should ideally 307 // know the exact memref type of all operands. Otherwise, we have to use a 308 // memref type with a fully dynamic layout map, which has to canonicalize 309 // away. 310 // Moreover, if "fullyDynamicLayoutMaps = false", we may otherwise have to 311 // insert buffer copies to fold ("finalize") to_memref(to_tensor(x)) ops with 312 // non-cast-compatible layout maps. 313 GreedyRewriteConfig config; 314 config.useTopDownTraversal = true; 315 316 if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) 317 return failure(); 318 319 return checkBufferizationResult(op, state.getOptions()); 320 } 321 322 namespace { 323 /// This a "no analysis, always copy" BufferizationState. In the absence of an 324 /// analysis, a buffer must be copied each time it is written to. Therefore, all 325 /// OpOperands that bufferize to a memory write must bufferize out-of-place. 326 class AlwaysCopyBufferizationState : public BufferizationState { 327 public: 328 AlwaysCopyBufferizationState(const BufferizationOptions &options) 329 : BufferizationState(options) {} 330 331 AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete; 332 333 virtual ~AlwaysCopyBufferizationState() = default; 334 335 /// Return `true` if the given OpResult has been decided to bufferize inplace. 336 bool isInPlace(OpOperand &opOperand) const override { 337 // OpOperands that bufferize to a memory write are out-of-place, i.e., an 338 // alloc and copy is inserted. 339 return !bufferizesToMemoryWrite(opOperand); 340 } 341 342 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 343 bool areEquivalentBufferizedValues(Value v1, Value v2) const override { 344 // There is no analysis, so we do not know if the values are equivalent. The 345 // conservative answer is "false". 346 return false; 347 } 348 }; 349 } // namespace 350 351 LogicalResult bufferization::bufferizeOp(Operation *op, 352 const BufferizationOptions &options) { 353 AlwaysCopyBufferizationState state(options); 354 return bufferizeOp(op, state); 355 } 356 357 void bufferization::populateBufferizationPattern( 358 const BufferizationState &state, RewritePatternSet &patterns) { 359 patterns.add<BufferizationPattern>(patterns.getContext(), state); 360 } 361 362 BufferizationOptions bufferization::getPartialBufferizationOptions() { 363 BufferizationOptions options; 364 options.allowReturnMemref = true; 365 options.allowUnknownOps = true; 366 options.createDeallocs = false; 367 options.fullyDynamicLayoutMaps = false; 368 return options; 369 } 370