1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/CodeGen/CodeGen.h"
14 #include "PassDetail.h"
15 #include "flang/Optimizer/Dialect/FIROps.h"
16 #include "flang/Optimizer/Dialect/FIRType.h"
17 #include "flang/Optimizer/Support/FIRContext.h"
18 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
19 #include "mlir/Conversion/LLVMCommon/Pattern.h"
20 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
21 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/Pass/Pass.h"
25 #include "llvm/ADT/ArrayRef.h"
26 
27 #define DEBUG_TYPE "flang-codegen"
28 
29 // fir::LLVMTypeConverter for converting to LLVM IR dialect types.
30 #include "TypeConverter.h"
31 
32 namespace {
33 /// FIR conversion pattern template
34 template <typename FromOp>
35 class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
36 public:
37   explicit FIROpConversion(fir::LLVMTypeConverter &lowering)
38       : mlir::ConvertOpToLLVMPattern<FromOp>(lowering) {}
39 
40 protected:
41   mlir::Type convertType(mlir::Type ty) const {
42     return lowerTy().convertType(ty);
43   }
44 
45   fir::LLVMTypeConverter &lowerTy() const {
46     return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter());
47   }
48 };
49 
50 /// FIR conversion pattern template
51 template <typename FromOp>
52 class FIROpAndTypeConversion : public FIROpConversion<FromOp> {
53 public:
54   using FIROpConversion<FromOp>::FIROpConversion;
55   using OpAdaptor = typename FromOp::Adaptor;
56 
57   mlir::LogicalResult
58   matchAndRewrite(FromOp op, OpAdaptor adaptor,
59                   mlir::ConversionPatternRewriter &rewriter) const final {
60     mlir::Type ty = this->convertType(op.getType());
61     return doRewrite(op, ty, adaptor, rewriter);
62   }
63 
64   virtual mlir::LogicalResult
65   doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor,
66             mlir::ConversionPatternRewriter &rewriter) const = 0;
67 };
68 
69 // Lower `fir.address_of` operation to `llvm.address_of` operation.
70 struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> {
71   using FIROpConversion::FIROpConversion;
72 
73   mlir::LogicalResult
74   matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
75                   mlir::ConversionPatternRewriter &rewriter) const override {
76     auto ty = convertType(addr.getType());
77     rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
78         addr, ty, addr.symbol().getRootReference().getValue());
79     return success();
80   }
81 };
82 
83 // `fir.call` -> `llvm.call`
84 struct CallOpConversion : public FIROpConversion<fir::CallOp> {
85   using FIROpConversion::FIROpConversion;
86 
87   mlir::LogicalResult
88   matchAndRewrite(fir::CallOp call, OpAdaptor adaptor,
89                   mlir::ConversionPatternRewriter &rewriter) const override {
90     SmallVector<mlir::Type> resultTys;
91     for (auto r : call.getResults())
92       resultTys.push_back(convertType(r.getType()));
93     rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
94         call, resultTys, adaptor.getOperands(), call->getAttrs());
95     return success();
96   }
97 };
98 
99 static mlir::Type getComplexEleTy(mlir::Type complex) {
100   if (auto cc = complex.dyn_cast<mlir::ComplexType>())
101     return cc.getElementType();
102   return complex.cast<fir::ComplexType>().getElementType();
103 }
104 
105 /// convert value of from-type to value of to-type
106 struct ConvertOpConversion : public FIROpConversion<fir::ConvertOp> {
107   using FIROpConversion::FIROpConversion;
108 
109   static bool isFloatingPointTy(mlir::Type ty) {
110     return ty.isa<mlir::FloatType>();
111   }
112 
113   mlir::LogicalResult
114   matchAndRewrite(fir::ConvertOp convert, OpAdaptor adaptor,
115                   mlir::ConversionPatternRewriter &rewriter) const override {
116     auto fromTy = convertType(convert.value().getType());
117     auto toTy = convertType(convert.res().getType());
118     mlir::Value op0 = adaptor.getOperands()[0];
119     if (fromTy == toTy) {
120       rewriter.replaceOp(convert, op0);
121       return success();
122     }
123     auto loc = convert.getLoc();
124     auto convertFpToFp = [&](mlir::Value val, unsigned fromBits,
125                              unsigned toBits, mlir::Type toTy) -> mlir::Value {
126       if (fromBits == toBits) {
127         // TODO: Converting between two floating-point representations with the
128         // same bitwidth is not allowed for now.
129         mlir::emitError(loc,
130                         "cannot implicitly convert between two floating-point "
131                         "representations of the same bitwidth");
132         return {};
133       }
134       if (fromBits > toBits)
135         return rewriter.create<mlir::LLVM::FPTruncOp>(loc, toTy, val);
136       return rewriter.create<mlir::LLVM::FPExtOp>(loc, toTy, val);
137     };
138     // Complex to complex conversion.
139     if (fir::isa_complex(convert.value().getType()) &&
140         fir::isa_complex(convert.res().getType())) {
141       // Special case: handle the conversion of a complex such that both the
142       // real and imaginary parts are converted together.
143       auto zero = mlir::ArrayAttr::get(convert.getContext(),
144                                        rewriter.getI32IntegerAttr(0));
145       auto one = mlir::ArrayAttr::get(convert.getContext(),
146                                       rewriter.getI32IntegerAttr(1));
147       auto ty = convertType(getComplexEleTy(convert.value().getType()));
148       auto rp = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, ty, op0, zero);
149       auto ip = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, ty, op0, one);
150       auto nt = convertType(getComplexEleTy(convert.res().getType()));
151       auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
152       auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(nt);
153       auto rc = convertFpToFp(rp, fromBits, toBits, nt);
154       auto ic = convertFpToFp(ip, fromBits, toBits, nt);
155       auto un = rewriter.create<mlir::LLVM::UndefOp>(loc, toTy);
156       auto i1 =
157           rewriter.create<mlir::LLVM::InsertValueOp>(loc, toTy, un, rc, zero);
158       rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(convert, toTy, i1,
159                                                              ic, one);
160       return mlir::success();
161     }
162     // Floating point to floating point conversion.
163     if (isFloatingPointTy(fromTy)) {
164       if (isFloatingPointTy(toTy)) {
165         auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy);
166         auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy);
167         auto v = convertFpToFp(op0, fromBits, toBits, toTy);
168         rewriter.replaceOp(convert, v);
169         return mlir::success();
170       }
171       if (toTy.isa<mlir::IntegerType>()) {
172         rewriter.replaceOpWithNewOp<mlir::LLVM::FPToSIOp>(convert, toTy, op0);
173         return mlir::success();
174       }
175     } else if (fromTy.isa<mlir::IntegerType>()) {
176       // Integer to integer conversion.
177       if (toTy.isa<mlir::IntegerType>()) {
178         auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy);
179         auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy);
180         assert(fromBits != toBits);
181         if (fromBits > toBits) {
182           rewriter.replaceOpWithNewOp<mlir::LLVM::TruncOp>(convert, toTy, op0);
183           return mlir::success();
184         }
185         rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>(convert, toTy, op0);
186         return mlir::success();
187       }
188       // Integer to floating point conversion.
189       if (isFloatingPointTy(toTy)) {
190         rewriter.replaceOpWithNewOp<mlir::LLVM::SIToFPOp>(convert, toTy, op0);
191         return mlir::success();
192       }
193       // Integer to pointer conversion.
194       if (toTy.isa<mlir::LLVM::LLVMPointerType>()) {
195         rewriter.replaceOpWithNewOp<mlir::LLVM::IntToPtrOp>(convert, toTy, op0);
196         return mlir::success();
197       }
198     } else if (fromTy.isa<mlir::LLVM::LLVMPointerType>()) {
199       // Pointer to integer conversion.
200       if (toTy.isa<mlir::IntegerType>()) {
201         rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(convert, toTy, op0);
202         return mlir::success();
203       }
204       // Pointer to pointer conversion.
205       if (toTy.isa<mlir::LLVM::LLVMPointerType>()) {
206         rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(convert, toTy, op0);
207         return mlir::success();
208       }
209     }
210     return emitError(loc) << "cannot convert " << fromTy << " to " << toTy;
211   }
212 };
213 
214 /// Lower `fir.has_value` operation to `llvm.return` operation.
215 struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> {
216   using FIROpConversion::FIROpConversion;
217 
218   mlir::LogicalResult
219   matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor,
220                   mlir::ConversionPatternRewriter &rewriter) const override {
221     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
222     return success();
223   }
224 };
225 
226 /// Lower `fir.global` operation to `llvm.global` operation.
227 /// `fir.insert_on_range` operations are replaced with constant dense attribute
228 /// if they are applied on the full range.
229 struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
230   using FIROpConversion::FIROpConversion;
231 
232   mlir::LogicalResult
233   matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor,
234                   mlir::ConversionPatternRewriter &rewriter) const override {
235     auto tyAttr = convertType(global.getType());
236     if (global.getType().isa<fir::BoxType>())
237       tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType();
238     auto loc = global.getLoc();
239     mlir::Attribute initAttr{};
240     if (global.initVal())
241       initAttr = global.initVal().getValue();
242     auto linkage = convertLinkage(global.linkName());
243     auto isConst = global.constant().hasValue();
244     auto g = rewriter.create<mlir::LLVM::GlobalOp>(
245         loc, tyAttr, isConst, linkage, global.sym_name(), initAttr);
246     auto &gr = g.getInitializerRegion();
247     rewriter.inlineRegionBefore(global.region(), gr, gr.end());
248     if (!gr.empty()) {
249       // Replace insert_on_range with a constant dense attribute if the
250       // initialization is on the full range.
251       auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>();
252       for (auto insertOp : insertOnRangeOps) {
253         if (isFullRange(insertOp.coor(), insertOp.getType())) {
254           auto seqTyAttr = convertType(insertOp.getType());
255           auto *op = insertOp.val().getDefiningOp();
256           auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op);
257           if (!constant) {
258             auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op);
259             if (!convertOp)
260               continue;
261             constant = cast<mlir::arith::ConstantOp>(
262                 convertOp.value().getDefiningOp());
263           }
264           mlir::Type vecType = mlir::VectorType::get(
265               insertOp.getType().getShape(), constant.getType());
266           auto denseAttr = mlir::DenseElementsAttr::get(
267               vecType.cast<ShapedType>(), constant.value());
268           rewriter.setInsertionPointAfter(insertOp);
269           rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
270               insertOp, seqTyAttr, denseAttr);
271         }
272       }
273     }
274     rewriter.eraseOp(global);
275     return success();
276   }
277 
278   bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const {
279     auto extents = seqTy.getShape();
280     if (indexes.size() / 2 != extents.size())
281       return false;
282     for (unsigned i = 0; i < indexes.size(); i += 2) {
283       if (indexes[i].cast<IntegerAttr>().getInt() != 0)
284         return false;
285       if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1)
286         return false;
287     }
288     return true;
289   }
290 
291   // TODO: String comparaison should be avoided. Replace linkName with an
292   // enumeration.
293   mlir::LLVM::Linkage convertLinkage(Optional<StringRef> optLinkage) const {
294     if (optLinkage.hasValue()) {
295       auto name = optLinkage.getValue();
296       if (name == "internal")
297         return mlir::LLVM::Linkage::Internal;
298       if (name == "linkonce")
299         return mlir::LLVM::Linkage::Linkonce;
300       if (name == "common")
301         return mlir::LLVM::Linkage::Common;
302       if (name == "weak")
303         return mlir::LLVM::Linkage::Weak;
304     }
305     return mlir::LLVM::Linkage::External;
306   }
307 };
308 
309 template <typename OP>
310 void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
311                            typename OP::Adaptor adaptor,
312                            mlir::ConversionPatternRewriter &rewriter) {
313   unsigned conds = select.getNumConditions();
314   auto cases = select.getCases().getValue();
315   mlir::Value selector = adaptor.selector();
316   auto loc = select.getLoc();
317   assert(conds > 0 && "select must have cases");
318 
319   llvm::SmallVector<mlir::Block *> destinations;
320   llvm::SmallVector<mlir::ValueRange> destinationsOperands;
321   mlir::Block *defaultDestination;
322   mlir::ValueRange defaultOperands;
323   llvm::SmallVector<int32_t> caseValues;
324 
325   for (unsigned t = 0; t != conds; ++t) {
326     mlir::Block *dest = select.getSuccessor(t);
327     auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
328     const mlir::Attribute &attr = cases[t];
329     if (auto intAttr = attr.template dyn_cast<mlir::IntegerAttr>()) {
330       destinations.push_back(dest);
331       destinationsOperands.push_back(destOps.hasValue() ? *destOps
332                                                         : ValueRange());
333       caseValues.push_back(intAttr.getInt());
334       continue;
335     }
336     assert(attr.template dyn_cast_or_null<mlir::UnitAttr>());
337     assert((t + 1 == conds) && "unit must be last");
338     defaultDestination = dest;
339     defaultOperands = destOps.hasValue() ? *destOps : ValueRange();
340   }
341 
342   // LLVM::SwitchOp takes a i32 type for the selector.
343   if (select.getSelector().getType() != rewriter.getI32Type())
344     selector =
345         rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), selector);
346 
347   rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
348       select, selector,
349       /*defaultDestination=*/defaultDestination,
350       /*defaultOperands=*/defaultOperands,
351       /*caseValues=*/caseValues,
352       /*caseDestinations=*/destinations,
353       /*caseOperands=*/destinationsOperands,
354       /*branchWeights=*/ArrayRef<int32_t>());
355 }
356 
357 /// conversion of fir::SelectOp to an if-then-else ladder
358 struct SelectOpConversion : public FIROpConversion<fir::SelectOp> {
359   using FIROpConversion::FIROpConversion;
360 
361   mlir::LogicalResult
362   matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
363                   mlir::ConversionPatternRewriter &rewriter) const override {
364     selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, rewriter);
365     return success();
366   }
367 };
368 
369 /// `fir.load` --> `llvm.load`
370 struct LoadOpConversion : public FIROpConversion<fir::LoadOp> {
371   using FIROpConversion::FIROpConversion;
372 
373   mlir::LogicalResult
374   matchAndRewrite(fir::LoadOp load, OpAdaptor adaptor,
375                   mlir::ConversionPatternRewriter &rewriter) const override {
376     // fir.box is a special case because it is considered as an ssa values in
377     // fir, but it is lowered as a pointer to a descriptor. So fir.ref<fir.box>
378     // and fir.box end up being the same llvm types and loading a
379     // fir.ref<fir.box> is actually a no op in LLVM.
380     if (load.getType().isa<fir::BoxType>()) {
381       rewriter.replaceOp(load, adaptor.getOperands()[0]);
382     } else {
383       mlir::Type ty = convertType(load.getType());
384       ArrayRef<NamedAttribute> at = load->getAttrs();
385       rewriter.replaceOpWithNewOp<mlir::LLVM::LoadOp>(
386           load, ty, adaptor.getOperands(), at);
387     }
388     return success();
389   }
390 };
391 
392 /// conversion of fir::SelectRankOp to an if-then-else ladder
393 struct SelectRankOpConversion : public FIROpConversion<fir::SelectRankOp> {
394   using FIROpConversion::FIROpConversion;
395 
396   mlir::LogicalResult
397   matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
398                   mlir::ConversionPatternRewriter &rewriter) const override {
399     selectMatchAndRewrite<fir::SelectRankOp>(lowerTy(), op, adaptor, rewriter);
400     return success();
401   }
402 };
403 
404 /// `fir.store` --> `llvm.store`
405 struct StoreOpConversion : public FIROpConversion<fir::StoreOp> {
406   using FIROpConversion::FIROpConversion;
407 
408   mlir::LogicalResult
409   matchAndRewrite(fir::StoreOp store, OpAdaptor adaptor,
410                   mlir::ConversionPatternRewriter &rewriter) const override {
411     if (store.value().getType().isa<fir::BoxType>()) {
412       // fir.box value is actually in memory, load it first before storing it.
413       mlir::Location loc = store.getLoc();
414       mlir::Type boxPtrTy = adaptor.getOperands()[0].getType();
415       auto val = rewriter.create<mlir::LLVM::LoadOp>(
416           loc, boxPtrTy.cast<mlir::LLVM::LLVMPointerType>().getElementType(),
417           adaptor.getOperands()[0]);
418       rewriter.replaceOpWithNewOp<mlir::LLVM::StoreOp>(
419           store, val, adaptor.getOperands()[1]);
420     } else {
421       rewriter.replaceOpWithNewOp<mlir::LLVM::StoreOp>(
422           store, adaptor.getOperands()[0], adaptor.getOperands()[1]);
423     }
424     return success();
425   }
426 };
427 
428 /// convert to LLVM IR dialect `undef`
429 struct UndefOpConversion : public FIROpConversion<fir::UndefOp> {
430   using FIROpConversion::FIROpConversion;
431 
432   mlir::LogicalResult
433   matchAndRewrite(fir::UndefOp undef, OpAdaptor,
434                   mlir::ConversionPatternRewriter &rewriter) const override {
435     rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>(
436         undef, convertType(undef.getType()));
437     return success();
438   }
439 };
440 
441 /// `fir.unreachable` --> `llvm.unreachable`
442 struct UnreachableOpConversion : public FIROpConversion<fir::UnreachableOp> {
443   using FIROpConversion::FIROpConversion;
444 
445   mlir::LogicalResult
446   matchAndRewrite(fir::UnreachableOp unreach, OpAdaptor adaptor,
447                   mlir::ConversionPatternRewriter &rewriter) const override {
448     rewriter.replaceOpWithNewOp<mlir::LLVM::UnreachableOp>(unreach);
449     return success();
450   }
451 };
452 
453 struct ZeroOpConversion : public FIROpConversion<fir::ZeroOp> {
454   using FIROpConversion::FIROpConversion;
455 
456   mlir::LogicalResult
457   matchAndRewrite(fir::ZeroOp zero, OpAdaptor,
458                   mlir::ConversionPatternRewriter &rewriter) const override {
459     auto ty = convertType(zero.getType());
460     if (ty.isa<mlir::LLVM::LLVMPointerType>()) {
461       rewriter.replaceOpWithNewOp<mlir::LLVM::NullOp>(zero, ty);
462     } else if (ty.isa<mlir::IntegerType>()) {
463       rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
464           zero, ty, mlir::IntegerAttr::get(zero.getType(), 0));
465     } else if (mlir::LLVM::isCompatibleFloatingPointType(ty)) {
466       rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
467           zero, ty, mlir::FloatAttr::get(zero.getType(), 0.0));
468     } else {
469       // TODO: create ConstantAggregateZero for FIR aggregate/array types.
470       return rewriter.notifyMatchFailure(
471           zero,
472           "conversion of fir.zero with aggregate type not implemented yet");
473     }
474     return success();
475   }
476 };
477 
478 // Code shared between insert_value and extract_value Ops.
479 struct ValueOpCommon {
480   // Translate the arguments pertaining to any multidimensional array to
481   // row-major order for LLVM-IR.
482   static void toRowMajor(SmallVectorImpl<mlir::Attribute> &attrs,
483                          mlir::Type ty) {
484     assert(ty && "type is null");
485     const auto end = attrs.size();
486     for (std::remove_const_t<decltype(end)> i = 0; i < end; ++i) {
487       if (auto seq = ty.dyn_cast<mlir::LLVM::LLVMArrayType>()) {
488         const auto dim = getDimension(seq);
489         if (dim > 1) {
490           auto ub = std::min(i + dim, end);
491           std::reverse(attrs.begin() + i, attrs.begin() + ub);
492           i += dim - 1;
493         }
494         ty = getArrayElementType(seq);
495       } else if (auto st = ty.dyn_cast<mlir::LLVM::LLVMStructType>()) {
496         ty = st.getBody()[attrs[i].cast<mlir::IntegerAttr>().getInt()];
497       } else {
498         llvm_unreachable("index into invalid type");
499       }
500     }
501   }
502 
503   static llvm::SmallVector<mlir::Attribute>
504   collectIndices(mlir::ConversionPatternRewriter &rewriter,
505                  mlir::ArrayAttr arrAttr) {
506     llvm::SmallVector<mlir::Attribute> attrs;
507     for (auto i = arrAttr.begin(), e = arrAttr.end(); i != e; ++i) {
508       if (i->isa<mlir::IntegerAttr>()) {
509         attrs.push_back(*i);
510       } else {
511         auto fieldName = i->cast<mlir::StringAttr>().getValue();
512         ++i;
513         auto ty = i->cast<mlir::TypeAttr>().getValue();
514         auto index = ty.cast<fir::RecordType>().getFieldIndex(fieldName);
515         attrs.push_back(mlir::IntegerAttr::get(rewriter.getI32Type(), index));
516       }
517     }
518     return attrs;
519   }
520 
521 private:
522   static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
523     unsigned result = 1;
524     for (auto eleTy = ty.getElementType().dyn_cast<mlir::LLVM::LLVMArrayType>();
525          eleTy;
526          eleTy = eleTy.getElementType().dyn_cast<mlir::LLVM::LLVMArrayType>())
527       ++result;
528     return result;
529   }
530 
531   static mlir::Type getArrayElementType(mlir::LLVM::LLVMArrayType ty) {
532     auto eleTy = ty.getElementType();
533     while (auto arrTy = eleTy.dyn_cast<mlir::LLVM::LLVMArrayType>())
534       eleTy = arrTy.getElementType();
535     return eleTy;
536   }
537 };
538 
539 /// Extract a subobject value from an ssa-value of aggregate type
540 struct ExtractValueOpConversion
541     : public FIROpAndTypeConversion<fir::ExtractValueOp>,
542       public ValueOpCommon {
543   using FIROpAndTypeConversion::FIROpAndTypeConversion;
544 
545   mlir::LogicalResult
546   doRewrite(fir::ExtractValueOp extractVal, mlir::Type ty, OpAdaptor adaptor,
547             mlir::ConversionPatternRewriter &rewriter) const override {
548     auto attrs = collectIndices(rewriter, extractVal.coor());
549     toRowMajor(attrs, adaptor.getOperands()[0].getType());
550     auto position = mlir::ArrayAttr::get(extractVal.getContext(), attrs);
551     rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
552         extractVal, ty, adaptor.getOperands()[0], position);
553     return success();
554   }
555 };
556 
557 /// InsertValue is the generalized instruction for the composition of new
558 /// aggregate type values.
559 struct InsertValueOpConversion
560     : public FIROpAndTypeConversion<fir::InsertValueOp>,
561       public ValueOpCommon {
562   using FIROpAndTypeConversion::FIROpAndTypeConversion;
563 
564   mlir::LogicalResult
565   doRewrite(fir::InsertValueOp insertVal, mlir::Type ty, OpAdaptor adaptor,
566             mlir::ConversionPatternRewriter &rewriter) const override {
567     auto attrs = collectIndices(rewriter, insertVal.coor());
568     toRowMajor(attrs, adaptor.getOperands()[0].getType());
569     auto position = mlir::ArrayAttr::get(insertVal.getContext(), attrs);
570     rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
571         insertVal, ty, adaptor.getOperands()[0], adaptor.getOperands()[1],
572         position);
573     return success();
574   }
575 };
576 
577 /// InsertOnRange inserts a value into a sequence over a range of offsets.
578 struct InsertOnRangeOpConversion
579     : public FIROpAndTypeConversion<fir::InsertOnRangeOp> {
580   using FIROpAndTypeConversion::FIROpAndTypeConversion;
581 
582   // Increments an array of subscripts in a row major fasion.
583   void incrementSubscripts(const SmallVector<uint64_t> &dims,
584                            SmallVector<uint64_t> &subscripts) const {
585     for (size_t i = dims.size(); i > 0; --i) {
586       if (++subscripts[i - 1] < dims[i - 1]) {
587         return;
588       }
589       subscripts[i - 1] = 0;
590     }
591   }
592 
593   mlir::LogicalResult
594   doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor,
595             mlir::ConversionPatternRewriter &rewriter) const override {
596 
597     llvm::SmallVector<uint64_t> dims;
598     auto type = adaptor.getOperands()[0].getType();
599 
600     // Iteratively extract the array dimensions from the type.
601     while (auto t = type.dyn_cast<mlir::LLVM::LLVMArrayType>()) {
602       dims.push_back(t.getNumElements());
603       type = t.getElementType();
604     }
605 
606     SmallVector<uint64_t> lBounds;
607     SmallVector<uint64_t> uBounds;
608 
609     // Extract integer value from the attribute
610     SmallVector<int64_t> coordinates = llvm::to_vector<4>(
611         llvm::map_range(range.coor(), [](Attribute a) -> int64_t {
612           return a.cast<IntegerAttr>().getInt();
613         }));
614 
615     // Unzip the upper and lower bound and convert to a row major format.
616     for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) {
617       uBounds.push_back(*i++);
618       lBounds.push_back(*i);
619     }
620 
621     auto &subscripts = lBounds;
622     auto loc = range.getLoc();
623     mlir::Value lastOp = adaptor.getOperands()[0];
624     mlir::Value insertVal = adaptor.getOperands()[1];
625 
626     auto i64Ty = rewriter.getI64Type();
627     while (subscripts != uBounds) {
628       // Convert uint64_t's to Attribute's.
629       SmallVector<mlir::Attribute> subscriptAttrs;
630       for (const auto &subscript : subscripts)
631         subscriptAttrs.push_back(IntegerAttr::get(i64Ty, subscript));
632       lastOp = rewriter.create<mlir::LLVM::InsertValueOp>(
633           loc, ty, lastOp, insertVal,
634           ArrayAttr::get(range.getContext(), subscriptAttrs));
635 
636       incrementSubscripts(dims, subscripts);
637     }
638 
639     // Convert uint64_t's to Attribute's.
640     SmallVector<mlir::Attribute> subscriptAttrs;
641     for (const auto &subscript : subscripts)
642       subscriptAttrs.push_back(
643           IntegerAttr::get(rewriter.getI64Type(), subscript));
644     mlir::ArrayRef<mlir::Attribute> arrayRef(subscriptAttrs);
645 
646     rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
647         range, ty, lastOp, insertVal,
648         ArrayAttr::get(range.getContext(), arrayRef));
649 
650     return success();
651   }
652 };
653 
654 //
655 // Primitive operations on Complex types
656 //
657 
658 /// Generate inline code for complex addition/subtraction
659 template <typename LLVMOP, typename OPTY>
660 mlir::LLVM::InsertValueOp complexSum(OPTY sumop, mlir::ValueRange opnds,
661                                      mlir::ConversionPatternRewriter &rewriter,
662                                      fir::LLVMTypeConverter &lowering) {
663   mlir::Value a = opnds[0];
664   mlir::Value b = opnds[1];
665   auto loc = sumop.getLoc();
666   auto ctx = sumop.getContext();
667   auto c0 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(0));
668   auto c1 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(1));
669   mlir::Type eleTy = lowering.convertType(getComplexEleTy(sumop.getType()));
670   mlir::Type ty = lowering.convertType(sumop.getType());
671   auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c0);
672   auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c1);
673   auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c0);
674   auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c1);
675   auto rx = rewriter.create<LLVMOP>(loc, eleTy, x0, x1);
676   auto ry = rewriter.create<LLVMOP>(loc, eleTy, y0, y1);
677   auto r0 = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
678   auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, r0, rx, c0);
679   return rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, r1, ry, c1);
680 }
681 
682 struct AddcOpConversion : public FIROpConversion<fir::AddcOp> {
683   using FIROpConversion::FIROpConversion;
684 
685   mlir::LogicalResult
686   matchAndRewrite(fir::AddcOp addc, OpAdaptor adaptor,
687                   mlir::ConversionPatternRewriter &rewriter) const override {
688     // given: (x + iy) + (x' + iy')
689     // result: (x + x') + i(y + y')
690     auto r = complexSum<mlir::LLVM::FAddOp>(addc, adaptor.getOperands(),
691                                             rewriter, lowerTy());
692     rewriter.replaceOp(addc, r.getResult());
693     return success();
694   }
695 };
696 
697 struct SubcOpConversion : public FIROpConversion<fir::SubcOp> {
698   using FIROpConversion::FIROpConversion;
699 
700   mlir::LogicalResult
701   matchAndRewrite(fir::SubcOp subc, OpAdaptor adaptor,
702                   mlir::ConversionPatternRewriter &rewriter) const override {
703     // given: (x + iy) - (x' + iy')
704     // result: (x - x') + i(y - y')
705     auto r = complexSum<mlir::LLVM::FSubOp>(subc, adaptor.getOperands(),
706                                             rewriter, lowerTy());
707     rewriter.replaceOp(subc, r.getResult());
708     return success();
709   }
710 };
711 
712 /// Inlined complex multiply
713 struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
714   using FIROpConversion::FIROpConversion;
715 
716   mlir::LogicalResult
717   matchAndRewrite(fir::MulcOp mulc, OpAdaptor adaptor,
718                   mlir::ConversionPatternRewriter &rewriter) const override {
719     // TODO: Can we use a call to __muldc3 ?
720     // given: (x + iy) * (x' + iy')
721     // result: (xx'-yy')+i(xy'+yx')
722     mlir::Value a = adaptor.getOperands()[0];
723     mlir::Value b = adaptor.getOperands()[1];
724     auto loc = mulc.getLoc();
725     auto *ctx = mulc.getContext();
726     auto c0 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(0));
727     auto c1 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(1));
728     mlir::Type eleTy = convertType(getComplexEleTy(mulc.getType()));
729     mlir::Type ty = convertType(mulc.getType());
730     auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c0);
731     auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c1);
732     auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c0);
733     auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c1);
734     auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
735     auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
736     auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
737     auto ri = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xy, yx);
738     auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
739     auto rr = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, xx, yy);
740     auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
741     auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, ra, rr, c0);
742     auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, r1, ri, c1);
743     rewriter.replaceOp(mulc, r0.getResult());
744     return success();
745   }
746 };
747 
748 /// Inlined complex division
749 struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
750   using FIROpConversion::FIROpConversion;
751 
752   mlir::LogicalResult
753   matchAndRewrite(fir::DivcOp divc, OpAdaptor adaptor,
754                   mlir::ConversionPatternRewriter &rewriter) const override {
755     // TODO: Can we use a call to __divdc3 instead?
756     // Just generate inline code for now.
757     // given: (x + iy) / (x' + iy')
758     // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
759     mlir::Value a = adaptor.getOperands()[0];
760     mlir::Value b = adaptor.getOperands()[1];
761     auto loc = divc.getLoc();
762     auto *ctx = divc.getContext();
763     auto c0 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(0));
764     auto c1 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(1));
765     mlir::Type eleTy = convertType(getComplexEleTy(divc.getType()));
766     mlir::Type ty = convertType(divc.getType());
767     auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c0);
768     auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c1);
769     auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c0);
770     auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c1);
771     auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
772     auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
773     auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
774     auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
775     auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
776     auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
777     auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
778     auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
779     auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
780     auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
781     auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
782     auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
783     auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, ra, rr, c0);
784     auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, r1, ri, c1);
785     rewriter.replaceOp(divc, r0.getResult());
786     return success();
787   }
788 };
789 
790 /// Inlined complex negation
791 struct NegcOpConversion : public FIROpConversion<fir::NegcOp> {
792   using FIROpConversion::FIROpConversion;
793 
794   mlir::LogicalResult
795   matchAndRewrite(fir::NegcOp neg, OpAdaptor adaptor,
796                   mlir::ConversionPatternRewriter &rewriter) const override {
797     // given: -(x + iy)
798     // result: -x - iy
799     auto *ctxt = neg.getContext();
800     auto eleTy = convertType(getComplexEleTy(neg.getType()));
801     auto ty = convertType(neg.getType());
802     auto loc = neg.getLoc();
803     mlir::Value o0 = adaptor.getOperands()[0];
804     auto c0 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(0));
805     auto c1 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(1));
806     auto rp = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, o0, c0);
807     auto ip = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, o0, c1);
808     auto nrp = rewriter.create<mlir::LLVM::FNegOp>(loc, eleTy, rp);
809     auto nip = rewriter.create<mlir::LLVM::FNegOp>(loc, eleTy, ip);
810     auto r = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, o0, nrp, c0);
811     rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(neg, ty, r, nip, c1);
812     return success();
813   }
814 };
815 
816 } // namespace
817 
818 namespace {
819 /// Convert FIR dialect to LLVM dialect
820 ///
821 /// This pass lowers all FIR dialect operations to LLVM IR dialect. An
822 /// MLIR pass is used to lower residual Std dialect to LLVM IR dialect.
823 ///
824 /// This pass is not complete yet. We are upstreaming it in small patches.
825 class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
826 public:
827   mlir::ModuleOp getModule() { return getOperation(); }
828 
829   void runOnOperation() override final {
830     auto mod = getModule();
831     if (!forcedTargetTriple.empty()) {
832       fir::setTargetTriple(mod, forcedTargetTriple);
833     }
834 
835     auto *context = getModule().getContext();
836     fir::LLVMTypeConverter typeConverter{getModule()};
837     mlir::OwningRewritePatternList pattern(context);
838     pattern.insert<AddcOpConversion, AddrOfOpConversion, CallOpConversion,
839                    ConvertOpConversion, DivcOpConversion,
840                    ExtractValueOpConversion, HasValueOpConversion,
841                    GlobalOpConversion, InsertOnRangeOpConversion,
842                    InsertValueOpConversion, LoadOpConversion, NegcOpConversion,
843                    MulcOpConversion, SelectOpConversion, SelectRankOpConversion,
844                    StoreOpConversion, SubcOpConversion, UndefOpConversion,
845                    UnreachableOpConversion, ZeroOpConversion>(typeConverter);
846     mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
847     mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
848                                                             pattern);
849     mlir::ConversionTarget target{*context};
850     target.addLegalDialect<mlir::LLVM::LLVMDialect>();
851 
852     // required NOPs for applying a full conversion
853     target.addLegalOp<mlir::ModuleOp>();
854 
855     // apply the patterns
856     if (mlir::failed(mlir::applyFullConversion(getModule(), target,
857                                                std::move(pattern)))) {
858       signalPassFailure();
859     }
860   }
861 };
862 } // namespace
863 
864 std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() {
865   return std::make_unique<FIRToLLVMLowering>();
866 }
867