13b2004e1SMatthias Springer //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===//
23b2004e1SMatthias Springer //
33b2004e1SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43b2004e1SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
53b2004e1SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63b2004e1SMatthias Springer //
73b2004e1SMatthias Springer //===----------------------------------------------------------------------===//
83b2004e1SMatthias Springer 
93b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
103b2004e1SMatthias Springer 
113b2004e1SMatthias Springer #include "PassDetail.h"
123b2004e1SMatthias Springer 
133b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
143b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
153b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
163b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
173b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
183b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
193b2004e1SMatthias Springer 
203b2004e1SMatthias Springer using namespace mlir;
213b2004e1SMatthias Springer using namespace mlir::bufferization;
223b2004e1SMatthias Springer 
insertTensorCopies(Operation * op,const OneShotBufferizationOptions & options)233b2004e1SMatthias Springer LogicalResult mlir::bufferization::insertTensorCopies(
243b2004e1SMatthias Springer     Operation *op, const OneShotBufferizationOptions &options) {
253b2004e1SMatthias Springer   OneShotAnalysisState state(op, options);
263b2004e1SMatthias Springer   // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
273b2004e1SMatthias Springer   // analysis depending on whether function boundary bufferization is enabled or
283b2004e1SMatthias Springer   // not.
293b2004e1SMatthias Springer   if (options.bufferizeFunctionBoundaries) {
303b2004e1SMatthias Springer     if (failed(analyzeModuleOp(cast<ModuleOp>(op), state)))
313b2004e1SMatthias Springer       return failure();
323b2004e1SMatthias Springer   } else {
333b2004e1SMatthias Springer     if (failed(analyzeOp(op, state)))
343b2004e1SMatthias Springer       return failure();
353b2004e1SMatthias Springer   }
363b2004e1SMatthias Springer 
373b2004e1SMatthias Springer   if (options.testAnalysisOnly)
383b2004e1SMatthias Springer     return success();
393b2004e1SMatthias Springer 
403b2004e1SMatthias Springer   return insertTensorCopies(op, state);
413b2004e1SMatthias Springer }
423b2004e1SMatthias Springer 
433b2004e1SMatthias Springer LogicalResult
insertTensorCopies(Operation * op,const AnalysisState & state)443b2004e1SMatthias Springer mlir::bufferization::insertTensorCopies(Operation *op,
453b2004e1SMatthias Springer                                         const AnalysisState &state) {
4687c770bbSMatthias Springer   IRRewriter rewriter(op->getContext());
473474d10eSMatthias Springer   StringRef escapeAttrName = BufferizationDialect::kEscapeAttrName;
483474d10eSMatthias Springer 
493b2004e1SMatthias Springer   WalkResult result = op->walk([&](Operation *op) {
503b2004e1SMatthias Springer     auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
513b2004e1SMatthias Springer     if (!bufferizableOp)
523b2004e1SMatthias Springer       return WalkResult::skip();
533b2004e1SMatthias Springer 
543474d10eSMatthias Springer     // Find allocations without an `escape` attribute and add the attribute
553b2004e1SMatthias Springer     // based on analysis results.
563474d10eSMatthias Springer     if (!op->hasAttr(escapeAttrName)) {
573474d10eSMatthias Springer       SmallVector<bool> escapeAttrValue;
583474d10eSMatthias Springer       bool foundTensorResult = false;
593474d10eSMatthias Springer       for (OpResult opResult : op->getOpResults()) {
603474d10eSMatthias Springer         if (!opResult.getType().isa<TensorType>() ||
613474d10eSMatthias Springer             !bufferizableOp.bufferizesToAllocation(opResult)) {
623474d10eSMatthias Springer           escapeAttrValue.push_back(false);
633474d10eSMatthias Springer           continue;
643474d10eSMatthias Springer         }
653474d10eSMatthias Springer         foundTensorResult = true;
66a36c801dSMatthias Springer         bool escape = !state.getOptions().createDeallocs ||
673474d10eSMatthias Springer                       state.isTensorYielded(opResult);
683474d10eSMatthias Springer         escapeAttrValue.push_back(escape);
693474d10eSMatthias Springer       }
703474d10eSMatthias Springer       if (foundTensorResult)
713474d10eSMatthias Springer         op->setAttr(escapeAttrName, rewriter.getBoolArrayAttr(escapeAttrValue));
723b2004e1SMatthias Springer     }
733b2004e1SMatthias Springer 
7487c770bbSMatthias Springer     // Find inplacability conflicts and resolve them. (Typically with explicit
7587c770bbSMatthias Springer     // tensor copies in the form of AllocTensorOps.)
7687c770bbSMatthias Springer     rewriter.setInsertionPoint(op);
7787c770bbSMatthias Springer     if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
783b2004e1SMatthias Springer       return WalkResult::interrupt();
793b2004e1SMatthias Springer 
803b2004e1SMatthias Springer     return WalkResult::advance();
813b2004e1SMatthias Springer   });
823b2004e1SMatthias Springer 
833b2004e1SMatthias Springer   return failure(result.wasInterrupted());
843b2004e1SMatthias Springer }
853b2004e1SMatthias Springer 
863b2004e1SMatthias Springer namespace {
873b2004e1SMatthias Springer struct TensorCopyInsertionPass
883b2004e1SMatthias Springer     : TensorCopyInsertionBase<TensorCopyInsertionPass> {
TensorCopyInsertionPass__anon012e1bb40211::TensorCopyInsertionPass893b2004e1SMatthias Springer   TensorCopyInsertionPass()
903b2004e1SMatthias Springer       : TensorCopyInsertionBase<TensorCopyInsertionPass>(),
913b2004e1SMatthias Springer         options(llvm::None) {}
TensorCopyInsertionPass__anon012e1bb40211::TensorCopyInsertionPass923b2004e1SMatthias Springer   TensorCopyInsertionPass(const OneShotBufferizationOptions &options)
933b2004e1SMatthias Springer       : TensorCopyInsertionBase<TensorCopyInsertionPass>(), options(options) {}
943b2004e1SMatthias Springer 
getDependentDialects__anon012e1bb40211::TensorCopyInsertionPass953b2004e1SMatthias Springer   void getDependentDialects(DialectRegistry &registry) const override {
963b2004e1SMatthias Springer     registry.insert<bufferization::BufferizationDialect>();
973b2004e1SMatthias Springer   }
983b2004e1SMatthias Springer 
runOnOperation__anon012e1bb40211::TensorCopyInsertionPass993b2004e1SMatthias Springer   void runOnOperation() override {
100037f0995SKazu Hirata     if (options) {
1013b2004e1SMatthias Springer       if (failed(insertTensorCopies(getOperation(), *options)))
1023b2004e1SMatthias Springer         signalPassFailure();
1033b2004e1SMatthias Springer     } else {
1043b2004e1SMatthias Springer       OneShotBufferizationOptions options;
1053b2004e1SMatthias Springer       options.allowReturnAllocs = allowReturnAllocs;
1063b2004e1SMatthias Springer       options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
107a36c801dSMatthias Springer       options.createDeallocs = createDeallocs;
108*c0b0b6a0SMatthias Springer       if (mustInferMemorySpace)
109*c0b0b6a0SMatthias Springer         options.defaultMemorySpace = None;
1103b2004e1SMatthias Springer       if (failed(insertTensorCopies(getOperation(), options)))
1113b2004e1SMatthias Springer         signalPassFailure();
1123b2004e1SMatthias Springer     }
1133b2004e1SMatthias Springer   }
1143b2004e1SMatthias Springer 
1153b2004e1SMatthias Springer private:
1163b2004e1SMatthias Springer   Optional<OneShotBufferizationOptions> options;
1173b2004e1SMatthias Springer };
1183b2004e1SMatthias Springer } // namespace
1193b2004e1SMatthias Springer 
createTensorCopyInsertionPass()1203b2004e1SMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass() {
1213b2004e1SMatthias Springer   return std::make_unique<TensorCopyInsertionPass>();
1223b2004e1SMatthias Springer }
1233b2004e1SMatthias Springer 
createTensorCopyInsertionPass(const OneShotBufferizationOptions & options)1243b2004e1SMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass(
1253b2004e1SMatthias Springer     const OneShotBufferizationOptions &options) {
1263b2004e1SMatthias Springer   return std::make_unique<TensorCopyInsertionPass>(options);
1273b2004e1SMatthias Springer }
128