1 //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h" 10 11 #include "PassDetail.h" 12 13 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 15 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 16 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 17 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 18 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 19 20 using namespace mlir; 21 using namespace mlir::bufferization; 22 23 LogicalResult mlir::bufferization::insertTensorCopies( 24 Operation *op, const OneShotBufferizationOptions &options) { 25 OneShotAnalysisState state(op, options); 26 // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize 27 // analysis depending on whether function boundary bufferization is enabled or 28 // not. 29 if (options.bufferizeFunctionBoundaries) { 30 if (failed(analyzeModuleOp(cast<ModuleOp>(op), state))) 31 return failure(); 32 } else { 33 if (failed(analyzeOp(op, state))) 34 return failure(); 35 } 36 37 if (options.testAnalysisOnly) 38 return success(); 39 40 return insertTensorCopies(op, state); 41 } 42 43 LogicalResult 44 mlir::bufferization::insertTensorCopies(Operation *op, 45 const AnalysisState &state) { 46 IRRewriter rewriter(op->getContext()); 47 StringRef escapeAttrName = BufferizationDialect::kEscapeAttrName; 48 49 WalkResult result = op->walk([&](Operation *op) { 50 auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op); 51 if (!bufferizableOp) 52 return WalkResult::skip(); 53 54 // Find allocations without an `escape` attribute and add the attribute 55 // based on analysis results. 56 if (!op->hasAttr(escapeAttrName)) { 57 SmallVector<bool> escapeAttrValue; 58 bool foundTensorResult = false; 59 for (OpResult opResult : op->getOpResults()) { 60 if (!opResult.getType().isa<TensorType>() || 61 !bufferizableOp.bufferizesToAllocation(opResult)) { 62 escapeAttrValue.push_back(false); 63 continue; 64 } 65 foundTensorResult = true; 66 bool escape = !state.getOptions().createDeallocs || 67 state.isTensorYielded(opResult); 68 escapeAttrValue.push_back(escape); 69 } 70 if (foundTensorResult) 71 op->setAttr(escapeAttrName, rewriter.getBoolArrayAttr(escapeAttrValue)); 72 } 73 74 // Find inplacability conflicts and resolve them. (Typically with explicit 75 // tensor copies in the form of AllocTensorOps.) 76 rewriter.setInsertionPoint(op); 77 if (failed(bufferizableOp.resolveConflicts(rewriter, state))) 78 return WalkResult::interrupt(); 79 80 return WalkResult::advance(); 81 }); 82 83 return failure(result.wasInterrupted()); 84 } 85 86 namespace { 87 struct TensorCopyInsertionPass 88 : TensorCopyInsertionBase<TensorCopyInsertionPass> { 89 TensorCopyInsertionPass() 90 : TensorCopyInsertionBase<TensorCopyInsertionPass>(), 91 options(llvm::None) {} 92 TensorCopyInsertionPass(const OneShotBufferizationOptions &options) 93 : TensorCopyInsertionBase<TensorCopyInsertionPass>(), options(options) {} 94 95 void getDependentDialects(DialectRegistry ®istry) const override { 96 registry.insert<bufferization::BufferizationDialect>(); 97 } 98 99 void runOnOperation() override { 100 if (options) { 101 if (failed(insertTensorCopies(getOperation(), *options))) 102 signalPassFailure(); 103 } else { 104 OneShotBufferizationOptions options; 105 options.allowReturnAllocs = allowReturnAllocs; 106 options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; 107 options.createDeallocs = createDeallocs; 108 if (mustInferMemorySpace) 109 options.defaultMemorySpace = None; 110 if (failed(insertTensorCopies(getOperation(), options))) 111 signalPassFailure(); 112 } 113 } 114 115 private: 116 Optional<OneShotBufferizationOptions> options; 117 }; 118 } // namespace 119 120 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass() { 121 return std::make_unique<TensorCopyInsertionPass>(); 122 } 123 124 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass( 125 const OneShotBufferizationOptions &options) { 126 return std::make_unique<TensorCopyInsertionPass>(options); 127 } 128