1 //===-- TargetRewrite.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest. 10 // LLVM expects different lowering idioms to be used for distinct target 11 // triples. These distinctions are handled by this pass. 12 // 13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "PassDetail.h" 18 #include "Target.h" 19 #include "flang/Lower/Todo.h" 20 #include "flang/Optimizer/Builder/Character.h" 21 #include "flang/Optimizer/CodeGen/CodeGen.h" 22 #include "flang/Optimizer/Dialect/FIRDialect.h" 23 #include "flang/Optimizer/Dialect/FIROps.h" 24 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 25 #include "flang/Optimizer/Dialect/FIRType.h" 26 #include "flang/Optimizer/Support/FIRContext.h" 27 #include "mlir/Transforms/DialectConversion.h" 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/ADT/TypeSwitch.h" 30 #include "llvm/Support/Debug.h" 31 32 using namespace fir; 33 using namespace mlir; 34 35 #define DEBUG_TYPE "flang-target-rewrite" 36 37 namespace { 38 39 /// Fixups for updating a FuncOp's arguments and return values. 40 struct FixupTy { 41 enum class Codes { 42 ArgumentAsLoad, 43 ArgumentType, 44 CharPair, 45 ReturnAsStore, 46 ReturnType, 47 Split, 48 Trailing, 49 TrailingCharProc 50 }; 51 52 FixupTy(Codes code, std::size_t index, std::size_t second = 0) 53 : code{code}, index{index}, second{second} {} 54 FixupTy(Codes code, std::size_t index, 55 std::function<void(mlir::FuncOp)> &&finalizer) 56 : code{code}, index{index}, finalizer{finalizer} {} 57 FixupTy(Codes code, std::size_t index, std::size_t second, 58 std::function<void(mlir::FuncOp)> &&finalizer) 59 : code{code}, index{index}, second{second}, finalizer{finalizer} {} 60 61 Codes code; 62 std::size_t index; 63 std::size_t second{}; 64 llvm::Optional<std::function<void(mlir::FuncOp)>> finalizer{}; 65 }; // namespace 66 67 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code 68 /// generation that traverses the FIR and modifies types and operations to a 69 /// form that is appropriate for the specific target. LLVM IR has specific 70 /// idioms that are used for distinct target processor and ABI combinations. 71 class TargetRewrite : public TargetRewriteBase<TargetRewrite> { 72 public: 73 TargetRewrite(const TargetRewriteOptions &options) { 74 noCharacterConversion = options.noCharacterConversion; 75 noComplexConversion = options.noComplexConversion; 76 } 77 78 void runOnOperation() override final { 79 auto &context = getContext(); 80 mlir::OpBuilder rewriter(&context); 81 82 auto mod = getModule(); 83 if (!forcedTargetTriple.empty()) 84 setTargetTriple(mod, forcedTargetTriple); 85 86 auto specifics = CodeGenSpecifics::get(getOperation().getContext(), 87 getTargetTriple(getOperation()), 88 getKindMapping(getOperation())); 89 setMembers(specifics.get(), &rewriter); 90 91 // Perform type conversion on signatures and call sites. 92 if (mlir::failed(convertTypes(mod))) { 93 mlir::emitError(mlir::UnknownLoc::get(&context), 94 "error in converting types to target abi"); 95 signalPassFailure(); 96 } 97 98 // Convert ops in target-specific patterns. 99 mod.walk([&](mlir::Operation *op) { 100 if (auto call = dyn_cast<fir::CallOp>(op)) { 101 if (!hasPortableSignature(call.getFunctionType())) 102 convertCallOp(call); 103 } else if (auto dispatch = dyn_cast<DispatchOp>(op)) { 104 if (!hasPortableSignature(dispatch.getFunctionType())) 105 convertCallOp(dispatch); 106 } else if (auto addr = dyn_cast<AddrOfOp>(op)) { 107 if (addr.getType().isa<mlir::FunctionType>() && 108 !hasPortableSignature(addr.getType())) 109 convertAddrOp(addr); 110 } 111 }); 112 113 clearMembers(); 114 } 115 116 mlir::ModuleOp getModule() { return getOperation(); } 117 118 template <typename A, typename B, typename C> 119 std::function<mlir::Value(mlir::Operation *)> 120 rewriteCallComplexResultType(A ty, B &newResTys, B &newInTys, C &newOpers) { 121 auto m = specifics->complexReturnType(ty.getElementType()); 122 // Currently targets mandate COMPLEX is a single aggregate or packed 123 // scalar, including the sret case. 124 assert(m.size() == 1 && "target lowering of complex return not supported"); 125 auto resTy = std::get<mlir::Type>(m[0]); 126 auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); 127 auto loc = mlir::UnknownLoc::get(resTy.getContext()); 128 if (attr.isSRet()) { 129 assert(isa_ref_type(resTy)); 130 mlir::Value stack = 131 rewriter->create<fir::AllocaOp>(loc, dyn_cast_ptrEleTy(resTy)); 132 newInTys.push_back(resTy); 133 newOpers.push_back(stack); 134 return [=](mlir::Operation *) -> mlir::Value { 135 auto memTy = ReferenceType::get(ty); 136 auto cast = rewriter->create<ConvertOp>(loc, memTy, stack); 137 return rewriter->create<fir::LoadOp>(loc, cast); 138 }; 139 } 140 newResTys.push_back(resTy); 141 return [=](mlir::Operation *call) -> mlir::Value { 142 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 143 rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem); 144 auto memTy = ReferenceType::get(ty); 145 auto cast = rewriter->create<ConvertOp>(loc, memTy, mem); 146 return rewriter->create<fir::LoadOp>(loc, cast); 147 }; 148 } 149 150 template <typename A, typename B, typename C> 151 void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys, 152 C &newOpers) { 153 auto m = specifics->complexArgumentType(ty.getElementType()); 154 auto *ctx = ty.getContext(); 155 auto loc = mlir::UnknownLoc::get(ctx); 156 if (m.size() == 1) { 157 // COMPLEX is a single aggregate 158 auto resTy = std::get<mlir::Type>(m[0]); 159 auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); 160 auto oldRefTy = ReferenceType::get(ty); 161 if (attr.isByVal()) { 162 auto mem = rewriter->create<fir::AllocaOp>(loc, ty); 163 rewriter->create<fir::StoreOp>(loc, oper, mem); 164 newOpers.push_back(rewriter->create<ConvertOp>(loc, resTy, mem)); 165 } else { 166 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 167 auto cast = rewriter->create<ConvertOp>(loc, oldRefTy, mem); 168 rewriter->create<fir::StoreOp>(loc, oper, cast); 169 newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem)); 170 } 171 newInTys.push_back(resTy); 172 } else { 173 assert(m.size() == 2); 174 // COMPLEX is split into 2 separate arguments 175 auto iTy = rewriter->getIntegerType(32); 176 for (auto e : llvm::enumerate(m)) { 177 auto &tup = e.value(); 178 auto ty = std::get<mlir::Type>(tup); 179 auto index = e.index(); 180 auto idx = rewriter->getIntegerAttr(iTy, index); 181 auto val = rewriter->create<ExtractValueOp>( 182 loc, ty, oper, rewriter->getArrayAttr(idx)); 183 newInTys.push_back(ty); 184 newOpers.push_back(val); 185 } 186 } 187 } 188 189 // Convert fir.call and fir.dispatch Ops. 190 template <typename A> 191 void convertCallOp(A callOp) { 192 auto fnTy = callOp.getFunctionType(); 193 auto loc = callOp.getLoc(); 194 rewriter->setInsertionPoint(callOp); 195 llvm::SmallVector<mlir::Type> newResTys; 196 llvm::SmallVector<mlir::Type> newInTys; 197 llvm::SmallVector<mlir::Value> newOpers; 198 199 // If the call is indirect, the first argument must still be the function 200 // to call. 201 int dropFront = 0; 202 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 203 if (!callOp.getCallee().hasValue()) { 204 newInTys.push_back(fnTy.getInput(0)); 205 newOpers.push_back(callOp.getOperand(0)); 206 dropFront = 1; 207 } 208 } 209 210 // Determine the rewrite function, `wrap`, for the result value. 211 llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap; 212 if (fnTy.getResults().size() == 1) { 213 mlir::Type ty = fnTy.getResult(0); 214 llvm::TypeSwitch<mlir::Type>(ty) 215 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 216 wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 217 newOpers); 218 }) 219 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 220 wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 221 newOpers); 222 }) 223 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 224 } else if (fnTy.getResults().size() > 1) { 225 TODO(loc, "multiple results not supported yet"); 226 } 227 228 llvm::SmallVector<mlir::Type> trailingInTys; 229 llvm::SmallVector<mlir::Value> trailingOpers; 230 for (auto e : llvm::enumerate( 231 llvm::zip(fnTy.getInputs().drop_front(dropFront), 232 callOp.getOperands().drop_front(dropFront)))) { 233 mlir::Type ty = std::get<0>(e.value()); 234 mlir::Value oper = std::get<1>(e.value()); 235 unsigned index = e.index(); 236 llvm::TypeSwitch<mlir::Type>(ty) 237 .template Case<BoxCharType>([&](BoxCharType boxTy) { 238 bool sret; 239 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 240 sret = callOp.getCallee() && 241 functionArgIsSRet(index, 242 getModule().lookupSymbol<mlir::FuncOp>( 243 *callOp.getCallee())); 244 } else { 245 // TODO: dispatch case; how do we put arguments on a call? 246 // We cannot put both an sret and the dispatch object first. 247 sret = false; 248 TODO(loc, "dispatch + sret not supported yet"); 249 } 250 auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret); 251 auto unbox = 252 rewriter->create<UnboxCharOp>(loc, std::get<mlir::Type>(m[0]), 253 std::get<mlir::Type>(m[1]), oper); 254 // unboxed CHARACTER arguments 255 for (auto e : llvm::enumerate(m)) { 256 unsigned idx = e.index(); 257 auto attr = std::get<CodeGenSpecifics::Attributes>(e.value()); 258 auto argTy = std::get<mlir::Type>(e.value()); 259 if (attr.isAppend()) { 260 trailingInTys.push_back(argTy); 261 trailingOpers.push_back(unbox.getResult(idx)); 262 } else { 263 newInTys.push_back(argTy); 264 newOpers.push_back(unbox.getResult(idx)); 265 } 266 } 267 }) 268 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 269 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 270 }) 271 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 272 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 273 }) 274 .template Case<mlir::TupleType>([&](mlir::TupleType tuple) { 275 if (factory::isCharacterProcedureTuple(tuple)) { 276 mlir::ModuleOp module = getModule(); 277 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 278 if (callOp.getCallee()) { 279 llvm::StringRef charProcAttr = 280 fir::getCharacterProcedureDummyAttrName(); 281 // The charProcAttr attribute is only used as a safety to 282 // confirm that this is a dummy procedure and should be split. 283 // It cannot be used to match because attributes are not 284 // available in case of indirect calls. 285 auto funcOp = 286 module.lookupSymbol<mlir::FuncOp>(*callOp.getCallee()); 287 if (funcOp && 288 !funcOp.template getArgAttrOfType<mlir::UnitAttr>( 289 index, charProcAttr)) 290 mlir::emitError(loc, "tuple argument will be split even " 291 "though it does not have the `" + 292 charProcAttr + "` attribute"); 293 } 294 } 295 mlir::Type funcPointerType = tuple.getType(0); 296 mlir::Type lenType = tuple.getType(1); 297 FirOpBuilder builder(*rewriter, getKindMapping(module)); 298 auto [funcPointer, len] = 299 factory::extractCharacterProcedureTuple(builder, loc, oper); 300 newInTys.push_back(funcPointerType); 301 newOpers.push_back(funcPointer); 302 trailingInTys.push_back(lenType); 303 trailingOpers.push_back(len); 304 } else { 305 newInTys.push_back(tuple); 306 newOpers.push_back(oper); 307 } 308 }) 309 .Default([&](mlir::Type ty) { 310 newInTys.push_back(ty); 311 newOpers.push_back(oper); 312 }); 313 } 314 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 315 newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); 316 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 317 fir::CallOp newCall; 318 if (callOp.getCallee().hasValue()) { 319 newCall = rewriter->create<A>(loc, callOp.getCallee().getValue(), 320 newResTys, newOpers); 321 } else { 322 // Force new type on the input operand. 323 newOpers[0].setType(mlir::FunctionType::get( 324 callOp.getContext(), 325 mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys)); 326 newCall = rewriter->create<A>(loc, newResTys, newOpers); 327 } 328 LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n'); 329 if (wrap.hasValue()) 330 replaceOp(callOp, (*wrap)(newCall.getOperation())); 331 else 332 replaceOp(callOp, newCall.getResults()); 333 } else { 334 // A is fir::DispatchOp 335 TODO(loc, "dispatch not implemented"); 336 } 337 } 338 339 // Result type fixup for fir::ComplexType and mlir::ComplexType 340 template <typename A, typename B> 341 void lowerComplexSignatureRes(A cmplx, B &newResTys, B &newInTys) { 342 if (noComplexConversion) { 343 newResTys.push_back(cmplx); 344 } else { 345 for (auto &tup : specifics->complexReturnType(cmplx.getElementType())) { 346 auto argTy = std::get<mlir::Type>(tup); 347 if (std::get<CodeGenSpecifics::Attributes>(tup).isSRet()) 348 newInTys.push_back(argTy); 349 else 350 newResTys.push_back(argTy); 351 } 352 } 353 } 354 355 // Argument type fixup for fir::ComplexType and mlir::ComplexType 356 template <typename A, typename B> 357 void lowerComplexSignatureArg(A cmplx, B &newInTys) { 358 if (noComplexConversion) 359 newInTys.push_back(cmplx); 360 else 361 for (auto &tup : specifics->complexArgumentType(cmplx.getElementType())) 362 newInTys.push_back(std::get<mlir::Type>(tup)); 363 } 364 365 /// Taking the address of a function. Modify the signature as needed. 366 void convertAddrOp(AddrOfOp addrOp) { 367 rewriter->setInsertionPoint(addrOp); 368 auto addrTy = addrOp.getType().cast<mlir::FunctionType>(); 369 llvm::SmallVector<mlir::Type> newResTys; 370 llvm::SmallVector<mlir::Type> newInTys; 371 for (mlir::Type ty : addrTy.getResults()) { 372 llvm::TypeSwitch<mlir::Type>(ty) 373 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 374 lowerComplexSignatureRes(ty, newResTys, newInTys); 375 }) 376 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 377 lowerComplexSignatureRes(ty, newResTys, newInTys); 378 }) 379 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 380 } 381 llvm::SmallVector<mlir::Type> trailingInTys; 382 for (mlir::Type ty : addrTy.getInputs()) { 383 llvm::TypeSwitch<mlir::Type>(ty) 384 .Case<BoxCharType>([&](BoxCharType box) { 385 if (noCharacterConversion) { 386 newInTys.push_back(box); 387 } else { 388 for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) { 389 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 390 auto argTy = std::get<mlir::Type>(tup); 391 llvm::SmallVector<mlir::Type> &vec = 392 attr.isAppend() ? trailingInTys : newInTys; 393 vec.push_back(argTy); 394 } 395 } 396 }) 397 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 398 lowerComplexSignatureArg(ty, newInTys); 399 }) 400 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 401 lowerComplexSignatureArg(ty, newInTys); 402 }) 403 .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 404 if (factory::isCharacterProcedureTuple(tuple)) { 405 newInTys.push_back(tuple.getType(0)); 406 trailingInTys.push_back(tuple.getType(1)); 407 } else { 408 newInTys.push_back(ty); 409 } 410 }) 411 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 412 } 413 // append trailing input types 414 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 415 // replace this op with a new one with the updated signature 416 auto newTy = rewriter->getFunctionType(newInTys, newResTys); 417 auto newOp = 418 rewriter->create<AddrOfOp>(addrOp.getLoc(), newTy, addrOp.getSymbol()); 419 replaceOp(addrOp, newOp.getResult()); 420 } 421 422 /// Convert the type signatures on all the functions present in the module. 423 /// As the type signature is being changed, this must also update the 424 /// function itself to use any new arguments, etc. 425 mlir::LogicalResult convertTypes(mlir::ModuleOp mod) { 426 for (auto fn : mod.getOps<mlir::FuncOp>()) 427 convertSignature(fn); 428 return mlir::success(); 429 } 430 431 /// If the signature does not need any special target-specific converions, 432 /// then it is considered portable for any target, and this function will 433 /// return `true`. Otherwise, the signature is not portable and `false` is 434 /// returned. 435 bool hasPortableSignature(mlir::Type signature) { 436 assert(signature.isa<mlir::FunctionType>()); 437 auto func = signature.dyn_cast<mlir::FunctionType>(); 438 for (auto ty : func.getResults()) 439 if ((ty.isa<BoxCharType>() && !noCharacterConversion) || 440 (isa_complex(ty) && !noComplexConversion)) { 441 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 442 return false; 443 } 444 for (auto ty : func.getInputs()) 445 if (((ty.isa<BoxCharType>() || factory::isCharacterProcedureTuple(ty)) && 446 !noCharacterConversion) || 447 (isa_complex(ty) && !noComplexConversion)) { 448 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 449 return false; 450 } 451 return true; 452 } 453 454 /// Rewrite the signatures and body of the `FuncOp`s in the module for 455 /// the immediately subsequent target code gen. 456 void convertSignature(mlir::FuncOp func) { 457 auto funcTy = func.getType().cast<mlir::FunctionType>(); 458 if (hasPortableSignature(funcTy)) 459 return; 460 llvm::SmallVector<mlir::Type> newResTys; 461 llvm::SmallVector<mlir::Type> newInTys; 462 llvm::SmallVector<FixupTy> fixups; 463 464 // Convert return value(s) 465 for (auto ty : funcTy.getResults()) 466 llvm::TypeSwitch<mlir::Type>(ty) 467 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 468 if (noComplexConversion) 469 newResTys.push_back(cmplx); 470 else 471 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 472 }) 473 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 474 if (noComplexConversion) 475 newResTys.push_back(cmplx); 476 else 477 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 478 }) 479 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 480 481 // Convert arguments 482 llvm::SmallVector<mlir::Type> trailingTys; 483 for (auto e : llvm::enumerate(funcTy.getInputs())) { 484 auto ty = e.value(); 485 unsigned index = e.index(); 486 llvm::TypeSwitch<mlir::Type>(ty) 487 .Case<BoxCharType>([&](BoxCharType boxTy) { 488 if (noCharacterConversion) { 489 newInTys.push_back(boxTy); 490 } else { 491 // Convert a CHARACTER argument type. This can involve separating 492 // the pointer and the LEN into two arguments and moving the LEN 493 // argument to the end of the arg list. 494 bool sret = functionArgIsSRet(index, func); 495 for (auto e : llvm::enumerate(specifics->boxcharArgumentType( 496 boxTy.getEleTy(), sret))) { 497 auto &tup = e.value(); 498 auto index = e.index(); 499 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 500 auto argTy = std::get<mlir::Type>(tup); 501 if (attr.isAppend()) { 502 trailingTys.push_back(argTy); 503 } else { 504 if (sret) { 505 fixups.emplace_back(FixupTy::Codes::CharPair, 506 newInTys.size(), index); 507 } else { 508 fixups.emplace_back(FixupTy::Codes::Trailing, 509 newInTys.size(), trailingTys.size()); 510 } 511 newInTys.push_back(argTy); 512 } 513 } 514 } 515 }) 516 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 517 if (noComplexConversion) 518 newInTys.push_back(cmplx); 519 else 520 doComplexArg(func, cmplx, newInTys, fixups); 521 }) 522 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 523 if (noComplexConversion) 524 newInTys.push_back(cmplx); 525 else 526 doComplexArg(func, cmplx, newInTys, fixups); 527 }) 528 .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 529 if (factory::isCharacterProcedureTuple(tuple)) { 530 fixups.emplace_back(FixupTy::Codes::TrailingCharProc, 531 newInTys.size(), trailingTys.size()); 532 newInTys.push_back(tuple.getType(0)); 533 trailingTys.push_back(tuple.getType(1)); 534 } else { 535 newInTys.push_back(ty); 536 } 537 }) 538 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 539 } 540 541 if (!func.empty()) { 542 // If the function has a body, then apply the fixups to the arguments and 543 // return ops as required. These fixups are done in place. 544 auto loc = func.getLoc(); 545 const auto fixupSize = fixups.size(); 546 const auto oldArgTys = func.getType().getInputs(); 547 int offset = 0; 548 for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) { 549 const auto &fixup = fixups[i]; 550 switch (fixup.code) { 551 case FixupTy::Codes::ArgumentAsLoad: { 552 // Argument was pass-by-value, but is now pass-by-reference and 553 // possibly with a different element type. 554 auto newArg = func.front().insertArgument(fixup.index, 555 newInTys[fixup.index], loc); 556 rewriter->setInsertionPointToStart(&func.front()); 557 auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); 558 auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, newArg); 559 auto load = rewriter->create<fir::LoadOp>(loc, cast); 560 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 561 func.front().eraseArgument(fixup.index + 1); 562 } break; 563 case FixupTy::Codes::ArgumentType: { 564 // Argument is pass-by-value, but its type has likely been modified to 565 // suit the target ABI convention. 566 auto newArg = func.front().insertArgument(fixup.index, 567 newInTys[fixup.index], loc); 568 rewriter->setInsertionPointToStart(&func.front()); 569 auto mem = 570 rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]); 571 rewriter->create<fir::StoreOp>(loc, newArg, mem); 572 auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); 573 auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, mem); 574 mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast); 575 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 576 func.front().eraseArgument(fixup.index + 1); 577 LLVM_DEBUG(llvm::dbgs() 578 << "old argument: " << oldArgTy.getEleTy() 579 << ", repl: " << load << ", new argument: " 580 << func.getArgument(fixup.index).getType() << '\n'); 581 } break; 582 case FixupTy::Codes::CharPair: { 583 // The FIR boxchar argument has been split into a pair of distinct 584 // arguments that are in juxtaposition to each other. 585 auto newArg = func.front().insertArgument(fixup.index, 586 newInTys[fixup.index], loc); 587 if (fixup.second == 1) { 588 rewriter->setInsertionPointToStart(&func.front()); 589 auto boxTy = oldArgTys[fixup.index - offset - fixup.second]; 590 auto box = rewriter->create<EmboxCharOp>( 591 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg); 592 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 593 func.front().eraseArgument(fixup.index + 1); 594 offset++; 595 } 596 } break; 597 case FixupTy::Codes::ReturnAsStore: { 598 // The value being returned is now being returned in memory (callee 599 // stack space) through a hidden reference argument. 600 auto newArg = func.front().insertArgument(fixup.index, 601 newInTys[fixup.index], loc); 602 offset++; 603 func.walk([&](mlir::func::ReturnOp ret) { 604 rewriter->setInsertionPoint(ret); 605 auto oldOper = ret.getOperand(0); 606 auto oldOperTy = ReferenceType::get(oldOper.getType()); 607 auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, newArg); 608 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 609 rewriter->create<mlir::func::ReturnOp>(loc); 610 ret.erase(); 611 }); 612 } break; 613 case FixupTy::Codes::ReturnType: { 614 // The function is still returning a value, but its type has likely 615 // changed to suit the target ABI convention. 616 func.walk([&](mlir::func::ReturnOp ret) { 617 rewriter->setInsertionPoint(ret); 618 auto oldOper = ret.getOperand(0); 619 auto oldOperTy = ReferenceType::get(oldOper.getType()); 620 auto mem = 621 rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]); 622 auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, mem); 623 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 624 mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem); 625 rewriter->create<mlir::func::ReturnOp>(loc, load); 626 ret.erase(); 627 }); 628 } break; 629 case FixupTy::Codes::Split: { 630 // The FIR argument has been split into a pair of distinct arguments 631 // that are in juxtaposition to each other. (For COMPLEX value.) 632 auto newArg = func.front().insertArgument(fixup.index, 633 newInTys[fixup.index], loc); 634 if (fixup.second == 1) { 635 rewriter->setInsertionPointToStart(&func.front()); 636 auto cplxTy = oldArgTys[fixup.index - offset - fixup.second]; 637 auto undef = rewriter->create<UndefOp>(loc, cplxTy); 638 auto iTy = rewriter->getIntegerType(32); 639 auto zero = rewriter->getIntegerAttr(iTy, 0); 640 auto one = rewriter->getIntegerAttr(iTy, 1); 641 auto cplx1 = rewriter->create<InsertValueOp>( 642 loc, cplxTy, undef, func.front().getArgument(fixup.index - 1), 643 rewriter->getArrayAttr(zero)); 644 auto cplx = rewriter->create<InsertValueOp>( 645 loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one)); 646 func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx); 647 func.front().eraseArgument(fixup.index + 1); 648 offset++; 649 } 650 } break; 651 case FixupTy::Codes::Trailing: { 652 // The FIR argument has been split into a pair of distinct arguments. 653 // The first part of the pair appears in the original argument 654 // position. The second part of the pair is appended after all the 655 // original arguments. (Boxchar arguments.) 656 auto newBufArg = func.front().insertArgument( 657 fixup.index, newInTys[fixup.index], loc); 658 auto newLenArg = 659 func.front().addArgument(trailingTys[fixup.second], loc); 660 auto boxTy = oldArgTys[fixup.index - offset]; 661 rewriter->setInsertionPointToStart(&func.front()); 662 auto box = 663 rewriter->create<EmboxCharOp>(loc, boxTy, newBufArg, newLenArg); 664 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 665 func.front().eraseArgument(fixup.index + 1); 666 } break; 667 case FixupTy::Codes::TrailingCharProc: { 668 // The FIR character procedure argument tuple has been split into a 669 // pair of distinct arguments. The first part of the pair appears in 670 // the original argument position. The second part of the pair is 671 // appended after all the original arguments. 672 auto newProcPointerArg = func.front().insertArgument( 673 fixup.index, newInTys[fixup.index], loc); 674 auto newLenArg = 675 func.front().addArgument(trailingTys[fixup.second], loc); 676 auto tupleType = oldArgTys[fixup.index - offset]; 677 rewriter->setInsertionPointToStart(&func.front()); 678 FirOpBuilder builder(*rewriter, getKindMapping(getModule())); 679 auto tuple = factory::createCharacterProcedureTuple( 680 builder, loc, tupleType, newProcPointerArg, newLenArg); 681 func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple); 682 func.front().eraseArgument(fixup.index + 1); 683 } break; 684 } 685 } 686 } 687 688 // Set the new type and finalize the arguments, etc. 689 newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end()); 690 auto newFuncTy = 691 mlir::FunctionType::get(func.getContext(), newInTys, newResTys); 692 LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n'); 693 func.setType(newFuncTy); 694 695 for (auto &fixup : fixups) 696 if (fixup.finalizer) 697 (*fixup.finalizer)(func); 698 } 699 700 inline bool functionArgIsSRet(unsigned index, mlir::FuncOp func) { 701 if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret")) 702 return true; 703 return false; 704 } 705 706 /// Convert a complex return value. This can involve converting the return 707 /// value to a "hidden" first argument or packing the complex into a wide 708 /// GPR. 709 template <typename A, typename B, typename C> 710 void doComplexReturn(mlir::FuncOp func, A cmplx, B &newResTys, B &newInTys, 711 C &fixups) { 712 if (noComplexConversion) { 713 newResTys.push_back(cmplx); 714 return; 715 } 716 auto m = specifics->complexReturnType(cmplx.getElementType()); 717 assert(m.size() == 1); 718 auto &tup = m[0]; 719 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 720 auto argTy = std::get<mlir::Type>(tup); 721 if (attr.isSRet()) { 722 unsigned argNo = newInTys.size(); 723 fixups.emplace_back( 724 FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::FuncOp func) { 725 func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); 726 }); 727 newInTys.push_back(argTy); 728 return; 729 } 730 fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size()); 731 newResTys.push_back(argTy); 732 } 733 734 /// Convert a complex argument value. This can involve storing the value to 735 /// a temporary memory location or factoring the value into two distinct 736 /// arguments. 737 template <typename A, typename B, typename C> 738 void doComplexArg(mlir::FuncOp func, A cmplx, B &newInTys, C &fixups) { 739 if (noComplexConversion) { 740 newInTys.push_back(cmplx); 741 return; 742 } 743 auto m = specifics->complexArgumentType(cmplx.getElementType()); 744 const auto fixupCode = 745 m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType; 746 for (auto e : llvm::enumerate(m)) { 747 auto &tup = e.value(); 748 auto index = e.index(); 749 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 750 auto argTy = std::get<mlir::Type>(tup); 751 auto argNo = newInTys.size(); 752 if (attr.isByVal()) { 753 if (auto align = attr.getAlignment()) 754 fixups.emplace_back( 755 FixupTy::Codes::ArgumentAsLoad, argNo, [=](mlir::FuncOp func) { 756 func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr()); 757 func.setArgAttr(argNo, "llvm.align", 758 rewriter->getIntegerAttr( 759 rewriter->getIntegerType(32), align)); 760 }); 761 else 762 fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(), 763 [=](mlir::FuncOp func) { 764 func.setArgAttr(argNo, "llvm.byval", 765 rewriter->getUnitAttr()); 766 }); 767 } else { 768 if (auto align = attr.getAlignment()) 769 fixups.emplace_back(fixupCode, argNo, index, [=](mlir::FuncOp func) { 770 func.setArgAttr( 771 argNo, "llvm.align", 772 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align)); 773 }); 774 else 775 fixups.emplace_back(fixupCode, argNo, index); 776 } 777 newInTys.push_back(argTy); 778 } 779 } 780 781 private: 782 // Replace `op` and remove it. 783 void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { 784 op->replaceAllUsesWith(newValues); 785 op->dropAllReferences(); 786 op->erase(); 787 } 788 789 inline void setMembers(CodeGenSpecifics *s, mlir::OpBuilder *r) { 790 specifics = s; 791 rewriter = r; 792 } 793 794 inline void clearMembers() { setMembers(nullptr, nullptr); } 795 796 CodeGenSpecifics *specifics{}; 797 mlir::OpBuilder *rewriter; 798 }; // namespace 799 } // namespace 800 801 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 802 fir::createFirTargetRewritePass(const TargetRewriteOptions &options) { 803 return std::make_unique<TargetRewrite>(options); 804 } 805