1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/CodeGen/CodeGen.h"
14 #include "PassDetail.h"
15 #include "flang/Optimizer/Dialect/FIROps.h"
16 #include "flang/Optimizer/Dialect/FIRType.h"
17 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
18 #include "mlir/Conversion/LLVMCommon/Pattern.h"
19 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/Pass/Pass.h"
24 #include "llvm/ADT/ArrayRef.h"
25 
26 #define DEBUG_TYPE "flang-codegen"
27 
28 // fir::LLVMTypeConverter for converting to LLVM IR dialect types.
29 #include "TypeConverter.h"
30 
31 namespace {
32 /// FIR conversion pattern template
33 template <typename FromOp>
34 class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
35 public:
36   explicit FIROpConversion(fir::LLVMTypeConverter &lowering)
37       : mlir::ConvertOpToLLVMPattern<FromOp>(lowering) {}
38 
39 protected:
40   mlir::Type convertType(mlir::Type ty) const {
41     return lowerTy().convertType(ty);
42   }
43 
44   fir::LLVMTypeConverter &lowerTy() const {
45     return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter());
46   }
47 };
48 
49 /// FIR conversion pattern template
50 template <typename FromOp>
51 class FIROpAndTypeConversion : public FIROpConversion<FromOp> {
52 public:
53   using FIROpConversion<FromOp>::FIROpConversion;
54   using OpAdaptor = typename FromOp::Adaptor;
55 
56   mlir::LogicalResult
57   matchAndRewrite(FromOp op, OpAdaptor adaptor,
58                   mlir::ConversionPatternRewriter &rewriter) const final {
59     mlir::Type ty = this->convertType(op.getType());
60     return doRewrite(op, ty, adaptor, rewriter);
61   }
62 
63   virtual mlir::LogicalResult
64   doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor,
65             mlir::ConversionPatternRewriter &rewriter) const = 0;
66 };
67 
68 // Lower `fir.address_of` operation to `llvm.address_of` operation.
69 struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> {
70   using FIROpConversion::FIROpConversion;
71 
72   mlir::LogicalResult
73   matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
74                   mlir::ConversionPatternRewriter &rewriter) const override {
75     auto ty = convertType(addr.getType());
76     rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
77         addr, ty, addr.symbol().getRootReference().getValue());
78     return success();
79   }
80 };
81 
82 /// Lower `fir.has_value` operation to `llvm.return` operation.
83 struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> {
84   using FIROpConversion::FIROpConversion;
85 
86   mlir::LogicalResult
87   matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor,
88                   mlir::ConversionPatternRewriter &rewriter) const override {
89     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
90     return success();
91   }
92 };
93 
94 /// Lower `fir.global` operation to `llvm.global` operation.
95 /// `fir.insert_on_range` operations are replaced with constant dense attribute
96 /// if they are applied on the full range.
97 struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
98   using FIROpConversion::FIROpConversion;
99 
100   mlir::LogicalResult
101   matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor,
102                   mlir::ConversionPatternRewriter &rewriter) const override {
103     auto tyAttr = convertType(global.getType());
104     if (global.getType().isa<fir::BoxType>())
105       tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType();
106     auto loc = global.getLoc();
107     mlir::Attribute initAttr{};
108     if (global.initVal())
109       initAttr = global.initVal().getValue();
110     auto linkage = convertLinkage(global.linkName());
111     auto isConst = global.constant().hasValue();
112     auto g = rewriter.create<mlir::LLVM::GlobalOp>(
113         loc, tyAttr, isConst, linkage, global.sym_name(), initAttr);
114     auto &gr = g.getInitializerRegion();
115     rewriter.inlineRegionBefore(global.region(), gr, gr.end());
116     if (!gr.empty()) {
117       // Replace insert_on_range with a constant dense attribute if the
118       // initialization is on the full range.
119       auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>();
120       for (auto insertOp : insertOnRangeOps) {
121         if (isFullRange(insertOp.coor(), insertOp.getType())) {
122           auto seqTyAttr = convertType(insertOp.getType());
123           auto *op = insertOp.val().getDefiningOp();
124           auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op);
125           if (!constant) {
126             auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op);
127             if (!convertOp)
128               continue;
129             constant = cast<mlir::arith::ConstantOp>(
130                 convertOp.value().getDefiningOp());
131           }
132           mlir::Type vecType = mlir::VectorType::get(
133               insertOp.getType().getShape(), constant.getType());
134           auto denseAttr = mlir::DenseElementsAttr::get(
135               vecType.cast<ShapedType>(), constant.value());
136           rewriter.setInsertionPointAfter(insertOp);
137           rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
138               insertOp, seqTyAttr, denseAttr);
139         }
140       }
141     }
142     rewriter.eraseOp(global);
143     return success();
144   }
145 
146   bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const {
147     auto extents = seqTy.getShape();
148     if (indexes.size() / 2 != extents.size())
149       return false;
150     for (unsigned i = 0; i < indexes.size(); i += 2) {
151       if (indexes[i].cast<IntegerAttr>().getInt() != 0)
152         return false;
153       if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1)
154         return false;
155     }
156     return true;
157   }
158 
159   // TODO: String comparaison should be avoided. Replace linkName with an
160   // enumeration.
161   mlir::LLVM::Linkage convertLinkage(Optional<StringRef> optLinkage) const {
162     if (optLinkage.hasValue()) {
163       auto name = optLinkage.getValue();
164       if (name == "internal")
165         return mlir::LLVM::Linkage::Internal;
166       if (name == "linkonce")
167         return mlir::LLVM::Linkage::Linkonce;
168       if (name == "common")
169         return mlir::LLVM::Linkage::Common;
170       if (name == "weak")
171         return mlir::LLVM::Linkage::Weak;
172     }
173     return mlir::LLVM::Linkage::External;
174   }
175 };
176 
177 // convert to LLVM IR dialect `undef`
178 struct UndefOpConversion : public FIROpConversion<fir::UndefOp> {
179   using FIROpConversion::FIROpConversion;
180 
181   mlir::LogicalResult
182   matchAndRewrite(fir::UndefOp undef, OpAdaptor,
183                   mlir::ConversionPatternRewriter &rewriter) const override {
184     rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>(
185         undef, convertType(undef.getType()));
186     return success();
187   }
188 };
189 
190 // convert to LLVM IR dialect `unreachable`
191 struct UnreachableOpConversion : public FIROpConversion<fir::UnreachableOp> {
192   using FIROpConversion::FIROpConversion;
193 
194   mlir::LogicalResult
195   matchAndRewrite(fir::UnreachableOp unreach, OpAdaptor adaptor,
196                   mlir::ConversionPatternRewriter &rewriter) const override {
197     rewriter.replaceOpWithNewOp<mlir::LLVM::UnreachableOp>(unreach);
198     return success();
199   }
200 };
201 
202 struct ZeroOpConversion : public FIROpConversion<fir::ZeroOp> {
203   using FIROpConversion::FIROpConversion;
204 
205   mlir::LogicalResult
206   matchAndRewrite(fir::ZeroOp zero, OpAdaptor,
207                   mlir::ConversionPatternRewriter &rewriter) const override {
208     auto ty = convertType(zero.getType());
209     if (ty.isa<mlir::LLVM::LLVMPointerType>()) {
210       rewriter.replaceOpWithNewOp<mlir::LLVM::NullOp>(zero, ty);
211     } else if (ty.isa<mlir::IntegerType>()) {
212       rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
213           zero, ty, mlir::IntegerAttr::get(zero.getType(), 0));
214     } else if (mlir::LLVM::isCompatibleFloatingPointType(ty)) {
215       rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
216           zero, ty, mlir::FloatAttr::get(zero.getType(), 0.0));
217     } else {
218       // TODO: create ConstantAggregateZero for FIR aggregate/array types.
219       return rewriter.notifyMatchFailure(
220           zero,
221           "conversion of fir.zero with aggregate type not implemented yet");
222     }
223     return success();
224   }
225 };
226 
227 /// InsertOnRange inserts a value into a sequence over a range of offsets.
228 struct InsertOnRangeOpConversion
229     : public FIROpAndTypeConversion<fir::InsertOnRangeOp> {
230   using FIROpAndTypeConversion::FIROpAndTypeConversion;
231 
232   // Increments an array of subscripts in a row major fasion.
233   void incrementSubscripts(const SmallVector<uint64_t> &dims,
234                            SmallVector<uint64_t> &subscripts) const {
235     for (size_t i = dims.size(); i > 0; --i) {
236       if (++subscripts[i - 1] < dims[i - 1]) {
237         return;
238       }
239       subscripts[i - 1] = 0;
240     }
241   }
242 
243   mlir::LogicalResult
244   doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor,
245             mlir::ConversionPatternRewriter &rewriter) const override {
246 
247     llvm::SmallVector<uint64_t> dims;
248     auto type = adaptor.getOperands()[0].getType();
249 
250     // Iteratively extract the array dimensions from the type.
251     while (auto t = type.dyn_cast<mlir::LLVM::LLVMArrayType>()) {
252       dims.push_back(t.getNumElements());
253       type = t.getElementType();
254     }
255 
256     SmallVector<uint64_t> lBounds;
257     SmallVector<uint64_t> uBounds;
258 
259     // Extract integer value from the attribute
260     SmallVector<int64_t> coordinates = llvm::to_vector<4>(
261         llvm::map_range(range.coor(), [](Attribute a) -> int64_t {
262           return a.cast<IntegerAttr>().getInt();
263         }));
264 
265     // Unzip the upper and lower bound and convert to a row major format.
266     for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) {
267       uBounds.push_back(*i++);
268       lBounds.push_back(*i);
269     }
270 
271     auto &subscripts = lBounds;
272     auto loc = range.getLoc();
273     mlir::Value lastOp = adaptor.getOperands()[0];
274     mlir::Value insertVal = adaptor.getOperands()[1];
275 
276     auto i64Ty = rewriter.getI64Type();
277     while (subscripts != uBounds) {
278       // Convert uint64_t's to Attribute's.
279       SmallVector<mlir::Attribute> subscriptAttrs;
280       for (const auto &subscript : subscripts)
281         subscriptAttrs.push_back(IntegerAttr::get(i64Ty, subscript));
282       lastOp = rewriter.create<mlir::LLVM::InsertValueOp>(
283           loc, ty, lastOp, insertVal,
284           ArrayAttr::get(range.getContext(), subscriptAttrs));
285 
286       incrementSubscripts(dims, subscripts);
287     }
288 
289     // Convert uint64_t's to Attribute's.
290     SmallVector<mlir::Attribute> subscriptAttrs;
291     for (const auto &subscript : subscripts)
292       subscriptAttrs.push_back(
293           IntegerAttr::get(rewriter.getI64Type(), subscript));
294     mlir::ArrayRef<mlir::Attribute> arrayRef(subscriptAttrs);
295 
296     rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
297         range, ty, lastOp, insertVal,
298         ArrayAttr::get(range.getContext(), arrayRef));
299 
300     return success();
301   }
302 };
303 } // namespace
304 
305 namespace {
306 /// Convert FIR dialect to LLVM dialect
307 ///
308 /// This pass lowers all FIR dialect operations to LLVM IR dialect. An
309 /// MLIR pass is used to lower residual Std dialect to LLVM IR dialect.
310 ///
311 /// This pass is not complete yet. We are upstreaming it in small patches.
312 class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
313 public:
314   mlir::ModuleOp getModule() { return getOperation(); }
315 
316   void runOnOperation() override final {
317     auto *context = getModule().getContext();
318     fir::LLVMTypeConverter typeConverter{getModule()};
319     mlir::OwningRewritePatternList pattern(context);
320     pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion,
321                    InsertOnRangeOpConversion, UndefOpConversion,
322                    UnreachableOpConversion, ZeroOpConversion>(typeConverter);
323     mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
324     mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
325                                                             pattern);
326     mlir::ConversionTarget target{*context};
327     target.addLegalDialect<mlir::LLVM::LLVMDialect>();
328 
329     // required NOPs for applying a full conversion
330     target.addLegalOp<mlir::ModuleOp>();
331 
332     // apply the patterns
333     if (mlir::failed(mlir::applyFullConversion(getModule(), target,
334                                                std::move(pattern)))) {
335       signalPassFailure();
336     }
337   }
338 };
339 } // namespace
340 
341 std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() {
342   return std::make_unique<FIRToLLVMLowering>();
343 }
344