1ace01605SRiver Riddle //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2ace01605SRiver Riddle //
3ace01605SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ace01605SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5ace01605SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ace01605SRiver Riddle //
7ace01605SRiver Riddle //===----------------------------------------------------------------------===//
8ace01605SRiver Riddle //
9ace01605SRiver Riddle // This file implements a pass to convert MLIR standard and builtin dialects
10ace01605SRiver Riddle // into the LLVM IR dialect.
11ace01605SRiver Riddle //
12ace01605SRiver Riddle //===----------------------------------------------------------------------===//
13ace01605SRiver Riddle 
14ace01605SRiver Riddle #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15ace01605SRiver Riddle #include "../PassDetail.h"
16ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
17ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/Pattern.h"
18ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
19ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
20ace01605SRiver Riddle #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
21ace01605SRiver Riddle #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22ace01605SRiver Riddle #include "mlir/IR/BuiltinOps.h"
23ace01605SRiver Riddle #include "mlir/IR/PatternMatch.h"
24ace01605SRiver Riddle #include "mlir/Transforms/DialectConversion.h"
25ace01605SRiver Riddle #include <functional>
26ace01605SRiver Riddle 
27ace01605SRiver Riddle using namespace mlir;
28ace01605SRiver Riddle 
29ace01605SRiver Riddle #define PASS_NAME "convert-cf-to-llvm"
30ace01605SRiver Riddle 
31ace01605SRiver Riddle namespace {
32*23aa5a74SRiver Riddle /// Lower `cf.assert`. The default lowering calls the `abort` function if the
33ace01605SRiver Riddle /// assertion is violated and has no effect otherwise. The failure message is
34ace01605SRiver Riddle /// ignored by the default lowering but should be propagated by any custom
35ace01605SRiver Riddle /// lowering.
36ace01605SRiver Riddle struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
37ace01605SRiver Riddle   using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
38ace01605SRiver Riddle 
39ace01605SRiver Riddle   LogicalResult
matchAndRewrite__anon90dfa76b0111::AssertOpLowering40ace01605SRiver Riddle   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
41ace01605SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
42ace01605SRiver Riddle     auto loc = op.getLoc();
43ace01605SRiver Riddle 
44ace01605SRiver Riddle     // Insert the `abort` declaration if necessary.
45ace01605SRiver Riddle     auto module = op->getParentOfType<ModuleOp>();
46ace01605SRiver Riddle     auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
47ace01605SRiver Riddle     if (!abortFunc) {
48ace01605SRiver Riddle       OpBuilder::InsertionGuard guard(rewriter);
49ace01605SRiver Riddle       rewriter.setInsertionPointToStart(module.getBody());
50ace01605SRiver Riddle       auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
51ace01605SRiver Riddle       abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
52ace01605SRiver Riddle                                                     "abort", abortFuncTy);
53ace01605SRiver Riddle     }
54ace01605SRiver Riddle 
55ace01605SRiver Riddle     // Split block at `assert` operation.
56ace01605SRiver Riddle     Block *opBlock = rewriter.getInsertionBlock();
57ace01605SRiver Riddle     auto opPosition = rewriter.getInsertionPoint();
58ace01605SRiver Riddle     Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
59ace01605SRiver Riddle 
60ace01605SRiver Riddle     // Generate IR to call `abort`.
61ace01605SRiver Riddle     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
62ace01605SRiver Riddle     rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
63ace01605SRiver Riddle     rewriter.create<LLVM::UnreachableOp>(loc);
64ace01605SRiver Riddle 
65ace01605SRiver Riddle     // Generate assertion test.
66ace01605SRiver Riddle     rewriter.setInsertionPointToEnd(opBlock);
67ace01605SRiver Riddle     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
68ace01605SRiver Riddle         op, adaptor.getArg(), continuationBlock, failureBlock);
69ace01605SRiver Riddle 
70ace01605SRiver Riddle     return success();
71ace01605SRiver Riddle   }
72ace01605SRiver Riddle };
73ace01605SRiver Riddle 
74ace01605SRiver Riddle // Base class for LLVM IR lowering terminator operations with successors.
75ace01605SRiver Riddle template <typename SourceOp, typename TargetOp>
76ace01605SRiver Riddle struct OneToOneLLVMTerminatorLowering
77ace01605SRiver Riddle     : public ConvertOpToLLVMPattern<SourceOp> {
78ace01605SRiver Riddle   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
79ace01605SRiver Riddle   using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
80ace01605SRiver Riddle 
81ace01605SRiver Riddle   LogicalResult
matchAndRewrite__anon90dfa76b0111::OneToOneLLVMTerminatorLowering82ace01605SRiver Riddle   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
83ace01605SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
84ace01605SRiver Riddle     rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
85ace01605SRiver Riddle                                           op->getSuccessors(), op->getAttrs());
86ace01605SRiver Riddle     return success();
87ace01605SRiver Riddle   }
88ace01605SRiver Riddle };
89ace01605SRiver Riddle 
90ace01605SRiver Riddle // FIXME: this should be tablegen'ed as well.
91ace01605SRiver Riddle struct BranchOpLowering
92ace01605SRiver Riddle     : public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> {
93ace01605SRiver Riddle   using Base::Base;
94ace01605SRiver Riddle };
95ace01605SRiver Riddle struct CondBranchOpLowering
96ace01605SRiver Riddle     : public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> {
97ace01605SRiver Riddle   using Base::Base;
98ace01605SRiver Riddle };
99ace01605SRiver Riddle struct SwitchOpLowering
100ace01605SRiver Riddle     : public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> {
101ace01605SRiver Riddle   using Base::Base;
102ace01605SRiver Riddle };
103ace01605SRiver Riddle 
104ace01605SRiver Riddle } // namespace
105ace01605SRiver Riddle 
populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)106ace01605SRiver Riddle void mlir::cf::populateControlFlowToLLVMConversionPatterns(
107ace01605SRiver Riddle     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
108ace01605SRiver Riddle   // clang-format off
109ace01605SRiver Riddle   patterns.add<
110ace01605SRiver Riddle       AssertOpLowering,
111ace01605SRiver Riddle       BranchOpLowering,
112ace01605SRiver Riddle       CondBranchOpLowering,
113ace01605SRiver Riddle       SwitchOpLowering>(converter);
114ace01605SRiver Riddle   // clang-format on
115ace01605SRiver Riddle }
116ace01605SRiver Riddle 
117ace01605SRiver Riddle //===----------------------------------------------------------------------===//
118ace01605SRiver Riddle // Pass Definition
119ace01605SRiver Riddle //===----------------------------------------------------------------------===//
120ace01605SRiver Riddle 
121ace01605SRiver Riddle namespace {
122ace01605SRiver Riddle /// A pass converting MLIR operations into the LLVM IR dialect.
123ace01605SRiver Riddle struct ConvertControlFlowToLLVM
124ace01605SRiver Riddle     : public ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> {
125ace01605SRiver Riddle   ConvertControlFlowToLLVM() = default;
126ace01605SRiver Riddle 
127ace01605SRiver Riddle   /// Run the dialect converter on the module.
runOnOperation__anon90dfa76b0211::ConvertControlFlowToLLVM128ace01605SRiver Riddle   void runOnOperation() override {
129ace01605SRiver Riddle     LLVMConversionTarget target(getContext());
130ace01605SRiver Riddle     RewritePatternSet patterns(&getContext());
131ace01605SRiver Riddle 
132ace01605SRiver Riddle     LowerToLLVMOptions options(&getContext());
133ace01605SRiver Riddle     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
134ace01605SRiver Riddle       options.overrideIndexBitwidth(indexBitwidth);
135ace01605SRiver Riddle 
136ace01605SRiver Riddle     LLVMTypeConverter converter(&getContext(), options);
137ace01605SRiver Riddle     mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
138ace01605SRiver Riddle 
139ace01605SRiver Riddle     if (failed(applyPartialConversion(getOperation(), target,
140ace01605SRiver Riddle                                       std::move(patterns))))
141ace01605SRiver Riddle       signalPassFailure();
142ace01605SRiver Riddle   }
143ace01605SRiver Riddle };
144ace01605SRiver Riddle } // namespace
145ace01605SRiver Riddle 
createConvertControlFlowToLLVMPass()146ace01605SRiver Riddle std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() {
147ace01605SRiver Riddle   return std::make_unique<ConvertControlFlowToLLVM>();
148ace01605SRiver Riddle }
149