1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/CodeGen/CodeGen.h"
14 #include "PassDetail.h"
15 #include "flang/Optimizer/Dialect/FIROps.h"
16 #include "flang/Optimizer/Dialect/FIRType.h"
17 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
18 #include "mlir/Conversion/LLVMCommon/Pattern.h"
19 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/Pass/Pass.h"
24 #include "llvm/ADT/ArrayRef.h"
25 
26 #define DEBUG_TYPE "flang-codegen"
27 
28 // fir::LLVMTypeConverter for converting to LLVM IR dialect types.
29 #include "TypeConverter.h"
30 
31 namespace {
32 /// FIR conversion pattern template
33 template <typename FromOp>
34 class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
35 public:
36   explicit FIROpConversion(fir::LLVMTypeConverter &lowering)
37       : mlir::ConvertOpToLLVMPattern<FromOp>(lowering) {}
38 
39 protected:
40   mlir::Type convertType(mlir::Type ty) const {
41     return lowerTy().convertType(ty);
42   }
43 
44   fir::LLVMTypeConverter &lowerTy() const {
45     return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter());
46   }
47 };
48 
49 /// FIR conversion pattern template
50 template <typename FromOp>
51 class FIROpAndTypeConversion : public FIROpConversion<FromOp> {
52 public:
53   using FIROpConversion<FromOp>::FIROpConversion;
54   using OpAdaptor = typename FromOp::Adaptor;
55 
56   mlir::LogicalResult
57   matchAndRewrite(FromOp op, OpAdaptor adaptor,
58                   mlir::ConversionPatternRewriter &rewriter) const final {
59     mlir::Type ty = this->convertType(op.getType());
60     return doRewrite(op, ty, adaptor, rewriter);
61   }
62 
63   virtual mlir::LogicalResult
64   doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor,
65             mlir::ConversionPatternRewriter &rewriter) const = 0;
66 };
67 
68 // Lower `fir.address_of` operation to `llvm.address_of` operation.
69 struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> {
70   using FIROpConversion::FIROpConversion;
71 
72   mlir::LogicalResult
73   matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
74                   mlir::ConversionPatternRewriter &rewriter) const override {
75     auto ty = convertType(addr.getType());
76     rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
77         addr, ty, addr.symbol().getRootReference().getValue());
78     return success();
79   }
80 };
81 
82 /// Lower `fir.has_value` operation to `llvm.return` operation.
83 struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> {
84   using FIROpConversion::FIROpConversion;
85 
86   mlir::LogicalResult
87   matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor,
88                   mlir::ConversionPatternRewriter &rewriter) const override {
89     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
90     return success();
91   }
92 };
93 
94 /// Lower `fir.global` operation to `llvm.global` operation.
95 /// `fir.insert_on_range` operations are replaced with constant dense attribute
96 /// if they are applied on the full range.
97 struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
98   using FIROpConversion::FIROpConversion;
99 
100   mlir::LogicalResult
101   matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor,
102                   mlir::ConversionPatternRewriter &rewriter) const override {
103     auto tyAttr = convertType(global.getType());
104     if (global.getType().isa<fir::BoxType>())
105       tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType();
106     auto loc = global.getLoc();
107     mlir::Attribute initAttr{};
108     if (global.initVal())
109       initAttr = global.initVal().getValue();
110     auto linkage = convertLinkage(global.linkName());
111     auto isConst = global.constant().hasValue();
112     auto g = rewriter.create<mlir::LLVM::GlobalOp>(
113         loc, tyAttr, isConst, linkage, global.sym_name(), initAttr);
114     auto &gr = g.getInitializerRegion();
115     rewriter.inlineRegionBefore(global.region(), gr, gr.end());
116     if (!gr.empty()) {
117       // Replace insert_on_range with a constant dense attribute if the
118       // initialization is on the full range.
119       auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>();
120       for (auto insertOp : insertOnRangeOps) {
121         if (isFullRange(insertOp.coor(), insertOp.getType())) {
122           auto seqTyAttr = convertType(insertOp.getType());
123           auto *op = insertOp.val().getDefiningOp();
124           auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op);
125           if (!constant) {
126             auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op);
127             if (!convertOp)
128               continue;
129             constant = cast<mlir::arith::ConstantOp>(
130                 convertOp.value().getDefiningOp());
131           }
132           mlir::Type vecType = mlir::VectorType::get(
133               insertOp.getType().getShape(), constant.getType());
134           auto denseAttr = mlir::DenseElementsAttr::get(
135               vecType.cast<ShapedType>(), constant.value());
136           rewriter.setInsertionPointAfter(insertOp);
137           rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
138               insertOp, seqTyAttr, denseAttr);
139         }
140       }
141     }
142     rewriter.eraseOp(global);
143     return success();
144   }
145 
146   bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const {
147     auto extents = seqTy.getShape();
148     if (indexes.size() / 2 != extents.size())
149       return false;
150     for (unsigned i = 0; i < indexes.size(); i += 2) {
151       if (indexes[i].cast<IntegerAttr>().getInt() != 0)
152         return false;
153       if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1)
154         return false;
155     }
156     return true;
157   }
158 
159   // TODO: String comparaison should be avoided. Replace linkName with an
160   // enumeration.
161   mlir::LLVM::Linkage convertLinkage(Optional<StringRef> optLinkage) const {
162     if (optLinkage.hasValue()) {
163       auto name = optLinkage.getValue();
164       if (name == "internal")
165         return mlir::LLVM::Linkage::Internal;
166       if (name == "linkonce")
167         return mlir::LLVM::Linkage::Linkonce;
168       if (name == "common")
169         return mlir::LLVM::Linkage::Common;
170       if (name == "weak")
171         return mlir::LLVM::Linkage::Weak;
172     }
173     return mlir::LLVM::Linkage::External;
174   }
175 };
176 
177 template <typename OP>
178 void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
179                            typename OP::Adaptor adaptor,
180                            mlir::ConversionPatternRewriter &rewriter) {
181   unsigned conds = select.getNumConditions();
182   auto cases = select.getCases().getValue();
183   mlir::Value selector = adaptor.selector();
184   auto loc = select.getLoc();
185   assert(conds > 0 && "select must have cases");
186 
187   llvm::SmallVector<mlir::Block *> destinations;
188   llvm::SmallVector<mlir::ValueRange> destinationsOperands;
189   mlir::Block *defaultDestination;
190   mlir::ValueRange defaultOperands;
191   llvm::SmallVector<int32_t> caseValues;
192 
193   for (unsigned t = 0; t != conds; ++t) {
194     mlir::Block *dest = select.getSuccessor(t);
195     auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
196     const mlir::Attribute &attr = cases[t];
197     if (auto intAttr = attr.template dyn_cast<mlir::IntegerAttr>()) {
198       destinations.push_back(dest);
199       destinationsOperands.push_back(destOps.hasValue() ? *destOps
200                                                         : ValueRange());
201       caseValues.push_back(intAttr.getInt());
202       continue;
203     }
204     assert(attr.template dyn_cast_or_null<mlir::UnitAttr>());
205     assert((t + 1 == conds) && "unit must be last");
206     defaultDestination = dest;
207     defaultOperands = destOps.hasValue() ? *destOps : ValueRange();
208   }
209 
210   // LLVM::SwitchOp takes a i32 type for the selector.
211   if (select.getSelector().getType() != rewriter.getI32Type())
212     selector =
213         rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), selector);
214 
215   rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
216       select, selector,
217       /*defaultDestination=*/defaultDestination,
218       /*defaultOperands=*/defaultOperands,
219       /*caseValues=*/caseValues,
220       /*caseDestinations=*/destinations,
221       /*caseOperands=*/destinationsOperands,
222       /*branchWeights=*/ArrayRef<int32_t>());
223 }
224 
225 /// conversion of fir::SelectOp to an if-then-else ladder
226 struct SelectOpConversion : public FIROpConversion<fir::SelectOp> {
227   using FIROpConversion::FIROpConversion;
228 
229   mlir::LogicalResult
230   matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
231                   mlir::ConversionPatternRewriter &rewriter) const override {
232     selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, rewriter);
233     return success();
234   }
235 };
236 
237 /// conversion of fir::SelectRankOp to an if-then-else ladder
238 struct SelectRankOpConversion : public FIROpConversion<fir::SelectRankOp> {
239   using FIROpConversion::FIROpConversion;
240 
241   mlir::LogicalResult
242   matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
243                   mlir::ConversionPatternRewriter &rewriter) const override {
244     selectMatchAndRewrite<fir::SelectRankOp>(lowerTy(), op, adaptor, rewriter);
245     return success();
246   }
247 };
248 
249 // convert to LLVM IR dialect `undef`
250 struct UndefOpConversion : public FIROpConversion<fir::UndefOp> {
251   using FIROpConversion::FIROpConversion;
252 
253   mlir::LogicalResult
254   matchAndRewrite(fir::UndefOp undef, OpAdaptor,
255                   mlir::ConversionPatternRewriter &rewriter) const override {
256     rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>(
257         undef, convertType(undef.getType()));
258     return success();
259   }
260 };
261 
262 // convert to LLVM IR dialect `unreachable`
263 struct UnreachableOpConversion : public FIROpConversion<fir::UnreachableOp> {
264   using FIROpConversion::FIROpConversion;
265 
266   mlir::LogicalResult
267   matchAndRewrite(fir::UnreachableOp unreach, OpAdaptor adaptor,
268                   mlir::ConversionPatternRewriter &rewriter) const override {
269     rewriter.replaceOpWithNewOp<mlir::LLVM::UnreachableOp>(unreach);
270     return success();
271   }
272 };
273 
274 struct ZeroOpConversion : public FIROpConversion<fir::ZeroOp> {
275   using FIROpConversion::FIROpConversion;
276 
277   mlir::LogicalResult
278   matchAndRewrite(fir::ZeroOp zero, OpAdaptor,
279                   mlir::ConversionPatternRewriter &rewriter) const override {
280     auto ty = convertType(zero.getType());
281     if (ty.isa<mlir::LLVM::LLVMPointerType>()) {
282       rewriter.replaceOpWithNewOp<mlir::LLVM::NullOp>(zero, ty);
283     } else if (ty.isa<mlir::IntegerType>()) {
284       rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
285           zero, ty, mlir::IntegerAttr::get(zero.getType(), 0));
286     } else if (mlir::LLVM::isCompatibleFloatingPointType(ty)) {
287       rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
288           zero, ty, mlir::FloatAttr::get(zero.getType(), 0.0));
289     } else {
290       // TODO: create ConstantAggregateZero for FIR aggregate/array types.
291       return rewriter.notifyMatchFailure(
292           zero,
293           "conversion of fir.zero with aggregate type not implemented yet");
294     }
295     return success();
296   }
297 };
298 
299 /// InsertOnRange inserts a value into a sequence over a range of offsets.
300 struct InsertOnRangeOpConversion
301     : public FIROpAndTypeConversion<fir::InsertOnRangeOp> {
302   using FIROpAndTypeConversion::FIROpAndTypeConversion;
303 
304   // Increments an array of subscripts in a row major fasion.
305   void incrementSubscripts(const SmallVector<uint64_t> &dims,
306                            SmallVector<uint64_t> &subscripts) const {
307     for (size_t i = dims.size(); i > 0; --i) {
308       if (++subscripts[i - 1] < dims[i - 1]) {
309         return;
310       }
311       subscripts[i - 1] = 0;
312     }
313   }
314 
315   mlir::LogicalResult
316   doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor,
317             mlir::ConversionPatternRewriter &rewriter) const override {
318 
319     llvm::SmallVector<uint64_t> dims;
320     auto type = adaptor.getOperands()[0].getType();
321 
322     // Iteratively extract the array dimensions from the type.
323     while (auto t = type.dyn_cast<mlir::LLVM::LLVMArrayType>()) {
324       dims.push_back(t.getNumElements());
325       type = t.getElementType();
326     }
327 
328     SmallVector<uint64_t> lBounds;
329     SmallVector<uint64_t> uBounds;
330 
331     // Extract integer value from the attribute
332     SmallVector<int64_t> coordinates = llvm::to_vector<4>(
333         llvm::map_range(range.coor(), [](Attribute a) -> int64_t {
334           return a.cast<IntegerAttr>().getInt();
335         }));
336 
337     // Unzip the upper and lower bound and convert to a row major format.
338     for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) {
339       uBounds.push_back(*i++);
340       lBounds.push_back(*i);
341     }
342 
343     auto &subscripts = lBounds;
344     auto loc = range.getLoc();
345     mlir::Value lastOp = adaptor.getOperands()[0];
346     mlir::Value insertVal = adaptor.getOperands()[1];
347 
348     auto i64Ty = rewriter.getI64Type();
349     while (subscripts != uBounds) {
350       // Convert uint64_t's to Attribute's.
351       SmallVector<mlir::Attribute> subscriptAttrs;
352       for (const auto &subscript : subscripts)
353         subscriptAttrs.push_back(IntegerAttr::get(i64Ty, subscript));
354       lastOp = rewriter.create<mlir::LLVM::InsertValueOp>(
355           loc, ty, lastOp, insertVal,
356           ArrayAttr::get(range.getContext(), subscriptAttrs));
357 
358       incrementSubscripts(dims, subscripts);
359     }
360 
361     // Convert uint64_t's to Attribute's.
362     SmallVector<mlir::Attribute> subscriptAttrs;
363     for (const auto &subscript : subscripts)
364       subscriptAttrs.push_back(
365           IntegerAttr::get(rewriter.getI64Type(), subscript));
366     mlir::ArrayRef<mlir::Attribute> arrayRef(subscriptAttrs);
367 
368     rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
369         range, ty, lastOp, insertVal,
370         ArrayAttr::get(range.getContext(), arrayRef));
371 
372     return success();
373   }
374 };
375 } // namespace
376 
377 namespace {
378 /// Convert FIR dialect to LLVM dialect
379 ///
380 /// This pass lowers all FIR dialect operations to LLVM IR dialect. An
381 /// MLIR pass is used to lower residual Std dialect to LLVM IR dialect.
382 ///
383 /// This pass is not complete yet. We are upstreaming it in small patches.
384 class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
385 public:
386   mlir::ModuleOp getModule() { return getOperation(); }
387 
388   void runOnOperation() override final {
389     auto *context = getModule().getContext();
390     fir::LLVMTypeConverter typeConverter{getModule()};
391     mlir::OwningRewritePatternList pattern(context);
392     pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion,
393                    InsertOnRangeOpConversion, SelectOpConversion,
394                    SelectRankOpConversion, UnreachableOpConversion,
395                    ZeroOpConversion, UndefOpConversion>(typeConverter);
396     mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
397     mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
398                                                             pattern);
399     mlir::ConversionTarget target{*context};
400     target.addLegalDialect<mlir::LLVM::LLVMDialect>();
401 
402     // required NOPs for applying a full conversion
403     target.addLegalOp<mlir::ModuleOp>();
404 
405     // apply the patterns
406     if (mlir::failed(mlir::applyFullConversion(getModule(), target,
407                                                std::move(pattern)))) {
408       signalPassFailure();
409     }
410   }
411 };
412 } // namespace
413 
414 std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() {
415   return std::make_unique<FIRToLLVMLowering>();
416 }
417