1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 15 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 16 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/IR/Operation.h" 20 #include "mlir/Pass/PassManager.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 #include "mlir/Transforms/Passes.h" 23 24 using namespace mlir; 25 using namespace mlir::bufferization; 26 27 //===----------------------------------------------------------------------===// 28 // BufferizeTypeConverter 29 //===----------------------------------------------------------------------===// 30 31 static Value materializeToTensor(OpBuilder &builder, TensorType type, 32 ValueRange inputs, Location loc) { 33 assert(inputs.size() == 1); 34 assert(inputs[0].getType().isa<BaseMemRefType>()); 35 return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]); 36 } 37 38 /// Registers conversions into BufferizeTypeConverter 39 BufferizeTypeConverter::BufferizeTypeConverter() { 40 // Keep all types unchanged. 41 addConversion([](Type type) { return type; }); 42 // Convert RankedTensorType to MemRefType. 43 addConversion([](RankedTensorType type) -> Type { 44 return MemRefType::get(type.getShape(), type.getElementType()); 45 }); 46 // Convert UnrankedTensorType to UnrankedMemRefType. 47 addConversion([](UnrankedTensorType type) -> Type { 48 return UnrankedMemRefType::get(type.getElementType(), 0); 49 }); 50 addArgumentMaterialization(materializeToTensor); 51 addSourceMaterialization(materializeToTensor); 52 addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, 53 ValueRange inputs, Location loc) -> Value { 54 assert(inputs.size() == 1 && "expected exactly one input"); 55 56 if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) { 57 // MemRef to MemRef cast. 58 assert(inputType != type && "expected different types"); 59 // Unranked to ranked and ranked to unranked casts must be explicit. 60 auto rankedDestType = type.dyn_cast<MemRefType>(); 61 if (!rankedDestType) 62 return nullptr; 63 FailureOr<Value> replacement = 64 castOrReallocMemRefValue(builder, inputs[0], rankedDestType); 65 if (failed(replacement)) 66 return nullptr; 67 return *replacement; 68 } 69 70 if (inputs[0].getType().isa<TensorType>()) { 71 // Tensor to MemRef cast. 72 return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]); 73 } 74 75 llvm_unreachable("only tensor/memref input types supported"); 76 }); 77 } 78 79 void mlir::bufferization::populateBufferizeMaterializationLegality( 80 ConversionTarget &target) { 81 target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>(); 82 } 83 84 namespace { 85 // In a finalizing bufferize conversion, we know that all tensors have been 86 // converted to memrefs, thus, this op becomes an identity. 87 class BufferizeToTensorOp 88 : public OpConversionPattern<bufferization::ToTensorOp> { 89 public: 90 using OpConversionPattern::OpConversionPattern; 91 LogicalResult 92 matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor, 93 ConversionPatternRewriter &rewriter) const override { 94 rewriter.replaceOp(op, adaptor.memref()); 95 return success(); 96 } 97 }; 98 } // namespace 99 100 namespace { 101 // In a finalizing bufferize conversion, we know that all tensors have been 102 // converted to memrefs, thus, this op becomes an identity. 103 class BufferizeToMemrefOp 104 : public OpConversionPattern<bufferization::ToMemrefOp> { 105 public: 106 using OpConversionPattern::OpConversionPattern; 107 LogicalResult 108 matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor, 109 ConversionPatternRewriter &rewriter) const override { 110 rewriter.replaceOp(op, adaptor.tensor()); 111 return success(); 112 } 113 }; 114 } // namespace 115 116 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns( 117 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 118 patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter, 119 patterns.getContext()); 120 } 121 122 namespace { 123 struct FinalizingBufferizePass 124 : public FinalizingBufferizeBase<FinalizingBufferizePass> { 125 using FinalizingBufferizeBase< 126 FinalizingBufferizePass>::FinalizingBufferizeBase; 127 128 void runOnOperation() override { 129 auto func = getOperation(); 130 auto *context = &getContext(); 131 132 BufferizeTypeConverter typeConverter; 133 RewritePatternSet patterns(context); 134 ConversionTarget target(*context); 135 136 populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns); 137 138 // If all result types are legal, and all block arguments are legal (ensured 139 // by func conversion above), then all types in the program are legal. 140 // 141 // We also check that the operand types are legal to avoid creating invalid 142 // IR. For example, this prevents 143 // populateEliminateBufferizeMaterializationsPatterns from updating the 144 // types of the operands to a return op without updating the enclosing 145 // function. 146 target.markUnknownOpDynamicallyLegal( 147 [&](Operation *op) { return typeConverter.isLegal(op); }); 148 149 if (failed(applyFullConversion(func, target, std::move(patterns)))) 150 signalPassFailure(); 151 } 152 }; 153 154 struct OneShotBufferizePass 155 : public OneShotBufferizeBase<OneShotBufferizePass> { 156 OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {} 157 158 explicit OneShotBufferizePass(const OneShotBufferizationOptions &options) 159 : options(options) {} 160 161 void getDependentDialects(DialectRegistry ®istry) const override { 162 registry 163 .insert<bufferization::BufferizationDialect, memref::MemRefDialect>(); 164 registerAllocationOpInterfaceExternalModels(registry); 165 } 166 167 void runOnOperation() override { 168 OneShotBufferizationOptions opt; 169 if (!options) { 170 // Make new bufferization options if none were provided when creating the 171 // pass. 172 opt.dropEquivalentFuncResults = dropEquivalentFuncResults; 173 opt.allowReturnAllocs = allowReturnAllocs; 174 opt.allowUnknownOps = allowUnknownOps; 175 opt.alwaysAliasingWithDest = alwaysAliasingWithDest; 176 opt.analysisFuzzerSeed = analysisFuzzerSeed; 177 opt.createDeallocs = createDeallocs; 178 opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps; 179 opt.printConflicts = printConflicts; 180 opt.testAnalysisOnly = testAnalysisOnly; 181 opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; 182 183 BufferizationOptions::OpFilterEntry::FilterFn filterFn = 184 [&](Operation *op) { 185 // Filter may be specified via options. 186 if (this->dialectFilter.hasValue()) 187 return llvm::find(this->dialectFilter, 188 op->getDialect()->getNamespace()) != 189 this->dialectFilter.end(); 190 // No filter specified: All other ops are allowed. 191 return true; 192 }; 193 opt.allowOperationInFilter(filterFn); 194 } else { 195 opt = *options; 196 } 197 198 ModuleOp moduleOp = getOperation(); 199 if (opt.bufferizeFunctionBoundaries) { 200 if (failed(runOneShotModuleBufferize(moduleOp, opt))) { 201 signalPassFailure(); 202 return; 203 } 204 } else { 205 if (failed(runOneShotBufferize(moduleOp, opt))) { 206 signalPassFailure(); 207 return; 208 } 209 } 210 211 if (opt.testAnalysisOnly) 212 return; 213 214 OpPassManager cleanupPipeline("builtin.module"); 215 cleanupPipeline.addPass(createCanonicalizerPass()); 216 cleanupPipeline.addPass(createCSEPass()); 217 cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); 218 (void)runPipeline(cleanupPipeline, moduleOp); 219 } 220 221 private: 222 llvm::Optional<OneShotBufferizationOptions> options; 223 }; 224 } // namespace 225 226 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() { 227 return std::make_unique<OneShotBufferizePass>(); 228 } 229 230 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass( 231 const OneShotBufferizationOptions &options) { 232 return std::make_unique<OneShotBufferizePass>(options); 233 } 234 235 std::unique_ptr<OperationPass<func::FuncOp>> 236 mlir::bufferization::createFinalizingBufferizePass() { 237 return std::make_unique<FinalizingBufferizePass>(); 238 } 239 240 //===----------------------------------------------------------------------===// 241 // BufferizableOpInterface-based Bufferization 242 //===----------------------------------------------------------------------===// 243 244 static bool isaTensor(Type t) { return t.isa<TensorType>(); } 245 246 /// Return true if the given op has a tensor result or a tensor operand. 247 static bool hasTensorSemantics(Operation *op) { 248 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) { 249 bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor); 250 bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor); 251 return hasTensorArg || hasTensorResult; 252 } 253 254 bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); 255 bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); 256 return hasTensorResult || hasTensorOperand; 257 } 258 259 LogicalResult 260 bufferization::finalizeBuffers(Operation *op, 261 const BufferizationOptions &options) { 262 // Hoist buffers. 263 if (failed(hoistBufferAllocations(op, options))) 264 return failure(); 265 266 // Deallocate buffers that escape block boundaries ("leaking buffers") with 267 // the buffer deallocation pass. 268 bool hasLeakingAlloc = false; 269 if (failed(createAllocDeallocOps(op, options, /*onlyLeakingAllocs=*/true, 270 &hasLeakingAlloc))) 271 return failure(); 272 if (options.createDeallocs && hasLeakingAlloc && 273 failed(deallocateBuffers(op))) 274 return failure(); 275 276 // Deallocate all remaining buffers at the end of the block. 277 if (failed(createAllocDeallocOps(op, options))) 278 return failure(); 279 280 return success(); 281 } 282 283 LogicalResult bufferization::bufferizeOp(Operation *op, 284 const AnalysisState &analysisState) { 285 // Catch incorrect API usage. 286 assert((analysisState.hasDialectState( 287 func::FuncDialect::getDialectNamespace()) || 288 !analysisState.getOptions().bufferizeFunctionBoundaries) && 289 "must use ModuleBufferize to bufferize function boundaries"); 290 291 BufferizationState bufferizationState(analysisState); 292 if (failed(bufferizeOp(op, bufferizationState))) 293 return failure(); 294 if (failed(finalizeBuffers(op, analysisState.getOptions()))) 295 return failure(); 296 return success(); 297 } 298 299 namespace { 300 /// A rewriter that keeps track of extra information during bufferization. 301 class BufferizationRewriter : public IRRewriter { 302 public: 303 BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps, 304 DenseSet<Operation *> &toMemrefOps, 305 SmallVector<Operation *> &worklist) 306 : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps), 307 worklist(worklist) {} 308 309 protected: 310 void notifyOperationRemoved(Operation *op) override { 311 IRRewriter::notifyOperationRemoved(op); 312 erasedOps.insert(op); 313 } 314 315 void notifyOperationInserted(Operation *op) override { 316 IRRewriter::notifyOperationInserted(op); 317 318 // Keep track of to_memref ops. 319 if (isa<ToMemrefOp>(op)) { 320 toMemrefOps.insert(op); 321 return; 322 } 323 324 // Skip to_tensor ops. 325 if (isa<ToTensorOp>(op)) 326 return; 327 328 // A new bufferizable op was inserted. Add it to the worklist. 329 if (hasTensorSemantics(op)) 330 worklist.push_back(op); 331 } 332 333 private: 334 /// A set of all erased ops. 335 DenseSet<Operation *> &erasedOps; 336 337 /// A set of all to_memref ops. 338 DenseSet<Operation *> &toMemrefOps; 339 340 /// The list of bufferizable ops. 341 SmallVector<Operation *> &worklist; 342 }; 343 } // namespace 344 345 LogicalResult 346 bufferization::bufferizeOp(Operation *op, 347 BufferizationState &bufferizationState) { 348 const auto &options = bufferizationState.getOptions(); 349 350 // Keep track of to_memref ops. 351 DenseSet<Operation *> toMemrefOps; 352 op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); }); 353 354 // Gather all bufferizable ops in top-to-bottom order. 355 // 356 // We should ideally know the exact memref type of all operands when 357 // bufferizing an op. (This is the case when bufferizing top-to-bottom.) 358 // Otherwise, we have to use a memref type with a fully dynamic layout map, 359 // which has to canonicalize away. This is less efficient. 360 // 361 // If "fullyDynamicLayoutMaps = false", we would have to insert buffer copies 362 // to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast-compatible 363 // layout maps when doing a traversal other than top-to-bottom. These would 364 // not easily fold away. 365 SmallVector<Operation *> worklist; 366 op->walk<WalkOrder::PreOrder>([&](Operation *op) { 367 if (hasTensorSemantics(op)) 368 worklist.push_back(op); 369 }); 370 371 // Keep track of all erased ops. 372 DenseSet<Operation *> erasedOps; 373 374 // Bufferize all ops. 375 BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps, 376 worklist); 377 for (unsigned i = 0; i < worklist.size(); ++i) { 378 Operation *op = worklist[i]; 379 // Skip ops that were erased. 380 if (erasedOps.contains(op)) 381 continue; 382 // Skip ops that are not bufferizable. 383 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 384 if (!bufferizableOp) 385 continue; 386 // Continue ops that are not allowed. 387 if (!options.isOpAllowed(op)) 388 continue; 389 // Bufferize the op. 390 rewriter.setInsertionPoint(op); 391 (void)bufferizableOp.bufferize(rewriter, bufferizationState); 392 } 393 394 // Fold all to_memref(to_tensor(x)) pairs. 395 for (Operation *op : toMemrefOps) { 396 if (erasedOps.contains(op)) 397 continue; 398 rewriter.setInsertionPoint(op); 399 (void)bufferization::foldToMemrefToTensorPair(rewriter, 400 cast<ToMemrefOp>(op)); 401 } 402 403 /// Check the result of bufferization. Return an error if an op was not 404 /// bufferized, unless partial bufferization is allowed. 405 if (bufferizationState.getOptions().allowUnknownOps) 406 return success(); 407 408 for (Operation *op : worklist) { 409 // Skip ops that are entirely gone. 410 if (erasedOps.contains(op)) 411 continue; 412 // Ops that no longer have tensor semantics (because they were updated 413 // in-place) are allowed. 414 if (!hasTensorSemantics(op)) 415 continue; 416 // Continue ops that are not allowed. 417 if (!options.isOpAllowed(op)) 418 continue; 419 // Ops without any uses and no side effects will fold away. 420 if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) 421 continue; 422 return op->emitError("op was not bufferized"); 423 } 424 425 return success(); 426 } 427 428 namespace { 429 /// This a "no analysis, always copy" AnalysisState. In the absence of an 430 /// analysis, a buffer must be copied each time it is written to. Therefore, all 431 /// OpOperands that bufferize to a memory write must bufferize out-of-place. 432 class AlwaysCopyAnalysisState : public AnalysisState { 433 public: 434 AlwaysCopyAnalysisState(const BufferizationOptions &options) 435 : AnalysisState(options) {} 436 437 AlwaysCopyAnalysisState(const AlwaysCopyAnalysisState &) = delete; 438 439 virtual ~AlwaysCopyAnalysisState() = default; 440 441 /// Return `true` if the given OpResult has been decided to bufferize inplace. 442 bool isInPlace(OpOperand &opOperand) const override { 443 // OpOperands that bufferize to a memory write are out-of-place, i.e., an 444 // alloc and copy is inserted. 445 return !bufferizesToMemoryWrite(opOperand); 446 } 447 448 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 449 bool areEquivalentBufferizedValues(Value v1, Value v2) const override { 450 // There is no analysis, so we do not know if the values are equivalent. The 451 // conservative answer is "false". 452 return false; 453 } 454 }; 455 } // namespace 456 457 LogicalResult bufferization::bufferizeOp(Operation *op, 458 const BufferizationOptions &options) { 459 AlwaysCopyAnalysisState state(options); 460 return bufferizeOp(op, state); 461 } 462 463 BufferizationOptions bufferization::getPartialBufferizationOptions() { 464 BufferizationOptions options; 465 options.allowUnknownOps = true; 466 options.createDeallocs = false; 467 options.fullyDynamicLayoutMaps = false; 468 return options; 469 } 470