123aa5a74SRiver Riddle //===- Bufferize.cpp - Bufferization for func ops -------------------------===//
223aa5a74SRiver Riddle //
323aa5a74SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
423aa5a74SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
523aa5a74SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
623aa5a74SRiver Riddle //
723aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
823aa5a74SRiver Riddle //
936550692SRiver Riddle // This file implements bufferization of func.func's and func.call's.
1023aa5a74SRiver Riddle //
1123aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
1223aa5a74SRiver Riddle 
1323aa5a74SRiver Riddle #include "PassDetail.h"
1423aa5a74SRiver Riddle #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1523aa5a74SRiver Riddle #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
1623aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1723aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
1823aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/Passes.h"
19*eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
2023aa5a74SRiver Riddle #include "mlir/Transforms/DialectConversion.h"
2123aa5a74SRiver Riddle 
2223aa5a74SRiver Riddle using namespace mlir;
2323aa5a74SRiver Riddle using namespace mlir::func;
2423aa5a74SRiver Riddle 
2523aa5a74SRiver Riddle namespace {
2623aa5a74SRiver Riddle struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
2723aa5a74SRiver Riddle   using FuncBufferizeBase<FuncBufferizePass>::FuncBufferizeBase;
runOnOperation__anon2abbd2450111::FuncBufferizePass2823aa5a74SRiver Riddle   void runOnOperation() override {
2923aa5a74SRiver Riddle     auto module = getOperation();
3023aa5a74SRiver Riddle     auto *context = &getContext();
3123aa5a74SRiver Riddle 
3223aa5a74SRiver Riddle     bufferization::BufferizeTypeConverter typeConverter;
3323aa5a74SRiver Riddle     RewritePatternSet patterns(context);
3423aa5a74SRiver Riddle     ConversionTarget target(*context);
3523aa5a74SRiver Riddle 
3623aa5a74SRiver Riddle     populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
3723aa5a74SRiver Riddle                                                              typeConverter);
3823aa5a74SRiver Riddle     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
394a3460a7SRiver Riddle       return typeConverter.isSignatureLegal(op.getFunctionType()) &&
4023aa5a74SRiver Riddle              typeConverter.isLegal(&op.getBody());
4123aa5a74SRiver Riddle     });
4223aa5a74SRiver Riddle     populateCallOpTypeConversionPattern(patterns, typeConverter);
4323aa5a74SRiver Riddle     target.addDynamicallyLegalOp<CallOp>(
4423aa5a74SRiver Riddle         [&](CallOp op) { return typeConverter.isLegal(op); });
4523aa5a74SRiver Riddle 
4623aa5a74SRiver Riddle     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
4723aa5a74SRiver Riddle     populateReturnOpTypeConversionPattern(patterns, typeConverter);
4823aa5a74SRiver Riddle     target.addLegalOp<ModuleOp, bufferization::ToTensorOp,
4923aa5a74SRiver Riddle                       bufferization::ToMemrefOp>();
5023aa5a74SRiver Riddle 
5123aa5a74SRiver Riddle     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
5223aa5a74SRiver Riddle       return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
5323aa5a74SRiver Riddle              isLegalForBranchOpInterfaceTypeConversionPattern(op,
5423aa5a74SRiver Riddle                                                               typeConverter) ||
5523aa5a74SRiver Riddle              isLegalForReturnOpTypeConversionPattern(op, typeConverter);
5623aa5a74SRiver Riddle     });
5723aa5a74SRiver Riddle 
5823aa5a74SRiver Riddle     if (failed(applyFullConversion(module, target, std::move(patterns))))
5923aa5a74SRiver Riddle       signalPassFailure();
6023aa5a74SRiver Riddle   }
6123aa5a74SRiver Riddle };
6223aa5a74SRiver Riddle } // namespace
6323aa5a74SRiver Riddle 
createFuncBufferizePass()6423aa5a74SRiver Riddle std::unique_ptr<Pass> mlir::func::createFuncBufferizePass() {
6523aa5a74SRiver Riddle   return std::make_unique<FuncBufferizePass>();
6623aa5a74SRiver Riddle }
67