1 //===-- TargetRewrite.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest. 10 // LLVM expects different lowering idioms to be used for distinct target 11 // triples. These distinctions are handled by this pass. 12 // 13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "PassDetail.h" 18 #include "Target.h" 19 #include "flang/Lower/Todo.h" 20 #include "flang/Optimizer/Builder/Character.h" 21 #include "flang/Optimizer/CodeGen/CodeGen.h" 22 #include "flang/Optimizer/Dialect/FIRDialect.h" 23 #include "flang/Optimizer/Dialect/FIROps.h" 24 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 25 #include "flang/Optimizer/Dialect/FIRType.h" 26 #include "flang/Optimizer/Support/FIRContext.h" 27 #include "mlir/Transforms/DialectConversion.h" 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/ADT/TypeSwitch.h" 30 #include "llvm/Support/Debug.h" 31 32 using namespace fir; 33 34 #define DEBUG_TYPE "flang-target-rewrite" 35 36 namespace { 37 38 /// Fixups for updating a FuncOp's arguments and return values. 39 struct FixupTy { 40 enum class Codes { 41 ArgumentAsLoad, 42 ArgumentType, 43 CharPair, 44 ReturnAsStore, 45 ReturnType, 46 Split, 47 Trailing, 48 TrailingCharProc 49 }; 50 51 FixupTy(Codes code, std::size_t index, std::size_t second = 0) 52 : code{code}, index{index}, second{second} {} 53 FixupTy(Codes code, std::size_t index, 54 std::function<void(mlir::FuncOp)> &&finalizer) 55 : code{code}, index{index}, finalizer{finalizer} {} 56 FixupTy(Codes code, std::size_t index, std::size_t second, 57 std::function<void(mlir::FuncOp)> &&finalizer) 58 : code{code}, index{index}, second{second}, finalizer{finalizer} {} 59 60 Codes code; 61 std::size_t index; 62 std::size_t second{}; 63 llvm::Optional<std::function<void(mlir::FuncOp)>> finalizer{}; 64 }; // namespace 65 66 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code 67 /// generation that traverses the FIR and modifies types and operations to a 68 /// form that is appropriate for the specific target. LLVM IR has specific 69 /// idioms that are used for distinct target processor and ABI combinations. 70 class TargetRewrite : public TargetRewriteBase<TargetRewrite> { 71 public: 72 TargetRewrite(const TargetRewriteOptions &options) { 73 noCharacterConversion = options.noCharacterConversion; 74 noComplexConversion = options.noComplexConversion; 75 } 76 77 void runOnOperation() override final { 78 auto &context = getContext(); 79 mlir::OpBuilder rewriter(&context); 80 81 auto mod = getModule(); 82 if (!forcedTargetTriple.empty()) 83 setTargetTriple(mod, forcedTargetTriple); 84 85 auto specifics = CodeGenSpecifics::get(getOperation().getContext(), 86 getTargetTriple(getOperation()), 87 getKindMapping(getOperation())); 88 setMembers(specifics.get(), &rewriter); 89 90 // Perform type conversion on signatures and call sites. 91 if (mlir::failed(convertTypes(mod))) { 92 mlir::emitError(mlir::UnknownLoc::get(&context), 93 "error in converting types to target abi"); 94 signalPassFailure(); 95 } 96 97 // Convert ops in target-specific patterns. 98 mod.walk([&](mlir::Operation *op) { 99 if (auto call = dyn_cast<fir::CallOp>(op)) { 100 if (!hasPortableSignature(call.getFunctionType())) 101 convertCallOp(call); 102 } else if (auto dispatch = dyn_cast<DispatchOp>(op)) { 103 if (!hasPortableSignature(dispatch.getFunctionType())) 104 convertCallOp(dispatch); 105 } else if (auto addr = dyn_cast<AddrOfOp>(op)) { 106 if (addr.getType().isa<mlir::FunctionType>() && 107 !hasPortableSignature(addr.getType())) 108 convertAddrOp(addr); 109 } 110 }); 111 112 clearMembers(); 113 } 114 115 mlir::ModuleOp getModule() { return getOperation(); } 116 117 template <typename A, typename B, typename C> 118 std::function<mlir::Value(mlir::Operation *)> 119 rewriteCallComplexResultType(A ty, B &newResTys, B &newInTys, C &newOpers) { 120 auto m = specifics->complexReturnType(ty.getElementType()); 121 // Currently targets mandate COMPLEX is a single aggregate or packed 122 // scalar, including the sret case. 123 assert(m.size() == 1 && "target lowering of complex return not supported"); 124 auto resTy = std::get<mlir::Type>(m[0]); 125 auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); 126 auto loc = mlir::UnknownLoc::get(resTy.getContext()); 127 if (attr.isSRet()) { 128 assert(isa_ref_type(resTy)); 129 mlir::Value stack = 130 rewriter->create<fir::AllocaOp>(loc, dyn_cast_ptrEleTy(resTy)); 131 newInTys.push_back(resTy); 132 newOpers.push_back(stack); 133 return [=](mlir::Operation *) -> mlir::Value { 134 auto memTy = ReferenceType::get(ty); 135 auto cast = rewriter->create<ConvertOp>(loc, memTy, stack); 136 return rewriter->create<fir::LoadOp>(loc, cast); 137 }; 138 } 139 newResTys.push_back(resTy); 140 return [=](mlir::Operation *call) -> mlir::Value { 141 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 142 rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem); 143 auto memTy = ReferenceType::get(ty); 144 auto cast = rewriter->create<ConvertOp>(loc, memTy, mem); 145 return rewriter->create<fir::LoadOp>(loc, cast); 146 }; 147 } 148 149 template <typename A, typename B, typename C> 150 void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys, 151 C &newOpers) { 152 auto m = specifics->complexArgumentType(ty.getElementType()); 153 auto *ctx = ty.getContext(); 154 auto loc = mlir::UnknownLoc::get(ctx); 155 if (m.size() == 1) { 156 // COMPLEX is a single aggregate 157 auto resTy = std::get<mlir::Type>(m[0]); 158 auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); 159 auto oldRefTy = ReferenceType::get(ty); 160 if (attr.isByVal()) { 161 auto mem = rewriter->create<fir::AllocaOp>(loc, ty); 162 rewriter->create<fir::StoreOp>(loc, oper, mem); 163 newOpers.push_back(rewriter->create<ConvertOp>(loc, resTy, mem)); 164 } else { 165 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 166 auto cast = rewriter->create<ConvertOp>(loc, oldRefTy, mem); 167 rewriter->create<fir::StoreOp>(loc, oper, cast); 168 newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem)); 169 } 170 newInTys.push_back(resTy); 171 } else { 172 assert(m.size() == 2); 173 // COMPLEX is split into 2 separate arguments 174 auto iTy = rewriter->getIntegerType(32); 175 for (auto e : llvm::enumerate(m)) { 176 auto &tup = e.value(); 177 auto ty = std::get<mlir::Type>(tup); 178 auto index = e.index(); 179 auto idx = rewriter->getIntegerAttr(iTy, index); 180 auto val = rewriter->create<ExtractValueOp>( 181 loc, ty, oper, rewriter->getArrayAttr(idx)); 182 newInTys.push_back(ty); 183 newOpers.push_back(val); 184 } 185 } 186 } 187 188 // Convert fir.call and fir.dispatch Ops. 189 template <typename A> 190 void convertCallOp(A callOp) { 191 auto fnTy = callOp.getFunctionType(); 192 auto loc = callOp.getLoc(); 193 rewriter->setInsertionPoint(callOp); 194 llvm::SmallVector<mlir::Type> newResTys; 195 llvm::SmallVector<mlir::Type> newInTys; 196 llvm::SmallVector<mlir::Value> newOpers; 197 198 // If the call is indirect, the first argument must still be the function 199 // to call. 200 int dropFront = 0; 201 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 202 if (!callOp.callee().hasValue()) { 203 newInTys.push_back(fnTy.getInput(0)); 204 newOpers.push_back(callOp.getOperand(0)); 205 dropFront = 1; 206 } 207 } 208 209 // Determine the rewrite function, `wrap`, for the result value. 210 llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap; 211 if (fnTy.getResults().size() == 1) { 212 mlir::Type ty = fnTy.getResult(0); 213 llvm::TypeSwitch<mlir::Type>(ty) 214 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 215 wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 216 newOpers); 217 }) 218 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 219 wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 220 newOpers); 221 }) 222 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 223 } else if (fnTy.getResults().size() > 1) { 224 TODO(loc, "multiple results not supported yet"); 225 } 226 227 llvm::SmallVector<mlir::Type> trailingInTys; 228 llvm::SmallVector<mlir::Value> trailingOpers; 229 for (auto e : llvm::enumerate( 230 llvm::zip(fnTy.getInputs().drop_front(dropFront), 231 callOp.getOperands().drop_front(dropFront)))) { 232 mlir::Type ty = std::get<0>(e.value()); 233 mlir::Value oper = std::get<1>(e.value()); 234 unsigned index = e.index(); 235 llvm::TypeSwitch<mlir::Type>(ty) 236 .template Case<BoxCharType>([&](BoxCharType boxTy) { 237 bool sret; 238 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 239 sret = callOp.callee() && 240 functionArgIsSRet(index, 241 getModule().lookupSymbol<mlir::FuncOp>( 242 *callOp.callee())); 243 } else { 244 // TODO: dispatch case; how do we put arguments on a call? 245 // We cannot put both an sret and the dispatch object first. 246 sret = false; 247 TODO(loc, "dispatch + sret not supported yet"); 248 } 249 auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret); 250 auto unbox = 251 rewriter->create<UnboxCharOp>(loc, std::get<mlir::Type>(m[0]), 252 std::get<mlir::Type>(m[1]), oper); 253 // unboxed CHARACTER arguments 254 for (auto e : llvm::enumerate(m)) { 255 unsigned idx = e.index(); 256 auto attr = std::get<CodeGenSpecifics::Attributes>(e.value()); 257 auto argTy = std::get<mlir::Type>(e.value()); 258 if (attr.isAppend()) { 259 trailingInTys.push_back(argTy); 260 trailingOpers.push_back(unbox.getResult(idx)); 261 } else { 262 newInTys.push_back(argTy); 263 newOpers.push_back(unbox.getResult(idx)); 264 } 265 } 266 }) 267 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 268 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 269 }) 270 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 271 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 272 }) 273 .template Case<mlir::TupleType>([&](mlir::TupleType tuple) { 274 if (factory::isCharacterProcedureTuple(tuple)) { 275 mlir::ModuleOp module = getModule(); 276 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 277 if (callOp.callee()) { 278 llvm::StringRef charProcAttr = 279 fir::getCharacterProcedureDummyAttrName(); 280 // The charProcAttr attribute is only used as a safety to 281 // confirm that this is a dummy procedure and should be split. 282 // It cannot be used to match because attributes are not 283 // available in case of indirect calls. 284 auto funcOp = 285 module.lookupSymbol<mlir::FuncOp>(*callOp.callee()); 286 if (funcOp && 287 !funcOp.template getArgAttrOfType<mlir::UnitAttr>( 288 index, charProcAttr)) 289 mlir::emitError(loc, "tuple argument will be split even " 290 "though it does not have the `" + 291 charProcAttr + "` attribute"); 292 } 293 } 294 mlir::Type funcPointerType = tuple.getType(0); 295 mlir::Type lenType = tuple.getType(1); 296 FirOpBuilder builder(*rewriter, getKindMapping(module)); 297 auto [funcPointer, len] = 298 factory::extractCharacterProcedureTuple(builder, loc, oper); 299 newInTys.push_back(funcPointerType); 300 newOpers.push_back(funcPointer); 301 trailingInTys.push_back(lenType); 302 trailingOpers.push_back(len); 303 } else { 304 newInTys.push_back(tuple); 305 newOpers.push_back(oper); 306 } 307 }) 308 .Default([&](mlir::Type ty) { 309 newInTys.push_back(ty); 310 newOpers.push_back(oper); 311 }); 312 } 313 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 314 newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); 315 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 316 fir::CallOp newCall; 317 if (callOp.callee().hasValue()) { 318 newCall = rewriter->create<A>(loc, callOp.callee().getValue(), 319 newResTys, newOpers); 320 } else { 321 // Force new type on the input operand. 322 newOpers[0].setType(mlir::FunctionType::get( 323 callOp.getContext(), 324 mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys)); 325 newCall = rewriter->create<A>(loc, newResTys, newOpers); 326 } 327 LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n'); 328 if (wrap.hasValue()) 329 replaceOp(callOp, (*wrap)(newCall.getOperation())); 330 else 331 replaceOp(callOp, newCall.getResults()); 332 } else { 333 // A is fir::DispatchOp 334 TODO(loc, "dispatch not implemented"); 335 } 336 } 337 338 // Result type fixup for fir::ComplexType and mlir::ComplexType 339 template <typename A, typename B> 340 void lowerComplexSignatureRes(A cmplx, B &newResTys, B &newInTys) { 341 if (noComplexConversion) { 342 newResTys.push_back(cmplx); 343 } else { 344 for (auto &tup : specifics->complexReturnType(cmplx.getElementType())) { 345 auto argTy = std::get<mlir::Type>(tup); 346 if (std::get<CodeGenSpecifics::Attributes>(tup).isSRet()) 347 newInTys.push_back(argTy); 348 else 349 newResTys.push_back(argTy); 350 } 351 } 352 } 353 354 // Argument type fixup for fir::ComplexType and mlir::ComplexType 355 template <typename A, typename B> 356 void lowerComplexSignatureArg(A cmplx, B &newInTys) { 357 if (noComplexConversion) 358 newInTys.push_back(cmplx); 359 else 360 for (auto &tup : specifics->complexArgumentType(cmplx.getElementType())) 361 newInTys.push_back(std::get<mlir::Type>(tup)); 362 } 363 364 /// Taking the address of a function. Modify the signature as needed. 365 void convertAddrOp(AddrOfOp addrOp) { 366 rewriter->setInsertionPoint(addrOp); 367 auto addrTy = addrOp.getType().cast<mlir::FunctionType>(); 368 llvm::SmallVector<mlir::Type> newResTys; 369 llvm::SmallVector<mlir::Type> newInTys; 370 for (mlir::Type ty : addrTy.getResults()) { 371 llvm::TypeSwitch<mlir::Type>(ty) 372 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 373 lowerComplexSignatureRes(ty, newResTys, newInTys); 374 }) 375 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 376 lowerComplexSignatureRes(ty, newResTys, newInTys); 377 }) 378 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 379 } 380 llvm::SmallVector<mlir::Type> trailingInTys; 381 for (mlir::Type ty : addrTy.getInputs()) { 382 llvm::TypeSwitch<mlir::Type>(ty) 383 .Case<BoxCharType>([&](BoxCharType box) { 384 if (noCharacterConversion) { 385 newInTys.push_back(box); 386 } else { 387 for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) { 388 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 389 auto argTy = std::get<mlir::Type>(tup); 390 llvm::SmallVector<mlir::Type> &vec = 391 attr.isAppend() ? trailingInTys : newInTys; 392 vec.push_back(argTy); 393 } 394 } 395 }) 396 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 397 lowerComplexSignatureArg(ty, newInTys); 398 }) 399 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 400 lowerComplexSignatureArg(ty, newInTys); 401 }) 402 .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 403 if (factory::isCharacterProcedureTuple(tuple)) { 404 newInTys.push_back(tuple.getType(0)); 405 trailingInTys.push_back(tuple.getType(1)); 406 } else { 407 newInTys.push_back(ty); 408 } 409 }) 410 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 411 } 412 // append trailing input types 413 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 414 // replace this op with a new one with the updated signature 415 auto newTy = rewriter->getFunctionType(newInTys, newResTys); 416 auto newOp = 417 rewriter->create<AddrOfOp>(addrOp.getLoc(), newTy, addrOp.symbol()); 418 replaceOp(addrOp, newOp.getResult()); 419 } 420 421 /// Convert the type signatures on all the functions present in the module. 422 /// As the type signature is being changed, this must also update the 423 /// function itself to use any new arguments, etc. 424 mlir::LogicalResult convertTypes(mlir::ModuleOp mod) { 425 for (auto fn : mod.getOps<mlir::FuncOp>()) 426 convertSignature(fn); 427 return mlir::success(); 428 } 429 430 /// If the signature does not need any special target-specific converions, 431 /// then it is considered portable for any target, and this function will 432 /// return `true`. Otherwise, the signature is not portable and `false` is 433 /// returned. 434 bool hasPortableSignature(mlir::Type signature) { 435 assert(signature.isa<mlir::FunctionType>()); 436 auto func = signature.dyn_cast<mlir::FunctionType>(); 437 for (auto ty : func.getResults()) 438 if ((ty.isa<BoxCharType>() && !noCharacterConversion) || 439 (isa_complex(ty) && !noComplexConversion)) { 440 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 441 return false; 442 } 443 for (auto ty : func.getInputs()) 444 if (((ty.isa<BoxCharType>() || factory::isCharacterProcedureTuple(ty)) && 445 !noCharacterConversion) || 446 (isa_complex(ty) && !noComplexConversion)) { 447 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 448 return false; 449 } 450 return true; 451 } 452 453 /// Rewrite the signatures and body of the `FuncOp`s in the module for 454 /// the immediately subsequent target code gen. 455 void convertSignature(mlir::FuncOp func) { 456 auto funcTy = func.getType().cast<mlir::FunctionType>(); 457 if (hasPortableSignature(funcTy)) 458 return; 459 llvm::SmallVector<mlir::Type> newResTys; 460 llvm::SmallVector<mlir::Type> newInTys; 461 llvm::SmallVector<FixupTy> fixups; 462 463 // Convert return value(s) 464 for (auto ty : funcTy.getResults()) 465 llvm::TypeSwitch<mlir::Type>(ty) 466 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 467 if (noComplexConversion) 468 newResTys.push_back(cmplx); 469 else 470 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 471 }) 472 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 473 if (noComplexConversion) 474 newResTys.push_back(cmplx); 475 else 476 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 477 }) 478 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 479 480 // Convert arguments 481 llvm::SmallVector<mlir::Type> trailingTys; 482 for (auto e : llvm::enumerate(funcTy.getInputs())) { 483 auto ty = e.value(); 484 unsigned index = e.index(); 485 llvm::TypeSwitch<mlir::Type>(ty) 486 .Case<BoxCharType>([&](BoxCharType boxTy) { 487 if (noCharacterConversion) { 488 newInTys.push_back(boxTy); 489 } else { 490 // Convert a CHARACTER argument type. This can involve separating 491 // the pointer and the LEN into two arguments and moving the LEN 492 // argument to the end of the arg list. 493 bool sret = functionArgIsSRet(index, func); 494 for (auto e : llvm::enumerate(specifics->boxcharArgumentType( 495 boxTy.getEleTy(), sret))) { 496 auto &tup = e.value(); 497 auto index = e.index(); 498 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 499 auto argTy = std::get<mlir::Type>(tup); 500 if (attr.isAppend()) { 501 trailingTys.push_back(argTy); 502 } else { 503 if (sret) { 504 fixups.emplace_back(FixupTy::Codes::CharPair, 505 newInTys.size(), index); 506 } else { 507 fixups.emplace_back(FixupTy::Codes::Trailing, 508 newInTys.size(), trailingTys.size()); 509 } 510 newInTys.push_back(argTy); 511 } 512 } 513 } 514 }) 515 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 516 if (noComplexConversion) 517 newInTys.push_back(cmplx); 518 else 519 doComplexArg(func, cmplx, newInTys, fixups); 520 }) 521 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 522 if (noComplexConversion) 523 newInTys.push_back(cmplx); 524 else 525 doComplexArg(func, cmplx, newInTys, fixups); 526 }) 527 .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 528 if (factory::isCharacterProcedureTuple(tuple)) { 529 fixups.emplace_back(FixupTy::Codes::TrailingCharProc, 530 newInTys.size(), trailingTys.size()); 531 newInTys.push_back(tuple.getType(0)); 532 trailingTys.push_back(tuple.getType(1)); 533 } else { 534 newInTys.push_back(ty); 535 } 536 }) 537 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 538 } 539 540 if (!func.empty()) { 541 // If the function has a body, then apply the fixups to the arguments and 542 // return ops as required. These fixups are done in place. 543 auto loc = func.getLoc(); 544 const auto fixupSize = fixups.size(); 545 const auto oldArgTys = func.getType().getInputs(); 546 int offset = 0; 547 for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) { 548 const auto &fixup = fixups[i]; 549 switch (fixup.code) { 550 case FixupTy::Codes::ArgumentAsLoad: { 551 // Argument was pass-by-value, but is now pass-by-reference and 552 // possibly with a different element type. 553 auto newArg = func.front().insertArgument(fixup.index, 554 newInTys[fixup.index], loc); 555 rewriter->setInsertionPointToStart(&func.front()); 556 auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); 557 auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, newArg); 558 auto load = rewriter->create<fir::LoadOp>(loc, cast); 559 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 560 func.front().eraseArgument(fixup.index + 1); 561 } break; 562 case FixupTy::Codes::ArgumentType: { 563 // Argument is pass-by-value, but its type has likely been modified to 564 // suit the target ABI convention. 565 auto newArg = func.front().insertArgument(fixup.index, 566 newInTys[fixup.index], loc); 567 rewriter->setInsertionPointToStart(&func.front()); 568 auto mem = 569 rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]); 570 rewriter->create<fir::StoreOp>(loc, newArg, mem); 571 auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); 572 auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, mem); 573 mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast); 574 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 575 func.front().eraseArgument(fixup.index + 1); 576 LLVM_DEBUG(llvm::dbgs() 577 << "old argument: " << oldArgTy.getEleTy() 578 << ", repl: " << load << ", new argument: " 579 << func.getArgument(fixup.index).getType() << '\n'); 580 } break; 581 case FixupTy::Codes::CharPair: { 582 // The FIR boxchar argument has been split into a pair of distinct 583 // arguments that are in juxtaposition to each other. 584 auto newArg = func.front().insertArgument(fixup.index, 585 newInTys[fixup.index], loc); 586 if (fixup.second == 1) { 587 rewriter->setInsertionPointToStart(&func.front()); 588 auto boxTy = oldArgTys[fixup.index - offset - fixup.second]; 589 auto box = rewriter->create<EmboxCharOp>( 590 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg); 591 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 592 func.front().eraseArgument(fixup.index + 1); 593 offset++; 594 } 595 } break; 596 case FixupTy::Codes::ReturnAsStore: { 597 // The value being returned is now being returned in memory (callee 598 // stack space) through a hidden reference argument. 599 auto newArg = func.front().insertArgument(fixup.index, 600 newInTys[fixup.index], loc); 601 offset++; 602 func.walk([&](mlir::ReturnOp ret) { 603 rewriter->setInsertionPoint(ret); 604 auto oldOper = ret.getOperand(0); 605 auto oldOperTy = ReferenceType::get(oldOper.getType()); 606 auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, newArg); 607 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 608 rewriter->create<mlir::ReturnOp>(loc); 609 ret.erase(); 610 }); 611 } break; 612 case FixupTy::Codes::ReturnType: { 613 // The function is still returning a value, but its type has likely 614 // changed to suit the target ABI convention. 615 func.walk([&](mlir::ReturnOp ret) { 616 rewriter->setInsertionPoint(ret); 617 auto oldOper = ret.getOperand(0); 618 auto oldOperTy = ReferenceType::get(oldOper.getType()); 619 auto mem = 620 rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]); 621 auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, mem); 622 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 623 mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem); 624 rewriter->create<mlir::ReturnOp>(loc, load); 625 ret.erase(); 626 }); 627 } break; 628 case FixupTy::Codes::Split: { 629 // The FIR argument has been split into a pair of distinct arguments 630 // that are in juxtaposition to each other. (For COMPLEX value.) 631 auto newArg = func.front().insertArgument(fixup.index, 632 newInTys[fixup.index], loc); 633 if (fixup.second == 1) { 634 rewriter->setInsertionPointToStart(&func.front()); 635 auto cplxTy = oldArgTys[fixup.index - offset - fixup.second]; 636 auto undef = rewriter->create<UndefOp>(loc, cplxTy); 637 auto iTy = rewriter->getIntegerType(32); 638 auto zero = rewriter->getIntegerAttr(iTy, 0); 639 auto one = rewriter->getIntegerAttr(iTy, 1); 640 auto cplx1 = rewriter->create<InsertValueOp>( 641 loc, cplxTy, undef, func.front().getArgument(fixup.index - 1), 642 rewriter->getArrayAttr(zero)); 643 auto cplx = rewriter->create<InsertValueOp>( 644 loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one)); 645 func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx); 646 func.front().eraseArgument(fixup.index + 1); 647 offset++; 648 } 649 } break; 650 case FixupTy::Codes::Trailing: { 651 // The FIR argument has been split into a pair of distinct arguments. 652 // The first part of the pair appears in the original argument 653 // position. The second part of the pair is appended after all the 654 // original arguments. (Boxchar arguments.) 655 auto newBufArg = func.front().insertArgument( 656 fixup.index, newInTys[fixup.index], loc); 657 auto newLenArg = 658 func.front().addArgument(trailingTys[fixup.second], loc); 659 auto boxTy = oldArgTys[fixup.index - offset]; 660 rewriter->setInsertionPointToStart(&func.front()); 661 auto box = 662 rewriter->create<EmboxCharOp>(loc, boxTy, newBufArg, newLenArg); 663 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 664 func.front().eraseArgument(fixup.index + 1); 665 } break; 666 case FixupTy::Codes::TrailingCharProc: { 667 // The FIR character procedure argument tuple has been split into a 668 // pair of distinct arguments. The first part of the pair appears in 669 // the original argument position. The second part of the pair is 670 // appended after all the original arguments. 671 auto newProcPointerArg = func.front().insertArgument( 672 fixup.index, newInTys[fixup.index], loc); 673 auto newLenArg = 674 func.front().addArgument(trailingTys[fixup.second], loc); 675 auto tupleType = oldArgTys[fixup.index - offset]; 676 rewriter->setInsertionPointToStart(&func.front()); 677 FirOpBuilder builder(*rewriter, getKindMapping(getModule())); 678 auto tuple = factory::createCharacterProcedureTuple( 679 builder, loc, tupleType, newProcPointerArg, newLenArg); 680 func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple); 681 func.front().eraseArgument(fixup.index + 1); 682 } break; 683 } 684 } 685 } 686 687 // Set the new type and finalize the arguments, etc. 688 newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end()); 689 auto newFuncTy = 690 mlir::FunctionType::get(func.getContext(), newInTys, newResTys); 691 LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n'); 692 func.setType(newFuncTy); 693 694 for (auto &fixup : fixups) 695 if (fixup.finalizer) 696 (*fixup.finalizer)(func); 697 } 698 699 inline bool functionArgIsSRet(unsigned index, mlir::FuncOp func) { 700 if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret")) 701 return true; 702 return false; 703 } 704 705 /// Convert a complex return value. This can involve converting the return 706 /// value to a "hidden" first argument or packing the complex into a wide 707 /// GPR. 708 template <typename A, typename B, typename C> 709 void doComplexReturn(mlir::FuncOp func, A cmplx, B &newResTys, B &newInTys, 710 C &fixups) { 711 if (noComplexConversion) { 712 newResTys.push_back(cmplx); 713 return; 714 } 715 auto m = specifics->complexReturnType(cmplx.getElementType()); 716 assert(m.size() == 1); 717 auto &tup = m[0]; 718 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 719 auto argTy = std::get<mlir::Type>(tup); 720 if (attr.isSRet()) { 721 unsigned argNo = newInTys.size(); 722 fixups.emplace_back( 723 FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::FuncOp func) { 724 func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); 725 }); 726 newInTys.push_back(argTy); 727 return; 728 } 729 fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size()); 730 newResTys.push_back(argTy); 731 } 732 733 /// Convert a complex argument value. This can involve storing the value to 734 /// a temporary memory location or factoring the value into two distinct 735 /// arguments. 736 template <typename A, typename B, typename C> 737 void doComplexArg(mlir::FuncOp func, A cmplx, B &newInTys, C &fixups) { 738 if (noComplexConversion) { 739 newInTys.push_back(cmplx); 740 return; 741 } 742 auto m = specifics->complexArgumentType(cmplx.getElementType()); 743 const auto fixupCode = 744 m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType; 745 for (auto e : llvm::enumerate(m)) { 746 auto &tup = e.value(); 747 auto index = e.index(); 748 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 749 auto argTy = std::get<mlir::Type>(tup); 750 auto argNo = newInTys.size(); 751 if (attr.isByVal()) { 752 if (auto align = attr.getAlignment()) 753 fixups.emplace_back( 754 FixupTy::Codes::ArgumentAsLoad, argNo, [=](mlir::FuncOp func) { 755 func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr()); 756 func.setArgAttr(argNo, "llvm.align", 757 rewriter->getIntegerAttr( 758 rewriter->getIntegerType(32), align)); 759 }); 760 else 761 fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(), 762 [=](mlir::FuncOp func) { 763 func.setArgAttr(argNo, "llvm.byval", 764 rewriter->getUnitAttr()); 765 }); 766 } else { 767 if (auto align = attr.getAlignment()) 768 fixups.emplace_back(fixupCode, argNo, index, [=](mlir::FuncOp func) { 769 func.setArgAttr( 770 argNo, "llvm.align", 771 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align)); 772 }); 773 else 774 fixups.emplace_back(fixupCode, argNo, index); 775 } 776 newInTys.push_back(argTy); 777 } 778 } 779 780 private: 781 // Replace `op` and remove it. 782 void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { 783 op->replaceAllUsesWith(newValues); 784 op->dropAllReferences(); 785 op->erase(); 786 } 787 788 inline void setMembers(CodeGenSpecifics *s, mlir::OpBuilder *r) { 789 specifics = s; 790 rewriter = r; 791 } 792 793 inline void clearMembers() { setMembers(nullptr, nullptr); } 794 795 CodeGenSpecifics *specifics{}; 796 mlir::OpBuilder *rewriter; 797 }; // namespace 798 } // namespace 799 800 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 801 fir::createFirTargetRewritePass(const TargetRewriteOptions &options) { 802 return std::make_unique<TargetRewrite>(options); 803 } 804