1 //===-- TargetRewrite.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest. 10 // LLVM expects different lowering idioms to be used for distinct target 11 // triples. These distinctions are handled by this pass. 12 // 13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "PassDetail.h" 18 #include "Target.h" 19 #include "flang/Optimizer/Builder/Character.h" 20 #include "flang/Optimizer/Builder/FIRBuilder.h" 21 #include "flang/Optimizer/Builder/Todo.h" 22 #include "flang/Optimizer/CodeGen/CodeGen.h" 23 #include "flang/Optimizer/Dialect/FIRDialect.h" 24 #include "flang/Optimizer/Dialect/FIROps.h" 25 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 26 #include "flang/Optimizer/Dialect/FIRType.h" 27 #include "flang/Optimizer/Support/FIRContext.h" 28 #include "mlir/Transforms/DialectConversion.h" 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/ADT/TypeSwitch.h" 31 #include "llvm/Support/Debug.h" 32 33 #define DEBUG_TYPE "flang-target-rewrite" 34 35 namespace { 36 37 /// Fixups for updating a FuncOp's arguments and return values. 38 struct FixupTy { 39 enum class Codes { 40 ArgumentAsLoad, 41 ArgumentType, 42 CharPair, 43 ReturnAsStore, 44 ReturnType, 45 Split, 46 Trailing, 47 TrailingCharProc 48 }; 49 50 FixupTy(Codes code, std::size_t index, std::size_t second = 0) 51 : code{code}, index{index}, second{second} {} 52 FixupTy(Codes code, std::size_t index, 53 std::function<void(mlir::func::FuncOp)> &&finalizer) 54 : code{code}, index{index}, finalizer{finalizer} {} 55 FixupTy(Codes code, std::size_t index, std::size_t second, 56 std::function<void(mlir::func::FuncOp)> &&finalizer) 57 : code{code}, index{index}, second{second}, finalizer{finalizer} {} 58 59 Codes code; 60 std::size_t index; 61 std::size_t second{}; 62 llvm::Optional<std::function<void(mlir::func::FuncOp)>> finalizer{}; 63 }; // namespace 64 65 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code 66 /// generation that traverses the FIR and modifies types and operations to a 67 /// form that is appropriate for the specific target. LLVM IR has specific 68 /// idioms that are used for distinct target processor and ABI combinations. 69 class TargetRewrite : public fir::TargetRewriteBase<TargetRewrite> { 70 public: 71 TargetRewrite(const fir::TargetRewriteOptions &options) { 72 noCharacterConversion = options.noCharacterConversion; 73 noComplexConversion = options.noComplexConversion; 74 } 75 76 void runOnOperation() override final { 77 auto &context = getContext(); 78 mlir::OpBuilder rewriter(&context); 79 80 auto mod = getModule(); 81 if (!forcedTargetTriple.empty()) 82 fir::setTargetTriple(mod, forcedTargetTriple); 83 84 auto specifics = fir::CodeGenSpecifics::get( 85 mod.getContext(), fir::getTargetTriple(mod), fir::getKindMapping(mod)); 86 setMembers(specifics.get(), &rewriter); 87 88 // Perform type conversion on signatures and call sites. 89 if (mlir::failed(convertTypes(mod))) { 90 mlir::emitError(mlir::UnknownLoc::get(&context), 91 "error in converting types to target abi"); 92 signalPassFailure(); 93 } 94 95 // Convert ops in target-specific patterns. 96 mod.walk([&](mlir::Operation *op) { 97 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) { 98 if (!hasPortableSignature(call.getFunctionType())) 99 convertCallOp(call); 100 } else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) { 101 if (!hasPortableSignature(dispatch.getFunctionType())) 102 convertCallOp(dispatch); 103 } else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) { 104 if (addr.getType().isa<mlir::FunctionType>() && 105 !hasPortableSignature(addr.getType())) 106 convertAddrOp(addr); 107 } 108 }); 109 110 clearMembers(); 111 } 112 113 mlir::ModuleOp getModule() { return getOperation(); } 114 115 template <typename A, typename B, typename C> 116 std::function<mlir::Value(mlir::Operation *)> 117 rewriteCallComplexResultType(A ty, B &newResTys, B &newInTys, C &newOpers) { 118 auto m = specifics->complexReturnType(ty.getElementType()); 119 // Currently targets mandate COMPLEX is a single aggregate or packed 120 // scalar, including the sret case. 121 assert(m.size() == 1 && "target lowering of complex return not supported"); 122 auto resTy = std::get<mlir::Type>(m[0]); 123 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]); 124 auto loc = mlir::UnknownLoc::get(resTy.getContext()); 125 if (attr.isSRet()) { 126 assert(fir::isa_ref_type(resTy) && "must be a memory reference type"); 127 mlir::Value stack = 128 rewriter->create<fir::AllocaOp>(loc, fir::dyn_cast_ptrEleTy(resTy)); 129 newInTys.push_back(resTy); 130 newOpers.push_back(stack); 131 return [=](mlir::Operation *) -> mlir::Value { 132 auto memTy = fir::ReferenceType::get(ty); 133 auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, stack); 134 return rewriter->create<fir::LoadOp>(loc, cast); 135 }; 136 } 137 newResTys.push_back(resTy); 138 return [=](mlir::Operation *call) -> mlir::Value { 139 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 140 rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem); 141 auto memTy = fir::ReferenceType::get(ty); 142 auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, mem); 143 return rewriter->create<fir::LoadOp>(loc, cast); 144 }; 145 } 146 147 template <typename A, typename B, typename C> 148 void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys, 149 C &newOpers) { 150 auto m = specifics->complexArgumentType(ty.getElementType()); 151 auto *ctx = ty.getContext(); 152 auto loc = mlir::UnknownLoc::get(ctx); 153 if (m.size() == 1) { 154 // COMPLEX is a single aggregate 155 auto resTy = std::get<mlir::Type>(m[0]); 156 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]); 157 auto oldRefTy = fir::ReferenceType::get(ty); 158 if (attr.isByVal()) { 159 auto mem = rewriter->create<fir::AllocaOp>(loc, ty); 160 rewriter->create<fir::StoreOp>(loc, oper, mem); 161 newOpers.push_back(rewriter->create<fir::ConvertOp>(loc, resTy, mem)); 162 } else { 163 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 164 auto cast = rewriter->create<fir::ConvertOp>(loc, oldRefTy, mem); 165 rewriter->create<fir::StoreOp>(loc, oper, cast); 166 newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem)); 167 } 168 newInTys.push_back(resTy); 169 } else { 170 assert(m.size() == 2); 171 // COMPLEX is split into 2 separate arguments 172 auto iTy = rewriter->getIntegerType(32); 173 for (auto e : llvm::enumerate(m)) { 174 auto &tup = e.value(); 175 auto ty = std::get<mlir::Type>(tup); 176 auto index = e.index(); 177 auto idx = rewriter->getIntegerAttr(iTy, index); 178 auto val = rewriter->create<fir::ExtractValueOp>( 179 loc, ty, oper, rewriter->getArrayAttr(idx)); 180 newInTys.push_back(ty); 181 newOpers.push_back(val); 182 } 183 } 184 } 185 186 // Convert fir.call and fir.dispatch Ops. 187 template <typename A> 188 void convertCallOp(A callOp) { 189 auto fnTy = callOp.getFunctionType(); 190 auto loc = callOp.getLoc(); 191 rewriter->setInsertionPoint(callOp); 192 llvm::SmallVector<mlir::Type> newResTys; 193 llvm::SmallVector<mlir::Type> newInTys; 194 llvm::SmallVector<mlir::Value> newOpers; 195 196 // If the call is indirect, the first argument must still be the function 197 // to call. 198 int dropFront = 0; 199 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 200 if (!callOp.getCallee().hasValue()) { 201 newInTys.push_back(fnTy.getInput(0)); 202 newOpers.push_back(callOp.getOperand(0)); 203 dropFront = 1; 204 } 205 } 206 207 // Determine the rewrite function, `wrap`, for the result value. 208 llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap; 209 if (fnTy.getResults().size() == 1) { 210 mlir::Type ty = fnTy.getResult(0); 211 llvm::TypeSwitch<mlir::Type>(ty) 212 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 213 wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 214 newOpers); 215 }) 216 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 217 wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 218 newOpers); 219 }) 220 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 221 } else if (fnTy.getResults().size() > 1) { 222 TODO(loc, "multiple results not supported yet"); 223 } 224 225 llvm::SmallVector<mlir::Type> trailingInTys; 226 llvm::SmallVector<mlir::Value> trailingOpers; 227 for (auto e : llvm::enumerate( 228 llvm::zip(fnTy.getInputs().drop_front(dropFront), 229 callOp.getOperands().drop_front(dropFront)))) { 230 mlir::Type ty = std::get<0>(e.value()); 231 mlir::Value oper = std::get<1>(e.value()); 232 unsigned index = e.index(); 233 llvm::TypeSwitch<mlir::Type>(ty) 234 .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) { 235 bool sret; 236 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 237 sret = callOp.getCallee() && 238 functionArgIsSRet( 239 index, getModule().lookupSymbol<mlir::func::FuncOp>( 240 *callOp.getCallee())); 241 } else { 242 // TODO: dispatch case; how do we put arguments on a call? 243 // We cannot put both an sret and the dispatch object first. 244 sret = false; 245 TODO(loc, "dispatch + sret not supported yet"); 246 } 247 auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret); 248 auto unbox = rewriter->create<fir::UnboxCharOp>( 249 loc, std::get<mlir::Type>(m[0]), std::get<mlir::Type>(m[1]), 250 oper); 251 // unboxed CHARACTER arguments 252 for (auto e : llvm::enumerate(m)) { 253 unsigned idx = e.index(); 254 auto attr = 255 std::get<fir::CodeGenSpecifics::Attributes>(e.value()); 256 auto argTy = std::get<mlir::Type>(e.value()); 257 if (attr.isAppend()) { 258 trailingInTys.push_back(argTy); 259 trailingOpers.push_back(unbox.getResult(idx)); 260 } else { 261 newInTys.push_back(argTy); 262 newOpers.push_back(unbox.getResult(idx)); 263 } 264 } 265 }) 266 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 267 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 268 }) 269 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 270 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 271 }) 272 .template Case<mlir::TupleType>([&](mlir::TupleType tuple) { 273 if (fir::isCharacterProcedureTuple(tuple)) { 274 mlir::ModuleOp module = getModule(); 275 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 276 if (callOp.getCallee()) { 277 llvm::StringRef charProcAttr = 278 fir::getCharacterProcedureDummyAttrName(); 279 // The charProcAttr attribute is only used as a safety to 280 // confirm that this is a dummy procedure and should be split. 281 // It cannot be used to match because attributes are not 282 // available in case of indirect calls. 283 auto funcOp = module.lookupSymbol<mlir::func::FuncOp>( 284 *callOp.getCallee()); 285 if (funcOp && 286 !funcOp.template getArgAttrOfType<mlir::UnitAttr>( 287 index, charProcAttr)) 288 mlir::emitError(loc, "tuple argument will be split even " 289 "though it does not have the `" + 290 charProcAttr + "` attribute"); 291 } 292 } 293 mlir::Type funcPointerType = tuple.getType(0); 294 mlir::Type lenType = tuple.getType(1); 295 fir::FirOpBuilder builder(*rewriter, fir::getKindMapping(module)); 296 auto [funcPointer, len] = 297 fir::factory::extractCharacterProcedureTuple(builder, loc, 298 oper); 299 newInTys.push_back(funcPointerType); 300 newOpers.push_back(funcPointer); 301 trailingInTys.push_back(lenType); 302 trailingOpers.push_back(len); 303 } else { 304 newInTys.push_back(tuple); 305 newOpers.push_back(oper); 306 } 307 }) 308 .Default([&](mlir::Type ty) { 309 newInTys.push_back(ty); 310 newOpers.push_back(oper); 311 }); 312 } 313 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 314 newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); 315 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 316 fir::CallOp newCall; 317 if (callOp.getCallee().hasValue()) { 318 newCall = rewriter->create<A>(loc, callOp.getCallee().getValue(), 319 newResTys, newOpers); 320 } else { 321 // Force new type on the input operand. 322 newOpers[0].setType(mlir::FunctionType::get( 323 callOp.getContext(), 324 mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys)); 325 newCall = rewriter->create<A>(loc, newResTys, newOpers); 326 } 327 LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n'); 328 if (wrap.hasValue()) 329 replaceOp(callOp, (*wrap)(newCall.getOperation())); 330 else 331 replaceOp(callOp, newCall.getResults()); 332 } else { 333 // A is fir::DispatchOp 334 TODO(loc, "dispatch not implemented"); 335 } 336 } 337 338 // Result type fixup for fir::ComplexType and mlir::ComplexType 339 template <typename A, typename B> 340 void lowerComplexSignatureRes(A cmplx, B &newResTys, B &newInTys) { 341 if (noComplexConversion) { 342 newResTys.push_back(cmplx); 343 } else { 344 for (auto &tup : specifics->complexReturnType(cmplx.getElementType())) { 345 auto argTy = std::get<mlir::Type>(tup); 346 if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet()) 347 newInTys.push_back(argTy); 348 else 349 newResTys.push_back(argTy); 350 } 351 } 352 } 353 354 // Argument type fixup for fir::ComplexType and mlir::ComplexType 355 template <typename A, typename B> 356 void lowerComplexSignatureArg(A cmplx, B &newInTys) { 357 if (noComplexConversion) 358 newInTys.push_back(cmplx); 359 else 360 for (auto &tup : specifics->complexArgumentType(cmplx.getElementType())) 361 newInTys.push_back(std::get<mlir::Type>(tup)); 362 } 363 364 /// Taking the address of a function. Modify the signature as needed. 365 void convertAddrOp(fir::AddrOfOp addrOp) { 366 rewriter->setInsertionPoint(addrOp); 367 auto addrTy = addrOp.getType().cast<mlir::FunctionType>(); 368 llvm::SmallVector<mlir::Type> newResTys; 369 llvm::SmallVector<mlir::Type> newInTys; 370 for (mlir::Type ty : addrTy.getResults()) { 371 llvm::TypeSwitch<mlir::Type>(ty) 372 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 373 lowerComplexSignatureRes(ty, newResTys, newInTys); 374 }) 375 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 376 lowerComplexSignatureRes(ty, newResTys, newInTys); 377 }) 378 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 379 } 380 llvm::SmallVector<mlir::Type> trailingInTys; 381 for (mlir::Type ty : addrTy.getInputs()) { 382 llvm::TypeSwitch<mlir::Type>(ty) 383 .Case<fir::BoxCharType>([&](auto box) { 384 if (noCharacterConversion) { 385 newInTys.push_back(box); 386 } else { 387 for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) { 388 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup); 389 auto argTy = std::get<mlir::Type>(tup); 390 llvm::SmallVector<mlir::Type> &vec = 391 attr.isAppend() ? trailingInTys : newInTys; 392 vec.push_back(argTy); 393 } 394 } 395 }) 396 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 397 lowerComplexSignatureArg(ty, newInTys); 398 }) 399 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 400 lowerComplexSignatureArg(ty, newInTys); 401 }) 402 .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 403 if (fir::isCharacterProcedureTuple(tuple)) { 404 newInTys.push_back(tuple.getType(0)); 405 trailingInTys.push_back(tuple.getType(1)); 406 } else { 407 newInTys.push_back(ty); 408 } 409 }) 410 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 411 } 412 // append trailing input types 413 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 414 // replace this op with a new one with the updated signature 415 auto newTy = rewriter->getFunctionType(newInTys, newResTys); 416 auto newOp = rewriter->create<fir::AddrOfOp>(addrOp.getLoc(), newTy, 417 addrOp.getSymbol()); 418 replaceOp(addrOp, newOp.getResult()); 419 } 420 421 /// Convert the type signatures on all the functions present in the module. 422 /// As the type signature is being changed, this must also update the 423 /// function itself to use any new arguments, etc. 424 mlir::LogicalResult convertTypes(mlir::ModuleOp mod) { 425 for (auto fn : mod.getOps<mlir::func::FuncOp>()) 426 convertSignature(fn); 427 return mlir::success(); 428 } 429 430 /// If the signature does not need any special target-specific converions, 431 /// then it is considered portable for any target, and this function will 432 /// return `true`. Otherwise, the signature is not portable and `false` is 433 /// returned. 434 bool hasPortableSignature(mlir::Type signature) { 435 assert(signature.isa<mlir::FunctionType>()); 436 auto func = signature.dyn_cast<mlir::FunctionType>(); 437 for (auto ty : func.getResults()) 438 if ((ty.isa<fir::BoxCharType>() && !noCharacterConversion) || 439 (fir::isa_complex(ty) && !noComplexConversion)) { 440 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 441 return false; 442 } 443 for (auto ty : func.getInputs()) 444 if (((ty.isa<fir::BoxCharType>() || fir::isCharacterProcedureTuple(ty)) && 445 !noCharacterConversion) || 446 (fir::isa_complex(ty) && !noComplexConversion)) { 447 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 448 return false; 449 } 450 return true; 451 } 452 453 /// Determine if the signature has host associations. The host association 454 /// argument may need special target specific rewriting. 455 static bool hasHostAssociations(mlir::func::FuncOp func) { 456 std::size_t end = func.getFunctionType().getInputs().size(); 457 for (std::size_t i = 0; i < end; ++i) 458 if (func.getArgAttrOfType<mlir::UnitAttr>(i, fir::getHostAssocAttrName())) 459 return true; 460 return false; 461 } 462 463 /// Rewrite the signatures and body of the `FuncOp`s in the module for 464 /// the immediately subsequent target code gen. 465 void convertSignature(mlir::func::FuncOp func) { 466 auto funcTy = func.getFunctionType().cast<mlir::FunctionType>(); 467 if (hasPortableSignature(funcTy) && !hasHostAssociations(func)) 468 return; 469 llvm::SmallVector<mlir::Type> newResTys; 470 llvm::SmallVector<mlir::Type> newInTys; 471 llvm::SmallVector<FixupTy> fixups; 472 473 // Convert return value(s) 474 for (auto ty : funcTy.getResults()) 475 llvm::TypeSwitch<mlir::Type>(ty) 476 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 477 if (noComplexConversion) 478 newResTys.push_back(cmplx); 479 else 480 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 481 }) 482 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 483 if (noComplexConversion) 484 newResTys.push_back(cmplx); 485 else 486 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 487 }) 488 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 489 490 // Convert arguments 491 llvm::SmallVector<mlir::Type> trailingTys; 492 for (auto e : llvm::enumerate(funcTy.getInputs())) { 493 auto ty = e.value(); 494 unsigned index = e.index(); 495 llvm::TypeSwitch<mlir::Type>(ty) 496 .Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) { 497 if (noCharacterConversion) { 498 newInTys.push_back(boxTy); 499 } else { 500 // Convert a CHARACTER argument type. This can involve separating 501 // the pointer and the LEN into two arguments and moving the LEN 502 // argument to the end of the arg list. 503 bool sret = functionArgIsSRet(index, func); 504 for (auto e : llvm::enumerate(specifics->boxcharArgumentType( 505 boxTy.getEleTy(), sret))) { 506 auto &tup = e.value(); 507 auto index = e.index(); 508 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup); 509 auto argTy = std::get<mlir::Type>(tup); 510 if (attr.isAppend()) { 511 trailingTys.push_back(argTy); 512 } else { 513 if (sret) { 514 fixups.emplace_back(FixupTy::Codes::CharPair, 515 newInTys.size(), index); 516 } else { 517 fixups.emplace_back(FixupTy::Codes::Trailing, 518 newInTys.size(), trailingTys.size()); 519 } 520 newInTys.push_back(argTy); 521 } 522 } 523 } 524 }) 525 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 526 if (noComplexConversion) 527 newInTys.push_back(cmplx); 528 else 529 doComplexArg(func, cmplx, newInTys, fixups); 530 }) 531 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 532 if (noComplexConversion) 533 newInTys.push_back(cmplx); 534 else 535 doComplexArg(func, cmplx, newInTys, fixups); 536 }) 537 .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 538 if (fir::isCharacterProcedureTuple(tuple)) { 539 fixups.emplace_back(FixupTy::Codes::TrailingCharProc, 540 newInTys.size(), trailingTys.size()); 541 newInTys.push_back(tuple.getType(0)); 542 trailingTys.push_back(tuple.getType(1)); 543 } else { 544 newInTys.push_back(ty); 545 } 546 }) 547 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 548 if (func.getArgAttrOfType<mlir::UnitAttr>(index, 549 fir::getHostAssocAttrName())) { 550 func.setArgAttr(index, "llvm.nest", rewriter->getUnitAttr()); 551 } 552 } 553 554 if (!func.empty()) { 555 // If the function has a body, then apply the fixups to the arguments and 556 // return ops as required. These fixups are done in place. 557 auto loc = func.getLoc(); 558 const auto fixupSize = fixups.size(); 559 const auto oldArgTys = func.getFunctionType().getInputs(); 560 int offset = 0; 561 for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) { 562 const auto &fixup = fixups[i]; 563 switch (fixup.code) { 564 case FixupTy::Codes::ArgumentAsLoad: { 565 // Argument was pass-by-value, but is now pass-by-reference and 566 // possibly with a different element type. 567 auto newArg = func.front().insertArgument(fixup.index, 568 newInTys[fixup.index], loc); 569 rewriter->setInsertionPointToStart(&func.front()); 570 auto oldArgTy = 571 fir::ReferenceType::get(oldArgTys[fixup.index - offset]); 572 auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, newArg); 573 auto load = rewriter->create<fir::LoadOp>(loc, cast); 574 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 575 func.front().eraseArgument(fixup.index + 1); 576 } break; 577 case FixupTy::Codes::ArgumentType: { 578 // Argument is pass-by-value, but its type has likely been modified to 579 // suit the target ABI convention. 580 auto newArg = func.front().insertArgument(fixup.index, 581 newInTys[fixup.index], loc); 582 rewriter->setInsertionPointToStart(&func.front()); 583 auto mem = 584 rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]); 585 rewriter->create<fir::StoreOp>(loc, newArg, mem); 586 auto oldArgTy = 587 fir::ReferenceType::get(oldArgTys[fixup.index - offset]); 588 auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, mem); 589 mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast); 590 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 591 func.front().eraseArgument(fixup.index + 1); 592 LLVM_DEBUG(llvm::dbgs() 593 << "old argument: " << oldArgTy.getEleTy() 594 << ", repl: " << load << ", new argument: " 595 << func.getArgument(fixup.index).getType() << '\n'); 596 } break; 597 case FixupTy::Codes::CharPair: { 598 // The FIR boxchar argument has been split into a pair of distinct 599 // arguments that are in juxtaposition to each other. 600 auto newArg = func.front().insertArgument(fixup.index, 601 newInTys[fixup.index], loc); 602 if (fixup.second == 1) { 603 rewriter->setInsertionPointToStart(&func.front()); 604 auto boxTy = oldArgTys[fixup.index - offset - fixup.second]; 605 auto box = rewriter->create<fir::EmboxCharOp>( 606 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg); 607 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 608 func.front().eraseArgument(fixup.index + 1); 609 offset++; 610 } 611 } break; 612 case FixupTy::Codes::ReturnAsStore: { 613 // The value being returned is now being returned in memory (callee 614 // stack space) through a hidden reference argument. 615 auto newArg = func.front().insertArgument(fixup.index, 616 newInTys[fixup.index], loc); 617 offset++; 618 func.walk([&](mlir::func::ReturnOp ret) { 619 rewriter->setInsertionPoint(ret); 620 auto oldOper = ret.getOperand(0); 621 auto oldOperTy = fir::ReferenceType::get(oldOper.getType()); 622 auto cast = 623 rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg); 624 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 625 rewriter->create<mlir::func::ReturnOp>(loc); 626 ret.erase(); 627 }); 628 } break; 629 case FixupTy::Codes::ReturnType: { 630 // The function is still returning a value, but its type has likely 631 // changed to suit the target ABI convention. 632 func.walk([&](mlir::func::ReturnOp ret) { 633 rewriter->setInsertionPoint(ret); 634 auto oldOper = ret.getOperand(0); 635 auto oldOperTy = fir::ReferenceType::get(oldOper.getType()); 636 auto mem = 637 rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]); 638 auto cast = rewriter->create<fir::ConvertOp>(loc, oldOperTy, mem); 639 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 640 mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem); 641 rewriter->create<mlir::func::ReturnOp>(loc, load); 642 ret.erase(); 643 }); 644 } break; 645 case FixupTy::Codes::Split: { 646 // The FIR argument has been split into a pair of distinct arguments 647 // that are in juxtaposition to each other. (For COMPLEX value.) 648 auto newArg = func.front().insertArgument(fixup.index, 649 newInTys[fixup.index], loc); 650 if (fixup.second == 1) { 651 rewriter->setInsertionPointToStart(&func.front()); 652 auto cplxTy = oldArgTys[fixup.index - offset - fixup.second]; 653 auto undef = rewriter->create<fir::UndefOp>(loc, cplxTy); 654 auto iTy = rewriter->getIntegerType(32); 655 auto zero = rewriter->getIntegerAttr(iTy, 0); 656 auto one = rewriter->getIntegerAttr(iTy, 1); 657 auto cplx1 = rewriter->create<fir::InsertValueOp>( 658 loc, cplxTy, undef, func.front().getArgument(fixup.index - 1), 659 rewriter->getArrayAttr(zero)); 660 auto cplx = rewriter->create<fir::InsertValueOp>( 661 loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one)); 662 func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx); 663 func.front().eraseArgument(fixup.index + 1); 664 offset++; 665 } 666 } break; 667 case FixupTy::Codes::Trailing: { 668 // The FIR argument has been split into a pair of distinct arguments. 669 // The first part of the pair appears in the original argument 670 // position. The second part of the pair is appended after all the 671 // original arguments. (Boxchar arguments.) 672 auto newBufArg = func.front().insertArgument( 673 fixup.index, newInTys[fixup.index], loc); 674 auto newLenArg = 675 func.front().addArgument(trailingTys[fixup.second], loc); 676 auto boxTy = oldArgTys[fixup.index - offset]; 677 rewriter->setInsertionPointToStart(&func.front()); 678 auto box = rewriter->create<fir::EmboxCharOp>(loc, boxTy, newBufArg, 679 newLenArg); 680 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 681 func.front().eraseArgument(fixup.index + 1); 682 } break; 683 case FixupTy::Codes::TrailingCharProc: { 684 // The FIR character procedure argument tuple must be split into a 685 // pair of distinct arguments. The first part of the pair appears in 686 // the original argument position. The second part of the pair is 687 // appended after all the original arguments. 688 auto newProcPointerArg = func.front().insertArgument( 689 fixup.index, newInTys[fixup.index], loc); 690 auto newLenArg = 691 func.front().addArgument(trailingTys[fixup.second], loc); 692 auto tupleType = oldArgTys[fixup.index - offset]; 693 rewriter->setInsertionPointToStart(&func.front()); 694 fir::FirOpBuilder builder(*rewriter, 695 fir::getKindMapping(getModule())); 696 auto tuple = fir::factory::createCharacterProcedureTuple( 697 builder, loc, tupleType, newProcPointerArg, newLenArg); 698 func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple); 699 func.front().eraseArgument(fixup.index + 1); 700 } break; 701 } 702 } 703 } 704 705 // Set the new type and finalize the arguments, etc. 706 newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end()); 707 auto newFuncTy = 708 mlir::FunctionType::get(func.getContext(), newInTys, newResTys); 709 LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n'); 710 func.setType(newFuncTy); 711 712 for (auto &fixup : fixups) 713 if (fixup.finalizer) 714 (*fixup.finalizer)(func); 715 } 716 717 inline bool functionArgIsSRet(unsigned index, mlir::func::FuncOp func) { 718 if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret")) 719 return true; 720 return false; 721 } 722 723 /// Convert a complex return value. This can involve converting the return 724 /// value to a "hidden" first argument or packing the complex into a wide 725 /// GPR. 726 template <typename A, typename B, typename C> 727 void doComplexReturn(mlir::func::FuncOp func, A cmplx, B &newResTys, 728 B &newInTys, C &fixups) { 729 if (noComplexConversion) { 730 newResTys.push_back(cmplx); 731 return; 732 } 733 auto m = specifics->complexReturnType(cmplx.getElementType()); 734 assert(m.size() == 1); 735 auto &tup = m[0]; 736 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup); 737 auto argTy = std::get<mlir::Type>(tup); 738 if (attr.isSRet()) { 739 unsigned argNo = newInTys.size(); 740 fixups.emplace_back( 741 FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) { 742 func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); 743 }); 744 newInTys.push_back(argTy); 745 return; 746 } 747 fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size()); 748 newResTys.push_back(argTy); 749 } 750 751 /// Convert a complex argument value. This can involve storing the value to 752 /// a temporary memory location or factoring the value into two distinct 753 /// arguments. 754 template <typename A, typename B, typename C> 755 void doComplexArg(mlir::func::FuncOp func, A cmplx, B &newInTys, C &fixups) { 756 if (noComplexConversion) { 757 newInTys.push_back(cmplx); 758 return; 759 } 760 auto m = specifics->complexArgumentType(cmplx.getElementType()); 761 const auto fixupCode = 762 m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType; 763 for (auto e : llvm::enumerate(m)) { 764 auto &tup = e.value(); 765 auto index = e.index(); 766 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup); 767 auto argTy = std::get<mlir::Type>(tup); 768 auto argNo = newInTys.size(); 769 if (attr.isByVal()) { 770 if (auto align = attr.getAlignment()) 771 fixups.emplace_back( 772 FixupTy::Codes::ArgumentAsLoad, argNo, 773 [=](mlir::func::FuncOp func) { 774 func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr()); 775 func.setArgAttr(argNo, "llvm.align", 776 rewriter->getIntegerAttr( 777 rewriter->getIntegerType(32), align)); 778 }); 779 else 780 fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(), 781 [=](mlir::func::FuncOp func) { 782 func.setArgAttr(argNo, "llvm.byval", 783 rewriter->getUnitAttr()); 784 }); 785 } else { 786 if (auto align = attr.getAlignment()) 787 fixups.emplace_back( 788 fixupCode, argNo, index, [=](mlir::func::FuncOp func) { 789 func.setArgAttr(argNo, "llvm.align", 790 rewriter->getIntegerAttr( 791 rewriter->getIntegerType(32), align)); 792 }); 793 else 794 fixups.emplace_back(fixupCode, argNo, index); 795 } 796 newInTys.push_back(argTy); 797 } 798 } 799 800 private: 801 // Replace `op` and remove it. 802 void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { 803 op->replaceAllUsesWith(newValues); 804 op->dropAllReferences(); 805 op->erase(); 806 } 807 808 inline void setMembers(fir::CodeGenSpecifics *s, mlir::OpBuilder *r) { 809 specifics = s; 810 rewriter = r; 811 } 812 813 inline void clearMembers() { setMembers(nullptr, nullptr); } 814 815 fir::CodeGenSpecifics *specifics = nullptr; 816 mlir::OpBuilder *rewriter = nullptr; 817 }; // namespace 818 } // namespace 819 820 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 821 fir::createFirTargetRewritePass(const fir::TargetRewriteOptions &options) { 822 return std::make_unique<TargetRewrite>(options); 823 } 824