1 //===- FuncConversions.cpp - Function conversions -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Transforms/DialectConversion.h"
12 
13 using namespace mlir;
14 using namespace mlir::func;
15 
16 namespace {
17 /// Converts the operand and result types of the CallOp, used together with the
18 /// FuncOpSignatureConversion.
19 struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
20   using OpConversionPattern<CallOp>::OpConversionPattern;
21 
22   /// Hook for derived classes to implement combined matching and rewriting.
23   LogicalResult
24   matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
25                   ConversionPatternRewriter &rewriter) const override {
26     // Convert the original function results.
27     SmallVector<Type, 1> convertedResults;
28     if (failed(typeConverter->convertTypes(callOp.getResultTypes(),
29                                            convertedResults)))
30       return failure();
31 
32     // Substitute with the new result types from the corresponding FuncType
33     // conversion.
34     rewriter.replaceOpWithNewOp<CallOp>(
35         callOp, callOp.getCallee(), convertedResults, adaptor.getOperands());
36     return success();
37   }
38 };
39 } // namespace
40 
41 void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
42                                                TypeConverter &converter) {
43   patterns.add<CallOpSignatureConversion>(converter, patterns.getContext());
44 }
45 
46 namespace {
47 /// Only needed to support partial conversion of functions where this pattern
48 /// ensures that the branch operation arguments matches up with the succesor
49 /// block arguments.
50 class BranchOpInterfaceTypeConversion
51     : public OpInterfaceConversionPattern<BranchOpInterface> {
52 public:
53   using OpInterfaceConversionPattern<
54       BranchOpInterface>::OpInterfaceConversionPattern;
55 
56   BranchOpInterfaceTypeConversion(
57       TypeConverter &typeConverter, MLIRContext *ctx,
58       function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand)
59       : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1),
60         shouldConvertBranchOperand(shouldConvertBranchOperand) {}
61 
62   LogicalResult
63   matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands,
64                   ConversionPatternRewriter &rewriter) const final {
65     // For a branch operation, only some operands go to the target blocks, so
66     // only rewrite those.
67     SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
68     for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
69          succIdx < succEnd; ++succIdx) {
70       auto successorOperands = op.getSuccessorOperands(succIdx);
71       if (!successorOperands || successorOperands->empty())
72         continue;
73 
74       for (int idx = successorOperands->getBeginOperandIndex(),
75                eidx = idx + successorOperands->size();
76            idx < eidx; ++idx) {
77         if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx))
78           newOperands[idx] = operands[idx];
79       }
80     }
81     rewriter.updateRootInPlace(
82         op, [newOperands, op]() { op->setOperands(newOperands); });
83     return success();
84   }
85 
86 private:
87   function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand;
88 };
89 } // namespace
90 
91 namespace {
92 /// Only needed to support partial conversion of functions where this pattern
93 /// ensures that the branch operation arguments matches up with the succesor
94 /// block arguments.
95 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
96 public:
97   using OpConversionPattern<ReturnOp>::OpConversionPattern;
98 
99   LogicalResult
100   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
101                   ConversionPatternRewriter &rewriter) const final {
102     // For a return, all operands go to the results of the parent, so
103     // rewrite them all.
104     rewriter.updateRootInPlace(op,
105                                [&] { op->setOperands(adaptor.getOperands()); });
106     return success();
107   }
108 };
109 } // namespace
110 
111 void mlir::populateBranchOpInterfaceTypeConversionPattern(
112     RewritePatternSet &patterns, TypeConverter &typeConverter,
113     function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) {
114   patterns.add<BranchOpInterfaceTypeConversion>(
115       typeConverter, patterns.getContext(), shouldConvertBranchOperand);
116 }
117 
118 bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
119     Operation *op, TypeConverter &converter) {
120   // All successor operands of branch like operations must be rewritten.
121   if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
122     for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
123       auto successorOperands = branchOp.getSuccessorOperands(p);
124       if (successorOperands.hasValue() &&
125           !converter.isLegal(successorOperands.getValue().getTypes()))
126         return false;
127     }
128     return true;
129   }
130 
131   return false;
132 }
133 
134 void mlir::populateReturnOpTypeConversionPattern(RewritePatternSet &patterns,
135                                                  TypeConverter &typeConverter) {
136   patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext());
137 }
138 
139 bool mlir::isLegalForReturnOpTypeConversionPattern(Operation *op,
140                                                    TypeConverter &converter,
141                                                    bool returnOpAlwaysLegal) {
142   // If this is a `return` and the user pass wants to convert/transform across
143   // function boundaries, then `converter` is invoked to check whether the the
144   // `return` op is legal.
145   if (dyn_cast<ReturnOp>(op) && !returnOpAlwaysLegal)
146     return converter.isLegal(op);
147 
148   // ReturnLike operations have to be legalized with their parent. For
149   // return this is handled, for other ops they remain as is.
150   return op->hasTrait<OpTrait::ReturnLike>();
151 }
152 
153 bool mlir::isNotBranchOpInterfaceOrReturnLikeOp(Operation *op) {
154   // If it is not a terminator, ignore it.
155   if (!op->mightHaveTrait<OpTrait::IsTerminator>())
156     return true;
157 
158   // If it is not the last operation in the block, also ignore it. We do
159   // this to handle unknown operations, as well.
160   Block *block = op->getBlock();
161   if (!block || &block->back() != op)
162     return true;
163 
164   // We don't want to handle terminators in nested regions, assume they are
165   // always legal.
166   if (!isa_and_nonnull<FuncOp>(op->getParentOp()))
167     return true;
168 
169   return false;
170 }
171