1 //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h" 10 11 #include "PassDetail.h" 12 13 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 15 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 16 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 17 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 18 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 19 20 using namespace mlir; 21 using namespace mlir::bufferization; 22 23 LogicalResult mlir::bufferization::insertTensorCopies( 24 Operation *op, const OneShotBufferizationOptions &options) { 25 OneShotAnalysisState state(op, options); 26 // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize 27 // analysis depending on whether function boundary bufferization is enabled or 28 // not. 29 if (options.bufferizeFunctionBoundaries) { 30 if (failed(analyzeModuleOp(cast<ModuleOp>(op), state))) 31 return failure(); 32 } else { 33 if (failed(analyzeOp(op, state))) 34 return failure(); 35 } 36 37 if (options.testAnalysisOnly) 38 return success(); 39 40 return insertTensorCopies(op, state); 41 } 42 43 LogicalResult 44 mlir::bufferization::insertTensorCopies(Operation *op, 45 const AnalysisState &state) { 46 OpBuilder builder(op->getContext()); 47 WalkResult result = op->walk([&](Operation *op) { 48 auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op); 49 if (!bufferizableOp) 50 return WalkResult::skip(); 51 52 // Find AllocTensorOps without an `escape` attribute and add the attribute 53 // based on analysis results. 54 if (auto allocTensorOp = dyn_cast<AllocTensorOp>(op)) { 55 if (allocTensorOp.escape()) 56 return WalkResult::advance(); 57 bool escape = state.isTensorYielded(allocTensorOp.result()); 58 allocTensorOp.escapeAttr(builder.getBoolAttr(escape)); 59 return WalkResult::advance(); 60 } 61 62 // Find out-of-place tensor OpOperands and resolve them with an explicit 63 // tensor copy in the form of an AllocTensorOp. 64 builder.setInsertionPoint(op); 65 for (OpOperand &opOperand : op->getOpOperands()) { 66 if (opOperand.get().getType().isa<UnrankedTensorType>()) { 67 op->emitError("copies of unranked tensors are not supported"); 68 return WalkResult::interrupt(); 69 } 70 auto tensorType = opOperand.get().getType().dyn_cast<RankedTensorType>(); 71 if (!tensorType) 72 continue; 73 if (state.isInPlace(opOperand)) 74 continue; 75 SmallVector<OpResult> aliasingOpResults = 76 state.getAliasingOpResult(opOperand); 77 bool escape = llvm::any_of( 78 aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); }); 79 Value copy = builder.create<AllocTensorOp>( 80 op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape); 81 opOperand.set(copy); 82 } 83 84 return WalkResult::advance(); 85 }); 86 87 return failure(result.wasInterrupted()); 88 } 89 90 namespace { 91 struct TensorCopyInsertionPass 92 : TensorCopyInsertionBase<TensorCopyInsertionPass> { 93 TensorCopyInsertionPass() 94 : TensorCopyInsertionBase<TensorCopyInsertionPass>(), 95 options(llvm::None) {} 96 TensorCopyInsertionPass(const OneShotBufferizationOptions &options) 97 : TensorCopyInsertionBase<TensorCopyInsertionPass>(), options(options) {} 98 99 void getDependentDialects(DialectRegistry ®istry) const override { 100 registry.insert<bufferization::BufferizationDialect>(); 101 } 102 103 void runOnOperation() override { 104 if (options.hasValue()) { 105 if (failed(insertTensorCopies(getOperation(), *options))) 106 signalPassFailure(); 107 } else { 108 OneShotBufferizationOptions options; 109 options.allowReturnAllocs = allowReturnAllocs; 110 options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; 111 if (failed(insertTensorCopies(getOperation(), options))) 112 signalPassFailure(); 113 } 114 } 115 116 private: 117 Optional<OneShotBufferizationOptions> options; 118 }; 119 } // namespace 120 121 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass() { 122 return std::make_unique<TensorCopyInsertionPass>(); 123 } 124 125 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass( 126 const OneShotBufferizationOptions &options) { 127 return std::make_unique<TensorCopyInsertionPass>(options); 128 } 129