1 //===- Bufferize.cpp - Bufferization for func ops -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements bufferization of func.func's and func.call's.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
18 #include "mlir/Dialect/Func/Transforms/Passes.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace mlir;
23 using namespace mlir::func;
24 
25 namespace {
26 struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
27   using FuncBufferizeBase<FuncBufferizePass>::FuncBufferizeBase;
runOnOperation__anon2abbd2450111::FuncBufferizePass28   void runOnOperation() override {
29     auto module = getOperation();
30     auto *context = &getContext();
31 
32     bufferization::BufferizeTypeConverter typeConverter;
33     RewritePatternSet patterns(context);
34     ConversionTarget target(*context);
35 
36     populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
37                                                              typeConverter);
38     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
39       return typeConverter.isSignatureLegal(op.getFunctionType()) &&
40              typeConverter.isLegal(&op.getBody());
41     });
42     populateCallOpTypeConversionPattern(patterns, typeConverter);
43     target.addDynamicallyLegalOp<CallOp>(
44         [&](CallOp op) { return typeConverter.isLegal(op); });
45 
46     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
47     populateReturnOpTypeConversionPattern(patterns, typeConverter);
48     target.addLegalOp<ModuleOp, bufferization::ToTensorOp,
49                       bufferization::ToMemrefOp>();
50 
51     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
52       return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
53              isLegalForBranchOpInterfaceTypeConversionPattern(op,
54                                                               typeConverter) ||
55              isLegalForReturnOpTypeConversionPattern(op, typeConverter);
56     });
57 
58     if (failed(applyFullConversion(module, target, std::move(patterns))))
59       signalPassFailure();
60   }
61 };
62 } // namespace
63 
createFuncBufferizePass()64 std::unique_ptr<Pass> mlir::func::createFuncBufferizePass() {
65   return std::make_unique<FuncBufferizePass>();
66 }
67