1 //===-- TargetRewrite.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest. 10 // LLVM expects different lowering idioms to be used for distinct target 11 // triples. These distinctions are handled by this pass. 12 // 13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "PassDetail.h" 18 #include "Target.h" 19 #include "flang/Lower/Todo.h" 20 #include "flang/Optimizer/Builder/Character.h" 21 #include "flang/Optimizer/Builder/FIRBuilder.h" 22 #include "flang/Optimizer/CodeGen/CodeGen.h" 23 #include "flang/Optimizer/Dialect/FIRDialect.h" 24 #include "flang/Optimizer/Dialect/FIROps.h" 25 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 26 #include "flang/Optimizer/Dialect/FIRType.h" 27 #include "flang/Optimizer/Support/FIRContext.h" 28 #include "mlir/Transforms/DialectConversion.h" 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/ADT/TypeSwitch.h" 31 #include "llvm/Support/Debug.h" 32 33 using namespace fir; 34 using namespace mlir; 35 36 #define DEBUG_TYPE "flang-target-rewrite" 37 38 namespace { 39 40 /// Fixups for updating a FuncOp's arguments and return values. 41 struct FixupTy { 42 enum class Codes { 43 ArgumentAsLoad, 44 ArgumentType, 45 CharPair, 46 ReturnAsStore, 47 ReturnType, 48 Split, 49 Trailing, 50 TrailingCharProc 51 }; 52 53 FixupTy(Codes code, std::size_t index, std::size_t second = 0) 54 : code{code}, index{index}, second{second} {} 55 FixupTy(Codes code, std::size_t index, 56 std::function<void(mlir::FuncOp)> &&finalizer) 57 : code{code}, index{index}, finalizer{finalizer} {} 58 FixupTy(Codes code, std::size_t index, std::size_t second, 59 std::function<void(mlir::FuncOp)> &&finalizer) 60 : code{code}, index{index}, second{second}, finalizer{finalizer} {} 61 62 Codes code; 63 std::size_t index; 64 std::size_t second{}; 65 llvm::Optional<std::function<void(mlir::FuncOp)>> finalizer{}; 66 }; // namespace 67 68 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code 69 /// generation that traverses the FIR and modifies types and operations to a 70 /// form that is appropriate for the specific target. LLVM IR has specific 71 /// idioms that are used for distinct target processor and ABI combinations. 72 class TargetRewrite : public TargetRewriteBase<TargetRewrite> { 73 public: 74 TargetRewrite(const TargetRewriteOptions &options) { 75 noCharacterConversion = options.noCharacterConversion; 76 noComplexConversion = options.noComplexConversion; 77 } 78 79 void runOnOperation() override final { 80 auto &context = getContext(); 81 mlir::OpBuilder rewriter(&context); 82 83 auto mod = getModule(); 84 if (!forcedTargetTriple.empty()) 85 setTargetTriple(mod, forcedTargetTriple); 86 87 auto specifics = CodeGenSpecifics::get( 88 mod.getContext(), getTargetTriple(mod), getKindMapping(mod)); 89 setMembers(specifics.get(), &rewriter); 90 91 // Perform type conversion on signatures and call sites. 92 if (mlir::failed(convertTypes(mod))) { 93 mlir::emitError(mlir::UnknownLoc::get(&context), 94 "error in converting types to target abi"); 95 signalPassFailure(); 96 } 97 98 // Convert ops in target-specific patterns. 99 mod.walk([&](mlir::Operation *op) { 100 if (auto call = dyn_cast<fir::CallOp>(op)) { 101 if (!hasPortableSignature(call.getFunctionType())) 102 convertCallOp(call); 103 } else if (auto dispatch = dyn_cast<DispatchOp>(op)) { 104 if (!hasPortableSignature(dispatch.getFunctionType())) 105 convertCallOp(dispatch); 106 } else if (auto addr = dyn_cast<AddrOfOp>(op)) { 107 if (addr.getType().isa<mlir::FunctionType>() && 108 !hasPortableSignature(addr.getType())) 109 convertAddrOp(addr); 110 } 111 }); 112 113 clearMembers(); 114 } 115 116 mlir::ModuleOp getModule() { return getOperation(); } 117 118 template <typename A, typename B, typename C> 119 std::function<mlir::Value(mlir::Operation *)> 120 rewriteCallComplexResultType(A ty, B &newResTys, B &newInTys, C &newOpers) { 121 auto m = specifics->complexReturnType(ty.getElementType()); 122 // Currently targets mandate COMPLEX is a single aggregate or packed 123 // scalar, including the sret case. 124 assert(m.size() == 1 && "target lowering of complex return not supported"); 125 auto resTy = std::get<mlir::Type>(m[0]); 126 auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); 127 auto loc = mlir::UnknownLoc::get(resTy.getContext()); 128 if (attr.isSRet()) { 129 assert(isa_ref_type(resTy)); 130 mlir::Value stack = 131 rewriter->create<fir::AllocaOp>(loc, dyn_cast_ptrEleTy(resTy)); 132 newInTys.push_back(resTy); 133 newOpers.push_back(stack); 134 return [=](mlir::Operation *) -> mlir::Value { 135 auto memTy = ReferenceType::get(ty); 136 auto cast = rewriter->create<ConvertOp>(loc, memTy, stack); 137 return rewriter->create<fir::LoadOp>(loc, cast); 138 }; 139 } 140 newResTys.push_back(resTy); 141 return [=](mlir::Operation *call) -> mlir::Value { 142 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 143 rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem); 144 auto memTy = ReferenceType::get(ty); 145 auto cast = rewriter->create<ConvertOp>(loc, memTy, mem); 146 return rewriter->create<fir::LoadOp>(loc, cast); 147 }; 148 } 149 150 template <typename A, typename B, typename C> 151 void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys, 152 C &newOpers) { 153 auto m = specifics->complexArgumentType(ty.getElementType()); 154 auto *ctx = ty.getContext(); 155 auto loc = mlir::UnknownLoc::get(ctx); 156 if (m.size() == 1) { 157 // COMPLEX is a single aggregate 158 auto resTy = std::get<mlir::Type>(m[0]); 159 auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); 160 auto oldRefTy = ReferenceType::get(ty); 161 if (attr.isByVal()) { 162 auto mem = rewriter->create<fir::AllocaOp>(loc, ty); 163 rewriter->create<fir::StoreOp>(loc, oper, mem); 164 newOpers.push_back(rewriter->create<ConvertOp>(loc, resTy, mem)); 165 } else { 166 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 167 auto cast = rewriter->create<ConvertOp>(loc, oldRefTy, mem); 168 rewriter->create<fir::StoreOp>(loc, oper, cast); 169 newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem)); 170 } 171 newInTys.push_back(resTy); 172 } else { 173 assert(m.size() == 2); 174 // COMPLEX is split into 2 separate arguments 175 auto iTy = rewriter->getIntegerType(32); 176 for (auto e : llvm::enumerate(m)) { 177 auto &tup = e.value(); 178 auto ty = std::get<mlir::Type>(tup); 179 auto index = e.index(); 180 auto idx = rewriter->getIntegerAttr(iTy, index); 181 auto val = rewriter->create<ExtractValueOp>( 182 loc, ty, oper, rewriter->getArrayAttr(idx)); 183 newInTys.push_back(ty); 184 newOpers.push_back(val); 185 } 186 } 187 } 188 189 // Convert fir.call and fir.dispatch Ops. 190 template <typename A> 191 void convertCallOp(A callOp) { 192 auto fnTy = callOp.getFunctionType(); 193 auto loc = callOp.getLoc(); 194 rewriter->setInsertionPoint(callOp); 195 llvm::SmallVector<mlir::Type> newResTys; 196 llvm::SmallVector<mlir::Type> newInTys; 197 llvm::SmallVector<mlir::Value> newOpers; 198 199 // If the call is indirect, the first argument must still be the function 200 // to call. 201 int dropFront = 0; 202 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 203 if (!callOp.getCallee().hasValue()) { 204 newInTys.push_back(fnTy.getInput(0)); 205 newOpers.push_back(callOp.getOperand(0)); 206 dropFront = 1; 207 } 208 } 209 210 // Determine the rewrite function, `wrap`, for the result value. 211 llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap; 212 if (fnTy.getResults().size() == 1) { 213 mlir::Type ty = fnTy.getResult(0); 214 llvm::TypeSwitch<mlir::Type>(ty) 215 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 216 wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 217 newOpers); 218 }) 219 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 220 wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 221 newOpers); 222 }) 223 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 224 } else if (fnTy.getResults().size() > 1) { 225 TODO(loc, "multiple results not supported yet"); 226 } 227 228 llvm::SmallVector<mlir::Type> trailingInTys; 229 llvm::SmallVector<mlir::Value> trailingOpers; 230 for (auto e : llvm::enumerate( 231 llvm::zip(fnTy.getInputs().drop_front(dropFront), 232 callOp.getOperands().drop_front(dropFront)))) { 233 mlir::Type ty = std::get<0>(e.value()); 234 mlir::Value oper = std::get<1>(e.value()); 235 unsigned index = e.index(); 236 llvm::TypeSwitch<mlir::Type>(ty) 237 .template Case<BoxCharType>([&](BoxCharType boxTy) { 238 bool sret; 239 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 240 sret = callOp.getCallee() && 241 functionArgIsSRet(index, 242 getModule().lookupSymbol<mlir::FuncOp>( 243 *callOp.getCallee())); 244 } else { 245 // TODO: dispatch case; how do we put arguments on a call? 246 // We cannot put both an sret and the dispatch object first. 247 sret = false; 248 TODO(loc, "dispatch + sret not supported yet"); 249 } 250 auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret); 251 auto unbox = 252 rewriter->create<UnboxCharOp>(loc, std::get<mlir::Type>(m[0]), 253 std::get<mlir::Type>(m[1]), oper); 254 // unboxed CHARACTER arguments 255 for (auto e : llvm::enumerate(m)) { 256 unsigned idx = e.index(); 257 auto attr = std::get<CodeGenSpecifics::Attributes>(e.value()); 258 auto argTy = std::get<mlir::Type>(e.value()); 259 if (attr.isAppend()) { 260 trailingInTys.push_back(argTy); 261 trailingOpers.push_back(unbox.getResult(idx)); 262 } else { 263 newInTys.push_back(argTy); 264 newOpers.push_back(unbox.getResult(idx)); 265 } 266 } 267 }) 268 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 269 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 270 }) 271 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 272 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 273 }) 274 .template Case<mlir::TupleType>([&](mlir::TupleType tuple) { 275 if (isCharacterProcedureTuple(tuple)) { 276 mlir::ModuleOp module = getModule(); 277 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 278 if (callOp.getCallee()) { 279 llvm::StringRef charProcAttr = 280 getCharacterProcedureDummyAttrName(); 281 // The charProcAttr attribute is only used as a safety to 282 // confirm that this is a dummy procedure and should be split. 283 // It cannot be used to match because attributes are not 284 // available in case of indirect calls. 285 auto funcOp = 286 module.lookupSymbol<mlir::FuncOp>(*callOp.getCallee()); 287 if (funcOp && 288 !funcOp.template getArgAttrOfType<mlir::UnitAttr>( 289 index, charProcAttr)) 290 mlir::emitError(loc, "tuple argument will be split even " 291 "though it does not have the `" + 292 charProcAttr + "` attribute"); 293 } 294 } 295 mlir::Type funcPointerType = tuple.getType(0); 296 mlir::Type lenType = tuple.getType(1); 297 FirOpBuilder builder(*rewriter, getKindMapping(module)); 298 auto [funcPointer, len] = 299 factory::extractCharacterProcedureTuple(builder, loc, oper); 300 newInTys.push_back(funcPointerType); 301 newOpers.push_back(funcPointer); 302 trailingInTys.push_back(lenType); 303 trailingOpers.push_back(len); 304 } else { 305 newInTys.push_back(tuple); 306 newOpers.push_back(oper); 307 } 308 }) 309 .Default([&](mlir::Type ty) { 310 newInTys.push_back(ty); 311 newOpers.push_back(oper); 312 }); 313 } 314 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 315 newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); 316 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 317 fir::CallOp newCall; 318 if (callOp.getCallee().hasValue()) { 319 newCall = rewriter->create<A>(loc, callOp.getCallee().getValue(), 320 newResTys, newOpers); 321 } else { 322 // Force new type on the input operand. 323 newOpers[0].setType(mlir::FunctionType::get( 324 callOp.getContext(), 325 mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys)); 326 newCall = rewriter->create<A>(loc, newResTys, newOpers); 327 } 328 LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n'); 329 if (wrap.hasValue()) 330 replaceOp(callOp, (*wrap)(newCall.getOperation())); 331 else 332 replaceOp(callOp, newCall.getResults()); 333 } else { 334 // A is fir::DispatchOp 335 TODO(loc, "dispatch not implemented"); 336 } 337 } 338 339 // Result type fixup for fir::ComplexType and mlir::ComplexType 340 template <typename A, typename B> 341 void lowerComplexSignatureRes(A cmplx, B &newResTys, B &newInTys) { 342 if (noComplexConversion) { 343 newResTys.push_back(cmplx); 344 } else { 345 for (auto &tup : specifics->complexReturnType(cmplx.getElementType())) { 346 auto argTy = std::get<mlir::Type>(tup); 347 if (std::get<CodeGenSpecifics::Attributes>(tup).isSRet()) 348 newInTys.push_back(argTy); 349 else 350 newResTys.push_back(argTy); 351 } 352 } 353 } 354 355 // Argument type fixup for fir::ComplexType and mlir::ComplexType 356 template <typename A, typename B> 357 void lowerComplexSignatureArg(A cmplx, B &newInTys) { 358 if (noComplexConversion) 359 newInTys.push_back(cmplx); 360 else 361 for (auto &tup : specifics->complexArgumentType(cmplx.getElementType())) 362 newInTys.push_back(std::get<mlir::Type>(tup)); 363 } 364 365 /// Taking the address of a function. Modify the signature as needed. 366 void convertAddrOp(AddrOfOp addrOp) { 367 rewriter->setInsertionPoint(addrOp); 368 auto addrTy = addrOp.getType().cast<mlir::FunctionType>(); 369 llvm::SmallVector<mlir::Type> newResTys; 370 llvm::SmallVector<mlir::Type> newInTys; 371 for (mlir::Type ty : addrTy.getResults()) { 372 llvm::TypeSwitch<mlir::Type>(ty) 373 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 374 lowerComplexSignatureRes(ty, newResTys, newInTys); 375 }) 376 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 377 lowerComplexSignatureRes(ty, newResTys, newInTys); 378 }) 379 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 380 } 381 llvm::SmallVector<mlir::Type> trailingInTys; 382 for (mlir::Type ty : addrTy.getInputs()) { 383 llvm::TypeSwitch<mlir::Type>(ty) 384 .Case<BoxCharType>([&](BoxCharType box) { 385 if (noCharacterConversion) { 386 newInTys.push_back(box); 387 } else { 388 for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) { 389 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 390 auto argTy = std::get<mlir::Type>(tup); 391 llvm::SmallVector<mlir::Type> &vec = 392 attr.isAppend() ? trailingInTys : newInTys; 393 vec.push_back(argTy); 394 } 395 } 396 }) 397 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 398 lowerComplexSignatureArg(ty, newInTys); 399 }) 400 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 401 lowerComplexSignatureArg(ty, newInTys); 402 }) 403 .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 404 if (isCharacterProcedureTuple(tuple)) { 405 newInTys.push_back(tuple.getType(0)); 406 trailingInTys.push_back(tuple.getType(1)); 407 } else { 408 newInTys.push_back(ty); 409 } 410 }) 411 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 412 } 413 // append trailing input types 414 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 415 // replace this op with a new one with the updated signature 416 auto newTy = rewriter->getFunctionType(newInTys, newResTys); 417 auto newOp = 418 rewriter->create<AddrOfOp>(addrOp.getLoc(), newTy, addrOp.getSymbol()); 419 replaceOp(addrOp, newOp.getResult()); 420 } 421 422 /// Convert the type signatures on all the functions present in the module. 423 /// As the type signature is being changed, this must also update the 424 /// function itself to use any new arguments, etc. 425 mlir::LogicalResult convertTypes(mlir::ModuleOp mod) { 426 for (auto fn : mod.getOps<mlir::FuncOp>()) 427 convertSignature(fn); 428 return mlir::success(); 429 } 430 431 /// If the signature does not need any special target-specific converions, 432 /// then it is considered portable for any target, and this function will 433 /// return `true`. Otherwise, the signature is not portable and `false` is 434 /// returned. 435 bool hasPortableSignature(mlir::Type signature) { 436 assert(signature.isa<mlir::FunctionType>()); 437 auto func = signature.dyn_cast<mlir::FunctionType>(); 438 for (auto ty : func.getResults()) 439 if ((ty.isa<BoxCharType>() && !noCharacterConversion) || 440 (isa_complex(ty) && !noComplexConversion)) { 441 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 442 return false; 443 } 444 for (auto ty : func.getInputs()) 445 if (((ty.isa<BoxCharType>() || isCharacterProcedureTuple(ty)) && 446 !noCharacterConversion) || 447 (isa_complex(ty) && !noComplexConversion)) { 448 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 449 return false; 450 } 451 return true; 452 } 453 454 /// Determine if the signature has host associations. The host association 455 /// argument may need special target specific rewriting. 456 static bool hasHostAssociations(mlir::FuncOp func) { 457 std::size_t end = func.getFunctionType().getInputs().size(); 458 for (std::size_t i = 0; i < end; ++i) 459 if (func.getArgAttrOfType<mlir::UnitAttr>(i, getHostAssocAttrName())) 460 return true; 461 return false; 462 } 463 464 /// Rewrite the signatures and body of the `FuncOp`s in the module for 465 /// the immediately subsequent target code gen. 466 void convertSignature(mlir::FuncOp func) { 467 auto funcTy = func.getFunctionType().cast<mlir::FunctionType>(); 468 if (hasPortableSignature(funcTy) && !hasHostAssociations(func)) 469 return; 470 llvm::SmallVector<mlir::Type> newResTys; 471 llvm::SmallVector<mlir::Type> newInTys; 472 llvm::SmallVector<FixupTy> fixups; 473 474 // Convert return value(s) 475 for (auto ty : funcTy.getResults()) 476 llvm::TypeSwitch<mlir::Type>(ty) 477 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 478 if (noComplexConversion) 479 newResTys.push_back(cmplx); 480 else 481 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 482 }) 483 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 484 if (noComplexConversion) 485 newResTys.push_back(cmplx); 486 else 487 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 488 }) 489 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 490 491 // Convert arguments 492 llvm::SmallVector<mlir::Type> trailingTys; 493 for (auto e : llvm::enumerate(funcTy.getInputs())) { 494 auto ty = e.value(); 495 unsigned index = e.index(); 496 llvm::TypeSwitch<mlir::Type>(ty) 497 .Case<BoxCharType>([&](BoxCharType boxTy) { 498 if (noCharacterConversion) { 499 newInTys.push_back(boxTy); 500 } else { 501 // Convert a CHARACTER argument type. This can involve separating 502 // the pointer and the LEN into two arguments and moving the LEN 503 // argument to the end of the arg list. 504 bool sret = functionArgIsSRet(index, func); 505 for (auto e : llvm::enumerate(specifics->boxcharArgumentType( 506 boxTy.getEleTy(), sret))) { 507 auto &tup = e.value(); 508 auto index = e.index(); 509 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 510 auto argTy = std::get<mlir::Type>(tup); 511 if (attr.isAppend()) { 512 trailingTys.push_back(argTy); 513 } else { 514 if (sret) { 515 fixups.emplace_back(FixupTy::Codes::CharPair, 516 newInTys.size(), index); 517 } else { 518 fixups.emplace_back(FixupTy::Codes::Trailing, 519 newInTys.size(), trailingTys.size()); 520 } 521 newInTys.push_back(argTy); 522 } 523 } 524 } 525 }) 526 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 527 if (noComplexConversion) 528 newInTys.push_back(cmplx); 529 else 530 doComplexArg(func, cmplx, newInTys, fixups); 531 }) 532 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 533 if (noComplexConversion) 534 newInTys.push_back(cmplx); 535 else 536 doComplexArg(func, cmplx, newInTys, fixups); 537 }) 538 .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 539 if (isCharacterProcedureTuple(tuple)) { 540 fixups.emplace_back(FixupTy::Codes::TrailingCharProc, 541 newInTys.size(), trailingTys.size()); 542 newInTys.push_back(tuple.getType(0)); 543 trailingTys.push_back(tuple.getType(1)); 544 } else { 545 newInTys.push_back(ty); 546 } 547 }) 548 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 549 if (func.getArgAttrOfType<mlir::UnitAttr>(index, 550 getHostAssocAttrName())) { 551 func.setArgAttr(index, "llvm.nest", rewriter->getUnitAttr()); 552 } 553 } 554 555 if (!func.empty()) { 556 // If the function has a body, then apply the fixups to the arguments and 557 // return ops as required. These fixups are done in place. 558 auto loc = func.getLoc(); 559 const auto fixupSize = fixups.size(); 560 const auto oldArgTys = func.getFunctionType().getInputs(); 561 int offset = 0; 562 for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) { 563 const auto &fixup = fixups[i]; 564 switch (fixup.code) { 565 case FixupTy::Codes::ArgumentAsLoad: { 566 // Argument was pass-by-value, but is now pass-by-reference and 567 // possibly with a different element type. 568 auto newArg = func.front().insertArgument(fixup.index, 569 newInTys[fixup.index], loc); 570 rewriter->setInsertionPointToStart(&func.front()); 571 auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); 572 auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, newArg); 573 auto load = rewriter->create<fir::LoadOp>(loc, cast); 574 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 575 func.front().eraseArgument(fixup.index + 1); 576 } break; 577 case FixupTy::Codes::ArgumentType: { 578 // Argument is pass-by-value, but its type has likely been modified to 579 // suit the target ABI convention. 580 auto newArg = func.front().insertArgument(fixup.index, 581 newInTys[fixup.index], loc); 582 rewriter->setInsertionPointToStart(&func.front()); 583 auto mem = 584 rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]); 585 rewriter->create<fir::StoreOp>(loc, newArg, mem); 586 auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); 587 auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, mem); 588 mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast); 589 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 590 func.front().eraseArgument(fixup.index + 1); 591 LLVM_DEBUG(llvm::dbgs() 592 << "old argument: " << oldArgTy.getEleTy() 593 << ", repl: " << load << ", new argument: " 594 << func.getArgument(fixup.index).getType() << '\n'); 595 } break; 596 case FixupTy::Codes::CharPair: { 597 // The FIR boxchar argument has been split into a pair of distinct 598 // arguments that are in juxtaposition to each other. 599 auto newArg = func.front().insertArgument(fixup.index, 600 newInTys[fixup.index], loc); 601 if (fixup.second == 1) { 602 rewriter->setInsertionPointToStart(&func.front()); 603 auto boxTy = oldArgTys[fixup.index - offset - fixup.second]; 604 auto box = rewriter->create<EmboxCharOp>( 605 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg); 606 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 607 func.front().eraseArgument(fixup.index + 1); 608 offset++; 609 } 610 } break; 611 case FixupTy::Codes::ReturnAsStore: { 612 // The value being returned is now being returned in memory (callee 613 // stack space) through a hidden reference argument. 614 auto newArg = func.front().insertArgument(fixup.index, 615 newInTys[fixup.index], loc); 616 offset++; 617 func.walk([&](mlir::func::ReturnOp ret) { 618 rewriter->setInsertionPoint(ret); 619 auto oldOper = ret.getOperand(0); 620 auto oldOperTy = ReferenceType::get(oldOper.getType()); 621 auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, newArg); 622 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 623 rewriter->create<mlir::func::ReturnOp>(loc); 624 ret.erase(); 625 }); 626 } break; 627 case FixupTy::Codes::ReturnType: { 628 // The function is still returning a value, but its type has likely 629 // changed to suit the target ABI convention. 630 func.walk([&](mlir::func::ReturnOp ret) { 631 rewriter->setInsertionPoint(ret); 632 auto oldOper = ret.getOperand(0); 633 auto oldOperTy = ReferenceType::get(oldOper.getType()); 634 auto mem = 635 rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]); 636 auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, mem); 637 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 638 mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem); 639 rewriter->create<mlir::func::ReturnOp>(loc, load); 640 ret.erase(); 641 }); 642 } break; 643 case FixupTy::Codes::Split: { 644 // The FIR argument has been split into a pair of distinct arguments 645 // that are in juxtaposition to each other. (For COMPLEX value.) 646 auto newArg = func.front().insertArgument(fixup.index, 647 newInTys[fixup.index], loc); 648 if (fixup.second == 1) { 649 rewriter->setInsertionPointToStart(&func.front()); 650 auto cplxTy = oldArgTys[fixup.index - offset - fixup.second]; 651 auto undef = rewriter->create<UndefOp>(loc, cplxTy); 652 auto iTy = rewriter->getIntegerType(32); 653 auto zero = rewriter->getIntegerAttr(iTy, 0); 654 auto one = rewriter->getIntegerAttr(iTy, 1); 655 auto cplx1 = rewriter->create<InsertValueOp>( 656 loc, cplxTy, undef, func.front().getArgument(fixup.index - 1), 657 rewriter->getArrayAttr(zero)); 658 auto cplx = rewriter->create<InsertValueOp>( 659 loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one)); 660 func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx); 661 func.front().eraseArgument(fixup.index + 1); 662 offset++; 663 } 664 } break; 665 case FixupTy::Codes::Trailing: { 666 // The FIR argument has been split into a pair of distinct arguments. 667 // The first part of the pair appears in the original argument 668 // position. The second part of the pair is appended after all the 669 // original arguments. (Boxchar arguments.) 670 auto newBufArg = func.front().insertArgument( 671 fixup.index, newInTys[fixup.index], loc); 672 auto newLenArg = 673 func.front().addArgument(trailingTys[fixup.second], loc); 674 auto boxTy = oldArgTys[fixup.index - offset]; 675 rewriter->setInsertionPointToStart(&func.front()); 676 auto box = 677 rewriter->create<EmboxCharOp>(loc, boxTy, newBufArg, newLenArg); 678 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 679 func.front().eraseArgument(fixup.index + 1); 680 } break; 681 case FixupTy::Codes::TrailingCharProc: { 682 // The FIR character procedure argument tuple must be split into a 683 // pair of distinct arguments. The first part of the pair appears in 684 // the original argument position. The second part of the pair is 685 // appended after all the original arguments. 686 auto newProcPointerArg = func.front().insertArgument( 687 fixup.index, newInTys[fixup.index], loc); 688 auto newLenArg = 689 func.front().addArgument(trailingTys[fixup.second], loc); 690 auto tupleType = oldArgTys[fixup.index - offset]; 691 rewriter->setInsertionPointToStart(&func.front()); 692 FirOpBuilder builder(*rewriter, getKindMapping(getModule())); 693 auto tuple = factory::createCharacterProcedureTuple( 694 builder, loc, tupleType, newProcPointerArg, newLenArg); 695 func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple); 696 func.front().eraseArgument(fixup.index + 1); 697 } break; 698 } 699 } 700 } 701 702 // Set the new type and finalize the arguments, etc. 703 newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end()); 704 auto newFuncTy = 705 mlir::FunctionType::get(func.getContext(), newInTys, newResTys); 706 LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n'); 707 func.setType(newFuncTy); 708 709 for (auto &fixup : fixups) 710 if (fixup.finalizer) 711 (*fixup.finalizer)(func); 712 } 713 714 inline bool functionArgIsSRet(unsigned index, mlir::FuncOp func) { 715 if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret")) 716 return true; 717 return false; 718 } 719 720 /// Convert a complex return value. This can involve converting the return 721 /// value to a "hidden" first argument or packing the complex into a wide 722 /// GPR. 723 template <typename A, typename B, typename C> 724 void doComplexReturn(mlir::FuncOp func, A cmplx, B &newResTys, B &newInTys, 725 C &fixups) { 726 if (noComplexConversion) { 727 newResTys.push_back(cmplx); 728 return; 729 } 730 auto m = specifics->complexReturnType(cmplx.getElementType()); 731 assert(m.size() == 1); 732 auto &tup = m[0]; 733 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 734 auto argTy = std::get<mlir::Type>(tup); 735 if (attr.isSRet()) { 736 unsigned argNo = newInTys.size(); 737 fixups.emplace_back( 738 FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::FuncOp func) { 739 func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); 740 }); 741 newInTys.push_back(argTy); 742 return; 743 } 744 fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size()); 745 newResTys.push_back(argTy); 746 } 747 748 /// Convert a complex argument value. This can involve storing the value to 749 /// a temporary memory location or factoring the value into two distinct 750 /// arguments. 751 template <typename A, typename B, typename C> 752 void doComplexArg(mlir::FuncOp func, A cmplx, B &newInTys, C &fixups) { 753 if (noComplexConversion) { 754 newInTys.push_back(cmplx); 755 return; 756 } 757 auto m = specifics->complexArgumentType(cmplx.getElementType()); 758 const auto fixupCode = 759 m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType; 760 for (auto e : llvm::enumerate(m)) { 761 auto &tup = e.value(); 762 auto index = e.index(); 763 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 764 auto argTy = std::get<mlir::Type>(tup); 765 auto argNo = newInTys.size(); 766 if (attr.isByVal()) { 767 if (auto align = attr.getAlignment()) 768 fixups.emplace_back( 769 FixupTy::Codes::ArgumentAsLoad, argNo, [=](mlir::FuncOp func) { 770 func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr()); 771 func.setArgAttr(argNo, "llvm.align", 772 rewriter->getIntegerAttr( 773 rewriter->getIntegerType(32), align)); 774 }); 775 else 776 fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(), 777 [=](mlir::FuncOp func) { 778 func.setArgAttr(argNo, "llvm.byval", 779 rewriter->getUnitAttr()); 780 }); 781 } else { 782 if (auto align = attr.getAlignment()) 783 fixups.emplace_back(fixupCode, argNo, index, [=](mlir::FuncOp func) { 784 func.setArgAttr( 785 argNo, "llvm.align", 786 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align)); 787 }); 788 else 789 fixups.emplace_back(fixupCode, argNo, index); 790 } 791 newInTys.push_back(argTy); 792 } 793 } 794 795 private: 796 // Replace `op` and remove it. 797 void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { 798 op->replaceAllUsesWith(newValues); 799 op->dropAllReferences(); 800 op->erase(); 801 } 802 803 inline void setMembers(CodeGenSpecifics *s, mlir::OpBuilder *r) { 804 specifics = s; 805 rewriter = r; 806 } 807 808 inline void clearMembers() { setMembers(nullptr, nullptr); } 809 810 CodeGenSpecifics *specifics{}; 811 mlir::OpBuilder *rewriter; 812 }; // namespace 813 } // namespace 814 815 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 816 fir::createFirTargetRewritePass(const TargetRewriteOptions &options) { 817 return std::make_unique<TargetRewrite>(options); 818 } 819