1 //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
10 
11 #include "PassDetail.h"
12 
13 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
16 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
17 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
18 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
19 
20 using namespace mlir;
21 using namespace mlir::bufferization;
22 
insertTensorCopies(Operation * op,const OneShotBufferizationOptions & options)23 LogicalResult mlir::bufferization::insertTensorCopies(
24     Operation *op, const OneShotBufferizationOptions &options) {
25   OneShotAnalysisState state(op, options);
26   // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
27   // analysis depending on whether function boundary bufferization is enabled or
28   // not.
29   if (options.bufferizeFunctionBoundaries) {
30     if (failed(analyzeModuleOp(cast<ModuleOp>(op), state)))
31       return failure();
32   } else {
33     if (failed(analyzeOp(op, state)))
34       return failure();
35   }
36 
37   if (options.testAnalysisOnly)
38     return success();
39 
40   return insertTensorCopies(op, state);
41 }
42 
43 LogicalResult
insertTensorCopies(Operation * op,const AnalysisState & state)44 mlir::bufferization::insertTensorCopies(Operation *op,
45                                         const AnalysisState &state) {
46   IRRewriter rewriter(op->getContext());
47   StringRef escapeAttrName = BufferizationDialect::kEscapeAttrName;
48 
49   WalkResult result = op->walk([&](Operation *op) {
50     auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
51     if (!bufferizableOp)
52       return WalkResult::skip();
53 
54     // Find allocations without an `escape` attribute and add the attribute
55     // based on analysis results.
56     if (!op->hasAttr(escapeAttrName)) {
57       SmallVector<bool> escapeAttrValue;
58       bool foundTensorResult = false;
59       for (OpResult opResult : op->getOpResults()) {
60         if (!opResult.getType().isa<TensorType>() ||
61             !bufferizableOp.bufferizesToAllocation(opResult)) {
62           escapeAttrValue.push_back(false);
63           continue;
64         }
65         foundTensorResult = true;
66         bool escape = !state.getOptions().createDeallocs ||
67                       state.isTensorYielded(opResult);
68         escapeAttrValue.push_back(escape);
69       }
70       if (foundTensorResult)
71         op->setAttr(escapeAttrName, rewriter.getBoolArrayAttr(escapeAttrValue));
72     }
73 
74     // Find inplacability conflicts and resolve them. (Typically with explicit
75     // tensor copies in the form of AllocTensorOps.)
76     rewriter.setInsertionPoint(op);
77     if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
78       return WalkResult::interrupt();
79 
80     return WalkResult::advance();
81   });
82 
83   return failure(result.wasInterrupted());
84 }
85 
86 namespace {
87 struct TensorCopyInsertionPass
88     : TensorCopyInsertionBase<TensorCopyInsertionPass> {
TensorCopyInsertionPass__anon012e1bb40211::TensorCopyInsertionPass89   TensorCopyInsertionPass()
90       : TensorCopyInsertionBase<TensorCopyInsertionPass>(),
91         options(llvm::None) {}
TensorCopyInsertionPass__anon012e1bb40211::TensorCopyInsertionPass92   TensorCopyInsertionPass(const OneShotBufferizationOptions &options)
93       : TensorCopyInsertionBase<TensorCopyInsertionPass>(), options(options) {}
94 
getDependentDialects__anon012e1bb40211::TensorCopyInsertionPass95   void getDependentDialects(DialectRegistry &registry) const override {
96     registry.insert<bufferization::BufferizationDialect>();
97   }
98 
runOnOperation__anon012e1bb40211::TensorCopyInsertionPass99   void runOnOperation() override {
100     if (options) {
101       if (failed(insertTensorCopies(getOperation(), *options)))
102         signalPassFailure();
103     } else {
104       OneShotBufferizationOptions options;
105       options.allowReturnAllocs = allowReturnAllocs;
106       options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
107       options.createDeallocs = createDeallocs;
108       if (mustInferMemorySpace)
109         options.defaultMemorySpace = None;
110       if (failed(insertTensorCopies(getOperation(), options)))
111         signalPassFailure();
112     }
113   }
114 
115 private:
116   Optional<OneShotBufferizationOptions> options;
117 };
118 } // namespace
119 
createTensorCopyInsertionPass()120 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass() {
121   return std::make_unique<TensorCopyInsertionPass>();
122 }
123 
createTensorCopyInsertionPass(const OneShotBufferizationOptions & options)124 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass(
125     const OneShotBufferizationOptions &options) {
126   return std::make_unique<TensorCopyInsertionPass>(options);
127 }
128