1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15 #include "../PassDetail.h"
16 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
17 #include "mlir/Conversion/LLVMCommon/Pattern.h"
18 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
20 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include <functional>
26
27 using namespace mlir;
28
29 #define PASS_NAME "convert-cf-to-llvm"
30
31 namespace {
32 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
33 /// assertion is violated and has no effect otherwise. The failure message is
34 /// ignored by the default lowering but should be propagated by any custom
35 /// lowering.
36 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
37 using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
38
39 LogicalResult
matchAndRewrite__anon90dfa76b0111::AssertOpLowering40 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
41 ConversionPatternRewriter &rewriter) const override {
42 auto loc = op.getLoc();
43
44 // Insert the `abort` declaration if necessary.
45 auto module = op->getParentOfType<ModuleOp>();
46 auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
47 if (!abortFunc) {
48 OpBuilder::InsertionGuard guard(rewriter);
49 rewriter.setInsertionPointToStart(module.getBody());
50 auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
51 abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
52 "abort", abortFuncTy);
53 }
54
55 // Split block at `assert` operation.
56 Block *opBlock = rewriter.getInsertionBlock();
57 auto opPosition = rewriter.getInsertionPoint();
58 Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
59
60 // Generate IR to call `abort`.
61 Block *failureBlock = rewriter.createBlock(opBlock->getParent());
62 rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
63 rewriter.create<LLVM::UnreachableOp>(loc);
64
65 // Generate assertion test.
66 rewriter.setInsertionPointToEnd(opBlock);
67 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
68 op, adaptor.getArg(), continuationBlock, failureBlock);
69
70 return success();
71 }
72 };
73
74 // Base class for LLVM IR lowering terminator operations with successors.
75 template <typename SourceOp, typename TargetOp>
76 struct OneToOneLLVMTerminatorLowering
77 : public ConvertOpToLLVMPattern<SourceOp> {
78 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
79 using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
80
81 LogicalResult
matchAndRewrite__anon90dfa76b0111::OneToOneLLVMTerminatorLowering82 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
83 ConversionPatternRewriter &rewriter) const override {
84 rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
85 op->getSuccessors(), op->getAttrs());
86 return success();
87 }
88 };
89
90 // FIXME: this should be tablegen'ed as well.
91 struct BranchOpLowering
92 : public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> {
93 using Base::Base;
94 };
95 struct CondBranchOpLowering
96 : public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> {
97 using Base::Base;
98 };
99 struct SwitchOpLowering
100 : public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> {
101 using Base::Base;
102 };
103
104 } // namespace
105
populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)106 void mlir::cf::populateControlFlowToLLVMConversionPatterns(
107 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
108 // clang-format off
109 patterns.add<
110 AssertOpLowering,
111 BranchOpLowering,
112 CondBranchOpLowering,
113 SwitchOpLowering>(converter);
114 // clang-format on
115 }
116
117 //===----------------------------------------------------------------------===//
118 // Pass Definition
119 //===----------------------------------------------------------------------===//
120
121 namespace {
122 /// A pass converting MLIR operations into the LLVM IR dialect.
123 struct ConvertControlFlowToLLVM
124 : public ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> {
125 ConvertControlFlowToLLVM() = default;
126
127 /// Run the dialect converter on the module.
runOnOperation__anon90dfa76b0211::ConvertControlFlowToLLVM128 void runOnOperation() override {
129 LLVMConversionTarget target(getContext());
130 RewritePatternSet patterns(&getContext());
131
132 LowerToLLVMOptions options(&getContext());
133 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
134 options.overrideIndexBitwidth(indexBitwidth);
135
136 LLVMTypeConverter converter(&getContext(), options);
137 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
138
139 if (failed(applyPartialConversion(getOperation(), target,
140 std::move(patterns))))
141 signalPassFailure();
142 }
143 };
144 } // namespace
145
createConvertControlFlowToLLVMPass()146 std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() {
147 return std::make_unique<ConvertControlFlowToLLVM>();
148 }
149