1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Optimizer/CodeGen/CodeGen.h" 14 #include "PassDetail.h" 15 #include "flang/Optimizer/Dialect/FIROps.h" 16 #include "flang/Optimizer/Dialect/FIRType.h" 17 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 18 #include "mlir/Conversion/LLVMCommon/Pattern.h" 19 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/Pass/Pass.h" 23 #include "llvm/ADT/ArrayRef.h" 24 25 #define DEBUG_TYPE "flang-codegen" 26 27 // fir::LLVMTypeConverter for converting to LLVM IR dialect types. 28 #include "TypeConverter.h" 29 30 namespace { 31 /// FIR conversion pattern template 32 template <typename FromOp> 33 class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> { 34 public: 35 explicit FIROpConversion(fir::LLVMTypeConverter &lowering) 36 : mlir::ConvertOpToLLVMPattern<FromOp>(lowering) {} 37 38 protected: 39 mlir::Type convertType(mlir::Type ty) const { 40 return lowerTy().convertType(ty); 41 } 42 43 fir::LLVMTypeConverter &lowerTy() const { 44 return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter()); 45 } 46 }; 47 } // namespace 48 49 namespace { 50 struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> { 51 using FIROpConversion::FIROpConversion; 52 53 mlir::LogicalResult 54 matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor, 55 mlir::ConversionPatternRewriter &rewriter) const override { 56 auto ty = convertType(addr.getType()); 57 rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>( 58 addr, ty, addr.symbol().getRootReference().getValue()); 59 return success(); 60 } 61 }; 62 63 struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> { 64 using FIROpConversion::FIROpConversion; 65 66 mlir::LogicalResult 67 matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor, 68 mlir::ConversionPatternRewriter &rewriter) const override { 69 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands()); 70 return success(); 71 } 72 }; 73 74 struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> { 75 using FIROpConversion::FIROpConversion; 76 77 mlir::LogicalResult 78 matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor, 79 mlir::ConversionPatternRewriter &rewriter) const override { 80 auto tyAttr = convertType(global.getType()); 81 if (global.getType().isa<fir::BoxType>()) 82 tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType(); 83 auto loc = global.getLoc(); 84 mlir::Attribute initAttr{}; 85 if (global.initVal()) 86 initAttr = global.initVal().getValue(); 87 auto linkage = convertLinkage(global.linkName()); 88 auto isConst = global.constant().hasValue(); 89 auto g = rewriter.create<mlir::LLVM::GlobalOp>( 90 loc, tyAttr, isConst, linkage, global.sym_name(), initAttr); 91 auto &gr = g.getInitializerRegion(); 92 rewriter.inlineRegionBefore(global.region(), gr, gr.end()); 93 if (!gr.empty()) { 94 // Replace insert_on_range with a constant dense attribute if the 95 // initialization is on the full range. 96 auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>(); 97 for (auto insertOp : insertOnRangeOps) { 98 if (isFullRange(insertOp.coor(), insertOp.getType())) { 99 auto seqTyAttr = convertType(insertOp.getType()); 100 auto *op = insertOp.val().getDefiningOp(); 101 auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op); 102 if (!constant) { 103 auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op); 104 if (!convertOp) 105 continue; 106 constant = cast<mlir::arith::ConstantOp>( 107 convertOp.value().getDefiningOp()); 108 } 109 mlir::Type vecType = mlir::VectorType::get( 110 insertOp.getType().getShape(), constant.getType()); 111 auto denseAttr = mlir::DenseElementsAttr::get( 112 vecType.cast<ShapedType>(), constant.value()); 113 rewriter.setInsertionPointAfter(insertOp); 114 rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>( 115 insertOp, seqTyAttr, denseAttr); 116 } 117 } 118 } 119 rewriter.eraseOp(global); 120 return success(); 121 } 122 123 bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const { 124 auto extents = seqTy.getShape(); 125 if (indexes.size() / 2 != extents.size()) 126 return false; 127 for (unsigned i = 0; i < indexes.size(); i += 2) { 128 if (indexes[i].cast<IntegerAttr>().getInt() != 0) 129 return false; 130 if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1) 131 return false; 132 } 133 return true; 134 } 135 136 mlir::LLVM::Linkage convertLinkage(Optional<StringRef> optLinkage) const { 137 if (optLinkage.hasValue()) { 138 auto name = optLinkage.getValue(); 139 if (name == "internal") 140 return mlir::LLVM::Linkage::Internal; 141 if (name == "linkonce") 142 return mlir::LLVM::Linkage::Linkonce; 143 if (name == "common") 144 return mlir::LLVM::Linkage::Common; 145 if (name == "weak") 146 return mlir::LLVM::Linkage::Weak; 147 } 148 return mlir::LLVM::Linkage::External; 149 } 150 }; 151 152 // convert to LLVM IR dialect `undef` 153 struct UndefOpConversion : public FIROpConversion<fir::UndefOp> { 154 using FIROpConversion::FIROpConversion; 155 156 mlir::LogicalResult 157 matchAndRewrite(fir::UndefOp undef, OpAdaptor, 158 mlir::ConversionPatternRewriter &rewriter) const override { 159 rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>( 160 undef, convertType(undef.getType())); 161 return success(); 162 } 163 }; 164 } // namespace 165 166 namespace { 167 /// Convert FIR dialect to LLVM dialect 168 /// 169 /// This pass lowers all FIR dialect operations to LLVM IR dialect. An 170 /// MLIR pass is used to lower residual Std dialect to LLVM IR dialect. 171 /// 172 /// This pass is not complete yet. We are upstreaming it in small patches. 173 class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> { 174 public: 175 mlir::ModuleOp getModule() { return getOperation(); } 176 177 void runOnOperation() override final { 178 auto *context = getModule().getContext(); 179 fir::LLVMTypeConverter typeConverter{getModule()}; 180 auto loc = mlir::UnknownLoc::get(context); 181 mlir::OwningRewritePatternList pattern(context); 182 pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion, 183 UndefOpConversion>(typeConverter); 184 mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); 185 mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, 186 pattern); 187 mlir::ConversionTarget target{*context}; 188 target.addLegalDialect<mlir::LLVM::LLVMDialect>(); 189 190 // required NOPs for applying a full conversion 191 target.addLegalOp<mlir::ModuleOp>(); 192 193 // apply the patterns 194 if (mlir::failed(mlir::applyFullConversion(getModule(), target, 195 std::move(pattern)))) { 196 mlir::emitError(loc, "error in converting to LLVM-IR dialect\n"); 197 signalPassFailure(); 198 } 199 } 200 }; 201 } // namespace 202 203 std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() { 204 return std::make_unique<FIRToLLVMLowering>(); 205 } 206