1 //===- Bufferize.cpp - Bufferization for func ops -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements bufferization of builtin.func's and func.call's.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
18 #include "mlir/Dialect/Func/Transforms/Passes.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 using namespace mlir;
22 using namespace mlir::func;
23 
24 namespace {
25 struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
26   using FuncBufferizeBase<FuncBufferizePass>::FuncBufferizeBase;
27   void runOnOperation() override {
28     auto module = getOperation();
29     auto *context = &getContext();
30 
31     bufferization::BufferizeTypeConverter typeConverter;
32     RewritePatternSet patterns(context);
33     ConversionTarget target(*context);
34 
35     populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
36                                                              typeConverter);
37     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
38       return typeConverter.isSignatureLegal(op.getType()) &&
39              typeConverter.isLegal(&op.getBody());
40     });
41     populateCallOpTypeConversionPattern(patterns, typeConverter);
42     target.addDynamicallyLegalOp<CallOp>(
43         [&](CallOp op) { return typeConverter.isLegal(op); });
44 
45     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
46     populateReturnOpTypeConversionPattern(patterns, typeConverter);
47     target.addLegalOp<ModuleOp, bufferization::ToTensorOp,
48                       bufferization::ToMemrefOp>();
49 
50     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
51       return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
52              isLegalForBranchOpInterfaceTypeConversionPattern(op,
53                                                               typeConverter) ||
54              isLegalForReturnOpTypeConversionPattern(op, typeConverter);
55     });
56 
57     if (failed(applyFullConversion(module, target, std::move(patterns))))
58       signalPassFailure();
59   }
60 };
61 } // namespace
62 
63 std::unique_ptr<Pass> mlir::func::createFuncBufferizePass() {
64   return std::make_unique<FuncBufferizePass>();
65 }
66