1 //===-- TargetRewrite.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest. 10 // LLVM expects different lowering idioms to be used for distinct target 11 // triples. These distinctions are handled by this pass. 12 // 13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "PassDetail.h" 18 #include "Target.h" 19 #include "flang/Optimizer/Builder/Character.h" 20 #include "flang/Optimizer/Builder/FIRBuilder.h" 21 #include "flang/Optimizer/Builder/Todo.h" 22 #include "flang/Optimizer/CodeGen/CodeGen.h" 23 #include "flang/Optimizer/Dialect/FIRDialect.h" 24 #include "flang/Optimizer/Dialect/FIROps.h" 25 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 26 #include "flang/Optimizer/Dialect/FIRType.h" 27 #include "flang/Optimizer/Support/FIRContext.h" 28 #include "mlir/Transforms/DialectConversion.h" 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/ADT/TypeSwitch.h" 31 #include "llvm/Support/Debug.h" 32 33 #define DEBUG_TYPE "flang-target-rewrite" 34 35 namespace { 36 37 /// Fixups for updating a FuncOp's arguments and return values. 38 struct FixupTy { 39 enum class Codes { 40 ArgumentAsLoad, 41 ArgumentType, 42 CharPair, 43 ReturnAsStore, 44 ReturnType, 45 Split, 46 Trailing, 47 TrailingCharProc 48 }; 49 50 FixupTy(Codes code, std::size_t index, std::size_t second = 0) 51 : code{code}, index{index}, second{second} {} 52 FixupTy(Codes code, std::size_t index, 53 std::function<void(mlir::func::FuncOp)> &&finalizer) 54 : code{code}, index{index}, finalizer{finalizer} {} 55 FixupTy(Codes code, std::size_t index, std::size_t second, 56 std::function<void(mlir::func::FuncOp)> &&finalizer) 57 : code{code}, index{index}, second{second}, finalizer{finalizer} {} 58 59 Codes code; 60 std::size_t index; 61 std::size_t second{}; 62 llvm::Optional<std::function<void(mlir::func::FuncOp)>> finalizer{}; 63 }; // namespace 64 65 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code 66 /// generation that traverses the FIR and modifies types and operations to a 67 /// form that is appropriate for the specific target. LLVM IR has specific 68 /// idioms that are used for distinct target processor and ABI combinations. 69 class TargetRewrite : public fir::TargetRewriteBase<TargetRewrite> { 70 public: 71 TargetRewrite(const fir::TargetRewriteOptions &options) { 72 noCharacterConversion = options.noCharacterConversion; 73 noComplexConversion = options.noComplexConversion; 74 } 75 76 void runOnOperation() override final { 77 auto &context = getContext(); 78 mlir::OpBuilder rewriter(&context); 79 80 auto mod = getModule(); 81 if (!forcedTargetTriple.empty()) 82 fir::setTargetTriple(mod, forcedTargetTriple); 83 84 auto specifics = fir::CodeGenSpecifics::get( 85 mod.getContext(), fir::getTargetTriple(mod), fir::getKindMapping(mod)); 86 setMembers(specifics.get(), &rewriter); 87 88 // Perform type conversion on signatures and call sites. 89 if (mlir::failed(convertTypes(mod))) { 90 mlir::emitError(mlir::UnknownLoc::get(&context), 91 "error in converting types to target abi"); 92 signalPassFailure(); 93 } 94 95 // Convert ops in target-specific patterns. 96 mod.walk([&](mlir::Operation *op) { 97 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) { 98 if (!hasPortableSignature(call.getFunctionType())) 99 convertCallOp(call); 100 } else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) { 101 if (!hasPortableSignature(dispatch.getFunctionType())) 102 convertCallOp(dispatch); 103 } else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) { 104 if (addr.getType().isa<mlir::FunctionType>() && 105 !hasPortableSignature(addr.getType())) 106 convertAddrOp(addr); 107 } 108 }); 109 110 clearMembers(); 111 } 112 113 mlir::ModuleOp getModule() { return getOperation(); } 114 115 template <typename A, typename B, typename C> 116 std::function<mlir::Value(mlir::Operation *)> 117 rewriteCallComplexResultType(mlir::Location loc, A ty, B &newResTys, 118 B &newInTys, C &newOpers) { 119 auto m = specifics->complexReturnType(loc, ty.getElementType()); 120 // Currently targets mandate COMPLEX is a single aggregate or packed 121 // scalar, including the sret case. 122 assert(m.size() == 1 && "target lowering of complex return not supported"); 123 auto resTy = std::get<mlir::Type>(m[0]); 124 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]); 125 if (attr.isSRet()) { 126 assert(fir::isa_ref_type(resTy) && "must be a memory reference type"); 127 mlir::Value stack = 128 rewriter->create<fir::AllocaOp>(loc, fir::dyn_cast_ptrEleTy(resTy)); 129 newInTys.push_back(resTy); 130 newOpers.push_back(stack); 131 return [=](mlir::Operation *) -> mlir::Value { 132 auto memTy = fir::ReferenceType::get(ty); 133 auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, stack); 134 return rewriter->create<fir::LoadOp>(loc, cast); 135 }; 136 } 137 newResTys.push_back(resTy); 138 return [=](mlir::Operation *call) -> mlir::Value { 139 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 140 rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem); 141 auto memTy = fir::ReferenceType::get(ty); 142 auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, mem); 143 return rewriter->create<fir::LoadOp>(loc, cast); 144 }; 145 } 146 147 template <typename A, typename B, typename C> 148 void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys, 149 C &newOpers) { 150 auto *ctx = ty.getContext(); 151 mlir::Location loc = mlir::UnknownLoc::get(ctx); 152 if (auto *op = oper.getDefiningOp()) 153 loc = op->getLoc(); 154 auto m = specifics->complexArgumentType(loc, ty.getElementType()); 155 if (m.size() == 1) { 156 // COMPLEX is a single aggregate 157 auto resTy = std::get<mlir::Type>(m[0]); 158 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]); 159 auto oldRefTy = fir::ReferenceType::get(ty); 160 if (attr.isByVal()) { 161 auto mem = rewriter->create<fir::AllocaOp>(loc, ty); 162 rewriter->create<fir::StoreOp>(loc, oper, mem); 163 newOpers.push_back(rewriter->create<fir::ConvertOp>(loc, resTy, mem)); 164 } else { 165 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 166 auto cast = rewriter->create<fir::ConvertOp>(loc, oldRefTy, mem); 167 rewriter->create<fir::StoreOp>(loc, oper, cast); 168 newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem)); 169 } 170 newInTys.push_back(resTy); 171 } else { 172 assert(m.size() == 2); 173 // COMPLEX is split into 2 separate arguments 174 auto iTy = rewriter->getIntegerType(32); 175 for (auto e : llvm::enumerate(m)) { 176 auto &tup = e.value(); 177 auto ty = std::get<mlir::Type>(tup); 178 auto index = e.index(); 179 auto idx = rewriter->getIntegerAttr(iTy, index); 180 auto val = rewriter->create<fir::ExtractValueOp>( 181 loc, ty, oper, rewriter->getArrayAttr(idx)); 182 newInTys.push_back(ty); 183 newOpers.push_back(val); 184 } 185 } 186 } 187 188 // Convert fir.call and fir.dispatch Ops. 189 template <typename A> 190 void convertCallOp(A callOp) { 191 auto fnTy = callOp.getFunctionType(); 192 auto loc = callOp.getLoc(); 193 rewriter->setInsertionPoint(callOp); 194 llvm::SmallVector<mlir::Type> newResTys; 195 llvm::SmallVector<mlir::Type> newInTys; 196 llvm::SmallVector<mlir::Value> newOpers; 197 198 // If the call is indirect, the first argument must still be the function 199 // to call. 200 int dropFront = 0; 201 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 202 if (!callOp.getCallee()) { 203 newInTys.push_back(fnTy.getInput(0)); 204 newOpers.push_back(callOp.getOperand(0)); 205 dropFront = 1; 206 } 207 } 208 209 // Determine the rewrite function, `wrap`, for the result value. 210 llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap; 211 if (fnTy.getResults().size() == 1) { 212 mlir::Type ty = fnTy.getResult(0); 213 llvm::TypeSwitch<mlir::Type>(ty) 214 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 215 wrap = rewriteCallComplexResultType(loc, cmplx, newResTys, newInTys, 216 newOpers); 217 }) 218 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 219 wrap = rewriteCallComplexResultType(loc, cmplx, newResTys, newInTys, 220 newOpers); 221 }) 222 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 223 } else if (fnTy.getResults().size() > 1) { 224 TODO(loc, "multiple results not supported yet"); 225 } 226 227 llvm::SmallVector<mlir::Type> trailingInTys; 228 llvm::SmallVector<mlir::Value> trailingOpers; 229 for (auto e : llvm::enumerate( 230 llvm::zip(fnTy.getInputs().drop_front(dropFront), 231 callOp.getOperands().drop_front(dropFront)))) { 232 mlir::Type ty = std::get<0>(e.value()); 233 mlir::Value oper = std::get<1>(e.value()); 234 unsigned index = e.index(); 235 llvm::TypeSwitch<mlir::Type>(ty) 236 .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) { 237 bool sret; 238 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 239 sret = callOp.getCallee() && 240 functionArgIsSRet( 241 index, getModule().lookupSymbol<mlir::func::FuncOp>( 242 *callOp.getCallee())); 243 } else { 244 // TODO: dispatch case; how do we put arguments on a call? 245 // We cannot put both an sret and the dispatch object first. 246 sret = false; 247 TODO(loc, "dispatch + sret not supported yet"); 248 } 249 auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret); 250 auto unbox = rewriter->create<fir::UnboxCharOp>( 251 loc, std::get<mlir::Type>(m[0]), std::get<mlir::Type>(m[1]), 252 oper); 253 // unboxed CHARACTER arguments 254 for (auto e : llvm::enumerate(m)) { 255 unsigned idx = e.index(); 256 auto attr = 257 std::get<fir::CodeGenSpecifics::Attributes>(e.value()); 258 auto argTy = std::get<mlir::Type>(e.value()); 259 if (attr.isAppend()) { 260 trailingInTys.push_back(argTy); 261 trailingOpers.push_back(unbox.getResult(idx)); 262 } else { 263 newInTys.push_back(argTy); 264 newOpers.push_back(unbox.getResult(idx)); 265 } 266 } 267 }) 268 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 269 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 270 }) 271 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 272 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 273 }) 274 .template Case<mlir::TupleType>([&](mlir::TupleType tuple) { 275 if (fir::isCharacterProcedureTuple(tuple)) { 276 mlir::ModuleOp module = getModule(); 277 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 278 if (callOp.getCallee()) { 279 llvm::StringRef charProcAttr = 280 fir::getCharacterProcedureDummyAttrName(); 281 // The charProcAttr attribute is only used as a safety to 282 // confirm that this is a dummy procedure and should be split. 283 // It cannot be used to match because attributes are not 284 // available in case of indirect calls. 285 auto funcOp = module.lookupSymbol<mlir::func::FuncOp>( 286 *callOp.getCallee()); 287 if (funcOp && 288 !funcOp.template getArgAttrOfType<mlir::UnitAttr>( 289 index, charProcAttr)) 290 mlir::emitError(loc, "tuple argument will be split even " 291 "though it does not have the `" + 292 charProcAttr + "` attribute"); 293 } 294 } 295 mlir::Type funcPointerType = tuple.getType(0); 296 mlir::Type lenType = tuple.getType(1); 297 fir::FirOpBuilder builder(*rewriter, fir::getKindMapping(module)); 298 auto [funcPointer, len] = 299 fir::factory::extractCharacterProcedureTuple(builder, loc, 300 oper); 301 newInTys.push_back(funcPointerType); 302 newOpers.push_back(funcPointer); 303 trailingInTys.push_back(lenType); 304 trailingOpers.push_back(len); 305 } else { 306 newInTys.push_back(tuple); 307 newOpers.push_back(oper); 308 } 309 }) 310 .Default([&](mlir::Type ty) { 311 newInTys.push_back(ty); 312 newOpers.push_back(oper); 313 }); 314 } 315 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 316 newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); 317 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 318 fir::CallOp newCall; 319 if (callOp.getCallee()) { 320 newCall = 321 rewriter->create<A>(loc, *callOp.getCallee(), newResTys, newOpers); 322 } else { 323 // Force new type on the input operand. 324 newOpers[0].setType(mlir::FunctionType::get( 325 callOp.getContext(), 326 mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys)); 327 newCall = rewriter->create<A>(loc, newResTys, newOpers); 328 } 329 LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n'); 330 if (wrap) 331 replaceOp(callOp, (*wrap)(newCall.getOperation())); 332 else 333 replaceOp(callOp, newCall.getResults()); 334 } else { 335 // A is fir::DispatchOp 336 TODO(loc, "dispatch not implemented"); 337 } 338 } 339 340 // Result type fixup for fir::ComplexType and mlir::ComplexType 341 template <typename A, typename B> 342 void lowerComplexSignatureRes(mlir::Location loc, A cmplx, B &newResTys, 343 B &newInTys) { 344 if (noComplexConversion) { 345 newResTys.push_back(cmplx); 346 } else { 347 for (auto &tup : 348 specifics->complexReturnType(loc, cmplx.getElementType())) { 349 auto argTy = std::get<mlir::Type>(tup); 350 if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet()) 351 newInTys.push_back(argTy); 352 else 353 newResTys.push_back(argTy); 354 } 355 } 356 } 357 358 // Argument type fixup for fir::ComplexType and mlir::ComplexType 359 template <typename A, typename B> 360 void lowerComplexSignatureArg(mlir::Location loc, A cmplx, B &newInTys) { 361 if (noComplexConversion) 362 newInTys.push_back(cmplx); 363 else 364 for (auto &tup : 365 specifics->complexArgumentType(loc, cmplx.getElementType())) 366 newInTys.push_back(std::get<mlir::Type>(tup)); 367 } 368 369 /// Taking the address of a function. Modify the signature as needed. 370 void convertAddrOp(fir::AddrOfOp addrOp) { 371 rewriter->setInsertionPoint(addrOp); 372 auto addrTy = addrOp.getType().cast<mlir::FunctionType>(); 373 llvm::SmallVector<mlir::Type> newResTys; 374 llvm::SmallVector<mlir::Type> newInTys; 375 auto loc = addrOp.getLoc(); 376 for (mlir::Type ty : addrTy.getResults()) { 377 llvm::TypeSwitch<mlir::Type>(ty) 378 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 379 lowerComplexSignatureRes(loc, ty, newResTys, newInTys); 380 }) 381 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 382 lowerComplexSignatureRes(loc, ty, newResTys, newInTys); 383 }) 384 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 385 } 386 llvm::SmallVector<mlir::Type> trailingInTys; 387 for (mlir::Type ty : addrTy.getInputs()) { 388 llvm::TypeSwitch<mlir::Type>(ty) 389 .Case<fir::BoxCharType>([&](auto box) { 390 if (noCharacterConversion) { 391 newInTys.push_back(box); 392 } else { 393 for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) { 394 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup); 395 auto argTy = std::get<mlir::Type>(tup); 396 llvm::SmallVector<mlir::Type> &vec = 397 attr.isAppend() ? trailingInTys : newInTys; 398 vec.push_back(argTy); 399 } 400 } 401 }) 402 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 403 lowerComplexSignatureArg(loc, ty, newInTys); 404 }) 405 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 406 lowerComplexSignatureArg(loc, ty, newInTys); 407 }) 408 .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 409 if (fir::isCharacterProcedureTuple(tuple)) { 410 newInTys.push_back(tuple.getType(0)); 411 trailingInTys.push_back(tuple.getType(1)); 412 } else { 413 newInTys.push_back(ty); 414 } 415 }) 416 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 417 } 418 // append trailing input types 419 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 420 // replace this op with a new one with the updated signature 421 auto newTy = rewriter->getFunctionType(newInTys, newResTys); 422 auto newOp = rewriter->create<fir::AddrOfOp>(addrOp.getLoc(), newTy, 423 addrOp.getSymbol()); 424 replaceOp(addrOp, newOp.getResult()); 425 } 426 427 /// Convert the type signatures on all the functions present in the module. 428 /// As the type signature is being changed, this must also update the 429 /// function itself to use any new arguments, etc. 430 mlir::LogicalResult convertTypes(mlir::ModuleOp mod) { 431 for (auto fn : mod.getOps<mlir::func::FuncOp>()) 432 convertSignature(fn); 433 return mlir::success(); 434 } 435 436 /// If the signature does not need any special target-specific converions, 437 /// then it is considered portable for any target, and this function will 438 /// return `true`. Otherwise, the signature is not portable and `false` is 439 /// returned. 440 bool hasPortableSignature(mlir::Type signature) { 441 assert(signature.isa<mlir::FunctionType>()); 442 auto func = signature.dyn_cast<mlir::FunctionType>(); 443 for (auto ty : func.getResults()) 444 if ((ty.isa<fir::BoxCharType>() && !noCharacterConversion) || 445 (fir::isa_complex(ty) && !noComplexConversion)) { 446 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 447 return false; 448 } 449 for (auto ty : func.getInputs()) 450 if (((ty.isa<fir::BoxCharType>() || fir::isCharacterProcedureTuple(ty)) && 451 !noCharacterConversion) || 452 (fir::isa_complex(ty) && !noComplexConversion)) { 453 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 454 return false; 455 } 456 return true; 457 } 458 459 /// Determine if the signature has host associations. The host association 460 /// argument may need special target specific rewriting. 461 static bool hasHostAssociations(mlir::func::FuncOp func) { 462 std::size_t end = func.getFunctionType().getInputs().size(); 463 for (std::size_t i = 0; i < end; ++i) 464 if (func.getArgAttrOfType<mlir::UnitAttr>(i, fir::getHostAssocAttrName())) 465 return true; 466 return false; 467 } 468 469 /// Rewrite the signatures and body of the `FuncOp`s in the module for 470 /// the immediately subsequent target code gen. 471 void convertSignature(mlir::func::FuncOp func) { 472 auto funcTy = func.getFunctionType().cast<mlir::FunctionType>(); 473 if (hasPortableSignature(funcTy) && !hasHostAssociations(func)) 474 return; 475 llvm::SmallVector<mlir::Type> newResTys; 476 llvm::SmallVector<mlir::Type> newInTys; 477 llvm::SmallVector<FixupTy> fixups; 478 479 // Convert return value(s) 480 for (auto ty : funcTy.getResults()) 481 llvm::TypeSwitch<mlir::Type>(ty) 482 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 483 if (noComplexConversion) 484 newResTys.push_back(cmplx); 485 else 486 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 487 }) 488 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 489 if (noComplexConversion) 490 newResTys.push_back(cmplx); 491 else 492 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 493 }) 494 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 495 496 // Convert arguments 497 llvm::SmallVector<mlir::Type> trailingTys; 498 for (auto e : llvm::enumerate(funcTy.getInputs())) { 499 auto ty = e.value(); 500 unsigned index = e.index(); 501 llvm::TypeSwitch<mlir::Type>(ty) 502 .Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) { 503 if (noCharacterConversion) { 504 newInTys.push_back(boxTy); 505 } else { 506 // Convert a CHARACTER argument type. This can involve separating 507 // the pointer and the LEN into two arguments and moving the LEN 508 // argument to the end of the arg list. 509 bool sret = functionArgIsSRet(index, func); 510 for (auto e : llvm::enumerate(specifics->boxcharArgumentType( 511 boxTy.getEleTy(), sret))) { 512 auto &tup = e.value(); 513 auto index = e.index(); 514 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup); 515 auto argTy = std::get<mlir::Type>(tup); 516 if (attr.isAppend()) { 517 trailingTys.push_back(argTy); 518 } else { 519 if (sret) { 520 fixups.emplace_back(FixupTy::Codes::CharPair, 521 newInTys.size(), index); 522 } else { 523 fixups.emplace_back(FixupTy::Codes::Trailing, 524 newInTys.size(), trailingTys.size()); 525 } 526 newInTys.push_back(argTy); 527 } 528 } 529 } 530 }) 531 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 532 if (noComplexConversion) 533 newInTys.push_back(cmplx); 534 else 535 doComplexArg(func, cmplx, newInTys, fixups); 536 }) 537 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 538 if (noComplexConversion) 539 newInTys.push_back(cmplx); 540 else 541 doComplexArg(func, cmplx, newInTys, fixups); 542 }) 543 .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 544 if (fir::isCharacterProcedureTuple(tuple)) { 545 fixups.emplace_back(FixupTy::Codes::TrailingCharProc, 546 newInTys.size(), trailingTys.size()); 547 newInTys.push_back(tuple.getType(0)); 548 trailingTys.push_back(tuple.getType(1)); 549 } else { 550 newInTys.push_back(ty); 551 } 552 }) 553 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 554 if (func.getArgAttrOfType<mlir::UnitAttr>(index, 555 fir::getHostAssocAttrName())) { 556 func.setArgAttr(index, "llvm.nest", rewriter->getUnitAttr()); 557 } 558 } 559 560 if (!func.empty()) { 561 // If the function has a body, then apply the fixups to the arguments and 562 // return ops as required. These fixups are done in place. 563 auto loc = func.getLoc(); 564 const auto fixupSize = fixups.size(); 565 const auto oldArgTys = func.getFunctionType().getInputs(); 566 int offset = 0; 567 for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) { 568 const auto &fixup = fixups[i]; 569 switch (fixup.code) { 570 case FixupTy::Codes::ArgumentAsLoad: { 571 // Argument was pass-by-value, but is now pass-by-reference and 572 // possibly with a different element type. 573 auto newArg = func.front().insertArgument(fixup.index, 574 newInTys[fixup.index], loc); 575 rewriter->setInsertionPointToStart(&func.front()); 576 auto oldArgTy = 577 fir::ReferenceType::get(oldArgTys[fixup.index - offset]); 578 auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, newArg); 579 auto load = rewriter->create<fir::LoadOp>(loc, cast); 580 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 581 func.front().eraseArgument(fixup.index + 1); 582 } break; 583 case FixupTy::Codes::ArgumentType: { 584 // Argument is pass-by-value, but its type has likely been modified to 585 // suit the target ABI convention. 586 auto newArg = func.front().insertArgument(fixup.index, 587 newInTys[fixup.index], loc); 588 rewriter->setInsertionPointToStart(&func.front()); 589 auto mem = 590 rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]); 591 rewriter->create<fir::StoreOp>(loc, newArg, mem); 592 auto oldArgTy = 593 fir::ReferenceType::get(oldArgTys[fixup.index - offset]); 594 auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, mem); 595 mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast); 596 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 597 func.front().eraseArgument(fixup.index + 1); 598 LLVM_DEBUG(llvm::dbgs() 599 << "old argument: " << oldArgTy.getEleTy() 600 << ", repl: " << load << ", new argument: " 601 << func.getArgument(fixup.index).getType() << '\n'); 602 } break; 603 case FixupTy::Codes::CharPair: { 604 // The FIR boxchar argument has been split into a pair of distinct 605 // arguments that are in juxtaposition to each other. 606 auto newArg = func.front().insertArgument(fixup.index, 607 newInTys[fixup.index], loc); 608 if (fixup.second == 1) { 609 rewriter->setInsertionPointToStart(&func.front()); 610 auto boxTy = oldArgTys[fixup.index - offset - fixup.second]; 611 auto box = rewriter->create<fir::EmboxCharOp>( 612 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg); 613 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 614 func.front().eraseArgument(fixup.index + 1); 615 offset++; 616 } 617 } break; 618 case FixupTy::Codes::ReturnAsStore: { 619 // The value being returned is now being returned in memory (callee 620 // stack space) through a hidden reference argument. 621 auto newArg = func.front().insertArgument(fixup.index, 622 newInTys[fixup.index], loc); 623 offset++; 624 func.walk([&](mlir::func::ReturnOp ret) { 625 rewriter->setInsertionPoint(ret); 626 auto oldOper = ret.getOperand(0); 627 auto oldOperTy = fir::ReferenceType::get(oldOper.getType()); 628 auto cast = 629 rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg); 630 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 631 rewriter->create<mlir::func::ReturnOp>(loc); 632 ret.erase(); 633 }); 634 } break; 635 case FixupTy::Codes::ReturnType: { 636 // The function is still returning a value, but its type has likely 637 // changed to suit the target ABI convention. 638 func.walk([&](mlir::func::ReturnOp ret) { 639 rewriter->setInsertionPoint(ret); 640 auto oldOper = ret.getOperand(0); 641 auto oldOperTy = fir::ReferenceType::get(oldOper.getType()); 642 auto mem = 643 rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]); 644 auto cast = rewriter->create<fir::ConvertOp>(loc, oldOperTy, mem); 645 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 646 mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem); 647 rewriter->create<mlir::func::ReturnOp>(loc, load); 648 ret.erase(); 649 }); 650 } break; 651 case FixupTy::Codes::Split: { 652 // The FIR argument has been split into a pair of distinct arguments 653 // that are in juxtaposition to each other. (For COMPLEX value.) 654 auto newArg = func.front().insertArgument(fixup.index, 655 newInTys[fixup.index], loc); 656 if (fixup.second == 1) { 657 rewriter->setInsertionPointToStart(&func.front()); 658 auto cplxTy = oldArgTys[fixup.index - offset - fixup.second]; 659 auto undef = rewriter->create<fir::UndefOp>(loc, cplxTy); 660 auto iTy = rewriter->getIntegerType(32); 661 auto zero = rewriter->getIntegerAttr(iTy, 0); 662 auto one = rewriter->getIntegerAttr(iTy, 1); 663 auto cplx1 = rewriter->create<fir::InsertValueOp>( 664 loc, cplxTy, undef, func.front().getArgument(fixup.index - 1), 665 rewriter->getArrayAttr(zero)); 666 auto cplx = rewriter->create<fir::InsertValueOp>( 667 loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one)); 668 func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx); 669 func.front().eraseArgument(fixup.index + 1); 670 offset++; 671 } 672 } break; 673 case FixupTy::Codes::Trailing: { 674 // The FIR argument has been split into a pair of distinct arguments. 675 // The first part of the pair appears in the original argument 676 // position. The second part of the pair is appended after all the 677 // original arguments. (Boxchar arguments.) 678 auto newBufArg = func.front().insertArgument( 679 fixup.index, newInTys[fixup.index], loc); 680 auto newLenArg = 681 func.front().addArgument(trailingTys[fixup.second], loc); 682 auto boxTy = oldArgTys[fixup.index - offset]; 683 rewriter->setInsertionPointToStart(&func.front()); 684 auto box = rewriter->create<fir::EmboxCharOp>(loc, boxTy, newBufArg, 685 newLenArg); 686 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 687 func.front().eraseArgument(fixup.index + 1); 688 } break; 689 case FixupTy::Codes::TrailingCharProc: { 690 // The FIR character procedure argument tuple must be split into a 691 // pair of distinct arguments. The first part of the pair appears in 692 // the original argument position. The second part of the pair is 693 // appended after all the original arguments. 694 auto newProcPointerArg = func.front().insertArgument( 695 fixup.index, newInTys[fixup.index], loc); 696 auto newLenArg = 697 func.front().addArgument(trailingTys[fixup.second], loc); 698 auto tupleType = oldArgTys[fixup.index - offset]; 699 rewriter->setInsertionPointToStart(&func.front()); 700 fir::FirOpBuilder builder(*rewriter, 701 fir::getKindMapping(getModule())); 702 auto tuple = fir::factory::createCharacterProcedureTuple( 703 builder, loc, tupleType, newProcPointerArg, newLenArg); 704 func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple); 705 func.front().eraseArgument(fixup.index + 1); 706 } break; 707 } 708 } 709 } 710 711 // Set the new type and finalize the arguments, etc. 712 newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end()); 713 auto newFuncTy = 714 mlir::FunctionType::get(func.getContext(), newInTys, newResTys); 715 LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n'); 716 func.setType(newFuncTy); 717 718 for (auto &fixup : fixups) 719 if (fixup.finalizer) 720 (*fixup.finalizer)(func); 721 } 722 723 inline bool functionArgIsSRet(unsigned index, mlir::func::FuncOp func) { 724 if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret")) 725 return true; 726 return false; 727 } 728 729 /// Convert a complex return value. This can involve converting the return 730 /// value to a "hidden" first argument or packing the complex into a wide 731 /// GPR. 732 template <typename A, typename B, typename C> 733 void doComplexReturn(mlir::func::FuncOp func, A cmplx, B &newResTys, 734 B &newInTys, C &fixups) { 735 if (noComplexConversion) { 736 newResTys.push_back(cmplx); 737 return; 738 } 739 auto m = 740 specifics->complexReturnType(func.getLoc(), cmplx.getElementType()); 741 assert(m.size() == 1); 742 auto &tup = m[0]; 743 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup); 744 auto argTy = std::get<mlir::Type>(tup); 745 if (attr.isSRet()) { 746 unsigned argNo = newInTys.size(); 747 if (auto align = attr.getAlignment()) 748 fixups.emplace_back( 749 FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) { 750 func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); 751 func.setArgAttr(argNo, "llvm.align", 752 rewriter->getIntegerAttr( 753 rewriter->getIntegerType(32), align)); 754 }); 755 else 756 fixups.emplace_back( 757 FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) { 758 func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); 759 }); 760 newInTys.push_back(argTy); 761 return; 762 } else { 763 if (auto align = attr.getAlignment()) 764 fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size(), 765 [=](mlir::func::FuncOp func) { 766 func.setArgAttr( 767 newResTys.size(), "llvm.align", 768 rewriter->getIntegerAttr( 769 rewriter->getIntegerType(32), align)); 770 }); 771 else 772 fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size()); 773 } 774 newResTys.push_back(argTy); 775 } 776 777 /// Convert a complex argument value. This can involve storing the value to 778 /// a temporary memory location or factoring the value into two distinct 779 /// arguments. 780 template <typename A, typename B, typename C> 781 void doComplexArg(mlir::func::FuncOp func, A cmplx, B &newInTys, C &fixups) { 782 if (noComplexConversion) { 783 newInTys.push_back(cmplx); 784 return; 785 } 786 auto m = 787 specifics->complexArgumentType(func.getLoc(), cmplx.getElementType()); 788 const auto fixupCode = 789 m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType; 790 for (auto e : llvm::enumerate(m)) { 791 auto &tup = e.value(); 792 auto index = e.index(); 793 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup); 794 auto argTy = std::get<mlir::Type>(tup); 795 auto argNo = newInTys.size(); 796 if (attr.isByVal()) { 797 if (auto align = attr.getAlignment()) 798 fixups.emplace_back( 799 FixupTy::Codes::ArgumentAsLoad, argNo, 800 [=](mlir::func::FuncOp func) { 801 func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr()); 802 func.setArgAttr(argNo, "llvm.align", 803 rewriter->getIntegerAttr( 804 rewriter->getIntegerType(32), align)); 805 }); 806 else 807 fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(), 808 [=](mlir::func::FuncOp func) { 809 func.setArgAttr(argNo, "llvm.byval", 810 rewriter->getUnitAttr()); 811 }); 812 } else { 813 if (auto align = attr.getAlignment()) 814 fixups.emplace_back( 815 fixupCode, argNo, index, [=](mlir::func::FuncOp func) { 816 func.setArgAttr(argNo, "llvm.align", 817 rewriter->getIntegerAttr( 818 rewriter->getIntegerType(32), align)); 819 }); 820 else 821 fixups.emplace_back(fixupCode, argNo, index); 822 } 823 newInTys.push_back(argTy); 824 } 825 } 826 827 private: 828 // Replace `op` and remove it. 829 void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { 830 op->replaceAllUsesWith(newValues); 831 op->dropAllReferences(); 832 op->erase(); 833 } 834 835 inline void setMembers(fir::CodeGenSpecifics *s, mlir::OpBuilder *r) { 836 specifics = s; 837 rewriter = r; 838 } 839 840 inline void clearMembers() { setMembers(nullptr, nullptr); } 841 842 fir::CodeGenSpecifics *specifics = nullptr; 843 mlir::OpBuilder *rewriter = nullptr; 844 }; // namespace 845 } // namespace 846 847 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 848 fir::createFirTargetRewritePass(const fir::TargetRewriteOptions &options) { 849 return std::make_unique<TargetRewrite>(options); 850 } 851