1 //===-- TargetRewrite.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest. 10 // LLVM expects different lowering idioms to be used for distinct target 11 // triples. These distinctions are handled by this pass. 12 // 13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "PassDetail.h" 18 #include "Target.h" 19 #include "flang/Lower/Todo.h" 20 #include "flang/Optimizer/CodeGen/CodeGen.h" 21 #include "flang/Optimizer/Dialect/FIRDialect.h" 22 #include "flang/Optimizer/Dialect/FIROps.h" 23 #include "flang/Optimizer/Dialect/FIRType.h" 24 #include "flang/Optimizer/Support/FIRContext.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 #include "llvm/ADT/STLExtras.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 #include "llvm/Support/Debug.h" 29 30 using namespace fir; 31 32 #define DEBUG_TYPE "flang-target-rewrite" 33 34 namespace { 35 36 /// Fixups for updating a FuncOp's arguments and return values. 37 struct FixupTy { 38 enum class Codes { 39 ArgumentAsLoad, 40 ArgumentType, 41 CharPair, 42 ReturnAsStore, 43 ReturnType, 44 Split, 45 Trailing 46 }; 47 48 FixupTy(Codes code, std::size_t index, std::size_t second = 0) 49 : code{code}, index{index}, second{second} {} 50 FixupTy(Codes code, std::size_t index, 51 std::function<void(mlir::FuncOp)> &&finalizer) 52 : code{code}, index{index}, finalizer{finalizer} {} 53 FixupTy(Codes code, std::size_t index, std::size_t second, 54 std::function<void(mlir::FuncOp)> &&finalizer) 55 : code{code}, index{index}, second{second}, finalizer{finalizer} {} 56 57 Codes code; 58 std::size_t index; 59 std::size_t second{}; 60 llvm::Optional<std::function<void(mlir::FuncOp)>> finalizer{}; 61 }; // namespace 62 63 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code 64 /// generation that traverses the FIR and modifies types and operations to a 65 /// form that is appropriate for the specific target. LLVM IR has specific 66 /// idioms that are used for distinct target processor and ABI combinations. 67 class TargetRewrite : public TargetRewriteBase<TargetRewrite> { 68 public: 69 TargetRewrite(const TargetRewriteOptions &options) { 70 noCharacterConversion = options.noCharacterConversion; 71 noComplexConversion = options.noComplexConversion; 72 } 73 74 void runOnOperation() override final { 75 auto &context = getContext(); 76 mlir::OpBuilder rewriter(&context); 77 78 auto mod = getModule(); 79 if (!forcedTargetTriple.empty()) { 80 setTargetTriple(mod, forcedTargetTriple); 81 } 82 83 auto specifics = CodeGenSpecifics::get(getOperation().getContext(), 84 getTargetTriple(getOperation()), 85 getKindMapping(getOperation())); 86 setMembers(specifics.get(), &rewriter); 87 88 // Perform type conversion on signatures and call sites. 89 if (mlir::failed(convertTypes(mod))) { 90 mlir::emitError(mlir::UnknownLoc::get(&context), 91 "error in converting types to target abi"); 92 signalPassFailure(); 93 } 94 95 // Convert ops in target-specific patterns. 96 mod.walk([&](mlir::Operation *op) { 97 if (auto call = dyn_cast<fir::CallOp>(op)) { 98 if (!hasPortableSignature(call.getFunctionType())) 99 convertCallOp(call); 100 } else if (auto dispatch = dyn_cast<DispatchOp>(op)) { 101 if (!hasPortableSignature(dispatch.getFunctionType())) 102 convertCallOp(dispatch); 103 } else if (auto addr = dyn_cast<AddrOfOp>(op)) { 104 if (addr.getType().isa<mlir::FunctionType>() && 105 !hasPortableSignature(addr.getType())) 106 convertAddrOp(addr); 107 } 108 }); 109 110 clearMembers(); 111 } 112 113 mlir::ModuleOp getModule() { return getOperation(); } 114 115 template <typename A, typename B, typename C> 116 std::function<mlir::Value(mlir::Operation *)> 117 rewriteCallComplexResultType(A ty, B &newResTys, B &newInTys, C &newOpers) { 118 auto m = specifics->complexReturnType(ty.getElementType()); 119 // Currently targets mandate COMPLEX is a single aggregate or packed 120 // scalar, including the sret case. 121 assert(m.size() == 1 && "target lowering of complex return not supported"); 122 auto resTy = std::get<mlir::Type>(m[0]); 123 auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); 124 auto loc = mlir::UnknownLoc::get(resTy.getContext()); 125 if (attr.isSRet()) { 126 assert(isa_ref_type(resTy)); 127 mlir::Value stack = 128 rewriter->create<fir::AllocaOp>(loc, dyn_cast_ptrEleTy(resTy)); 129 newInTys.push_back(resTy); 130 newOpers.push_back(stack); 131 return [=](mlir::Operation *) -> mlir::Value { 132 auto memTy = ReferenceType::get(ty); 133 auto cast = rewriter->create<ConvertOp>(loc, memTy, stack); 134 return rewriter->create<fir::LoadOp>(loc, cast); 135 }; 136 } 137 newResTys.push_back(resTy); 138 return [=](mlir::Operation *call) -> mlir::Value { 139 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 140 rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem); 141 auto memTy = ReferenceType::get(ty); 142 auto cast = rewriter->create<ConvertOp>(loc, memTy, mem); 143 return rewriter->create<fir::LoadOp>(loc, cast); 144 }; 145 } 146 147 template <typename A, typename B, typename C> 148 void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys, 149 C &newOpers) { 150 auto m = specifics->complexArgumentType(ty.getElementType()); 151 auto *ctx = ty.getContext(); 152 auto loc = mlir::UnknownLoc::get(ctx); 153 if (m.size() == 1) { 154 // COMPLEX is a single aggregate 155 auto resTy = std::get<mlir::Type>(m[0]); 156 auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); 157 auto oldRefTy = ReferenceType::get(ty); 158 if (attr.isByVal()) { 159 auto mem = rewriter->create<fir::AllocaOp>(loc, ty); 160 rewriter->create<fir::StoreOp>(loc, oper, mem); 161 newOpers.push_back(rewriter->create<ConvertOp>(loc, resTy, mem)); 162 } else { 163 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 164 auto cast = rewriter->create<ConvertOp>(loc, oldRefTy, mem); 165 rewriter->create<fir::StoreOp>(loc, oper, cast); 166 newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem)); 167 } 168 newInTys.push_back(resTy); 169 } else { 170 assert(m.size() == 2); 171 // COMPLEX is split into 2 separate arguments 172 for (auto e : llvm::enumerate(m)) { 173 auto &tup = e.value(); 174 auto ty = std::get<mlir::Type>(tup); 175 auto index = e.index(); 176 auto idx = rewriter->getIntegerAttr(rewriter->getIndexType(), index); 177 auto val = rewriter->create<ExtractValueOp>( 178 loc, ty, oper, rewriter->getArrayAttr(idx)); 179 newInTys.push_back(ty); 180 newOpers.push_back(val); 181 } 182 } 183 } 184 185 // Convert fir.call and fir.dispatch Ops. 186 template <typename A> 187 void convertCallOp(A callOp) { 188 auto fnTy = callOp.getFunctionType(); 189 auto loc = callOp.getLoc(); 190 rewriter->setInsertionPoint(callOp); 191 llvm::SmallVector<mlir::Type> newResTys; 192 llvm::SmallVector<mlir::Type> newInTys; 193 llvm::SmallVector<mlir::Value> newOpers; 194 195 // If the call is indirect, the first argument must still be the function 196 // to call. 197 int dropFront = 0; 198 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 199 if (!callOp.callee().hasValue()) { 200 newInTys.push_back(fnTy.getInput(0)); 201 newOpers.push_back(callOp.getOperand(0)); 202 dropFront = 1; 203 } 204 } 205 206 // Determine the rewrite function, `wrap`, for the result value. 207 llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap; 208 if (fnTy.getResults().size() == 1) { 209 mlir::Type ty = fnTy.getResult(0); 210 llvm::TypeSwitch<mlir::Type>(ty) 211 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 212 wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 213 newOpers); 214 }) 215 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 216 wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 217 newOpers); 218 }) 219 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 220 } else if (fnTy.getResults().size() > 1) { 221 TODO(loc, "multiple results not supported yet"); 222 } 223 224 llvm::SmallVector<mlir::Type> trailingInTys; 225 llvm::SmallVector<mlir::Value> trailingOpers; 226 for (auto e : llvm::enumerate( 227 llvm::zip(fnTy.getInputs().drop_front(dropFront), 228 callOp.getOperands().drop_front(dropFront)))) { 229 mlir::Type ty = std::get<0>(e.value()); 230 mlir::Value oper = std::get<1>(e.value()); 231 unsigned index = e.index(); 232 llvm::TypeSwitch<mlir::Type>(ty) 233 .template Case<BoxCharType>([&](BoxCharType boxTy) { 234 bool sret; 235 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 236 sret = callOp.callee() && 237 functionArgIsSRet(index, 238 getModule().lookupSymbol<mlir::FuncOp>( 239 *callOp.callee())); 240 } else { 241 // TODO: dispatch case; how do we put arguments on a call? 242 // We cannot put both an sret and the dispatch object first. 243 sret = false; 244 TODO(loc, "dispatch + sret not supported yet"); 245 } 246 auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret); 247 auto unbox = 248 rewriter->create<UnboxCharOp>(loc, std::get<mlir::Type>(m[0]), 249 std::get<mlir::Type>(m[1]), oper); 250 // unboxed CHARACTER arguments 251 for (auto e : llvm::enumerate(m)) { 252 unsigned idx = e.index(); 253 auto attr = std::get<CodeGenSpecifics::Attributes>(e.value()); 254 auto argTy = std::get<mlir::Type>(e.value()); 255 if (attr.isAppend()) { 256 trailingInTys.push_back(argTy); 257 trailingOpers.push_back(unbox.getResult(idx)); 258 } else { 259 newInTys.push_back(argTy); 260 newOpers.push_back(unbox.getResult(idx)); 261 } 262 } 263 }) 264 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 265 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 266 }) 267 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 268 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 269 }) 270 .Default([&](mlir::Type ty) { 271 newInTys.push_back(ty); 272 newOpers.push_back(oper); 273 }); 274 } 275 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 276 newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); 277 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 278 fir::CallOp newCall; 279 if (callOp.callee().hasValue()) { 280 newCall = rewriter->create<A>(loc, callOp.callee().getValue(), 281 newResTys, newOpers); 282 } else { 283 // Force new type on the input operand. 284 newOpers[0].setType(mlir::FunctionType::get( 285 callOp.getContext(), 286 mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys)); 287 newCall = rewriter->create<A>(loc, newResTys, newOpers); 288 } 289 LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n'); 290 if (wrap.hasValue()) 291 replaceOp(callOp, (*wrap)(newCall.getOperation())); 292 else 293 replaceOp(callOp, newCall.getResults()); 294 } else { 295 // A is fir::DispatchOp 296 TODO(loc, "dispatch not implemented"); 297 } 298 } 299 300 // Result type fixup for fir::ComplexType and mlir::ComplexType 301 template <typename A, typename B> 302 void lowerComplexSignatureRes(A cmplx, B &newResTys, B &newInTys) { 303 if (noComplexConversion) { 304 newResTys.push_back(cmplx); 305 } else { 306 for (auto &tup : specifics->complexReturnType(cmplx.getElementType())) { 307 auto argTy = std::get<mlir::Type>(tup); 308 if (std::get<CodeGenSpecifics::Attributes>(tup).isSRet()) 309 newInTys.push_back(argTy); 310 else 311 newResTys.push_back(argTy); 312 } 313 } 314 } 315 316 // Argument type fixup for fir::ComplexType and mlir::ComplexType 317 template <typename A, typename B> 318 void lowerComplexSignatureArg(A cmplx, B &newInTys) { 319 if (noComplexConversion) 320 newInTys.push_back(cmplx); 321 else 322 for (auto &tup : specifics->complexArgumentType(cmplx.getElementType())) 323 newInTys.push_back(std::get<mlir::Type>(tup)); 324 } 325 326 /// Taking the address of a function. Modify the signature as needed. 327 void convertAddrOp(AddrOfOp addrOp) { 328 rewriter->setInsertionPoint(addrOp); 329 auto addrTy = addrOp.getType().cast<mlir::FunctionType>(); 330 llvm::SmallVector<mlir::Type> newResTys; 331 llvm::SmallVector<mlir::Type> newInTys; 332 for (mlir::Type ty : addrTy.getResults()) { 333 llvm::TypeSwitch<mlir::Type>(ty) 334 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 335 lowerComplexSignatureRes(ty, newResTys, newInTys); 336 }) 337 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 338 lowerComplexSignatureRes(ty, newResTys, newInTys); 339 }) 340 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 341 } 342 llvm::SmallVector<mlir::Type> trailingInTys; 343 for (mlir::Type ty : addrTy.getInputs()) { 344 llvm::TypeSwitch<mlir::Type>(ty) 345 .Case<BoxCharType>([&](BoxCharType box) { 346 if (noCharacterConversion) { 347 newInTys.push_back(box); 348 } else { 349 for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) { 350 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 351 auto argTy = std::get<mlir::Type>(tup); 352 llvm::SmallVector<mlir::Type> &vec = 353 attr.isAppend() ? trailingInTys : newInTys; 354 vec.push_back(argTy); 355 } 356 } 357 }) 358 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 359 lowerComplexSignatureArg(ty, newInTys); 360 }) 361 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 362 lowerComplexSignatureArg(ty, newInTys); 363 }) 364 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 365 } 366 // append trailing input types 367 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 368 // replace this op with a new one with the updated signature 369 auto newTy = rewriter->getFunctionType(newInTys, newResTys); 370 auto newOp = 371 rewriter->create<AddrOfOp>(addrOp.getLoc(), newTy, addrOp.symbol()); 372 replaceOp(addrOp, newOp.getResult()); 373 } 374 375 /// Convert the type signatures on all the functions present in the module. 376 /// As the type signature is being changed, this must also update the 377 /// function itself to use any new arguments, etc. 378 mlir::LogicalResult convertTypes(mlir::ModuleOp mod) { 379 for (auto fn : mod.getOps<mlir::FuncOp>()) 380 convertSignature(fn); 381 return mlir::success(); 382 } 383 384 /// If the signature does not need any special target-specific converions, 385 /// then it is considered portable for any target, and this function will 386 /// return `true`. Otherwise, the signature is not portable and `false` is 387 /// returned. 388 bool hasPortableSignature(mlir::Type signature) { 389 assert(signature.isa<mlir::FunctionType>()); 390 auto func = signature.dyn_cast<mlir::FunctionType>(); 391 for (auto ty : func.getResults()) 392 if ((ty.isa<BoxCharType>() && !noCharacterConversion) || 393 (isa_complex(ty) && !noComplexConversion)) { 394 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 395 return false; 396 } 397 for (auto ty : func.getInputs()) 398 if ((ty.isa<BoxCharType>() && !noCharacterConversion) || 399 (isa_complex(ty) && !noComplexConversion)) { 400 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 401 return false; 402 } 403 return true; 404 } 405 406 /// Rewrite the signatures and body of the `FuncOp`s in the module for 407 /// the immediately subsequent target code gen. 408 void convertSignature(mlir::FuncOp func) { 409 auto funcTy = func.getType().cast<mlir::FunctionType>(); 410 if (hasPortableSignature(funcTy)) 411 return; 412 llvm::SmallVector<mlir::Type> newResTys; 413 llvm::SmallVector<mlir::Type> newInTys; 414 llvm::SmallVector<FixupTy> fixups; 415 416 // Convert return value(s) 417 for (auto ty : funcTy.getResults()) 418 llvm::TypeSwitch<mlir::Type>(ty) 419 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 420 if (noComplexConversion) 421 newResTys.push_back(cmplx); 422 else 423 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 424 }) 425 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 426 if (noComplexConversion) 427 newResTys.push_back(cmplx); 428 else 429 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 430 }) 431 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 432 433 // Convert arguments 434 llvm::SmallVector<mlir::Type> trailingTys; 435 for (auto e : llvm::enumerate(funcTy.getInputs())) { 436 auto ty = e.value(); 437 unsigned index = e.index(); 438 llvm::TypeSwitch<mlir::Type>(ty) 439 .Case<BoxCharType>([&](BoxCharType boxTy) { 440 if (noCharacterConversion) { 441 newInTys.push_back(boxTy); 442 } else { 443 // Convert a CHARACTER argument type. This can involve separating 444 // the pointer and the LEN into two arguments and moving the LEN 445 // argument to the end of the arg list. 446 bool sret = functionArgIsSRet(index, func); 447 for (auto e : llvm::enumerate(specifics->boxcharArgumentType( 448 boxTy.getEleTy(), sret))) { 449 auto &tup = e.value(); 450 auto index = e.index(); 451 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 452 auto argTy = std::get<mlir::Type>(tup); 453 if (attr.isAppend()) { 454 trailingTys.push_back(argTy); 455 } else { 456 if (sret) { 457 fixups.emplace_back(FixupTy::Codes::CharPair, 458 newInTys.size(), index); 459 } else { 460 fixups.emplace_back(FixupTy::Codes::Trailing, 461 newInTys.size(), trailingTys.size()); 462 } 463 newInTys.push_back(argTy); 464 } 465 } 466 } 467 }) 468 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 469 if (noComplexConversion) 470 newInTys.push_back(cmplx); 471 else 472 doComplexArg(func, cmplx, newInTys, fixups); 473 }) 474 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 475 if (noComplexConversion) 476 newInTys.push_back(cmplx); 477 else 478 doComplexArg(func, cmplx, newInTys, fixups); 479 }) 480 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 481 } 482 483 if (!func.empty()) { 484 // If the function has a body, then apply the fixups to the arguments and 485 // return ops as required. These fixups are done in place. 486 auto loc = func.getLoc(); 487 const auto fixupSize = fixups.size(); 488 const auto oldArgTys = func.getType().getInputs(); 489 int offset = 0; 490 for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) { 491 const auto &fixup = fixups[i]; 492 switch (fixup.code) { 493 case FixupTy::Codes::ArgumentAsLoad: { 494 // Argument was pass-by-value, but is now pass-by-reference and 495 // possibly with a different element type. 496 auto newArg = 497 func.front().insertArgument(fixup.index, newInTys[fixup.index]); 498 rewriter->setInsertionPointToStart(&func.front()); 499 auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); 500 auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, newArg); 501 auto load = rewriter->create<fir::LoadOp>(loc, cast); 502 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 503 func.front().eraseArgument(fixup.index + 1); 504 } break; 505 case FixupTy::Codes::ArgumentType: { 506 // Argument is pass-by-value, but its type has likely been modified to 507 // suit the target ABI convention. 508 auto newArg = 509 func.front().insertArgument(fixup.index, newInTys[fixup.index]); 510 rewriter->setInsertionPointToStart(&func.front()); 511 auto mem = 512 rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]); 513 rewriter->create<fir::StoreOp>(loc, newArg, mem); 514 auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); 515 auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, mem); 516 mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast); 517 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 518 func.front().eraseArgument(fixup.index + 1); 519 LLVM_DEBUG(llvm::dbgs() 520 << "old argument: " << oldArgTy.getEleTy() 521 << ", repl: " << load << ", new argument: " 522 << func.getArgument(fixup.index).getType() << '\n'); 523 } break; 524 case FixupTy::Codes::CharPair: { 525 // The FIR boxchar argument has been split into a pair of distinct 526 // arguments that are in juxtaposition to each other. 527 auto newArg = 528 func.front().insertArgument(fixup.index, newInTys[fixup.index]); 529 if (fixup.second == 1) { 530 rewriter->setInsertionPointToStart(&func.front()); 531 auto boxTy = oldArgTys[fixup.index - offset - fixup.second]; 532 auto box = rewriter->create<EmboxCharOp>( 533 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg); 534 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 535 func.front().eraseArgument(fixup.index + 1); 536 offset++; 537 } 538 } break; 539 case FixupTy::Codes::ReturnAsStore: { 540 // The value being returned is now being returned in memory (callee 541 // stack space) through a hidden reference argument. 542 auto newArg = 543 func.front().insertArgument(fixup.index, newInTys[fixup.index]); 544 offset++; 545 func.walk([&](mlir::ReturnOp ret) { 546 rewriter->setInsertionPoint(ret); 547 auto oldOper = ret.getOperand(0); 548 auto oldOperTy = ReferenceType::get(oldOper.getType()); 549 auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, newArg); 550 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 551 rewriter->create<mlir::ReturnOp>(loc); 552 ret.erase(); 553 }); 554 } break; 555 case FixupTy::Codes::ReturnType: { 556 // The function is still returning a value, but its type has likely 557 // changed to suit the target ABI convention. 558 func.walk([&](mlir::ReturnOp ret) { 559 rewriter->setInsertionPoint(ret); 560 auto oldOper = ret.getOperand(0); 561 auto oldOperTy = ReferenceType::get(oldOper.getType()); 562 auto mem = 563 rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]); 564 auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, mem); 565 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 566 mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem); 567 rewriter->create<mlir::ReturnOp>(loc, load); 568 ret.erase(); 569 }); 570 } break; 571 case FixupTy::Codes::Split: { 572 // The FIR argument has been split into a pair of distinct arguments 573 // that are in juxtaposition to each other. (For COMPLEX value.) 574 auto newArg = 575 func.front().insertArgument(fixup.index, newInTys[fixup.index]); 576 if (fixup.second == 1) { 577 rewriter->setInsertionPointToStart(&func.front()); 578 auto cplxTy = oldArgTys[fixup.index - offset - fixup.second]; 579 auto undef = rewriter->create<UndefOp>(loc, cplxTy); 580 auto zero = rewriter->getIntegerAttr(rewriter->getIndexType(), 0); 581 auto one = rewriter->getIntegerAttr(rewriter->getIndexType(), 1); 582 auto cplx1 = rewriter->create<InsertValueOp>( 583 loc, cplxTy, undef, func.front().getArgument(fixup.index - 1), 584 rewriter->getArrayAttr(zero)); 585 auto cplx = rewriter->create<InsertValueOp>( 586 loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one)); 587 func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx); 588 func.front().eraseArgument(fixup.index + 1); 589 offset++; 590 } 591 } break; 592 case FixupTy::Codes::Trailing: { 593 // The FIR argument has been split into a pair of distinct arguments. 594 // The first part of the pair appears in the original argument 595 // position. The second part of the pair is appended after all the 596 // original arguments. (Boxchar arguments.) 597 auto newBufArg = 598 func.front().insertArgument(fixup.index, newInTys[fixup.index]); 599 auto newLenArg = func.front().addArgument(trailingTys[fixup.second]); 600 auto boxTy = oldArgTys[fixup.index - offset]; 601 rewriter->setInsertionPointToStart(&func.front()); 602 auto box = 603 rewriter->create<EmboxCharOp>(loc, boxTy, newBufArg, newLenArg); 604 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 605 func.front().eraseArgument(fixup.index + 1); 606 } break; 607 } 608 } 609 } 610 611 // Set the new type and finalize the arguments, etc. 612 newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end()); 613 auto newFuncTy = 614 mlir::FunctionType::get(func.getContext(), newInTys, newResTys); 615 LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n'); 616 func.setType(newFuncTy); 617 618 for (auto &fixup : fixups) 619 if (fixup.finalizer) 620 (*fixup.finalizer)(func); 621 } 622 623 inline bool functionArgIsSRet(unsigned index, mlir::FuncOp func) { 624 if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret")) 625 return true; 626 return false; 627 } 628 629 /// Convert a complex return value. This can involve converting the return 630 /// value to a "hidden" first argument or packing the complex into a wide 631 /// GPR. 632 template <typename A, typename B, typename C> 633 void doComplexReturn(mlir::FuncOp func, A cmplx, B &newResTys, B &newInTys, 634 C &fixups) { 635 if (noComplexConversion) { 636 newResTys.push_back(cmplx); 637 return; 638 } 639 auto m = specifics->complexReturnType(cmplx.getElementType()); 640 assert(m.size() == 1); 641 auto &tup = m[0]; 642 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 643 auto argTy = std::get<mlir::Type>(tup); 644 if (attr.isSRet()) { 645 unsigned argNo = newInTys.size(); 646 fixups.emplace_back( 647 FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::FuncOp func) { 648 func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); 649 }); 650 newInTys.push_back(argTy); 651 return; 652 } 653 fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size()); 654 newResTys.push_back(argTy); 655 } 656 657 /// Convert a complex argument value. This can involve storing the value to 658 /// a temporary memory location or factoring the value into two distinct 659 /// arguments. 660 template <typename A, typename B, typename C> 661 void doComplexArg(mlir::FuncOp func, A cmplx, B &newInTys, C &fixups) { 662 if (noComplexConversion) { 663 newInTys.push_back(cmplx); 664 return; 665 } 666 auto m = specifics->complexArgumentType(cmplx.getElementType()); 667 const auto fixupCode = 668 m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType; 669 for (auto e : llvm::enumerate(m)) { 670 auto &tup = e.value(); 671 auto index = e.index(); 672 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 673 auto argTy = std::get<mlir::Type>(tup); 674 auto argNo = newInTys.size(); 675 if (attr.isByVal()) { 676 if (auto align = attr.getAlignment()) 677 fixups.emplace_back( 678 FixupTy::Codes::ArgumentAsLoad, argNo, [=](mlir::FuncOp func) { 679 func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr()); 680 func.setArgAttr(argNo, "llvm.align", 681 rewriter->getIntegerAttr( 682 rewriter->getIntegerType(32), align)); 683 }); 684 else 685 fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(), 686 [=](mlir::FuncOp func) { 687 func.setArgAttr(argNo, "llvm.byval", 688 rewriter->getUnitAttr()); 689 }); 690 } else { 691 if (auto align = attr.getAlignment()) 692 fixups.emplace_back(fixupCode, argNo, index, [=](mlir::FuncOp func) { 693 func.setArgAttr( 694 argNo, "llvm.align", 695 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align)); 696 }); 697 else 698 fixups.emplace_back(fixupCode, argNo, index); 699 } 700 newInTys.push_back(argTy); 701 } 702 } 703 704 private: 705 // Replace `op` and remove it. 706 void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { 707 op->replaceAllUsesWith(newValues); 708 op->dropAllReferences(); 709 op->erase(); 710 } 711 712 inline void setMembers(CodeGenSpecifics *s, mlir::OpBuilder *r) { 713 specifics = s; 714 rewriter = r; 715 } 716 717 inline void clearMembers() { setMembers(nullptr, nullptr); } 718 719 CodeGenSpecifics *specifics{}; 720 mlir::OpBuilder *rewriter; 721 }; // namespace 722 } // namespace 723 724 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 725 fir::createFirTargetRewritePass(const TargetRewriteOptions &options) { 726 return std::make_unique<TargetRewrite>(options); 727 } 728