1 //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
10 
11 #include "PassDetail.h"
12 
13 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
16 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
17 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
18 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
19 
20 using namespace mlir;
21 using namespace mlir::bufferization;
22 
23 LogicalResult mlir::bufferization::insertTensorCopies(
24     Operation *op, const OneShotBufferizationOptions &options) {
25   OneShotAnalysisState state(op, options);
26   // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
27   // analysis depending on whether function boundary bufferization is enabled or
28   // not.
29   if (options.bufferizeFunctionBoundaries) {
30     if (failed(analyzeModuleOp(cast<ModuleOp>(op), state)))
31       return failure();
32   } else {
33     if (failed(analyzeOp(op, state)))
34       return failure();
35   }
36 
37   if (options.testAnalysisOnly)
38     return success();
39 
40   return insertTensorCopies(op, state);
41 }
42 
43 LogicalResult
44 mlir::bufferization::insertTensorCopies(Operation *op,
45                                         const AnalysisState &state) {
46   OpBuilder builder(op->getContext());
47   WalkResult result = op->walk([&](Operation *op) {
48     auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
49     if (!bufferizableOp)
50       return WalkResult::skip();
51 
52     // Find AllocTensorOps without an `escape` attribute and add the attribute
53     // based on analysis results.
54     if (auto allocTensorOp = dyn_cast<AllocTensorOp>(op)) {
55       if (allocTensorOp.escape())
56         return WalkResult::advance();
57       bool escape = state.isTensorYielded(allocTensorOp.result());
58       allocTensorOp.escapeAttr(builder.getBoolAttr(escape));
59       return WalkResult::advance();
60     }
61 
62     // Find out-of-place tensor OpOperands and resolve them with an explicit
63     // tensor copy in the form of an AllocTensorOp.
64     builder.setInsertionPoint(op);
65     for (OpOperand &opOperand : op->getOpOperands()) {
66       if (opOperand.get().getType().isa<UnrankedTensorType>()) {
67         op->emitError("copies of unranked tensors are not supported");
68         return WalkResult::interrupt();
69       }
70       auto tensorType = opOperand.get().getType().dyn_cast<RankedTensorType>();
71       if (!tensorType)
72         continue;
73       if (state.isInPlace(opOperand))
74         continue;
75       SmallVector<OpResult> aliasingOpResults =
76           state.getAliasingOpResult(opOperand);
77       bool escape = llvm::any_of(
78           aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); });
79       Value copy = builder.create<AllocTensorOp>(
80           op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape);
81       opOperand.set(copy);
82     }
83 
84     return WalkResult::advance();
85   });
86 
87   return failure(result.wasInterrupted());
88 }
89 
90 namespace {
91 struct TensorCopyInsertionPass
92     : TensorCopyInsertionBase<TensorCopyInsertionPass> {
93   TensorCopyInsertionPass()
94       : TensorCopyInsertionBase<TensorCopyInsertionPass>(),
95         options(llvm::None) {}
96   TensorCopyInsertionPass(const OneShotBufferizationOptions &options)
97       : TensorCopyInsertionBase<TensorCopyInsertionPass>(), options(options) {}
98 
99   void getDependentDialects(DialectRegistry &registry) const override {
100     registry.insert<bufferization::BufferizationDialect>();
101   }
102 
103   void runOnOperation() override {
104     if (options.hasValue()) {
105       if (failed(insertTensorCopies(getOperation(), *options)))
106         signalPassFailure();
107     } else {
108       OneShotBufferizationOptions options;
109       options.allowReturnAllocs = allowReturnAllocs;
110       options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
111       if (failed(insertTensorCopies(getOperation(), options)))
112         signalPassFailure();
113     }
114   }
115 
116 private:
117   Optional<OneShotBufferizationOptions> options;
118 };
119 } // namespace
120 
121 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass() {
122   return std::make_unique<TensorCopyInsertionPass>();
123 }
124 
125 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass(
126     const OneShotBufferizationOptions &options) {
127   return std::make_unique<TensorCopyInsertionPass>(options);
128 }
129