1 //===-- TargetRewrite.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest. 10 // LLVM expects different lowering idioms to be used for distinct target 11 // triples. These distinctions are handled by this pass. 12 // 13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "PassDetail.h" 18 #include "Target.h" 19 #include "flang/Optimizer/Builder/Character.h" 20 #include "flang/Optimizer/Builder/FIRBuilder.h" 21 #include "flang/Optimizer/Builder/Todo.h" 22 #include "flang/Optimizer/CodeGen/CodeGen.h" 23 #include "flang/Optimizer/Dialect/FIRDialect.h" 24 #include "flang/Optimizer/Dialect/FIROps.h" 25 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 26 #include "flang/Optimizer/Dialect/FIRType.h" 27 #include "flang/Optimizer/Support/FIRContext.h" 28 #include "mlir/Transforms/DialectConversion.h" 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/ADT/TypeSwitch.h" 31 #include "llvm/Support/Debug.h" 32 33 #define DEBUG_TYPE "flang-target-rewrite" 34 35 namespace { 36 37 /// Fixups for updating a FuncOp's arguments and return values. 38 struct FixupTy { 39 enum class Codes { 40 ArgumentAsLoad, 41 ArgumentType, 42 CharPair, 43 ReturnAsStore, 44 ReturnType, 45 Split, 46 Trailing, 47 TrailingCharProc 48 }; 49 50 FixupTy(Codes code, std::size_t index, std::size_t second = 0) 51 : code{code}, index{index}, second{second} {} 52 FixupTy(Codes code, std::size_t index, 53 std::function<void(mlir::func::FuncOp)> &&finalizer) 54 : code{code}, index{index}, finalizer{finalizer} {} 55 FixupTy(Codes code, std::size_t index, std::size_t second, 56 std::function<void(mlir::func::FuncOp)> &&finalizer) 57 : code{code}, index{index}, second{second}, finalizer{finalizer} {} 58 59 Codes code; 60 std::size_t index; 61 std::size_t second{}; 62 llvm::Optional<std::function<void(mlir::func::FuncOp)>> finalizer{}; 63 }; // namespace 64 65 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code 66 /// generation that traverses the FIR and modifies types and operations to a 67 /// form that is appropriate for the specific target. LLVM IR has specific 68 /// idioms that are used for distinct target processor and ABI combinations. 69 class TargetRewrite : public fir::TargetRewriteBase<TargetRewrite> { 70 public: 71 TargetRewrite(const fir::TargetRewriteOptions &options) { 72 noCharacterConversion = options.noCharacterConversion; 73 noComplexConversion = options.noComplexConversion; 74 } 75 76 void runOnOperation() override final { 77 auto &context = getContext(); 78 mlir::OpBuilder rewriter(&context); 79 80 auto mod = getModule(); 81 if (!forcedTargetTriple.empty()) 82 fir::setTargetTriple(mod, forcedTargetTriple); 83 84 auto specifics = fir::CodeGenSpecifics::get( 85 mod.getContext(), fir::getTargetTriple(mod), fir::getKindMapping(mod)); 86 setMembers(specifics.get(), &rewriter); 87 88 // Perform type conversion on signatures and call sites. 89 if (mlir::failed(convertTypes(mod))) { 90 mlir::emitError(mlir::UnknownLoc::get(&context), 91 "error in converting types to target abi"); 92 signalPassFailure(); 93 } 94 95 // Convert ops in target-specific patterns. 96 mod.walk([&](mlir::Operation *op) { 97 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) { 98 if (!hasPortableSignature(call.getFunctionType())) 99 convertCallOp(call); 100 } else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) { 101 if (!hasPortableSignature(dispatch.getFunctionType())) 102 convertCallOp(dispatch); 103 } else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) { 104 if (addr.getType().isa<mlir::FunctionType>() && 105 !hasPortableSignature(addr.getType())) 106 convertAddrOp(addr); 107 } 108 }); 109 110 clearMembers(); 111 } 112 113 mlir::ModuleOp getModule() { return getOperation(); } 114 115 template <typename A, typename B, typename C> 116 std::function<mlir::Value(mlir::Operation *)> 117 rewriteCallComplexResultType(mlir::Location loc, A ty, B &newResTys, 118 B &newInTys, C &newOpers) { 119 auto m = specifics->complexReturnType(loc, ty.getElementType()); 120 // Currently targets mandate COMPLEX is a single aggregate or packed 121 // scalar, including the sret case. 122 assert(m.size() == 1 && "target of complex return not supported"); 123 auto resTy = std::get<mlir::Type>(m[0]); 124 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]); 125 if (attr.isSRet()) { 126 assert(fir::isa_ref_type(resTy) && "must be a memory reference type"); 127 mlir::Value stack = 128 rewriter->create<fir::AllocaOp>(loc, fir::dyn_cast_ptrEleTy(resTy)); 129 newInTys.push_back(resTy); 130 newOpers.push_back(stack); 131 return [=](mlir::Operation *) -> mlir::Value { 132 auto memTy = fir::ReferenceType::get(ty); 133 auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, stack); 134 return rewriter->create<fir::LoadOp>(loc, cast); 135 }; 136 } 137 newResTys.push_back(resTy); 138 return [=](mlir::Operation *call) -> mlir::Value { 139 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 140 rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem); 141 auto memTy = fir::ReferenceType::get(ty); 142 auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, mem); 143 return rewriter->create<fir::LoadOp>(loc, cast); 144 }; 145 } 146 147 template <typename A, typename B, typename C> 148 void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys, 149 C &newOpers) { 150 auto *ctx = ty.getContext(); 151 mlir::Location loc = mlir::UnknownLoc::get(ctx); 152 if (auto *op = oper.getDefiningOp()) 153 loc = op->getLoc(); 154 auto m = specifics->complexArgumentType(loc, ty.getElementType()); 155 if (m.size() == 1) { 156 // COMPLEX is a single aggregate 157 auto resTy = std::get<mlir::Type>(m[0]); 158 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]); 159 auto oldRefTy = fir::ReferenceType::get(ty); 160 if (attr.isByVal()) { 161 auto mem = rewriter->create<fir::AllocaOp>(loc, ty); 162 rewriter->create<fir::StoreOp>(loc, oper, mem); 163 newOpers.push_back(rewriter->create<fir::ConvertOp>(loc, resTy, mem)); 164 } else { 165 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 166 auto cast = rewriter->create<fir::ConvertOp>(loc, oldRefTy, mem); 167 rewriter->create<fir::StoreOp>(loc, oper, cast); 168 newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem)); 169 } 170 newInTys.push_back(resTy); 171 } else { 172 assert(m.size() == 2); 173 // COMPLEX is split into 2 separate arguments 174 auto iTy = rewriter->getIntegerType(32); 175 for (auto e : llvm::enumerate(m)) { 176 auto &tup = e.value(); 177 auto ty = std::get<mlir::Type>(tup); 178 auto index = e.index(); 179 auto idx = rewriter->getIntegerAttr(iTy, index); 180 auto val = rewriter->create<fir::ExtractValueOp>( 181 loc, ty, oper, rewriter->getArrayAttr(idx)); 182 newInTys.push_back(ty); 183 newOpers.push_back(val); 184 } 185 } 186 } 187 188 // Convert fir.call and fir.dispatch Ops. 189 template <typename A> 190 void convertCallOp(A callOp) { 191 auto fnTy = callOp.getFunctionType(); 192 auto loc = callOp.getLoc(); 193 rewriter->setInsertionPoint(callOp); 194 llvm::SmallVector<mlir::Type> newResTys; 195 llvm::SmallVector<mlir::Type> newInTys; 196 llvm::SmallVector<mlir::Value> newOpers; 197 198 // If the call is indirect, the first argument must still be the function 199 // to call. 200 int dropFront = 0; 201 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 202 if (!callOp.getCallee()) { 203 newInTys.push_back(fnTy.getInput(0)); 204 newOpers.push_back(callOp.getOperand(0)); 205 dropFront = 1; 206 } 207 } 208 209 // Determine the rewrite function, `wrap`, for the result value. 210 llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap; 211 if (fnTy.getResults().size() == 1) { 212 mlir::Type ty = fnTy.getResult(0); 213 llvm::TypeSwitch<mlir::Type>(ty) 214 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 215 wrap = rewriteCallComplexResultType(loc, cmplx, newResTys, newInTys, 216 newOpers); 217 }) 218 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 219 wrap = rewriteCallComplexResultType(loc, cmplx, newResTys, newInTys, 220 newOpers); 221 }) 222 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 223 } else if (fnTy.getResults().size() > 1) { 224 TODO(loc, "multiple results not supported yet"); 225 } 226 227 llvm::SmallVector<mlir::Type> trailingInTys; 228 llvm::SmallVector<mlir::Value> trailingOpers; 229 for (auto e : llvm::enumerate( 230 llvm::zip(fnTy.getInputs().drop_front(dropFront), 231 callOp.getOperands().drop_front(dropFront)))) { 232 mlir::Type ty = std::get<0>(e.value()); 233 mlir::Value oper = std::get<1>(e.value()); 234 unsigned index = e.index(); 235 llvm::TypeSwitch<mlir::Type>(ty) 236 .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) { 237 bool sret; 238 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 239 sret = callOp.getCallee() && 240 functionArgIsSRet( 241 index, getModule().lookupSymbol<mlir::func::FuncOp>( 242 *callOp.getCallee())); 243 } else { 244 // TODO: dispatch case; how do we put arguments on a call? 245 // We cannot put both an sret and the dispatch object first. 246 sret = false; 247 TODO(loc, "dispatch + sret not supported yet"); 248 } 249 auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret); 250 auto unbox = rewriter->create<fir::UnboxCharOp>( 251 loc, std::get<mlir::Type>(m[0]), std::get<mlir::Type>(m[1]), 252 oper); 253 // unboxed CHARACTER arguments 254 for (auto e : llvm::enumerate(m)) { 255 unsigned idx = e.index(); 256 auto attr = 257 std::get<fir::CodeGenSpecifics::Attributes>(e.value()); 258 auto argTy = std::get<mlir::Type>(e.value()); 259 if (attr.isAppend()) { 260 trailingInTys.push_back(argTy); 261 trailingOpers.push_back(unbox.getResult(idx)); 262 } else { 263 newInTys.push_back(argTy); 264 newOpers.push_back(unbox.getResult(idx)); 265 } 266 } 267 }) 268 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 269 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 270 }) 271 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 272 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 273 }) 274 .template Case<mlir::TupleType>([&](mlir::TupleType tuple) { 275 if (fir::isCharacterProcedureTuple(tuple)) { 276 mlir::ModuleOp module = getModule(); 277 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 278 if (callOp.getCallee()) { 279 llvm::StringRef charProcAttr = 280 fir::getCharacterProcedureDummyAttrName(); 281 // The charProcAttr attribute is only used as a safety to 282 // confirm that this is a dummy procedure and should be split. 283 // It cannot be used to match because attributes are not 284 // available in case of indirect calls. 285 auto funcOp = module.lookupSymbol<mlir::func::FuncOp>( 286 *callOp.getCallee()); 287 if (funcOp && 288 !funcOp.template getArgAttrOfType<mlir::UnitAttr>( 289 index, charProcAttr)) 290 mlir::emitError(loc, "tuple argument will be split even " 291 "though it does not have the `" + 292 charProcAttr + "` attribute"); 293 } 294 } 295 mlir::Type funcPointerType = tuple.getType(0); 296 mlir::Type lenType = tuple.getType(1); 297 fir::KindMapping kindMap = fir::getKindMapping(module); 298 fir::FirOpBuilder builder(*rewriter, kindMap); 299 auto [funcPointer, len] = 300 fir::factory::extractCharacterProcedureTuple(builder, loc, 301 oper); 302 newInTys.push_back(funcPointerType); 303 newOpers.push_back(funcPointer); 304 trailingInTys.push_back(lenType); 305 trailingOpers.push_back(len); 306 } else { 307 newInTys.push_back(tuple); 308 newOpers.push_back(oper); 309 } 310 }) 311 .Default([&](mlir::Type ty) { 312 newInTys.push_back(ty); 313 newOpers.push_back(oper); 314 }); 315 } 316 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 317 newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); 318 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 319 fir::CallOp newCall; 320 if (callOp.getCallee()) { 321 newCall = rewriter->create<A>(loc, callOp.getCallee().getValue(), 322 newResTys, newOpers); 323 } else { 324 // Force new type on the input operand. 325 newOpers[0].setType(mlir::FunctionType::get( 326 callOp.getContext(), 327 mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys)); 328 newCall = rewriter->create<A>(loc, newResTys, newOpers); 329 } 330 LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n'); 331 if (wrap) 332 replaceOp(callOp, (*wrap)(newCall.getOperation())); 333 else 334 replaceOp(callOp, newCall.getResults()); 335 } else { 336 // A is fir::DispatchOp 337 TODO(loc, "dispatch not implemented"); 338 } 339 } 340 341 // Result type fixup for fir::ComplexType and mlir::ComplexType 342 template <typename A, typename B> 343 void lowerComplexSignatureRes(mlir::Location loc, A cmplx, B &newResTys, 344 B &newInTys) { 345 if (noComplexConversion) { 346 newResTys.push_back(cmplx); 347 } else { 348 for (auto &tup : 349 specifics->complexReturnType(loc, cmplx.getElementType())) { 350 auto argTy = std::get<mlir::Type>(tup); 351 if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet()) 352 newInTys.push_back(argTy); 353 else 354 newResTys.push_back(argTy); 355 } 356 } 357 } 358 359 // Argument type fixup for fir::ComplexType and mlir::ComplexType 360 template <typename A, typename B> 361 void lowerComplexSignatureArg(mlir::Location loc, A cmplx, B &newInTys) { 362 if (noComplexConversion) 363 newInTys.push_back(cmplx); 364 else 365 for (auto &tup : 366 specifics->complexArgumentType(loc, cmplx.getElementType())) 367 newInTys.push_back(std::get<mlir::Type>(tup)); 368 } 369 370 /// Taking the address of a function. Modify the signature as needed. 371 void convertAddrOp(fir::AddrOfOp addrOp) { 372 rewriter->setInsertionPoint(addrOp); 373 auto addrTy = addrOp.getType().cast<mlir::FunctionType>(); 374 llvm::SmallVector<mlir::Type> newResTys; 375 llvm::SmallVector<mlir::Type> newInTys; 376 auto loc = addrOp.getLoc(); 377 for (mlir::Type ty : addrTy.getResults()) { 378 llvm::TypeSwitch<mlir::Type>(ty) 379 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 380 lowerComplexSignatureRes(loc, ty, newResTys, newInTys); 381 }) 382 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 383 lowerComplexSignatureRes(loc, ty, newResTys, newInTys); 384 }) 385 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 386 } 387 llvm::SmallVector<mlir::Type> trailingInTys; 388 for (mlir::Type ty : addrTy.getInputs()) { 389 llvm::TypeSwitch<mlir::Type>(ty) 390 .Case<fir::BoxCharType>([&](auto box) { 391 if (noCharacterConversion) { 392 newInTys.push_back(box); 393 } else { 394 for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) { 395 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup); 396 auto argTy = std::get<mlir::Type>(tup); 397 llvm::SmallVector<mlir::Type> &vec = 398 attr.isAppend() ? trailingInTys : newInTys; 399 vec.push_back(argTy); 400 } 401 } 402 }) 403 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 404 lowerComplexSignatureArg(loc, ty, newInTys); 405 }) 406 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 407 lowerComplexSignatureArg(loc, ty, newInTys); 408 }) 409 .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 410 if (fir::isCharacterProcedureTuple(tuple)) { 411 newInTys.push_back(tuple.getType(0)); 412 trailingInTys.push_back(tuple.getType(1)); 413 } else { 414 newInTys.push_back(ty); 415 } 416 }) 417 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 418 } 419 // append trailing input types 420 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 421 // replace this op with a new one with the updated signature 422 auto newTy = rewriter->getFunctionType(newInTys, newResTys); 423 auto newOp = rewriter->create<fir::AddrOfOp>(addrOp.getLoc(), newTy, 424 addrOp.getSymbol()); 425 replaceOp(addrOp, newOp.getResult()); 426 } 427 428 /// Convert the type signatures on all the functions present in the module. 429 /// As the type signature is being changed, this must also update the 430 /// function itself to use any new arguments, etc. 431 mlir::LogicalResult convertTypes(mlir::ModuleOp mod) { 432 for (auto fn : mod.getOps<mlir::func::FuncOp>()) 433 convertSignature(fn); 434 return mlir::success(); 435 } 436 437 /// If the signature does not need any special target-specific converions, 438 /// then it is considered portable for any target, and this function will 439 /// return `true`. Otherwise, the signature is not portable and `false` is 440 /// returned. 441 bool hasPortableSignature(mlir::Type signature) { 442 assert(signature.isa<mlir::FunctionType>()); 443 auto func = signature.dyn_cast<mlir::FunctionType>(); 444 for (auto ty : func.getResults()) 445 if ((ty.isa<fir::BoxCharType>() && !noCharacterConversion) || 446 (fir::isa_complex(ty) && !noComplexConversion)) { 447 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 448 return false; 449 } 450 for (auto ty : func.getInputs()) 451 if (((ty.isa<fir::BoxCharType>() || fir::isCharacterProcedureTuple(ty)) && 452 !noCharacterConversion) || 453 (fir::isa_complex(ty) && !noComplexConversion)) { 454 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 455 return false; 456 } 457 return true; 458 } 459 460 /// Determine if the signature has host associations. The host association 461 /// argument may need special target specific rewriting. 462 static bool hasHostAssociations(mlir::func::FuncOp func) { 463 std::size_t end = func.getFunctionType().getInputs().size(); 464 for (std::size_t i = 0; i < end; ++i) 465 if (func.getArgAttrOfType<mlir::UnitAttr>(i, fir::getHostAssocAttrName())) 466 return true; 467 return false; 468 } 469 470 /// Rewrite the signatures and body of the `FuncOp`s in the module for 471 /// the immediately subsequent target code gen. 472 void convertSignature(mlir::func::FuncOp func) { 473 auto funcTy = func.getFunctionType().cast<mlir::FunctionType>(); 474 if (hasPortableSignature(funcTy) && !hasHostAssociations(func)) 475 return; 476 llvm::SmallVector<mlir::Type> newResTys; 477 llvm::SmallVector<mlir::Type> newInTys; 478 llvm::SmallVector<FixupTy> fixups; 479 480 // Convert return value(s) 481 for (auto ty : funcTy.getResults()) 482 llvm::TypeSwitch<mlir::Type>(ty) 483 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 484 if (noComplexConversion) 485 newResTys.push_back(cmplx); 486 else 487 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 488 }) 489 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 490 if (noComplexConversion) 491 newResTys.push_back(cmplx); 492 else 493 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 494 }) 495 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 496 497 // Convert arguments 498 llvm::SmallVector<mlir::Type> trailingTys; 499 for (auto e : llvm::enumerate(funcTy.getInputs())) { 500 auto ty = e.value(); 501 unsigned index = e.index(); 502 llvm::TypeSwitch<mlir::Type>(ty) 503 .Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) { 504 if (noCharacterConversion) { 505 newInTys.push_back(boxTy); 506 } else { 507 // Convert a CHARACTER argument type. This can involve separating 508 // the pointer and the LEN into two arguments and moving the LEN 509 // argument to the end of the arg list. 510 bool sret = functionArgIsSRet(index, func); 511 for (auto e : llvm::enumerate(specifics->boxcharArgumentType( 512 boxTy.getEleTy(), sret))) { 513 auto &tup = e.value(); 514 auto index = e.index(); 515 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup); 516 auto argTy = std::get<mlir::Type>(tup); 517 if (attr.isAppend()) { 518 trailingTys.push_back(argTy); 519 } else { 520 if (sret) { 521 fixups.emplace_back(FixupTy::Codes::CharPair, 522 newInTys.size(), index); 523 } else { 524 fixups.emplace_back(FixupTy::Codes::Trailing, 525 newInTys.size(), trailingTys.size()); 526 } 527 newInTys.push_back(argTy); 528 } 529 } 530 } 531 }) 532 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 533 if (noComplexConversion) 534 newInTys.push_back(cmplx); 535 else 536 doComplexArg(func, cmplx, newInTys, fixups); 537 }) 538 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 539 if (noComplexConversion) 540 newInTys.push_back(cmplx); 541 else 542 doComplexArg(func, cmplx, newInTys, fixups); 543 }) 544 .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 545 if (fir::isCharacterProcedureTuple(tuple)) { 546 fixups.emplace_back(FixupTy::Codes::TrailingCharProc, 547 newInTys.size(), trailingTys.size()); 548 newInTys.push_back(tuple.getType(0)); 549 trailingTys.push_back(tuple.getType(1)); 550 } else { 551 newInTys.push_back(ty); 552 } 553 }) 554 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 555 if (func.getArgAttrOfType<mlir::UnitAttr>(index, 556 fir::getHostAssocAttrName())) { 557 func.setArgAttr(index, "llvm.nest", rewriter->getUnitAttr()); 558 } 559 } 560 561 if (!func.empty()) { 562 // If the function has a body, then apply the fixups to the arguments and 563 // return ops as required. These fixups are done in place. 564 auto loc = func.getLoc(); 565 const auto fixupSize = fixups.size(); 566 const auto oldArgTys = func.getFunctionType().getInputs(); 567 int offset = 0; 568 for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) { 569 const auto &fixup = fixups[i]; 570 switch (fixup.code) { 571 case FixupTy::Codes::ArgumentAsLoad: { 572 // Argument was pass-by-value, but is now pass-by-reference and 573 // possibly with a different element type. 574 auto newArg = func.front().insertArgument(fixup.index, 575 newInTys[fixup.index], loc); 576 rewriter->setInsertionPointToStart(&func.front()); 577 auto oldArgTy = 578 fir::ReferenceType::get(oldArgTys[fixup.index - offset]); 579 auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, newArg); 580 auto load = rewriter->create<fir::LoadOp>(loc, cast); 581 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 582 func.front().eraseArgument(fixup.index + 1); 583 } break; 584 case FixupTy::Codes::ArgumentType: { 585 // Argument is pass-by-value, but its type has likely been modified to 586 // suit the target ABI convention. 587 auto newArg = func.front().insertArgument(fixup.index, 588 newInTys[fixup.index], loc); 589 rewriter->setInsertionPointToStart(&func.front()); 590 auto mem = 591 rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]); 592 rewriter->create<fir::StoreOp>(loc, newArg, mem); 593 auto oldArgTy = 594 fir::ReferenceType::get(oldArgTys[fixup.index - offset]); 595 auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, mem); 596 mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast); 597 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 598 func.front().eraseArgument(fixup.index + 1); 599 LLVM_DEBUG(llvm::dbgs() 600 << "old argument: " << oldArgTy.getEleTy() 601 << ", repl: " << load << ", new argument: " 602 << func.getArgument(fixup.index).getType() << '\n'); 603 } break; 604 case FixupTy::Codes::CharPair: { 605 // The FIR boxchar argument has been split into a pair of distinct 606 // arguments that are in juxtaposition to each other. 607 auto newArg = func.front().insertArgument(fixup.index, 608 newInTys[fixup.index], loc); 609 if (fixup.second == 1) { 610 rewriter->setInsertionPointToStart(&func.front()); 611 auto boxTy = oldArgTys[fixup.index - offset - fixup.second]; 612 auto box = rewriter->create<fir::EmboxCharOp>( 613 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg); 614 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 615 func.front().eraseArgument(fixup.index + 1); 616 offset++; 617 } 618 } break; 619 case FixupTy::Codes::ReturnAsStore: { 620 // The value being returned is now being returned in memory (callee 621 // stack space) through a hidden reference argument. 622 auto newArg = func.front().insertArgument(fixup.index, 623 newInTys[fixup.index], loc); 624 offset++; 625 func.walk([&](mlir::func::ReturnOp ret) { 626 rewriter->setInsertionPoint(ret); 627 auto oldOper = ret.getOperand(0); 628 auto oldOperTy = fir::ReferenceType::get(oldOper.getType()); 629 auto cast = 630 rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg); 631 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 632 rewriter->create<mlir::func::ReturnOp>(loc); 633 ret.erase(); 634 }); 635 } break; 636 case FixupTy::Codes::ReturnType: { 637 // The function is still returning a value, but its type has likely 638 // changed to suit the target ABI convention. 639 func.walk([&](mlir::func::ReturnOp ret) { 640 rewriter->setInsertionPoint(ret); 641 auto oldOper = ret.getOperand(0); 642 auto oldOperTy = fir::ReferenceType::get(oldOper.getType()); 643 auto mem = 644 rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]); 645 auto cast = rewriter->create<fir::ConvertOp>(loc, oldOperTy, mem); 646 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 647 mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem); 648 rewriter->create<mlir::func::ReturnOp>(loc, load); 649 ret.erase(); 650 }); 651 } break; 652 case FixupTy::Codes::Split: { 653 // The FIR argument has been split into a pair of distinct arguments 654 // that are in juxtaposition to each other. (For COMPLEX value.) 655 auto newArg = func.front().insertArgument(fixup.index, 656 newInTys[fixup.index], loc); 657 if (fixup.second == 1) { 658 rewriter->setInsertionPointToStart(&func.front()); 659 auto cplxTy = oldArgTys[fixup.index - offset - fixup.second]; 660 auto undef = rewriter->create<fir::UndefOp>(loc, cplxTy); 661 auto iTy = rewriter->getIntegerType(32); 662 auto zero = rewriter->getIntegerAttr(iTy, 0); 663 auto one = rewriter->getIntegerAttr(iTy, 1); 664 auto cplx1 = rewriter->create<fir::InsertValueOp>( 665 loc, cplxTy, undef, func.front().getArgument(fixup.index - 1), 666 rewriter->getArrayAttr(zero)); 667 auto cplx = rewriter->create<fir::InsertValueOp>( 668 loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one)); 669 func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx); 670 func.front().eraseArgument(fixup.index + 1); 671 offset++; 672 } 673 } break; 674 case FixupTy::Codes::Trailing: { 675 // The FIR argument has been split into a pair of distinct arguments. 676 // The first part of the pair appears in the original argument 677 // position. The second part of the pair is appended after all the 678 // original arguments. (Boxchar arguments.) 679 auto newBufArg = func.front().insertArgument( 680 fixup.index, newInTys[fixup.index], loc); 681 auto newLenArg = 682 func.front().addArgument(trailingTys[fixup.second], loc); 683 auto boxTy = oldArgTys[fixup.index - offset]; 684 rewriter->setInsertionPointToStart(&func.front()); 685 auto box = rewriter->create<fir::EmboxCharOp>(loc, boxTy, newBufArg, 686 newLenArg); 687 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 688 func.front().eraseArgument(fixup.index + 1); 689 } break; 690 case FixupTy::Codes::TrailingCharProc: { 691 // The FIR character procedure argument tuple must be split into a 692 // pair of distinct arguments. The first part of the pair appears in 693 // the original argument position. The second part of the pair is 694 // appended after all the original arguments. 695 auto newProcPointerArg = func.front().insertArgument( 696 fixup.index, newInTys[fixup.index], loc); 697 auto newLenArg = 698 func.front().addArgument(trailingTys[fixup.second], loc); 699 auto tupleType = oldArgTys[fixup.index - offset]; 700 rewriter->setInsertionPointToStart(&func.front()); 701 fir::KindMapping kindMap = fir::getKindMapping(getModule()); 702 fir::FirOpBuilder builder(*rewriter, kindMap); 703 auto tuple = fir::factory::createCharacterProcedureTuple( 704 builder, loc, tupleType, newProcPointerArg, newLenArg); 705 func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple); 706 func.front().eraseArgument(fixup.index + 1); 707 } break; 708 } 709 } 710 } 711 712 // Set the new type and finalize the arguments, etc. 713 newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end()); 714 auto newFuncTy = 715 mlir::FunctionType::get(func.getContext(), newInTys, newResTys); 716 LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n'); 717 func.setType(newFuncTy); 718 719 for (auto &fixup : fixups) 720 if (fixup.finalizer) 721 (*fixup.finalizer)(func); 722 } 723 724 inline bool functionArgIsSRet(unsigned index, mlir::func::FuncOp func) { 725 if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret")) 726 return true; 727 return false; 728 } 729 730 /// Convert a complex return value. This can involve converting the return 731 /// value to a "hidden" first argument or packing the complex into a wide 732 /// GPR. 733 template <typename A, typename B, typename C> 734 void doComplexReturn(mlir::func::FuncOp func, A cmplx, B &newResTys, 735 B &newInTys, C &fixups) { 736 if (noComplexConversion) { 737 newResTys.push_back(cmplx); 738 return; 739 } 740 auto m = 741 specifics->complexReturnType(func.getLoc(), cmplx.getElementType()); 742 assert(m.size() == 1); 743 auto &tup = m[0]; 744 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup); 745 auto argTy = std::get<mlir::Type>(tup); 746 if (attr.isSRet()) { 747 unsigned argNo = newInTys.size(); 748 if (auto align = attr.getAlignment()) 749 fixups.emplace_back( 750 FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) { 751 func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); 752 func.setArgAttr(argNo, "llvm.align", 753 rewriter->getIntegerAttr( 754 rewriter->getIntegerType(32), align)); 755 }); 756 else 757 fixups.emplace_back( 758 FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) { 759 func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); 760 }); 761 newInTys.push_back(argTy); 762 return; 763 } else { 764 if (auto align = attr.getAlignment()) 765 fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size(), 766 [=](mlir::func::FuncOp func) { 767 func.setArgAttr( 768 newResTys.size(), "llvm.align", 769 rewriter->getIntegerAttr( 770 rewriter->getIntegerType(32), align)); 771 }); 772 else 773 fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size()); 774 } 775 newResTys.push_back(argTy); 776 } 777 778 /// Convert a complex argument value. This can involve storing the value to 779 /// a temporary memory location or factoring the value into two distinct 780 /// arguments. 781 template <typename A, typename B, typename C> 782 void doComplexArg(mlir::func::FuncOp func, A cmplx, B &newInTys, C &fixups) { 783 if (noComplexConversion) { 784 newInTys.push_back(cmplx); 785 return; 786 } 787 auto m = 788 specifics->complexArgumentType(func.getLoc(), cmplx.getElementType()); 789 const auto fixupCode = 790 m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType; 791 for (auto e : llvm::enumerate(m)) { 792 auto &tup = e.value(); 793 auto index = e.index(); 794 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup); 795 auto argTy = std::get<mlir::Type>(tup); 796 auto argNo = newInTys.size(); 797 if (attr.isByVal()) { 798 if (auto align = attr.getAlignment()) 799 fixups.emplace_back( 800 FixupTy::Codes::ArgumentAsLoad, argNo, 801 [=](mlir::func::FuncOp func) { 802 func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr()); 803 func.setArgAttr(argNo, "llvm.align", 804 rewriter->getIntegerAttr( 805 rewriter->getIntegerType(32), align)); 806 }); 807 else 808 fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(), 809 [=](mlir::func::FuncOp func) { 810 func.setArgAttr(argNo, "llvm.byval", 811 rewriter->getUnitAttr()); 812 }); 813 } else { 814 if (auto align = attr.getAlignment()) 815 fixups.emplace_back( 816 fixupCode, argNo, index, [=](mlir::func::FuncOp func) { 817 func.setArgAttr(argNo, "llvm.align", 818 rewriter->getIntegerAttr( 819 rewriter->getIntegerType(32), align)); 820 }); 821 else 822 fixups.emplace_back(fixupCode, argNo, index); 823 } 824 newInTys.push_back(argTy); 825 } 826 } 827 828 private: 829 // Replace `op` and remove it. 830 void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { 831 op->replaceAllUsesWith(newValues); 832 op->dropAllReferences(); 833 op->erase(); 834 } 835 836 inline void setMembers(fir::CodeGenSpecifics *s, mlir::OpBuilder *r) { 837 specifics = s; 838 rewriter = r; 839 } 840 841 inline void clearMembers() { setMembers(nullptr, nullptr); } 842 843 fir::CodeGenSpecifics *specifics = nullptr; 844 mlir::OpBuilder *rewriter = nullptr; 845 }; // namespace 846 } // namespace 847 848 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 849 fir::createFirTargetRewritePass(const fir::TargetRewriteOptions &options) { 850 return std::make_unique<TargetRewrite>(options); 851 } 852