1 //===-- TargetRewrite.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest. 10 // LLVM expects different lowering idioms to be used for distinct target 11 // triples. These distinctions are handled by this pass. 12 // 13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "PassDetail.h" 18 #include "Target.h" 19 #include "flang/Lower/Todo.h" 20 #include "flang/Optimizer/CodeGen/CodeGen.h" 21 #include "flang/Optimizer/Dialect/FIRDialect.h" 22 #include "flang/Optimizer/Dialect/FIROps.h" 23 #include "flang/Optimizer/Dialect/FIRType.h" 24 #include "flang/Optimizer/Support/FIRContext.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 #include "llvm/ADT/STLExtras.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 #include "llvm/Support/Debug.h" 29 30 using namespace fir; 31 32 #define DEBUG_TYPE "flang-target-rewrite" 33 34 namespace { 35 36 /// Fixups for updating a FuncOp's arguments and return values. 37 struct FixupTy { 38 enum class Codes { 39 ArgumentAsLoad, 40 ArgumentType, 41 CharPair, 42 ReturnAsStore, 43 ReturnType, 44 Split, 45 Trailing 46 }; 47 48 FixupTy(Codes code, std::size_t index, std::size_t second = 0) 49 : code{code}, index{index}, second{second} {} 50 FixupTy(Codes code, std::size_t index, 51 std::function<void(mlir::FuncOp)> &&finalizer) 52 : code{code}, index{index}, finalizer{finalizer} {} 53 FixupTy(Codes code, std::size_t index, std::size_t second, 54 std::function<void(mlir::FuncOp)> &&finalizer) 55 : code{code}, index{index}, second{second}, finalizer{finalizer} {} 56 57 Codes code; 58 std::size_t index; 59 std::size_t second{}; 60 llvm::Optional<std::function<void(mlir::FuncOp)>> finalizer{}; 61 }; // namespace 62 63 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code 64 /// generation that traverses the FIR and modifies types and operations to a 65 /// form that is appropriate for the specific target. LLVM IR has specific 66 /// idioms that are used for distinct target processor and ABI combinations. 67 class TargetRewrite : public TargetRewriteBase<TargetRewrite> { 68 public: 69 TargetRewrite(const TargetRewriteOptions &options) { 70 noCharacterConversion = options.noCharacterConversion; 71 noComplexConversion = options.noComplexConversion; 72 } 73 74 void runOnOperation() override final { 75 auto &context = getContext(); 76 mlir::OpBuilder rewriter(&context); 77 78 auto mod = getModule(); 79 if (!forcedTargetTriple.empty()) 80 setTargetTriple(mod, forcedTargetTriple); 81 82 auto specifics = CodeGenSpecifics::get(getOperation().getContext(), 83 getTargetTriple(getOperation()), 84 getKindMapping(getOperation())); 85 setMembers(specifics.get(), &rewriter); 86 87 // Perform type conversion on signatures and call sites. 88 if (mlir::failed(convertTypes(mod))) { 89 mlir::emitError(mlir::UnknownLoc::get(&context), 90 "error in converting types to target abi"); 91 signalPassFailure(); 92 } 93 94 // Convert ops in target-specific patterns. 95 mod.walk([&](mlir::Operation *op) { 96 if (auto call = dyn_cast<fir::CallOp>(op)) { 97 if (!hasPortableSignature(call.getFunctionType())) 98 convertCallOp(call); 99 } else if (auto dispatch = dyn_cast<DispatchOp>(op)) { 100 if (!hasPortableSignature(dispatch.getFunctionType())) 101 convertCallOp(dispatch); 102 } else if (auto addr = dyn_cast<AddrOfOp>(op)) { 103 if (addr.getType().isa<mlir::FunctionType>() && 104 !hasPortableSignature(addr.getType())) 105 convertAddrOp(addr); 106 } 107 }); 108 109 clearMembers(); 110 } 111 112 mlir::ModuleOp getModule() { return getOperation(); } 113 114 template <typename A, typename B, typename C> 115 std::function<mlir::Value(mlir::Operation *)> 116 rewriteCallComplexResultType(A ty, B &newResTys, B &newInTys, C &newOpers) { 117 auto m = specifics->complexReturnType(ty.getElementType()); 118 // Currently targets mandate COMPLEX is a single aggregate or packed 119 // scalar, including the sret case. 120 assert(m.size() == 1 && "target lowering of complex return not supported"); 121 auto resTy = std::get<mlir::Type>(m[0]); 122 auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); 123 auto loc = mlir::UnknownLoc::get(resTy.getContext()); 124 if (attr.isSRet()) { 125 assert(isa_ref_type(resTy)); 126 mlir::Value stack = 127 rewriter->create<fir::AllocaOp>(loc, dyn_cast_ptrEleTy(resTy)); 128 newInTys.push_back(resTy); 129 newOpers.push_back(stack); 130 return [=](mlir::Operation *) -> mlir::Value { 131 auto memTy = ReferenceType::get(ty); 132 auto cast = rewriter->create<ConvertOp>(loc, memTy, stack); 133 return rewriter->create<fir::LoadOp>(loc, cast); 134 }; 135 } 136 newResTys.push_back(resTy); 137 return [=](mlir::Operation *call) -> mlir::Value { 138 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 139 rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem); 140 auto memTy = ReferenceType::get(ty); 141 auto cast = rewriter->create<ConvertOp>(loc, memTy, mem); 142 return rewriter->create<fir::LoadOp>(loc, cast); 143 }; 144 } 145 146 template <typename A, typename B, typename C> 147 void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys, 148 C &newOpers) { 149 auto m = specifics->complexArgumentType(ty.getElementType()); 150 auto *ctx = ty.getContext(); 151 auto loc = mlir::UnknownLoc::get(ctx); 152 if (m.size() == 1) { 153 // COMPLEX is a single aggregate 154 auto resTy = std::get<mlir::Type>(m[0]); 155 auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); 156 auto oldRefTy = ReferenceType::get(ty); 157 if (attr.isByVal()) { 158 auto mem = rewriter->create<fir::AllocaOp>(loc, ty); 159 rewriter->create<fir::StoreOp>(loc, oper, mem); 160 newOpers.push_back(rewriter->create<ConvertOp>(loc, resTy, mem)); 161 } else { 162 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 163 auto cast = rewriter->create<ConvertOp>(loc, oldRefTy, mem); 164 rewriter->create<fir::StoreOp>(loc, oper, cast); 165 newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem)); 166 } 167 newInTys.push_back(resTy); 168 } else { 169 assert(m.size() == 2); 170 // COMPLEX is split into 2 separate arguments 171 for (auto e : llvm::enumerate(m)) { 172 auto &tup = e.value(); 173 auto ty = std::get<mlir::Type>(tup); 174 auto index = e.index(); 175 auto idx = rewriter->getIntegerAttr(rewriter->getIndexType(), index); 176 auto val = rewriter->create<ExtractValueOp>( 177 loc, ty, oper, rewriter->getArrayAttr(idx)); 178 newInTys.push_back(ty); 179 newOpers.push_back(val); 180 } 181 } 182 } 183 184 // Convert fir.call and fir.dispatch Ops. 185 template <typename A> 186 void convertCallOp(A callOp) { 187 auto fnTy = callOp.getFunctionType(); 188 auto loc = callOp.getLoc(); 189 rewriter->setInsertionPoint(callOp); 190 llvm::SmallVector<mlir::Type> newResTys; 191 llvm::SmallVector<mlir::Type> newInTys; 192 llvm::SmallVector<mlir::Value> newOpers; 193 194 // If the call is indirect, the first argument must still be the function 195 // to call. 196 int dropFront = 0; 197 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 198 if (!callOp.callee().hasValue()) { 199 newInTys.push_back(fnTy.getInput(0)); 200 newOpers.push_back(callOp.getOperand(0)); 201 dropFront = 1; 202 } 203 } 204 205 // Determine the rewrite function, `wrap`, for the result value. 206 llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap; 207 if (fnTy.getResults().size() == 1) { 208 mlir::Type ty = fnTy.getResult(0); 209 llvm::TypeSwitch<mlir::Type>(ty) 210 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 211 wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 212 newOpers); 213 }) 214 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 215 wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 216 newOpers); 217 }) 218 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 219 } else if (fnTy.getResults().size() > 1) { 220 TODO(loc, "multiple results not supported yet"); 221 } 222 223 llvm::SmallVector<mlir::Type> trailingInTys; 224 llvm::SmallVector<mlir::Value> trailingOpers; 225 for (auto e : llvm::enumerate( 226 llvm::zip(fnTy.getInputs().drop_front(dropFront), 227 callOp.getOperands().drop_front(dropFront)))) { 228 mlir::Type ty = std::get<0>(e.value()); 229 mlir::Value oper = std::get<1>(e.value()); 230 unsigned index = e.index(); 231 llvm::TypeSwitch<mlir::Type>(ty) 232 .template Case<BoxCharType>([&](BoxCharType boxTy) { 233 bool sret; 234 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 235 sret = callOp.callee() && 236 functionArgIsSRet(index, 237 getModule().lookupSymbol<mlir::FuncOp>( 238 *callOp.callee())); 239 } else { 240 // TODO: dispatch case; how do we put arguments on a call? 241 // We cannot put both an sret and the dispatch object first. 242 sret = false; 243 TODO(loc, "dispatch + sret not supported yet"); 244 } 245 auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret); 246 auto unbox = 247 rewriter->create<UnboxCharOp>(loc, std::get<mlir::Type>(m[0]), 248 std::get<mlir::Type>(m[1]), oper); 249 // unboxed CHARACTER arguments 250 for (auto e : llvm::enumerate(m)) { 251 unsigned idx = e.index(); 252 auto attr = std::get<CodeGenSpecifics::Attributes>(e.value()); 253 auto argTy = std::get<mlir::Type>(e.value()); 254 if (attr.isAppend()) { 255 trailingInTys.push_back(argTy); 256 trailingOpers.push_back(unbox.getResult(idx)); 257 } else { 258 newInTys.push_back(argTy); 259 newOpers.push_back(unbox.getResult(idx)); 260 } 261 } 262 }) 263 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 264 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 265 }) 266 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 267 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 268 }) 269 .Default([&](mlir::Type ty) { 270 newInTys.push_back(ty); 271 newOpers.push_back(oper); 272 }); 273 } 274 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 275 newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); 276 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 277 fir::CallOp newCall; 278 if (callOp.callee().hasValue()) { 279 newCall = rewriter->create<A>(loc, callOp.callee().getValue(), 280 newResTys, newOpers); 281 } else { 282 // Force new type on the input operand. 283 newOpers[0].setType(mlir::FunctionType::get( 284 callOp.getContext(), 285 mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys)); 286 newCall = rewriter->create<A>(loc, newResTys, newOpers); 287 } 288 LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n'); 289 if (wrap.hasValue()) 290 replaceOp(callOp, (*wrap)(newCall.getOperation())); 291 else 292 replaceOp(callOp, newCall.getResults()); 293 } else { 294 // A is fir::DispatchOp 295 TODO(loc, "dispatch not implemented"); 296 } 297 } 298 299 // Result type fixup for fir::ComplexType and mlir::ComplexType 300 template <typename A, typename B> 301 void lowerComplexSignatureRes(A cmplx, B &newResTys, B &newInTys) { 302 if (noComplexConversion) { 303 newResTys.push_back(cmplx); 304 } else { 305 for (auto &tup : specifics->complexReturnType(cmplx.getElementType())) { 306 auto argTy = std::get<mlir::Type>(tup); 307 if (std::get<CodeGenSpecifics::Attributes>(tup).isSRet()) 308 newInTys.push_back(argTy); 309 else 310 newResTys.push_back(argTy); 311 } 312 } 313 } 314 315 // Argument type fixup for fir::ComplexType and mlir::ComplexType 316 template <typename A, typename B> 317 void lowerComplexSignatureArg(A cmplx, B &newInTys) { 318 if (noComplexConversion) 319 newInTys.push_back(cmplx); 320 else 321 for (auto &tup : specifics->complexArgumentType(cmplx.getElementType())) 322 newInTys.push_back(std::get<mlir::Type>(tup)); 323 } 324 325 /// Taking the address of a function. Modify the signature as needed. 326 void convertAddrOp(AddrOfOp addrOp) { 327 rewriter->setInsertionPoint(addrOp); 328 auto addrTy = addrOp.getType().cast<mlir::FunctionType>(); 329 llvm::SmallVector<mlir::Type> newResTys; 330 llvm::SmallVector<mlir::Type> newInTys; 331 for (mlir::Type ty : addrTy.getResults()) { 332 llvm::TypeSwitch<mlir::Type>(ty) 333 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 334 lowerComplexSignatureRes(ty, newResTys, newInTys); 335 }) 336 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 337 lowerComplexSignatureRes(ty, newResTys, newInTys); 338 }) 339 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 340 } 341 llvm::SmallVector<mlir::Type> trailingInTys; 342 for (mlir::Type ty : addrTy.getInputs()) { 343 llvm::TypeSwitch<mlir::Type>(ty) 344 .Case<BoxCharType>([&](BoxCharType box) { 345 if (noCharacterConversion) { 346 newInTys.push_back(box); 347 } else { 348 for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) { 349 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 350 auto argTy = std::get<mlir::Type>(tup); 351 llvm::SmallVector<mlir::Type> &vec = 352 attr.isAppend() ? trailingInTys : newInTys; 353 vec.push_back(argTy); 354 } 355 } 356 }) 357 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 358 lowerComplexSignatureArg(ty, newInTys); 359 }) 360 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 361 lowerComplexSignatureArg(ty, newInTys); 362 }) 363 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 364 } 365 // append trailing input types 366 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 367 // replace this op with a new one with the updated signature 368 auto newTy = rewriter->getFunctionType(newInTys, newResTys); 369 auto newOp = 370 rewriter->create<AddrOfOp>(addrOp.getLoc(), newTy, addrOp.symbol()); 371 replaceOp(addrOp, newOp.getResult()); 372 } 373 374 /// Convert the type signatures on all the functions present in the module. 375 /// As the type signature is being changed, this must also update the 376 /// function itself to use any new arguments, etc. 377 mlir::LogicalResult convertTypes(mlir::ModuleOp mod) { 378 for (auto fn : mod.getOps<mlir::FuncOp>()) 379 convertSignature(fn); 380 return mlir::success(); 381 } 382 383 /// If the signature does not need any special target-specific converions, 384 /// then it is considered portable for any target, and this function will 385 /// return `true`. Otherwise, the signature is not portable and `false` is 386 /// returned. 387 bool hasPortableSignature(mlir::Type signature) { 388 assert(signature.isa<mlir::FunctionType>()); 389 auto func = signature.dyn_cast<mlir::FunctionType>(); 390 for (auto ty : func.getResults()) 391 if ((ty.isa<BoxCharType>() && !noCharacterConversion) || 392 (isa_complex(ty) && !noComplexConversion)) { 393 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 394 return false; 395 } 396 for (auto ty : func.getInputs()) 397 if ((ty.isa<BoxCharType>() && !noCharacterConversion) || 398 (isa_complex(ty) && !noComplexConversion)) { 399 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 400 return false; 401 } 402 return true; 403 } 404 405 /// Rewrite the signatures and body of the `FuncOp`s in the module for 406 /// the immediately subsequent target code gen. 407 void convertSignature(mlir::FuncOp func) { 408 auto funcTy = func.getType().cast<mlir::FunctionType>(); 409 if (hasPortableSignature(funcTy)) 410 return; 411 llvm::SmallVector<mlir::Type> newResTys; 412 llvm::SmallVector<mlir::Type> newInTys; 413 llvm::SmallVector<FixupTy> fixups; 414 415 // Convert return value(s) 416 for (auto ty : funcTy.getResults()) 417 llvm::TypeSwitch<mlir::Type>(ty) 418 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 419 if (noComplexConversion) 420 newResTys.push_back(cmplx); 421 else 422 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 423 }) 424 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 425 if (noComplexConversion) 426 newResTys.push_back(cmplx); 427 else 428 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 429 }) 430 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 431 432 // Convert arguments 433 llvm::SmallVector<mlir::Type> trailingTys; 434 for (auto e : llvm::enumerate(funcTy.getInputs())) { 435 auto ty = e.value(); 436 unsigned index = e.index(); 437 llvm::TypeSwitch<mlir::Type>(ty) 438 .Case<BoxCharType>([&](BoxCharType boxTy) { 439 if (noCharacterConversion) { 440 newInTys.push_back(boxTy); 441 } else { 442 // Convert a CHARACTER argument type. This can involve separating 443 // the pointer and the LEN into two arguments and moving the LEN 444 // argument to the end of the arg list. 445 bool sret = functionArgIsSRet(index, func); 446 for (auto e : llvm::enumerate(specifics->boxcharArgumentType( 447 boxTy.getEleTy(), sret))) { 448 auto &tup = e.value(); 449 auto index = e.index(); 450 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 451 auto argTy = std::get<mlir::Type>(tup); 452 if (attr.isAppend()) { 453 trailingTys.push_back(argTy); 454 } else { 455 if (sret) { 456 fixups.emplace_back(FixupTy::Codes::CharPair, 457 newInTys.size(), index); 458 } else { 459 fixups.emplace_back(FixupTy::Codes::Trailing, 460 newInTys.size(), trailingTys.size()); 461 } 462 newInTys.push_back(argTy); 463 } 464 } 465 } 466 }) 467 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 468 if (noComplexConversion) 469 newInTys.push_back(cmplx); 470 else 471 doComplexArg(func, cmplx, newInTys, fixups); 472 }) 473 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 474 if (noComplexConversion) 475 newInTys.push_back(cmplx); 476 else 477 doComplexArg(func, cmplx, newInTys, fixups); 478 }) 479 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 480 } 481 482 if (!func.empty()) { 483 // If the function has a body, then apply the fixups to the arguments and 484 // return ops as required. These fixups are done in place. 485 auto loc = func.getLoc(); 486 const auto fixupSize = fixups.size(); 487 const auto oldArgTys = func.getType().getInputs(); 488 int offset = 0; 489 for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) { 490 const auto &fixup = fixups[i]; 491 switch (fixup.code) { 492 case FixupTy::Codes::ArgumentAsLoad: { 493 // Argument was pass-by-value, but is now pass-by-reference and 494 // possibly with a different element type. 495 auto newArg = func.front().insertArgument(fixup.index, 496 newInTys[fixup.index], loc); 497 rewriter->setInsertionPointToStart(&func.front()); 498 auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); 499 auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, newArg); 500 auto load = rewriter->create<fir::LoadOp>(loc, cast); 501 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 502 func.front().eraseArgument(fixup.index + 1); 503 } break; 504 case FixupTy::Codes::ArgumentType: { 505 // Argument is pass-by-value, but its type has likely been modified to 506 // suit the target ABI convention. 507 auto newArg = func.front().insertArgument(fixup.index, 508 newInTys[fixup.index], loc); 509 rewriter->setInsertionPointToStart(&func.front()); 510 auto mem = 511 rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]); 512 rewriter->create<fir::StoreOp>(loc, newArg, mem); 513 auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); 514 auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, mem); 515 mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast); 516 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 517 func.front().eraseArgument(fixup.index + 1); 518 LLVM_DEBUG(llvm::dbgs() 519 << "old argument: " << oldArgTy.getEleTy() 520 << ", repl: " << load << ", new argument: " 521 << func.getArgument(fixup.index).getType() << '\n'); 522 } break; 523 case FixupTy::Codes::CharPair: { 524 // The FIR boxchar argument has been split into a pair of distinct 525 // arguments that are in juxtaposition to each other. 526 auto newArg = func.front().insertArgument(fixup.index, 527 newInTys[fixup.index], loc); 528 if (fixup.second == 1) { 529 rewriter->setInsertionPointToStart(&func.front()); 530 auto boxTy = oldArgTys[fixup.index - offset - fixup.second]; 531 auto box = rewriter->create<EmboxCharOp>( 532 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg); 533 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 534 func.front().eraseArgument(fixup.index + 1); 535 offset++; 536 } 537 } break; 538 case FixupTy::Codes::ReturnAsStore: { 539 // The value being returned is now being returned in memory (callee 540 // stack space) through a hidden reference argument. 541 auto newArg = func.front().insertArgument(fixup.index, 542 newInTys[fixup.index], loc); 543 offset++; 544 func.walk([&](mlir::ReturnOp ret) { 545 rewriter->setInsertionPoint(ret); 546 auto oldOper = ret.getOperand(0); 547 auto oldOperTy = ReferenceType::get(oldOper.getType()); 548 auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, newArg); 549 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 550 rewriter->create<mlir::ReturnOp>(loc); 551 ret.erase(); 552 }); 553 } break; 554 case FixupTy::Codes::ReturnType: { 555 // The function is still returning a value, but its type has likely 556 // changed to suit the target ABI convention. 557 func.walk([&](mlir::ReturnOp ret) { 558 rewriter->setInsertionPoint(ret); 559 auto oldOper = ret.getOperand(0); 560 auto oldOperTy = ReferenceType::get(oldOper.getType()); 561 auto mem = 562 rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]); 563 auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, mem); 564 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 565 mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem); 566 rewriter->create<mlir::ReturnOp>(loc, load); 567 ret.erase(); 568 }); 569 } break; 570 case FixupTy::Codes::Split: { 571 // The FIR argument has been split into a pair of distinct arguments 572 // that are in juxtaposition to each other. (For COMPLEX value.) 573 auto newArg = func.front().insertArgument(fixup.index, 574 newInTys[fixup.index], loc); 575 if (fixup.second == 1) { 576 rewriter->setInsertionPointToStart(&func.front()); 577 auto cplxTy = oldArgTys[fixup.index - offset - fixup.second]; 578 auto undef = rewriter->create<UndefOp>(loc, cplxTy); 579 auto zero = rewriter->getIntegerAttr(rewriter->getIndexType(), 0); 580 auto one = rewriter->getIntegerAttr(rewriter->getIndexType(), 1); 581 auto cplx1 = rewriter->create<InsertValueOp>( 582 loc, cplxTy, undef, func.front().getArgument(fixup.index - 1), 583 rewriter->getArrayAttr(zero)); 584 auto cplx = rewriter->create<InsertValueOp>( 585 loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one)); 586 func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx); 587 func.front().eraseArgument(fixup.index + 1); 588 offset++; 589 } 590 } break; 591 case FixupTy::Codes::Trailing: { 592 // The FIR argument has been split into a pair of distinct arguments. 593 // The first part of the pair appears in the original argument 594 // position. The second part of the pair is appended after all the 595 // original arguments. (Boxchar arguments.) 596 auto newBufArg = func.front().insertArgument( 597 fixup.index, newInTys[fixup.index], loc); 598 auto newLenArg = 599 func.front().addArgument(trailingTys[fixup.second], loc); 600 auto boxTy = oldArgTys[fixup.index - offset]; 601 rewriter->setInsertionPointToStart(&func.front()); 602 auto box = 603 rewriter->create<EmboxCharOp>(loc, boxTy, newBufArg, newLenArg); 604 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 605 func.front().eraseArgument(fixup.index + 1); 606 } break; 607 } 608 } 609 } 610 611 // Set the new type and finalize the arguments, etc. 612 newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end()); 613 auto newFuncTy = 614 mlir::FunctionType::get(func.getContext(), newInTys, newResTys); 615 LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n'); 616 func.setType(newFuncTy); 617 618 for (auto &fixup : fixups) 619 if (fixup.finalizer) 620 (*fixup.finalizer)(func); 621 } 622 623 inline bool functionArgIsSRet(unsigned index, mlir::FuncOp func) { 624 if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret")) 625 return true; 626 return false; 627 } 628 629 /// Convert a complex return value. This can involve converting the return 630 /// value to a "hidden" first argument or packing the complex into a wide 631 /// GPR. 632 template <typename A, typename B, typename C> 633 void doComplexReturn(mlir::FuncOp func, A cmplx, B &newResTys, B &newInTys, 634 C &fixups) { 635 if (noComplexConversion) { 636 newResTys.push_back(cmplx); 637 return; 638 } 639 auto m = specifics->complexReturnType(cmplx.getElementType()); 640 assert(m.size() == 1); 641 auto &tup = m[0]; 642 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 643 auto argTy = std::get<mlir::Type>(tup); 644 if (attr.isSRet()) { 645 unsigned argNo = newInTys.size(); 646 fixups.emplace_back( 647 FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::FuncOp func) { 648 func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); 649 }); 650 newInTys.push_back(argTy); 651 return; 652 } 653 fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size()); 654 newResTys.push_back(argTy); 655 } 656 657 /// Convert a complex argument value. This can involve storing the value to 658 /// a temporary memory location or factoring the value into two distinct 659 /// arguments. 660 template <typename A, typename B, typename C> 661 void doComplexArg(mlir::FuncOp func, A cmplx, B &newInTys, C &fixups) { 662 if (noComplexConversion) { 663 newInTys.push_back(cmplx); 664 return; 665 } 666 auto m = specifics->complexArgumentType(cmplx.getElementType()); 667 const auto fixupCode = 668 m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType; 669 for (auto e : llvm::enumerate(m)) { 670 auto &tup = e.value(); 671 auto index = e.index(); 672 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 673 auto argTy = std::get<mlir::Type>(tup); 674 auto argNo = newInTys.size(); 675 if (attr.isByVal()) { 676 if (auto align = attr.getAlignment()) 677 fixups.emplace_back( 678 FixupTy::Codes::ArgumentAsLoad, argNo, [=](mlir::FuncOp func) { 679 func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr()); 680 func.setArgAttr(argNo, "llvm.align", 681 rewriter->getIntegerAttr( 682 rewriter->getIntegerType(32), align)); 683 }); 684 else 685 fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(), 686 [=](mlir::FuncOp func) { 687 func.setArgAttr(argNo, "llvm.byval", 688 rewriter->getUnitAttr()); 689 }); 690 } else { 691 if (auto align = attr.getAlignment()) 692 fixups.emplace_back(fixupCode, argNo, index, [=](mlir::FuncOp func) { 693 func.setArgAttr( 694 argNo, "llvm.align", 695 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align)); 696 }); 697 else 698 fixups.emplace_back(fixupCode, argNo, index); 699 } 700 newInTys.push_back(argTy); 701 } 702 } 703 704 private: 705 // Replace `op` and remove it. 706 void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { 707 op->replaceAllUsesWith(newValues); 708 op->dropAllReferences(); 709 op->erase(); 710 } 711 712 inline void setMembers(CodeGenSpecifics *s, mlir::OpBuilder *r) { 713 specifics = s; 714 rewriter = r; 715 } 716 717 inline void clearMembers() { setMembers(nullptr, nullptr); } 718 719 CodeGenSpecifics *specifics{}; 720 mlir::OpBuilder *rewriter; 721 }; // namespace 722 } // namespace 723 724 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 725 fir::createFirTargetRewritePass(const TargetRewriteOptions &options) { 726 return std::make_unique<TargetRewrite>(options); 727 } 728