1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Optimizer/CodeGen/CodeGen.h" 14 #include "PassDetail.h" 15 #include "flang/Optimizer/Dialect/FIROps.h" 16 #include "flang/Optimizer/Dialect/FIRType.h" 17 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 18 #include "mlir/Conversion/LLVMCommon/Pattern.h" 19 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/Pass/Pass.h" 24 #include "llvm/ADT/ArrayRef.h" 25 26 #define DEBUG_TYPE "flang-codegen" 27 28 // fir::LLVMTypeConverter for converting to LLVM IR dialect types. 29 #include "TypeConverter.h" 30 31 namespace { 32 /// FIR conversion pattern template 33 template <typename FromOp> 34 class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> { 35 public: 36 explicit FIROpConversion(fir::LLVMTypeConverter &lowering) 37 : mlir::ConvertOpToLLVMPattern<FromOp>(lowering) {} 38 39 protected: 40 mlir::Type convertType(mlir::Type ty) const { 41 return lowerTy().convertType(ty); 42 } 43 44 fir::LLVMTypeConverter &lowerTy() const { 45 return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter()); 46 } 47 }; 48 49 /// FIR conversion pattern template 50 template <typename FromOp> 51 class FIROpAndTypeConversion : public FIROpConversion<FromOp> { 52 public: 53 using FIROpConversion<FromOp>::FIROpConversion; 54 using OpAdaptor = typename FromOp::Adaptor; 55 56 mlir::LogicalResult 57 matchAndRewrite(FromOp op, OpAdaptor adaptor, 58 mlir::ConversionPatternRewriter &rewriter) const final { 59 mlir::Type ty = this->convertType(op.getType()); 60 return doRewrite(op, ty, adaptor, rewriter); 61 } 62 63 virtual mlir::LogicalResult 64 doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor, 65 mlir::ConversionPatternRewriter &rewriter) const = 0; 66 }; 67 68 // Lower `fir.address_of` operation to `llvm.address_of` operation. 69 struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> { 70 using FIROpConversion::FIROpConversion; 71 72 mlir::LogicalResult 73 matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor, 74 mlir::ConversionPatternRewriter &rewriter) const override { 75 auto ty = convertType(addr.getType()); 76 rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>( 77 addr, ty, addr.symbol().getRootReference().getValue()); 78 return success(); 79 } 80 }; 81 82 /// Lower `fir.has_value` operation to `llvm.return` operation. 83 struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> { 84 using FIROpConversion::FIROpConversion; 85 86 mlir::LogicalResult 87 matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor, 88 mlir::ConversionPatternRewriter &rewriter) const override { 89 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands()); 90 return success(); 91 } 92 }; 93 94 /// Lower `fir.global` operation to `llvm.global` operation. 95 /// `fir.insert_on_range` operations are replaced with constant dense attribute 96 /// if they are applied on the full range. 97 struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> { 98 using FIROpConversion::FIROpConversion; 99 100 mlir::LogicalResult 101 matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor, 102 mlir::ConversionPatternRewriter &rewriter) const override { 103 auto tyAttr = convertType(global.getType()); 104 if (global.getType().isa<fir::BoxType>()) 105 tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType(); 106 auto loc = global.getLoc(); 107 mlir::Attribute initAttr{}; 108 if (global.initVal()) 109 initAttr = global.initVal().getValue(); 110 auto linkage = convertLinkage(global.linkName()); 111 auto isConst = global.constant().hasValue(); 112 auto g = rewriter.create<mlir::LLVM::GlobalOp>( 113 loc, tyAttr, isConst, linkage, global.sym_name(), initAttr); 114 auto &gr = g.getInitializerRegion(); 115 rewriter.inlineRegionBefore(global.region(), gr, gr.end()); 116 if (!gr.empty()) { 117 // Replace insert_on_range with a constant dense attribute if the 118 // initialization is on the full range. 119 auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>(); 120 for (auto insertOp : insertOnRangeOps) { 121 if (isFullRange(insertOp.coor(), insertOp.getType())) { 122 auto seqTyAttr = convertType(insertOp.getType()); 123 auto *op = insertOp.val().getDefiningOp(); 124 auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op); 125 if (!constant) { 126 auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op); 127 if (!convertOp) 128 continue; 129 constant = cast<mlir::arith::ConstantOp>( 130 convertOp.value().getDefiningOp()); 131 } 132 mlir::Type vecType = mlir::VectorType::get( 133 insertOp.getType().getShape(), constant.getType()); 134 auto denseAttr = mlir::DenseElementsAttr::get( 135 vecType.cast<ShapedType>(), constant.value()); 136 rewriter.setInsertionPointAfter(insertOp); 137 rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>( 138 insertOp, seqTyAttr, denseAttr); 139 } 140 } 141 } 142 rewriter.eraseOp(global); 143 return success(); 144 } 145 146 bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const { 147 auto extents = seqTy.getShape(); 148 if (indexes.size() / 2 != extents.size()) 149 return false; 150 for (unsigned i = 0; i < indexes.size(); i += 2) { 151 if (indexes[i].cast<IntegerAttr>().getInt() != 0) 152 return false; 153 if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1) 154 return false; 155 } 156 return true; 157 } 158 159 // TODO: String comparaison should be avoided. Replace linkName with an 160 // enumeration. 161 mlir::LLVM::Linkage convertLinkage(Optional<StringRef> optLinkage) const { 162 if (optLinkage.hasValue()) { 163 auto name = optLinkage.getValue(); 164 if (name == "internal") 165 return mlir::LLVM::Linkage::Internal; 166 if (name == "linkonce") 167 return mlir::LLVM::Linkage::Linkonce; 168 if (name == "common") 169 return mlir::LLVM::Linkage::Common; 170 if (name == "weak") 171 return mlir::LLVM::Linkage::Weak; 172 } 173 return mlir::LLVM::Linkage::External; 174 } 175 }; 176 177 // convert to LLVM IR dialect `undef` 178 struct UndefOpConversion : public FIROpConversion<fir::UndefOp> { 179 using FIROpConversion::FIROpConversion; 180 181 mlir::LogicalResult 182 matchAndRewrite(fir::UndefOp undef, OpAdaptor, 183 mlir::ConversionPatternRewriter &rewriter) const override { 184 rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>( 185 undef, convertType(undef.getType())); 186 return success(); 187 } 188 }; 189 190 // convert to LLVM IR dialect `unreachable` 191 struct UnreachableOpConversion : public FIROpConversion<fir::UnreachableOp> { 192 using FIROpConversion::FIROpConversion; 193 194 mlir::LogicalResult 195 matchAndRewrite(fir::UnreachableOp unreach, OpAdaptor adaptor, 196 mlir::ConversionPatternRewriter &rewriter) const override { 197 rewriter.replaceOpWithNewOp<mlir::LLVM::UnreachableOp>(unreach); 198 return success(); 199 } 200 }; 201 202 struct ZeroOpConversion : public FIROpConversion<fir::ZeroOp> { 203 using FIROpConversion::FIROpConversion; 204 205 mlir::LogicalResult 206 matchAndRewrite(fir::ZeroOp zero, OpAdaptor, 207 mlir::ConversionPatternRewriter &rewriter) const override { 208 auto ty = convertType(zero.getType()); 209 if (ty.isa<mlir::LLVM::LLVMPointerType>()) { 210 rewriter.replaceOpWithNewOp<mlir::LLVM::NullOp>(zero, ty); 211 } else if (ty.isa<mlir::IntegerType>()) { 212 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>( 213 zero, ty, mlir::IntegerAttr::get(zero.getType(), 0)); 214 } else if (mlir::LLVM::isCompatibleFloatingPointType(ty)) { 215 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>( 216 zero, ty, mlir::FloatAttr::get(zero.getType(), 0.0)); 217 } else { 218 // TODO: create ConstantAggregateZero for FIR aggregate/array types. 219 return rewriter.notifyMatchFailure( 220 zero, 221 "conversion of fir.zero with aggregate type not implemented yet"); 222 } 223 return success(); 224 } 225 }; 226 227 /// InsertOnRange inserts a value into a sequence over a range of offsets. 228 struct InsertOnRangeOpConversion 229 : public FIROpAndTypeConversion<fir::InsertOnRangeOp> { 230 using FIROpAndTypeConversion::FIROpAndTypeConversion; 231 232 // Increments an array of subscripts in a row major fasion. 233 void incrementSubscripts(const SmallVector<uint64_t> &dims, 234 SmallVector<uint64_t> &subscripts) const { 235 for (size_t i = dims.size(); i > 0; --i) { 236 if (++subscripts[i - 1] < dims[i - 1]) { 237 return; 238 } 239 subscripts[i - 1] = 0; 240 } 241 } 242 243 mlir::LogicalResult 244 doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor, 245 mlir::ConversionPatternRewriter &rewriter) const override { 246 247 llvm::SmallVector<uint64_t> dims; 248 auto type = adaptor.getOperands()[0].getType(); 249 250 // Iteratively extract the array dimensions from the type. 251 while (auto t = type.dyn_cast<mlir::LLVM::LLVMArrayType>()) { 252 dims.push_back(t.getNumElements()); 253 type = t.getElementType(); 254 } 255 256 SmallVector<uint64_t> lBounds; 257 SmallVector<uint64_t> uBounds; 258 259 // Extract integer value from the attribute 260 SmallVector<int64_t> coordinates = llvm::to_vector<4>( 261 llvm::map_range(range.coor(), [](Attribute a) -> int64_t { 262 return a.cast<IntegerAttr>().getInt(); 263 })); 264 265 // Unzip the upper and lower bound and convert to a row major format. 266 for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) { 267 uBounds.push_back(*i++); 268 lBounds.push_back(*i); 269 } 270 271 auto &subscripts = lBounds; 272 auto loc = range.getLoc(); 273 mlir::Value lastOp = adaptor.getOperands()[0]; 274 mlir::Value insertVal = adaptor.getOperands()[1]; 275 276 auto i64Ty = rewriter.getI64Type(); 277 while (subscripts != uBounds) { 278 // Convert uint64_t's to Attribute's. 279 SmallVector<mlir::Attribute> subscriptAttrs; 280 for (const auto &subscript : subscripts) 281 subscriptAttrs.push_back(IntegerAttr::get(i64Ty, subscript)); 282 lastOp = rewriter.create<mlir::LLVM::InsertValueOp>( 283 loc, ty, lastOp, insertVal, 284 ArrayAttr::get(range.getContext(), subscriptAttrs)); 285 286 incrementSubscripts(dims, subscripts); 287 } 288 289 // Convert uint64_t's to Attribute's. 290 SmallVector<mlir::Attribute> subscriptAttrs; 291 for (const auto &subscript : subscripts) 292 subscriptAttrs.push_back( 293 IntegerAttr::get(rewriter.getI64Type(), subscript)); 294 mlir::ArrayRef<mlir::Attribute> arrayRef(subscriptAttrs); 295 296 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>( 297 range, ty, lastOp, insertVal, 298 ArrayAttr::get(range.getContext(), arrayRef)); 299 300 return success(); 301 } 302 }; 303 } // namespace 304 305 namespace { 306 /// Convert FIR dialect to LLVM dialect 307 /// 308 /// This pass lowers all FIR dialect operations to LLVM IR dialect. An 309 /// MLIR pass is used to lower residual Std dialect to LLVM IR dialect. 310 /// 311 /// This pass is not complete yet. We are upstreaming it in small patches. 312 class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> { 313 public: 314 mlir::ModuleOp getModule() { return getOperation(); } 315 316 void runOnOperation() override final { 317 auto *context = getModule().getContext(); 318 fir::LLVMTypeConverter typeConverter{getModule()}; 319 mlir::OwningRewritePatternList pattern(context); 320 pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion, 321 InsertOnRangeOpConversion, UndefOpConversion, 322 UnreachableOpConversion, ZeroOpConversion>(typeConverter); 323 mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); 324 mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, 325 pattern); 326 mlir::ConversionTarget target{*context}; 327 target.addLegalDialect<mlir::LLVM::LLVMDialect>(); 328 329 // required NOPs for applying a full conversion 330 target.addLegalOp<mlir::ModuleOp>(); 331 332 // apply the patterns 333 if (mlir::failed(mlir::applyFullConversion(getModule(), target, 334 std::move(pattern)))) { 335 signalPassFailure(); 336 } 337 } 338 }; 339 } // namespace 340 341 std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() { 342 return std::make_unique<FIRToLLVMLowering>(); 343 } 344