1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Optimizer/CodeGen/CodeGen.h" 14 #include "PassDetail.h" 15 #include "flang/Optimizer/Dialect/FIROps.h" 16 #include "flang/Optimizer/Dialect/FIRType.h" 17 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 18 #include "mlir/Conversion/LLVMCommon/Pattern.h" 19 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/Pass/Pass.h" 23 #include "llvm/ADT/ArrayRef.h" 24 25 #define DEBUG_TYPE "flang-codegen" 26 27 // fir::LLVMTypeConverter for converting to LLVM IR dialect types. 28 #include "TypeConverter.h" 29 30 namespace { 31 /// FIR conversion pattern template 32 template <typename FromOp> 33 class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> { 34 public: 35 explicit FIROpConversion(fir::LLVMTypeConverter &lowering) 36 : mlir::ConvertOpToLLVMPattern<FromOp>(lowering) {} 37 38 protected: 39 mlir::Type convertType(mlir::Type ty) const { 40 return lowerTy().convertType(ty); 41 } 42 43 fir::LLVMTypeConverter &lowerTy() const { 44 return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter()); 45 } 46 }; 47 48 // Lower `fir.address_of` operation to `llvm.address_of` operation. 49 struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> { 50 using FIROpConversion::FIROpConversion; 51 52 mlir::LogicalResult 53 matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor, 54 mlir::ConversionPatternRewriter &rewriter) const override { 55 auto ty = convertType(addr.getType()); 56 rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>( 57 addr, ty, addr.symbol().getRootReference().getValue()); 58 return success(); 59 } 60 }; 61 62 /// Lower `fir.has_value` operation to `llvm.return` operation. 63 struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> { 64 using FIROpConversion::FIROpConversion; 65 66 mlir::LogicalResult 67 matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor, 68 mlir::ConversionPatternRewriter &rewriter) const override { 69 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands()); 70 return success(); 71 } 72 }; 73 74 /// Lower `fir.global` operation to `llvm.global` operation. 75 /// `fir.insert_on_range` operations are replaced with constant dense attribute 76 /// if they are applied on the full range. 77 struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> { 78 using FIROpConversion::FIROpConversion; 79 80 mlir::LogicalResult 81 matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor, 82 mlir::ConversionPatternRewriter &rewriter) const override { 83 auto tyAttr = convertType(global.getType()); 84 if (global.getType().isa<fir::BoxType>()) 85 tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType(); 86 auto loc = global.getLoc(); 87 mlir::Attribute initAttr{}; 88 if (global.initVal()) 89 initAttr = global.initVal().getValue(); 90 auto linkage = convertLinkage(global.linkName()); 91 auto isConst = global.constant().hasValue(); 92 auto g = rewriter.create<mlir::LLVM::GlobalOp>( 93 loc, tyAttr, isConst, linkage, global.sym_name(), initAttr); 94 auto &gr = g.getInitializerRegion(); 95 rewriter.inlineRegionBefore(global.region(), gr, gr.end()); 96 if (!gr.empty()) { 97 // Replace insert_on_range with a constant dense attribute if the 98 // initialization is on the full range. 99 auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>(); 100 for (auto insertOp : insertOnRangeOps) { 101 if (isFullRange(insertOp.coor(), insertOp.getType())) { 102 auto seqTyAttr = convertType(insertOp.getType()); 103 auto *op = insertOp.val().getDefiningOp(); 104 auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op); 105 if (!constant) { 106 auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op); 107 if (!convertOp) 108 continue; 109 constant = cast<mlir::arith::ConstantOp>( 110 convertOp.value().getDefiningOp()); 111 } 112 mlir::Type vecType = mlir::VectorType::get( 113 insertOp.getType().getShape(), constant.getType()); 114 auto denseAttr = mlir::DenseElementsAttr::get( 115 vecType.cast<ShapedType>(), constant.value()); 116 rewriter.setInsertionPointAfter(insertOp); 117 rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>( 118 insertOp, seqTyAttr, denseAttr); 119 } 120 } 121 } 122 rewriter.eraseOp(global); 123 return success(); 124 } 125 126 bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const { 127 auto extents = seqTy.getShape(); 128 if (indexes.size() / 2 != extents.size()) 129 return false; 130 for (unsigned i = 0; i < indexes.size(); i += 2) { 131 if (indexes[i].cast<IntegerAttr>().getInt() != 0) 132 return false; 133 if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1) 134 return false; 135 } 136 return true; 137 } 138 139 // TODO: String comparaison should be avoided. Replace linkName with an 140 // enumeration. 141 mlir::LLVM::Linkage convertLinkage(Optional<StringRef> optLinkage) const { 142 if (optLinkage.hasValue()) { 143 auto name = optLinkage.getValue(); 144 if (name == "internal") 145 return mlir::LLVM::Linkage::Internal; 146 if (name == "linkonce") 147 return mlir::LLVM::Linkage::Linkonce; 148 if (name == "common") 149 return mlir::LLVM::Linkage::Common; 150 if (name == "weak") 151 return mlir::LLVM::Linkage::Weak; 152 } 153 return mlir::LLVM::Linkage::External; 154 } 155 }; 156 157 // convert to LLVM IR dialect `undef` 158 struct UndefOpConversion : public FIROpConversion<fir::UndefOp> { 159 using FIROpConversion::FIROpConversion; 160 161 mlir::LogicalResult 162 matchAndRewrite(fir::UndefOp undef, OpAdaptor, 163 mlir::ConversionPatternRewriter &rewriter) const override { 164 rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>( 165 undef, convertType(undef.getType())); 166 return success(); 167 } 168 }; 169 170 // convert to LLVM IR dialect `unreachable` 171 struct UnreachableOpConversion : public FIROpConversion<fir::UnreachableOp> { 172 using FIROpConversion::FIROpConversion; 173 174 mlir::LogicalResult 175 matchAndRewrite(fir::UnreachableOp unreach, OpAdaptor adaptor, 176 mlir::ConversionPatternRewriter &rewriter) const override { 177 rewriter.replaceOpWithNewOp<mlir::LLVM::UnreachableOp>(unreach); 178 return success(); 179 } 180 }; 181 182 struct ZeroOpConversion : public FIROpConversion<fir::ZeroOp> { 183 using FIROpConversion::FIROpConversion; 184 185 mlir::LogicalResult 186 matchAndRewrite(fir::ZeroOp zero, OpAdaptor, 187 mlir::ConversionPatternRewriter &rewriter) const override { 188 auto ty = convertType(zero.getType()); 189 if (ty.isa<mlir::LLVM::LLVMPointerType>()) { 190 rewriter.replaceOpWithNewOp<mlir::LLVM::NullOp>(zero, ty); 191 } else if (ty.isa<mlir::IntegerType>()) { 192 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>( 193 zero, ty, mlir::IntegerAttr::get(zero.getType(), 0)); 194 } else if (mlir::LLVM::isCompatibleFloatingPointType(ty)) { 195 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>( 196 zero, ty, mlir::FloatAttr::get(zero.getType(), 0.0)); 197 } else { 198 // TODO: create ConstantAggregateZero for FIR aggregate/array types. 199 return rewriter.notifyMatchFailure( 200 zero, 201 "conversion of fir.zero with aggregate type not implemented yet"); 202 } 203 return success(); 204 } 205 }; 206 207 } // namespace 208 209 namespace { 210 /// Convert FIR dialect to LLVM dialect 211 /// 212 /// This pass lowers all FIR dialect operations to LLVM IR dialect. An 213 /// MLIR pass is used to lower residual Std dialect to LLVM IR dialect. 214 /// 215 /// This pass is not complete yet. We are upstreaming it in small patches. 216 class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> { 217 public: 218 mlir::ModuleOp getModule() { return getOperation(); } 219 220 void runOnOperation() override final { 221 auto *context = getModule().getContext(); 222 fir::LLVMTypeConverter typeConverter{getModule()}; 223 mlir::OwningRewritePatternList pattern(context); 224 pattern 225 .insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion, 226 UndefOpConversion, UnreachableOpConversion, ZeroOpConversion>( 227 typeConverter); 228 mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); 229 mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, 230 pattern); 231 mlir::ConversionTarget target{*context}; 232 target.addLegalDialect<mlir::LLVM::LLVMDialect>(); 233 234 // required NOPs for applying a full conversion 235 target.addLegalOp<mlir::ModuleOp>(); 236 237 // apply the patterns 238 if (mlir::failed(mlir::applyFullConversion(getModule(), target, 239 std::move(pattern)))) { 240 signalPassFailure(); 241 } 242 } 243 }; 244 } // namespace 245 246 std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() { 247 return std::make_unique<FIRToLLVMLowering>(); 248 } 249