1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/CodeGen/CodeGen.h"
14 #include "PassDetail.h"
15 #include "flang/Optimizer/Dialect/FIROps.h"
16 #include "flang/Optimizer/Dialect/FIRType.h"
17 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
18 #include "mlir/Conversion/LLVMCommon/Pattern.h"
19 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/Pass/Pass.h"
23 #include "llvm/ADT/ArrayRef.h"
24 
25 #define DEBUG_TYPE "flang-codegen"
26 
27 // fir::LLVMTypeConverter for converting to LLVM IR dialect types.
28 #include "TypeConverter.h"
29 
30 namespace {
31 /// FIR conversion pattern template
32 template <typename FromOp>
33 class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
34 public:
35   explicit FIROpConversion(fir::LLVMTypeConverter &lowering)
36       : mlir::ConvertOpToLLVMPattern<FromOp>(lowering) {}
37 
38 protected:
39   mlir::Type convertType(mlir::Type ty) const {
40     return lowerTy().convertType(ty);
41   }
42 
43   fir::LLVMTypeConverter &lowerTy() const {
44     return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter());
45   }
46 };
47 
48 // Lower `fir.address_of` operation to `llvm.address_of` operation.
49 struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> {
50   using FIROpConversion::FIROpConversion;
51 
52   mlir::LogicalResult
53   matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
54                   mlir::ConversionPatternRewriter &rewriter) const override {
55     auto ty = convertType(addr.getType());
56     rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
57         addr, ty, addr.symbol().getRootReference().getValue());
58     return success();
59   }
60 };
61 
62 /// Lower `fir.has_value` operation to `llvm.return` operation.
63 struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> {
64   using FIROpConversion::FIROpConversion;
65 
66   mlir::LogicalResult
67   matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor,
68                   mlir::ConversionPatternRewriter &rewriter) const override {
69     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
70     return success();
71   }
72 };
73 
74 /// Lower `fir.global` operation to `llvm.global` operation.
75 /// `fir.insert_on_range` operations are replaced with constant dense attribute
76 /// if they are applied on the full range.
77 struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
78   using FIROpConversion::FIROpConversion;
79 
80   mlir::LogicalResult
81   matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor,
82                   mlir::ConversionPatternRewriter &rewriter) const override {
83     auto tyAttr = convertType(global.getType());
84     if (global.getType().isa<fir::BoxType>())
85       tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType();
86     auto loc = global.getLoc();
87     mlir::Attribute initAttr{};
88     if (global.initVal())
89       initAttr = global.initVal().getValue();
90     auto linkage = convertLinkage(global.linkName());
91     auto isConst = global.constant().hasValue();
92     auto g = rewriter.create<mlir::LLVM::GlobalOp>(
93         loc, tyAttr, isConst, linkage, global.sym_name(), initAttr);
94     auto &gr = g.getInitializerRegion();
95     rewriter.inlineRegionBefore(global.region(), gr, gr.end());
96     if (!gr.empty()) {
97       // Replace insert_on_range with a constant dense attribute if the
98       // initialization is on the full range.
99       auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>();
100       for (auto insertOp : insertOnRangeOps) {
101         if (isFullRange(insertOp.coor(), insertOp.getType())) {
102           auto seqTyAttr = convertType(insertOp.getType());
103           auto *op = insertOp.val().getDefiningOp();
104           auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op);
105           if (!constant) {
106             auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op);
107             if (!convertOp)
108               continue;
109             constant = cast<mlir::arith::ConstantOp>(
110                 convertOp.value().getDefiningOp());
111           }
112           mlir::Type vecType = mlir::VectorType::get(
113               insertOp.getType().getShape(), constant.getType());
114           auto denseAttr = mlir::DenseElementsAttr::get(
115               vecType.cast<ShapedType>(), constant.value());
116           rewriter.setInsertionPointAfter(insertOp);
117           rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
118               insertOp, seqTyAttr, denseAttr);
119         }
120       }
121     }
122     rewriter.eraseOp(global);
123     return success();
124   }
125 
126   bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const {
127     auto extents = seqTy.getShape();
128     if (indexes.size() / 2 != extents.size())
129       return false;
130     for (unsigned i = 0; i < indexes.size(); i += 2) {
131       if (indexes[i].cast<IntegerAttr>().getInt() != 0)
132         return false;
133       if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1)
134         return false;
135     }
136     return true;
137   }
138 
139   // TODO: String comparaison should be avoided. Replace linkName with an
140   // enumeration.
141   mlir::LLVM::Linkage convertLinkage(Optional<StringRef> optLinkage) const {
142     if (optLinkage.hasValue()) {
143       auto name = optLinkage.getValue();
144       if (name == "internal")
145         return mlir::LLVM::Linkage::Internal;
146       if (name == "linkonce")
147         return mlir::LLVM::Linkage::Linkonce;
148       if (name == "common")
149         return mlir::LLVM::Linkage::Common;
150       if (name == "weak")
151         return mlir::LLVM::Linkage::Weak;
152     }
153     return mlir::LLVM::Linkage::External;
154   }
155 };
156 
157 // convert to LLVM IR dialect `undef`
158 struct UndefOpConversion : public FIROpConversion<fir::UndefOp> {
159   using FIROpConversion::FIROpConversion;
160 
161   mlir::LogicalResult
162   matchAndRewrite(fir::UndefOp undef, OpAdaptor,
163                   mlir::ConversionPatternRewriter &rewriter) const override {
164     rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>(
165         undef, convertType(undef.getType()));
166     return success();
167   }
168 };
169 } // namespace
170 
171 namespace {
172 /// Convert FIR dialect to LLVM dialect
173 ///
174 /// This pass lowers all FIR dialect operations to LLVM IR dialect. An
175 /// MLIR pass is used to lower residual Std dialect to LLVM IR dialect.
176 ///
177 /// This pass is not complete yet. We are upstreaming it in small patches.
178 class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
179 public:
180   mlir::ModuleOp getModule() { return getOperation(); }
181 
182   void runOnOperation() override final {
183     auto *context = getModule().getContext();
184     fir::LLVMTypeConverter typeConverter{getModule()};
185     auto loc = mlir::UnknownLoc::get(context);
186     mlir::OwningRewritePatternList pattern(context);
187     pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion,
188                    UndefOpConversion>(typeConverter);
189     mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
190     mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
191                                                             pattern);
192     mlir::ConversionTarget target{*context};
193     target.addLegalDialect<mlir::LLVM::LLVMDialect>();
194 
195     // required NOPs for applying a full conversion
196     target.addLegalOp<mlir::ModuleOp>();
197 
198     // apply the patterns
199     if (mlir::failed(mlir::applyFullConversion(getModule(), target,
200                                                std::move(pattern)))) {
201       mlir::emitError(loc, "error in converting to LLVM-IR dialect\n");
202       signalPassFailure();
203     }
204   }
205 };
206 } // namespace
207 
208 std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() {
209   return std::make_unique<FIRToLLVMLowering>();
210 }
211