1 //===-- TargetRewrite.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest. 10 // LLVM expects different lowering idioms to be used for distinct target 11 // triples. These distinctions are handled by this pass. 12 // 13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "PassDetail.h" 18 #include "Target.h" 19 #include "flang/Lower/Todo.h" 20 #include "flang/Optimizer/CodeGen/CodeGen.h" 21 #include "flang/Optimizer/Dialect/FIRDialect.h" 22 #include "flang/Optimizer/Dialect/FIROps.h" 23 #include "flang/Optimizer/Dialect/FIRType.h" 24 #include "flang/Optimizer/Support/FIRContext.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 #include "llvm/ADT/STLExtras.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 #include "llvm/Support/Debug.h" 29 30 using namespace fir; 31 32 #define DEBUG_TYPE "flang-target-rewrite" 33 34 namespace { 35 36 /// Fixups for updating a FuncOp's arguments and return values. 37 struct FixupTy { 38 enum class Codes { 39 ArgumentAsLoad, 40 ArgumentType, 41 CharPair, 42 ReturnAsStore, 43 ReturnType, 44 Split, 45 Trailing 46 }; 47 48 FixupTy(Codes code, std::size_t index, std::size_t second = 0) 49 : code{code}, index{index}, second{second} {} 50 FixupTy(Codes code, std::size_t index, 51 std::function<void(mlir::FuncOp)> &&finalizer) 52 : code{code}, index{index}, finalizer{finalizer} {} 53 FixupTy(Codes code, std::size_t index, std::size_t second, 54 std::function<void(mlir::FuncOp)> &&finalizer) 55 : code{code}, index{index}, second{second}, finalizer{finalizer} {} 56 57 Codes code; 58 std::size_t index; 59 std::size_t second{}; 60 llvm::Optional<std::function<void(mlir::FuncOp)>> finalizer{}; 61 }; // namespace 62 63 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code 64 /// generation that traverses the FIR and modifies types and operations to a 65 /// form that is appropriate for the specific target. LLVM IR has specific 66 /// idioms that are used for distinct target processor and ABI combinations. 67 class TargetRewrite : public TargetRewriteBase<TargetRewrite> { 68 public: 69 TargetRewrite(const TargetRewriteOptions &options) { 70 noCharacterConversion = options.noCharacterConversion; 71 noComplexConversion = options.noComplexConversion; 72 } 73 74 void runOnOperation() override final { 75 auto &context = getContext(); 76 mlir::OpBuilder rewriter(&context); 77 78 auto mod = getModule(); 79 if (!forcedTargetTriple.empty()) { 80 setTargetTriple(mod, forcedTargetTriple); 81 } 82 83 auto specifics = CodeGenSpecifics::get(getOperation().getContext(), 84 getTargetTriple(getOperation()), 85 getKindMapping(getOperation())); 86 setMembers(specifics.get(), &rewriter); 87 88 // Perform type conversion on signatures and call sites. 89 if (mlir::failed(convertTypes(mod))) { 90 mlir::emitError(mlir::UnknownLoc::get(&context), 91 "error in converting types to target abi"); 92 signalPassFailure(); 93 } 94 95 // Convert ops in target-specific patterns. 96 mod.walk([&](mlir::Operation *op) { 97 if (auto call = dyn_cast<fir::CallOp>(op)) { 98 if (!hasPortableSignature(call.getFunctionType())) 99 convertCallOp(call); 100 } else if (auto dispatch = dyn_cast<DispatchOp>(op)) { 101 if (!hasPortableSignature(dispatch.getFunctionType())) 102 convertCallOp(dispatch); 103 } 104 }); 105 106 clearMembers(); 107 } 108 109 mlir::ModuleOp getModule() { return getOperation(); } 110 111 template <typename A, typename B, typename C> 112 std::function<mlir::Value(mlir::Operation *)> 113 rewriteCallComplexResultType(A ty, B &newResTys, B &newInTys, C &newOpers) { 114 auto m = specifics->complexReturnType(ty.getElementType()); 115 // Currently targets mandate COMPLEX is a single aggregate or packed 116 // scalar, including the sret case. 117 assert(m.size() == 1 && "target lowering of complex return not supported"); 118 auto resTy = std::get<mlir::Type>(m[0]); 119 auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); 120 auto loc = mlir::UnknownLoc::get(resTy.getContext()); 121 if (attr.isSRet()) { 122 assert(isa_ref_type(resTy)); 123 mlir::Value stack = 124 rewriter->create<fir::AllocaOp>(loc, dyn_cast_ptrEleTy(resTy)); 125 newInTys.push_back(resTy); 126 newOpers.push_back(stack); 127 return [=](mlir::Operation *) -> mlir::Value { 128 auto memTy = ReferenceType::get(ty); 129 auto cast = rewriter->create<ConvertOp>(loc, memTy, stack); 130 return rewriter->create<fir::LoadOp>(loc, cast); 131 }; 132 } 133 newResTys.push_back(resTy); 134 return [=](mlir::Operation *call) -> mlir::Value { 135 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 136 rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem); 137 auto memTy = ReferenceType::get(ty); 138 auto cast = rewriter->create<ConvertOp>(loc, memTy, mem); 139 return rewriter->create<fir::LoadOp>(loc, cast); 140 }; 141 } 142 143 template <typename A, typename B, typename C> 144 void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys, 145 C &newOpers) { 146 auto m = specifics->complexArgumentType(ty.getElementType()); 147 auto *ctx = ty.getContext(); 148 auto loc = mlir::UnknownLoc::get(ctx); 149 if (m.size() == 1) { 150 // COMPLEX is a single aggregate 151 auto resTy = std::get<mlir::Type>(m[0]); 152 auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); 153 auto oldRefTy = ReferenceType::get(ty); 154 if (attr.isByVal()) { 155 auto mem = rewriter->create<fir::AllocaOp>(loc, ty); 156 rewriter->create<fir::StoreOp>(loc, oper, mem); 157 newOpers.push_back(rewriter->create<ConvertOp>(loc, resTy, mem)); 158 } else { 159 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 160 auto cast = rewriter->create<ConvertOp>(loc, oldRefTy, mem); 161 rewriter->create<fir::StoreOp>(loc, oper, cast); 162 newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem)); 163 } 164 newInTys.push_back(resTy); 165 } else { 166 assert(m.size() == 2); 167 // COMPLEX is split into 2 separate arguments 168 for (auto e : llvm::enumerate(m)) { 169 auto &tup = e.value(); 170 auto ty = std::get<mlir::Type>(tup); 171 auto index = e.index(); 172 auto idx = rewriter->getIntegerAttr(rewriter->getIndexType(), index); 173 auto val = rewriter->create<ExtractValueOp>( 174 loc, ty, oper, rewriter->getArrayAttr(idx)); 175 newInTys.push_back(ty); 176 newOpers.push_back(val); 177 } 178 } 179 } 180 181 // Convert fir.call and fir.dispatch Ops. 182 template <typename A> 183 void convertCallOp(A callOp) { 184 auto fnTy = callOp.getFunctionType(); 185 auto loc = callOp.getLoc(); 186 rewriter->setInsertionPoint(callOp); 187 llvm::SmallVector<mlir::Type> newResTys; 188 llvm::SmallVector<mlir::Type> newInTys; 189 llvm::SmallVector<mlir::Value> newOpers; 190 191 // If the call is indirect, the first argument must still be the function 192 // to call. 193 int dropFront = 0; 194 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 195 if (!callOp.callee().hasValue()) { 196 newInTys.push_back(fnTy.getInput(0)); 197 newOpers.push_back(callOp.getOperand(0)); 198 dropFront = 1; 199 } 200 } 201 202 // Determine the rewrite function, `wrap`, for the result value. 203 llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap; 204 if (fnTy.getResults().size() == 1) { 205 mlir::Type ty = fnTy.getResult(0); 206 llvm::TypeSwitch<mlir::Type>(ty) 207 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 208 wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 209 newOpers); 210 }) 211 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 212 wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 213 newOpers); 214 }) 215 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 216 } else if (fnTy.getResults().size() > 1) { 217 TODO(loc, "multiple results not supported yet"); 218 } 219 220 llvm::SmallVector<mlir::Type> trailingInTys; 221 llvm::SmallVector<mlir::Value> trailingOpers; 222 for (auto e : llvm::enumerate( 223 llvm::zip(fnTy.getInputs().drop_front(dropFront), 224 callOp.getOperands().drop_front(dropFront)))) { 225 mlir::Type ty = std::get<0>(e.value()); 226 mlir::Value oper = std::get<1>(e.value()); 227 unsigned index = e.index(); 228 llvm::TypeSwitch<mlir::Type>(ty) 229 .template Case<BoxCharType>([&](BoxCharType boxTy) { 230 bool sret; 231 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 232 sret = callOp.callee() && 233 functionArgIsSRet(index, 234 getModule().lookupSymbol<mlir::FuncOp>( 235 *callOp.callee())); 236 } else { 237 // TODO: dispatch case; how do we put arguments on a call? 238 // We cannot put both an sret and the dispatch object first. 239 sret = false; 240 TODO(loc, "dispatch + sret not supported yet"); 241 } 242 auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret); 243 auto unbox = 244 rewriter->create<UnboxCharOp>(loc, std::get<mlir::Type>(m[0]), 245 std::get<mlir::Type>(m[1]), oper); 246 // unboxed CHARACTER arguments 247 for (auto e : llvm::enumerate(m)) { 248 unsigned idx = e.index(); 249 auto attr = std::get<CodeGenSpecifics::Attributes>(e.value()); 250 auto argTy = std::get<mlir::Type>(e.value()); 251 if (attr.isAppend()) { 252 trailingInTys.push_back(argTy); 253 trailingOpers.push_back(unbox.getResult(idx)); 254 } else { 255 newInTys.push_back(argTy); 256 newOpers.push_back(unbox.getResult(idx)); 257 } 258 } 259 }) 260 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 261 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 262 }) 263 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 264 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 265 }) 266 .Default([&](mlir::Type ty) { 267 newInTys.push_back(ty); 268 newOpers.push_back(oper); 269 }); 270 } 271 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 272 newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); 273 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 274 fir::CallOp newCall; 275 if (callOp.callee().hasValue()) { 276 newCall = rewriter->create<A>(loc, callOp.callee().getValue(), 277 newResTys, newOpers); 278 } else { 279 // Force new type on the input operand. 280 newOpers[0].setType(mlir::FunctionType::get( 281 callOp.getContext(), 282 mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys)); 283 newCall = rewriter->create<A>(loc, newResTys, newOpers); 284 } 285 LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n'); 286 if (wrap.hasValue()) 287 replaceOp(callOp, (*wrap)(newCall.getOperation())); 288 else 289 replaceOp(callOp, newCall.getResults()); 290 } else { 291 // A is fir::DispatchOp 292 TODO(loc, "dispatch not implemented"); 293 } 294 } 295 296 // Result type fixup for fir::ComplexType and mlir::ComplexType 297 template <typename A, typename B> 298 void lowerComplexSignatureRes(A cmplx, B &newResTys, B &newInTys) { 299 if (noComplexConversion) { 300 newResTys.push_back(cmplx); 301 } else { 302 for (auto &tup : specifics->complexReturnType(cmplx.getElementType())) { 303 auto argTy = std::get<mlir::Type>(tup); 304 if (std::get<CodeGenSpecifics::Attributes>(tup).isSRet()) 305 newInTys.push_back(argTy); 306 else 307 newResTys.push_back(argTy); 308 } 309 } 310 } 311 312 // Argument type fixup for fir::ComplexType and mlir::ComplexType 313 template <typename A, typename B> 314 void lowerComplexSignatureArg(A cmplx, B &newInTys) { 315 if (noComplexConversion) 316 newInTys.push_back(cmplx); 317 else 318 for (auto &tup : specifics->complexArgumentType(cmplx.getElementType())) 319 newInTys.push_back(std::get<mlir::Type>(tup)); 320 } 321 322 /// Convert the type signatures on all the functions present in the module. 323 /// As the type signature is being changed, this must also update the 324 /// function itself to use any new arguments, etc. 325 mlir::LogicalResult convertTypes(mlir::ModuleOp mod) { 326 for (auto fn : mod.getOps<mlir::FuncOp>()) 327 convertSignature(fn); 328 return mlir::success(); 329 } 330 331 /// If the signature does not need any special target-specific converions, 332 /// then it is considered portable for any target, and this function will 333 /// return `true`. Otherwise, the signature is not portable and `false` is 334 /// returned. 335 bool hasPortableSignature(mlir::Type signature) { 336 assert(signature.isa<mlir::FunctionType>()); 337 auto func = signature.dyn_cast<mlir::FunctionType>(); 338 for (auto ty : func.getResults()) 339 if ((ty.isa<BoxCharType>() && !noCharacterConversion) || 340 (isa_complex(ty) && !noComplexConversion)) { 341 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 342 return false; 343 } 344 for (auto ty : func.getInputs()) 345 if ((ty.isa<BoxCharType>() && !noCharacterConversion) || 346 (isa_complex(ty) && !noComplexConversion)) { 347 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 348 return false; 349 } 350 return true; 351 } 352 353 /// Rewrite the signatures and body of the `FuncOp`s in the module for 354 /// the immediately subsequent target code gen. 355 void convertSignature(mlir::FuncOp func) { 356 auto funcTy = func.getType().cast<mlir::FunctionType>(); 357 if (hasPortableSignature(funcTy)) 358 return; 359 llvm::SmallVector<mlir::Type> newResTys; 360 llvm::SmallVector<mlir::Type> newInTys; 361 llvm::SmallVector<FixupTy> fixups; 362 363 // Convert return value(s) 364 for (auto ty : funcTy.getResults()) 365 llvm::TypeSwitch<mlir::Type>(ty) 366 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 367 if (noComplexConversion) 368 newResTys.push_back(cmplx); 369 else 370 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 371 }) 372 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 373 if (noComplexConversion) 374 newResTys.push_back(cmplx); 375 else 376 doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 377 }) 378 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 379 380 // Convert arguments 381 llvm::SmallVector<mlir::Type> trailingTys; 382 for (auto e : llvm::enumerate(funcTy.getInputs())) { 383 auto ty = e.value(); 384 unsigned index = e.index(); 385 llvm::TypeSwitch<mlir::Type>(ty) 386 .Case<BoxCharType>([&](BoxCharType boxTy) { 387 if (noCharacterConversion) { 388 newInTys.push_back(boxTy); 389 } else { 390 // Convert a CHARACTER argument type. This can involve separating 391 // the pointer and the LEN into two arguments and moving the LEN 392 // argument to the end of the arg list. 393 bool sret = functionArgIsSRet(index, func); 394 for (auto e : llvm::enumerate(specifics->boxcharArgumentType( 395 boxTy.getEleTy(), sret))) { 396 auto &tup = e.value(); 397 auto index = e.index(); 398 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 399 auto argTy = std::get<mlir::Type>(tup); 400 if (attr.isAppend()) { 401 trailingTys.push_back(argTy); 402 } else { 403 if (sret) { 404 fixups.emplace_back(FixupTy::Codes::CharPair, 405 newInTys.size(), index); 406 } else { 407 fixups.emplace_back(FixupTy::Codes::Trailing, 408 newInTys.size(), trailingTys.size()); 409 } 410 newInTys.push_back(argTy); 411 } 412 } 413 } 414 }) 415 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 416 if (noComplexConversion) 417 newInTys.push_back(cmplx); 418 else 419 doComplexArg(func, cmplx, newInTys, fixups); 420 }) 421 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 422 if (noComplexConversion) 423 newInTys.push_back(cmplx); 424 else 425 doComplexArg(func, cmplx, newInTys, fixups); 426 }) 427 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 428 } 429 430 if (!func.empty()) { 431 // If the function has a body, then apply the fixups to the arguments and 432 // return ops as required. These fixups are done in place. 433 auto loc = func.getLoc(); 434 const auto fixupSize = fixups.size(); 435 const auto oldArgTys = func.getType().getInputs(); 436 int offset = 0; 437 for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) { 438 const auto &fixup = fixups[i]; 439 switch (fixup.code) { 440 case FixupTy::Codes::ArgumentAsLoad: { 441 // Argument was pass-by-value, but is now pass-by-reference and 442 // possibly with a different element type. 443 auto newArg = 444 func.front().insertArgument(fixup.index, newInTys[fixup.index]); 445 rewriter->setInsertionPointToStart(&func.front()); 446 auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); 447 auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, newArg); 448 auto load = rewriter->create<fir::LoadOp>(loc, cast); 449 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 450 func.front().eraseArgument(fixup.index + 1); 451 } break; 452 case FixupTy::Codes::ArgumentType: { 453 // Argument is pass-by-value, but its type has likely been modified to 454 // suit the target ABI convention. 455 auto newArg = 456 func.front().insertArgument(fixup.index, newInTys[fixup.index]); 457 rewriter->setInsertionPointToStart(&func.front()); 458 auto mem = 459 rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]); 460 rewriter->create<fir::StoreOp>(loc, newArg, mem); 461 auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); 462 auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, mem); 463 mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast); 464 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 465 func.front().eraseArgument(fixup.index + 1); 466 LLVM_DEBUG(llvm::dbgs() 467 << "old argument: " << oldArgTy.getEleTy() 468 << ", repl: " << load << ", new argument: " 469 << func.getArgument(fixup.index).getType() << '\n'); 470 } break; 471 case FixupTy::Codes::CharPair: { 472 // The FIR boxchar argument has been split into a pair of distinct 473 // arguments that are in juxtaposition to each other. 474 auto newArg = 475 func.front().insertArgument(fixup.index, newInTys[fixup.index]); 476 if (fixup.second == 1) { 477 rewriter->setInsertionPointToStart(&func.front()); 478 auto boxTy = oldArgTys[fixup.index - offset - fixup.second]; 479 auto box = rewriter->create<EmboxCharOp>( 480 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg); 481 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 482 func.front().eraseArgument(fixup.index + 1); 483 offset++; 484 } 485 } break; 486 case FixupTy::Codes::ReturnAsStore: { 487 // The value being returned is now being returned in memory (callee 488 // stack space) through a hidden reference argument. 489 auto newArg = 490 func.front().insertArgument(fixup.index, newInTys[fixup.index]); 491 offset++; 492 func.walk([&](mlir::ReturnOp ret) { 493 rewriter->setInsertionPoint(ret); 494 auto oldOper = ret.getOperand(0); 495 auto oldOperTy = ReferenceType::get(oldOper.getType()); 496 auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, newArg); 497 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 498 rewriter->create<mlir::ReturnOp>(loc); 499 ret.erase(); 500 }); 501 } break; 502 case FixupTy::Codes::ReturnType: { 503 // The function is still returning a value, but its type has likely 504 // changed to suit the target ABI convention. 505 func.walk([&](mlir::ReturnOp ret) { 506 rewriter->setInsertionPoint(ret); 507 auto oldOper = ret.getOperand(0); 508 auto oldOperTy = ReferenceType::get(oldOper.getType()); 509 auto mem = 510 rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]); 511 auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, mem); 512 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 513 mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem); 514 rewriter->create<mlir::ReturnOp>(loc, load); 515 ret.erase(); 516 }); 517 } break; 518 case FixupTy::Codes::Split: { 519 // The FIR argument has been split into a pair of distinct arguments 520 // that are in juxtaposition to each other. (For COMPLEX value.) 521 auto newArg = 522 func.front().insertArgument(fixup.index, newInTys[fixup.index]); 523 if (fixup.second == 1) { 524 rewriter->setInsertionPointToStart(&func.front()); 525 auto cplxTy = oldArgTys[fixup.index - offset - fixup.second]; 526 auto undef = rewriter->create<UndefOp>(loc, cplxTy); 527 auto zero = rewriter->getIntegerAttr(rewriter->getIndexType(), 0); 528 auto one = rewriter->getIntegerAttr(rewriter->getIndexType(), 1); 529 auto cplx1 = rewriter->create<InsertValueOp>( 530 loc, cplxTy, undef, func.front().getArgument(fixup.index - 1), 531 rewriter->getArrayAttr(zero)); 532 auto cplx = rewriter->create<InsertValueOp>( 533 loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one)); 534 func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx); 535 func.front().eraseArgument(fixup.index + 1); 536 offset++; 537 } 538 } break; 539 case FixupTy::Codes::Trailing: { 540 // The FIR argument has been split into a pair of distinct arguments. 541 // The first part of the pair appears in the original argument 542 // position. The second part of the pair is appended after all the 543 // original arguments. (Boxchar arguments.) 544 auto newBufArg = 545 func.front().insertArgument(fixup.index, newInTys[fixup.index]); 546 auto newLenArg = func.front().addArgument(trailingTys[fixup.second]); 547 auto boxTy = oldArgTys[fixup.index - offset]; 548 rewriter->setInsertionPointToStart(&func.front()); 549 auto box = 550 rewriter->create<EmboxCharOp>(loc, boxTy, newBufArg, newLenArg); 551 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 552 func.front().eraseArgument(fixup.index + 1); 553 } break; 554 } 555 } 556 } 557 558 // Set the new type and finalize the arguments, etc. 559 newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end()); 560 auto newFuncTy = 561 mlir::FunctionType::get(func.getContext(), newInTys, newResTys); 562 LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n'); 563 func.setType(newFuncTy); 564 565 for (auto &fixup : fixups) 566 if (fixup.finalizer) 567 (*fixup.finalizer)(func); 568 } 569 570 inline bool functionArgIsSRet(unsigned index, mlir::FuncOp func) { 571 if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret")) 572 return true; 573 return false; 574 } 575 576 /// Convert a complex return value. This can involve converting the return 577 /// value to a "hidden" first argument or packing the complex into a wide 578 /// GPR. 579 template <typename A, typename B, typename C> 580 void doComplexReturn(mlir::FuncOp func, A cmplx, B &newResTys, B &newInTys, 581 C &fixups) { 582 if (noComplexConversion) { 583 newResTys.push_back(cmplx); 584 return; 585 } 586 auto m = specifics->complexReturnType(cmplx.getElementType()); 587 assert(m.size() == 1); 588 auto &tup = m[0]; 589 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 590 auto argTy = std::get<mlir::Type>(tup); 591 if (attr.isSRet()) { 592 unsigned argNo = newInTys.size(); 593 fixups.emplace_back( 594 FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::FuncOp func) { 595 func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); 596 }); 597 newInTys.push_back(argTy); 598 return; 599 } 600 fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size()); 601 newResTys.push_back(argTy); 602 } 603 604 /// Convert a complex argument value. This can involve storing the value to 605 /// a temporary memory location or factoring the value into two distinct 606 /// arguments. 607 template <typename A, typename B, typename C> 608 void doComplexArg(mlir::FuncOp func, A cmplx, B &newInTys, C &fixups) { 609 if (noComplexConversion) { 610 newInTys.push_back(cmplx); 611 return; 612 } 613 auto m = specifics->complexArgumentType(cmplx.getElementType()); 614 const auto fixupCode = 615 m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType; 616 for (auto e : llvm::enumerate(m)) { 617 auto &tup = e.value(); 618 auto index = e.index(); 619 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 620 auto argTy = std::get<mlir::Type>(tup); 621 auto argNo = newInTys.size(); 622 if (attr.isByVal()) { 623 if (auto align = attr.getAlignment()) 624 fixups.emplace_back( 625 FixupTy::Codes::ArgumentAsLoad, argNo, [=](mlir::FuncOp func) { 626 func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr()); 627 func.setArgAttr(argNo, "llvm.align", 628 rewriter->getIntegerAttr( 629 rewriter->getIntegerType(32), align)); 630 }); 631 else 632 fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(), 633 [=](mlir::FuncOp func) { 634 func.setArgAttr(argNo, "llvm.byval", 635 rewriter->getUnitAttr()); 636 }); 637 } else { 638 if (auto align = attr.getAlignment()) 639 fixups.emplace_back(fixupCode, argNo, index, [=](mlir::FuncOp func) { 640 func.setArgAttr( 641 argNo, "llvm.align", 642 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align)); 643 }); 644 else 645 fixups.emplace_back(fixupCode, argNo, index); 646 } 647 newInTys.push_back(argTy); 648 } 649 } 650 651 private: 652 // Replace `op` and remove it. 653 void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { 654 op->replaceAllUsesWith(newValues); 655 op->dropAllReferences(); 656 op->erase(); 657 } 658 659 inline void setMembers(CodeGenSpecifics *s, mlir::OpBuilder *r) { 660 specifics = s; 661 rewriter = r; 662 } 663 664 inline void clearMembers() { setMembers(nullptr, nullptr); } 665 666 CodeGenSpecifics *specifics{}; 667 mlir::OpBuilder *rewriter; 668 }; // namespace 669 } // namespace 670 671 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 672 fir::createFirTargetRewritePass(const TargetRewriteOptions &options) { 673 return std::make_unique<TargetRewrite>(options); 674 } 675