1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Optimizer/CodeGen/CodeGen.h" 14 #include "PassDetail.h" 15 #include "flang/Optimizer/Dialect/FIROps.h" 16 #include "flang/Optimizer/Dialect/FIRType.h" 17 #include "flang/Optimizer/Support/FIRContext.h" 18 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 19 #include "mlir/Conversion/LLVMCommon/Pattern.h" 20 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 21 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 22 #include "mlir/IR/BuiltinTypes.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/Pass/Pass.h" 25 #include "llvm/ADT/ArrayRef.h" 26 27 #define DEBUG_TYPE "flang-codegen" 28 29 // fir::LLVMTypeConverter for converting to LLVM IR dialect types. 30 #include "TypeConverter.h" 31 32 namespace { 33 /// FIR conversion pattern template 34 template <typename FromOp> 35 class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> { 36 public: 37 explicit FIROpConversion(fir::LLVMTypeConverter &lowering) 38 : mlir::ConvertOpToLLVMPattern<FromOp>(lowering) {} 39 40 protected: 41 mlir::Type convertType(mlir::Type ty) const { 42 return lowerTy().convertType(ty); 43 } 44 45 fir::LLVMTypeConverter &lowerTy() const { 46 return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter()); 47 } 48 }; 49 50 /// FIR conversion pattern template 51 template <typename FromOp> 52 class FIROpAndTypeConversion : public FIROpConversion<FromOp> { 53 public: 54 using FIROpConversion<FromOp>::FIROpConversion; 55 using OpAdaptor = typename FromOp::Adaptor; 56 57 mlir::LogicalResult 58 matchAndRewrite(FromOp op, OpAdaptor adaptor, 59 mlir::ConversionPatternRewriter &rewriter) const final { 60 mlir::Type ty = this->convertType(op.getType()); 61 return doRewrite(op, ty, adaptor, rewriter); 62 } 63 64 virtual mlir::LogicalResult 65 doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor, 66 mlir::ConversionPatternRewriter &rewriter) const = 0; 67 }; 68 69 // Lower `fir.address_of` operation to `llvm.address_of` operation. 70 struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> { 71 using FIROpConversion::FIROpConversion; 72 73 mlir::LogicalResult 74 matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor, 75 mlir::ConversionPatternRewriter &rewriter) const override { 76 auto ty = convertType(addr.getType()); 77 rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>( 78 addr, ty, addr.symbol().getRootReference().getValue()); 79 return success(); 80 } 81 }; 82 83 // `fir.call` -> `llvm.call` 84 struct CallOpConversion : public FIROpConversion<fir::CallOp> { 85 using FIROpConversion::FIROpConversion; 86 87 mlir::LogicalResult 88 matchAndRewrite(fir::CallOp call, OpAdaptor adaptor, 89 mlir::ConversionPatternRewriter &rewriter) const override { 90 SmallVector<mlir::Type> resultTys; 91 for (auto r : call.getResults()) 92 resultTys.push_back(convertType(r.getType())); 93 rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( 94 call, resultTys, adaptor.getOperands(), call->getAttrs()); 95 return success(); 96 } 97 }; 98 99 static mlir::Type getComplexEleTy(mlir::Type complex) { 100 if (auto cc = complex.dyn_cast<mlir::ComplexType>()) 101 return cc.getElementType(); 102 return complex.cast<fir::ComplexType>().getElementType(); 103 } 104 105 /// convert value of from-type to value of to-type 106 struct ConvertOpConversion : public FIROpConversion<fir::ConvertOp> { 107 using FIROpConversion::FIROpConversion; 108 109 static bool isFloatingPointTy(mlir::Type ty) { 110 return ty.isa<mlir::FloatType>(); 111 } 112 113 mlir::LogicalResult 114 matchAndRewrite(fir::ConvertOp convert, OpAdaptor adaptor, 115 mlir::ConversionPatternRewriter &rewriter) const override { 116 auto fromTy = convertType(convert.value().getType()); 117 auto toTy = convertType(convert.res().getType()); 118 mlir::Value op0 = adaptor.getOperands()[0]; 119 if (fromTy == toTy) { 120 rewriter.replaceOp(convert, op0); 121 return success(); 122 } 123 auto loc = convert.getLoc(); 124 auto convertFpToFp = [&](mlir::Value val, unsigned fromBits, 125 unsigned toBits, mlir::Type toTy) -> mlir::Value { 126 if (fromBits == toBits) { 127 // TODO: Converting between two floating-point representations with the 128 // same bitwidth is not allowed for now. 129 mlir::emitError(loc, 130 "cannot implicitly convert between two floating-point " 131 "representations of the same bitwidth"); 132 return {}; 133 } 134 if (fromBits > toBits) 135 return rewriter.create<mlir::LLVM::FPTruncOp>(loc, toTy, val); 136 return rewriter.create<mlir::LLVM::FPExtOp>(loc, toTy, val); 137 }; 138 // Complex to complex conversion. 139 if (fir::isa_complex(convert.value().getType()) && 140 fir::isa_complex(convert.res().getType())) { 141 // Special case: handle the conversion of a complex such that both the 142 // real and imaginary parts are converted together. 143 auto zero = mlir::ArrayAttr::get(convert.getContext(), 144 rewriter.getI32IntegerAttr(0)); 145 auto one = mlir::ArrayAttr::get(convert.getContext(), 146 rewriter.getI32IntegerAttr(1)); 147 auto ty = convertType(getComplexEleTy(convert.value().getType())); 148 auto rp = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, ty, op0, zero); 149 auto ip = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, ty, op0, one); 150 auto nt = convertType(getComplexEleTy(convert.res().getType())); 151 auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(ty); 152 auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(nt); 153 auto rc = convertFpToFp(rp, fromBits, toBits, nt); 154 auto ic = convertFpToFp(ip, fromBits, toBits, nt); 155 auto un = rewriter.create<mlir::LLVM::UndefOp>(loc, toTy); 156 auto i1 = 157 rewriter.create<mlir::LLVM::InsertValueOp>(loc, toTy, un, rc, zero); 158 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(convert, toTy, i1, 159 ic, one); 160 return mlir::success(); 161 } 162 // Floating point to floating point conversion. 163 if (isFloatingPointTy(fromTy)) { 164 if (isFloatingPointTy(toTy)) { 165 auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy); 166 auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy); 167 auto v = convertFpToFp(op0, fromBits, toBits, toTy); 168 rewriter.replaceOp(convert, v); 169 return mlir::success(); 170 } 171 if (toTy.isa<mlir::IntegerType>()) { 172 rewriter.replaceOpWithNewOp<mlir::LLVM::FPToSIOp>(convert, toTy, op0); 173 return mlir::success(); 174 } 175 } else if (fromTy.isa<mlir::IntegerType>()) { 176 // Integer to integer conversion. 177 if (toTy.isa<mlir::IntegerType>()) { 178 auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy); 179 auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy); 180 assert(fromBits != toBits); 181 if (fromBits > toBits) { 182 rewriter.replaceOpWithNewOp<mlir::LLVM::TruncOp>(convert, toTy, op0); 183 return mlir::success(); 184 } 185 rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>(convert, toTy, op0); 186 return mlir::success(); 187 } 188 // Integer to floating point conversion. 189 if (isFloatingPointTy(toTy)) { 190 rewriter.replaceOpWithNewOp<mlir::LLVM::SIToFPOp>(convert, toTy, op0); 191 return mlir::success(); 192 } 193 // Integer to pointer conversion. 194 if (toTy.isa<mlir::LLVM::LLVMPointerType>()) { 195 rewriter.replaceOpWithNewOp<mlir::LLVM::IntToPtrOp>(convert, toTy, op0); 196 return mlir::success(); 197 } 198 } else if (fromTy.isa<mlir::LLVM::LLVMPointerType>()) { 199 // Pointer to integer conversion. 200 if (toTy.isa<mlir::IntegerType>()) { 201 rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(convert, toTy, op0); 202 return mlir::success(); 203 } 204 // Pointer to pointer conversion. 205 if (toTy.isa<mlir::LLVM::LLVMPointerType>()) { 206 rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(convert, toTy, op0); 207 return mlir::success(); 208 } 209 } 210 return emitError(loc) << "cannot convert " << fromTy << " to " << toTy; 211 } 212 }; 213 214 /// Lower `fir.has_value` operation to `llvm.return` operation. 215 struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> { 216 using FIROpConversion::FIROpConversion; 217 218 mlir::LogicalResult 219 matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor, 220 mlir::ConversionPatternRewriter &rewriter) const override { 221 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands()); 222 return success(); 223 } 224 }; 225 226 /// Lower `fir.global` operation to `llvm.global` operation. 227 /// `fir.insert_on_range` operations are replaced with constant dense attribute 228 /// if they are applied on the full range. 229 struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> { 230 using FIROpConversion::FIROpConversion; 231 232 mlir::LogicalResult 233 matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor, 234 mlir::ConversionPatternRewriter &rewriter) const override { 235 auto tyAttr = convertType(global.getType()); 236 if (global.getType().isa<fir::BoxType>()) 237 tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType(); 238 auto loc = global.getLoc(); 239 mlir::Attribute initAttr{}; 240 if (global.initVal()) 241 initAttr = global.initVal().getValue(); 242 auto linkage = convertLinkage(global.linkName()); 243 auto isConst = global.constant().hasValue(); 244 auto g = rewriter.create<mlir::LLVM::GlobalOp>( 245 loc, tyAttr, isConst, linkage, global.sym_name(), initAttr); 246 auto &gr = g.getInitializerRegion(); 247 rewriter.inlineRegionBefore(global.region(), gr, gr.end()); 248 if (!gr.empty()) { 249 // Replace insert_on_range with a constant dense attribute if the 250 // initialization is on the full range. 251 auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>(); 252 for (auto insertOp : insertOnRangeOps) { 253 if (isFullRange(insertOp.coor(), insertOp.getType())) { 254 auto seqTyAttr = convertType(insertOp.getType()); 255 auto *op = insertOp.val().getDefiningOp(); 256 auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op); 257 if (!constant) { 258 auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op); 259 if (!convertOp) 260 continue; 261 constant = cast<mlir::arith::ConstantOp>( 262 convertOp.value().getDefiningOp()); 263 } 264 mlir::Type vecType = mlir::VectorType::get( 265 insertOp.getType().getShape(), constant.getType()); 266 auto denseAttr = mlir::DenseElementsAttr::get( 267 vecType.cast<ShapedType>(), constant.value()); 268 rewriter.setInsertionPointAfter(insertOp); 269 rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>( 270 insertOp, seqTyAttr, denseAttr); 271 } 272 } 273 } 274 rewriter.eraseOp(global); 275 return success(); 276 } 277 278 bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const { 279 auto extents = seqTy.getShape(); 280 if (indexes.size() / 2 != extents.size()) 281 return false; 282 for (unsigned i = 0; i < indexes.size(); i += 2) { 283 if (indexes[i].cast<IntegerAttr>().getInt() != 0) 284 return false; 285 if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1) 286 return false; 287 } 288 return true; 289 } 290 291 // TODO: String comparaison should be avoided. Replace linkName with an 292 // enumeration. 293 mlir::LLVM::Linkage convertLinkage(Optional<StringRef> optLinkage) const { 294 if (optLinkage.hasValue()) { 295 auto name = optLinkage.getValue(); 296 if (name == "internal") 297 return mlir::LLVM::Linkage::Internal; 298 if (name == "linkonce") 299 return mlir::LLVM::Linkage::Linkonce; 300 if (name == "common") 301 return mlir::LLVM::Linkage::Common; 302 if (name == "weak") 303 return mlir::LLVM::Linkage::Weak; 304 } 305 return mlir::LLVM::Linkage::External; 306 } 307 }; 308 309 template <typename OP> 310 void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select, 311 typename OP::Adaptor adaptor, 312 mlir::ConversionPatternRewriter &rewriter) { 313 unsigned conds = select.getNumConditions(); 314 auto cases = select.getCases().getValue(); 315 mlir::Value selector = adaptor.selector(); 316 auto loc = select.getLoc(); 317 assert(conds > 0 && "select must have cases"); 318 319 llvm::SmallVector<mlir::Block *> destinations; 320 llvm::SmallVector<mlir::ValueRange> destinationsOperands; 321 mlir::Block *defaultDestination; 322 mlir::ValueRange defaultOperands; 323 llvm::SmallVector<int32_t> caseValues; 324 325 for (unsigned t = 0; t != conds; ++t) { 326 mlir::Block *dest = select.getSuccessor(t); 327 auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t); 328 const mlir::Attribute &attr = cases[t]; 329 if (auto intAttr = attr.template dyn_cast<mlir::IntegerAttr>()) { 330 destinations.push_back(dest); 331 destinationsOperands.push_back(destOps.hasValue() ? *destOps 332 : ValueRange()); 333 caseValues.push_back(intAttr.getInt()); 334 continue; 335 } 336 assert(attr.template dyn_cast_or_null<mlir::UnitAttr>()); 337 assert((t + 1 == conds) && "unit must be last"); 338 defaultDestination = dest; 339 defaultOperands = destOps.hasValue() ? *destOps : ValueRange(); 340 } 341 342 // LLVM::SwitchOp takes a i32 type for the selector. 343 if (select.getSelector().getType() != rewriter.getI32Type()) 344 selector = 345 rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), selector); 346 347 rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>( 348 select, selector, 349 /*defaultDestination=*/defaultDestination, 350 /*defaultOperands=*/defaultOperands, 351 /*caseValues=*/caseValues, 352 /*caseDestinations=*/destinations, 353 /*caseOperands=*/destinationsOperands, 354 /*branchWeights=*/ArrayRef<int32_t>()); 355 } 356 357 /// conversion of fir::SelectOp to an if-then-else ladder 358 struct SelectOpConversion : public FIROpConversion<fir::SelectOp> { 359 using FIROpConversion::FIROpConversion; 360 361 mlir::LogicalResult 362 matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor, 363 mlir::ConversionPatternRewriter &rewriter) const override { 364 selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, rewriter); 365 return success(); 366 } 367 }; 368 369 /// `fir.load` --> `llvm.load` 370 struct LoadOpConversion : public FIROpConversion<fir::LoadOp> { 371 using FIROpConversion::FIROpConversion; 372 373 mlir::LogicalResult 374 matchAndRewrite(fir::LoadOp load, OpAdaptor adaptor, 375 mlir::ConversionPatternRewriter &rewriter) const override { 376 // fir.box is a special case because it is considered as an ssa values in 377 // fir, but it is lowered as a pointer to a descriptor. So fir.ref<fir.box> 378 // and fir.box end up being the same llvm types and loading a 379 // fir.ref<fir.box> is actually a no op in LLVM. 380 if (load.getType().isa<fir::BoxType>()) { 381 rewriter.replaceOp(load, adaptor.getOperands()[0]); 382 } else { 383 mlir::Type ty = convertType(load.getType()); 384 ArrayRef<NamedAttribute> at = load->getAttrs(); 385 rewriter.replaceOpWithNewOp<mlir::LLVM::LoadOp>( 386 load, ty, adaptor.getOperands(), at); 387 } 388 return success(); 389 } 390 }; 391 392 /// conversion of fir::SelectRankOp to an if-then-else ladder 393 struct SelectRankOpConversion : public FIROpConversion<fir::SelectRankOp> { 394 using FIROpConversion::FIROpConversion; 395 396 mlir::LogicalResult 397 matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor, 398 mlir::ConversionPatternRewriter &rewriter) const override { 399 selectMatchAndRewrite<fir::SelectRankOp>(lowerTy(), op, adaptor, rewriter); 400 return success(); 401 } 402 }; 403 404 /// `fir.store` --> `llvm.store` 405 struct StoreOpConversion : public FIROpConversion<fir::StoreOp> { 406 using FIROpConversion::FIROpConversion; 407 408 mlir::LogicalResult 409 matchAndRewrite(fir::StoreOp store, OpAdaptor adaptor, 410 mlir::ConversionPatternRewriter &rewriter) const override { 411 if (store.value().getType().isa<fir::BoxType>()) { 412 // fir.box value is actually in memory, load it first before storing it. 413 mlir::Location loc = store.getLoc(); 414 mlir::Type boxPtrTy = adaptor.getOperands()[0].getType(); 415 auto val = rewriter.create<mlir::LLVM::LoadOp>( 416 loc, boxPtrTy.cast<mlir::LLVM::LLVMPointerType>().getElementType(), 417 adaptor.getOperands()[0]); 418 rewriter.replaceOpWithNewOp<mlir::LLVM::StoreOp>( 419 store, val, adaptor.getOperands()[1]); 420 } else { 421 rewriter.replaceOpWithNewOp<mlir::LLVM::StoreOp>( 422 store, adaptor.getOperands()[0], adaptor.getOperands()[1]); 423 } 424 return success(); 425 } 426 }; 427 428 /// convert to LLVM IR dialect `undef` 429 struct UndefOpConversion : public FIROpConversion<fir::UndefOp> { 430 using FIROpConversion::FIROpConversion; 431 432 mlir::LogicalResult 433 matchAndRewrite(fir::UndefOp undef, OpAdaptor, 434 mlir::ConversionPatternRewriter &rewriter) const override { 435 rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>( 436 undef, convertType(undef.getType())); 437 return success(); 438 } 439 }; 440 441 /// `fir.unreachable` --> `llvm.unreachable` 442 struct UnreachableOpConversion : public FIROpConversion<fir::UnreachableOp> { 443 using FIROpConversion::FIROpConversion; 444 445 mlir::LogicalResult 446 matchAndRewrite(fir::UnreachableOp unreach, OpAdaptor adaptor, 447 mlir::ConversionPatternRewriter &rewriter) const override { 448 rewriter.replaceOpWithNewOp<mlir::LLVM::UnreachableOp>(unreach); 449 return success(); 450 } 451 }; 452 453 struct ZeroOpConversion : public FIROpConversion<fir::ZeroOp> { 454 using FIROpConversion::FIROpConversion; 455 456 mlir::LogicalResult 457 matchAndRewrite(fir::ZeroOp zero, OpAdaptor, 458 mlir::ConversionPatternRewriter &rewriter) const override { 459 auto ty = convertType(zero.getType()); 460 if (ty.isa<mlir::LLVM::LLVMPointerType>()) { 461 rewriter.replaceOpWithNewOp<mlir::LLVM::NullOp>(zero, ty); 462 } else if (ty.isa<mlir::IntegerType>()) { 463 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>( 464 zero, ty, mlir::IntegerAttr::get(zero.getType(), 0)); 465 } else if (mlir::LLVM::isCompatibleFloatingPointType(ty)) { 466 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>( 467 zero, ty, mlir::FloatAttr::get(zero.getType(), 0.0)); 468 } else { 469 // TODO: create ConstantAggregateZero for FIR aggregate/array types. 470 return rewriter.notifyMatchFailure( 471 zero, 472 "conversion of fir.zero with aggregate type not implemented yet"); 473 } 474 return success(); 475 } 476 }; 477 478 // Code shared between insert_value and extract_value Ops. 479 struct ValueOpCommon { 480 // Translate the arguments pertaining to any multidimensional array to 481 // row-major order for LLVM-IR. 482 static void toRowMajor(SmallVectorImpl<mlir::Attribute> &attrs, 483 mlir::Type ty) { 484 assert(ty && "type is null"); 485 const auto end = attrs.size(); 486 for (std::remove_const_t<decltype(end)> i = 0; i < end; ++i) { 487 if (auto seq = ty.dyn_cast<mlir::LLVM::LLVMArrayType>()) { 488 const auto dim = getDimension(seq); 489 if (dim > 1) { 490 auto ub = std::min(i + dim, end); 491 std::reverse(attrs.begin() + i, attrs.begin() + ub); 492 i += dim - 1; 493 } 494 ty = getArrayElementType(seq); 495 } else if (auto st = ty.dyn_cast<mlir::LLVM::LLVMStructType>()) { 496 ty = st.getBody()[attrs[i].cast<mlir::IntegerAttr>().getInt()]; 497 } else { 498 llvm_unreachable("index into invalid type"); 499 } 500 } 501 } 502 503 static llvm::SmallVector<mlir::Attribute> 504 collectIndices(mlir::ConversionPatternRewriter &rewriter, 505 mlir::ArrayAttr arrAttr) { 506 llvm::SmallVector<mlir::Attribute> attrs; 507 for (auto i = arrAttr.begin(), e = arrAttr.end(); i != e; ++i) { 508 if (i->isa<mlir::IntegerAttr>()) { 509 attrs.push_back(*i); 510 } else { 511 auto fieldName = i->cast<mlir::StringAttr>().getValue(); 512 ++i; 513 auto ty = i->cast<mlir::TypeAttr>().getValue(); 514 auto index = ty.cast<fir::RecordType>().getFieldIndex(fieldName); 515 attrs.push_back(mlir::IntegerAttr::get(rewriter.getI32Type(), index)); 516 } 517 } 518 return attrs; 519 } 520 521 private: 522 static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) { 523 unsigned result = 1; 524 for (auto eleTy = ty.getElementType().dyn_cast<mlir::LLVM::LLVMArrayType>(); 525 eleTy; 526 eleTy = eleTy.getElementType().dyn_cast<mlir::LLVM::LLVMArrayType>()) 527 ++result; 528 return result; 529 } 530 531 static mlir::Type getArrayElementType(mlir::LLVM::LLVMArrayType ty) { 532 auto eleTy = ty.getElementType(); 533 while (auto arrTy = eleTy.dyn_cast<mlir::LLVM::LLVMArrayType>()) 534 eleTy = arrTy.getElementType(); 535 return eleTy; 536 } 537 }; 538 539 /// Extract a subobject value from an ssa-value of aggregate type 540 struct ExtractValueOpConversion 541 : public FIROpAndTypeConversion<fir::ExtractValueOp>, 542 public ValueOpCommon { 543 using FIROpAndTypeConversion::FIROpAndTypeConversion; 544 545 mlir::LogicalResult 546 doRewrite(fir::ExtractValueOp extractVal, mlir::Type ty, OpAdaptor adaptor, 547 mlir::ConversionPatternRewriter &rewriter) const override { 548 auto attrs = collectIndices(rewriter, extractVal.coor()); 549 toRowMajor(attrs, adaptor.getOperands()[0].getType()); 550 auto position = mlir::ArrayAttr::get(extractVal.getContext(), attrs); 551 rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>( 552 extractVal, ty, adaptor.getOperands()[0], position); 553 return success(); 554 } 555 }; 556 557 /// InsertValue is the generalized instruction for the composition of new 558 /// aggregate type values. 559 struct InsertValueOpConversion 560 : public FIROpAndTypeConversion<fir::InsertValueOp>, 561 public ValueOpCommon { 562 using FIROpAndTypeConversion::FIROpAndTypeConversion; 563 564 mlir::LogicalResult 565 doRewrite(fir::InsertValueOp insertVal, mlir::Type ty, OpAdaptor adaptor, 566 mlir::ConversionPatternRewriter &rewriter) const override { 567 auto attrs = collectIndices(rewriter, insertVal.coor()); 568 toRowMajor(attrs, adaptor.getOperands()[0].getType()); 569 auto position = mlir::ArrayAttr::get(insertVal.getContext(), attrs); 570 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>( 571 insertVal, ty, adaptor.getOperands()[0], adaptor.getOperands()[1], 572 position); 573 return success(); 574 } 575 }; 576 577 /// InsertOnRange inserts a value into a sequence over a range of offsets. 578 struct InsertOnRangeOpConversion 579 : public FIROpAndTypeConversion<fir::InsertOnRangeOp> { 580 using FIROpAndTypeConversion::FIROpAndTypeConversion; 581 582 // Increments an array of subscripts in a row major fasion. 583 void incrementSubscripts(const SmallVector<uint64_t> &dims, 584 SmallVector<uint64_t> &subscripts) const { 585 for (size_t i = dims.size(); i > 0; --i) { 586 if (++subscripts[i - 1] < dims[i - 1]) { 587 return; 588 } 589 subscripts[i - 1] = 0; 590 } 591 } 592 593 mlir::LogicalResult 594 doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor, 595 mlir::ConversionPatternRewriter &rewriter) const override { 596 597 llvm::SmallVector<uint64_t> dims; 598 auto type = adaptor.getOperands()[0].getType(); 599 600 // Iteratively extract the array dimensions from the type. 601 while (auto t = type.dyn_cast<mlir::LLVM::LLVMArrayType>()) { 602 dims.push_back(t.getNumElements()); 603 type = t.getElementType(); 604 } 605 606 SmallVector<uint64_t> lBounds; 607 SmallVector<uint64_t> uBounds; 608 609 // Extract integer value from the attribute 610 SmallVector<int64_t> coordinates = llvm::to_vector<4>( 611 llvm::map_range(range.coor(), [](Attribute a) -> int64_t { 612 return a.cast<IntegerAttr>().getInt(); 613 })); 614 615 // Unzip the upper and lower bound and convert to a row major format. 616 for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) { 617 uBounds.push_back(*i++); 618 lBounds.push_back(*i); 619 } 620 621 auto &subscripts = lBounds; 622 auto loc = range.getLoc(); 623 mlir::Value lastOp = adaptor.getOperands()[0]; 624 mlir::Value insertVal = adaptor.getOperands()[1]; 625 626 auto i64Ty = rewriter.getI64Type(); 627 while (subscripts != uBounds) { 628 // Convert uint64_t's to Attribute's. 629 SmallVector<mlir::Attribute> subscriptAttrs; 630 for (const auto &subscript : subscripts) 631 subscriptAttrs.push_back(IntegerAttr::get(i64Ty, subscript)); 632 lastOp = rewriter.create<mlir::LLVM::InsertValueOp>( 633 loc, ty, lastOp, insertVal, 634 ArrayAttr::get(range.getContext(), subscriptAttrs)); 635 636 incrementSubscripts(dims, subscripts); 637 } 638 639 // Convert uint64_t's to Attribute's. 640 SmallVector<mlir::Attribute> subscriptAttrs; 641 for (const auto &subscript : subscripts) 642 subscriptAttrs.push_back( 643 IntegerAttr::get(rewriter.getI64Type(), subscript)); 644 mlir::ArrayRef<mlir::Attribute> arrayRef(subscriptAttrs); 645 646 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>( 647 range, ty, lastOp, insertVal, 648 ArrayAttr::get(range.getContext(), arrayRef)); 649 650 return success(); 651 } 652 }; 653 654 // 655 // Primitive operations on Complex types 656 // 657 658 /// Generate inline code for complex addition/subtraction 659 template <typename LLVMOP, typename OPTY> 660 mlir::LLVM::InsertValueOp complexSum(OPTY sumop, mlir::ValueRange opnds, 661 mlir::ConversionPatternRewriter &rewriter, 662 fir::LLVMTypeConverter &lowering) { 663 mlir::Value a = opnds[0]; 664 mlir::Value b = opnds[1]; 665 auto loc = sumop.getLoc(); 666 auto ctx = sumop.getContext(); 667 auto c0 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(0)); 668 auto c1 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(1)); 669 mlir::Type eleTy = lowering.convertType(getComplexEleTy(sumop.getType())); 670 mlir::Type ty = lowering.convertType(sumop.getType()); 671 auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c0); 672 auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c1); 673 auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c0); 674 auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c1); 675 auto rx = rewriter.create<LLVMOP>(loc, eleTy, x0, x1); 676 auto ry = rewriter.create<LLVMOP>(loc, eleTy, y0, y1); 677 auto r0 = rewriter.create<mlir::LLVM::UndefOp>(loc, ty); 678 auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, r0, rx, c0); 679 return rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, r1, ry, c1); 680 } 681 682 struct AddcOpConversion : public FIROpConversion<fir::AddcOp> { 683 using FIROpConversion::FIROpConversion; 684 685 mlir::LogicalResult 686 matchAndRewrite(fir::AddcOp addc, OpAdaptor adaptor, 687 mlir::ConversionPatternRewriter &rewriter) const override { 688 // given: (x + iy) + (x' + iy') 689 // result: (x + x') + i(y + y') 690 auto r = complexSum<mlir::LLVM::FAddOp>(addc, adaptor.getOperands(), 691 rewriter, lowerTy()); 692 rewriter.replaceOp(addc, r.getResult()); 693 return success(); 694 } 695 }; 696 697 struct SubcOpConversion : public FIROpConversion<fir::SubcOp> { 698 using FIROpConversion::FIROpConversion; 699 700 mlir::LogicalResult 701 matchAndRewrite(fir::SubcOp subc, OpAdaptor adaptor, 702 mlir::ConversionPatternRewriter &rewriter) const override { 703 // given: (x + iy) - (x' + iy') 704 // result: (x - x') + i(y - y') 705 auto r = complexSum<mlir::LLVM::FSubOp>(subc, adaptor.getOperands(), 706 rewriter, lowerTy()); 707 rewriter.replaceOp(subc, r.getResult()); 708 return success(); 709 } 710 }; 711 712 /// Inlined complex multiply 713 struct MulcOpConversion : public FIROpConversion<fir::MulcOp> { 714 using FIROpConversion::FIROpConversion; 715 716 mlir::LogicalResult 717 matchAndRewrite(fir::MulcOp mulc, OpAdaptor adaptor, 718 mlir::ConversionPatternRewriter &rewriter) const override { 719 // TODO: Can we use a call to __muldc3 ? 720 // given: (x + iy) * (x' + iy') 721 // result: (xx'-yy')+i(xy'+yx') 722 mlir::Value a = adaptor.getOperands()[0]; 723 mlir::Value b = adaptor.getOperands()[1]; 724 auto loc = mulc.getLoc(); 725 auto *ctx = mulc.getContext(); 726 auto c0 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(0)); 727 auto c1 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(1)); 728 mlir::Type eleTy = convertType(getComplexEleTy(mulc.getType())); 729 mlir::Type ty = convertType(mulc.getType()); 730 auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c0); 731 auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c1); 732 auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c0); 733 auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c1); 734 auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1); 735 auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1); 736 auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1); 737 auto ri = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xy, yx); 738 auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1); 739 auto rr = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, xx, yy); 740 auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty); 741 auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, ra, rr, c0); 742 auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, r1, ri, c1); 743 rewriter.replaceOp(mulc, r0.getResult()); 744 return success(); 745 } 746 }; 747 748 /// Inlined complex division 749 struct DivcOpConversion : public FIROpConversion<fir::DivcOp> { 750 using FIROpConversion::FIROpConversion; 751 752 mlir::LogicalResult 753 matchAndRewrite(fir::DivcOp divc, OpAdaptor adaptor, 754 mlir::ConversionPatternRewriter &rewriter) const override { 755 // TODO: Can we use a call to __divdc3 instead? 756 // Just generate inline code for now. 757 // given: (x + iy) / (x' + iy') 758 // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y' 759 mlir::Value a = adaptor.getOperands()[0]; 760 mlir::Value b = adaptor.getOperands()[1]; 761 auto loc = divc.getLoc(); 762 auto *ctx = divc.getContext(); 763 auto c0 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(0)); 764 auto c1 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(1)); 765 mlir::Type eleTy = convertType(getComplexEleTy(divc.getType())); 766 mlir::Type ty = convertType(divc.getType()); 767 auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c0); 768 auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, a, c1); 769 auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c0); 770 auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, b, c1); 771 auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1); 772 auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1); 773 auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1); 774 auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1); 775 auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1); 776 auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1); 777 auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1); 778 auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy); 779 auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy); 780 auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d); 781 auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d); 782 auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty); 783 auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, ra, rr, c0); 784 auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, r1, ri, c1); 785 rewriter.replaceOp(divc, r0.getResult()); 786 return success(); 787 } 788 }; 789 790 /// Inlined complex negation 791 struct NegcOpConversion : public FIROpConversion<fir::NegcOp> { 792 using FIROpConversion::FIROpConversion; 793 794 mlir::LogicalResult 795 matchAndRewrite(fir::NegcOp neg, OpAdaptor adaptor, 796 mlir::ConversionPatternRewriter &rewriter) const override { 797 // given: -(x + iy) 798 // result: -x - iy 799 auto *ctxt = neg.getContext(); 800 auto eleTy = convertType(getComplexEleTy(neg.getType())); 801 auto ty = convertType(neg.getType()); 802 auto loc = neg.getLoc(); 803 mlir::Value o0 = adaptor.getOperands()[0]; 804 auto c0 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(0)); 805 auto c1 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(1)); 806 auto rp = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, o0, c0); 807 auto ip = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, eleTy, o0, c1); 808 auto nrp = rewriter.create<mlir::LLVM::FNegOp>(loc, eleTy, rp); 809 auto nip = rewriter.create<mlir::LLVM::FNegOp>(loc, eleTy, ip); 810 auto r = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, o0, nrp, c0); 811 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(neg, ty, r, nip, c1); 812 return success(); 813 } 814 }; 815 816 } // namespace 817 818 namespace { 819 /// Convert FIR dialect to LLVM dialect 820 /// 821 /// This pass lowers all FIR dialect operations to LLVM IR dialect. An 822 /// MLIR pass is used to lower residual Std dialect to LLVM IR dialect. 823 /// 824 /// This pass is not complete yet. We are upstreaming it in small patches. 825 class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> { 826 public: 827 mlir::ModuleOp getModule() { return getOperation(); } 828 829 void runOnOperation() override final { 830 auto mod = getModule(); 831 if (!forcedTargetTriple.empty()) { 832 fir::setTargetTriple(mod, forcedTargetTriple); 833 } 834 835 auto *context = getModule().getContext(); 836 fir::LLVMTypeConverter typeConverter{getModule()}; 837 mlir::OwningRewritePatternList pattern(context); 838 pattern.insert<AddcOpConversion, AddrOfOpConversion, CallOpConversion, 839 ConvertOpConversion, DivcOpConversion, 840 ExtractValueOpConversion, HasValueOpConversion, 841 GlobalOpConversion, InsertOnRangeOpConversion, 842 InsertValueOpConversion, LoadOpConversion, NegcOpConversion, 843 MulcOpConversion, SelectOpConversion, SelectRankOpConversion, 844 StoreOpConversion, SubcOpConversion, UndefOpConversion, 845 UnreachableOpConversion, ZeroOpConversion>(typeConverter); 846 mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); 847 mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, 848 pattern); 849 mlir::ConversionTarget target{*context}; 850 target.addLegalDialect<mlir::LLVM::LLVMDialect>(); 851 852 // required NOPs for applying a full conversion 853 target.addLegalOp<mlir::ModuleOp>(); 854 855 // apply the patterns 856 if (mlir::failed(mlir::applyFullConversion(getModule(), target, 857 std::move(pattern)))) { 858 signalPassFailure(); 859 } 860 } 861 }; 862 } // namespace 863 864 std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() { 865 return std::make_unique<FIRToLLVMLowering>(); 866 } 867