1 //===-- TargetRewrite.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest. 10 // LLVM expects different lowering idioms to be used for distinct target 11 // triples. These distinctions are handled by this pass. 12 // 13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "PassDetail.h" 18 #include "Target.h" 19 #include "flang/Lower/Todo.h" 20 #include "flang/Optimizer/Builder/Character.h" 21 #include "flang/Optimizer/CodeGen/CodeGen.h" 22 #include "flang/Optimizer/Dialect/FIRDialect.h" 23 #include "flang/Optimizer/Dialect/FIROps.h" 24 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 25 #include "flang/Optimizer/Dialect/FIRType.h" 26 #include "flang/Optimizer/Support/FIRContext.h" 27 #include "mlir/Transforms/DialectConversion.h" 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/ADT/TypeSwitch.h" 30 #include "llvm/Support/Debug.h" 31 32 using namespace fir; 33 34 #define DEBUG_TYPE "flang-target-rewrite" 35 36 namespace { 37 38 /// Fixups for updating a FuncOp's arguments and return values. 39 struct FixupTy { 40 enum class Codes { 41 ArgumentAsLoad, 42 ArgumentType, 43 CharPair, 44 ReturnAsStore, 45 ReturnType, 46 Split, 47 Trailing, 48 TrailingCharProc 49 }; 50 51 FixupTy(Codes code, std::size_t index, std::size_t second = 0) 52 : code{code}, index{index}, second{second} {} 53 FixupTy(Codes code, std::size_t index, 54 std::function<void(mlir::FuncOp)> &&finalizer) 55 : code{code}, index{index}, finalizer{finalizer} {} 56 FixupTy(Codes code, std::size_t index, std::size_t second, 57 std::function<void(mlir::FuncOp)> &&finalizer) 58 : code{code}, index{index}, second{second}, finalizer{finalizer} {} 59 60 Codes code; 61 std::size_t index; 62 std::size_t second{}; 63 llvm::Optional<std::function<void(mlir::FuncOp)>> finalizer{}; 64 }; // namespace 65 66 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code 67 /// generation that traverses the FIR and modifies types and operations to a 68 /// form that is appropriate for the specific target. LLVM IR has specific 69 /// idioms that are used for distinct target processor and ABI combinations. 70 class TargetRewrite : public TargetRewriteBase<TargetRewrite> { 71 public: 72 TargetRewrite(const TargetRewriteOptions &options) { 73 noCharacterConversion = options.noCharacterConversion; 74 noComplexConversion = options.noComplexConversion; 75 } 76 77 void runOnOperation() override final { 78 auto &context = getContext(); 79 mlir::OpBuilder rewriter(&context); 80 81 auto mod = getModule(); 82 if (!forcedTargetTriple.empty()) 83 setTargetTriple(mod, forcedTargetTriple); 84 85 auto specifics = CodeGenSpecifics::get(getOperation().getContext(), 86 getTargetTriple(getOperation()), 87 getKindMapping(getOperation())); 88 setMembers(specifics.get(), &rewriter); 89 90 // Perform type conversion on signatures and call sites. 91 if (mlir::failed(convertTypes(mod))) { 92 mlir::emitError(mlir::UnknownLoc::get(&context), 93 "error in converting types to target abi"); 94 signalPassFailure(); 95 } 96 97 // Convert ops in target-specific patterns. 98 mod.walk([&](mlir::Operation *op) { 99 if (auto call = dyn_cast<fir::CallOp>(op)) { 100 if (!hasPortableSignature(call.getFunctionType())) 101 convertCallOp(call); 102 } else if (auto dispatch = dyn_cast<DispatchOp>(op)) { 103 if (!hasPortableSignature(dispatch.getFunctionType())) 104 convertCallOp(dispatch); 105 } else if (auto addr = dyn_cast<AddrOfOp>(op)) { 106 if (addr.getType().isa<mlir::FunctionType>() && 107 !hasPortableSignature(addr.getType())) 108 convertAddrOp(addr); 109 } 110 }); 111 112 clearMembers(); 113 } 114 115 mlir::ModuleOp getModule() { return getOperation(); } 116 117 template <typename A, typename B, typename C> 118 std::function<mlir::Value(mlir::Operation *)> 119 rewriteCallComplexResultType(A ty, B &newResTys, B &newInTys, C &newOpers) { 120 auto m = specifics->complexReturnType(ty.getElementType()); 121 // Currently targets mandate COMPLEX is a single aggregate or packed 122 // scalar, including the sret case. 123 assert(m.size() == 1 && "target lowering of complex return not supported"); 124 auto resTy = std::get<mlir::Type>(m[0]); 125 auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); 126 auto loc = mlir::UnknownLoc::get(resTy.getContext()); 127 if (attr.isSRet()) { 128 assert(isa_ref_type(resTy)); 129 mlir::Value stack = 130 rewriter->create<fir::AllocaOp>(loc, dyn_cast_ptrEleTy(resTy)); 131 newInTys.push_back(resTy); 132 newOpers.push_back(stack); 133 return [=](mlir::Operation *) -> mlir::Value { 134 auto memTy = ReferenceType::get(ty); 135 auto cast = rewriter->create<ConvertOp>(loc, memTy, stack); 136 return rewriter->create<fir::LoadOp>(loc, cast); 137 }; 138 } 139 newResTys.push_back(resTy); 140 return [=](mlir::Operation *call) -> mlir::Value { 141 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 142 rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem); 143 auto memTy = ReferenceType::get(ty); 144 auto cast = rewriter->create<ConvertOp>(loc, memTy, mem); 145 return rewriter->create<fir::LoadOp>(loc, cast); 146 }; 147 } 148 149 template <typename A, typename B, typename C> 150 void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys, 151 C &newOpers) { 152 auto m = specifics->complexArgumentType(ty.getElementType()); 153 auto *ctx = ty.getContext(); 154 auto loc = mlir::UnknownLoc::get(ctx); 155 if (m.size() == 1) { 156 // COMPLEX is a single aggregate 157 auto resTy = std::get<mlir::Type>(m[0]); 158 auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); 159 auto oldRefTy = ReferenceType::get(ty); 160 if (attr.isByVal()) { 161 auto mem = rewriter->create<fir::AllocaOp>(loc, ty); 162 rewriter->create<fir::StoreOp>(loc, oper, mem); 163 newOpers.push_back(rewriter->create<ConvertOp>(loc, resTy, mem)); 164 } else { 165 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 166 auto cast = rewriter->create<ConvertOp>(loc, oldRefTy, mem); 167 rewriter->create<fir::StoreOp>(loc, oper, cast); 168 newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem)); 169 } 170 newInTys.push_back(resTy); 171 } else { 172 assert(m.size() == 2); 173 // COMPLEX is split into 2 separate arguments 174 for (auto e : llvm::enumerate(m)) { 175 auto &tup = e.value(); 176 auto ty = std::get<mlir::Type>(tup); 177 auto index = e.index(); 178 auto idx = rewriter->getIntegerAttr(rewriter->getIndexType(), index); 179 auto val = rewriter->create<ExtractValueOp>( 180 loc, ty, oper, rewriter->getArrayAttr(idx)); 181 newInTys.push_back(ty); 182 newOpers.push_back(val); 183 } 184 } 185 } 186 187 // Convert fir.call and fir.dispatch Ops. 188 template <typename A> 189 void convertCallOp(A callOp) { 190 auto fnTy = callOp.getFunctionType(); 191 auto loc = callOp.getLoc(); 192 rewriter->setInsertionPoint(callOp); 193 llvm::SmallVector<mlir::Type> newResTys; 194 llvm::SmallVector<mlir::Type> newInTys; 195 llvm::SmallVector<mlir::Value> newOpers; 196 197 // If the call is indirect, the first argument must still be the function 198 // to call. 199 int dropFront = 0; 200 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 201 if (!callOp.callee().hasValue()) { 202 newInTys.push_back(fnTy.getInput(0)); 203 newOpers.push_back(callOp.getOperand(0)); 204 dropFront = 1; 205 } 206 } 207 208 // Determine the rewrite function, `wrap`, for the result value. 209 llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap; 210 if (fnTy.getResults().size() == 1) { 211 mlir::Type ty = fnTy.getResult(0); 212 llvm::TypeSwitch<mlir::Type>(ty) 213 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 214 wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 215 newOpers); 216 }) 217 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 218 wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 219 newOpers); 220 }) 221 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 222 } else if (fnTy.getResults().size() > 1) { 223 TODO(loc, "multiple results not supported yet"); 224 } 225 226 llvm::SmallVector<mlir::Type> trailingInTys; 227 llvm::SmallVector<mlir::Value> trailingOpers; 228 for (auto e : llvm::enumerate( 229 llvm::zip(fnTy.getInputs().drop_front(dropFront), 230 callOp.getOperands().drop_front(dropFront)))) { 231 mlir::Type ty = std::get<0>(e.value()); 232 mlir::Value oper = std::get<1>(e.value()); 233 unsigned index = e.index(); 234 llvm::TypeSwitch<mlir::Type>(ty) 235 .template Case<BoxCharType>([&](BoxCharType boxTy) { 236 bool sret; 237 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 238 sret = callOp.callee() && 239 functionArgIsSRet(index, 240 getModule().lookupSymbol<mlir::FuncOp>( 241 *callOp.callee())); 242 } else { 243 // TODO: dispatch case; how do we put arguments on a call? 244 // We cannot put both an sret and the dispatch object first. 245 sret = false; 246 TODO(loc, "dispatch + sret not supported yet"); 247 } 248 auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret); 249 auto unbox = 250 rewriter->create<UnboxCharOp>(loc, std::get<mlir::Type>(m[0]), 251 std::get<mlir::Type>(m[1]), oper); 252 // unboxed CHARACTER arguments 253 for (auto e : llvm::enumerate(m)) { 254 unsigned idx = e.index(); 255 auto attr = std::get<CodeGenSpecifics::Attributes>(e.value()); 256 auto argTy = std::get<mlir::Type>(e.value()); 257 if (attr.isAppend()) { 258 trailingInTys.push_back(argTy); 259 trailingOpers.push_back(unbox.getResult(idx)); 260 } else { 261 newInTys.push_back(argTy); 262 newOpers.push_back(unbox.getResult(idx)); 263 } 264 } 265 }) 266 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 267 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 268 }) 269 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 270 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 271 }) 272 .template Case<mlir::TupleType>([&](mlir::TupleType tuple) { 273 if (factory::isCharacterProcedureTuple(tuple)) { 274 mlir::ModuleOp module = getModule(); 275 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 276 if (callOp.callee()) { 277 llvm::StringRef charProcAttr = 278 fir::getCharacterProcedureDummyAttrName(); 279 // The charProcAttr attribute is only used as a safety to 280 // confirm that this is a dummy procedure and should be split. 281 // It cannot be used to match because attributes are not 282 // available in case of indirect calls. 283 auto funcOp = 284 module.lookupSymbol<mlir::FuncOp>(*callOp.callee()); 285 if (funcOp && 286 !funcOp.template getArgAttrOfType<mlir::UnitAttr>( 287 index, charProcAttr)) 288 mlir::emitError(loc, "tuple argument will be split even " 289 "though it does not have the `" + 290 charProcAttr + "` attribute"); 291 } 292 } 293 mlir::Type funcPointerType = tuple.getType(0); 294 mlir::Type lenType = tuple.getType(1); 295 FirOpBuilder builder(*rewriter, getKindMapping(module)); 296 auto [funcPointer, len] = 297 factory::extractCharacterProcedureTuple(builder, loc, oper); 298 newInTys.push_back(funcPointerType); 299 newOpers.push_back(funcPointer); 300 trailingInTys.push_back(lenType); 301 trailingOpers.push_back(len); 302 } else { 303 newInTys.push_back(tuple); 304 newOpers.push_back(oper); 305 } 306 }) 307 .Default([&](mlir::Type ty) { 308 newInTys.push_back(ty); 309 newOpers.push_back(oper); 310 }); 311 } 312 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 313 newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); 314 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 315 fir::CallOp newCall; 316 if (callOp.callee().hasValue()) { 317 newCall = rewriter->create<A>(loc, callOp.callee().getValue(), 318 newResTys, newOpers); 319 } else { 320 // Force new type on the input operand. 321 newOpers[0].setType(mlir::FunctionType::get( 322 callOp.getContext(), 323 mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys)); 324 newCall = rewriter->create<A>(loc, newResTys, newOpers); 325 } 326 LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n'); 327 if (wrap.hasValue()) 328 replaceOp(callOp, (*wrap)(newCall.getOperation())); 329 else 330 replaceOp(callOp, newCall.getResults()); 331 } else { 332 // A is fir::DispatchOp 333 TODO(loc, "dispatch not implemented"); 334 } 335 } 336 337 // Result type fixup for fir::ComplexType and mlir::ComplexType 338 template <typename A, typename B> 339 void lowerComplexSignatureRes(A cmplx, B &newResTys, B &newInTys) { 340 if (noComplexConversion) { 341 newResTys.push_back(cmplx); 342 } else { 343 for (auto &tup : specifics->complexReturnType(cmplx.getElementType())) { 344 auto argTy = std::get<mlir::Type>(tup); 345 if (std::get<CodeGenSpecifics::Attributes>(tup).isSRet()) 346 newInTys.push_back(argTy); 347 else 348 newResTys.push_back(argTy); 349 } 350 } 351 } 352 353 // Argument type fixup for fir::ComplexType and mlir::ComplexType 354 template <typename A, typename B> 355 void lowerComplexSignatureArg(A cmplx, B &newInTys) { 356 if (noComplexConversion) 357 newInTys.push_back(cmplx); 358 else 359 for (auto &tup : specifics->complexArgumentType(cmplx.getElementType())) 360 newInTys.push_back(std::get<mlir::Type>(tup)); 361 } 362 363 /// Taking the address of a function. Modify the signature as needed. 364 void convertAddrOp(AddrOfOp addrOp) { 365 rewriter->setInsertionPoint(addrOp); 366 auto addrTy = addrOp.getType().cast<mlir::FunctionType>(); 367 llvm::SmallVector<mlir::Type> newResTys; 368 llvm::SmallVector<mlir::Type> newInTys; 369 for (mlir::Type ty : addrTy.getResults()) { 370 llvm::TypeSwitch<mlir::Type>(ty) 371 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 372 lowerComplexSignatureRes(ty, newResTys, newInTys); 373 }) 374 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 375 lowerComplexSignatureRes(ty, newResTys, newInTys); 376 }) 377 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 378 } 379 llvm::SmallVector<mlir::Type> trailingInTys; 380 for (mlir::Type ty : addrTy.getInputs()) { 381 llvm::TypeSwitch<mlir::Type>(ty) 382 .Case<BoxCharType>([&](BoxCharType box) { 383 if (noCharacterConversion) { 384 newInTys.push_back(box); 385 } else { 386 for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) { 387 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 388 auto argTy = std::get<mlir::Type>(tup); 389 llvm::SmallVector<mlir::Type> &vec = 390 attr.isAppend() ? trailingInTys : newInTys; 391 vec.push_back(argTy); 392 } 393 } 394 }) 395 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 396 lowerComplexSignatureArg(ty, newInTys); 397 }) 398 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 399 lowerComplexSignatureArg(ty, newInTys); 400 }) 401 .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 402 if (factory::isCharacterProcedureTuple(tuple)) { 403 newInTys.push_back(tuple.getType(0)); 404 trailingInTys.push_back(tuple.getType(1)); 405 } else { 406 newInTys.push_back(ty); 407 } 408 }) 409 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 410 } 411 // append trailing input types 412 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 413 // replace this op with a new one with the updated signature 414 auto newTy = rewriter->getFunctionType(newInTys, newResTys); 415 auto newOp = 416 rewriter->create<AddrOfOp>(addrOp.getLoc(), newTy, addrOp.symbol()); 417 replaceOp(addrOp, newOp.getResult()); 418 } 419 420 /// Convert the type signatures on all the functions present in the module. 421 /// As the type signature is being changed, this must also update the 422 /// function itself to use any new arguments, etc. 423 mlir::LogicalResult convertTypes(mlir::ModuleOp mod) { 424 for (auto fn : mod.getOps<mlir::FuncOp>()) 425 convertSignature(fn); 426 return mlir::success(); 427 } 428 429 /// If the signature does not need any special target-specific converions, 430 /// then it is considered portable for any target, and this function will 431 /// return `true`. Otherwise, the signature is not portable and `false` is 432 /// returned. 433 bool hasPortableSignature(mlir::Type signature) { 434 assert(signature.isa<mlir::FunctionType>()); 435 auto func = signature.dyn_cast<mlir::FunctionType>(); 436 for (auto ty : func.getResults()) 437 if ((ty.isa<BoxCharType>() && !noCharacterConversion) || 438 (isa_complex(ty) && !noComplexConversion)) { 439 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 440 return false; 441 } 442 for (auto ty : func.getInputs()) 443 if (((ty.isa<BoxCharType>() || factory::isCharacterProcedureTuple(ty)) && 444 !noCharacterConversion) || 445 (isa_complex(ty) && !noComplexConversion)) { 446 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 447 return false; 448 } 449 return true; 450 } 451 452 /// Rewrite the signatures and body of the `FuncOp`s in the module for 453 /// the immediately subsequent target code gen. 454 void convertSignature(mlir::FuncOp func) { 455 auto funcTy = func.getType().cast<mlir::FunctionType>(); 456 if (hasPortableSignature(funcTy)) 457 return; 458 llvm::SmallVector<mlir::Type> newResTys; 459 llvm::SmallVector<mlir::Type> newInTys; 460 llvm::SmallVector<FixupTy> fixups; 461 462 // Convert return value(s) 463 for (auto ty : funcTy.getResults()) 464 llvm::TypeSwitch<mlir::Type>(ty) 465 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 466 if (noComplexConversion) 467 newResTys.push_back(cmplx); 468 else 469 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 470 }) 471 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 472 if (noComplexConversion) 473 newResTys.push_back(cmplx); 474 else 475 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 476 }) 477 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 478 479 // Convert arguments 480 llvm::SmallVector<mlir::Type> trailingTys; 481 for (auto e : llvm::enumerate(funcTy.getInputs())) { 482 auto ty = e.value(); 483 unsigned index = e.index(); 484 llvm::TypeSwitch<mlir::Type>(ty) 485 .Case<BoxCharType>([&](BoxCharType boxTy) { 486 if (noCharacterConversion) { 487 newInTys.push_back(boxTy); 488 } else { 489 // Convert a CHARACTER argument type. This can involve separating 490 // the pointer and the LEN into two arguments and moving the LEN 491 // argument to the end of the arg list. 492 bool sret = functionArgIsSRet(index, func); 493 for (auto e : llvm::enumerate(specifics->boxcharArgumentType( 494 boxTy.getEleTy(), sret))) { 495 auto &tup = e.value(); 496 auto index = e.index(); 497 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 498 auto argTy = std::get<mlir::Type>(tup); 499 if (attr.isAppend()) { 500 trailingTys.push_back(argTy); 501 } else { 502 if (sret) { 503 fixups.emplace_back(FixupTy::Codes::CharPair, 504 newInTys.size(), index); 505 } else { 506 fixups.emplace_back(FixupTy::Codes::Trailing, 507 newInTys.size(), trailingTys.size()); 508 } 509 newInTys.push_back(argTy); 510 } 511 } 512 } 513 }) 514 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 515 if (noComplexConversion) 516 newInTys.push_back(cmplx); 517 else 518 doComplexArg(func, cmplx, newInTys, fixups); 519 }) 520 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 521 if (noComplexConversion) 522 newInTys.push_back(cmplx); 523 else 524 doComplexArg(func, cmplx, newInTys, fixups); 525 }) 526 .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 527 if (factory::isCharacterProcedureTuple(tuple)) { 528 fixups.emplace_back(FixupTy::Codes::TrailingCharProc, 529 newInTys.size(), trailingTys.size()); 530 newInTys.push_back(tuple.getType(0)); 531 trailingTys.push_back(tuple.getType(1)); 532 } else { 533 newInTys.push_back(ty); 534 } 535 }) 536 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 537 } 538 539 if (!func.empty()) { 540 // If the function has a body, then apply the fixups to the arguments and 541 // return ops as required. These fixups are done in place. 542 auto loc = func.getLoc(); 543 const auto fixupSize = fixups.size(); 544 const auto oldArgTys = func.getType().getInputs(); 545 int offset = 0; 546 for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) { 547 const auto &fixup = fixups[i]; 548 switch (fixup.code) { 549 case FixupTy::Codes::ArgumentAsLoad: { 550 // Argument was pass-by-value, but is now pass-by-reference and 551 // possibly with a different element type. 552 auto newArg = func.front().insertArgument(fixup.index, 553 newInTys[fixup.index], loc); 554 rewriter->setInsertionPointToStart(&func.front()); 555 auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); 556 auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, newArg); 557 auto load = rewriter->create<fir::LoadOp>(loc, cast); 558 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 559 func.front().eraseArgument(fixup.index + 1); 560 } break; 561 case FixupTy::Codes::ArgumentType: { 562 // Argument is pass-by-value, but its type has likely been modified to 563 // suit the target ABI convention. 564 auto newArg = func.front().insertArgument(fixup.index, 565 newInTys[fixup.index], loc); 566 rewriter->setInsertionPointToStart(&func.front()); 567 auto mem = 568 rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]); 569 rewriter->create<fir::StoreOp>(loc, newArg, mem); 570 auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); 571 auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, mem); 572 mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast); 573 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 574 func.front().eraseArgument(fixup.index + 1); 575 LLVM_DEBUG(llvm::dbgs() 576 << "old argument: " << oldArgTy.getEleTy() 577 << ", repl: " << load << ", new argument: " 578 << func.getArgument(fixup.index).getType() << '\n'); 579 } break; 580 case FixupTy::Codes::CharPair: { 581 // The FIR boxchar argument has been split into a pair of distinct 582 // arguments that are in juxtaposition to each other. 583 auto newArg = func.front().insertArgument(fixup.index, 584 newInTys[fixup.index], loc); 585 if (fixup.second == 1) { 586 rewriter->setInsertionPointToStart(&func.front()); 587 auto boxTy = oldArgTys[fixup.index - offset - fixup.second]; 588 auto box = rewriter->create<EmboxCharOp>( 589 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg); 590 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 591 func.front().eraseArgument(fixup.index + 1); 592 offset++; 593 } 594 } break; 595 case FixupTy::Codes::ReturnAsStore: { 596 // The value being returned is now being returned in memory (callee 597 // stack space) through a hidden reference argument. 598 auto newArg = func.front().insertArgument(fixup.index, 599 newInTys[fixup.index], loc); 600 offset++; 601 func.walk([&](mlir::ReturnOp ret) { 602 rewriter->setInsertionPoint(ret); 603 auto oldOper = ret.getOperand(0); 604 auto oldOperTy = ReferenceType::get(oldOper.getType()); 605 auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, newArg); 606 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 607 rewriter->create<mlir::ReturnOp>(loc); 608 ret.erase(); 609 }); 610 } break; 611 case FixupTy::Codes::ReturnType: { 612 // The function is still returning a value, but its type has likely 613 // changed to suit the target ABI convention. 614 func.walk([&](mlir::ReturnOp ret) { 615 rewriter->setInsertionPoint(ret); 616 auto oldOper = ret.getOperand(0); 617 auto oldOperTy = ReferenceType::get(oldOper.getType()); 618 auto mem = 619 rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]); 620 auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, mem); 621 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 622 mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem); 623 rewriter->create<mlir::ReturnOp>(loc, load); 624 ret.erase(); 625 }); 626 } break; 627 case FixupTy::Codes::Split: { 628 // The FIR argument has been split into a pair of distinct arguments 629 // that are in juxtaposition to each other. (For COMPLEX value.) 630 auto newArg = func.front().insertArgument(fixup.index, 631 newInTys[fixup.index], loc); 632 if (fixup.second == 1) { 633 rewriter->setInsertionPointToStart(&func.front()); 634 auto cplxTy = oldArgTys[fixup.index - offset - fixup.second]; 635 auto undef = rewriter->create<UndefOp>(loc, cplxTy); 636 auto zero = rewriter->getIntegerAttr(rewriter->getIndexType(), 0); 637 auto one = rewriter->getIntegerAttr(rewriter->getIndexType(), 1); 638 auto cplx1 = rewriter->create<InsertValueOp>( 639 loc, cplxTy, undef, func.front().getArgument(fixup.index - 1), 640 rewriter->getArrayAttr(zero)); 641 auto cplx = rewriter->create<InsertValueOp>( 642 loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one)); 643 func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx); 644 func.front().eraseArgument(fixup.index + 1); 645 offset++; 646 } 647 } break; 648 case FixupTy::Codes::Trailing: { 649 // The FIR argument has been split into a pair of distinct arguments. 650 // The first part of the pair appears in the original argument 651 // position. The second part of the pair is appended after all the 652 // original arguments. (Boxchar arguments.) 653 auto newBufArg = func.front().insertArgument( 654 fixup.index, newInTys[fixup.index], loc); 655 auto newLenArg = 656 func.front().addArgument(trailingTys[fixup.second], loc); 657 auto boxTy = oldArgTys[fixup.index - offset]; 658 rewriter->setInsertionPointToStart(&func.front()); 659 auto box = 660 rewriter->create<EmboxCharOp>(loc, boxTy, newBufArg, newLenArg); 661 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 662 func.front().eraseArgument(fixup.index + 1); 663 } break; 664 case FixupTy::Codes::TrailingCharProc: { 665 // The FIR character procedure argument tuple has been split into a 666 // pair of distinct arguments. The first part of the pair appears in 667 // the original argument position. The second part of the pair is 668 // appended after all the original arguments. 669 auto newProcPointerArg = func.front().insertArgument( 670 fixup.index, newInTys[fixup.index], loc); 671 auto newLenArg = 672 func.front().addArgument(trailingTys[fixup.second], loc); 673 auto tupleType = oldArgTys[fixup.index - offset]; 674 rewriter->setInsertionPointToStart(&func.front()); 675 FirOpBuilder builder(*rewriter, getKindMapping(getModule())); 676 auto tuple = factory::createCharacterProcedureTuple( 677 builder, loc, tupleType, newProcPointerArg, newLenArg); 678 func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple); 679 func.front().eraseArgument(fixup.index + 1); 680 } break; 681 } 682 } 683 } 684 685 // Set the new type and finalize the arguments, etc. 686 newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end()); 687 auto newFuncTy = 688 mlir::FunctionType::get(func.getContext(), newInTys, newResTys); 689 LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n'); 690 func.setType(newFuncTy); 691 692 for (auto &fixup : fixups) 693 if (fixup.finalizer) 694 (*fixup.finalizer)(func); 695 } 696 697 inline bool functionArgIsSRet(unsigned index, mlir::FuncOp func) { 698 if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret")) 699 return true; 700 return false; 701 } 702 703 /// Convert a complex return value. This can involve converting the return 704 /// value to a "hidden" first argument or packing the complex into a wide 705 /// GPR. 706 template <typename A, typename B, typename C> 707 void doComplexReturn(mlir::FuncOp func, A cmplx, B &newResTys, B &newInTys, 708 C &fixups) { 709 if (noComplexConversion) { 710 newResTys.push_back(cmplx); 711 return; 712 } 713 auto m = specifics->complexReturnType(cmplx.getElementType()); 714 assert(m.size() == 1); 715 auto &tup = m[0]; 716 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 717 auto argTy = std::get<mlir::Type>(tup); 718 if (attr.isSRet()) { 719 unsigned argNo = newInTys.size(); 720 fixups.emplace_back( 721 FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::FuncOp func) { 722 func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); 723 }); 724 newInTys.push_back(argTy); 725 return; 726 } 727 fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size()); 728 newResTys.push_back(argTy); 729 } 730 731 /// Convert a complex argument value. This can involve storing the value to 732 /// a temporary memory location or factoring the value into two distinct 733 /// arguments. 734 template <typename A, typename B, typename C> 735 void doComplexArg(mlir::FuncOp func, A cmplx, B &newInTys, C &fixups) { 736 if (noComplexConversion) { 737 newInTys.push_back(cmplx); 738 return; 739 } 740 auto m = specifics->complexArgumentType(cmplx.getElementType()); 741 const auto fixupCode = 742 m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType; 743 for (auto e : llvm::enumerate(m)) { 744 auto &tup = e.value(); 745 auto index = e.index(); 746 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 747 auto argTy = std::get<mlir::Type>(tup); 748 auto argNo = newInTys.size(); 749 if (attr.isByVal()) { 750 if (auto align = attr.getAlignment()) 751 fixups.emplace_back( 752 FixupTy::Codes::ArgumentAsLoad, argNo, [=](mlir::FuncOp func) { 753 func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr()); 754 func.setArgAttr(argNo, "llvm.align", 755 rewriter->getIntegerAttr( 756 rewriter->getIntegerType(32), align)); 757 }); 758 else 759 fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(), 760 [=](mlir::FuncOp func) { 761 func.setArgAttr(argNo, "llvm.byval", 762 rewriter->getUnitAttr()); 763 }); 764 } else { 765 if (auto align = attr.getAlignment()) 766 fixups.emplace_back(fixupCode, argNo, index, [=](mlir::FuncOp func) { 767 func.setArgAttr( 768 argNo, "llvm.align", 769 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align)); 770 }); 771 else 772 fixups.emplace_back(fixupCode, argNo, index); 773 } 774 newInTys.push_back(argTy); 775 } 776 } 777 778 private: 779 // Replace `op` and remove it. 780 void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { 781 op->replaceAllUsesWith(newValues); 782 op->dropAllReferences(); 783 op->erase(); 784 } 785 786 inline void setMembers(CodeGenSpecifics *s, mlir::OpBuilder *r) { 787 specifics = s; 788 rewriter = r; 789 } 790 791 inline void clearMembers() { setMembers(nullptr, nullptr); } 792 793 CodeGenSpecifics *specifics{}; 794 mlir::OpBuilder *rewriter; 795 }; // namespace 796 } // namespace 797 798 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 799 fir::createFirTargetRewritePass(const TargetRewriteOptions &options) { 800 return std::make_unique<TargetRewrite>(options); 801 } 802