1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 15 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 16 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 17 #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/IR/Operation.h" 21 #include "mlir/Pass/PassManager.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 #include "mlir/Transforms/Passes.h" 24 25 using namespace mlir; 26 using namespace mlir::bufferization; 27 28 //===----------------------------------------------------------------------===// 29 // BufferizeTypeConverter 30 //===----------------------------------------------------------------------===// 31 32 static Value materializeToTensor(OpBuilder &builder, TensorType type, 33 ValueRange inputs, Location loc) { 34 assert(inputs.size() == 1); 35 assert(inputs[0].getType().isa<BaseMemRefType>()); 36 return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]); 37 } 38 39 /// Registers conversions into BufferizeTypeConverter 40 BufferizeTypeConverter::BufferizeTypeConverter() { 41 // Keep all types unchanged. 42 addConversion([](Type type) { return type; }); 43 // Convert RankedTensorType to MemRefType. 44 addConversion([](RankedTensorType type) -> Type { 45 return MemRefType::get(type.getShape(), type.getElementType()); 46 }); 47 // Convert UnrankedTensorType to UnrankedMemRefType. 48 addConversion([](UnrankedTensorType type) -> Type { 49 return UnrankedMemRefType::get(type.getElementType(), 0); 50 }); 51 addArgumentMaterialization(materializeToTensor); 52 addSourceMaterialization(materializeToTensor); 53 addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, 54 ValueRange inputs, Location loc) -> Value { 55 assert(inputs.size() == 1 && "expected exactly one input"); 56 57 if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) { 58 // MemRef to MemRef cast. 59 assert(inputType != type && "expected different types"); 60 // Unranked to ranked and ranked to unranked casts must be explicit. 61 auto rankedDestType = type.dyn_cast<MemRefType>(); 62 if (!rankedDestType) 63 return nullptr; 64 FailureOr<Value> replacement = 65 castOrReallocMemRefValue(builder, inputs[0], rankedDestType); 66 if (failed(replacement)) 67 return nullptr; 68 return *replacement; 69 } 70 71 if (inputs[0].getType().isa<TensorType>()) { 72 // Tensor to MemRef cast. 73 return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]); 74 } 75 76 llvm_unreachable("only tensor/memref input types supported"); 77 }); 78 } 79 80 void mlir::bufferization::populateBufferizeMaterializationLegality( 81 ConversionTarget &target) { 82 target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>(); 83 } 84 85 namespace { 86 // In a finalizing bufferize conversion, we know that all tensors have been 87 // converted to memrefs, thus, this op becomes an identity. 88 class BufferizeToTensorOp 89 : public OpConversionPattern<bufferization::ToTensorOp> { 90 public: 91 using OpConversionPattern::OpConversionPattern; 92 LogicalResult 93 matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor, 94 ConversionPatternRewriter &rewriter) const override { 95 rewriter.replaceOp(op, adaptor.getMemref()); 96 return success(); 97 } 98 }; 99 } // namespace 100 101 namespace { 102 // In a finalizing bufferize conversion, we know that all tensors have been 103 // converted to memrefs, thus, this op becomes an identity. 104 class BufferizeToMemrefOp 105 : public OpConversionPattern<bufferization::ToMemrefOp> { 106 public: 107 using OpConversionPattern::OpConversionPattern; 108 LogicalResult 109 matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor, 110 ConversionPatternRewriter &rewriter) const override { 111 rewriter.replaceOp(op, adaptor.getTensor()); 112 return success(); 113 } 114 }; 115 } // namespace 116 117 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns( 118 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 119 patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter, 120 patterns.getContext()); 121 } 122 123 namespace { 124 struct FinalizingBufferizePass 125 : public FinalizingBufferizeBase<FinalizingBufferizePass> { 126 using FinalizingBufferizeBase< 127 FinalizingBufferizePass>::FinalizingBufferizeBase; 128 129 void runOnOperation() override { 130 auto func = getOperation(); 131 auto *context = &getContext(); 132 133 BufferizeTypeConverter typeConverter; 134 RewritePatternSet patterns(context); 135 ConversionTarget target(*context); 136 137 populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns); 138 139 // If all result types are legal, and all block arguments are legal (ensured 140 // by func conversion above), then all types in the program are legal. 141 // 142 // We also check that the operand types are legal to avoid creating invalid 143 // IR. For example, this prevents 144 // populateEliminateBufferizeMaterializationsPatterns from updating the 145 // types of the operands to a return op without updating the enclosing 146 // function. 147 target.markUnknownOpDynamicallyLegal( 148 [&](Operation *op) { return typeConverter.isLegal(op); }); 149 150 if (failed(applyFullConversion(func, target, std::move(patterns)))) 151 signalPassFailure(); 152 } 153 }; 154 155 static BufferizationOptions::LayoutMapOption 156 parseLayoutMapOption(const std::string &s) { 157 if (s == "fully-dynamic-layout-map") 158 return BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap; 159 if (s == "identity-layout-map") 160 return BufferizationOptions::LayoutMapOption::IdentityLayoutMap; 161 if (s == "infer-layout-map") 162 return BufferizationOptions::LayoutMapOption::InferLayoutMap; 163 llvm_unreachable("invalid layout map option"); 164 } 165 166 struct OneShotBufferizePass 167 : public OneShotBufferizeBase<OneShotBufferizePass> { 168 OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {} 169 170 explicit OneShotBufferizePass(const OneShotBufferizationOptions &options) 171 : options(options) {} 172 173 void getDependentDialects(DialectRegistry ®istry) const override { 174 registry 175 .insert<bufferization::BufferizationDialect, memref::MemRefDialect>(); 176 registerAllocationOpInterfaceExternalModels(registry); 177 } 178 179 void runOnOperation() override { 180 OneShotBufferizationOptions opt; 181 if (!options) { 182 // Make new bufferization options if none were provided when creating the 183 // pass. 184 opt.allowReturnAllocs = allowReturnAllocs; 185 opt.allowUnknownOps = allowUnknownOps; 186 opt.analysisFuzzerSeed = analysisFuzzerSeed; 187 opt.createDeallocs = createDeallocs; 188 opt.functionBoundaryTypeConversion = 189 parseLayoutMapOption(functionBoundaryTypeConversion); 190 if (mustInferMemorySpace) 191 opt.defaultMemorySpace = None; 192 opt.printConflicts = printConflicts; 193 opt.testAnalysisOnly = testAnalysisOnly; 194 opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; 195 196 // Configure type converter. 197 BufferizationOptions::LayoutMapOption unknownTypeConversionOption = 198 parseLayoutMapOption(unknownTypeConversion); 199 opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace, 200 const BufferizationOptions &options) { 201 auto tensorType = value.getType().cast<TensorType>(); 202 if (unknownTypeConversionOption == 203 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) 204 return bufferization::getMemRefTypeWithStaticIdentityLayout( 205 tensorType, memorySpace); 206 assert( 207 unknownTypeConversionOption == 208 BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap && 209 "invalid layout map option"); 210 return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, 211 memorySpace); 212 }; 213 214 // Configure op filter. 215 OpFilter::Entry::FilterFn filterFn = 216 [&](Operation *op) { 217 // Filter may be specified via options. 218 if (this->dialectFilter.hasValue()) 219 return llvm::is_contained(this->dialectFilter, 220 op->getDialect()->getNamespace()); 221 // No filter specified: All other ops are allowed. 222 return true; 223 }; 224 opt.opFilter.allowOperation(filterFn); 225 } else { 226 opt = *options; 227 } 228 229 ModuleOp moduleOp = getOperation(); 230 if (opt.bufferizeFunctionBoundaries) { 231 if (failed(runOneShotModuleBufferize(moduleOp, opt))) { 232 signalPassFailure(); 233 return; 234 } 235 } else { 236 if (failed(runOneShotBufferize(moduleOp, opt))) { 237 signalPassFailure(); 238 return; 239 } 240 } 241 242 if (opt.testAnalysisOnly) 243 return; 244 245 OpPassManager cleanupPipeline("builtin.module"); 246 cleanupPipeline.addPass(createCanonicalizerPass()); 247 cleanupPipeline.addPass(createCSEPass()); 248 cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); 249 (void)runPipeline(cleanupPipeline, moduleOp); 250 } 251 252 private: 253 llvm::Optional<OneShotBufferizationOptions> options; 254 }; 255 } // namespace 256 257 namespace { 258 struct BufferizationBufferizePass 259 : public BufferizationBufferizeBase<BufferizationBufferizePass> { 260 void runOnOperation() override { 261 BufferizationOptions options = getPartialBufferizationOptions(); 262 options.opFilter.allowDialect<BufferizationDialect>(); 263 264 if (failed(bufferizeOp(getOperation(), options))) 265 signalPassFailure(); 266 } 267 268 void getDependentDialects(DialectRegistry ®istry) const override { 269 registry 270 .insert<bufferization::BufferizationDialect, memref::MemRefDialect>(); 271 } 272 }; 273 } // namespace 274 275 std::unique_ptr<Pass> mlir::bufferization::createBufferizationBufferizePass() { 276 return std::make_unique<BufferizationBufferizePass>(); 277 } 278 279 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() { 280 return std::make_unique<OneShotBufferizePass>(); 281 } 282 283 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass( 284 const OneShotBufferizationOptions &options) { 285 return std::make_unique<OneShotBufferizePass>(options); 286 } 287 288 std::unique_ptr<OperationPass<func::FuncOp>> 289 mlir::bufferization::createFinalizingBufferizePass() { 290 return std::make_unique<FinalizingBufferizePass>(); 291 } 292 293 //===----------------------------------------------------------------------===// 294 // BufferizableOpInterface-based Bufferization 295 //===----------------------------------------------------------------------===// 296 297 static bool isaTensor(Type t) { return t.isa<TensorType>(); } 298 299 /// Return true if the given op has a tensor result or a tensor operand. 300 static bool hasTensorSemantics(Operation *op) { 301 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) { 302 bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor); 303 bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor); 304 return hasTensorArg || hasTensorResult; 305 } 306 307 bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); 308 bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); 309 return hasTensorResult || hasTensorOperand; 310 } 311 312 namespace { 313 /// A rewriter that keeps track of extra information during bufferization. 314 class BufferizationRewriter : public IRRewriter { 315 public: 316 BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps, 317 DenseSet<Operation *> &toMemrefOps, 318 SmallVector<Operation *> &worklist, 319 const BufferizationOptions &options, 320 const OpFilter *opFilter) 321 : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps), 322 worklist(worklist), analysisState(options), opFilter(opFilter) {} 323 324 protected: 325 void notifyOperationRemoved(Operation *op) override { 326 IRRewriter::notifyOperationRemoved(op); 327 erasedOps.insert(op); 328 // Erase if present. 329 toMemrefOps.erase(op); 330 } 331 332 void notifyOperationInserted(Operation *op) override { 333 IRRewriter::notifyOperationInserted(op); 334 erasedOps.erase(op); 335 336 // Keep track of to_memref ops. 337 if (isa<ToMemrefOp>(op)) { 338 toMemrefOps.insert(op); 339 return; 340 } 341 342 // Skip to_tensor ops. 343 if (isa<ToTensorOp>(op)) 344 return; 345 346 // Skip non-tensor ops. 347 if (!hasTensorSemantics(op)) 348 return; 349 350 // Skip ops that are not allowed to be bufferized. 351 auto const &options = analysisState.getOptions(); 352 if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op))) 353 return; 354 355 #ifndef NDEBUG 356 // Read-only tensor ops may be created during bufferization. Ops that are 357 // writing should not be created because such ops were never analyzed. 358 // Bufferizing such ops could introduce a RaW conflict. 359 for (OpOperand &operand : op->getOpOperands()) 360 if (operand.get().getType().isa<TensorType>()) 361 assert(!analysisState.bufferizesToMemoryWrite(operand) && 362 "creating tensor ops that bufferize to a memory write is not " 363 "allowed during bufferization"); 364 #endif // NDEBUG 365 366 // Add op to worklist. 367 worklist.push_back(op); 368 } 369 370 private: 371 /// A set of all erased ops. 372 DenseSet<Operation *> &erasedOps; 373 374 /// A set of all to_memref ops. 375 DenseSet<Operation *> &toMemrefOps; 376 377 /// The worklist of ops to be bufferized. 378 SmallVector<Operation *> &worklist; 379 380 /// The analysis state. Used for debug assertions and access to the 381 /// bufferization options. 382 const AnalysisState analysisState; 383 384 /// An extra op filter for bufferization. 385 const OpFilter *opFilter; 386 }; 387 } // namespace 388 389 LogicalResult bufferization::bufferizeOp(Operation *op, 390 const BufferizationOptions &options, 391 bool copyBeforeWrite, 392 const OpFilter *opFilter) { 393 if (copyBeforeWrite) { 394 AnalysisState state(options); 395 if (failed(insertTensorCopies(op, state))) 396 return failure(); 397 } 398 399 // Keep track of to_memref ops. 400 DenseSet<Operation *> toMemrefOps; 401 op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); }); 402 403 // Gather all bufferizable ops in top-to-bottom order. 404 // 405 // We should ideally know the exact memref type of all operands when 406 // bufferizing an op. (This is the case when bufferizing top-to-bottom.) 407 // Otherwise, we have to use a memref type with a fully dynamic layout map to 408 // avoid copies. We are currently missing patterns for layout maps to 409 // canonicalize away (or canonicalize to more precise layouts). 410 // 411 // FuncOps must be bufferized before their bodies, so add them to the worklist 412 // first. 413 SmallVector<Operation *> worklist; 414 op->walk([&](func::FuncOp funcOp) { 415 if (hasTensorSemantics(funcOp)) 416 worklist.push_back(funcOp); 417 }); 418 op->walk<WalkOrder::PostOrder>([&](Operation *op) { 419 if (hasTensorSemantics(op) && !isa<func::FuncOp>(op)) 420 worklist.push_back(op); 421 }); 422 423 // Keep track of all erased ops. 424 DenseSet<Operation *> erasedOps; 425 426 // Bufferize all ops. 427 BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps, 428 worklist, options, opFilter); 429 for (unsigned i = 0; i < worklist.size(); ++i) { 430 Operation *op = worklist[i]; 431 // Skip ops that were erased. 432 if (erasedOps.contains(op)) 433 continue; 434 // Skip ops that are not bufferizable or not allowed. 435 auto bufferizableOp = options.dynCastBufferizableOp(op); 436 if (!bufferizableOp) 437 continue; 438 if (opFilter && !opFilter->isOpAllowed(op)) 439 continue; 440 // Skip ops that no longer have tensor semantics. 441 if (!hasTensorSemantics(op)) 442 continue; 443 // Bufferize the op. 444 rewriter.setInsertionPoint(op); 445 if (failed(bufferizableOp.bufferize(rewriter, options))) 446 return op->emitError("failed to bufferize op"); 447 } 448 449 // Fold all to_memref(to_tensor(x)) pairs. 450 for (Operation *op : toMemrefOps) { 451 rewriter.setInsertionPoint(op); 452 (void)bufferization::foldToMemrefToTensorPair(rewriter, 453 cast<ToMemrefOp>(op)); 454 } 455 456 /// Check the result of bufferization. Return an error if an op was not 457 /// bufferized, unless partial bufferization is allowed. 458 if (options.allowUnknownOps) 459 return success(); 460 461 for (Operation *op : worklist) { 462 // Skip ops that are entirely gone. 463 if (erasedOps.contains(op)) 464 continue; 465 // Ops that no longer have tensor semantics (because they were updated 466 // in-place) are allowed. 467 if (!hasTensorSemantics(op)) 468 continue; 469 // Continue ops that are not allowed. 470 if (!options.isOpAllowed(op)) 471 continue; 472 if (opFilter && !opFilter->isOpAllowed(op)) 473 continue; 474 // Ops without any uses and no side effects will fold away. 475 if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) 476 continue; 477 // ToTensorOps/ToMemrefOps are allowed in the output. 478 if (isa<ToTensorOp, ToMemrefOp>(op)) 479 continue; 480 return op->emitError("op was not bufferized"); 481 } 482 483 return success(); 484 } 485 486 BufferizationOptions bufferization::getPartialBufferizationOptions() { 487 BufferizationOptions options; 488 options.allowUnknownOps = true; 489 options.createDeallocs = false; 490 options.enforceAliasingInvariants = false; 491 options.unknownTypeConverterFn = [](Value value, unsigned memorySpace, 492 const BufferizationOptions &options) { 493 return getMemRefTypeWithStaticIdentityLayout( 494 value.getType().cast<TensorType>(), memorySpace); 495 }; 496 options.opFilter.allowDialect<BufferizationDialect>(); 497 return options; 498 } 499