123aa5a74SRiver Riddle //===- FuncConversions.cpp - Function conversions -------------------------===//
223aa5a74SRiver Riddle //
323aa5a74SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
423aa5a74SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
523aa5a74SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
623aa5a74SRiver Riddle //
723aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
823aa5a74SRiver Riddle 
923aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
1023aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1123aa5a74SRiver Riddle #include "mlir/Transforms/DialectConversion.h"
1223aa5a74SRiver Riddle 
1323aa5a74SRiver Riddle using namespace mlir;
1423aa5a74SRiver Riddle using namespace mlir::func;
1523aa5a74SRiver Riddle 
1623aa5a74SRiver Riddle namespace {
1723aa5a74SRiver Riddle /// Converts the operand and result types of the CallOp, used together with the
1823aa5a74SRiver Riddle /// FuncOpSignatureConversion.
1923aa5a74SRiver Riddle struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
2023aa5a74SRiver Riddle   using OpConversionPattern<CallOp>::OpConversionPattern;
2123aa5a74SRiver Riddle 
2223aa5a74SRiver Riddle   /// Hook for derived classes to implement combined matching and rewriting.
2323aa5a74SRiver Riddle   LogicalResult
matchAndRewrite__anon0778ab1c0111::CallOpSignatureConversion2423aa5a74SRiver Riddle   matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
2523aa5a74SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
2623aa5a74SRiver Riddle     // Convert the original function results.
2723aa5a74SRiver Riddle     SmallVector<Type, 1> convertedResults;
2823aa5a74SRiver Riddle     if (failed(typeConverter->convertTypes(callOp.getResultTypes(),
2923aa5a74SRiver Riddle                                            convertedResults)))
3023aa5a74SRiver Riddle       return failure();
3123aa5a74SRiver Riddle 
3223aa5a74SRiver Riddle     // Substitute with the new result types from the corresponding FuncType
3323aa5a74SRiver Riddle     // conversion.
3423aa5a74SRiver Riddle     rewriter.replaceOpWithNewOp<CallOp>(
3523aa5a74SRiver Riddle         callOp, callOp.getCallee(), convertedResults, adaptor.getOperands());
3623aa5a74SRiver Riddle     return success();
3723aa5a74SRiver Riddle   }
3823aa5a74SRiver Riddle };
3923aa5a74SRiver Riddle } // namespace
4023aa5a74SRiver Riddle 
populateCallOpTypeConversionPattern(RewritePatternSet & patterns,TypeConverter & converter)4123aa5a74SRiver Riddle void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
4223aa5a74SRiver Riddle                                                TypeConverter &converter) {
4323aa5a74SRiver Riddle   patterns.add<CallOpSignatureConversion>(converter, patterns.getContext());
4423aa5a74SRiver Riddle }
4523aa5a74SRiver Riddle 
4623aa5a74SRiver Riddle namespace {
4723aa5a74SRiver Riddle /// Only needed to support partial conversion of functions where this pattern
4823aa5a74SRiver Riddle /// ensures that the branch operation arguments matches up with the succesor
4923aa5a74SRiver Riddle /// block arguments.
5023aa5a74SRiver Riddle class BranchOpInterfaceTypeConversion
5123aa5a74SRiver Riddle     : public OpInterfaceConversionPattern<BranchOpInterface> {
5223aa5a74SRiver Riddle public:
5323aa5a74SRiver Riddle   using OpInterfaceConversionPattern<
5423aa5a74SRiver Riddle       BranchOpInterface>::OpInterfaceConversionPattern;
5523aa5a74SRiver Riddle 
BranchOpInterfaceTypeConversion(TypeConverter & typeConverter,MLIRContext * ctx,function_ref<bool (BranchOpInterface,int)> shouldConvertBranchOperand)5623aa5a74SRiver Riddle   BranchOpInterfaceTypeConversion(
5723aa5a74SRiver Riddle       TypeConverter &typeConverter, MLIRContext *ctx,
5823aa5a74SRiver Riddle       function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand)
5923aa5a74SRiver Riddle       : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1),
6023aa5a74SRiver Riddle         shouldConvertBranchOperand(shouldConvertBranchOperand) {}
6123aa5a74SRiver Riddle 
6223aa5a74SRiver Riddle   LogicalResult
matchAndRewrite(BranchOpInterface op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const6323aa5a74SRiver Riddle   matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands,
6423aa5a74SRiver Riddle                   ConversionPatternRewriter &rewriter) const final {
6523aa5a74SRiver Riddle     // For a branch operation, only some operands go to the target blocks, so
6623aa5a74SRiver Riddle     // only rewrite those.
6723aa5a74SRiver Riddle     SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
6823aa5a74SRiver Riddle     for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
6923aa5a74SRiver Riddle          succIdx < succEnd; ++succIdx) {
70*0c789db5SMarkus Böck       OperandRange forwardedOperands =
71*0c789db5SMarkus Böck           op.getSuccessorOperands(succIdx).getForwardedOperands();
72*0c789db5SMarkus Böck       if (forwardedOperands.empty())
7323aa5a74SRiver Riddle         continue;
7423aa5a74SRiver Riddle 
75*0c789db5SMarkus Böck       for (int idx = forwardedOperands.getBeginOperandIndex(),
76*0c789db5SMarkus Böck                eidx = idx + forwardedOperands.size();
7723aa5a74SRiver Riddle            idx < eidx; ++idx) {
7823aa5a74SRiver Riddle         if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx))
7923aa5a74SRiver Riddle           newOperands[idx] = operands[idx];
8023aa5a74SRiver Riddle       }
8123aa5a74SRiver Riddle     }
8223aa5a74SRiver Riddle     rewriter.updateRootInPlace(
8323aa5a74SRiver Riddle         op, [newOperands, op]() { op->setOperands(newOperands); });
8423aa5a74SRiver Riddle     return success();
8523aa5a74SRiver Riddle   }
8623aa5a74SRiver Riddle 
8723aa5a74SRiver Riddle private:
8823aa5a74SRiver Riddle   function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand;
8923aa5a74SRiver Riddle };
9023aa5a74SRiver Riddle } // namespace
9123aa5a74SRiver Riddle 
9223aa5a74SRiver Riddle namespace {
9323aa5a74SRiver Riddle /// Only needed to support partial conversion of functions where this pattern
9423aa5a74SRiver Riddle /// ensures that the branch operation arguments matches up with the succesor
9523aa5a74SRiver Riddle /// block arguments.
9623aa5a74SRiver Riddle class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
9723aa5a74SRiver Riddle public:
9823aa5a74SRiver Riddle   using OpConversionPattern<ReturnOp>::OpConversionPattern;
9923aa5a74SRiver Riddle 
10023aa5a74SRiver Riddle   LogicalResult
matchAndRewrite(ReturnOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const10123aa5a74SRiver Riddle   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
10223aa5a74SRiver Riddle                   ConversionPatternRewriter &rewriter) const final {
10323aa5a74SRiver Riddle     // For a return, all operands go to the results of the parent, so
10423aa5a74SRiver Riddle     // rewrite them all.
10523aa5a74SRiver Riddle     rewriter.updateRootInPlace(op,
10623aa5a74SRiver Riddle                                [&] { op->setOperands(adaptor.getOperands()); });
10723aa5a74SRiver Riddle     return success();
10823aa5a74SRiver Riddle   }
10923aa5a74SRiver Riddle };
11023aa5a74SRiver Riddle } // namespace
11123aa5a74SRiver Riddle 
populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet & patterns,TypeConverter & typeConverter,function_ref<bool (BranchOpInterface,int)> shouldConvertBranchOperand)11223aa5a74SRiver Riddle void mlir::populateBranchOpInterfaceTypeConversionPattern(
11323aa5a74SRiver Riddle     RewritePatternSet &patterns, TypeConverter &typeConverter,
11423aa5a74SRiver Riddle     function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) {
11523aa5a74SRiver Riddle   patterns.add<BranchOpInterfaceTypeConversion>(
11623aa5a74SRiver Riddle       typeConverter, patterns.getContext(), shouldConvertBranchOperand);
11723aa5a74SRiver Riddle }
11823aa5a74SRiver Riddle 
isLegalForBranchOpInterfaceTypeConversionPattern(Operation * op,TypeConverter & converter)11923aa5a74SRiver Riddle bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
12023aa5a74SRiver Riddle     Operation *op, TypeConverter &converter) {
12123aa5a74SRiver Riddle   // All successor operands of branch like operations must be rewritten.
12223aa5a74SRiver Riddle   if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
12323aa5a74SRiver Riddle     for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
12423aa5a74SRiver Riddle       auto successorOperands = branchOp.getSuccessorOperands(p);
125*0c789db5SMarkus Böck       if (!converter.isLegal(
126*0c789db5SMarkus Böck               successorOperands.getForwardedOperands().getTypes()))
12723aa5a74SRiver Riddle         return false;
12823aa5a74SRiver Riddle     }
12923aa5a74SRiver Riddle     return true;
13023aa5a74SRiver Riddle   }
13123aa5a74SRiver Riddle 
13223aa5a74SRiver Riddle   return false;
13323aa5a74SRiver Riddle }
13423aa5a74SRiver Riddle 
populateReturnOpTypeConversionPattern(RewritePatternSet & patterns,TypeConverter & typeConverter)13523aa5a74SRiver Riddle void mlir::populateReturnOpTypeConversionPattern(RewritePatternSet &patterns,
13623aa5a74SRiver Riddle                                                  TypeConverter &typeConverter) {
13723aa5a74SRiver Riddle   patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext());
13823aa5a74SRiver Riddle }
13923aa5a74SRiver Riddle 
isLegalForReturnOpTypeConversionPattern(Operation * op,TypeConverter & converter,bool returnOpAlwaysLegal)14023aa5a74SRiver Riddle bool mlir::isLegalForReturnOpTypeConversionPattern(Operation *op,
14123aa5a74SRiver Riddle                                                    TypeConverter &converter,
14223aa5a74SRiver Riddle                                                    bool returnOpAlwaysLegal) {
14323aa5a74SRiver Riddle   // If this is a `return` and the user pass wants to convert/transform across
14423aa5a74SRiver Riddle   // function boundaries, then `converter` is invoked to check whether the the
14523aa5a74SRiver Riddle   // `return` op is legal.
14623aa5a74SRiver Riddle   if (dyn_cast<ReturnOp>(op) && !returnOpAlwaysLegal)
14723aa5a74SRiver Riddle     return converter.isLegal(op);
14823aa5a74SRiver Riddle 
14923aa5a74SRiver Riddle   // ReturnLike operations have to be legalized with their parent. For
15023aa5a74SRiver Riddle   // return this is handled, for other ops they remain as is.
15123aa5a74SRiver Riddle   return op->hasTrait<OpTrait::ReturnLike>();
15223aa5a74SRiver Riddle }
15323aa5a74SRiver Riddle 
isNotBranchOpInterfaceOrReturnLikeOp(Operation * op)15423aa5a74SRiver Riddle bool mlir::isNotBranchOpInterfaceOrReturnLikeOp(Operation *op) {
15523aa5a74SRiver Riddle   // If it is not a terminator, ignore it.
15623aa5a74SRiver Riddle   if (!op->mightHaveTrait<OpTrait::IsTerminator>())
15723aa5a74SRiver Riddle     return true;
15823aa5a74SRiver Riddle 
15923aa5a74SRiver Riddle   // If it is not the last operation in the block, also ignore it. We do
16023aa5a74SRiver Riddle   // this to handle unknown operations, as well.
16123aa5a74SRiver Riddle   Block *block = op->getBlock();
16223aa5a74SRiver Riddle   if (!block || &block->back() != op)
16323aa5a74SRiver Riddle     return true;
16423aa5a74SRiver Riddle 
16523aa5a74SRiver Riddle   // We don't want to handle terminators in nested regions, assume they are
16623aa5a74SRiver Riddle   // always legal.
16723aa5a74SRiver Riddle   if (!isa_and_nonnull<FuncOp>(op->getParentOp()))
16823aa5a74SRiver Riddle     return true;
16923aa5a74SRiver Riddle 
17023aa5a74SRiver Riddle   return false;
17123aa5a74SRiver Riddle }
172