1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 15 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Pass/PassManager.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "mlir/Transforms/Passes.h" 21 22 using namespace mlir; 23 using namespace mlir::bufferization; 24 25 //===----------------------------------------------------------------------===// 26 // BufferizeTypeConverter 27 //===----------------------------------------------------------------------===// 28 29 static Value materializeToTensor(OpBuilder &builder, TensorType type, 30 ValueRange inputs, Location loc) { 31 assert(inputs.size() == 1); 32 assert(inputs[0].getType().isa<BaseMemRefType>()); 33 return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]); 34 } 35 36 /// Registers conversions into BufferizeTypeConverter 37 BufferizeTypeConverter::BufferizeTypeConverter() { 38 // Keep all types unchanged. 39 addConversion([](Type type) { return type; }); 40 // Convert RankedTensorType to MemRefType. 41 addConversion([](RankedTensorType type) -> Type { 42 return MemRefType::get(type.getShape(), type.getElementType()); 43 }); 44 // Convert UnrankedTensorType to UnrankedMemRefType. 45 addConversion([](UnrankedTensorType type) -> Type { 46 return UnrankedMemRefType::get(type.getElementType(), 0); 47 }); 48 addArgumentMaterialization(materializeToTensor); 49 addSourceMaterialization(materializeToTensor); 50 addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, 51 ValueRange inputs, Location loc) -> Value { 52 assert(inputs.size() == 1 && "expected exactly one input"); 53 54 if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) { 55 // MemRef to MemRef cast. 56 assert(inputType != type && "expected different types"); 57 // Unranked to ranked and ranked to unranked casts must be explicit. 58 auto rankedDestType = type.dyn_cast<MemRefType>(); 59 if (!rankedDestType) 60 return nullptr; 61 FailureOr<Value> replacement = 62 castOrReallocMemRefValue(builder, inputs[0], rankedDestType); 63 if (failed(replacement)) 64 return nullptr; 65 return *replacement; 66 } 67 68 if (inputs[0].getType().isa<TensorType>()) { 69 // Tensor to MemRef cast. 70 return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]); 71 } 72 73 llvm_unreachable("only tensor/memref input types supported"); 74 }); 75 } 76 77 void mlir::bufferization::populateBufferizeMaterializationLegality( 78 ConversionTarget &target) { 79 target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>(); 80 } 81 82 namespace { 83 // In a finalizing bufferize conversion, we know that all tensors have been 84 // converted to memrefs, thus, this op becomes an identity. 85 class BufferizeToTensorOp 86 : public OpConversionPattern<bufferization::ToTensorOp> { 87 public: 88 using OpConversionPattern::OpConversionPattern; 89 LogicalResult 90 matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor, 91 ConversionPatternRewriter &rewriter) const override { 92 rewriter.replaceOp(op, adaptor.memref()); 93 return success(); 94 } 95 }; 96 } // namespace 97 98 namespace { 99 // In a finalizing bufferize conversion, we know that all tensors have been 100 // converted to memrefs, thus, this op becomes an identity. 101 class BufferizeToMemrefOp 102 : public OpConversionPattern<bufferization::ToMemrefOp> { 103 public: 104 using OpConversionPattern::OpConversionPattern; 105 LogicalResult 106 matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor, 107 ConversionPatternRewriter &rewriter) const override { 108 rewriter.replaceOp(op, adaptor.tensor()); 109 return success(); 110 } 111 }; 112 } // namespace 113 114 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns( 115 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 116 patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter, 117 patterns.getContext()); 118 } 119 120 namespace { 121 struct FinalizingBufferizePass 122 : public FinalizingBufferizeBase<FinalizingBufferizePass> { 123 using FinalizingBufferizeBase< 124 FinalizingBufferizePass>::FinalizingBufferizeBase; 125 126 void runOnOperation() override { 127 auto func = getOperation(); 128 auto *context = &getContext(); 129 130 BufferizeTypeConverter typeConverter; 131 RewritePatternSet patterns(context); 132 ConversionTarget target(*context); 133 134 populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns); 135 136 // If all result types are legal, and all block arguments are legal (ensured 137 // by func conversion above), then all types in the program are legal. 138 // 139 // We also check that the operand types are legal to avoid creating invalid 140 // IR. For example, this prevents 141 // populateEliminateBufferizeMaterializationsPatterns from updating the 142 // types of the operands to a return op without updating the enclosing 143 // function. 144 target.markUnknownOpDynamicallyLegal( 145 [&](Operation *op) { return typeConverter.isLegal(op); }); 146 147 if (failed(applyFullConversion(func, target, std::move(patterns)))) 148 signalPassFailure(); 149 } 150 }; 151 152 struct OneShotBufferizePass 153 : public OneShotBufferizeBase<OneShotBufferizePass> { 154 OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {} 155 156 explicit OneShotBufferizePass(const OneShotBufferizationOptions &options) 157 : options(options) {} 158 159 void getDependentDialects(DialectRegistry ®istry) const override { 160 registry.insert<bufferization::BufferizationDialect>(); 161 } 162 163 void runOnOperation() override { 164 OneShotBufferizationOptions opt; 165 if (!options) { 166 // Make new bufferization options if none were provided when creating the 167 // pass. 168 opt.allowReturnMemref = allowReturnMemref; 169 opt.allowUnknownOps = allowUnknownOps; 170 opt.analysisFuzzerSeed = analysisFuzzerSeed; 171 opt.createDeallocs = createDeallocs; 172 opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps; 173 opt.printConflicts = printConflicts; 174 opt.testAnalysisOnly = testAnalysisOnly; 175 176 BufferizationOptions::OpFilterEntry::FilterFn filterFn = 177 [&](Operation *op) { 178 // Disallow non-func dialect ops. I.e., no ops related to function 179 // calls. 180 if (isa<func::FuncDialect>(op->getDialect())) 181 return false; 182 // Filter may be specified via options. 183 if (this->dialectFilter.hasValue()) 184 return llvm::find(this->dialectFilter, 185 op->getDialect()->getNamespace()) != 186 this->dialectFilter.end(); 187 // No filter specified: All other ops are allowed. 188 return true; 189 }; 190 opt.allowOperationInFilter(filterFn); 191 } else { 192 opt = *options; 193 } 194 195 ModuleOp moduleOp = getOperation(); 196 if (failed(runOneShotBufferize(moduleOp, opt))) { 197 signalPassFailure(); 198 return; 199 } 200 201 if (opt.testAnalysisOnly) 202 return; 203 204 OpPassManager cleanupPipeline("builtin.module"); 205 cleanupPipeline.addPass(createCanonicalizerPass()); 206 cleanupPipeline.addPass(createCSEPass()); 207 cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); 208 (void)runPipeline(cleanupPipeline, moduleOp); 209 } 210 211 private: 212 llvm::Optional<OneShotBufferizationOptions> options; 213 }; 214 } // namespace 215 216 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() { 217 return std::make_unique<OneShotBufferizePass>(); 218 } 219 220 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass( 221 const OneShotBufferizationOptions &options) { 222 return std::make_unique<OneShotBufferizePass>(options); 223 } 224 225 std::unique_ptr<OperationPass<FuncOp>> 226 mlir::bufferization::createFinalizingBufferizePass() { 227 return std::make_unique<FinalizingBufferizePass>(); 228 } 229 230 //===----------------------------------------------------------------------===// 231 // BufferizableOpInterface-based Bufferization 232 //===----------------------------------------------------------------------===// 233 234 static bool isaTensor(Type t) { return t.isa<TensorType>(); } 235 236 /// Return true if the given op has a tensor result or a tensor operand. 237 static bool hasTensorSemantics(Operation *op) { 238 bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); 239 bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); 240 return hasTensorResult || hasTensorOperand; 241 } 242 243 /// Rewrite pattern that bufferizes bufferizable ops. 244 struct BufferizationPattern 245 : public OpInterfaceRewritePattern<BufferizableOpInterface> { 246 BufferizationPattern(MLIRContext *context, BufferizationState &state, 247 PatternBenefit benefit = 1) 248 : OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit), 249 state(&state) {} 250 251 LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp, 252 PatternRewriter &rewriter) const override { 253 const BufferizationOptions &options = state->getOptions(); 254 255 // No tensors => no buffers. 256 if (!hasTensorSemantics(bufferizableOp.getOperation())) 257 return failure(); 258 if (!options.isOpAllowed(bufferizableOp.getOperation())) 259 return failure(); 260 return bufferizableOp.bufferize(rewriter, *state); 261 } 262 263 private: 264 BufferizationState *const state; 265 }; 266 267 /// Check the result of bufferization. Return an error if an op was not 268 /// bufferized, unless partial bufferization is allowed. 269 static LogicalResult 270 checkBufferizationResult(Operation *op, const BufferizationOptions &options) { 271 if (!options.allowUnknownOps) { 272 // Check if all ops were bufferized. 273 LogicalResult status = success(); 274 op->walk([&](Operation *op) { 275 if (!hasTensorSemantics(op)) 276 return WalkResult::advance(); 277 278 // Bufferization dialect ops will canonicalize away if all other ops are 279 // bufferized. 280 if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op)) 281 return WalkResult::advance(); 282 283 // Ops that are not in the allow list can be ignored. 284 if (!options.isOpAllowed(op)) 285 return WalkResult::advance(); 286 287 // Ops without any uses and no side effects will fold away. 288 if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) 289 return WalkResult::advance(); 290 291 status = op->emitError("op was not bufferized"); 292 return WalkResult::interrupt(); 293 }); 294 295 if (failed(status)) 296 return status; 297 } 298 299 return success(); 300 } 301 302 LogicalResult bufferization::bufferizeOp(Operation *op, 303 const AnalysisState &analysisState) { 304 BufferizationState bufferizationState(analysisState); 305 if (failed(bufferizeOp(op, bufferizationState))) 306 return failure(); 307 if (failed(finalizeBuffers(op, analysisState.getOptions()))) 308 return failure(); 309 return success(); 310 } 311 312 LogicalResult 313 bufferization::bufferizeOp(Operation *op, 314 BufferizationState &bufferizationState) { 315 // Bufferize the op and its nested ops. 316 RewritePatternSet patterns(op->getContext()); 317 patterns.add<BufferizationPattern>(patterns.getContext(), bufferizationState); 318 319 // Bufferize ops top-to-bottom. When creating a new op, we should ideally 320 // know the exact memref type of all operands. Otherwise, we have to use a 321 // memref type with a fully dynamic layout map, which has to canonicalize 322 // away. This is less efficient. 323 // 324 // Note: If "fullyDynamicLayoutMaps = false", we may have to insert buffer 325 // copies to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast- 326 // compatible layout maps when doing a traversal other than top-to-bottom. 327 // There are currently no canonicalization patterns to fold these away. 328 GreedyRewriteConfig config; 329 config.useTopDownTraversal = true; 330 331 // TODO: Perform a preorder walk instead of the greedy pattern rewriter. This 332 // would be more efficient because every bufferization pattern is guaranteed 333 // to apply only a single time (otherwise, an assertion would be triggered). 334 // However, there are restrictions wrt. erasing ops during a preorder walk, 335 // which would likely require a larger refactoring. 336 if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) 337 return failure(); 338 339 if (failed(checkBufferizationResult(op, bufferizationState.getOptions()))) 340 return failure(); 341 342 return success(); 343 } 344 345 namespace { 346 /// This a "no analysis, always copy" AnalysisState. In the absence of an 347 /// analysis, a buffer must be copied each time it is written to. Therefore, all 348 /// OpOperands that bufferize to a memory write must bufferize out-of-place. 349 class AlwaysCopyAnalysisState : public AnalysisState { 350 public: 351 AlwaysCopyAnalysisState(const BufferizationOptions &options) 352 : AnalysisState(options) {} 353 354 AlwaysCopyAnalysisState(const AlwaysCopyAnalysisState &) = delete; 355 356 virtual ~AlwaysCopyAnalysisState() = default; 357 358 /// Return `true` if the given OpResult has been decided to bufferize inplace. 359 bool isInPlace(OpOperand &opOperand) const override { 360 // OpOperands that bufferize to a memory write are out-of-place, i.e., an 361 // alloc and copy is inserted. 362 return !bufferizesToMemoryWrite(opOperand); 363 } 364 365 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 366 bool areEquivalentBufferizedValues(Value v1, Value v2) const override { 367 // There is no analysis, so we do not know if the values are equivalent. The 368 // conservative answer is "false". 369 return false; 370 } 371 }; 372 } // namespace 373 374 LogicalResult bufferization::bufferizeOp(Operation *op, 375 const BufferizationOptions &options) { 376 AlwaysCopyAnalysisState state(options); 377 return bufferizeOp(op, state); 378 } 379 380 BufferizationOptions bufferization::getPartialBufferizationOptions() { 381 BufferizationOptions options; 382 options.allowReturnMemref = true; 383 options.allowUnknownOps = true; 384 options.createDeallocs = false; 385 options.fullyDynamicLayoutMaps = false; 386 return options; 387 } 388