1*ace01605SRiver Riddle //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2*ace01605SRiver Riddle //
3*ace01605SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*ace01605SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5*ace01605SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*ace01605SRiver Riddle //
7*ace01605SRiver Riddle //===----------------------------------------------------------------------===//
8*ace01605SRiver Riddle //
9*ace01605SRiver Riddle // This file implements a pass to convert MLIR standard and builtin dialects
10*ace01605SRiver Riddle // into the LLVM IR dialect.
11*ace01605SRiver Riddle //
12*ace01605SRiver Riddle //===----------------------------------------------------------------------===//
13*ace01605SRiver Riddle 
14*ace01605SRiver Riddle #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15*ace01605SRiver Riddle #include "../PassDetail.h"
16*ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
17*ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/Pattern.h"
18*ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
19*ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
20*ace01605SRiver Riddle #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
21*ace01605SRiver Riddle #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22*ace01605SRiver Riddle #include "mlir/IR/BuiltinOps.h"
23*ace01605SRiver Riddle #include "mlir/IR/PatternMatch.h"
24*ace01605SRiver Riddle #include "mlir/Transforms/DialectConversion.h"
25*ace01605SRiver Riddle #include <functional>
26*ace01605SRiver Riddle 
27*ace01605SRiver Riddle using namespace mlir;
28*ace01605SRiver Riddle 
29*ace01605SRiver Riddle #define PASS_NAME "convert-cf-to-llvm"
30*ace01605SRiver Riddle 
31*ace01605SRiver Riddle namespace {
32*ace01605SRiver Riddle /// Lower `std.assert`. The default lowering calls the `abort` function if the
33*ace01605SRiver Riddle /// assertion is violated and has no effect otherwise. The failure message is
34*ace01605SRiver Riddle /// ignored by the default lowering but should be propagated by any custom
35*ace01605SRiver Riddle /// lowering.
36*ace01605SRiver Riddle struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
37*ace01605SRiver Riddle   using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
38*ace01605SRiver Riddle 
39*ace01605SRiver Riddle   LogicalResult
40*ace01605SRiver Riddle   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
41*ace01605SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
42*ace01605SRiver Riddle     auto loc = op.getLoc();
43*ace01605SRiver Riddle 
44*ace01605SRiver Riddle     // Insert the `abort` declaration if necessary.
45*ace01605SRiver Riddle     auto module = op->getParentOfType<ModuleOp>();
46*ace01605SRiver Riddle     auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
47*ace01605SRiver Riddle     if (!abortFunc) {
48*ace01605SRiver Riddle       OpBuilder::InsertionGuard guard(rewriter);
49*ace01605SRiver Riddle       rewriter.setInsertionPointToStart(module.getBody());
50*ace01605SRiver Riddle       auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
51*ace01605SRiver Riddle       abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
52*ace01605SRiver Riddle                                                     "abort", abortFuncTy);
53*ace01605SRiver Riddle     }
54*ace01605SRiver Riddle 
55*ace01605SRiver Riddle     // Split block at `assert` operation.
56*ace01605SRiver Riddle     Block *opBlock = rewriter.getInsertionBlock();
57*ace01605SRiver Riddle     auto opPosition = rewriter.getInsertionPoint();
58*ace01605SRiver Riddle     Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
59*ace01605SRiver Riddle 
60*ace01605SRiver Riddle     // Generate IR to call `abort`.
61*ace01605SRiver Riddle     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
62*ace01605SRiver Riddle     rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
63*ace01605SRiver Riddle     rewriter.create<LLVM::UnreachableOp>(loc);
64*ace01605SRiver Riddle 
65*ace01605SRiver Riddle     // Generate assertion test.
66*ace01605SRiver Riddle     rewriter.setInsertionPointToEnd(opBlock);
67*ace01605SRiver Riddle     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
68*ace01605SRiver Riddle         op, adaptor.getArg(), continuationBlock, failureBlock);
69*ace01605SRiver Riddle 
70*ace01605SRiver Riddle     return success();
71*ace01605SRiver Riddle   }
72*ace01605SRiver Riddle };
73*ace01605SRiver Riddle 
74*ace01605SRiver Riddle // Base class for LLVM IR lowering terminator operations with successors.
75*ace01605SRiver Riddle template <typename SourceOp, typename TargetOp>
76*ace01605SRiver Riddle struct OneToOneLLVMTerminatorLowering
77*ace01605SRiver Riddle     : public ConvertOpToLLVMPattern<SourceOp> {
78*ace01605SRiver Riddle   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
79*ace01605SRiver Riddle   using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
80*ace01605SRiver Riddle 
81*ace01605SRiver Riddle   LogicalResult
82*ace01605SRiver Riddle   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
83*ace01605SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
84*ace01605SRiver Riddle     rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
85*ace01605SRiver Riddle                                           op->getSuccessors(), op->getAttrs());
86*ace01605SRiver Riddle     return success();
87*ace01605SRiver Riddle   }
88*ace01605SRiver Riddle };
89*ace01605SRiver Riddle 
90*ace01605SRiver Riddle // FIXME: this should be tablegen'ed as well.
91*ace01605SRiver Riddle struct BranchOpLowering
92*ace01605SRiver Riddle     : public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> {
93*ace01605SRiver Riddle   using Base::Base;
94*ace01605SRiver Riddle };
95*ace01605SRiver Riddle struct CondBranchOpLowering
96*ace01605SRiver Riddle     : public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> {
97*ace01605SRiver Riddle   using Base::Base;
98*ace01605SRiver Riddle };
99*ace01605SRiver Riddle struct SwitchOpLowering
100*ace01605SRiver Riddle     : public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> {
101*ace01605SRiver Riddle   using Base::Base;
102*ace01605SRiver Riddle };
103*ace01605SRiver Riddle 
104*ace01605SRiver Riddle } // namespace
105*ace01605SRiver Riddle 
106*ace01605SRiver Riddle void mlir::cf::populateControlFlowToLLVMConversionPatterns(
107*ace01605SRiver Riddle     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
108*ace01605SRiver Riddle   // clang-format off
109*ace01605SRiver Riddle   patterns.add<
110*ace01605SRiver Riddle       AssertOpLowering,
111*ace01605SRiver Riddle       BranchOpLowering,
112*ace01605SRiver Riddle       CondBranchOpLowering,
113*ace01605SRiver Riddle       SwitchOpLowering>(converter);
114*ace01605SRiver Riddle   // clang-format on
115*ace01605SRiver Riddle }
116*ace01605SRiver Riddle 
117*ace01605SRiver Riddle //===----------------------------------------------------------------------===//
118*ace01605SRiver Riddle // Pass Definition
119*ace01605SRiver Riddle //===----------------------------------------------------------------------===//
120*ace01605SRiver Riddle 
121*ace01605SRiver Riddle namespace {
122*ace01605SRiver Riddle /// A pass converting MLIR operations into the LLVM IR dialect.
123*ace01605SRiver Riddle struct ConvertControlFlowToLLVM
124*ace01605SRiver Riddle     : public ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> {
125*ace01605SRiver Riddle   ConvertControlFlowToLLVM() = default;
126*ace01605SRiver Riddle 
127*ace01605SRiver Riddle   /// Run the dialect converter on the module.
128*ace01605SRiver Riddle   void runOnOperation() override {
129*ace01605SRiver Riddle     LLVMConversionTarget target(getContext());
130*ace01605SRiver Riddle     RewritePatternSet patterns(&getContext());
131*ace01605SRiver Riddle 
132*ace01605SRiver Riddle     LowerToLLVMOptions options(&getContext());
133*ace01605SRiver Riddle     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
134*ace01605SRiver Riddle       options.overrideIndexBitwidth(indexBitwidth);
135*ace01605SRiver Riddle 
136*ace01605SRiver Riddle     LLVMTypeConverter converter(&getContext(), options);
137*ace01605SRiver Riddle     mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
138*ace01605SRiver Riddle 
139*ace01605SRiver Riddle     if (failed(applyPartialConversion(getOperation(), target,
140*ace01605SRiver Riddle                                       std::move(patterns))))
141*ace01605SRiver Riddle       signalPassFailure();
142*ace01605SRiver Riddle   }
143*ace01605SRiver Riddle };
144*ace01605SRiver Riddle } // namespace
145*ace01605SRiver Riddle 
146*ace01605SRiver Riddle std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() {
147*ace01605SRiver Riddle   return std::make_unique<ConvertControlFlowToLLVM>();
148*ace01605SRiver Riddle }
149