1 //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h" 10 11 #include "PassDetail.h" 12 13 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 15 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 16 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 17 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 18 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 19 20 using namespace mlir; 21 using namespace mlir::bufferization; 22 23 LogicalResult mlir::bufferization::insertTensorCopies( 24 Operation *op, const OneShotBufferizationOptions &options) { 25 OneShotAnalysisState state(op, options); 26 // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize 27 // analysis depending on whether function boundary bufferization is enabled or 28 // not. 29 if (options.bufferizeFunctionBoundaries) { 30 if (failed(analyzeModuleOp(cast<ModuleOp>(op), state))) 31 return failure(); 32 } else { 33 if (failed(analyzeOp(op, state))) 34 return failure(); 35 } 36 37 if (options.testAnalysisOnly) 38 return success(); 39 40 return insertTensorCopies(op, state); 41 } 42 43 LogicalResult 44 mlir::bufferization::insertTensorCopies(Operation *op, 45 const AnalysisState &state) { 46 IRRewriter rewriter(op->getContext()); 47 WalkResult result = op->walk([&](Operation *op) { 48 auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op); 49 if (!bufferizableOp) 50 return WalkResult::skip(); 51 52 // Find AllocTensorOps without an `escape` attribute and add the attribute 53 // based on analysis results. 54 if (auto allocTensorOp = dyn_cast<AllocTensorOp>(op)) { 55 if (allocTensorOp.getEscape()) 56 return WalkResult::advance(); 57 bool escape = !state.getOptions().createDeallocs || 58 state.isTensorYielded(allocTensorOp.getResult()); 59 allocTensorOp.setEscapeAttr(rewriter.getBoolAttr(escape)); 60 return WalkResult::advance(); 61 } 62 63 // Find inplacability conflicts and resolve them. (Typically with explicit 64 // tensor copies in the form of AllocTensorOps.) 65 rewriter.setInsertionPoint(op); 66 if (failed(bufferizableOp.resolveConflicts(rewriter, state))) 67 return WalkResult::interrupt(); 68 69 return WalkResult::advance(); 70 }); 71 72 return failure(result.wasInterrupted()); 73 } 74 75 namespace { 76 struct TensorCopyInsertionPass 77 : TensorCopyInsertionBase<TensorCopyInsertionPass> { 78 TensorCopyInsertionPass() 79 : TensorCopyInsertionBase<TensorCopyInsertionPass>(), 80 options(llvm::None) {} 81 TensorCopyInsertionPass(const OneShotBufferizationOptions &options) 82 : TensorCopyInsertionBase<TensorCopyInsertionPass>(), options(options) {} 83 84 void getDependentDialects(DialectRegistry ®istry) const override { 85 registry.insert<bufferization::BufferizationDialect>(); 86 } 87 88 void runOnOperation() override { 89 if (options.hasValue()) { 90 if (failed(insertTensorCopies(getOperation(), *options))) 91 signalPassFailure(); 92 } else { 93 OneShotBufferizationOptions options; 94 options.allowReturnAllocs = allowReturnAllocs; 95 options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; 96 options.createDeallocs = createDeallocs; 97 if (failed(insertTensorCopies(getOperation(), options))) 98 signalPassFailure(); 99 } 100 } 101 102 private: 103 Optional<OneShotBufferizationOptions> options; 104 }; 105 } // namespace 106 107 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass() { 108 return std::make_unique<TensorCopyInsertionPass>(); 109 } 110 111 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass( 112 const OneShotBufferizationOptions &options) { 113 return std::make_unique<TensorCopyInsertionPass>(options); 114 } 115