1 //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
10
11 #include "PassDetail.h"
12
13 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
16 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
17 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
18 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
19
20 using namespace mlir;
21 using namespace mlir::bufferization;
22
insertTensorCopies(Operation * op,const OneShotBufferizationOptions & options)23 LogicalResult mlir::bufferization::insertTensorCopies(
24 Operation *op, const OneShotBufferizationOptions &options) {
25 OneShotAnalysisState state(op, options);
26 // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
27 // analysis depending on whether function boundary bufferization is enabled or
28 // not.
29 if (options.bufferizeFunctionBoundaries) {
30 if (failed(analyzeModuleOp(cast<ModuleOp>(op), state)))
31 return failure();
32 } else {
33 if (failed(analyzeOp(op, state)))
34 return failure();
35 }
36
37 if (options.testAnalysisOnly)
38 return success();
39
40 return insertTensorCopies(op, state);
41 }
42
43 LogicalResult
insertTensorCopies(Operation * op,const AnalysisState & state)44 mlir::bufferization::insertTensorCopies(Operation *op,
45 const AnalysisState &state) {
46 IRRewriter rewriter(op->getContext());
47 StringRef escapeAttrName = BufferizationDialect::kEscapeAttrName;
48
49 WalkResult result = op->walk([&](Operation *op) {
50 auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
51 if (!bufferizableOp)
52 return WalkResult::skip();
53
54 // Find allocations without an `escape` attribute and add the attribute
55 // based on analysis results.
56 if (!op->hasAttr(escapeAttrName)) {
57 SmallVector<bool> escapeAttrValue;
58 bool foundTensorResult = false;
59 for (OpResult opResult : op->getOpResults()) {
60 if (!opResult.getType().isa<TensorType>() ||
61 !bufferizableOp.bufferizesToAllocation(opResult)) {
62 escapeAttrValue.push_back(false);
63 continue;
64 }
65 foundTensorResult = true;
66 bool escape = !state.getOptions().createDeallocs ||
67 state.isTensorYielded(opResult);
68 escapeAttrValue.push_back(escape);
69 }
70 if (foundTensorResult)
71 op->setAttr(escapeAttrName, rewriter.getBoolArrayAttr(escapeAttrValue));
72 }
73
74 // Find inplacability conflicts and resolve them. (Typically with explicit
75 // tensor copies in the form of AllocTensorOps.)
76 rewriter.setInsertionPoint(op);
77 if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
78 return WalkResult::interrupt();
79
80 return WalkResult::advance();
81 });
82
83 return failure(result.wasInterrupted());
84 }
85
86 namespace {
87 struct TensorCopyInsertionPass
88 : TensorCopyInsertionBase<TensorCopyInsertionPass> {
TensorCopyInsertionPass__anon012e1bb40211::TensorCopyInsertionPass89 TensorCopyInsertionPass()
90 : TensorCopyInsertionBase<TensorCopyInsertionPass>(),
91 options(llvm::None) {}
TensorCopyInsertionPass__anon012e1bb40211::TensorCopyInsertionPass92 TensorCopyInsertionPass(const OneShotBufferizationOptions &options)
93 : TensorCopyInsertionBase<TensorCopyInsertionPass>(), options(options) {}
94
getDependentDialects__anon012e1bb40211::TensorCopyInsertionPass95 void getDependentDialects(DialectRegistry ®istry) const override {
96 registry.insert<bufferization::BufferizationDialect>();
97 }
98
runOnOperation__anon012e1bb40211::TensorCopyInsertionPass99 void runOnOperation() override {
100 if (options) {
101 if (failed(insertTensorCopies(getOperation(), *options)))
102 signalPassFailure();
103 } else {
104 OneShotBufferizationOptions options;
105 options.allowReturnAllocs = allowReturnAllocs;
106 options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
107 options.createDeallocs = createDeallocs;
108 if (mustInferMemorySpace)
109 options.defaultMemorySpace = None;
110 if (failed(insertTensorCopies(getOperation(), options)))
111 signalPassFailure();
112 }
113 }
114
115 private:
116 Optional<OneShotBufferizationOptions> options;
117 };
118 } // namespace
119
createTensorCopyInsertionPass()120 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass() {
121 return std::make_unique<TensorCopyInsertionPass>();
122 }
123
createTensorCopyInsertionPass(const OneShotBufferizationOptions & options)124 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass(
125 const OneShotBufferizationOptions &options) {
126 return std::make_unique<TensorCopyInsertionPass>(options);
127 }
128