123aa5a74SRiver Riddle //===- Bufferize.cpp - Bufferization for func ops -------------------------===//
223aa5a74SRiver Riddle //
323aa5a74SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
423aa5a74SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
523aa5a74SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
623aa5a74SRiver Riddle //
723aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
823aa5a74SRiver Riddle //
936550692SRiver Riddle // This file implements bufferization of func.func's and func.call's.
1023aa5a74SRiver Riddle //
1123aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
1223aa5a74SRiver Riddle
1323aa5a74SRiver Riddle #include "PassDetail.h"
1423aa5a74SRiver Riddle #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1523aa5a74SRiver Riddle #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
1623aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1723aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
1823aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/Passes.h"
19*eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
2023aa5a74SRiver Riddle #include "mlir/Transforms/DialectConversion.h"
2123aa5a74SRiver Riddle
2223aa5a74SRiver Riddle using namespace mlir;
2323aa5a74SRiver Riddle using namespace mlir::func;
2423aa5a74SRiver Riddle
2523aa5a74SRiver Riddle namespace {
2623aa5a74SRiver Riddle struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
2723aa5a74SRiver Riddle using FuncBufferizeBase<FuncBufferizePass>::FuncBufferizeBase;
runOnOperation__anon2abbd2450111::FuncBufferizePass2823aa5a74SRiver Riddle void runOnOperation() override {
2923aa5a74SRiver Riddle auto module = getOperation();
3023aa5a74SRiver Riddle auto *context = &getContext();
3123aa5a74SRiver Riddle
3223aa5a74SRiver Riddle bufferization::BufferizeTypeConverter typeConverter;
3323aa5a74SRiver Riddle RewritePatternSet patterns(context);
3423aa5a74SRiver Riddle ConversionTarget target(*context);
3523aa5a74SRiver Riddle
3623aa5a74SRiver Riddle populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
3723aa5a74SRiver Riddle typeConverter);
3823aa5a74SRiver Riddle target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
394a3460a7SRiver Riddle return typeConverter.isSignatureLegal(op.getFunctionType()) &&
4023aa5a74SRiver Riddle typeConverter.isLegal(&op.getBody());
4123aa5a74SRiver Riddle });
4223aa5a74SRiver Riddle populateCallOpTypeConversionPattern(patterns, typeConverter);
4323aa5a74SRiver Riddle target.addDynamicallyLegalOp<CallOp>(
4423aa5a74SRiver Riddle [&](CallOp op) { return typeConverter.isLegal(op); });
4523aa5a74SRiver Riddle
4623aa5a74SRiver Riddle populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
4723aa5a74SRiver Riddle populateReturnOpTypeConversionPattern(patterns, typeConverter);
4823aa5a74SRiver Riddle target.addLegalOp<ModuleOp, bufferization::ToTensorOp,
4923aa5a74SRiver Riddle bufferization::ToMemrefOp>();
5023aa5a74SRiver Riddle
5123aa5a74SRiver Riddle target.markUnknownOpDynamicallyLegal([&](Operation *op) {
5223aa5a74SRiver Riddle return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
5323aa5a74SRiver Riddle isLegalForBranchOpInterfaceTypeConversionPattern(op,
5423aa5a74SRiver Riddle typeConverter) ||
5523aa5a74SRiver Riddle isLegalForReturnOpTypeConversionPattern(op, typeConverter);
5623aa5a74SRiver Riddle });
5723aa5a74SRiver Riddle
5823aa5a74SRiver Riddle if (failed(applyFullConversion(module, target, std::move(patterns))))
5923aa5a74SRiver Riddle signalPassFailure();
6023aa5a74SRiver Riddle }
6123aa5a74SRiver Riddle };
6223aa5a74SRiver Riddle } // namespace
6323aa5a74SRiver Riddle
createFuncBufferizePass()6423aa5a74SRiver Riddle std::unique_ptr<Pass> mlir::func::createFuncBufferizePass() {
6523aa5a74SRiver Riddle return std::make_unique<FuncBufferizePass>();
6623aa5a74SRiver Riddle }
67