1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/CodeGen/CodeGen.h"
14 #include "PassDetail.h"
15 #include "flang/Optimizer/Dialect/FIROps.h"
16 #include "flang/Optimizer/Dialect/FIRType.h"
17 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
18 #include "mlir/Conversion/LLVMCommon/Pattern.h"
19 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/Pass/Pass.h"
24 #include "llvm/ADT/ArrayRef.h"
25 
26 #define DEBUG_TYPE "flang-codegen"
27 
28 // fir::LLVMTypeConverter for converting to LLVM IR dialect types.
29 #include "TypeConverter.h"
30 
31 namespace {
32 /// FIR conversion pattern template
33 template <typename FromOp>
34 class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
35 public:
36   explicit FIROpConversion(fir::LLVMTypeConverter &lowering)
37       : mlir::ConvertOpToLLVMPattern<FromOp>(lowering) {}
38 
39 protected:
40   mlir::Type convertType(mlir::Type ty) const {
41     return lowerTy().convertType(ty);
42   }
43 
44   fir::LLVMTypeConverter &lowerTy() const {
45     return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter());
46   }
47 };
48 
49 /// FIR conversion pattern template
50 template <typename FromOp>
51 class FIROpAndTypeConversion : public FIROpConversion<FromOp> {
52 public:
53   using FIROpConversion<FromOp>::FIROpConversion;
54   using OpAdaptor = typename FromOp::Adaptor;
55 
56   mlir::LogicalResult
57   matchAndRewrite(FromOp op, OpAdaptor adaptor,
58                   mlir::ConversionPatternRewriter &rewriter) const final {
59     mlir::Type ty = this->convertType(op.getType());
60     return doRewrite(op, ty, adaptor, rewriter);
61   }
62 
63   virtual mlir::LogicalResult
64   doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor,
65             mlir::ConversionPatternRewriter &rewriter) const = 0;
66 };
67 
68 // Lower `fir.address_of` operation to `llvm.address_of` operation.
69 struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> {
70   using FIROpConversion::FIROpConversion;
71 
72   mlir::LogicalResult
73   matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
74                   mlir::ConversionPatternRewriter &rewriter) const override {
75     auto ty = convertType(addr.getType());
76     rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
77         addr, ty, addr.symbol().getRootReference().getValue());
78     return success();
79   }
80 };
81 
82 // `fir.call` -> `llvm.call`
83 struct CallOpConversion : public FIROpConversion<fir::CallOp> {
84   using FIROpConversion::FIROpConversion;
85 
86   mlir::LogicalResult
87   matchAndRewrite(fir::CallOp call, OpAdaptor adaptor,
88                   mlir::ConversionPatternRewriter &rewriter) const override {
89     SmallVector<mlir::Type> resultTys;
90     for (auto r : call.getResults())
91       resultTys.push_back(convertType(r.getType()));
92     rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
93         call, resultTys, adaptor.getOperands(), call->getAttrs());
94     return success();
95   }
96 };
97 
98 /// Lower `fir.has_value` operation to `llvm.return` operation.
99 struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> {
100   using FIROpConversion::FIROpConversion;
101 
102   mlir::LogicalResult
103   matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor,
104                   mlir::ConversionPatternRewriter &rewriter) const override {
105     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
106     return success();
107   }
108 };
109 
110 /// Lower `fir.global` operation to `llvm.global` operation.
111 /// `fir.insert_on_range` operations are replaced with constant dense attribute
112 /// if they are applied on the full range.
113 struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
114   using FIROpConversion::FIROpConversion;
115 
116   mlir::LogicalResult
117   matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor,
118                   mlir::ConversionPatternRewriter &rewriter) const override {
119     auto tyAttr = convertType(global.getType());
120     if (global.getType().isa<fir::BoxType>())
121       tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType();
122     auto loc = global.getLoc();
123     mlir::Attribute initAttr{};
124     if (global.initVal())
125       initAttr = global.initVal().getValue();
126     auto linkage = convertLinkage(global.linkName());
127     auto isConst = global.constant().hasValue();
128     auto g = rewriter.create<mlir::LLVM::GlobalOp>(
129         loc, tyAttr, isConst, linkage, global.sym_name(), initAttr);
130     auto &gr = g.getInitializerRegion();
131     rewriter.inlineRegionBefore(global.region(), gr, gr.end());
132     if (!gr.empty()) {
133       // Replace insert_on_range with a constant dense attribute if the
134       // initialization is on the full range.
135       auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>();
136       for (auto insertOp : insertOnRangeOps) {
137         if (isFullRange(insertOp.coor(), insertOp.getType())) {
138           auto seqTyAttr = convertType(insertOp.getType());
139           auto *op = insertOp.val().getDefiningOp();
140           auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op);
141           if (!constant) {
142             auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op);
143             if (!convertOp)
144               continue;
145             constant = cast<mlir::arith::ConstantOp>(
146                 convertOp.value().getDefiningOp());
147           }
148           mlir::Type vecType = mlir::VectorType::get(
149               insertOp.getType().getShape(), constant.getType());
150           auto denseAttr = mlir::DenseElementsAttr::get(
151               vecType.cast<ShapedType>(), constant.value());
152           rewriter.setInsertionPointAfter(insertOp);
153           rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
154               insertOp, seqTyAttr, denseAttr);
155         }
156       }
157     }
158     rewriter.eraseOp(global);
159     return success();
160   }
161 
162   bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const {
163     auto extents = seqTy.getShape();
164     if (indexes.size() / 2 != extents.size())
165       return false;
166     for (unsigned i = 0; i < indexes.size(); i += 2) {
167       if (indexes[i].cast<IntegerAttr>().getInt() != 0)
168         return false;
169       if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1)
170         return false;
171     }
172     return true;
173   }
174 
175   // TODO: String comparaison should be avoided. Replace linkName with an
176   // enumeration.
177   mlir::LLVM::Linkage convertLinkage(Optional<StringRef> optLinkage) const {
178     if (optLinkage.hasValue()) {
179       auto name = optLinkage.getValue();
180       if (name == "internal")
181         return mlir::LLVM::Linkage::Internal;
182       if (name == "linkonce")
183         return mlir::LLVM::Linkage::Linkonce;
184       if (name == "common")
185         return mlir::LLVM::Linkage::Common;
186       if (name == "weak")
187         return mlir::LLVM::Linkage::Weak;
188     }
189     return mlir::LLVM::Linkage::External;
190   }
191 };
192 
193 template <typename OP>
194 void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
195                            typename OP::Adaptor adaptor,
196                            mlir::ConversionPatternRewriter &rewriter) {
197   unsigned conds = select.getNumConditions();
198   auto cases = select.getCases().getValue();
199   mlir::Value selector = adaptor.selector();
200   auto loc = select.getLoc();
201   assert(conds > 0 && "select must have cases");
202 
203   llvm::SmallVector<mlir::Block *> destinations;
204   llvm::SmallVector<mlir::ValueRange> destinationsOperands;
205   mlir::Block *defaultDestination;
206   mlir::ValueRange defaultOperands;
207   llvm::SmallVector<int32_t> caseValues;
208 
209   for (unsigned t = 0; t != conds; ++t) {
210     mlir::Block *dest = select.getSuccessor(t);
211     auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
212     const mlir::Attribute &attr = cases[t];
213     if (auto intAttr = attr.template dyn_cast<mlir::IntegerAttr>()) {
214       destinations.push_back(dest);
215       destinationsOperands.push_back(destOps.hasValue() ? *destOps
216                                                         : ValueRange());
217       caseValues.push_back(intAttr.getInt());
218       continue;
219     }
220     assert(attr.template dyn_cast_or_null<mlir::UnitAttr>());
221     assert((t + 1 == conds) && "unit must be last");
222     defaultDestination = dest;
223     defaultOperands = destOps.hasValue() ? *destOps : ValueRange();
224   }
225 
226   // LLVM::SwitchOp takes a i32 type for the selector.
227   if (select.getSelector().getType() != rewriter.getI32Type())
228     selector =
229         rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), selector);
230 
231   rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
232       select, selector,
233       /*defaultDestination=*/defaultDestination,
234       /*defaultOperands=*/defaultOperands,
235       /*caseValues=*/caseValues,
236       /*caseDestinations=*/destinations,
237       /*caseOperands=*/destinationsOperands,
238       /*branchWeights=*/ArrayRef<int32_t>());
239 }
240 
241 /// conversion of fir::SelectOp to an if-then-else ladder
242 struct SelectOpConversion : public FIROpConversion<fir::SelectOp> {
243   using FIROpConversion::FIROpConversion;
244 
245   mlir::LogicalResult
246   matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
247                   mlir::ConversionPatternRewriter &rewriter) const override {
248     selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, rewriter);
249     return success();
250   }
251 };
252 
253 /// conversion of fir::SelectRankOp to an if-then-else ladder
254 struct SelectRankOpConversion : public FIROpConversion<fir::SelectRankOp> {
255   using FIROpConversion::FIROpConversion;
256 
257   mlir::LogicalResult
258   matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
259                   mlir::ConversionPatternRewriter &rewriter) const override {
260     selectMatchAndRewrite<fir::SelectRankOp>(lowerTy(), op, adaptor, rewriter);
261     return success();
262   }
263 };
264 
265 // convert to LLVM IR dialect `undef`
266 struct UndefOpConversion : public FIROpConversion<fir::UndefOp> {
267   using FIROpConversion::FIROpConversion;
268 
269   mlir::LogicalResult
270   matchAndRewrite(fir::UndefOp undef, OpAdaptor,
271                   mlir::ConversionPatternRewriter &rewriter) const override {
272     rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>(
273         undef, convertType(undef.getType()));
274     return success();
275   }
276 };
277 
278 // convert to LLVM IR dialect `unreachable`
279 struct UnreachableOpConversion : public FIROpConversion<fir::UnreachableOp> {
280   using FIROpConversion::FIROpConversion;
281 
282   mlir::LogicalResult
283   matchAndRewrite(fir::UnreachableOp unreach, OpAdaptor adaptor,
284                   mlir::ConversionPatternRewriter &rewriter) const override {
285     rewriter.replaceOpWithNewOp<mlir::LLVM::UnreachableOp>(unreach);
286     return success();
287   }
288 };
289 
290 struct ZeroOpConversion : public FIROpConversion<fir::ZeroOp> {
291   using FIROpConversion::FIROpConversion;
292 
293   mlir::LogicalResult
294   matchAndRewrite(fir::ZeroOp zero, OpAdaptor,
295                   mlir::ConversionPatternRewriter &rewriter) const override {
296     auto ty = convertType(zero.getType());
297     if (ty.isa<mlir::LLVM::LLVMPointerType>()) {
298       rewriter.replaceOpWithNewOp<mlir::LLVM::NullOp>(zero, ty);
299     } else if (ty.isa<mlir::IntegerType>()) {
300       rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
301           zero, ty, mlir::IntegerAttr::get(zero.getType(), 0));
302     } else if (mlir::LLVM::isCompatibleFloatingPointType(ty)) {
303       rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
304           zero, ty, mlir::FloatAttr::get(zero.getType(), 0.0));
305     } else {
306       // TODO: create ConstantAggregateZero for FIR aggregate/array types.
307       return rewriter.notifyMatchFailure(
308           zero,
309           "conversion of fir.zero with aggregate type not implemented yet");
310     }
311     return success();
312   }
313 };
314 
315 // Code shared between insert_value and extract_value Ops.
316 struct ValueOpCommon {
317   // Translate the arguments pertaining to any multidimensional array to
318   // row-major order for LLVM-IR.
319   static void toRowMajor(SmallVectorImpl<mlir::Attribute> &attrs,
320                          mlir::Type ty) {
321     assert(ty && "type is null");
322     const auto end = attrs.size();
323     for (std::remove_const_t<decltype(end)> i = 0; i < end; ++i) {
324       if (auto seq = ty.dyn_cast<mlir::LLVM::LLVMArrayType>()) {
325         const auto dim = getDimension(seq);
326         if (dim > 1) {
327           auto ub = std::min(i + dim, end);
328           std::reverse(attrs.begin() + i, attrs.begin() + ub);
329           i += dim - 1;
330         }
331         ty = getArrayElementType(seq);
332       } else if (auto st = ty.dyn_cast<mlir::LLVM::LLVMStructType>()) {
333         ty = st.getBody()[attrs[i].cast<mlir::IntegerAttr>().getInt()];
334       } else {
335         llvm_unreachable("index into invalid type");
336       }
337     }
338   }
339 
340   static llvm::SmallVector<mlir::Attribute>
341   collectIndices(mlir::ConversionPatternRewriter &rewriter,
342                  mlir::ArrayAttr arrAttr) {
343     llvm::SmallVector<mlir::Attribute> attrs;
344     for (auto i = arrAttr.begin(), e = arrAttr.end(); i != e; ++i) {
345       if (i->isa<mlir::IntegerAttr>()) {
346         attrs.push_back(*i);
347       } else {
348         auto fieldName = i->cast<mlir::StringAttr>().getValue();
349         ++i;
350         auto ty = i->cast<mlir::TypeAttr>().getValue();
351         auto index = ty.cast<fir::RecordType>().getFieldIndex(fieldName);
352         attrs.push_back(mlir::IntegerAttr::get(rewriter.getI32Type(), index));
353       }
354     }
355     return attrs;
356   }
357 
358 private:
359   static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
360     unsigned result = 1;
361     for (auto eleTy = ty.getElementType().dyn_cast<mlir::LLVM::LLVMArrayType>();
362          eleTy;
363          eleTy = eleTy.getElementType().dyn_cast<mlir::LLVM::LLVMArrayType>())
364       ++result;
365     return result;
366   }
367 
368   static mlir::Type getArrayElementType(mlir::LLVM::LLVMArrayType ty) {
369     auto eleTy = ty.getElementType();
370     while (auto arrTy = eleTy.dyn_cast<mlir::LLVM::LLVMArrayType>())
371       eleTy = arrTy.getElementType();
372     return eleTy;
373   }
374 };
375 
376 /// Extract a subobject value from an ssa-value of aggregate type
377 struct ExtractValueOpConversion
378     : public FIROpAndTypeConversion<fir::ExtractValueOp>,
379       public ValueOpCommon {
380   using FIROpAndTypeConversion::FIROpAndTypeConversion;
381 
382   mlir::LogicalResult
383   doRewrite(fir::ExtractValueOp extractVal, mlir::Type ty, OpAdaptor adaptor,
384             mlir::ConversionPatternRewriter &rewriter) const override {
385     auto attrs = collectIndices(rewriter, extractVal.coor());
386     toRowMajor(attrs, adaptor.getOperands()[0].getType());
387     auto position = mlir::ArrayAttr::get(extractVal.getContext(), attrs);
388     rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
389         extractVal, ty, adaptor.getOperands()[0], position);
390     return success();
391   }
392 };
393 
394 /// InsertValue is the generalized instruction for the composition of new
395 /// aggregate type values.
396 struct InsertValueOpConversion
397     : public FIROpAndTypeConversion<fir::InsertValueOp>,
398       public ValueOpCommon {
399   using FIROpAndTypeConversion::FIROpAndTypeConversion;
400 
401   mlir::LogicalResult
402   doRewrite(fir::InsertValueOp insertVal, mlir::Type ty, OpAdaptor adaptor,
403             mlir::ConversionPatternRewriter &rewriter) const override {
404     auto attrs = collectIndices(rewriter, insertVal.coor());
405     toRowMajor(attrs, adaptor.getOperands()[0].getType());
406     auto position = mlir::ArrayAttr::get(insertVal.getContext(), attrs);
407     rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
408         insertVal, ty, adaptor.getOperands()[0], adaptor.getOperands()[1],
409         position);
410     return success();
411   }
412 };
413 
414 /// InsertOnRange inserts a value into a sequence over a range of offsets.
415 struct InsertOnRangeOpConversion
416     : public FIROpAndTypeConversion<fir::InsertOnRangeOp> {
417   using FIROpAndTypeConversion::FIROpAndTypeConversion;
418 
419   // Increments an array of subscripts in a row major fasion.
420   void incrementSubscripts(const SmallVector<uint64_t> &dims,
421                            SmallVector<uint64_t> &subscripts) const {
422     for (size_t i = dims.size(); i > 0; --i) {
423       if (++subscripts[i - 1] < dims[i - 1]) {
424         return;
425       }
426       subscripts[i - 1] = 0;
427     }
428   }
429 
430   mlir::LogicalResult
431   doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor,
432             mlir::ConversionPatternRewriter &rewriter) const override {
433 
434     llvm::SmallVector<uint64_t> dims;
435     auto type = adaptor.getOperands()[0].getType();
436 
437     // Iteratively extract the array dimensions from the type.
438     while (auto t = type.dyn_cast<mlir::LLVM::LLVMArrayType>()) {
439       dims.push_back(t.getNumElements());
440       type = t.getElementType();
441     }
442 
443     SmallVector<uint64_t> lBounds;
444     SmallVector<uint64_t> uBounds;
445 
446     // Extract integer value from the attribute
447     SmallVector<int64_t> coordinates = llvm::to_vector<4>(
448         llvm::map_range(range.coor(), [](Attribute a) -> int64_t {
449           return a.cast<IntegerAttr>().getInt();
450         }));
451 
452     // Unzip the upper and lower bound and convert to a row major format.
453     for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) {
454       uBounds.push_back(*i++);
455       lBounds.push_back(*i);
456     }
457 
458     auto &subscripts = lBounds;
459     auto loc = range.getLoc();
460     mlir::Value lastOp = adaptor.getOperands()[0];
461     mlir::Value insertVal = adaptor.getOperands()[1];
462 
463     auto i64Ty = rewriter.getI64Type();
464     while (subscripts != uBounds) {
465       // Convert uint64_t's to Attribute's.
466       SmallVector<mlir::Attribute> subscriptAttrs;
467       for (const auto &subscript : subscripts)
468         subscriptAttrs.push_back(IntegerAttr::get(i64Ty, subscript));
469       lastOp = rewriter.create<mlir::LLVM::InsertValueOp>(
470           loc, ty, lastOp, insertVal,
471           ArrayAttr::get(range.getContext(), subscriptAttrs));
472 
473       incrementSubscripts(dims, subscripts);
474     }
475 
476     // Convert uint64_t's to Attribute's.
477     SmallVector<mlir::Attribute> subscriptAttrs;
478     for (const auto &subscript : subscripts)
479       subscriptAttrs.push_back(
480           IntegerAttr::get(rewriter.getI64Type(), subscript));
481     mlir::ArrayRef<mlir::Attribute> arrayRef(subscriptAttrs);
482 
483     rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
484         range, ty, lastOp, insertVal,
485         ArrayAttr::get(range.getContext(), arrayRef));
486 
487     return success();
488   }
489 };
490 } // namespace
491 
492 namespace {
493 /// Convert FIR dialect to LLVM dialect
494 ///
495 /// This pass lowers all FIR dialect operations to LLVM IR dialect. An
496 /// MLIR pass is used to lower residual Std dialect to LLVM IR dialect.
497 ///
498 /// This pass is not complete yet. We are upstreaming it in small patches.
499 class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
500 public:
501   mlir::ModuleOp getModule() { return getOperation(); }
502 
503   void runOnOperation() override final {
504     auto *context = getModule().getContext();
505     fir::LLVMTypeConverter typeConverter{getModule()};
506     mlir::OwningRewritePatternList pattern(context);
507     pattern.insert<
508         AddrOfOpConversion, CallOpConversion, ExtractValueOpConversion,
509         HasValueOpConversion, GlobalOpConversion, InsertOnRangeOpConversion,
510         InsertValueOpConversion, SelectOpConversion, SelectRankOpConversion,
511         UndefOpConversion, UnreachableOpConversion, ZeroOpConversion>(
512         typeConverter);
513     mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
514     mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
515                                                             pattern);
516     mlir::ConversionTarget target{*context};
517     target.addLegalDialect<mlir::LLVM::LLVMDialect>();
518 
519     // required NOPs for applying a full conversion
520     target.addLegalOp<mlir::ModuleOp>();
521 
522     // apply the patterns
523     if (mlir::failed(mlir::applyFullConversion(getModule(), target,
524                                                std::move(pattern)))) {
525       signalPassFailure();
526     }
527   }
528 };
529 } // namespace
530 
531 std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() {
532   return std::make_unique<FIRToLLVMLowering>();
533 }
534