1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 15 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 16 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/IR/Operation.h" 20 #include "mlir/Pass/PassManager.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 #include "mlir/Transforms/Passes.h" 23 24 using namespace mlir; 25 using namespace mlir::bufferization; 26 27 //===----------------------------------------------------------------------===// 28 // BufferizeTypeConverter 29 //===----------------------------------------------------------------------===// 30 31 static Value materializeToTensor(OpBuilder &builder, TensorType type, 32 ValueRange inputs, Location loc) { 33 assert(inputs.size() == 1); 34 assert(inputs[0].getType().isa<BaseMemRefType>()); 35 return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]); 36 } 37 38 /// Registers conversions into BufferizeTypeConverter 39 BufferizeTypeConverter::BufferizeTypeConverter() { 40 // Keep all types unchanged. 41 addConversion([](Type type) { return type; }); 42 // Convert RankedTensorType to MemRefType. 43 addConversion([](RankedTensorType type) -> Type { 44 return MemRefType::get(type.getShape(), type.getElementType()); 45 }); 46 // Convert UnrankedTensorType to UnrankedMemRefType. 47 addConversion([](UnrankedTensorType type) -> Type { 48 return UnrankedMemRefType::get(type.getElementType(), 0); 49 }); 50 addArgumentMaterialization(materializeToTensor); 51 addSourceMaterialization(materializeToTensor); 52 addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, 53 ValueRange inputs, Location loc) -> Value { 54 assert(inputs.size() == 1 && "expected exactly one input"); 55 56 if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) { 57 // MemRef to MemRef cast. 58 assert(inputType != type && "expected different types"); 59 // Unranked to ranked and ranked to unranked casts must be explicit. 60 auto rankedDestType = type.dyn_cast<MemRefType>(); 61 if (!rankedDestType) 62 return nullptr; 63 FailureOr<Value> replacement = 64 castOrReallocMemRefValue(builder, inputs[0], rankedDestType); 65 if (failed(replacement)) 66 return nullptr; 67 return *replacement; 68 } 69 70 if (inputs[0].getType().isa<TensorType>()) { 71 // Tensor to MemRef cast. 72 return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]); 73 } 74 75 llvm_unreachable("only tensor/memref input types supported"); 76 }); 77 } 78 79 void mlir::bufferization::populateBufferizeMaterializationLegality( 80 ConversionTarget &target) { 81 target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>(); 82 } 83 84 namespace { 85 // In a finalizing bufferize conversion, we know that all tensors have been 86 // converted to memrefs, thus, this op becomes an identity. 87 class BufferizeToTensorOp 88 : public OpConversionPattern<bufferization::ToTensorOp> { 89 public: 90 using OpConversionPattern::OpConversionPattern; 91 LogicalResult 92 matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor, 93 ConversionPatternRewriter &rewriter) const override { 94 rewriter.replaceOp(op, adaptor.memref()); 95 return success(); 96 } 97 }; 98 } // namespace 99 100 namespace { 101 // In a finalizing bufferize conversion, we know that all tensors have been 102 // converted to memrefs, thus, this op becomes an identity. 103 class BufferizeToMemrefOp 104 : public OpConversionPattern<bufferization::ToMemrefOp> { 105 public: 106 using OpConversionPattern::OpConversionPattern; 107 LogicalResult 108 matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor, 109 ConversionPatternRewriter &rewriter) const override { 110 rewriter.replaceOp(op, adaptor.tensor()); 111 return success(); 112 } 113 }; 114 } // namespace 115 116 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns( 117 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 118 patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter, 119 patterns.getContext()); 120 } 121 122 namespace { 123 struct FinalizingBufferizePass 124 : public FinalizingBufferizeBase<FinalizingBufferizePass> { 125 using FinalizingBufferizeBase< 126 FinalizingBufferizePass>::FinalizingBufferizeBase; 127 128 void runOnOperation() override { 129 auto func = getOperation(); 130 auto *context = &getContext(); 131 132 BufferizeTypeConverter typeConverter; 133 RewritePatternSet patterns(context); 134 ConversionTarget target(*context); 135 136 populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns); 137 138 // If all result types are legal, and all block arguments are legal (ensured 139 // by func conversion above), then all types in the program are legal. 140 // 141 // We also check that the operand types are legal to avoid creating invalid 142 // IR. For example, this prevents 143 // populateEliminateBufferizeMaterializationsPatterns from updating the 144 // types of the operands to a return op without updating the enclosing 145 // function. 146 target.markUnknownOpDynamicallyLegal( 147 [&](Operation *op) { return typeConverter.isLegal(op); }); 148 149 if (failed(applyFullConversion(func, target, std::move(patterns)))) 150 signalPassFailure(); 151 } 152 }; 153 154 static BufferizationOptions::LayoutMapOption 155 parseLayoutMapOption(const std::string &s) { 156 if (s == "fully-dynamic-layout-map") 157 return BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap; 158 if (s == "identity-layout-map") 159 return BufferizationOptions::LayoutMapOption::IdentityLayoutMap; 160 if (s == "infer-layout-map") 161 return BufferizationOptions::LayoutMapOption::InferLayoutMap; 162 llvm_unreachable("invalid layout map option"); 163 } 164 165 struct OneShotBufferizePass 166 : public OneShotBufferizeBase<OneShotBufferizePass> { 167 OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {} 168 169 explicit OneShotBufferizePass(const OneShotBufferizationOptions &options) 170 : options(options) {} 171 172 void getDependentDialects(DialectRegistry ®istry) const override { 173 registry 174 .insert<bufferization::BufferizationDialect, memref::MemRefDialect>(); 175 registerAllocationOpInterfaceExternalModels(registry); 176 } 177 178 void runOnOperation() override { 179 OneShotBufferizationOptions opt; 180 if (!options) { 181 // Make new bufferization options if none were provided when creating the 182 // pass. 183 opt.dropEquivalentFuncResults = dropEquivalentFuncResults; 184 opt.allowReturnAllocs = allowReturnAllocs; 185 opt.allowUnknownOps = allowUnknownOps; 186 opt.alwaysAliasingWithDest = alwaysAliasingWithDest; 187 opt.analysisFuzzerSeed = analysisFuzzerSeed; 188 opt.createDeallocs = createDeallocs; 189 opt.functionBoundaryTypeConversion = 190 parseLayoutMapOption(functionBoundaryTypeConversion); 191 opt.printConflicts = printConflicts; 192 opt.testAnalysisOnly = testAnalysisOnly; 193 opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; 194 opt.promoteBufferResultsToOutParams = promoteBufferResultsToOutParams; 195 opt.unknownTypeConversion = parseLayoutMapOption(unknownTypeConversion); 196 197 OpFilter::Entry::FilterFn filterFn = 198 [&](Operation *op) { 199 // Filter may be specified via options. 200 if (this->dialectFilter.hasValue()) 201 return llvm::is_contained(this->dialectFilter, 202 op->getDialect()->getNamespace()); 203 // No filter specified: All other ops are allowed. 204 return true; 205 }; 206 opt.opFilter.allowOperation(filterFn); 207 } else { 208 opt = *options; 209 } 210 211 ModuleOp moduleOp = getOperation(); 212 if (opt.bufferizeFunctionBoundaries) { 213 if (failed(runOneShotModuleBufferize(moduleOp, opt))) { 214 signalPassFailure(); 215 return; 216 } 217 } else { 218 if (failed(runOneShotBufferize(moduleOp, opt))) { 219 signalPassFailure(); 220 return; 221 } 222 } 223 224 if (opt.testAnalysisOnly) 225 return; 226 227 OpPassManager cleanupPipeline("builtin.module"); 228 cleanupPipeline.addPass(createCanonicalizerPass()); 229 cleanupPipeline.addPass(createCSEPass()); 230 cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); 231 (void)runPipeline(cleanupPipeline, moduleOp); 232 } 233 234 private: 235 llvm::Optional<OneShotBufferizationOptions> options; 236 }; 237 } // namespace 238 239 namespace { 240 struct BufferizationBufferizePass 241 : public BufferizationBufferizeBase<BufferizationBufferizePass> { 242 void runOnOperation() override { 243 BufferizationOptions options = getPartialBufferizationOptions(); 244 options.opFilter.allowDialect<BufferizationDialect>(); 245 246 if (failed(bufferizeOp(getOperation(), options))) 247 signalPassFailure(); 248 } 249 250 void getDependentDialects(DialectRegistry ®istry) const override { 251 registry 252 .insert<bufferization::BufferizationDialect, memref::MemRefDialect>(); 253 } 254 }; 255 } // namespace 256 257 std::unique_ptr<Pass> mlir::bufferization::createBufferizationBufferizePass() { 258 return std::make_unique<BufferizationBufferizePass>(); 259 } 260 261 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() { 262 return std::make_unique<OneShotBufferizePass>(); 263 } 264 265 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass( 266 const OneShotBufferizationOptions &options) { 267 return std::make_unique<OneShotBufferizePass>(options); 268 } 269 270 std::unique_ptr<OperationPass<func::FuncOp>> 271 mlir::bufferization::createFinalizingBufferizePass() { 272 return std::make_unique<FinalizingBufferizePass>(); 273 } 274 275 //===----------------------------------------------------------------------===// 276 // BufferizableOpInterface-based Bufferization 277 //===----------------------------------------------------------------------===// 278 279 static bool isaTensor(Type t) { return t.isa<TensorType>(); } 280 281 /// Return true if the given op has a tensor result or a tensor operand. 282 static bool hasTensorSemantics(Operation *op) { 283 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) { 284 bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor); 285 bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor); 286 return hasTensorArg || hasTensorResult; 287 } 288 289 bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); 290 bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); 291 return hasTensorResult || hasTensorOperand; 292 } 293 294 LogicalResult 295 bufferization::finalizeBuffers(Operation *op, 296 const BufferizationOptions &options) { 297 // Create allocation ops for "leaking buffers", i.e., buffer allocations that 298 // escape block boundaries. If there are no leaking allocs, `hasLeakingAllocs` 299 // is set to `false`. 300 bool hasLeakingAllocs = false; 301 if (failed(createAllocDeallocOps(op, options, /*onlyLeakingAllocs=*/true, 302 &hasLeakingAllocs))) 303 return failure(); 304 305 // Promote returned buffers to "out" parameters. 306 // TODO: Pass options to support custom dealloc ops. 307 if (options.promoteBufferResultsToOutParams && isa<ModuleOp>(op) && 308 failed(promoteBufferResultsToOutParams(cast<ModuleOp>(op)))) 309 return failure(); 310 311 // Create deallocation ops for all "leaking buffers" and all buffer 312 // allocations that were added during the above promotion process. 313 // TODO: Pass options to support custom dealloc ops. 314 if (hasLeakingAllocs && options.createDeallocs && 315 failed(deallocateBuffers(op))) 316 return failure(); 317 318 // Deallocate all remaining buffers at the end of their parent blocks. 319 if (failed(createAllocDeallocOps(op, options))) 320 return failure(); 321 322 return success(); 323 } 324 325 LogicalResult bufferization::bufferizeOp(Operation *op, 326 const AnalysisState &analysisState) { 327 // Catch incorrect API usage. 328 assert((analysisState.hasDialectState( 329 func::FuncDialect::getDialectNamespace()) || 330 !analysisState.getOptions().bufferizeFunctionBoundaries) && 331 "must use ModuleBufferize to bufferize function boundaries"); 332 333 BufferizationState bufferizationState(analysisState); 334 if (failed(bufferizeOp(op, bufferizationState))) 335 return failure(); 336 if (failed(finalizeBuffers(op, analysisState.getOptions()))) 337 return failure(); 338 return success(); 339 } 340 341 namespace { 342 /// A rewriter that keeps track of extra information during bufferization. 343 class BufferizationRewriter : public IRRewriter { 344 public: 345 BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps, 346 DenseSet<Operation *> &toMemrefOps, 347 const BufferizationOptions &options, 348 const OpFilter *opFilter) 349 : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps), 350 options(options), opFilter(opFilter) {} 351 352 protected: 353 void notifyOperationRemoved(Operation *op) override { 354 IRRewriter::notifyOperationRemoved(op); 355 erasedOps.insert(op); 356 // Erase if present. 357 toMemrefOps.erase(op); 358 } 359 360 void notifyOperationInserted(Operation *op) override { 361 IRRewriter::notifyOperationInserted(op); 362 363 // Keep track of to_memref ops. 364 if (isa<ToMemrefOp>(op)) { 365 toMemrefOps.insert(op); 366 return; 367 } 368 369 // Skip to_tensor ops. 370 if (isa<ToTensorOp>(op)) 371 return; 372 373 // Skip non-tensor ops. 374 if (!hasTensorSemantics(op)) 375 return; 376 377 // Skip ops that are not allowed. 378 if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op))) 379 return; 380 381 // Adding new bufferizable ops is not allowed during bufferization. Such ops 382 // would not be analyzed and can lead to surprising behavior. 383 llvm_unreachable( 384 "creating new tensor ops is not allowed during bufferization"); 385 } 386 387 private: 388 /// A set of all erased ops. 389 DenseSet<Operation *> &erasedOps; 390 391 /// A set of all to_memref ops. 392 DenseSet<Operation *> &toMemrefOps; 393 394 /// The bufferization options. 395 /// Used for debug modes. 396 LLVM_ATTRIBUTE_UNUSED 397 const BufferizationOptions &options; 398 399 const OpFilter *opFilter; 400 }; 401 } // namespace 402 403 LogicalResult bufferization::bufferizeOp(Operation *op, 404 BufferizationState &bufferizationState, 405 const OpFilter *opFilter) { 406 const auto &options = bufferizationState.getOptions(); 407 assert(options.unknownTypeConversion != 408 BufferizationOptions::LayoutMapOption::InferLayoutMap && 409 "invalid layout map option"); 410 411 // Keep track of to_memref ops. 412 DenseSet<Operation *> toMemrefOps; 413 op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); }); 414 415 // Gather all bufferizable ops in top-to-bottom order. 416 // 417 // We should ideally know the exact memref type of all operands when 418 // bufferizing an op. (This is the case when bufferizing top-to-bottom.) 419 // Otherwise, we have to use a memref type with a fully dynamic layout map to 420 // avoid copies. We are currently missing patterns for layout maps to 421 // canonicalize away (or canonicalize to more precise layouts). 422 SmallVector<Operation *> worklist; 423 op->walk<WalkOrder::PreOrder>([&](Operation *op) { 424 if (hasTensorSemantics(op)) 425 worklist.push_back(op); 426 }); 427 428 // Keep track of all erased ops. 429 DenseSet<Operation *> erasedOps; 430 431 // Bufferize all ops. 432 BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps, 433 bufferizationState.getOptions(), opFilter); 434 for (unsigned i = 0; i < worklist.size(); ++i) { 435 Operation *op = worklist[i]; 436 // Skip ops that were erased. 437 if (erasedOps.contains(op)) 438 continue; 439 // Skip ops that are not bufferizable or not allowed. 440 auto bufferizableOp = options.dynCastBufferizableOp(op); 441 if (!bufferizableOp) 442 continue; 443 if (opFilter && !opFilter->isOpAllowed(op)) 444 continue; 445 // Skip ops that no longer have tensor semantics. 446 if (!hasTensorSemantics(op)) 447 continue; 448 // Bufferize the op. 449 rewriter.setInsertionPoint(op); 450 if (failed(bufferizableOp.bufferize(rewriter, bufferizationState))) 451 return op->emitError("failed to bufferize op"); 452 } 453 454 // Fold all to_memref(to_tensor(x)) pairs. 455 for (Operation *op : toMemrefOps) { 456 rewriter.setInsertionPoint(op); 457 (void)bufferization::foldToMemrefToTensorPair(rewriter, 458 cast<ToMemrefOp>(op)); 459 } 460 461 /// Check the result of bufferization. Return an error if an op was not 462 /// bufferized, unless partial bufferization is allowed. 463 if (bufferizationState.getOptions().allowUnknownOps) 464 return success(); 465 466 for (Operation *op : worklist) { 467 // Skip ops that are entirely gone. 468 if (erasedOps.contains(op)) 469 continue; 470 // Ops that no longer have tensor semantics (because they were updated 471 // in-place) are allowed. 472 if (!hasTensorSemantics(op)) 473 continue; 474 // Continue ops that are not allowed. 475 if (!options.isOpAllowed(op)) 476 continue; 477 if (opFilter && !opFilter->isOpAllowed(op)) 478 continue; 479 // Ops without any uses and no side effects will fold away. 480 if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) 481 continue; 482 return op->emitError("op was not bufferized"); 483 } 484 485 return success(); 486 } 487 488 namespace { 489 /// This a "no analysis, always copy" AnalysisState. In the absence of an 490 /// analysis, a buffer must be copied each time it is written to. Therefore, all 491 /// OpOperands that bufferize to a memory write must bufferize out-of-place. 492 class AlwaysCopyAnalysisState : public AnalysisState { 493 public: 494 AlwaysCopyAnalysisState(const BufferizationOptions &options) 495 : AnalysisState(options) { 496 // Note: Allocations must be deallocated with a subsequent run of the buffer 497 // deallocation pass. 498 assert(!options.createDeallocs && 499 "cannot create deallocs with AlwaysCopyBufferizationState"); 500 } 501 502 AlwaysCopyAnalysisState(const AlwaysCopyAnalysisState &) = delete; 503 504 virtual ~AlwaysCopyAnalysisState() = default; 505 506 /// Return `true` if the given OpResult has been decided to bufferize inplace. 507 bool isInPlace(OpOperand &opOperand) const override { 508 // OpOperands that bufferize to a memory write are out-of-place, i.e., an 509 // alloc and copy is inserted. 510 return !bufferizesToMemoryWrite(opOperand); 511 } 512 513 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 514 bool areEquivalentBufferizedValues(Value v1, Value v2) const override { 515 // There is no analysis, so we do not know if the values are equivalent. The 516 // conservative answer is "false". 517 return false; 518 } 519 520 /// Return true if `v1` and `v2` may bufferize to aliasing buffers. 521 bool areAliasingBufferizedValues(Value v1, Value v2) const override { 522 // There is no analysis, so we do not know if the values are equivalent. The 523 // conservative answer is "true". 524 return true; 525 } 526 527 /// Return `true` if the given tensor has undefined contents. 528 bool hasUndefinedContents(OpOperand *opOperand) const override { 529 // There is no analysis, so the conservative answer is "false". 530 return false; 531 } 532 533 /// Return true if the given tensor (or an aliasing tensor) is yielded from 534 /// the containing block. Also include all aliasing tensors in the same block. 535 bool isTensorYielded(Value tensor) const override { 536 // There is no analysis, so conservatively answer "true". 537 return true; 538 } 539 }; 540 } // namespace 541 542 LogicalResult bufferization::bufferizeOp(Operation *op, 543 const BufferizationOptions &options) { 544 AlwaysCopyAnalysisState state(options); 545 return bufferizeOp(op, state); 546 } 547 548 BufferizationOptions bufferization::getPartialBufferizationOptions() { 549 BufferizationOptions options; 550 options.allowUnknownOps = true; 551 options.createDeallocs = false; 552 options.unknownTypeConversion = 553 BufferizationOptions::LayoutMapOption::IdentityLayoutMap; 554 return options; 555 } 556