1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Optimizer/CodeGen/CodeGen.h" 14 #include "PassDetail.h" 15 #include "flang/Optimizer/Dialect/FIROps.h" 16 #include "flang/Optimizer/Dialect/FIRType.h" 17 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 18 #include "mlir/Conversion/LLVMCommon/Pattern.h" 19 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/Pass/Pass.h" 24 #include "llvm/ADT/ArrayRef.h" 25 26 #define DEBUG_TYPE "flang-codegen" 27 28 // fir::LLVMTypeConverter for converting to LLVM IR dialect types. 29 #include "TypeConverter.h" 30 31 namespace { 32 /// FIR conversion pattern template 33 template <typename FromOp> 34 class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> { 35 public: 36 explicit FIROpConversion(fir::LLVMTypeConverter &lowering) 37 : mlir::ConvertOpToLLVMPattern<FromOp>(lowering) {} 38 39 protected: 40 mlir::Type convertType(mlir::Type ty) const { 41 return lowerTy().convertType(ty); 42 } 43 44 fir::LLVMTypeConverter &lowerTy() const { 45 return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter()); 46 } 47 }; 48 49 /// FIR conversion pattern template 50 template <typename FromOp> 51 class FIROpAndTypeConversion : public FIROpConversion<FromOp> { 52 public: 53 using FIROpConversion<FromOp>::FIROpConversion; 54 using OpAdaptor = typename FromOp::Adaptor; 55 56 mlir::LogicalResult 57 matchAndRewrite(FromOp op, OpAdaptor adaptor, 58 mlir::ConversionPatternRewriter &rewriter) const final { 59 mlir::Type ty = this->convertType(op.getType()); 60 return doRewrite(op, ty, adaptor, rewriter); 61 } 62 63 virtual mlir::LogicalResult 64 doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor, 65 mlir::ConversionPatternRewriter &rewriter) const = 0; 66 }; 67 68 // Lower `fir.address_of` operation to `llvm.address_of` operation. 69 struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> { 70 using FIROpConversion::FIROpConversion; 71 72 mlir::LogicalResult 73 matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor, 74 mlir::ConversionPatternRewriter &rewriter) const override { 75 auto ty = convertType(addr.getType()); 76 rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>( 77 addr, ty, addr.symbol().getRootReference().getValue()); 78 return success(); 79 } 80 }; 81 82 /// Lower `fir.has_value` operation to `llvm.return` operation. 83 struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> { 84 using FIROpConversion::FIROpConversion; 85 86 mlir::LogicalResult 87 matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor, 88 mlir::ConversionPatternRewriter &rewriter) const override { 89 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands()); 90 return success(); 91 } 92 }; 93 94 /// Lower `fir.global` operation to `llvm.global` operation. 95 /// `fir.insert_on_range` operations are replaced with constant dense attribute 96 /// if they are applied on the full range. 97 struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> { 98 using FIROpConversion::FIROpConversion; 99 100 mlir::LogicalResult 101 matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor, 102 mlir::ConversionPatternRewriter &rewriter) const override { 103 auto tyAttr = convertType(global.getType()); 104 if (global.getType().isa<fir::BoxType>()) 105 tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType(); 106 auto loc = global.getLoc(); 107 mlir::Attribute initAttr{}; 108 if (global.initVal()) 109 initAttr = global.initVal().getValue(); 110 auto linkage = convertLinkage(global.linkName()); 111 auto isConst = global.constant().hasValue(); 112 auto g = rewriter.create<mlir::LLVM::GlobalOp>( 113 loc, tyAttr, isConst, linkage, global.sym_name(), initAttr); 114 auto &gr = g.getInitializerRegion(); 115 rewriter.inlineRegionBefore(global.region(), gr, gr.end()); 116 if (!gr.empty()) { 117 // Replace insert_on_range with a constant dense attribute if the 118 // initialization is on the full range. 119 auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>(); 120 for (auto insertOp : insertOnRangeOps) { 121 if (isFullRange(insertOp.coor(), insertOp.getType())) { 122 auto seqTyAttr = convertType(insertOp.getType()); 123 auto *op = insertOp.val().getDefiningOp(); 124 auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op); 125 if (!constant) { 126 auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op); 127 if (!convertOp) 128 continue; 129 constant = cast<mlir::arith::ConstantOp>( 130 convertOp.value().getDefiningOp()); 131 } 132 mlir::Type vecType = mlir::VectorType::get( 133 insertOp.getType().getShape(), constant.getType()); 134 auto denseAttr = mlir::DenseElementsAttr::get( 135 vecType.cast<ShapedType>(), constant.value()); 136 rewriter.setInsertionPointAfter(insertOp); 137 rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>( 138 insertOp, seqTyAttr, denseAttr); 139 } 140 } 141 } 142 rewriter.eraseOp(global); 143 return success(); 144 } 145 146 bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const { 147 auto extents = seqTy.getShape(); 148 if (indexes.size() / 2 != extents.size()) 149 return false; 150 for (unsigned i = 0; i < indexes.size(); i += 2) { 151 if (indexes[i].cast<IntegerAttr>().getInt() != 0) 152 return false; 153 if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1) 154 return false; 155 } 156 return true; 157 } 158 159 // TODO: String comparaison should be avoided. Replace linkName with an 160 // enumeration. 161 mlir::LLVM::Linkage convertLinkage(Optional<StringRef> optLinkage) const { 162 if (optLinkage.hasValue()) { 163 auto name = optLinkage.getValue(); 164 if (name == "internal") 165 return mlir::LLVM::Linkage::Internal; 166 if (name == "linkonce") 167 return mlir::LLVM::Linkage::Linkonce; 168 if (name == "common") 169 return mlir::LLVM::Linkage::Common; 170 if (name == "weak") 171 return mlir::LLVM::Linkage::Weak; 172 } 173 return mlir::LLVM::Linkage::External; 174 } 175 }; 176 177 template <typename OP> 178 void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select, 179 typename OP::Adaptor adaptor, 180 mlir::ConversionPatternRewriter &rewriter) { 181 unsigned conds = select.getNumConditions(); 182 auto cases = select.getCases().getValue(); 183 mlir::Value selector = adaptor.selector(); 184 auto loc = select.getLoc(); 185 assert(conds > 0 && "select must have cases"); 186 187 llvm::SmallVector<mlir::Block *> destinations; 188 llvm::SmallVector<mlir::ValueRange> destinationsOperands; 189 mlir::Block *defaultDestination; 190 mlir::ValueRange defaultOperands; 191 llvm::SmallVector<int32_t> caseValues; 192 193 for (unsigned t = 0; t != conds; ++t) { 194 mlir::Block *dest = select.getSuccessor(t); 195 auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t); 196 const mlir::Attribute &attr = cases[t]; 197 if (auto intAttr = attr.template dyn_cast<mlir::IntegerAttr>()) { 198 destinations.push_back(dest); 199 destinationsOperands.push_back(destOps.hasValue() ? *destOps 200 : ValueRange()); 201 caseValues.push_back(intAttr.getInt()); 202 continue; 203 } 204 assert(attr.template dyn_cast_or_null<mlir::UnitAttr>()); 205 assert((t + 1 == conds) && "unit must be last"); 206 defaultDestination = dest; 207 defaultOperands = destOps.hasValue() ? *destOps : ValueRange(); 208 } 209 210 // LLVM::SwitchOp takes a i32 type for the selector. 211 if (select.getSelector().getType() != rewriter.getI32Type()) 212 selector = 213 rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), selector); 214 215 rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>( 216 select, selector, 217 /*defaultDestination=*/defaultDestination, 218 /*defaultOperands=*/defaultOperands, 219 /*caseValues=*/caseValues, 220 /*caseDestinations=*/destinations, 221 /*caseOperands=*/destinationsOperands, 222 /*branchWeights=*/ArrayRef<int32_t>()); 223 } 224 225 /// conversion of fir::SelectOp to an if-then-else ladder 226 struct SelectOpConversion : public FIROpConversion<fir::SelectOp> { 227 using FIROpConversion::FIROpConversion; 228 229 mlir::LogicalResult 230 matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor, 231 mlir::ConversionPatternRewriter &rewriter) const override { 232 selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, rewriter); 233 return success(); 234 } 235 }; 236 237 /// conversion of fir::SelectRankOp to an if-then-else ladder 238 struct SelectRankOpConversion : public FIROpConversion<fir::SelectRankOp> { 239 using FIROpConversion::FIROpConversion; 240 241 mlir::LogicalResult 242 matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor, 243 mlir::ConversionPatternRewriter &rewriter) const override { 244 selectMatchAndRewrite<fir::SelectRankOp>(lowerTy(), op, adaptor, rewriter); 245 return success(); 246 } 247 }; 248 249 // convert to LLVM IR dialect `undef` 250 struct UndefOpConversion : public FIROpConversion<fir::UndefOp> { 251 using FIROpConversion::FIROpConversion; 252 253 mlir::LogicalResult 254 matchAndRewrite(fir::UndefOp undef, OpAdaptor, 255 mlir::ConversionPatternRewriter &rewriter) const override { 256 rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>( 257 undef, convertType(undef.getType())); 258 return success(); 259 } 260 }; 261 262 // convert to LLVM IR dialect `unreachable` 263 struct UnreachableOpConversion : public FIROpConversion<fir::UnreachableOp> { 264 using FIROpConversion::FIROpConversion; 265 266 mlir::LogicalResult 267 matchAndRewrite(fir::UnreachableOp unreach, OpAdaptor adaptor, 268 mlir::ConversionPatternRewriter &rewriter) const override { 269 rewriter.replaceOpWithNewOp<mlir::LLVM::UnreachableOp>(unreach); 270 return success(); 271 } 272 }; 273 274 struct ZeroOpConversion : public FIROpConversion<fir::ZeroOp> { 275 using FIROpConversion::FIROpConversion; 276 277 mlir::LogicalResult 278 matchAndRewrite(fir::ZeroOp zero, OpAdaptor, 279 mlir::ConversionPatternRewriter &rewriter) const override { 280 auto ty = convertType(zero.getType()); 281 if (ty.isa<mlir::LLVM::LLVMPointerType>()) { 282 rewriter.replaceOpWithNewOp<mlir::LLVM::NullOp>(zero, ty); 283 } else if (ty.isa<mlir::IntegerType>()) { 284 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>( 285 zero, ty, mlir::IntegerAttr::get(zero.getType(), 0)); 286 } else if (mlir::LLVM::isCompatibleFloatingPointType(ty)) { 287 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>( 288 zero, ty, mlir::FloatAttr::get(zero.getType(), 0.0)); 289 } else { 290 // TODO: create ConstantAggregateZero for FIR aggregate/array types. 291 return rewriter.notifyMatchFailure( 292 zero, 293 "conversion of fir.zero with aggregate type not implemented yet"); 294 } 295 return success(); 296 } 297 }; 298 299 /// InsertOnRange inserts a value into a sequence over a range of offsets. 300 struct InsertOnRangeOpConversion 301 : public FIROpAndTypeConversion<fir::InsertOnRangeOp> { 302 using FIROpAndTypeConversion::FIROpAndTypeConversion; 303 304 // Increments an array of subscripts in a row major fasion. 305 void incrementSubscripts(const SmallVector<uint64_t> &dims, 306 SmallVector<uint64_t> &subscripts) const { 307 for (size_t i = dims.size(); i > 0; --i) { 308 if (++subscripts[i - 1] < dims[i - 1]) { 309 return; 310 } 311 subscripts[i - 1] = 0; 312 } 313 } 314 315 mlir::LogicalResult 316 doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor, 317 mlir::ConversionPatternRewriter &rewriter) const override { 318 319 llvm::SmallVector<uint64_t> dims; 320 auto type = adaptor.getOperands()[0].getType(); 321 322 // Iteratively extract the array dimensions from the type. 323 while (auto t = type.dyn_cast<mlir::LLVM::LLVMArrayType>()) { 324 dims.push_back(t.getNumElements()); 325 type = t.getElementType(); 326 } 327 328 SmallVector<uint64_t> lBounds; 329 SmallVector<uint64_t> uBounds; 330 331 // Extract integer value from the attribute 332 SmallVector<int64_t> coordinates = llvm::to_vector<4>( 333 llvm::map_range(range.coor(), [](Attribute a) -> int64_t { 334 return a.cast<IntegerAttr>().getInt(); 335 })); 336 337 // Unzip the upper and lower bound and convert to a row major format. 338 for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) { 339 uBounds.push_back(*i++); 340 lBounds.push_back(*i); 341 } 342 343 auto &subscripts = lBounds; 344 auto loc = range.getLoc(); 345 mlir::Value lastOp = adaptor.getOperands()[0]; 346 mlir::Value insertVal = adaptor.getOperands()[1]; 347 348 auto i64Ty = rewriter.getI64Type(); 349 while (subscripts != uBounds) { 350 // Convert uint64_t's to Attribute's. 351 SmallVector<mlir::Attribute> subscriptAttrs; 352 for (const auto &subscript : subscripts) 353 subscriptAttrs.push_back(IntegerAttr::get(i64Ty, subscript)); 354 lastOp = rewriter.create<mlir::LLVM::InsertValueOp>( 355 loc, ty, lastOp, insertVal, 356 ArrayAttr::get(range.getContext(), subscriptAttrs)); 357 358 incrementSubscripts(dims, subscripts); 359 } 360 361 // Convert uint64_t's to Attribute's. 362 SmallVector<mlir::Attribute> subscriptAttrs; 363 for (const auto &subscript : subscripts) 364 subscriptAttrs.push_back( 365 IntegerAttr::get(rewriter.getI64Type(), subscript)); 366 mlir::ArrayRef<mlir::Attribute> arrayRef(subscriptAttrs); 367 368 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>( 369 range, ty, lastOp, insertVal, 370 ArrayAttr::get(range.getContext(), arrayRef)); 371 372 return success(); 373 } 374 }; 375 } // namespace 376 377 namespace { 378 /// Convert FIR dialect to LLVM dialect 379 /// 380 /// This pass lowers all FIR dialect operations to LLVM IR dialect. An 381 /// MLIR pass is used to lower residual Std dialect to LLVM IR dialect. 382 /// 383 /// This pass is not complete yet. We are upstreaming it in small patches. 384 class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> { 385 public: 386 mlir::ModuleOp getModule() { return getOperation(); } 387 388 void runOnOperation() override final { 389 auto *context = getModule().getContext(); 390 fir::LLVMTypeConverter typeConverter{getModule()}; 391 mlir::OwningRewritePatternList pattern(context); 392 pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion, 393 InsertOnRangeOpConversion, SelectOpConversion, 394 SelectRankOpConversion, UnreachableOpConversion, 395 ZeroOpConversion, UndefOpConversion>(typeConverter); 396 mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); 397 mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, 398 pattern); 399 mlir::ConversionTarget target{*context}; 400 target.addLegalDialect<mlir::LLVM::LLVMDialect>(); 401 402 // required NOPs for applying a full conversion 403 target.addLegalOp<mlir::ModuleOp>(); 404 405 // apply the patterns 406 if (mlir::failed(mlir::applyFullConversion(getModule(), target, 407 std::move(pattern)))) { 408 signalPassFailure(); 409 } 410 } 411 }; 412 } // namespace 413 414 std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() { 415 return std::make_unique<FIRToLLVMLowering>(); 416 } 417