1 //===- FuncConversions.cpp - Function conversions -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 10 #include "mlir/Dialect/Func/IR/FuncOps.h" 11 #include "mlir/Transforms/DialectConversion.h" 12 13 using namespace mlir; 14 using namespace mlir::func; 15 16 namespace { 17 /// Converts the operand and result types of the CallOp, used together with the 18 /// FuncOpSignatureConversion. 19 struct CallOpSignatureConversion : public OpConversionPattern<CallOp> { 20 using OpConversionPattern<CallOp>::OpConversionPattern; 21 22 /// Hook for derived classes to implement combined matching and rewriting. 23 LogicalResult 24 matchAndRewrite(CallOp callOp, OpAdaptor adaptor, 25 ConversionPatternRewriter &rewriter) const override { 26 // Convert the original function results. 27 SmallVector<Type, 1> convertedResults; 28 if (failed(typeConverter->convertTypes(callOp.getResultTypes(), 29 convertedResults))) 30 return failure(); 31 32 // Substitute with the new result types from the corresponding FuncType 33 // conversion. 34 rewriter.replaceOpWithNewOp<CallOp>( 35 callOp, callOp.getCallee(), convertedResults, adaptor.getOperands()); 36 return success(); 37 } 38 }; 39 } // namespace 40 41 void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns, 42 TypeConverter &converter) { 43 patterns.add<CallOpSignatureConversion>(converter, patterns.getContext()); 44 } 45 46 namespace { 47 /// Only needed to support partial conversion of functions where this pattern 48 /// ensures that the branch operation arguments matches up with the succesor 49 /// block arguments. 50 class BranchOpInterfaceTypeConversion 51 : public OpInterfaceConversionPattern<BranchOpInterface> { 52 public: 53 using OpInterfaceConversionPattern< 54 BranchOpInterface>::OpInterfaceConversionPattern; 55 56 BranchOpInterfaceTypeConversion( 57 TypeConverter &typeConverter, MLIRContext *ctx, 58 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) 59 : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1), 60 shouldConvertBranchOperand(shouldConvertBranchOperand) {} 61 62 LogicalResult 63 matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands, 64 ConversionPatternRewriter &rewriter) const final { 65 // For a branch operation, only some operands go to the target blocks, so 66 // only rewrite those. 67 SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end()); 68 for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors(); 69 succIdx < succEnd; ++succIdx) { 70 auto successorOperands = op.getSuccessorOperands(succIdx); 71 if (!successorOperands || successorOperands->empty()) 72 continue; 73 74 for (int idx = successorOperands->getBeginOperandIndex(), 75 eidx = idx + successorOperands->size(); 76 idx < eidx; ++idx) { 77 if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx)) 78 newOperands[idx] = operands[idx]; 79 } 80 } 81 rewriter.updateRootInPlace( 82 op, [newOperands, op]() { op->setOperands(newOperands); }); 83 return success(); 84 } 85 86 private: 87 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand; 88 }; 89 } // namespace 90 91 namespace { 92 /// Only needed to support partial conversion of functions where this pattern 93 /// ensures that the branch operation arguments matches up with the succesor 94 /// block arguments. 95 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> { 96 public: 97 using OpConversionPattern<ReturnOp>::OpConversionPattern; 98 99 LogicalResult 100 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 101 ConversionPatternRewriter &rewriter) const final { 102 // For a return, all operands go to the results of the parent, so 103 // rewrite them all. 104 rewriter.updateRootInPlace(op, 105 [&] { op->setOperands(adaptor.getOperands()); }); 106 return success(); 107 } 108 }; 109 } // namespace 110 111 void mlir::populateBranchOpInterfaceTypeConversionPattern( 112 RewritePatternSet &patterns, TypeConverter &typeConverter, 113 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) { 114 patterns.add<BranchOpInterfaceTypeConversion>( 115 typeConverter, patterns.getContext(), shouldConvertBranchOperand); 116 } 117 118 bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern( 119 Operation *op, TypeConverter &converter) { 120 // All successor operands of branch like operations must be rewritten. 121 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { 122 for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) { 123 auto successorOperands = branchOp.getSuccessorOperands(p); 124 if (successorOperands.hasValue() && 125 !converter.isLegal(successorOperands.getValue().getTypes())) 126 return false; 127 } 128 return true; 129 } 130 131 return false; 132 } 133 134 void mlir::populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, 135 TypeConverter &typeConverter) { 136 patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext()); 137 } 138 139 bool mlir::isLegalForReturnOpTypeConversionPattern(Operation *op, 140 TypeConverter &converter, 141 bool returnOpAlwaysLegal) { 142 // If this is a `return` and the user pass wants to convert/transform across 143 // function boundaries, then `converter` is invoked to check whether the the 144 // `return` op is legal. 145 if (dyn_cast<ReturnOp>(op) && !returnOpAlwaysLegal) 146 return converter.isLegal(op); 147 148 // ReturnLike operations have to be legalized with their parent. For 149 // return this is handled, for other ops they remain as is. 150 return op->hasTrait<OpTrait::ReturnLike>(); 151 } 152 153 bool mlir::isNotBranchOpInterfaceOrReturnLikeOp(Operation *op) { 154 // If it is not a terminator, ignore it. 155 if (!op->mightHaveTrait<OpTrait::IsTerminator>()) 156 return true; 157 158 // If it is not the last operation in the block, also ignore it. We do 159 // this to handle unknown operations, as well. 160 Block *block = op->getBlock(); 161 if (!block || &block->back() != op) 162 return true; 163 164 // We don't want to handle terminators in nested regions, assume they are 165 // always legal. 166 if (!isa_and_nonnull<FuncOp>(op->getParentOp())) 167 return true; 168 169 return false; 170 } 171