13b2004e1SMatthias Springer //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===//
23b2004e1SMatthias Springer //
33b2004e1SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43b2004e1SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
53b2004e1SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63b2004e1SMatthias Springer //
73b2004e1SMatthias Springer //===----------------------------------------------------------------------===//
83b2004e1SMatthias Springer 
93b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
103b2004e1SMatthias Springer 
113b2004e1SMatthias Springer #include "PassDetail.h"
123b2004e1SMatthias Springer 
133b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
143b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
153b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
163b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
173b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
183b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
193b2004e1SMatthias Springer 
203b2004e1SMatthias Springer using namespace mlir;
213b2004e1SMatthias Springer using namespace mlir::bufferization;
223b2004e1SMatthias Springer 
233b2004e1SMatthias Springer LogicalResult mlir::bufferization::insertTensorCopies(
243b2004e1SMatthias Springer     Operation *op, const OneShotBufferizationOptions &options) {
253b2004e1SMatthias Springer   OneShotAnalysisState state(op, options);
263b2004e1SMatthias Springer   // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
273b2004e1SMatthias Springer   // analysis depending on whether function boundary bufferization is enabled or
283b2004e1SMatthias Springer   // not.
293b2004e1SMatthias Springer   if (options.bufferizeFunctionBoundaries) {
303b2004e1SMatthias Springer     if (failed(analyzeModuleOp(cast<ModuleOp>(op), state)))
313b2004e1SMatthias Springer       return failure();
323b2004e1SMatthias Springer   } else {
333b2004e1SMatthias Springer     if (failed(analyzeOp(op, state)))
343b2004e1SMatthias Springer       return failure();
353b2004e1SMatthias Springer   }
363b2004e1SMatthias Springer 
373b2004e1SMatthias Springer   if (options.testAnalysisOnly)
383b2004e1SMatthias Springer     return success();
393b2004e1SMatthias Springer 
403b2004e1SMatthias Springer   return insertTensorCopies(op, state);
413b2004e1SMatthias Springer }
423b2004e1SMatthias Springer 
433b2004e1SMatthias Springer LogicalResult
443b2004e1SMatthias Springer mlir::bufferization::insertTensorCopies(Operation *op,
453b2004e1SMatthias Springer                                         const AnalysisState &state) {
46*87c770bbSMatthias Springer   IRRewriter rewriter(op->getContext());
473b2004e1SMatthias Springer   WalkResult result = op->walk([&](Operation *op) {
483b2004e1SMatthias Springer     auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
493b2004e1SMatthias Springer     if (!bufferizableOp)
503b2004e1SMatthias Springer       return WalkResult::skip();
513b2004e1SMatthias Springer 
523b2004e1SMatthias Springer     // Find AllocTensorOps without an `escape` attribute and add the attribute
533b2004e1SMatthias Springer     // based on analysis results.
543b2004e1SMatthias Springer     if (auto allocTensorOp = dyn_cast<AllocTensorOp>(op)) {
553b2004e1SMatthias Springer       if (allocTensorOp.escape())
563b2004e1SMatthias Springer         return WalkResult::advance();
573b2004e1SMatthias Springer       bool escape = state.isTensorYielded(allocTensorOp.result());
58*87c770bbSMatthias Springer       allocTensorOp.escapeAttr(rewriter.getBoolAttr(escape));
593b2004e1SMatthias Springer       return WalkResult::advance();
603b2004e1SMatthias Springer     }
613b2004e1SMatthias Springer 
62*87c770bbSMatthias Springer     // Find inplacability conflicts and resolve them. (Typically with explicit
63*87c770bbSMatthias Springer     // tensor copies in the form of AllocTensorOps.)
64*87c770bbSMatthias Springer     rewriter.setInsertionPoint(op);
65*87c770bbSMatthias Springer     if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
663b2004e1SMatthias Springer       return WalkResult::interrupt();
673b2004e1SMatthias Springer 
683b2004e1SMatthias Springer     return WalkResult::advance();
693b2004e1SMatthias Springer   });
703b2004e1SMatthias Springer 
713b2004e1SMatthias Springer   return failure(result.wasInterrupted());
723b2004e1SMatthias Springer }
733b2004e1SMatthias Springer 
743b2004e1SMatthias Springer namespace {
753b2004e1SMatthias Springer struct TensorCopyInsertionPass
763b2004e1SMatthias Springer     : TensorCopyInsertionBase<TensorCopyInsertionPass> {
773b2004e1SMatthias Springer   TensorCopyInsertionPass()
783b2004e1SMatthias Springer       : TensorCopyInsertionBase<TensorCopyInsertionPass>(),
793b2004e1SMatthias Springer         options(llvm::None) {}
803b2004e1SMatthias Springer   TensorCopyInsertionPass(const OneShotBufferizationOptions &options)
813b2004e1SMatthias Springer       : TensorCopyInsertionBase<TensorCopyInsertionPass>(), options(options) {}
823b2004e1SMatthias Springer 
833b2004e1SMatthias Springer   void getDependentDialects(DialectRegistry &registry) const override {
843b2004e1SMatthias Springer     registry.insert<bufferization::BufferizationDialect>();
853b2004e1SMatthias Springer   }
863b2004e1SMatthias Springer 
873b2004e1SMatthias Springer   void runOnOperation() override {
883b2004e1SMatthias Springer     if (options.hasValue()) {
893b2004e1SMatthias Springer       if (failed(insertTensorCopies(getOperation(), *options)))
903b2004e1SMatthias Springer         signalPassFailure();
913b2004e1SMatthias Springer     } else {
923b2004e1SMatthias Springer       OneShotBufferizationOptions options;
933b2004e1SMatthias Springer       options.allowReturnAllocs = allowReturnAllocs;
943b2004e1SMatthias Springer       options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
953b2004e1SMatthias Springer       if (failed(insertTensorCopies(getOperation(), options)))
963b2004e1SMatthias Springer         signalPassFailure();
973b2004e1SMatthias Springer     }
983b2004e1SMatthias Springer   }
993b2004e1SMatthias Springer 
1003b2004e1SMatthias Springer private:
1013b2004e1SMatthias Springer   Optional<OneShotBufferizationOptions> options;
1023b2004e1SMatthias Springer };
1033b2004e1SMatthias Springer } // namespace
1043b2004e1SMatthias Springer 
1053b2004e1SMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass() {
1063b2004e1SMatthias Springer   return std::make_unique<TensorCopyInsertionPass>();
1073b2004e1SMatthias Springer }
1083b2004e1SMatthias Springer 
1093b2004e1SMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass(
1103b2004e1SMatthias Springer     const OneShotBufferizationOptions &options) {
1113b2004e1SMatthias Springer   return std::make_unique<TensorCopyInsertionPass>(options);
1123b2004e1SMatthias Springer }
113