1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 15 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 16 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/Pass/PassManager.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 #include "mlir/Transforms/Passes.h" 22 23 using namespace mlir; 24 using namespace mlir::bufferization; 25 26 //===----------------------------------------------------------------------===// 27 // BufferizeTypeConverter 28 //===----------------------------------------------------------------------===// 29 30 static Value materializeToTensor(OpBuilder &builder, TensorType type, 31 ValueRange inputs, Location loc) { 32 assert(inputs.size() == 1); 33 assert(inputs[0].getType().isa<BaseMemRefType>()); 34 return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]); 35 } 36 37 /// Registers conversions into BufferizeTypeConverter 38 BufferizeTypeConverter::BufferizeTypeConverter() { 39 // Keep all types unchanged. 40 addConversion([](Type type) { return type; }); 41 // Convert RankedTensorType to MemRefType. 42 addConversion([](RankedTensorType type) -> Type { 43 return MemRefType::get(type.getShape(), type.getElementType()); 44 }); 45 // Convert UnrankedTensorType to UnrankedMemRefType. 46 addConversion([](UnrankedTensorType type) -> Type { 47 return UnrankedMemRefType::get(type.getElementType(), 0); 48 }); 49 addArgumentMaterialization(materializeToTensor); 50 addSourceMaterialization(materializeToTensor); 51 addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, 52 ValueRange inputs, Location loc) -> Value { 53 assert(inputs.size() == 1 && "expected exactly one input"); 54 55 if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) { 56 // MemRef to MemRef cast. 57 assert(inputType != type && "expected different types"); 58 // Unranked to ranked and ranked to unranked casts must be explicit. 59 auto rankedDestType = type.dyn_cast<MemRefType>(); 60 if (!rankedDestType) 61 return nullptr; 62 FailureOr<Value> replacement = 63 castOrReallocMemRefValue(builder, inputs[0], rankedDestType); 64 if (failed(replacement)) 65 return nullptr; 66 return *replacement; 67 } 68 69 if (inputs[0].getType().isa<TensorType>()) { 70 // Tensor to MemRef cast. 71 return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]); 72 } 73 74 llvm_unreachable("only tensor/memref input types supported"); 75 }); 76 } 77 78 void mlir::bufferization::populateBufferizeMaterializationLegality( 79 ConversionTarget &target) { 80 target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>(); 81 } 82 83 namespace { 84 // In a finalizing bufferize conversion, we know that all tensors have been 85 // converted to memrefs, thus, this op becomes an identity. 86 class BufferizeToTensorOp 87 : public OpConversionPattern<bufferization::ToTensorOp> { 88 public: 89 using OpConversionPattern::OpConversionPattern; 90 LogicalResult 91 matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor, 92 ConversionPatternRewriter &rewriter) const override { 93 rewriter.replaceOp(op, adaptor.memref()); 94 return success(); 95 } 96 }; 97 } // namespace 98 99 namespace { 100 // In a finalizing bufferize conversion, we know that all tensors have been 101 // converted to memrefs, thus, this op becomes an identity. 102 class BufferizeToMemrefOp 103 : public OpConversionPattern<bufferization::ToMemrefOp> { 104 public: 105 using OpConversionPattern::OpConversionPattern; 106 LogicalResult 107 matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor, 108 ConversionPatternRewriter &rewriter) const override { 109 rewriter.replaceOp(op, adaptor.tensor()); 110 return success(); 111 } 112 }; 113 } // namespace 114 115 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns( 116 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 117 patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter, 118 patterns.getContext()); 119 } 120 121 namespace { 122 struct FinalizingBufferizePass 123 : public FinalizingBufferizeBase<FinalizingBufferizePass> { 124 using FinalizingBufferizeBase< 125 FinalizingBufferizePass>::FinalizingBufferizeBase; 126 127 void runOnOperation() override { 128 auto func = getOperation(); 129 auto *context = &getContext(); 130 131 BufferizeTypeConverter typeConverter; 132 RewritePatternSet patterns(context); 133 ConversionTarget target(*context); 134 135 populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns); 136 137 // If all result types are legal, and all block arguments are legal (ensured 138 // by func conversion above), then all types in the program are legal. 139 // 140 // We also check that the operand types are legal to avoid creating invalid 141 // IR. For example, this prevents 142 // populateEliminateBufferizeMaterializationsPatterns from updating the 143 // types of the operands to a return op without updating the enclosing 144 // function. 145 target.markUnknownOpDynamicallyLegal( 146 [&](Operation *op) { return typeConverter.isLegal(op); }); 147 148 if (failed(applyFullConversion(func, target, std::move(patterns)))) 149 signalPassFailure(); 150 } 151 }; 152 153 struct OneShotBufferizePass 154 : public OneShotBufferizeBase<OneShotBufferizePass> { 155 OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {} 156 157 explicit OneShotBufferizePass(const OneShotBufferizationOptions &options) 158 : options(options) {} 159 160 void getDependentDialects(DialectRegistry ®istry) const override { 161 registry 162 .insert<bufferization::BufferizationDialect, memref::MemRefDialect>(); 163 registerAllocationOpInterfaceExternalModels(registry); 164 } 165 166 void runOnOperation() override { 167 OneShotBufferizationOptions opt; 168 if (!options) { 169 // Make new bufferization options if none were provided when creating the 170 // pass. 171 opt.allowReturnAllocs = allowReturnAllocs; 172 opt.allowUnknownOps = allowUnknownOps; 173 opt.analysisFuzzerSeed = analysisFuzzerSeed; 174 opt.createDeallocs = createDeallocs; 175 opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps; 176 opt.printConflicts = printConflicts; 177 opt.testAnalysisOnly = testAnalysisOnly; 178 opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; 179 180 BufferizationOptions::OpFilterEntry::FilterFn filterFn = 181 [&](Operation *op) { 182 // Filter may be specified via options. 183 if (this->dialectFilter.hasValue()) 184 return llvm::find(this->dialectFilter, 185 op->getDialect()->getNamespace()) != 186 this->dialectFilter.end(); 187 // No filter specified: All other ops are allowed. 188 return true; 189 }; 190 opt.allowOperationInFilter(filterFn); 191 } else { 192 opt = *options; 193 } 194 195 ModuleOp moduleOp = getOperation(); 196 if (opt.bufferizeFunctionBoundaries) { 197 if (failed(runOneShotModuleBufferize(moduleOp, opt))) { 198 signalPassFailure(); 199 return; 200 } 201 } else { 202 if (failed(runOneShotBufferize(moduleOp, opt))) { 203 signalPassFailure(); 204 return; 205 } 206 } 207 208 if (opt.testAnalysisOnly) 209 return; 210 211 OpPassManager cleanupPipeline("builtin.module"); 212 cleanupPipeline.addPass(createCanonicalizerPass()); 213 cleanupPipeline.addPass(createCSEPass()); 214 cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); 215 (void)runPipeline(cleanupPipeline, moduleOp); 216 } 217 218 private: 219 llvm::Optional<OneShotBufferizationOptions> options; 220 }; 221 } // namespace 222 223 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() { 224 return std::make_unique<OneShotBufferizePass>(); 225 } 226 227 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass( 228 const OneShotBufferizationOptions &options) { 229 return std::make_unique<OneShotBufferizePass>(options); 230 } 231 232 std::unique_ptr<OperationPass<func::FuncOp>> 233 mlir::bufferization::createFinalizingBufferizePass() { 234 return std::make_unique<FinalizingBufferizePass>(); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // BufferizableOpInterface-based Bufferization 239 //===----------------------------------------------------------------------===// 240 241 static bool isaTensor(Type t) { return t.isa<TensorType>(); } 242 243 /// Return true if the given op has a tensor result or a tensor operand. 244 static bool hasTensorSemantics(Operation *op) { 245 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) { 246 bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor); 247 bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor); 248 return hasTensorArg || hasTensorResult; 249 } 250 251 bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); 252 bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); 253 return hasTensorResult || hasTensorOperand; 254 } 255 256 LogicalResult 257 bufferization::finalizeBuffers(Operation *op, 258 const BufferizationOptions &options) { 259 // Hoist buffers. 260 if (failed(hoistBufferAllocations(op, options))) 261 return failure(); 262 263 // Deallocate buffers that escape block boundaries ("leaking buffers") with 264 // the buffer deallocation pass. 265 bool hasLeakingAlloc = false; 266 if (failed(createAllocDeallocOps(op, options, /*onlyLeakingAllocs=*/true, 267 &hasLeakingAlloc))) 268 return failure(); 269 if (options.createDeallocs && hasLeakingAlloc && 270 failed(deallocateBuffers(op))) 271 return failure(); 272 273 // Deallocate all remaining buffers at the end of the block. 274 if (failed(createAllocDeallocOps(op, options))) 275 return failure(); 276 277 return success(); 278 } 279 280 LogicalResult bufferization::bufferizeOp(Operation *op, 281 const AnalysisState &analysisState) { 282 // Catch incorrect API usage. 283 assert((analysisState.hasDialectState( 284 func::FuncDialect::getDialectNamespace()) || 285 !analysisState.getOptions().bufferizeFunctionBoundaries) && 286 "must use ModuleBufferize to bufferize function boundaries"); 287 288 BufferizationState bufferizationState(analysisState); 289 if (failed(bufferizeOp(op, bufferizationState))) 290 return failure(); 291 if (failed(finalizeBuffers(op, analysisState.getOptions()))) 292 return failure(); 293 return success(); 294 } 295 296 namespace { 297 /// A rewriter that keeps track of extra information during bufferization. 298 class BufferizationRewriter : public IRRewriter { 299 public: 300 BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps, 301 DenseSet<Operation *> &toMemrefOps, 302 SmallVector<Operation *> &worklist) 303 : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps), 304 worklist(worklist) {} 305 306 protected: 307 void notifyOperationRemoved(Operation *op) override { 308 IRRewriter::notifyOperationRemoved(op); 309 erasedOps.insert(op); 310 } 311 312 void notifyOperationInserted(Operation *op) override { 313 IRRewriter::notifyOperationInserted(op); 314 315 // Keep track of to_memref ops. 316 if (isa<ToMemrefOp>(op)) { 317 toMemrefOps.insert(op); 318 return; 319 } 320 321 // Skip to_tensor ops. 322 if (isa<ToTensorOp>(op)) 323 return; 324 325 // A new bufferizable op was inserted. Add it to the worklist. 326 if (hasTensorSemantics(op)) 327 worklist.push_back(op); 328 } 329 330 private: 331 /// A set of all erased ops. 332 DenseSet<Operation *> &erasedOps; 333 334 /// A set of all to_memref ops. 335 DenseSet<Operation *> &toMemrefOps; 336 337 /// The list of bufferizable ops. 338 SmallVector<Operation *> &worklist; 339 }; 340 } // namespace 341 342 LogicalResult 343 bufferization::bufferizeOp(Operation *op, 344 BufferizationState &bufferizationState) { 345 const auto &options = bufferizationState.getOptions(); 346 347 // Keep track of to_memref ops. 348 DenseSet<Operation *> toMemrefOps; 349 op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); }); 350 351 // Gather all bufferizable ops in top-to-bottom order. 352 // 353 // We should ideally know the exact memref type of all operands when 354 // bufferizing an op. (This is the case when bufferizing top-to-bottom.) 355 // Otherwise, we have to use a memref type with a fully dynamic layout map, 356 // which has to canonicalize away. This is less efficient. 357 // 358 // If "fullyDynamicLayoutMaps = false", we would have to insert buffer copies 359 // to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast-compatible 360 // layout maps when doing a traversal other than top-to-bottom. These would 361 // not easily fold away. 362 SmallVector<Operation *> worklist; 363 op->walk<WalkOrder::PreOrder>([&](Operation *op) { 364 if (hasTensorSemantics(op)) 365 worklist.push_back(op); 366 }); 367 368 // Keep track of all erased ops. 369 DenseSet<Operation *> erasedOps; 370 371 // Bufferize all ops. 372 BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps, 373 worklist); 374 for (unsigned i = 0; i < worklist.size(); ++i) { 375 Operation *op = worklist[i]; 376 // Skip ops that were erased. 377 if (erasedOps.contains(op)) 378 continue; 379 // Skip ops that are not bufferizable. 380 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 381 if (!bufferizableOp) 382 continue; 383 // Continue ops that are not allowed. 384 if (!options.isOpAllowed(op)) 385 continue; 386 // Bufferize the op. 387 rewriter.setInsertionPoint(op); 388 (void)bufferizableOp.bufferize(rewriter, bufferizationState); 389 } 390 391 // Fold all to_memref(to_tensor(x)) pairs. 392 for (Operation *op : toMemrefOps) { 393 if (erasedOps.contains(op)) 394 continue; 395 rewriter.setInsertionPoint(op); 396 (void)bufferization::foldToMemrefToTensorPair(rewriter, 397 cast<ToMemrefOp>(op)); 398 } 399 400 /// Check the result of bufferization. Return an error if an op was not 401 /// bufferized, unless partial bufferization is allowed. 402 if (bufferizationState.getOptions().allowUnknownOps) 403 return success(); 404 405 for (Operation *op : worklist) { 406 // Skip ops that are entirely gone. 407 if (erasedOps.contains(op)) 408 continue; 409 // Ops that no longer have tensor semantics (because they were updated 410 // in-place) are allowed. 411 if (!hasTensorSemantics(op)) 412 continue; 413 // Continue ops that are not allowed. 414 if (!options.isOpAllowed(op)) 415 continue; 416 // Ops without any uses and no side effects will fold away. 417 if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) 418 continue; 419 return op->emitError("op was not bufferized"); 420 } 421 422 return success(); 423 } 424 425 namespace { 426 /// This a "no analysis, always copy" AnalysisState. In the absence of an 427 /// analysis, a buffer must be copied each time it is written to. Therefore, all 428 /// OpOperands that bufferize to a memory write must bufferize out-of-place. 429 class AlwaysCopyAnalysisState : public AnalysisState { 430 public: 431 AlwaysCopyAnalysisState(const BufferizationOptions &options) 432 : AnalysisState(options) {} 433 434 AlwaysCopyAnalysisState(const AlwaysCopyAnalysisState &) = delete; 435 436 virtual ~AlwaysCopyAnalysisState() = default; 437 438 /// Return `true` if the given OpResult has been decided to bufferize inplace. 439 bool isInPlace(OpOperand &opOperand) const override { 440 // OpOperands that bufferize to a memory write are out-of-place, i.e., an 441 // alloc and copy is inserted. 442 return !bufferizesToMemoryWrite(opOperand); 443 } 444 445 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 446 bool areEquivalentBufferizedValues(Value v1, Value v2) const override { 447 // There is no analysis, so we do not know if the values are equivalent. The 448 // conservative answer is "false". 449 return false; 450 } 451 }; 452 } // namespace 453 454 LogicalResult bufferization::bufferizeOp(Operation *op, 455 const BufferizationOptions &options) { 456 AlwaysCopyAnalysisState state(options); 457 return bufferizeOp(op, state); 458 } 459 460 BufferizationOptions bufferization::getPartialBufferizationOptions() { 461 BufferizationOptions options; 462 options.allowUnknownOps = true; 463 options.createDeallocs = false; 464 options.fullyDynamicLayoutMaps = false; 465 return options; 466 } 467