1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 15 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Pass/PassManager.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "mlir/Transforms/Passes.h" 21 22 using namespace mlir; 23 using namespace mlir::bufferization; 24 25 //===----------------------------------------------------------------------===// 26 // BufferizeTypeConverter 27 //===----------------------------------------------------------------------===// 28 29 static Value materializeToTensor(OpBuilder &builder, TensorType type, 30 ValueRange inputs, Location loc) { 31 assert(inputs.size() == 1); 32 assert(inputs[0].getType().isa<BaseMemRefType>()); 33 return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]); 34 } 35 36 /// Registers conversions into BufferizeTypeConverter 37 BufferizeTypeConverter::BufferizeTypeConverter() { 38 // Keep all types unchanged. 39 addConversion([](Type type) { return type; }); 40 // Convert RankedTensorType to MemRefType. 41 addConversion([](RankedTensorType type) -> Type { 42 return MemRefType::get(type.getShape(), type.getElementType()); 43 }); 44 // Convert UnrankedTensorType to UnrankedMemRefType. 45 addConversion([](UnrankedTensorType type) -> Type { 46 return UnrankedMemRefType::get(type.getElementType(), 0); 47 }); 48 addArgumentMaterialization(materializeToTensor); 49 addSourceMaterialization(materializeToTensor); 50 addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, 51 ValueRange inputs, Location loc) -> Value { 52 assert(inputs.size() == 1 && "expected exactly one input"); 53 54 if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) { 55 // MemRef to MemRef cast. 56 assert(inputType != type && "expected different types"); 57 // Unranked to ranked and ranked to unranked casts must be explicit. 58 auto rankedDestType = type.dyn_cast<MemRefType>(); 59 if (!rankedDestType) 60 return nullptr; 61 FailureOr<Value> replacement = 62 castOrReallocMemRefValue(builder, inputs[0], rankedDestType); 63 if (failed(replacement)) 64 return nullptr; 65 return *replacement; 66 } 67 68 if (inputs[0].getType().isa<TensorType>()) { 69 // Tensor to MemRef cast. 70 return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]); 71 } 72 73 llvm_unreachable("only tensor/memref input types supported"); 74 }); 75 } 76 77 void mlir::bufferization::populateBufferizeMaterializationLegality( 78 ConversionTarget &target) { 79 target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>(); 80 } 81 82 namespace { 83 // In a finalizing bufferize conversion, we know that all tensors have been 84 // converted to memrefs, thus, this op becomes an identity. 85 class BufferizeToTensorOp 86 : public OpConversionPattern<bufferization::ToTensorOp> { 87 public: 88 using OpConversionPattern::OpConversionPattern; 89 LogicalResult 90 matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor, 91 ConversionPatternRewriter &rewriter) const override { 92 rewriter.replaceOp(op, adaptor.memref()); 93 return success(); 94 } 95 }; 96 } // namespace 97 98 namespace { 99 // In a finalizing bufferize conversion, we know that all tensors have been 100 // converted to memrefs, thus, this op becomes an identity. 101 class BufferizeToMemrefOp 102 : public OpConversionPattern<bufferization::ToMemrefOp> { 103 public: 104 using OpConversionPattern::OpConversionPattern; 105 LogicalResult 106 matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor, 107 ConversionPatternRewriter &rewriter) const override { 108 rewriter.replaceOp(op, adaptor.tensor()); 109 return success(); 110 } 111 }; 112 } // namespace 113 114 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns( 115 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 116 patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter, 117 patterns.getContext()); 118 } 119 120 namespace { 121 struct FinalizingBufferizePass 122 : public FinalizingBufferizeBase<FinalizingBufferizePass> { 123 using FinalizingBufferizeBase< 124 FinalizingBufferizePass>::FinalizingBufferizeBase; 125 126 void runOnOperation() override { 127 auto func = getOperation(); 128 auto *context = &getContext(); 129 130 BufferizeTypeConverter typeConverter; 131 RewritePatternSet patterns(context); 132 ConversionTarget target(*context); 133 134 populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns); 135 136 // If all result types are legal, and all block arguments are legal (ensured 137 // by func conversion above), then all types in the program are legal. 138 // 139 // We also check that the operand types are legal to avoid creating invalid 140 // IR. For example, this prevents 141 // populateEliminateBufferizeMaterializationsPatterns from updating the 142 // types of the operands to a return op without updating the enclosing 143 // function. 144 target.markUnknownOpDynamicallyLegal( 145 [&](Operation *op) { return typeConverter.isLegal(op); }); 146 147 if (failed(applyFullConversion(func, target, std::move(patterns)))) 148 signalPassFailure(); 149 } 150 }; 151 152 struct OneShotBufferizePass 153 : public OneShotBufferizeBase<OneShotBufferizePass> { 154 OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {} 155 156 explicit OneShotBufferizePass(const OneShotBufferizationOptions &options) 157 : options(options) {} 158 159 void getDependentDialects(DialectRegistry ®istry) const override { 160 registry 161 .insert<bufferization::BufferizationDialect, memref::MemRefDialect>(); 162 registerAllocationOpInterfaceExternalModels(registry); 163 } 164 165 void runOnOperation() override { 166 OneShotBufferizationOptions opt; 167 if (!options) { 168 // Make new bufferization options if none were provided when creating the 169 // pass. 170 opt.allowReturnAllocs = allowReturnAllocs; 171 opt.allowUnknownOps = allowUnknownOps; 172 opt.analysisFuzzerSeed = analysisFuzzerSeed; 173 opt.createDeallocs = createDeallocs; 174 opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps; 175 opt.printConflicts = printConflicts; 176 opt.testAnalysisOnly = testAnalysisOnly; 177 178 BufferizationOptions::OpFilterEntry::FilterFn filterFn = 179 [&](Operation *op) { 180 // Disallow non-func dialect ops. I.e., no ops related to function 181 // calls. 182 if (isa<func::FuncDialect>(op->getDialect())) 183 return false; 184 // Filter may be specified via options. 185 if (this->dialectFilter.hasValue()) 186 return llvm::find(this->dialectFilter, 187 op->getDialect()->getNamespace()) != 188 this->dialectFilter.end(); 189 // No filter specified: All other ops are allowed. 190 return true; 191 }; 192 opt.allowOperationInFilter(filterFn); 193 } else { 194 opt = *options; 195 } 196 197 ModuleOp moduleOp = getOperation(); 198 if (failed(runOneShotBufferize(moduleOp, opt))) { 199 signalPassFailure(); 200 return; 201 } 202 203 if (opt.testAnalysisOnly) 204 return; 205 206 OpPassManager cleanupPipeline("builtin.module"); 207 cleanupPipeline.addPass(createCanonicalizerPass()); 208 cleanupPipeline.addPass(createCSEPass()); 209 cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); 210 (void)runPipeline(cleanupPipeline, moduleOp); 211 } 212 213 private: 214 llvm::Optional<OneShotBufferizationOptions> options; 215 }; 216 } // namespace 217 218 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() { 219 return std::make_unique<OneShotBufferizePass>(); 220 } 221 222 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass( 223 const OneShotBufferizationOptions &options) { 224 return std::make_unique<OneShotBufferizePass>(options); 225 } 226 227 std::unique_ptr<OperationPass<func::FuncOp>> 228 mlir::bufferization::createFinalizingBufferizePass() { 229 return std::make_unique<FinalizingBufferizePass>(); 230 } 231 232 //===----------------------------------------------------------------------===// 233 // BufferizableOpInterface-based Bufferization 234 //===----------------------------------------------------------------------===// 235 236 static bool isaTensor(Type t) { return t.isa<TensorType>(); } 237 238 /// Return true if the given op has a tensor result or a tensor operand. 239 static bool hasTensorSemantics(Operation *op) { 240 bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); 241 bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); 242 return hasTensorResult || hasTensorOperand; 243 } 244 245 /// Rewrite pattern that bufferizes bufferizable ops. 246 struct BufferizationPattern 247 : public OpInterfaceRewritePattern<BufferizableOpInterface> { 248 BufferizationPattern(MLIRContext *context, BufferizationState &state, 249 PatternBenefit benefit = 1) 250 : OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit), 251 state(&state) {} 252 253 LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp, 254 PatternRewriter &rewriter) const override { 255 const BufferizationOptions &options = state->getOptions(); 256 257 // No tensors => no buffers. 258 if (!hasTensorSemantics(bufferizableOp.getOperation())) 259 return failure(); 260 if (!options.isOpAllowed(bufferizableOp.getOperation())) 261 return failure(); 262 return bufferizableOp.bufferize(rewriter, *state); 263 } 264 265 private: 266 BufferizationState *const state; 267 }; 268 269 /// Check the result of bufferization. Return an error if an op was not 270 /// bufferized, unless partial bufferization is allowed. 271 static LogicalResult 272 checkBufferizationResult(Operation *op, const BufferizationOptions &options) { 273 if (!options.allowUnknownOps) { 274 // Check if all ops were bufferized. 275 LogicalResult status = success(); 276 op->walk([&](Operation *op) { 277 if (!hasTensorSemantics(op)) 278 return WalkResult::advance(); 279 280 // Bufferization dialect ops will canonicalize away if all other ops are 281 // bufferized. 282 if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op)) 283 return WalkResult::advance(); 284 285 // Ops that are not in the allow list can be ignored. 286 if (!options.isOpAllowed(op)) 287 return WalkResult::advance(); 288 289 // Ops without any uses and no side effects will fold away. 290 if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) 291 return WalkResult::advance(); 292 293 status = op->emitError("op was not bufferized"); 294 return WalkResult::interrupt(); 295 }); 296 297 if (failed(status)) 298 return status; 299 } 300 301 return success(); 302 } 303 304 LogicalResult 305 bufferization::finalizeBuffers(Operation *op, 306 const BufferizationOptions &options) { 307 // Hoist buffers. 308 if (failed(hoistBufferAllocations(op, options))) 309 return failure(); 310 311 // Deallocate buffers that escape block boundaries ("leaking buffers") with 312 // the buffer deallocation pass. 313 bool hasLeakingAlloc = false; 314 if (failed(createAllocDeallocOps(op, options, /*onlyLeakingAllocs=*/true, 315 &hasLeakingAlloc))) 316 return failure(); 317 if (options.createDeallocs && hasLeakingAlloc && 318 failed(deallocateBuffers(op))) 319 return failure(); 320 321 // Deallocate all remaining buffers at the end of the block. 322 if (failed(createAllocDeallocOps(op, options))) 323 return failure(); 324 325 return success(); 326 } 327 328 LogicalResult bufferization::bufferizeOp(Operation *op, 329 const AnalysisState &analysisState) { 330 BufferizationState bufferizationState(analysisState); 331 if (failed(bufferizeOp(op, bufferizationState))) 332 return failure(); 333 if (failed(finalizeBuffers(op, analysisState.getOptions()))) 334 return failure(); 335 return success(); 336 } 337 338 LogicalResult 339 bufferization::bufferizeOp(Operation *op, 340 BufferizationState &bufferizationState) { 341 // Bufferize the op and its nested ops. 342 RewritePatternSet patterns(op->getContext()); 343 patterns.add<BufferizationPattern>(patterns.getContext(), bufferizationState); 344 345 // Bufferize ops top-to-bottom. When creating a new op, we should ideally 346 // know the exact memref type of all operands. Otherwise, we have to use a 347 // memref type with a fully dynamic layout map, which has to canonicalize 348 // away. This is less efficient. 349 // 350 // Note: If "fullyDynamicLayoutMaps = false", we may have to insert buffer 351 // copies to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast- 352 // compatible layout maps when doing a traversal other than top-to-bottom. 353 // There are currently no canonicalization patterns to fold these away. 354 GreedyRewriteConfig config; 355 config.useTopDownTraversal = true; 356 357 // TODO: Perform a preorder walk instead of the greedy pattern rewriter. This 358 // would be more efficient because every bufferization pattern is guaranteed 359 // to apply only a single time (otherwise, an assertion would be triggered). 360 // However, there are restrictions wrt. erasing ops during a preorder walk, 361 // which would likely require a larger refactoring. 362 if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) 363 return failure(); 364 365 if (failed(checkBufferizationResult(op, bufferizationState.getOptions()))) 366 return failure(); 367 368 return success(); 369 } 370 371 namespace { 372 /// This a "no analysis, always copy" AnalysisState. In the absence of an 373 /// analysis, a buffer must be copied each time it is written to. Therefore, all 374 /// OpOperands that bufferize to a memory write must bufferize out-of-place. 375 class AlwaysCopyAnalysisState : public AnalysisState { 376 public: 377 AlwaysCopyAnalysisState(const BufferizationOptions &options) 378 : AnalysisState(options) {} 379 380 AlwaysCopyAnalysisState(const AlwaysCopyAnalysisState &) = delete; 381 382 virtual ~AlwaysCopyAnalysisState() = default; 383 384 /// Return `true` if the given OpResult has been decided to bufferize inplace. 385 bool isInPlace(OpOperand &opOperand) const override { 386 // OpOperands that bufferize to a memory write are out-of-place, i.e., an 387 // alloc and copy is inserted. 388 return !bufferizesToMemoryWrite(opOperand); 389 } 390 391 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 392 bool areEquivalentBufferizedValues(Value v1, Value v2) const override { 393 // There is no analysis, so we do not know if the values are equivalent. The 394 // conservative answer is "false". 395 return false; 396 } 397 }; 398 } // namespace 399 400 LogicalResult bufferization::bufferizeOp(Operation *op, 401 const BufferizationOptions &options) { 402 AlwaysCopyAnalysisState state(options); 403 return bufferizeOp(op, state); 404 } 405 406 BufferizationOptions bufferization::getPartialBufferizationOptions() { 407 BufferizationOptions options; 408 options.allowUnknownOps = true; 409 options.createDeallocs = false; 410 options.fullyDynamicLayoutMaps = false; 411 return options; 412 } 413