1 //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
10 
11 #include "PassDetail.h"
12 
13 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
16 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
17 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
18 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
19 
20 using namespace mlir;
21 using namespace mlir::bufferization;
22 
23 LogicalResult mlir::bufferization::insertTensorCopies(
24     Operation *op, const OneShotBufferizationOptions &options) {
25   OneShotAnalysisState state(op, options);
26   // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
27   // analysis depending on whether function boundary bufferization is enabled or
28   // not.
29   if (options.bufferizeFunctionBoundaries) {
30     if (failed(analyzeModuleOp(cast<ModuleOp>(op), state)))
31       return failure();
32   } else {
33     if (failed(analyzeOp(op, state)))
34       return failure();
35   }
36 
37   if (options.testAnalysisOnly)
38     return success();
39 
40   return insertTensorCopies(op, state);
41 }
42 
43 LogicalResult
44 mlir::bufferization::insertTensorCopies(Operation *op,
45                                         const AnalysisState &state) {
46   IRRewriter rewriter(op->getContext());
47   WalkResult result = op->walk([&](Operation *op) {
48     auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
49     if (!bufferizableOp)
50       return WalkResult::skip();
51 
52     // Find AllocTensorOps without an `escape` attribute and add the attribute
53     // based on analysis results.
54     if (auto allocTensorOp = dyn_cast<AllocTensorOp>(op)) {
55       if (allocTensorOp.escape())
56         return WalkResult::advance();
57       bool escape = state.isTensorYielded(allocTensorOp.result());
58       allocTensorOp.escapeAttr(rewriter.getBoolAttr(escape));
59       return WalkResult::advance();
60     }
61 
62     // Find inplacability conflicts and resolve them. (Typically with explicit
63     // tensor copies in the form of AllocTensorOps.)
64     rewriter.setInsertionPoint(op);
65     if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
66       return WalkResult::interrupt();
67 
68     return WalkResult::advance();
69   });
70 
71   return failure(result.wasInterrupted());
72 }
73 
74 namespace {
75 struct TensorCopyInsertionPass
76     : TensorCopyInsertionBase<TensorCopyInsertionPass> {
77   TensorCopyInsertionPass()
78       : TensorCopyInsertionBase<TensorCopyInsertionPass>(),
79         options(llvm::None) {}
80   TensorCopyInsertionPass(const OneShotBufferizationOptions &options)
81       : TensorCopyInsertionBase<TensorCopyInsertionPass>(), options(options) {}
82 
83   void getDependentDialects(DialectRegistry &registry) const override {
84     registry.insert<bufferization::BufferizationDialect>();
85   }
86 
87   void runOnOperation() override {
88     if (options.hasValue()) {
89       if (failed(insertTensorCopies(getOperation(), *options)))
90         signalPassFailure();
91     } else {
92       OneShotBufferizationOptions options;
93       options.allowReturnAllocs = allowReturnAllocs;
94       options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
95       if (failed(insertTensorCopies(getOperation(), options)))
96         signalPassFailure();
97     }
98   }
99 
100 private:
101   Optional<OneShotBufferizationOptions> options;
102 };
103 } // namespace
104 
105 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass() {
106   return std::make_unique<TensorCopyInsertionPass>();
107 }
108 
109 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass(
110     const OneShotBufferizationOptions &options) {
111   return std::make_unique<TensorCopyInsertionPass>(options);
112 }
113