1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Optimizer/CodeGen/CodeGen.h" 14 #include "PassDetail.h" 15 #include "flang/Optimizer/Dialect/FIROps.h" 16 #include "flang/Optimizer/Dialect/FIRType.h" 17 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 18 #include "mlir/Conversion/LLVMCommon/Pattern.h" 19 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/Pass/Pass.h" 24 #include "llvm/ADT/ArrayRef.h" 25 26 #define DEBUG_TYPE "flang-codegen" 27 28 // fir::LLVMTypeConverter for converting to LLVM IR dialect types. 29 #include "TypeConverter.h" 30 31 namespace { 32 /// FIR conversion pattern template 33 template <typename FromOp> 34 class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> { 35 public: 36 explicit FIROpConversion(fir::LLVMTypeConverter &lowering) 37 : mlir::ConvertOpToLLVMPattern<FromOp>(lowering) {} 38 39 protected: 40 mlir::Type convertType(mlir::Type ty) const { 41 return lowerTy().convertType(ty); 42 } 43 44 fir::LLVMTypeConverter &lowerTy() const { 45 return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter()); 46 } 47 }; 48 49 /// FIR conversion pattern template 50 template <typename FromOp> 51 class FIROpAndTypeConversion : public FIROpConversion<FromOp> { 52 public: 53 using FIROpConversion<FromOp>::FIROpConversion; 54 using OpAdaptor = typename FromOp::Adaptor; 55 56 mlir::LogicalResult 57 matchAndRewrite(FromOp op, OpAdaptor adaptor, 58 mlir::ConversionPatternRewriter &rewriter) const final { 59 mlir::Type ty = this->convertType(op.getType()); 60 return doRewrite(op, ty, adaptor, rewriter); 61 } 62 63 virtual mlir::LogicalResult 64 doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor, 65 mlir::ConversionPatternRewriter &rewriter) const = 0; 66 }; 67 68 // Lower `fir.address_of` operation to `llvm.address_of` operation. 69 struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> { 70 using FIROpConversion::FIROpConversion; 71 72 mlir::LogicalResult 73 matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor, 74 mlir::ConversionPatternRewriter &rewriter) const override { 75 auto ty = convertType(addr.getType()); 76 rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>( 77 addr, ty, addr.symbol().getRootReference().getValue()); 78 return success(); 79 } 80 }; 81 82 // `fir.call` -> `llvm.call` 83 struct CallOpConversion : public FIROpConversion<fir::CallOp> { 84 using FIROpConversion::FIROpConversion; 85 86 mlir::LogicalResult 87 matchAndRewrite(fir::CallOp call, OpAdaptor adaptor, 88 mlir::ConversionPatternRewriter &rewriter) const override { 89 SmallVector<mlir::Type> resultTys; 90 for (auto r : call.getResults()) 91 resultTys.push_back(convertType(r.getType())); 92 rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( 93 call, resultTys, adaptor.getOperands(), call->getAttrs()); 94 return success(); 95 } 96 }; 97 98 /// Lower `fir.has_value` operation to `llvm.return` operation. 99 struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> { 100 using FIROpConversion::FIROpConversion; 101 102 mlir::LogicalResult 103 matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor, 104 mlir::ConversionPatternRewriter &rewriter) const override { 105 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands()); 106 return success(); 107 } 108 }; 109 110 /// Lower `fir.global` operation to `llvm.global` operation. 111 /// `fir.insert_on_range` operations are replaced with constant dense attribute 112 /// if they are applied on the full range. 113 struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> { 114 using FIROpConversion::FIROpConversion; 115 116 mlir::LogicalResult 117 matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor, 118 mlir::ConversionPatternRewriter &rewriter) const override { 119 auto tyAttr = convertType(global.getType()); 120 if (global.getType().isa<fir::BoxType>()) 121 tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType(); 122 auto loc = global.getLoc(); 123 mlir::Attribute initAttr{}; 124 if (global.initVal()) 125 initAttr = global.initVal().getValue(); 126 auto linkage = convertLinkage(global.linkName()); 127 auto isConst = global.constant().hasValue(); 128 auto g = rewriter.create<mlir::LLVM::GlobalOp>( 129 loc, tyAttr, isConst, linkage, global.sym_name(), initAttr); 130 auto &gr = g.getInitializerRegion(); 131 rewriter.inlineRegionBefore(global.region(), gr, gr.end()); 132 if (!gr.empty()) { 133 // Replace insert_on_range with a constant dense attribute if the 134 // initialization is on the full range. 135 auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>(); 136 for (auto insertOp : insertOnRangeOps) { 137 if (isFullRange(insertOp.coor(), insertOp.getType())) { 138 auto seqTyAttr = convertType(insertOp.getType()); 139 auto *op = insertOp.val().getDefiningOp(); 140 auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op); 141 if (!constant) { 142 auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op); 143 if (!convertOp) 144 continue; 145 constant = cast<mlir::arith::ConstantOp>( 146 convertOp.value().getDefiningOp()); 147 } 148 mlir::Type vecType = mlir::VectorType::get( 149 insertOp.getType().getShape(), constant.getType()); 150 auto denseAttr = mlir::DenseElementsAttr::get( 151 vecType.cast<ShapedType>(), constant.value()); 152 rewriter.setInsertionPointAfter(insertOp); 153 rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>( 154 insertOp, seqTyAttr, denseAttr); 155 } 156 } 157 } 158 rewriter.eraseOp(global); 159 return success(); 160 } 161 162 bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const { 163 auto extents = seqTy.getShape(); 164 if (indexes.size() / 2 != extents.size()) 165 return false; 166 for (unsigned i = 0; i < indexes.size(); i += 2) { 167 if (indexes[i].cast<IntegerAttr>().getInt() != 0) 168 return false; 169 if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1) 170 return false; 171 } 172 return true; 173 } 174 175 // TODO: String comparaison should be avoided. Replace linkName with an 176 // enumeration. 177 mlir::LLVM::Linkage convertLinkage(Optional<StringRef> optLinkage) const { 178 if (optLinkage.hasValue()) { 179 auto name = optLinkage.getValue(); 180 if (name == "internal") 181 return mlir::LLVM::Linkage::Internal; 182 if (name == "linkonce") 183 return mlir::LLVM::Linkage::Linkonce; 184 if (name == "common") 185 return mlir::LLVM::Linkage::Common; 186 if (name == "weak") 187 return mlir::LLVM::Linkage::Weak; 188 } 189 return mlir::LLVM::Linkage::External; 190 } 191 }; 192 193 template <typename OP> 194 void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select, 195 typename OP::Adaptor adaptor, 196 mlir::ConversionPatternRewriter &rewriter) { 197 unsigned conds = select.getNumConditions(); 198 auto cases = select.getCases().getValue(); 199 mlir::Value selector = adaptor.selector(); 200 auto loc = select.getLoc(); 201 assert(conds > 0 && "select must have cases"); 202 203 llvm::SmallVector<mlir::Block *> destinations; 204 llvm::SmallVector<mlir::ValueRange> destinationsOperands; 205 mlir::Block *defaultDestination; 206 mlir::ValueRange defaultOperands; 207 llvm::SmallVector<int32_t> caseValues; 208 209 for (unsigned t = 0; t != conds; ++t) { 210 mlir::Block *dest = select.getSuccessor(t); 211 auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t); 212 const mlir::Attribute &attr = cases[t]; 213 if (auto intAttr = attr.template dyn_cast<mlir::IntegerAttr>()) { 214 destinations.push_back(dest); 215 destinationsOperands.push_back(destOps.hasValue() ? *destOps 216 : ValueRange()); 217 caseValues.push_back(intAttr.getInt()); 218 continue; 219 } 220 assert(attr.template dyn_cast_or_null<mlir::UnitAttr>()); 221 assert((t + 1 == conds) && "unit must be last"); 222 defaultDestination = dest; 223 defaultOperands = destOps.hasValue() ? *destOps : ValueRange(); 224 } 225 226 // LLVM::SwitchOp takes a i32 type for the selector. 227 if (select.getSelector().getType() != rewriter.getI32Type()) 228 selector = 229 rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), selector); 230 231 rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>( 232 select, selector, 233 /*defaultDestination=*/defaultDestination, 234 /*defaultOperands=*/defaultOperands, 235 /*caseValues=*/caseValues, 236 /*caseDestinations=*/destinations, 237 /*caseOperands=*/destinationsOperands, 238 /*branchWeights=*/ArrayRef<int32_t>()); 239 } 240 241 /// conversion of fir::SelectOp to an if-then-else ladder 242 struct SelectOpConversion : public FIROpConversion<fir::SelectOp> { 243 using FIROpConversion::FIROpConversion; 244 245 mlir::LogicalResult 246 matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor, 247 mlir::ConversionPatternRewriter &rewriter) const override { 248 selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, rewriter); 249 return success(); 250 } 251 }; 252 253 /// conversion of fir::SelectRankOp to an if-then-else ladder 254 struct SelectRankOpConversion : public FIROpConversion<fir::SelectRankOp> { 255 using FIROpConversion::FIROpConversion; 256 257 mlir::LogicalResult 258 matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor, 259 mlir::ConversionPatternRewriter &rewriter) const override { 260 selectMatchAndRewrite<fir::SelectRankOp>(lowerTy(), op, adaptor, rewriter); 261 return success(); 262 } 263 }; 264 265 // convert to LLVM IR dialect `undef` 266 struct UndefOpConversion : public FIROpConversion<fir::UndefOp> { 267 using FIROpConversion::FIROpConversion; 268 269 mlir::LogicalResult 270 matchAndRewrite(fir::UndefOp undef, OpAdaptor, 271 mlir::ConversionPatternRewriter &rewriter) const override { 272 rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>( 273 undef, convertType(undef.getType())); 274 return success(); 275 } 276 }; 277 278 // convert to LLVM IR dialect `unreachable` 279 struct UnreachableOpConversion : public FIROpConversion<fir::UnreachableOp> { 280 using FIROpConversion::FIROpConversion; 281 282 mlir::LogicalResult 283 matchAndRewrite(fir::UnreachableOp unreach, OpAdaptor adaptor, 284 mlir::ConversionPatternRewriter &rewriter) const override { 285 rewriter.replaceOpWithNewOp<mlir::LLVM::UnreachableOp>(unreach); 286 return success(); 287 } 288 }; 289 290 struct ZeroOpConversion : public FIROpConversion<fir::ZeroOp> { 291 using FIROpConversion::FIROpConversion; 292 293 mlir::LogicalResult 294 matchAndRewrite(fir::ZeroOp zero, OpAdaptor, 295 mlir::ConversionPatternRewriter &rewriter) const override { 296 auto ty = convertType(zero.getType()); 297 if (ty.isa<mlir::LLVM::LLVMPointerType>()) { 298 rewriter.replaceOpWithNewOp<mlir::LLVM::NullOp>(zero, ty); 299 } else if (ty.isa<mlir::IntegerType>()) { 300 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>( 301 zero, ty, mlir::IntegerAttr::get(zero.getType(), 0)); 302 } else if (mlir::LLVM::isCompatibleFloatingPointType(ty)) { 303 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>( 304 zero, ty, mlir::FloatAttr::get(zero.getType(), 0.0)); 305 } else { 306 // TODO: create ConstantAggregateZero for FIR aggregate/array types. 307 return rewriter.notifyMatchFailure( 308 zero, 309 "conversion of fir.zero with aggregate type not implemented yet"); 310 } 311 return success(); 312 } 313 }; 314 315 // Code shared between insert_value and extract_value Ops. 316 struct ValueOpCommon { 317 // Translate the arguments pertaining to any multidimensional array to 318 // row-major order for LLVM-IR. 319 static void toRowMajor(SmallVectorImpl<mlir::Attribute> &attrs, 320 mlir::Type ty) { 321 assert(ty && "type is null"); 322 const auto end = attrs.size(); 323 for (std::remove_const_t<decltype(end)> i = 0; i < end; ++i) { 324 if (auto seq = ty.dyn_cast<mlir::LLVM::LLVMArrayType>()) { 325 const auto dim = getDimension(seq); 326 if (dim > 1) { 327 auto ub = std::min(i + dim, end); 328 std::reverse(attrs.begin() + i, attrs.begin() + ub); 329 i += dim - 1; 330 } 331 ty = getArrayElementType(seq); 332 } else if (auto st = ty.dyn_cast<mlir::LLVM::LLVMStructType>()) { 333 ty = st.getBody()[attrs[i].cast<mlir::IntegerAttr>().getInt()]; 334 } else { 335 llvm_unreachable("index into invalid type"); 336 } 337 } 338 } 339 340 static llvm::SmallVector<mlir::Attribute> 341 collectIndices(mlir::ConversionPatternRewriter &rewriter, 342 mlir::ArrayAttr arrAttr) { 343 llvm::SmallVector<mlir::Attribute> attrs; 344 for (auto i = arrAttr.begin(), e = arrAttr.end(); i != e; ++i) { 345 if (i->isa<mlir::IntegerAttr>()) { 346 attrs.push_back(*i); 347 } else { 348 auto fieldName = i->cast<mlir::StringAttr>().getValue(); 349 ++i; 350 auto ty = i->cast<mlir::TypeAttr>().getValue(); 351 auto index = ty.cast<fir::RecordType>().getFieldIndex(fieldName); 352 attrs.push_back(mlir::IntegerAttr::get(rewriter.getI32Type(), index)); 353 } 354 } 355 return attrs; 356 } 357 358 private: 359 static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) { 360 unsigned result = 1; 361 for (auto eleTy = ty.getElementType().dyn_cast<mlir::LLVM::LLVMArrayType>(); 362 eleTy; 363 eleTy = eleTy.getElementType().dyn_cast<mlir::LLVM::LLVMArrayType>()) 364 ++result; 365 return result; 366 } 367 368 static mlir::Type getArrayElementType(mlir::LLVM::LLVMArrayType ty) { 369 auto eleTy = ty.getElementType(); 370 while (auto arrTy = eleTy.dyn_cast<mlir::LLVM::LLVMArrayType>()) 371 eleTy = arrTy.getElementType(); 372 return eleTy; 373 } 374 }; 375 376 /// Extract a subobject value from an ssa-value of aggregate type 377 struct ExtractValueOpConversion 378 : public FIROpAndTypeConversion<fir::ExtractValueOp>, 379 public ValueOpCommon { 380 using FIROpAndTypeConversion::FIROpAndTypeConversion; 381 382 mlir::LogicalResult 383 doRewrite(fir::ExtractValueOp extractVal, mlir::Type ty, OpAdaptor adaptor, 384 mlir::ConversionPatternRewriter &rewriter) const override { 385 auto attrs = collectIndices(rewriter, extractVal.coor()); 386 toRowMajor(attrs, adaptor.getOperands()[0].getType()); 387 auto position = mlir::ArrayAttr::get(extractVal.getContext(), attrs); 388 rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>( 389 extractVal, ty, adaptor.getOperands()[0], position); 390 return success(); 391 } 392 }; 393 394 /// InsertValue is the generalized instruction for the composition of new 395 /// aggregate type values. 396 struct InsertValueOpConversion 397 : public FIROpAndTypeConversion<fir::InsertValueOp>, 398 public ValueOpCommon { 399 using FIROpAndTypeConversion::FIROpAndTypeConversion; 400 401 mlir::LogicalResult 402 doRewrite(fir::InsertValueOp insertVal, mlir::Type ty, OpAdaptor adaptor, 403 mlir::ConversionPatternRewriter &rewriter) const override { 404 auto attrs = collectIndices(rewriter, insertVal.coor()); 405 toRowMajor(attrs, adaptor.getOperands()[0].getType()); 406 auto position = mlir::ArrayAttr::get(insertVal.getContext(), attrs); 407 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>( 408 insertVal, ty, adaptor.getOperands()[0], adaptor.getOperands()[1], 409 position); 410 return success(); 411 } 412 }; 413 414 /// InsertOnRange inserts a value into a sequence over a range of offsets. 415 struct InsertOnRangeOpConversion 416 : public FIROpAndTypeConversion<fir::InsertOnRangeOp> { 417 using FIROpAndTypeConversion::FIROpAndTypeConversion; 418 419 // Increments an array of subscripts in a row major fasion. 420 void incrementSubscripts(const SmallVector<uint64_t> &dims, 421 SmallVector<uint64_t> &subscripts) const { 422 for (size_t i = dims.size(); i > 0; --i) { 423 if (++subscripts[i - 1] < dims[i - 1]) { 424 return; 425 } 426 subscripts[i - 1] = 0; 427 } 428 } 429 430 mlir::LogicalResult 431 doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor, 432 mlir::ConversionPatternRewriter &rewriter) const override { 433 434 llvm::SmallVector<uint64_t> dims; 435 auto type = adaptor.getOperands()[0].getType(); 436 437 // Iteratively extract the array dimensions from the type. 438 while (auto t = type.dyn_cast<mlir::LLVM::LLVMArrayType>()) { 439 dims.push_back(t.getNumElements()); 440 type = t.getElementType(); 441 } 442 443 SmallVector<uint64_t> lBounds; 444 SmallVector<uint64_t> uBounds; 445 446 // Extract integer value from the attribute 447 SmallVector<int64_t> coordinates = llvm::to_vector<4>( 448 llvm::map_range(range.coor(), [](Attribute a) -> int64_t { 449 return a.cast<IntegerAttr>().getInt(); 450 })); 451 452 // Unzip the upper and lower bound and convert to a row major format. 453 for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) { 454 uBounds.push_back(*i++); 455 lBounds.push_back(*i); 456 } 457 458 auto &subscripts = lBounds; 459 auto loc = range.getLoc(); 460 mlir::Value lastOp = adaptor.getOperands()[0]; 461 mlir::Value insertVal = adaptor.getOperands()[1]; 462 463 auto i64Ty = rewriter.getI64Type(); 464 while (subscripts != uBounds) { 465 // Convert uint64_t's to Attribute's. 466 SmallVector<mlir::Attribute> subscriptAttrs; 467 for (const auto &subscript : subscripts) 468 subscriptAttrs.push_back(IntegerAttr::get(i64Ty, subscript)); 469 lastOp = rewriter.create<mlir::LLVM::InsertValueOp>( 470 loc, ty, lastOp, insertVal, 471 ArrayAttr::get(range.getContext(), subscriptAttrs)); 472 473 incrementSubscripts(dims, subscripts); 474 } 475 476 // Convert uint64_t's to Attribute's. 477 SmallVector<mlir::Attribute> subscriptAttrs; 478 for (const auto &subscript : subscripts) 479 subscriptAttrs.push_back( 480 IntegerAttr::get(rewriter.getI64Type(), subscript)); 481 mlir::ArrayRef<mlir::Attribute> arrayRef(subscriptAttrs); 482 483 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>( 484 range, ty, lastOp, insertVal, 485 ArrayAttr::get(range.getContext(), arrayRef)); 486 487 return success(); 488 } 489 }; 490 } // namespace 491 492 namespace { 493 /// Convert FIR dialect to LLVM dialect 494 /// 495 /// This pass lowers all FIR dialect operations to LLVM IR dialect. An 496 /// MLIR pass is used to lower residual Std dialect to LLVM IR dialect. 497 /// 498 /// This pass is not complete yet. We are upstreaming it in small patches. 499 class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> { 500 public: 501 mlir::ModuleOp getModule() { return getOperation(); } 502 503 void runOnOperation() override final { 504 auto *context = getModule().getContext(); 505 fir::LLVMTypeConverter typeConverter{getModule()}; 506 mlir::OwningRewritePatternList pattern(context); 507 pattern.insert< 508 AddrOfOpConversion, CallOpConversion, ExtractValueOpConversion, 509 HasValueOpConversion, GlobalOpConversion, InsertOnRangeOpConversion, 510 InsertValueOpConversion, SelectOpConversion, SelectRankOpConversion, 511 UndefOpConversion, UnreachableOpConversion, ZeroOpConversion>( 512 typeConverter); 513 mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); 514 mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, 515 pattern); 516 mlir::ConversionTarget target{*context}; 517 target.addLegalDialect<mlir::LLVM::LLVMDialect>(); 518 519 // required NOPs for applying a full conversion 520 target.addLegalOp<mlir::ModuleOp>(); 521 522 // apply the patterns 523 if (mlir::failed(mlir::applyFullConversion(getModule(), target, 524 std::move(pattern)))) { 525 signalPassFailure(); 526 } 527 } 528 }; 529 } // namespace 530 531 std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() { 532 return std::make_unique<FIRToLLVMLowering>(); 533 } 534