1ace01605SRiver Riddle //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2ace01605SRiver Riddle //
3ace01605SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ace01605SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5ace01605SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ace01605SRiver Riddle //
7ace01605SRiver Riddle //===----------------------------------------------------------------------===//
8ace01605SRiver Riddle //
9ace01605SRiver Riddle // This file implements a pass to convert MLIR standard and builtin dialects
10ace01605SRiver Riddle // into the LLVM IR dialect.
11ace01605SRiver Riddle //
12ace01605SRiver Riddle //===----------------------------------------------------------------------===//
13ace01605SRiver Riddle
14ace01605SRiver Riddle #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15ace01605SRiver Riddle #include "../PassDetail.h"
16ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
17ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/Pattern.h"
18ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
19ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
20ace01605SRiver Riddle #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
21ace01605SRiver Riddle #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22ace01605SRiver Riddle #include "mlir/IR/BuiltinOps.h"
23ace01605SRiver Riddle #include "mlir/IR/PatternMatch.h"
24ace01605SRiver Riddle #include "mlir/Transforms/DialectConversion.h"
25ace01605SRiver Riddle #include <functional>
26ace01605SRiver Riddle
27ace01605SRiver Riddle using namespace mlir;
28ace01605SRiver Riddle
29ace01605SRiver Riddle #define PASS_NAME "convert-cf-to-llvm"
30ace01605SRiver Riddle
31ace01605SRiver Riddle namespace {
32*23aa5a74SRiver Riddle /// Lower `cf.assert`. The default lowering calls the `abort` function if the
33ace01605SRiver Riddle /// assertion is violated and has no effect otherwise. The failure message is
34ace01605SRiver Riddle /// ignored by the default lowering but should be propagated by any custom
35ace01605SRiver Riddle /// lowering.
36ace01605SRiver Riddle struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
37ace01605SRiver Riddle using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
38ace01605SRiver Riddle
39ace01605SRiver Riddle LogicalResult
matchAndRewrite__anon90dfa76b0111::AssertOpLowering40ace01605SRiver Riddle matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
41ace01605SRiver Riddle ConversionPatternRewriter &rewriter) const override {
42ace01605SRiver Riddle auto loc = op.getLoc();
43ace01605SRiver Riddle
44ace01605SRiver Riddle // Insert the `abort` declaration if necessary.
45ace01605SRiver Riddle auto module = op->getParentOfType<ModuleOp>();
46ace01605SRiver Riddle auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
47ace01605SRiver Riddle if (!abortFunc) {
48ace01605SRiver Riddle OpBuilder::InsertionGuard guard(rewriter);
49ace01605SRiver Riddle rewriter.setInsertionPointToStart(module.getBody());
50ace01605SRiver Riddle auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
51ace01605SRiver Riddle abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
52ace01605SRiver Riddle "abort", abortFuncTy);
53ace01605SRiver Riddle }
54ace01605SRiver Riddle
55ace01605SRiver Riddle // Split block at `assert` operation.
56ace01605SRiver Riddle Block *opBlock = rewriter.getInsertionBlock();
57ace01605SRiver Riddle auto opPosition = rewriter.getInsertionPoint();
58ace01605SRiver Riddle Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
59ace01605SRiver Riddle
60ace01605SRiver Riddle // Generate IR to call `abort`.
61ace01605SRiver Riddle Block *failureBlock = rewriter.createBlock(opBlock->getParent());
62ace01605SRiver Riddle rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
63ace01605SRiver Riddle rewriter.create<LLVM::UnreachableOp>(loc);
64ace01605SRiver Riddle
65ace01605SRiver Riddle // Generate assertion test.
66ace01605SRiver Riddle rewriter.setInsertionPointToEnd(opBlock);
67ace01605SRiver Riddle rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
68ace01605SRiver Riddle op, adaptor.getArg(), continuationBlock, failureBlock);
69ace01605SRiver Riddle
70ace01605SRiver Riddle return success();
71ace01605SRiver Riddle }
72ace01605SRiver Riddle };
73ace01605SRiver Riddle
74ace01605SRiver Riddle // Base class for LLVM IR lowering terminator operations with successors.
75ace01605SRiver Riddle template <typename SourceOp, typename TargetOp>
76ace01605SRiver Riddle struct OneToOneLLVMTerminatorLowering
77ace01605SRiver Riddle : public ConvertOpToLLVMPattern<SourceOp> {
78ace01605SRiver Riddle using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
79ace01605SRiver Riddle using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
80ace01605SRiver Riddle
81ace01605SRiver Riddle LogicalResult
matchAndRewrite__anon90dfa76b0111::OneToOneLLVMTerminatorLowering82ace01605SRiver Riddle matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
83ace01605SRiver Riddle ConversionPatternRewriter &rewriter) const override {
84ace01605SRiver Riddle rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
85ace01605SRiver Riddle op->getSuccessors(), op->getAttrs());
86ace01605SRiver Riddle return success();
87ace01605SRiver Riddle }
88ace01605SRiver Riddle };
89ace01605SRiver Riddle
90ace01605SRiver Riddle // FIXME: this should be tablegen'ed as well.
91ace01605SRiver Riddle struct BranchOpLowering
92ace01605SRiver Riddle : public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> {
93ace01605SRiver Riddle using Base::Base;
94ace01605SRiver Riddle };
95ace01605SRiver Riddle struct CondBranchOpLowering
96ace01605SRiver Riddle : public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> {
97ace01605SRiver Riddle using Base::Base;
98ace01605SRiver Riddle };
99ace01605SRiver Riddle struct SwitchOpLowering
100ace01605SRiver Riddle : public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> {
101ace01605SRiver Riddle using Base::Base;
102ace01605SRiver Riddle };
103ace01605SRiver Riddle
104ace01605SRiver Riddle } // namespace
105ace01605SRiver Riddle
populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)106ace01605SRiver Riddle void mlir::cf::populateControlFlowToLLVMConversionPatterns(
107ace01605SRiver Riddle LLVMTypeConverter &converter, RewritePatternSet &patterns) {
108ace01605SRiver Riddle // clang-format off
109ace01605SRiver Riddle patterns.add<
110ace01605SRiver Riddle AssertOpLowering,
111ace01605SRiver Riddle BranchOpLowering,
112ace01605SRiver Riddle CondBranchOpLowering,
113ace01605SRiver Riddle SwitchOpLowering>(converter);
114ace01605SRiver Riddle // clang-format on
115ace01605SRiver Riddle }
116ace01605SRiver Riddle
117ace01605SRiver Riddle //===----------------------------------------------------------------------===//
118ace01605SRiver Riddle // Pass Definition
119ace01605SRiver Riddle //===----------------------------------------------------------------------===//
120ace01605SRiver Riddle
121ace01605SRiver Riddle namespace {
122ace01605SRiver Riddle /// A pass converting MLIR operations into the LLVM IR dialect.
123ace01605SRiver Riddle struct ConvertControlFlowToLLVM
124ace01605SRiver Riddle : public ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> {
125ace01605SRiver Riddle ConvertControlFlowToLLVM() = default;
126ace01605SRiver Riddle
127ace01605SRiver Riddle /// Run the dialect converter on the module.
runOnOperation__anon90dfa76b0211::ConvertControlFlowToLLVM128ace01605SRiver Riddle void runOnOperation() override {
129ace01605SRiver Riddle LLVMConversionTarget target(getContext());
130ace01605SRiver Riddle RewritePatternSet patterns(&getContext());
131ace01605SRiver Riddle
132ace01605SRiver Riddle LowerToLLVMOptions options(&getContext());
133ace01605SRiver Riddle if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
134ace01605SRiver Riddle options.overrideIndexBitwidth(indexBitwidth);
135ace01605SRiver Riddle
136ace01605SRiver Riddle LLVMTypeConverter converter(&getContext(), options);
137ace01605SRiver Riddle mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
138ace01605SRiver Riddle
139ace01605SRiver Riddle if (failed(applyPartialConversion(getOperation(), target,
140ace01605SRiver Riddle std::move(patterns))))
141ace01605SRiver Riddle signalPassFailure();
142ace01605SRiver Riddle }
143ace01605SRiver Riddle };
144ace01605SRiver Riddle } // namespace
145ace01605SRiver Riddle
createConvertControlFlowToLLVMPass()146ace01605SRiver Riddle std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() {
147ace01605SRiver Riddle return std::make_unique<ConvertControlFlowToLLVM>();
148ace01605SRiver Riddle }
149