1 //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h" 10 11 #include "PassDetail.h" 12 13 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 15 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 16 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 17 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 18 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 19 20 using namespace mlir; 21 using namespace mlir::bufferization; 22 23 LogicalResult mlir::bufferization::insertTensorCopies( 24 Operation *op, const OneShotBufferizationOptions &options) { 25 OneShotAnalysisState state(op, options); 26 // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize 27 // analysis depending on whether function boundary bufferization is enabled or 28 // not. 29 if (options.bufferizeFunctionBoundaries) { 30 if (failed(analyzeModuleOp(cast<ModuleOp>(op), state))) 31 return failure(); 32 } else { 33 if (failed(analyzeOp(op, state))) 34 return failure(); 35 } 36 37 if (options.testAnalysisOnly) 38 return success(); 39 40 return insertTensorCopies(op, state); 41 } 42 43 LogicalResult 44 mlir::bufferization::insertTensorCopies(Operation *op, 45 const AnalysisState &state) { 46 IRRewriter rewriter(op->getContext()); 47 WalkResult result = op->walk([&](Operation *op) { 48 auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op); 49 if (!bufferizableOp) 50 return WalkResult::skip(); 51 52 // Find AllocTensorOps without an `escape` attribute and add the attribute 53 // based on analysis results. 54 if (auto allocTensorOp = dyn_cast<AllocTensorOp>(op)) { 55 if (allocTensorOp.escape()) 56 return WalkResult::advance(); 57 bool escape = state.isTensorYielded(allocTensorOp.result()); 58 allocTensorOp.escapeAttr(rewriter.getBoolAttr(escape)); 59 return WalkResult::advance(); 60 } 61 62 // Find inplacability conflicts and resolve them. (Typically with explicit 63 // tensor copies in the form of AllocTensorOps.) 64 rewriter.setInsertionPoint(op); 65 if (failed(bufferizableOp.resolveConflicts(rewriter, state))) 66 return WalkResult::interrupt(); 67 68 return WalkResult::advance(); 69 }); 70 71 return failure(result.wasInterrupted()); 72 } 73 74 namespace { 75 struct TensorCopyInsertionPass 76 : TensorCopyInsertionBase<TensorCopyInsertionPass> { 77 TensorCopyInsertionPass() 78 : TensorCopyInsertionBase<TensorCopyInsertionPass>(), 79 options(llvm::None) {} 80 TensorCopyInsertionPass(const OneShotBufferizationOptions &options) 81 : TensorCopyInsertionBase<TensorCopyInsertionPass>(), options(options) {} 82 83 void getDependentDialects(DialectRegistry ®istry) const override { 84 registry.insert<bufferization::BufferizationDialect>(); 85 } 86 87 void runOnOperation() override { 88 if (options.hasValue()) { 89 if (failed(insertTensorCopies(getOperation(), *options))) 90 signalPassFailure(); 91 } else { 92 OneShotBufferizationOptions options; 93 options.allowReturnAllocs = allowReturnAllocs; 94 options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; 95 if (failed(insertTensorCopies(getOperation(), options))) 96 signalPassFailure(); 97 } 98 } 99 100 private: 101 Optional<OneShotBufferizationOptions> options; 102 }; 103 } // namespace 104 105 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass() { 106 return std::make_unique<TensorCopyInsertionPass>(); 107 } 108 109 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass( 110 const OneShotBufferizationOptions &options) { 111 return std::make_unique<TensorCopyInsertionPass>(options); 112 } 113