13b2004e1SMatthias Springer //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===//
23b2004e1SMatthias Springer //
33b2004e1SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43b2004e1SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
53b2004e1SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63b2004e1SMatthias Springer //
73b2004e1SMatthias Springer //===----------------------------------------------------------------------===//
83b2004e1SMatthias Springer
93b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
103b2004e1SMatthias Springer
113b2004e1SMatthias Springer #include "PassDetail.h"
123b2004e1SMatthias Springer
133b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
143b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
153b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
163b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
173b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
183b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
193b2004e1SMatthias Springer
203b2004e1SMatthias Springer using namespace mlir;
213b2004e1SMatthias Springer using namespace mlir::bufferization;
223b2004e1SMatthias Springer
insertTensorCopies(Operation * op,const OneShotBufferizationOptions & options)233b2004e1SMatthias Springer LogicalResult mlir::bufferization::insertTensorCopies(
243b2004e1SMatthias Springer Operation *op, const OneShotBufferizationOptions &options) {
253b2004e1SMatthias Springer OneShotAnalysisState state(op, options);
263b2004e1SMatthias Springer // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
273b2004e1SMatthias Springer // analysis depending on whether function boundary bufferization is enabled or
283b2004e1SMatthias Springer // not.
293b2004e1SMatthias Springer if (options.bufferizeFunctionBoundaries) {
303b2004e1SMatthias Springer if (failed(analyzeModuleOp(cast<ModuleOp>(op), state)))
313b2004e1SMatthias Springer return failure();
323b2004e1SMatthias Springer } else {
333b2004e1SMatthias Springer if (failed(analyzeOp(op, state)))
343b2004e1SMatthias Springer return failure();
353b2004e1SMatthias Springer }
363b2004e1SMatthias Springer
373b2004e1SMatthias Springer if (options.testAnalysisOnly)
383b2004e1SMatthias Springer return success();
393b2004e1SMatthias Springer
403b2004e1SMatthias Springer return insertTensorCopies(op, state);
413b2004e1SMatthias Springer }
423b2004e1SMatthias Springer
433b2004e1SMatthias Springer LogicalResult
insertTensorCopies(Operation * op,const AnalysisState & state)443b2004e1SMatthias Springer mlir::bufferization::insertTensorCopies(Operation *op,
453b2004e1SMatthias Springer const AnalysisState &state) {
4687c770bbSMatthias Springer IRRewriter rewriter(op->getContext());
473474d10eSMatthias Springer StringRef escapeAttrName = BufferizationDialect::kEscapeAttrName;
483474d10eSMatthias Springer
493b2004e1SMatthias Springer WalkResult result = op->walk([&](Operation *op) {
503b2004e1SMatthias Springer auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
513b2004e1SMatthias Springer if (!bufferizableOp)
523b2004e1SMatthias Springer return WalkResult::skip();
533b2004e1SMatthias Springer
543474d10eSMatthias Springer // Find allocations without an `escape` attribute and add the attribute
553b2004e1SMatthias Springer // based on analysis results.
563474d10eSMatthias Springer if (!op->hasAttr(escapeAttrName)) {
573474d10eSMatthias Springer SmallVector<bool> escapeAttrValue;
583474d10eSMatthias Springer bool foundTensorResult = false;
593474d10eSMatthias Springer for (OpResult opResult : op->getOpResults()) {
603474d10eSMatthias Springer if (!opResult.getType().isa<TensorType>() ||
613474d10eSMatthias Springer !bufferizableOp.bufferizesToAllocation(opResult)) {
623474d10eSMatthias Springer escapeAttrValue.push_back(false);
633474d10eSMatthias Springer continue;
643474d10eSMatthias Springer }
653474d10eSMatthias Springer foundTensorResult = true;
66a36c801dSMatthias Springer bool escape = !state.getOptions().createDeallocs ||
673474d10eSMatthias Springer state.isTensorYielded(opResult);
683474d10eSMatthias Springer escapeAttrValue.push_back(escape);
693474d10eSMatthias Springer }
703474d10eSMatthias Springer if (foundTensorResult)
713474d10eSMatthias Springer op->setAttr(escapeAttrName, rewriter.getBoolArrayAttr(escapeAttrValue));
723b2004e1SMatthias Springer }
733b2004e1SMatthias Springer
7487c770bbSMatthias Springer // Find inplacability conflicts and resolve them. (Typically with explicit
7587c770bbSMatthias Springer // tensor copies in the form of AllocTensorOps.)
7687c770bbSMatthias Springer rewriter.setInsertionPoint(op);
7787c770bbSMatthias Springer if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
783b2004e1SMatthias Springer return WalkResult::interrupt();
793b2004e1SMatthias Springer
803b2004e1SMatthias Springer return WalkResult::advance();
813b2004e1SMatthias Springer });
823b2004e1SMatthias Springer
833b2004e1SMatthias Springer return failure(result.wasInterrupted());
843b2004e1SMatthias Springer }
853b2004e1SMatthias Springer
863b2004e1SMatthias Springer namespace {
873b2004e1SMatthias Springer struct TensorCopyInsertionPass
883b2004e1SMatthias Springer : TensorCopyInsertionBase<TensorCopyInsertionPass> {
TensorCopyInsertionPass__anon012e1bb40211::TensorCopyInsertionPass893b2004e1SMatthias Springer TensorCopyInsertionPass()
903b2004e1SMatthias Springer : TensorCopyInsertionBase<TensorCopyInsertionPass>(),
913b2004e1SMatthias Springer options(llvm::None) {}
TensorCopyInsertionPass__anon012e1bb40211::TensorCopyInsertionPass923b2004e1SMatthias Springer TensorCopyInsertionPass(const OneShotBufferizationOptions &options)
933b2004e1SMatthias Springer : TensorCopyInsertionBase<TensorCopyInsertionPass>(), options(options) {}
943b2004e1SMatthias Springer
getDependentDialects__anon012e1bb40211::TensorCopyInsertionPass953b2004e1SMatthias Springer void getDependentDialects(DialectRegistry ®istry) const override {
963b2004e1SMatthias Springer registry.insert<bufferization::BufferizationDialect>();
973b2004e1SMatthias Springer }
983b2004e1SMatthias Springer
runOnOperation__anon012e1bb40211::TensorCopyInsertionPass993b2004e1SMatthias Springer void runOnOperation() override {
100037f0995SKazu Hirata if (options) {
1013b2004e1SMatthias Springer if (failed(insertTensorCopies(getOperation(), *options)))
1023b2004e1SMatthias Springer signalPassFailure();
1033b2004e1SMatthias Springer } else {
1043b2004e1SMatthias Springer OneShotBufferizationOptions options;
1053b2004e1SMatthias Springer options.allowReturnAllocs = allowReturnAllocs;
1063b2004e1SMatthias Springer options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
107a36c801dSMatthias Springer options.createDeallocs = createDeallocs;
108*c0b0b6a0SMatthias Springer if (mustInferMemorySpace)
109*c0b0b6a0SMatthias Springer options.defaultMemorySpace = None;
1103b2004e1SMatthias Springer if (failed(insertTensorCopies(getOperation(), options)))
1113b2004e1SMatthias Springer signalPassFailure();
1123b2004e1SMatthias Springer }
1133b2004e1SMatthias Springer }
1143b2004e1SMatthias Springer
1153b2004e1SMatthias Springer private:
1163b2004e1SMatthias Springer Optional<OneShotBufferizationOptions> options;
1173b2004e1SMatthias Springer };
1183b2004e1SMatthias Springer } // namespace
1193b2004e1SMatthias Springer
createTensorCopyInsertionPass()1203b2004e1SMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass() {
1213b2004e1SMatthias Springer return std::make_unique<TensorCopyInsertionPass>();
1223b2004e1SMatthias Springer }
1233b2004e1SMatthias Springer
createTensorCopyInsertionPass(const OneShotBufferizationOptions & options)1243b2004e1SMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass(
1253b2004e1SMatthias Springer const OneShotBufferizationOptions &options) {
1263b2004e1SMatthias Springer return std::make_unique<TensorCopyInsertionPass>(options);
1273b2004e1SMatthias Springer }
128