13b2004e1SMatthias Springer //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===// 23b2004e1SMatthias Springer // 33b2004e1SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 43b2004e1SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 53b2004e1SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 63b2004e1SMatthias Springer // 73b2004e1SMatthias Springer //===----------------------------------------------------------------------===// 83b2004e1SMatthias Springer 93b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h" 103b2004e1SMatthias Springer 113b2004e1SMatthias Springer #include "PassDetail.h" 123b2004e1SMatthias Springer 133b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 143b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 153b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 163b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 173b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 183b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 193b2004e1SMatthias Springer 203b2004e1SMatthias Springer using namespace mlir; 213b2004e1SMatthias Springer using namespace mlir::bufferization; 223b2004e1SMatthias Springer 233b2004e1SMatthias Springer LogicalResult mlir::bufferization::insertTensorCopies( 243b2004e1SMatthias Springer Operation *op, const OneShotBufferizationOptions &options) { 253b2004e1SMatthias Springer OneShotAnalysisState state(op, options); 263b2004e1SMatthias Springer // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize 273b2004e1SMatthias Springer // analysis depending on whether function boundary bufferization is enabled or 283b2004e1SMatthias Springer // not. 293b2004e1SMatthias Springer if (options.bufferizeFunctionBoundaries) { 303b2004e1SMatthias Springer if (failed(analyzeModuleOp(cast<ModuleOp>(op), state))) 313b2004e1SMatthias Springer return failure(); 323b2004e1SMatthias Springer } else { 333b2004e1SMatthias Springer if (failed(analyzeOp(op, state))) 343b2004e1SMatthias Springer return failure(); 353b2004e1SMatthias Springer } 363b2004e1SMatthias Springer 373b2004e1SMatthias Springer if (options.testAnalysisOnly) 383b2004e1SMatthias Springer return success(); 393b2004e1SMatthias Springer 403b2004e1SMatthias Springer return insertTensorCopies(op, state); 413b2004e1SMatthias Springer } 423b2004e1SMatthias Springer 433b2004e1SMatthias Springer LogicalResult 443b2004e1SMatthias Springer mlir::bufferization::insertTensorCopies(Operation *op, 453b2004e1SMatthias Springer const AnalysisState &state) { 46*87c770bbSMatthias Springer IRRewriter rewriter(op->getContext()); 473b2004e1SMatthias Springer WalkResult result = op->walk([&](Operation *op) { 483b2004e1SMatthias Springer auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op); 493b2004e1SMatthias Springer if (!bufferizableOp) 503b2004e1SMatthias Springer return WalkResult::skip(); 513b2004e1SMatthias Springer 523b2004e1SMatthias Springer // Find AllocTensorOps without an `escape` attribute and add the attribute 533b2004e1SMatthias Springer // based on analysis results. 543b2004e1SMatthias Springer if (auto allocTensorOp = dyn_cast<AllocTensorOp>(op)) { 553b2004e1SMatthias Springer if (allocTensorOp.escape()) 563b2004e1SMatthias Springer return WalkResult::advance(); 573b2004e1SMatthias Springer bool escape = state.isTensorYielded(allocTensorOp.result()); 58*87c770bbSMatthias Springer allocTensorOp.escapeAttr(rewriter.getBoolAttr(escape)); 593b2004e1SMatthias Springer return WalkResult::advance(); 603b2004e1SMatthias Springer } 613b2004e1SMatthias Springer 62*87c770bbSMatthias Springer // Find inplacability conflicts and resolve them. (Typically with explicit 63*87c770bbSMatthias Springer // tensor copies in the form of AllocTensorOps.) 64*87c770bbSMatthias Springer rewriter.setInsertionPoint(op); 65*87c770bbSMatthias Springer if (failed(bufferizableOp.resolveConflicts(rewriter, state))) 663b2004e1SMatthias Springer return WalkResult::interrupt(); 673b2004e1SMatthias Springer 683b2004e1SMatthias Springer return WalkResult::advance(); 693b2004e1SMatthias Springer }); 703b2004e1SMatthias Springer 713b2004e1SMatthias Springer return failure(result.wasInterrupted()); 723b2004e1SMatthias Springer } 733b2004e1SMatthias Springer 743b2004e1SMatthias Springer namespace { 753b2004e1SMatthias Springer struct TensorCopyInsertionPass 763b2004e1SMatthias Springer : TensorCopyInsertionBase<TensorCopyInsertionPass> { 773b2004e1SMatthias Springer TensorCopyInsertionPass() 783b2004e1SMatthias Springer : TensorCopyInsertionBase<TensorCopyInsertionPass>(), 793b2004e1SMatthias Springer options(llvm::None) {} 803b2004e1SMatthias Springer TensorCopyInsertionPass(const OneShotBufferizationOptions &options) 813b2004e1SMatthias Springer : TensorCopyInsertionBase<TensorCopyInsertionPass>(), options(options) {} 823b2004e1SMatthias Springer 833b2004e1SMatthias Springer void getDependentDialects(DialectRegistry ®istry) const override { 843b2004e1SMatthias Springer registry.insert<bufferization::BufferizationDialect>(); 853b2004e1SMatthias Springer } 863b2004e1SMatthias Springer 873b2004e1SMatthias Springer void runOnOperation() override { 883b2004e1SMatthias Springer if (options.hasValue()) { 893b2004e1SMatthias Springer if (failed(insertTensorCopies(getOperation(), *options))) 903b2004e1SMatthias Springer signalPassFailure(); 913b2004e1SMatthias Springer } else { 923b2004e1SMatthias Springer OneShotBufferizationOptions options; 933b2004e1SMatthias Springer options.allowReturnAllocs = allowReturnAllocs; 943b2004e1SMatthias Springer options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; 953b2004e1SMatthias Springer if (failed(insertTensorCopies(getOperation(), options))) 963b2004e1SMatthias Springer signalPassFailure(); 973b2004e1SMatthias Springer } 983b2004e1SMatthias Springer } 993b2004e1SMatthias Springer 1003b2004e1SMatthias Springer private: 1013b2004e1SMatthias Springer Optional<OneShotBufferizationOptions> options; 1023b2004e1SMatthias Springer }; 1033b2004e1SMatthias Springer } // namespace 1043b2004e1SMatthias Springer 1053b2004e1SMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass() { 1063b2004e1SMatthias Springer return std::make_unique<TensorCopyInsertionPass>(); 1073b2004e1SMatthias Springer } 1083b2004e1SMatthias Springer 1093b2004e1SMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass( 1103b2004e1SMatthias Springer const OneShotBufferizationOptions &options) { 1113b2004e1SMatthias Springer return std::make_unique<TensorCopyInsertionPass>(options); 1123b2004e1SMatthias Springer } 113