1 //===- FuncConversions.cpp - Function conversions -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Transforms/DialectConversion.h"
12
13 using namespace mlir;
14 using namespace mlir::func;
15
16 namespace {
17 /// Converts the operand and result types of the CallOp, used together with the
18 /// FuncOpSignatureConversion.
19 struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
20 using OpConversionPattern<CallOp>::OpConversionPattern;
21
22 /// Hook for derived classes to implement combined matching and rewriting.
23 LogicalResult
matchAndRewrite__anon0778ab1c0111::CallOpSignatureConversion24 matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
25 ConversionPatternRewriter &rewriter) const override {
26 // Convert the original function results.
27 SmallVector<Type, 1> convertedResults;
28 if (failed(typeConverter->convertTypes(callOp.getResultTypes(),
29 convertedResults)))
30 return failure();
31
32 // Substitute with the new result types from the corresponding FuncType
33 // conversion.
34 rewriter.replaceOpWithNewOp<CallOp>(
35 callOp, callOp.getCallee(), convertedResults, adaptor.getOperands());
36 return success();
37 }
38 };
39 } // namespace
40
populateCallOpTypeConversionPattern(RewritePatternSet & patterns,TypeConverter & converter)41 void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
42 TypeConverter &converter) {
43 patterns.add<CallOpSignatureConversion>(converter, patterns.getContext());
44 }
45
46 namespace {
47 /// Only needed to support partial conversion of functions where this pattern
48 /// ensures that the branch operation arguments matches up with the succesor
49 /// block arguments.
50 class BranchOpInterfaceTypeConversion
51 : public OpInterfaceConversionPattern<BranchOpInterface> {
52 public:
53 using OpInterfaceConversionPattern<
54 BranchOpInterface>::OpInterfaceConversionPattern;
55
BranchOpInterfaceTypeConversion(TypeConverter & typeConverter,MLIRContext * ctx,function_ref<bool (BranchOpInterface,int)> shouldConvertBranchOperand)56 BranchOpInterfaceTypeConversion(
57 TypeConverter &typeConverter, MLIRContext *ctx,
58 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand)
59 : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1),
60 shouldConvertBranchOperand(shouldConvertBranchOperand) {}
61
62 LogicalResult
matchAndRewrite(BranchOpInterface op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const63 matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands,
64 ConversionPatternRewriter &rewriter) const final {
65 // For a branch operation, only some operands go to the target blocks, so
66 // only rewrite those.
67 SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
68 for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
69 succIdx < succEnd; ++succIdx) {
70 OperandRange forwardedOperands =
71 op.getSuccessorOperands(succIdx).getForwardedOperands();
72 if (forwardedOperands.empty())
73 continue;
74
75 for (int idx = forwardedOperands.getBeginOperandIndex(),
76 eidx = idx + forwardedOperands.size();
77 idx < eidx; ++idx) {
78 if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx))
79 newOperands[idx] = operands[idx];
80 }
81 }
82 rewriter.updateRootInPlace(
83 op, [newOperands, op]() { op->setOperands(newOperands); });
84 return success();
85 }
86
87 private:
88 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand;
89 };
90 } // namespace
91
92 namespace {
93 /// Only needed to support partial conversion of functions where this pattern
94 /// ensures that the branch operation arguments matches up with the succesor
95 /// block arguments.
96 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
97 public:
98 using OpConversionPattern<ReturnOp>::OpConversionPattern;
99
100 LogicalResult
matchAndRewrite(ReturnOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const101 matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
102 ConversionPatternRewriter &rewriter) const final {
103 // For a return, all operands go to the results of the parent, so
104 // rewrite them all.
105 rewriter.updateRootInPlace(op,
106 [&] { op->setOperands(adaptor.getOperands()); });
107 return success();
108 }
109 };
110 } // namespace
111
populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet & patterns,TypeConverter & typeConverter,function_ref<bool (BranchOpInterface,int)> shouldConvertBranchOperand)112 void mlir::populateBranchOpInterfaceTypeConversionPattern(
113 RewritePatternSet &patterns, TypeConverter &typeConverter,
114 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) {
115 patterns.add<BranchOpInterfaceTypeConversion>(
116 typeConverter, patterns.getContext(), shouldConvertBranchOperand);
117 }
118
isLegalForBranchOpInterfaceTypeConversionPattern(Operation * op,TypeConverter & converter)119 bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
120 Operation *op, TypeConverter &converter) {
121 // All successor operands of branch like operations must be rewritten.
122 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
123 for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
124 auto successorOperands = branchOp.getSuccessorOperands(p);
125 if (!converter.isLegal(
126 successorOperands.getForwardedOperands().getTypes()))
127 return false;
128 }
129 return true;
130 }
131
132 return false;
133 }
134
populateReturnOpTypeConversionPattern(RewritePatternSet & patterns,TypeConverter & typeConverter)135 void mlir::populateReturnOpTypeConversionPattern(RewritePatternSet &patterns,
136 TypeConverter &typeConverter) {
137 patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext());
138 }
139
isLegalForReturnOpTypeConversionPattern(Operation * op,TypeConverter & converter,bool returnOpAlwaysLegal)140 bool mlir::isLegalForReturnOpTypeConversionPattern(Operation *op,
141 TypeConverter &converter,
142 bool returnOpAlwaysLegal) {
143 // If this is a `return` and the user pass wants to convert/transform across
144 // function boundaries, then `converter` is invoked to check whether the the
145 // `return` op is legal.
146 if (dyn_cast<ReturnOp>(op) && !returnOpAlwaysLegal)
147 return converter.isLegal(op);
148
149 // ReturnLike operations have to be legalized with their parent. For
150 // return this is handled, for other ops they remain as is.
151 return op->hasTrait<OpTrait::ReturnLike>();
152 }
153
isNotBranchOpInterfaceOrReturnLikeOp(Operation * op)154 bool mlir::isNotBranchOpInterfaceOrReturnLikeOp(Operation *op) {
155 // If it is not a terminator, ignore it.
156 if (!op->mightHaveTrait<OpTrait::IsTerminator>())
157 return true;
158
159 // If it is not the last operation in the block, also ignore it. We do
160 // this to handle unknown operations, as well.
161 Block *block = op->getBlock();
162 if (!block || &block->back() != op)
163 return true;
164
165 // We don't want to handle terminators in nested regions, assume they are
166 // always legal.
167 if (!isa_and_nonnull<FuncOp>(op->getParentOp()))
168 return true;
169
170 return false;
171 }
172