1 //===- Bufferize.cpp - Bufferization for func ops -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements bufferization of func.func's and func.call's.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
18 #include "mlir/Dialect/Func/Transforms/Passes.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Transforms/DialectConversion.h"
21
22 using namespace mlir;
23 using namespace mlir::func;
24
25 namespace {
26 struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
27 using FuncBufferizeBase<FuncBufferizePass>::FuncBufferizeBase;
runOnOperation__anon2abbd2450111::FuncBufferizePass28 void runOnOperation() override {
29 auto module = getOperation();
30 auto *context = &getContext();
31
32 bufferization::BufferizeTypeConverter typeConverter;
33 RewritePatternSet patterns(context);
34 ConversionTarget target(*context);
35
36 populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
37 typeConverter);
38 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
39 return typeConverter.isSignatureLegal(op.getFunctionType()) &&
40 typeConverter.isLegal(&op.getBody());
41 });
42 populateCallOpTypeConversionPattern(patterns, typeConverter);
43 target.addDynamicallyLegalOp<CallOp>(
44 [&](CallOp op) { return typeConverter.isLegal(op); });
45
46 populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
47 populateReturnOpTypeConversionPattern(patterns, typeConverter);
48 target.addLegalOp<ModuleOp, bufferization::ToTensorOp,
49 bufferization::ToMemrefOp>();
50
51 target.markUnknownOpDynamicallyLegal([&](Operation *op) {
52 return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
53 isLegalForBranchOpInterfaceTypeConversionPattern(op,
54 typeConverter) ||
55 isLegalForReturnOpTypeConversionPattern(op, typeConverter);
56 });
57
58 if (failed(applyFullConversion(module, target, std::move(patterns))))
59 signalPassFailure();
60 }
61 };
62 } // namespace
63
createFuncBufferizePass()64 std::unique_ptr<Pass> mlir::func::createFuncBufferizePass() {
65 return std::make_unique<FuncBufferizePass>();
66 }
67