1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/CodeGen/CodeGen.h"
14 #include "PassDetail.h"
15 #include "flang/Optimizer/Dialect/FIROps.h"
16 #include "flang/Optimizer/Dialect/FIRType.h"
17 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
18 #include "mlir/Conversion/LLVMCommon/Pattern.h"
19 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/Pass/Pass.h"
23 #include "llvm/ADT/ArrayRef.h"
24 
25 #define DEBUG_TYPE "flang-codegen"
26 
27 // fir::LLVMTypeConverter for converting to LLVM IR dialect types.
28 #include "TypeConverter.h"
29 
30 namespace {
31 /// FIR conversion pattern template
32 template <typename FromOp>
33 class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
34 public:
35   explicit FIROpConversion(fir::LLVMTypeConverter &lowering)
36       : mlir::ConvertOpToLLVMPattern<FromOp>(lowering) {}
37 
38 protected:
39   mlir::Type convertType(mlir::Type ty) const {
40     return lowerTy().convertType(ty);
41   }
42 
43   fir::LLVMTypeConverter &lowerTy() const {
44     return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter());
45   }
46 };
47 } // namespace
48 
49 namespace {
50 struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> {
51   using FIROpConversion::FIROpConversion;
52 
53   mlir::LogicalResult
54   matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
55                   mlir::ConversionPatternRewriter &rewriter) const override {
56     auto ty = convertType(addr.getType());
57     rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
58         addr, ty, addr.symbol().getRootReference().getValue());
59     return success();
60   }
61 };
62 
63 struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> {
64   using FIROpConversion::FIROpConversion;
65 
66   mlir::LogicalResult
67   matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor,
68                   mlir::ConversionPatternRewriter &rewriter) const override {
69     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
70     return success();
71   }
72 };
73 
74 struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
75   using FIROpConversion::FIROpConversion;
76 
77   mlir::LogicalResult
78   matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor,
79                   mlir::ConversionPatternRewriter &rewriter) const override {
80     auto tyAttr = convertType(global.getType());
81     if (global.getType().isa<fir::BoxType>())
82       tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType();
83     auto loc = global.getLoc();
84     mlir::Attribute initAttr{};
85     if (global.initVal())
86       initAttr = global.initVal().getValue();
87     auto linkage = convertLinkage(global.linkName());
88     auto isConst = global.constant().hasValue();
89     auto g = rewriter.create<mlir::LLVM::GlobalOp>(
90         loc, tyAttr, isConst, linkage, global.sym_name(), initAttr);
91     auto &gr = g.getInitializerRegion();
92     rewriter.inlineRegionBefore(global.region(), gr, gr.end());
93     if (!gr.empty()) {
94       // Replace insert_on_range with a constant dense attribute if the
95       // initialization is on the full range.
96       auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>();
97       for (auto insertOp : insertOnRangeOps) {
98         if (isFullRange(insertOp.coor(), insertOp.getType())) {
99           auto seqTyAttr = convertType(insertOp.getType());
100           auto *op = insertOp.val().getDefiningOp();
101           auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op);
102           if (!constant) {
103             auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op);
104             if (!convertOp)
105               continue;
106             constant = cast<mlir::arith::ConstantOp>(
107                 convertOp.value().getDefiningOp());
108           }
109           mlir::Type vecType = mlir::VectorType::get(
110               insertOp.getType().getShape(), constant.getType());
111           auto denseAttr = mlir::DenseElementsAttr::get(
112               vecType.cast<ShapedType>(), constant.value());
113           rewriter.setInsertionPointAfter(insertOp);
114           rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
115               insertOp, seqTyAttr, denseAttr);
116         }
117       }
118     }
119     rewriter.eraseOp(global);
120     return success();
121   }
122 
123   bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const {
124     auto extents = seqTy.getShape();
125     if (indexes.size() / 2 != extents.size())
126       return false;
127     for (unsigned i = 0; i < indexes.size(); i += 2) {
128       if (indexes[i].cast<IntegerAttr>().getInt() != 0)
129         return false;
130       if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1)
131         return false;
132     }
133     return true;
134   }
135 
136   mlir::LLVM::Linkage convertLinkage(Optional<StringRef> optLinkage) const {
137     if (optLinkage.hasValue()) {
138       auto name = optLinkage.getValue();
139       if (name == "internal")
140         return mlir::LLVM::Linkage::Internal;
141       if (name == "linkonce")
142         return mlir::LLVM::Linkage::Linkonce;
143       if (name == "common")
144         return mlir::LLVM::Linkage::Common;
145       if (name == "weak")
146         return mlir::LLVM::Linkage::Weak;
147     }
148     return mlir::LLVM::Linkage::External;
149   }
150 };
151 
152 // convert to LLVM IR dialect `undef`
153 struct UndefOpConversion : public FIROpConversion<fir::UndefOp> {
154   using FIROpConversion::FIROpConversion;
155 
156   mlir::LogicalResult
157   matchAndRewrite(fir::UndefOp undef, OpAdaptor,
158                   mlir::ConversionPatternRewriter &rewriter) const override {
159     rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>(
160         undef, convertType(undef.getType()));
161     return success();
162   }
163 };
164 } // namespace
165 
166 namespace {
167 /// Convert FIR dialect to LLVM dialect
168 ///
169 /// This pass lowers all FIR dialect operations to LLVM IR dialect. An
170 /// MLIR pass is used to lower residual Std dialect to LLVM IR dialect.
171 ///
172 /// This pass is not complete yet. We are upstreaming it in small patches.
173 class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
174 public:
175   mlir::ModuleOp getModule() { return getOperation(); }
176 
177   void runOnOperation() override final {
178     auto *context = getModule().getContext();
179     fir::LLVMTypeConverter typeConverter{getModule()};
180     auto loc = mlir::UnknownLoc::get(context);
181     mlir::OwningRewritePatternList pattern(context);
182     pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion,
183                    UndefOpConversion>(typeConverter);
184     mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
185     mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
186                                                             pattern);
187     mlir::ConversionTarget target{*context};
188     target.addLegalDialect<mlir::LLVM::LLVMDialect>();
189 
190     // required NOPs for applying a full conversion
191     target.addLegalOp<mlir::ModuleOp>();
192 
193     // apply the patterns
194     if (mlir::failed(mlir::applyFullConversion(getModule(), target,
195                                                std::move(pattern)))) {
196       mlir::emitError(loc, "error in converting to LLVM-IR dialect\n");
197       signalPassFailure();
198     }
199   }
200 };
201 } // namespace
202 
203 std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() {
204   return std::make_unique<FIRToLLVMLowering>();
205 }
206