1 //===- FuncConversions.cpp - Function conversions -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 10 #include "mlir/Dialect/Func/IR/FuncOps.h" 11 #include "mlir/Transforms/DialectConversion.h" 12 13 using namespace mlir; 14 using namespace mlir::func; 15 16 namespace { 17 /// Converts the operand and result types of the CallOp, used together with the 18 /// FuncOpSignatureConversion. 19 struct CallOpSignatureConversion : public OpConversionPattern<CallOp> { 20 using OpConversionPattern<CallOp>::OpConversionPattern; 21 22 /// Hook for derived classes to implement combined matching and rewriting. 23 LogicalResult 24 matchAndRewrite(CallOp callOp, OpAdaptor adaptor, 25 ConversionPatternRewriter &rewriter) const override { 26 // Convert the original function results. 27 SmallVector<Type, 1> convertedResults; 28 if (failed(typeConverter->convertTypes(callOp.getResultTypes(), 29 convertedResults))) 30 return failure(); 31 32 // Substitute with the new result types from the corresponding FuncType 33 // conversion. 34 rewriter.replaceOpWithNewOp<CallOp>( 35 callOp, callOp.getCallee(), convertedResults, adaptor.getOperands()); 36 return success(); 37 } 38 }; 39 } // namespace 40 41 void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns, 42 TypeConverter &converter) { 43 patterns.add<CallOpSignatureConversion>(converter, patterns.getContext()); 44 } 45 46 namespace { 47 /// Only needed to support partial conversion of functions where this pattern 48 /// ensures that the branch operation arguments matches up with the succesor 49 /// block arguments. 50 class BranchOpInterfaceTypeConversion 51 : public OpInterfaceConversionPattern<BranchOpInterface> { 52 public: 53 using OpInterfaceConversionPattern< 54 BranchOpInterface>::OpInterfaceConversionPattern; 55 56 BranchOpInterfaceTypeConversion( 57 TypeConverter &typeConverter, MLIRContext *ctx, 58 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) 59 : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1), 60 shouldConvertBranchOperand(shouldConvertBranchOperand) {} 61 62 LogicalResult 63 matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands, 64 ConversionPatternRewriter &rewriter) const final { 65 // For a branch operation, only some operands go to the target blocks, so 66 // only rewrite those. 67 SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end()); 68 for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors(); 69 succIdx < succEnd; ++succIdx) { 70 OperandRange forwardedOperands = 71 op.getSuccessorOperands(succIdx).getForwardedOperands(); 72 if (forwardedOperands.empty()) 73 continue; 74 75 for (int idx = forwardedOperands.getBeginOperandIndex(), 76 eidx = idx + forwardedOperands.size(); 77 idx < eidx; ++idx) { 78 if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx)) 79 newOperands[idx] = operands[idx]; 80 } 81 } 82 rewriter.updateRootInPlace( 83 op, [newOperands, op]() { op->setOperands(newOperands); }); 84 return success(); 85 } 86 87 private: 88 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand; 89 }; 90 } // namespace 91 92 namespace { 93 /// Only needed to support partial conversion of functions where this pattern 94 /// ensures that the branch operation arguments matches up with the succesor 95 /// block arguments. 96 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> { 97 public: 98 using OpConversionPattern<ReturnOp>::OpConversionPattern; 99 100 LogicalResult 101 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 102 ConversionPatternRewriter &rewriter) const final { 103 // For a return, all operands go to the results of the parent, so 104 // rewrite them all. 105 rewriter.updateRootInPlace(op, 106 [&] { op->setOperands(adaptor.getOperands()); }); 107 return success(); 108 } 109 }; 110 } // namespace 111 112 void mlir::populateBranchOpInterfaceTypeConversionPattern( 113 RewritePatternSet &patterns, TypeConverter &typeConverter, 114 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) { 115 patterns.add<BranchOpInterfaceTypeConversion>( 116 typeConverter, patterns.getContext(), shouldConvertBranchOperand); 117 } 118 119 bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern( 120 Operation *op, TypeConverter &converter) { 121 // All successor operands of branch like operations must be rewritten. 122 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { 123 for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) { 124 auto successorOperands = branchOp.getSuccessorOperands(p); 125 if (!converter.isLegal( 126 successorOperands.getForwardedOperands().getTypes())) 127 return false; 128 } 129 return true; 130 } 131 132 return false; 133 } 134 135 void mlir::populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, 136 TypeConverter &typeConverter) { 137 patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext()); 138 } 139 140 bool mlir::isLegalForReturnOpTypeConversionPattern(Operation *op, 141 TypeConverter &converter, 142 bool returnOpAlwaysLegal) { 143 // If this is a `return` and the user pass wants to convert/transform across 144 // function boundaries, then `converter` is invoked to check whether the the 145 // `return` op is legal. 146 if (dyn_cast<ReturnOp>(op) && !returnOpAlwaysLegal) 147 return converter.isLegal(op); 148 149 // ReturnLike operations have to be legalized with their parent. For 150 // return this is handled, for other ops they remain as is. 151 return op->hasTrait<OpTrait::ReturnLike>(); 152 } 153 154 bool mlir::isNotBranchOpInterfaceOrReturnLikeOp(Operation *op) { 155 // If it is not a terminator, ignore it. 156 if (!op->mightHaveTrait<OpTrait::IsTerminator>()) 157 return true; 158 159 // If it is not the last operation in the block, also ignore it. We do 160 // this to handle unknown operations, as well. 161 Block *block = op->getBlock(); 162 if (!block || &block->back() != op) 163 return true; 164 165 // We don't want to handle terminators in nested regions, assume they are 166 // always legal. 167 if (!isa_and_nonnull<FuncOp>(op->getParentOp())) 168 return true; 169 170 return false; 171 } 172