1*ace01605SRiver Riddle //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// 2*ace01605SRiver Riddle // 3*ace01605SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*ace01605SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5*ace01605SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*ace01605SRiver Riddle // 7*ace01605SRiver Riddle //===----------------------------------------------------------------------===// 8*ace01605SRiver Riddle // 9*ace01605SRiver Riddle // This file implements a pass to convert MLIR standard and builtin dialects 10*ace01605SRiver Riddle // into the LLVM IR dialect. 11*ace01605SRiver Riddle // 12*ace01605SRiver Riddle //===----------------------------------------------------------------------===// 13*ace01605SRiver Riddle 14*ace01605SRiver Riddle #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 15*ace01605SRiver Riddle #include "../PassDetail.h" 16*ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 17*ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/Pattern.h" 18*ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 19*ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 20*ace01605SRiver Riddle #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 21*ace01605SRiver Riddle #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22*ace01605SRiver Riddle #include "mlir/IR/BuiltinOps.h" 23*ace01605SRiver Riddle #include "mlir/IR/PatternMatch.h" 24*ace01605SRiver Riddle #include "mlir/Transforms/DialectConversion.h" 25*ace01605SRiver Riddle #include <functional> 26*ace01605SRiver Riddle 27*ace01605SRiver Riddle using namespace mlir; 28*ace01605SRiver Riddle 29*ace01605SRiver Riddle #define PASS_NAME "convert-cf-to-llvm" 30*ace01605SRiver Riddle 31*ace01605SRiver Riddle namespace { 32*ace01605SRiver Riddle /// Lower `std.assert`. The default lowering calls the `abort` function if the 33*ace01605SRiver Riddle /// assertion is violated and has no effect otherwise. The failure message is 34*ace01605SRiver Riddle /// ignored by the default lowering but should be propagated by any custom 35*ace01605SRiver Riddle /// lowering. 36*ace01605SRiver Riddle struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> { 37*ace01605SRiver Riddle using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern; 38*ace01605SRiver Riddle 39*ace01605SRiver Riddle LogicalResult 40*ace01605SRiver Riddle matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, 41*ace01605SRiver Riddle ConversionPatternRewriter &rewriter) const override { 42*ace01605SRiver Riddle auto loc = op.getLoc(); 43*ace01605SRiver Riddle 44*ace01605SRiver Riddle // Insert the `abort` declaration if necessary. 45*ace01605SRiver Riddle auto module = op->getParentOfType<ModuleOp>(); 46*ace01605SRiver Riddle auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort"); 47*ace01605SRiver Riddle if (!abortFunc) { 48*ace01605SRiver Riddle OpBuilder::InsertionGuard guard(rewriter); 49*ace01605SRiver Riddle rewriter.setInsertionPointToStart(module.getBody()); 50*ace01605SRiver Riddle auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); 51*ace01605SRiver Riddle abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), 52*ace01605SRiver Riddle "abort", abortFuncTy); 53*ace01605SRiver Riddle } 54*ace01605SRiver Riddle 55*ace01605SRiver Riddle // Split block at `assert` operation. 56*ace01605SRiver Riddle Block *opBlock = rewriter.getInsertionBlock(); 57*ace01605SRiver Riddle auto opPosition = rewriter.getInsertionPoint(); 58*ace01605SRiver Riddle Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); 59*ace01605SRiver Riddle 60*ace01605SRiver Riddle // Generate IR to call `abort`. 61*ace01605SRiver Riddle Block *failureBlock = rewriter.createBlock(opBlock->getParent()); 62*ace01605SRiver Riddle rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None); 63*ace01605SRiver Riddle rewriter.create<LLVM::UnreachableOp>(loc); 64*ace01605SRiver Riddle 65*ace01605SRiver Riddle // Generate assertion test. 66*ace01605SRiver Riddle rewriter.setInsertionPointToEnd(opBlock); 67*ace01605SRiver Riddle rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 68*ace01605SRiver Riddle op, adaptor.getArg(), continuationBlock, failureBlock); 69*ace01605SRiver Riddle 70*ace01605SRiver Riddle return success(); 71*ace01605SRiver Riddle } 72*ace01605SRiver Riddle }; 73*ace01605SRiver Riddle 74*ace01605SRiver Riddle // Base class for LLVM IR lowering terminator operations with successors. 75*ace01605SRiver Riddle template <typename SourceOp, typename TargetOp> 76*ace01605SRiver Riddle struct OneToOneLLVMTerminatorLowering 77*ace01605SRiver Riddle : public ConvertOpToLLVMPattern<SourceOp> { 78*ace01605SRiver Riddle using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; 79*ace01605SRiver Riddle using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>; 80*ace01605SRiver Riddle 81*ace01605SRiver Riddle LogicalResult 82*ace01605SRiver Riddle matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 83*ace01605SRiver Riddle ConversionPatternRewriter &rewriter) const override { 84*ace01605SRiver Riddle rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(), 85*ace01605SRiver Riddle op->getSuccessors(), op->getAttrs()); 86*ace01605SRiver Riddle return success(); 87*ace01605SRiver Riddle } 88*ace01605SRiver Riddle }; 89*ace01605SRiver Riddle 90*ace01605SRiver Riddle // FIXME: this should be tablegen'ed as well. 91*ace01605SRiver Riddle struct BranchOpLowering 92*ace01605SRiver Riddle : public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> { 93*ace01605SRiver Riddle using Base::Base; 94*ace01605SRiver Riddle }; 95*ace01605SRiver Riddle struct CondBranchOpLowering 96*ace01605SRiver Riddle : public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> { 97*ace01605SRiver Riddle using Base::Base; 98*ace01605SRiver Riddle }; 99*ace01605SRiver Riddle struct SwitchOpLowering 100*ace01605SRiver Riddle : public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> { 101*ace01605SRiver Riddle using Base::Base; 102*ace01605SRiver Riddle }; 103*ace01605SRiver Riddle 104*ace01605SRiver Riddle } // namespace 105*ace01605SRiver Riddle 106*ace01605SRiver Riddle void mlir::cf::populateControlFlowToLLVMConversionPatterns( 107*ace01605SRiver Riddle LLVMTypeConverter &converter, RewritePatternSet &patterns) { 108*ace01605SRiver Riddle // clang-format off 109*ace01605SRiver Riddle patterns.add< 110*ace01605SRiver Riddle AssertOpLowering, 111*ace01605SRiver Riddle BranchOpLowering, 112*ace01605SRiver Riddle CondBranchOpLowering, 113*ace01605SRiver Riddle SwitchOpLowering>(converter); 114*ace01605SRiver Riddle // clang-format on 115*ace01605SRiver Riddle } 116*ace01605SRiver Riddle 117*ace01605SRiver Riddle //===----------------------------------------------------------------------===// 118*ace01605SRiver Riddle // Pass Definition 119*ace01605SRiver Riddle //===----------------------------------------------------------------------===// 120*ace01605SRiver Riddle 121*ace01605SRiver Riddle namespace { 122*ace01605SRiver Riddle /// A pass converting MLIR operations into the LLVM IR dialect. 123*ace01605SRiver Riddle struct ConvertControlFlowToLLVM 124*ace01605SRiver Riddle : public ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> { 125*ace01605SRiver Riddle ConvertControlFlowToLLVM() = default; 126*ace01605SRiver Riddle 127*ace01605SRiver Riddle /// Run the dialect converter on the module. 128*ace01605SRiver Riddle void runOnOperation() override { 129*ace01605SRiver Riddle LLVMConversionTarget target(getContext()); 130*ace01605SRiver Riddle RewritePatternSet patterns(&getContext()); 131*ace01605SRiver Riddle 132*ace01605SRiver Riddle LowerToLLVMOptions options(&getContext()); 133*ace01605SRiver Riddle if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 134*ace01605SRiver Riddle options.overrideIndexBitwidth(indexBitwidth); 135*ace01605SRiver Riddle 136*ace01605SRiver Riddle LLVMTypeConverter converter(&getContext(), options); 137*ace01605SRiver Riddle mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); 138*ace01605SRiver Riddle 139*ace01605SRiver Riddle if (failed(applyPartialConversion(getOperation(), target, 140*ace01605SRiver Riddle std::move(patterns)))) 141*ace01605SRiver Riddle signalPassFailure(); 142*ace01605SRiver Riddle } 143*ace01605SRiver Riddle }; 144*ace01605SRiver Riddle } // namespace 145*ace01605SRiver Riddle 146*ace01605SRiver Riddle std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() { 147*ace01605SRiver Riddle return std::make_unique<ConvertControlFlowToLLVM>(); 148*ace01605SRiver Riddle } 149