1 //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
10 
11 #include "PassDetail.h"
12 
13 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
16 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
17 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
18 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
19 
20 using namespace mlir;
21 using namespace mlir::bufferization;
22 
23 LogicalResult mlir::bufferization::insertTensorCopies(
24     Operation *op, const OneShotBufferizationOptions &options) {
25   OneShotAnalysisState state(op, options);
26   // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
27   // analysis depending on whether function boundary bufferization is enabled or
28   // not.
29   if (options.bufferizeFunctionBoundaries) {
30     if (failed(analyzeModuleOp(cast<ModuleOp>(op), state)))
31       return failure();
32   } else {
33     if (failed(analyzeOp(op, state)))
34       return failure();
35   }
36 
37   if (options.testAnalysisOnly)
38     return success();
39 
40   return insertTensorCopies(op, state);
41 }
42 
43 LogicalResult
44 mlir::bufferization::insertTensorCopies(Operation *op,
45                                         const AnalysisState &state) {
46   IRRewriter rewriter(op->getContext());
47   WalkResult result = op->walk([&](Operation *op) {
48     auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
49     if (!bufferizableOp)
50       return WalkResult::skip();
51 
52     // Find AllocTensorOps without an `escape` attribute and add the attribute
53     // based on analysis results.
54     if (auto allocTensorOp = dyn_cast<AllocTensorOp>(op)) {
55       if (allocTensorOp.escape())
56         return WalkResult::advance();
57       bool escape = !state.getOptions().createDeallocs ||
58                     state.isTensorYielded(allocTensorOp.result());
59       allocTensorOp.escapeAttr(rewriter.getBoolAttr(escape));
60       return WalkResult::advance();
61     }
62 
63     // Find inplacability conflicts and resolve them. (Typically with explicit
64     // tensor copies in the form of AllocTensorOps.)
65     rewriter.setInsertionPoint(op);
66     if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
67       return WalkResult::interrupt();
68 
69     return WalkResult::advance();
70   });
71 
72   return failure(result.wasInterrupted());
73 }
74 
75 namespace {
76 struct TensorCopyInsertionPass
77     : TensorCopyInsertionBase<TensorCopyInsertionPass> {
78   TensorCopyInsertionPass()
79       : TensorCopyInsertionBase<TensorCopyInsertionPass>(),
80         options(llvm::None) {}
81   TensorCopyInsertionPass(const OneShotBufferizationOptions &options)
82       : TensorCopyInsertionBase<TensorCopyInsertionPass>(), options(options) {}
83 
84   void getDependentDialects(DialectRegistry &registry) const override {
85     registry.insert<bufferization::BufferizationDialect>();
86   }
87 
88   void runOnOperation() override {
89     if (options.hasValue()) {
90       if (failed(insertTensorCopies(getOperation(), *options)))
91         signalPassFailure();
92     } else {
93       OneShotBufferizationOptions options;
94       options.allowReturnAllocs = allowReturnAllocs;
95       options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
96       options.createDeallocs = createDeallocs;
97       if (failed(insertTensorCopies(getOperation(), options)))
98         signalPassFailure();
99     }
100   }
101 
102 private:
103   Optional<OneShotBufferizationOptions> options;
104 };
105 } // namespace
106 
107 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass() {
108   return std::make_unique<TensorCopyInsertionPass>();
109 }
110 
111 std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass(
112     const OneShotBufferizationOptions &options) {
113   return std::make_unique<TensorCopyInsertionPass>(options);
114 }
115