14c263edeSDiana Picus //===-- TargetRewrite.cpp -------------------------------------------------===// 24c263edeSDiana Picus // 34c263edeSDiana Picus // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44c263edeSDiana Picus // See https://llvm.org/LICENSE.txt for license information. 54c263edeSDiana Picus // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64c263edeSDiana Picus // 74c263edeSDiana Picus //===----------------------------------------------------------------------===// 84c263edeSDiana Picus // 94c263edeSDiana Picus // Target rewrite: rewriting of ops to make target-specific lowerings manifest. 104c263edeSDiana Picus // LLVM expects different lowering idioms to be used for distinct target 114c263edeSDiana Picus // triples. These distinctions are handled by this pass. 124c263edeSDiana Picus // 134c263edeSDiana Picus // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 144c263edeSDiana Picus // 154c263edeSDiana Picus //===----------------------------------------------------------------------===// 164c263edeSDiana Picus 174c263edeSDiana Picus #include "PassDetail.h" 184c263edeSDiana Picus #include "Target.h" 194c263edeSDiana Picus #include "flang/Lower/Todo.h" 20416e503aSJean Perier #include "flang/Optimizer/Builder/Character.h" 214c263edeSDiana Picus #include "flang/Optimizer/CodeGen/CodeGen.h" 224c263edeSDiana Picus #include "flang/Optimizer/Dialect/FIRDialect.h" 234c263edeSDiana Picus #include "flang/Optimizer/Dialect/FIROps.h" 24416e503aSJean Perier #include "flang/Optimizer/Dialect/FIROpsSupport.h" 254c263edeSDiana Picus #include "flang/Optimizer/Dialect/FIRType.h" 264c263edeSDiana Picus #include "flang/Optimizer/Support/FIRContext.h" 274c263edeSDiana Picus #include "mlir/Transforms/DialectConversion.h" 284c263edeSDiana Picus #include "llvm/ADT/STLExtras.h" 294c263edeSDiana Picus #include "llvm/ADT/TypeSwitch.h" 304c263edeSDiana Picus #include "llvm/Support/Debug.h" 314c263edeSDiana Picus 324c263edeSDiana Picus using namespace fir; 334c263edeSDiana Picus 344c263edeSDiana Picus #define DEBUG_TYPE "flang-target-rewrite" 354c263edeSDiana Picus 364c263edeSDiana Picus namespace { 374c263edeSDiana Picus 384c263edeSDiana Picus /// Fixups for updating a FuncOp's arguments and return values. 394c263edeSDiana Picus struct FixupTy { 4065431d3aSDiana Picus enum class Codes { 4165431d3aSDiana Picus ArgumentAsLoad, 4265431d3aSDiana Picus ArgumentType, 4365431d3aSDiana Picus CharPair, 4465431d3aSDiana Picus ReturnAsStore, 4565431d3aSDiana Picus ReturnType, 4665431d3aSDiana Picus Split, 47416e503aSJean Perier Trailing, 48416e503aSJean Perier TrailingCharProc 4965431d3aSDiana Picus }; 504c263edeSDiana Picus 514c263edeSDiana Picus FixupTy(Codes code, std::size_t index, std::size_t second = 0) 524c263edeSDiana Picus : code{code}, index{index}, second{second} {} 534c263edeSDiana Picus FixupTy(Codes code, std::size_t index, 544c263edeSDiana Picus std::function<void(mlir::FuncOp)> &&finalizer) 554c263edeSDiana Picus : code{code}, index{index}, finalizer{finalizer} {} 564c263edeSDiana Picus FixupTy(Codes code, std::size_t index, std::size_t second, 574c263edeSDiana Picus std::function<void(mlir::FuncOp)> &&finalizer) 584c263edeSDiana Picus : code{code}, index{index}, second{second}, finalizer{finalizer} {} 594c263edeSDiana Picus 604c263edeSDiana Picus Codes code; 614c263edeSDiana Picus std::size_t index; 624c263edeSDiana Picus std::size_t second{}; 634c263edeSDiana Picus llvm::Optional<std::function<void(mlir::FuncOp)>> finalizer{}; 644c263edeSDiana Picus }; // namespace 654c263edeSDiana Picus 664c263edeSDiana Picus /// Target-specific rewriting of the FIR. This is a prerequisite pass to code 674c263edeSDiana Picus /// generation that traverses the FIR and modifies types and operations to a 684c263edeSDiana Picus /// form that is appropriate for the specific target. LLVM IR has specific 694c263edeSDiana Picus /// idioms that are used for distinct target processor and ABI combinations. 704c263edeSDiana Picus class TargetRewrite : public TargetRewriteBase<TargetRewrite> { 714c263edeSDiana Picus public: 724c263edeSDiana Picus TargetRewrite(const TargetRewriteOptions &options) { 734c263edeSDiana Picus noCharacterConversion = options.noCharacterConversion; 7465431d3aSDiana Picus noComplexConversion = options.noComplexConversion; 754c263edeSDiana Picus } 764c263edeSDiana Picus 774c263edeSDiana Picus void runOnOperation() override final { 784c263edeSDiana Picus auto &context = getContext(); 794c263edeSDiana Picus mlir::OpBuilder rewriter(&context); 804c263edeSDiana Picus 814c263edeSDiana Picus auto mod = getModule(); 82010a10b7SValentin Clement if (!forcedTargetTriple.empty()) 834c263edeSDiana Picus setTargetTriple(mod, forcedTargetTriple); 844c263edeSDiana Picus 854c263edeSDiana Picus auto specifics = CodeGenSpecifics::get(getOperation().getContext(), 864c263edeSDiana Picus getTargetTriple(getOperation()), 874c263edeSDiana Picus getKindMapping(getOperation())); 884c263edeSDiana Picus setMembers(specifics.get(), &rewriter); 894c263edeSDiana Picus 904c263edeSDiana Picus // Perform type conversion on signatures and call sites. 914c263edeSDiana Picus if (mlir::failed(convertTypes(mod))) { 924c263edeSDiana Picus mlir::emitError(mlir::UnknownLoc::get(&context), 934c263edeSDiana Picus "error in converting types to target abi"); 944c263edeSDiana Picus signalPassFailure(); 954c263edeSDiana Picus } 964c263edeSDiana Picus 974c263edeSDiana Picus // Convert ops in target-specific patterns. 984c263edeSDiana Picus mod.walk([&](mlir::Operation *op) { 994c263edeSDiana Picus if (auto call = dyn_cast<fir::CallOp>(op)) { 1004c263edeSDiana Picus if (!hasPortableSignature(call.getFunctionType())) 1014c263edeSDiana Picus convertCallOp(call); 1024c263edeSDiana Picus } else if (auto dispatch = dyn_cast<DispatchOp>(op)) { 1034c263edeSDiana Picus if (!hasPortableSignature(dispatch.getFunctionType())) 1044c263edeSDiana Picus convertCallOp(dispatch); 1053fd250d2SDiana Picus } else if (auto addr = dyn_cast<AddrOfOp>(op)) { 1063fd250d2SDiana Picus if (addr.getType().isa<mlir::FunctionType>() && 1073fd250d2SDiana Picus !hasPortableSignature(addr.getType())) 1083fd250d2SDiana Picus convertAddrOp(addr); 1094c263edeSDiana Picus } 1104c263edeSDiana Picus }); 1114c263edeSDiana Picus 1124c263edeSDiana Picus clearMembers(); 1134c263edeSDiana Picus } 1144c263edeSDiana Picus 1154c263edeSDiana Picus mlir::ModuleOp getModule() { return getOperation(); } 1164c263edeSDiana Picus 11765431d3aSDiana Picus template <typename A, typename B, typename C> 11865431d3aSDiana Picus std::function<mlir::Value(mlir::Operation *)> 11965431d3aSDiana Picus rewriteCallComplexResultType(A ty, B &newResTys, B &newInTys, C &newOpers) { 12065431d3aSDiana Picus auto m = specifics->complexReturnType(ty.getElementType()); 12165431d3aSDiana Picus // Currently targets mandate COMPLEX is a single aggregate or packed 12265431d3aSDiana Picus // scalar, including the sret case. 12365431d3aSDiana Picus assert(m.size() == 1 && "target lowering of complex return not supported"); 12465431d3aSDiana Picus auto resTy = std::get<mlir::Type>(m[0]); 12565431d3aSDiana Picus auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); 12665431d3aSDiana Picus auto loc = mlir::UnknownLoc::get(resTy.getContext()); 12765431d3aSDiana Picus if (attr.isSRet()) { 12865431d3aSDiana Picus assert(isa_ref_type(resTy)); 12965431d3aSDiana Picus mlir::Value stack = 13065431d3aSDiana Picus rewriter->create<fir::AllocaOp>(loc, dyn_cast_ptrEleTy(resTy)); 13165431d3aSDiana Picus newInTys.push_back(resTy); 13265431d3aSDiana Picus newOpers.push_back(stack); 13365431d3aSDiana Picus return [=](mlir::Operation *) -> mlir::Value { 13465431d3aSDiana Picus auto memTy = ReferenceType::get(ty); 13565431d3aSDiana Picus auto cast = rewriter->create<ConvertOp>(loc, memTy, stack); 13665431d3aSDiana Picus return rewriter->create<fir::LoadOp>(loc, cast); 13765431d3aSDiana Picus }; 13865431d3aSDiana Picus } 13965431d3aSDiana Picus newResTys.push_back(resTy); 14065431d3aSDiana Picus return [=](mlir::Operation *call) -> mlir::Value { 14165431d3aSDiana Picus auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 14265431d3aSDiana Picus rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem); 14365431d3aSDiana Picus auto memTy = ReferenceType::get(ty); 14465431d3aSDiana Picus auto cast = rewriter->create<ConvertOp>(loc, memTy, mem); 14565431d3aSDiana Picus return rewriter->create<fir::LoadOp>(loc, cast); 14665431d3aSDiana Picus }; 14765431d3aSDiana Picus } 14865431d3aSDiana Picus 14965431d3aSDiana Picus template <typename A, typename B, typename C> 15065431d3aSDiana Picus void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys, 15165431d3aSDiana Picus C &newOpers) { 15265431d3aSDiana Picus auto m = specifics->complexArgumentType(ty.getElementType()); 15365431d3aSDiana Picus auto *ctx = ty.getContext(); 15465431d3aSDiana Picus auto loc = mlir::UnknownLoc::get(ctx); 15565431d3aSDiana Picus if (m.size() == 1) { 15665431d3aSDiana Picus // COMPLEX is a single aggregate 15765431d3aSDiana Picus auto resTy = std::get<mlir::Type>(m[0]); 15865431d3aSDiana Picus auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); 15965431d3aSDiana Picus auto oldRefTy = ReferenceType::get(ty); 16065431d3aSDiana Picus if (attr.isByVal()) { 16165431d3aSDiana Picus auto mem = rewriter->create<fir::AllocaOp>(loc, ty); 16265431d3aSDiana Picus rewriter->create<fir::StoreOp>(loc, oper, mem); 16365431d3aSDiana Picus newOpers.push_back(rewriter->create<ConvertOp>(loc, resTy, mem)); 16465431d3aSDiana Picus } else { 16565431d3aSDiana Picus auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); 16665431d3aSDiana Picus auto cast = rewriter->create<ConvertOp>(loc, oldRefTy, mem); 16765431d3aSDiana Picus rewriter->create<fir::StoreOp>(loc, oper, cast); 16865431d3aSDiana Picus newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem)); 16965431d3aSDiana Picus } 17065431d3aSDiana Picus newInTys.push_back(resTy); 17165431d3aSDiana Picus } else { 17265431d3aSDiana Picus assert(m.size() == 2); 17365431d3aSDiana Picus // COMPLEX is split into 2 separate arguments 174*9b5bb511SValentin Clement auto iTy = rewriter->getIntegerType(32); 17565431d3aSDiana Picus for (auto e : llvm::enumerate(m)) { 17665431d3aSDiana Picus auto &tup = e.value(); 17765431d3aSDiana Picus auto ty = std::get<mlir::Type>(tup); 17865431d3aSDiana Picus auto index = e.index(); 179*9b5bb511SValentin Clement auto idx = rewriter->getIntegerAttr(iTy, index); 18065431d3aSDiana Picus auto val = rewriter->create<ExtractValueOp>( 18165431d3aSDiana Picus loc, ty, oper, rewriter->getArrayAttr(idx)); 18265431d3aSDiana Picus newInTys.push_back(ty); 18365431d3aSDiana Picus newOpers.push_back(val); 18465431d3aSDiana Picus } 18565431d3aSDiana Picus } 18665431d3aSDiana Picus } 18765431d3aSDiana Picus 1884c263edeSDiana Picus // Convert fir.call and fir.dispatch Ops. 1894c263edeSDiana Picus template <typename A> 1904c263edeSDiana Picus void convertCallOp(A callOp) { 1914c263edeSDiana Picus auto fnTy = callOp.getFunctionType(); 1924c263edeSDiana Picus auto loc = callOp.getLoc(); 1934c263edeSDiana Picus rewriter->setInsertionPoint(callOp); 1944c263edeSDiana Picus llvm::SmallVector<mlir::Type> newResTys; 1954c263edeSDiana Picus llvm::SmallVector<mlir::Type> newInTys; 1964c263edeSDiana Picus llvm::SmallVector<mlir::Value> newOpers; 1974c263edeSDiana Picus 1984c263edeSDiana Picus // If the call is indirect, the first argument must still be the function 1994c263edeSDiana Picus // to call. 2004c263edeSDiana Picus int dropFront = 0; 2014c263edeSDiana Picus if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 2024c263edeSDiana Picus if (!callOp.callee().hasValue()) { 2034c263edeSDiana Picus newInTys.push_back(fnTy.getInput(0)); 2044c263edeSDiana Picus newOpers.push_back(callOp.getOperand(0)); 2054c263edeSDiana Picus dropFront = 1; 2064c263edeSDiana Picus } 2074c263edeSDiana Picus } 2084c263edeSDiana Picus 2094c263edeSDiana Picus // Determine the rewrite function, `wrap`, for the result value. 2104c263edeSDiana Picus llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap; 2114c263edeSDiana Picus if (fnTy.getResults().size() == 1) { 2124c263edeSDiana Picus mlir::Type ty = fnTy.getResult(0); 21365431d3aSDiana Picus llvm::TypeSwitch<mlir::Type>(ty) 21465431d3aSDiana Picus .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 21565431d3aSDiana Picus wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 21665431d3aSDiana Picus newOpers); 21765431d3aSDiana Picus }) 21865431d3aSDiana Picus .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 21965431d3aSDiana Picus wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, 22065431d3aSDiana Picus newOpers); 22165431d3aSDiana Picus }) 22265431d3aSDiana Picus .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 2234c263edeSDiana Picus } else if (fnTy.getResults().size() > 1) { 2244c263edeSDiana Picus TODO(loc, "multiple results not supported yet"); 2254c263edeSDiana Picus } 2264c263edeSDiana Picus 2274c263edeSDiana Picus llvm::SmallVector<mlir::Type> trailingInTys; 2284c263edeSDiana Picus llvm::SmallVector<mlir::Value> trailingOpers; 2294c263edeSDiana Picus for (auto e : llvm::enumerate( 2304c263edeSDiana Picus llvm::zip(fnTy.getInputs().drop_front(dropFront), 2314c263edeSDiana Picus callOp.getOperands().drop_front(dropFront)))) { 2324c263edeSDiana Picus mlir::Type ty = std::get<0>(e.value()); 2334c263edeSDiana Picus mlir::Value oper = std::get<1>(e.value()); 2344c263edeSDiana Picus unsigned index = e.index(); 2354c263edeSDiana Picus llvm::TypeSwitch<mlir::Type>(ty) 2364c263edeSDiana Picus .template Case<BoxCharType>([&](BoxCharType boxTy) { 2374c263edeSDiana Picus bool sret; 2384c263edeSDiana Picus if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 2394c263edeSDiana Picus sret = callOp.callee() && 2404c263edeSDiana Picus functionArgIsSRet(index, 2414c263edeSDiana Picus getModule().lookupSymbol<mlir::FuncOp>( 2424c263edeSDiana Picus *callOp.callee())); 2434c263edeSDiana Picus } else { 2444c263edeSDiana Picus // TODO: dispatch case; how do we put arguments on a call? 2454c263edeSDiana Picus // We cannot put both an sret and the dispatch object first. 2464c263edeSDiana Picus sret = false; 2474c263edeSDiana Picus TODO(loc, "dispatch + sret not supported yet"); 2484c263edeSDiana Picus } 2494c263edeSDiana Picus auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret); 2504c263edeSDiana Picus auto unbox = 2514c263edeSDiana Picus rewriter->create<UnboxCharOp>(loc, std::get<mlir::Type>(m[0]), 2524c263edeSDiana Picus std::get<mlir::Type>(m[1]), oper); 2534c263edeSDiana Picus // unboxed CHARACTER arguments 2544c263edeSDiana Picus for (auto e : llvm::enumerate(m)) { 2554c263edeSDiana Picus unsigned idx = e.index(); 2564c263edeSDiana Picus auto attr = std::get<CodeGenSpecifics::Attributes>(e.value()); 2574c263edeSDiana Picus auto argTy = std::get<mlir::Type>(e.value()); 2584c263edeSDiana Picus if (attr.isAppend()) { 2594c263edeSDiana Picus trailingInTys.push_back(argTy); 2604c263edeSDiana Picus trailingOpers.push_back(unbox.getResult(idx)); 2614c263edeSDiana Picus } else { 2624c263edeSDiana Picus newInTys.push_back(argTy); 2634c263edeSDiana Picus newOpers.push_back(unbox.getResult(idx)); 2644c263edeSDiana Picus } 2654c263edeSDiana Picus } 2664c263edeSDiana Picus }) 26765431d3aSDiana Picus .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 26865431d3aSDiana Picus rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 26965431d3aSDiana Picus }) 27065431d3aSDiana Picus .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 27165431d3aSDiana Picus rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); 27265431d3aSDiana Picus }) 273416e503aSJean Perier .template Case<mlir::TupleType>([&](mlir::TupleType tuple) { 274416e503aSJean Perier if (factory::isCharacterProcedureTuple(tuple)) { 275416e503aSJean Perier mlir::ModuleOp module = getModule(); 276416e503aSJean Perier if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 277416e503aSJean Perier if (callOp.callee()) { 278416e503aSJean Perier llvm::StringRef charProcAttr = 279416e503aSJean Perier fir::getCharacterProcedureDummyAttrName(); 280416e503aSJean Perier // The charProcAttr attribute is only used as a safety to 281416e503aSJean Perier // confirm that this is a dummy procedure and should be split. 282416e503aSJean Perier // It cannot be used to match because attributes are not 283416e503aSJean Perier // available in case of indirect calls. 284416e503aSJean Perier auto funcOp = 285416e503aSJean Perier module.lookupSymbol<mlir::FuncOp>(*callOp.callee()); 286416e503aSJean Perier if (funcOp && 287416e503aSJean Perier !funcOp.template getArgAttrOfType<mlir::UnitAttr>( 288416e503aSJean Perier index, charProcAttr)) 289416e503aSJean Perier mlir::emitError(loc, "tuple argument will be split even " 290416e503aSJean Perier "though it does not have the `" + 291416e503aSJean Perier charProcAttr + "` attribute"); 292416e503aSJean Perier } 293416e503aSJean Perier } 294416e503aSJean Perier mlir::Type funcPointerType = tuple.getType(0); 295416e503aSJean Perier mlir::Type lenType = tuple.getType(1); 296416e503aSJean Perier FirOpBuilder builder(*rewriter, getKindMapping(module)); 297416e503aSJean Perier auto [funcPointer, len] = 298416e503aSJean Perier factory::extractCharacterProcedureTuple(builder, loc, oper); 299416e503aSJean Perier newInTys.push_back(funcPointerType); 300416e503aSJean Perier newOpers.push_back(funcPointer); 301416e503aSJean Perier trailingInTys.push_back(lenType); 302416e503aSJean Perier trailingOpers.push_back(len); 303416e503aSJean Perier } else { 304416e503aSJean Perier newInTys.push_back(tuple); 305416e503aSJean Perier newOpers.push_back(oper); 306416e503aSJean Perier } 307416e503aSJean Perier }) 3084c263edeSDiana Picus .Default([&](mlir::Type ty) { 3094c263edeSDiana Picus newInTys.push_back(ty); 3104c263edeSDiana Picus newOpers.push_back(oper); 3114c263edeSDiana Picus }); 3124c263edeSDiana Picus } 3134c263edeSDiana Picus newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 3144c263edeSDiana Picus newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); 3154c263edeSDiana Picus if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 3164c263edeSDiana Picus fir::CallOp newCall; 3174c263edeSDiana Picus if (callOp.callee().hasValue()) { 3184c263edeSDiana Picus newCall = rewriter->create<A>(loc, callOp.callee().getValue(), 3194c263edeSDiana Picus newResTys, newOpers); 3204c263edeSDiana Picus } else { 3214c263edeSDiana Picus // Force new type on the input operand. 3224c263edeSDiana Picus newOpers[0].setType(mlir::FunctionType::get( 3234c263edeSDiana Picus callOp.getContext(), 3244c263edeSDiana Picus mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys)); 3254c263edeSDiana Picus newCall = rewriter->create<A>(loc, newResTys, newOpers); 3264c263edeSDiana Picus } 3274c263edeSDiana Picus LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n'); 3284c263edeSDiana Picus if (wrap.hasValue()) 3294c263edeSDiana Picus replaceOp(callOp, (*wrap)(newCall.getOperation())); 3304c263edeSDiana Picus else 3314c263edeSDiana Picus replaceOp(callOp, newCall.getResults()); 3324c263edeSDiana Picus } else { 3334c263edeSDiana Picus // A is fir::DispatchOp 3344c263edeSDiana Picus TODO(loc, "dispatch not implemented"); 3354c263edeSDiana Picus } 3364c263edeSDiana Picus } 33765431d3aSDiana Picus 33865431d3aSDiana Picus // Result type fixup for fir::ComplexType and mlir::ComplexType 33965431d3aSDiana Picus template <typename A, typename B> 34065431d3aSDiana Picus void lowerComplexSignatureRes(A cmplx, B &newResTys, B &newInTys) { 34165431d3aSDiana Picus if (noComplexConversion) { 34265431d3aSDiana Picus newResTys.push_back(cmplx); 34365431d3aSDiana Picus } else { 34465431d3aSDiana Picus for (auto &tup : specifics->complexReturnType(cmplx.getElementType())) { 34565431d3aSDiana Picus auto argTy = std::get<mlir::Type>(tup); 34665431d3aSDiana Picus if (std::get<CodeGenSpecifics::Attributes>(tup).isSRet()) 34765431d3aSDiana Picus newInTys.push_back(argTy); 34865431d3aSDiana Picus else 34965431d3aSDiana Picus newResTys.push_back(argTy); 35065431d3aSDiana Picus } 35165431d3aSDiana Picus } 35265431d3aSDiana Picus } 35365431d3aSDiana Picus 35465431d3aSDiana Picus // Argument type fixup for fir::ComplexType and mlir::ComplexType 35565431d3aSDiana Picus template <typename A, typename B> 35665431d3aSDiana Picus void lowerComplexSignatureArg(A cmplx, B &newInTys) { 35765431d3aSDiana Picus if (noComplexConversion) 35865431d3aSDiana Picus newInTys.push_back(cmplx); 35965431d3aSDiana Picus else 36065431d3aSDiana Picus for (auto &tup : specifics->complexArgumentType(cmplx.getElementType())) 36165431d3aSDiana Picus newInTys.push_back(std::get<mlir::Type>(tup)); 36265431d3aSDiana Picus } 36365431d3aSDiana Picus 3643fd250d2SDiana Picus /// Taking the address of a function. Modify the signature as needed. 3653fd250d2SDiana Picus void convertAddrOp(AddrOfOp addrOp) { 3663fd250d2SDiana Picus rewriter->setInsertionPoint(addrOp); 3673fd250d2SDiana Picus auto addrTy = addrOp.getType().cast<mlir::FunctionType>(); 3683fd250d2SDiana Picus llvm::SmallVector<mlir::Type> newResTys; 3693fd250d2SDiana Picus llvm::SmallVector<mlir::Type> newInTys; 3703fd250d2SDiana Picus for (mlir::Type ty : addrTy.getResults()) { 3713fd250d2SDiana Picus llvm::TypeSwitch<mlir::Type>(ty) 3723fd250d2SDiana Picus .Case<fir::ComplexType>([&](fir::ComplexType ty) { 3733fd250d2SDiana Picus lowerComplexSignatureRes(ty, newResTys, newInTys); 3743fd250d2SDiana Picus }) 3753fd250d2SDiana Picus .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 3763fd250d2SDiana Picus lowerComplexSignatureRes(ty, newResTys, newInTys); 3773fd250d2SDiana Picus }) 3783fd250d2SDiana Picus .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 3793fd250d2SDiana Picus } 3803fd250d2SDiana Picus llvm::SmallVector<mlir::Type> trailingInTys; 3813fd250d2SDiana Picus for (mlir::Type ty : addrTy.getInputs()) { 3823fd250d2SDiana Picus llvm::TypeSwitch<mlir::Type>(ty) 3833fd250d2SDiana Picus .Case<BoxCharType>([&](BoxCharType box) { 3843fd250d2SDiana Picus if (noCharacterConversion) { 3853fd250d2SDiana Picus newInTys.push_back(box); 3863fd250d2SDiana Picus } else { 3873fd250d2SDiana Picus for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) { 3883fd250d2SDiana Picus auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 3893fd250d2SDiana Picus auto argTy = std::get<mlir::Type>(tup); 3903fd250d2SDiana Picus llvm::SmallVector<mlir::Type> &vec = 3913fd250d2SDiana Picus attr.isAppend() ? trailingInTys : newInTys; 3923fd250d2SDiana Picus vec.push_back(argTy); 3933fd250d2SDiana Picus } 3943fd250d2SDiana Picus } 3953fd250d2SDiana Picus }) 3963fd250d2SDiana Picus .Case<fir::ComplexType>([&](fir::ComplexType ty) { 3973fd250d2SDiana Picus lowerComplexSignatureArg(ty, newInTys); 3983fd250d2SDiana Picus }) 3993fd250d2SDiana Picus .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 4003fd250d2SDiana Picus lowerComplexSignatureArg(ty, newInTys); 4013fd250d2SDiana Picus }) 402416e503aSJean Perier .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 403416e503aSJean Perier if (factory::isCharacterProcedureTuple(tuple)) { 404416e503aSJean Perier newInTys.push_back(tuple.getType(0)); 405416e503aSJean Perier trailingInTys.push_back(tuple.getType(1)); 406416e503aSJean Perier } else { 407416e503aSJean Perier newInTys.push_back(ty); 408416e503aSJean Perier } 409416e503aSJean Perier }) 4103fd250d2SDiana Picus .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 4113fd250d2SDiana Picus } 4123fd250d2SDiana Picus // append trailing input types 4133fd250d2SDiana Picus newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 4143fd250d2SDiana Picus // replace this op with a new one with the updated signature 4153fd250d2SDiana Picus auto newTy = rewriter->getFunctionType(newInTys, newResTys); 4163fd250d2SDiana Picus auto newOp = 4173fd250d2SDiana Picus rewriter->create<AddrOfOp>(addrOp.getLoc(), newTy, addrOp.symbol()); 4183fd250d2SDiana Picus replaceOp(addrOp, newOp.getResult()); 4193fd250d2SDiana Picus } 4203fd250d2SDiana Picus 4214c263edeSDiana Picus /// Convert the type signatures on all the functions present in the module. 4224c263edeSDiana Picus /// As the type signature is being changed, this must also update the 4234c263edeSDiana Picus /// function itself to use any new arguments, etc. 4244c263edeSDiana Picus mlir::LogicalResult convertTypes(mlir::ModuleOp mod) { 4254c263edeSDiana Picus for (auto fn : mod.getOps<mlir::FuncOp>()) 4264c263edeSDiana Picus convertSignature(fn); 4274c263edeSDiana Picus return mlir::success(); 4284c263edeSDiana Picus } 4294c263edeSDiana Picus 4304c263edeSDiana Picus /// If the signature does not need any special target-specific converions, 4314c263edeSDiana Picus /// then it is considered portable for any target, and this function will 4324c263edeSDiana Picus /// return `true`. Otherwise, the signature is not portable and `false` is 4334c263edeSDiana Picus /// returned. 4344c263edeSDiana Picus bool hasPortableSignature(mlir::Type signature) { 4354c263edeSDiana Picus assert(signature.isa<mlir::FunctionType>()); 4364c263edeSDiana Picus auto func = signature.dyn_cast<mlir::FunctionType>(); 4374c263edeSDiana Picus for (auto ty : func.getResults()) 43865431d3aSDiana Picus if ((ty.isa<BoxCharType>() && !noCharacterConversion) || 43965431d3aSDiana Picus (isa_complex(ty) && !noComplexConversion)) { 4404c263edeSDiana Picus LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 4414c263edeSDiana Picus return false; 4424c263edeSDiana Picus } 4434c263edeSDiana Picus for (auto ty : func.getInputs()) 444416e503aSJean Perier if (((ty.isa<BoxCharType>() || factory::isCharacterProcedureTuple(ty)) && 445416e503aSJean Perier !noCharacterConversion) || 44665431d3aSDiana Picus (isa_complex(ty) && !noComplexConversion)) { 4474c263edeSDiana Picus LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 4484c263edeSDiana Picus return false; 4494c263edeSDiana Picus } 4504c263edeSDiana Picus return true; 4514c263edeSDiana Picus } 4524c263edeSDiana Picus 4534c263edeSDiana Picus /// Rewrite the signatures and body of the `FuncOp`s in the module for 4544c263edeSDiana Picus /// the immediately subsequent target code gen. 4554c263edeSDiana Picus void convertSignature(mlir::FuncOp func) { 4564c263edeSDiana Picus auto funcTy = func.getType().cast<mlir::FunctionType>(); 4574c263edeSDiana Picus if (hasPortableSignature(funcTy)) 4584c263edeSDiana Picus return; 4594c263edeSDiana Picus llvm::SmallVector<mlir::Type> newResTys; 4604c263edeSDiana Picus llvm::SmallVector<mlir::Type> newInTys; 4614c263edeSDiana Picus llvm::SmallVector<FixupTy> fixups; 4624c263edeSDiana Picus 4634c263edeSDiana Picus // Convert return value(s) 4644c263edeSDiana Picus for (auto ty : funcTy.getResults()) 46565431d3aSDiana Picus llvm::TypeSwitch<mlir::Type>(ty) 46665431d3aSDiana Picus .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 46765431d3aSDiana Picus if (noComplexConversion) 46865431d3aSDiana Picus newResTys.push_back(cmplx); 46965431d3aSDiana Picus else 47065431d3aSDiana Picus doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 47165431d3aSDiana Picus }) 47265431d3aSDiana Picus .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 47365431d3aSDiana Picus if (noComplexConversion) 47465431d3aSDiana Picus newResTys.push_back(cmplx); 47565431d3aSDiana Picus else 47665431d3aSDiana Picus doComplexReturn(func, cmplx, newResTys, newInTys, fixups); 47765431d3aSDiana Picus }) 47865431d3aSDiana Picus .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 4794c263edeSDiana Picus 4804c263edeSDiana Picus // Convert arguments 4814c263edeSDiana Picus llvm::SmallVector<mlir::Type> trailingTys; 4824c263edeSDiana Picus for (auto e : llvm::enumerate(funcTy.getInputs())) { 4834c263edeSDiana Picus auto ty = e.value(); 4844c263edeSDiana Picus unsigned index = e.index(); 4854c263edeSDiana Picus llvm::TypeSwitch<mlir::Type>(ty) 4864c263edeSDiana Picus .Case<BoxCharType>([&](BoxCharType boxTy) { 4874c263edeSDiana Picus if (noCharacterConversion) { 4884c263edeSDiana Picus newInTys.push_back(boxTy); 4894c263edeSDiana Picus } else { 4904c263edeSDiana Picus // Convert a CHARACTER argument type. This can involve separating 4914c263edeSDiana Picus // the pointer and the LEN into two arguments and moving the LEN 4924c263edeSDiana Picus // argument to the end of the arg list. 4934c263edeSDiana Picus bool sret = functionArgIsSRet(index, func); 4944c263edeSDiana Picus for (auto e : llvm::enumerate(specifics->boxcharArgumentType( 4954c263edeSDiana Picus boxTy.getEleTy(), sret))) { 4964c263edeSDiana Picus auto &tup = e.value(); 4974c263edeSDiana Picus auto index = e.index(); 4984c263edeSDiana Picus auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 4994c263edeSDiana Picus auto argTy = std::get<mlir::Type>(tup); 5004c263edeSDiana Picus if (attr.isAppend()) { 5014c263edeSDiana Picus trailingTys.push_back(argTy); 5024c263edeSDiana Picus } else { 5034c263edeSDiana Picus if (sret) { 5044c263edeSDiana Picus fixups.emplace_back(FixupTy::Codes::CharPair, 5054c263edeSDiana Picus newInTys.size(), index); 5064c263edeSDiana Picus } else { 5074c263edeSDiana Picus fixups.emplace_back(FixupTy::Codes::Trailing, 5084c263edeSDiana Picus newInTys.size(), trailingTys.size()); 5094c263edeSDiana Picus } 5104c263edeSDiana Picus newInTys.push_back(argTy); 5114c263edeSDiana Picus } 5124c263edeSDiana Picus } 5134c263edeSDiana Picus } 5144c263edeSDiana Picus }) 51565431d3aSDiana Picus .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { 51665431d3aSDiana Picus if (noComplexConversion) 51765431d3aSDiana Picus newInTys.push_back(cmplx); 51865431d3aSDiana Picus else 51965431d3aSDiana Picus doComplexArg(func, cmplx, newInTys, fixups); 52065431d3aSDiana Picus }) 52165431d3aSDiana Picus .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 52265431d3aSDiana Picus if (noComplexConversion) 52365431d3aSDiana Picus newInTys.push_back(cmplx); 52465431d3aSDiana Picus else 52565431d3aSDiana Picus doComplexArg(func, cmplx, newInTys, fixups); 52665431d3aSDiana Picus }) 527416e503aSJean Perier .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 528416e503aSJean Perier if (factory::isCharacterProcedureTuple(tuple)) { 529416e503aSJean Perier fixups.emplace_back(FixupTy::Codes::TrailingCharProc, 530416e503aSJean Perier newInTys.size(), trailingTys.size()); 531416e503aSJean Perier newInTys.push_back(tuple.getType(0)); 532416e503aSJean Perier trailingTys.push_back(tuple.getType(1)); 533416e503aSJean Perier } else { 534416e503aSJean Perier newInTys.push_back(ty); 535416e503aSJean Perier } 536416e503aSJean Perier }) 5374c263edeSDiana Picus .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 5384c263edeSDiana Picus } 5394c263edeSDiana Picus 5404c263edeSDiana Picus if (!func.empty()) { 5414c263edeSDiana Picus // If the function has a body, then apply the fixups to the arguments and 5424c263edeSDiana Picus // return ops as required. These fixups are done in place. 5434c263edeSDiana Picus auto loc = func.getLoc(); 5444c263edeSDiana Picus const auto fixupSize = fixups.size(); 5454c263edeSDiana Picus const auto oldArgTys = func.getType().getInputs(); 5464c263edeSDiana Picus int offset = 0; 5474c263edeSDiana Picus for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) { 5484c263edeSDiana Picus const auto &fixup = fixups[i]; 5494c263edeSDiana Picus switch (fixup.code) { 55065431d3aSDiana Picus case FixupTy::Codes::ArgumentAsLoad: { 55165431d3aSDiana Picus // Argument was pass-by-value, but is now pass-by-reference and 55265431d3aSDiana Picus // possibly with a different element type. 553e084679fSRiver Riddle auto newArg = func.front().insertArgument(fixup.index, 554e084679fSRiver Riddle newInTys[fixup.index], loc); 55565431d3aSDiana Picus rewriter->setInsertionPointToStart(&func.front()); 55665431d3aSDiana Picus auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); 55765431d3aSDiana Picus auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, newArg); 55865431d3aSDiana Picus auto load = rewriter->create<fir::LoadOp>(loc, cast); 55965431d3aSDiana Picus func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 56065431d3aSDiana Picus func.front().eraseArgument(fixup.index + 1); 56165431d3aSDiana Picus } break; 56265431d3aSDiana Picus case FixupTy::Codes::ArgumentType: { 56365431d3aSDiana Picus // Argument is pass-by-value, but its type has likely been modified to 56465431d3aSDiana Picus // suit the target ABI convention. 565e084679fSRiver Riddle auto newArg = func.front().insertArgument(fixup.index, 566e084679fSRiver Riddle newInTys[fixup.index], loc); 56765431d3aSDiana Picus rewriter->setInsertionPointToStart(&func.front()); 56865431d3aSDiana Picus auto mem = 56965431d3aSDiana Picus rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]); 57065431d3aSDiana Picus rewriter->create<fir::StoreOp>(loc, newArg, mem); 57165431d3aSDiana Picus auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); 57265431d3aSDiana Picus auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, mem); 57365431d3aSDiana Picus mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast); 57465431d3aSDiana Picus func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 57565431d3aSDiana Picus func.front().eraseArgument(fixup.index + 1); 57665431d3aSDiana Picus LLVM_DEBUG(llvm::dbgs() 57765431d3aSDiana Picus << "old argument: " << oldArgTy.getEleTy() 57865431d3aSDiana Picus << ", repl: " << load << ", new argument: " 57965431d3aSDiana Picus << func.getArgument(fixup.index).getType() << '\n'); 58065431d3aSDiana Picus } break; 5814c263edeSDiana Picus case FixupTy::Codes::CharPair: { 5824c263edeSDiana Picus // The FIR boxchar argument has been split into a pair of distinct 5834c263edeSDiana Picus // arguments that are in juxtaposition to each other. 584e084679fSRiver Riddle auto newArg = func.front().insertArgument(fixup.index, 585e084679fSRiver Riddle newInTys[fixup.index], loc); 5864c263edeSDiana Picus if (fixup.second == 1) { 5874c263edeSDiana Picus rewriter->setInsertionPointToStart(&func.front()); 5884c263edeSDiana Picus auto boxTy = oldArgTys[fixup.index - offset - fixup.second]; 5894c263edeSDiana Picus auto box = rewriter->create<EmboxCharOp>( 5904c263edeSDiana Picus loc, boxTy, func.front().getArgument(fixup.index - 1), newArg); 5914c263edeSDiana Picus func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 5924c263edeSDiana Picus func.front().eraseArgument(fixup.index + 1); 5934c263edeSDiana Picus offset++; 5944c263edeSDiana Picus } 5954c263edeSDiana Picus } break; 59665431d3aSDiana Picus case FixupTy::Codes::ReturnAsStore: { 59765431d3aSDiana Picus // The value being returned is now being returned in memory (callee 59865431d3aSDiana Picus // stack space) through a hidden reference argument. 599e084679fSRiver Riddle auto newArg = func.front().insertArgument(fixup.index, 600e084679fSRiver Riddle newInTys[fixup.index], loc); 60165431d3aSDiana Picus offset++; 60265431d3aSDiana Picus func.walk([&](mlir::ReturnOp ret) { 60365431d3aSDiana Picus rewriter->setInsertionPoint(ret); 60465431d3aSDiana Picus auto oldOper = ret.getOperand(0); 60565431d3aSDiana Picus auto oldOperTy = ReferenceType::get(oldOper.getType()); 60665431d3aSDiana Picus auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, newArg); 60765431d3aSDiana Picus rewriter->create<fir::StoreOp>(loc, oldOper, cast); 60865431d3aSDiana Picus rewriter->create<mlir::ReturnOp>(loc); 60965431d3aSDiana Picus ret.erase(); 61065431d3aSDiana Picus }); 61165431d3aSDiana Picus } break; 61265431d3aSDiana Picus case FixupTy::Codes::ReturnType: { 61365431d3aSDiana Picus // The function is still returning a value, but its type has likely 61465431d3aSDiana Picus // changed to suit the target ABI convention. 61565431d3aSDiana Picus func.walk([&](mlir::ReturnOp ret) { 61665431d3aSDiana Picus rewriter->setInsertionPoint(ret); 61765431d3aSDiana Picus auto oldOper = ret.getOperand(0); 61865431d3aSDiana Picus auto oldOperTy = ReferenceType::get(oldOper.getType()); 61965431d3aSDiana Picus auto mem = 62065431d3aSDiana Picus rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]); 62165431d3aSDiana Picus auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, mem); 62265431d3aSDiana Picus rewriter->create<fir::StoreOp>(loc, oldOper, cast); 62365431d3aSDiana Picus mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem); 62465431d3aSDiana Picus rewriter->create<mlir::ReturnOp>(loc, load); 62565431d3aSDiana Picus ret.erase(); 62665431d3aSDiana Picus }); 62765431d3aSDiana Picus } break; 62865431d3aSDiana Picus case FixupTy::Codes::Split: { 62965431d3aSDiana Picus // The FIR argument has been split into a pair of distinct arguments 63065431d3aSDiana Picus // that are in juxtaposition to each other. (For COMPLEX value.) 631e084679fSRiver Riddle auto newArg = func.front().insertArgument(fixup.index, 632e084679fSRiver Riddle newInTys[fixup.index], loc); 63365431d3aSDiana Picus if (fixup.second == 1) { 63465431d3aSDiana Picus rewriter->setInsertionPointToStart(&func.front()); 63565431d3aSDiana Picus auto cplxTy = oldArgTys[fixup.index - offset - fixup.second]; 63665431d3aSDiana Picus auto undef = rewriter->create<UndefOp>(loc, cplxTy); 637*9b5bb511SValentin Clement auto iTy = rewriter->getIntegerType(32); 638*9b5bb511SValentin Clement auto zero = rewriter->getIntegerAttr(iTy, 0); 639*9b5bb511SValentin Clement auto one = rewriter->getIntegerAttr(iTy, 1); 64065431d3aSDiana Picus auto cplx1 = rewriter->create<InsertValueOp>( 64165431d3aSDiana Picus loc, cplxTy, undef, func.front().getArgument(fixup.index - 1), 64265431d3aSDiana Picus rewriter->getArrayAttr(zero)); 64365431d3aSDiana Picus auto cplx = rewriter->create<InsertValueOp>( 64465431d3aSDiana Picus loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one)); 64565431d3aSDiana Picus func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx); 64665431d3aSDiana Picus func.front().eraseArgument(fixup.index + 1); 64765431d3aSDiana Picus offset++; 64865431d3aSDiana Picus } 64965431d3aSDiana Picus } break; 6504c263edeSDiana Picus case FixupTy::Codes::Trailing: { 6514c263edeSDiana Picus // The FIR argument has been split into a pair of distinct arguments. 6524c263edeSDiana Picus // The first part of the pair appears in the original argument 6534c263edeSDiana Picus // position. The second part of the pair is appended after all the 6544c263edeSDiana Picus // original arguments. (Boxchar arguments.) 655e084679fSRiver Riddle auto newBufArg = func.front().insertArgument( 656e084679fSRiver Riddle fixup.index, newInTys[fixup.index], loc); 657e084679fSRiver Riddle auto newLenArg = 658e084679fSRiver Riddle func.front().addArgument(trailingTys[fixup.second], loc); 6594c263edeSDiana Picus auto boxTy = oldArgTys[fixup.index - offset]; 6604c263edeSDiana Picus rewriter->setInsertionPointToStart(&func.front()); 6614c263edeSDiana Picus auto box = 6624c263edeSDiana Picus rewriter->create<EmboxCharOp>(loc, boxTy, newBufArg, newLenArg); 6634c263edeSDiana Picus func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 6644c263edeSDiana Picus func.front().eraseArgument(fixup.index + 1); 6654c263edeSDiana Picus } break; 666416e503aSJean Perier case FixupTy::Codes::TrailingCharProc: { 667416e503aSJean Perier // The FIR character procedure argument tuple has been split into a 668416e503aSJean Perier // pair of distinct arguments. The first part of the pair appears in 669416e503aSJean Perier // the original argument position. The second part of the pair is 670416e503aSJean Perier // appended after all the original arguments. 671416e503aSJean Perier auto newProcPointerArg = func.front().insertArgument( 672416e503aSJean Perier fixup.index, newInTys[fixup.index], loc); 673416e503aSJean Perier auto newLenArg = 674416e503aSJean Perier func.front().addArgument(trailingTys[fixup.second], loc); 675416e503aSJean Perier auto tupleType = oldArgTys[fixup.index - offset]; 676416e503aSJean Perier rewriter->setInsertionPointToStart(&func.front()); 677416e503aSJean Perier FirOpBuilder builder(*rewriter, getKindMapping(getModule())); 678416e503aSJean Perier auto tuple = factory::createCharacterProcedureTuple( 679416e503aSJean Perier builder, loc, tupleType, newProcPointerArg, newLenArg); 680416e503aSJean Perier func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple); 681416e503aSJean Perier func.front().eraseArgument(fixup.index + 1); 682416e503aSJean Perier } break; 6834c263edeSDiana Picus } 6844c263edeSDiana Picus } 6854c263edeSDiana Picus } 6864c263edeSDiana Picus 6874c263edeSDiana Picus // Set the new type and finalize the arguments, etc. 6884c263edeSDiana Picus newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end()); 6894c263edeSDiana Picus auto newFuncTy = 6904c263edeSDiana Picus mlir::FunctionType::get(func.getContext(), newInTys, newResTys); 6914c263edeSDiana Picus LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n'); 6924c263edeSDiana Picus func.setType(newFuncTy); 6934c263edeSDiana Picus 6944c263edeSDiana Picus for (auto &fixup : fixups) 6954c263edeSDiana Picus if (fixup.finalizer) 6964c263edeSDiana Picus (*fixup.finalizer)(func); 6974c263edeSDiana Picus } 6984c263edeSDiana Picus 6994c263edeSDiana Picus inline bool functionArgIsSRet(unsigned index, mlir::FuncOp func) { 7004c263edeSDiana Picus if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret")) 7014c263edeSDiana Picus return true; 7024c263edeSDiana Picus return false; 7034c263edeSDiana Picus } 7044c263edeSDiana Picus 70565431d3aSDiana Picus /// Convert a complex return value. This can involve converting the return 70665431d3aSDiana Picus /// value to a "hidden" first argument or packing the complex into a wide 70765431d3aSDiana Picus /// GPR. 70865431d3aSDiana Picus template <typename A, typename B, typename C> 70965431d3aSDiana Picus void doComplexReturn(mlir::FuncOp func, A cmplx, B &newResTys, B &newInTys, 71065431d3aSDiana Picus C &fixups) { 71165431d3aSDiana Picus if (noComplexConversion) { 71265431d3aSDiana Picus newResTys.push_back(cmplx); 71365431d3aSDiana Picus return; 71465431d3aSDiana Picus } 71565431d3aSDiana Picus auto m = specifics->complexReturnType(cmplx.getElementType()); 71665431d3aSDiana Picus assert(m.size() == 1); 71765431d3aSDiana Picus auto &tup = m[0]; 71865431d3aSDiana Picus auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 71965431d3aSDiana Picus auto argTy = std::get<mlir::Type>(tup); 72065431d3aSDiana Picus if (attr.isSRet()) { 72165431d3aSDiana Picus unsigned argNo = newInTys.size(); 72265431d3aSDiana Picus fixups.emplace_back( 72365431d3aSDiana Picus FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::FuncOp func) { 72465431d3aSDiana Picus func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); 72565431d3aSDiana Picus }); 72665431d3aSDiana Picus newInTys.push_back(argTy); 72765431d3aSDiana Picus return; 72865431d3aSDiana Picus } 72965431d3aSDiana Picus fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size()); 73065431d3aSDiana Picus newResTys.push_back(argTy); 73165431d3aSDiana Picus } 73265431d3aSDiana Picus 73365431d3aSDiana Picus /// Convert a complex argument value. This can involve storing the value to 73465431d3aSDiana Picus /// a temporary memory location or factoring the value into two distinct 73565431d3aSDiana Picus /// arguments. 73665431d3aSDiana Picus template <typename A, typename B, typename C> 73765431d3aSDiana Picus void doComplexArg(mlir::FuncOp func, A cmplx, B &newInTys, C &fixups) { 73865431d3aSDiana Picus if (noComplexConversion) { 73965431d3aSDiana Picus newInTys.push_back(cmplx); 74065431d3aSDiana Picus return; 74165431d3aSDiana Picus } 74265431d3aSDiana Picus auto m = specifics->complexArgumentType(cmplx.getElementType()); 74365431d3aSDiana Picus const auto fixupCode = 74465431d3aSDiana Picus m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType; 74565431d3aSDiana Picus for (auto e : llvm::enumerate(m)) { 74665431d3aSDiana Picus auto &tup = e.value(); 74765431d3aSDiana Picus auto index = e.index(); 74865431d3aSDiana Picus auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 74965431d3aSDiana Picus auto argTy = std::get<mlir::Type>(tup); 75065431d3aSDiana Picus auto argNo = newInTys.size(); 75165431d3aSDiana Picus if (attr.isByVal()) { 75265431d3aSDiana Picus if (auto align = attr.getAlignment()) 75365431d3aSDiana Picus fixups.emplace_back( 75465431d3aSDiana Picus FixupTy::Codes::ArgumentAsLoad, argNo, [=](mlir::FuncOp func) { 75565431d3aSDiana Picus func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr()); 75665431d3aSDiana Picus func.setArgAttr(argNo, "llvm.align", 75765431d3aSDiana Picus rewriter->getIntegerAttr( 75865431d3aSDiana Picus rewriter->getIntegerType(32), align)); 75965431d3aSDiana Picus }); 76065431d3aSDiana Picus else 76165431d3aSDiana Picus fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(), 76265431d3aSDiana Picus [=](mlir::FuncOp func) { 76365431d3aSDiana Picus func.setArgAttr(argNo, "llvm.byval", 76465431d3aSDiana Picus rewriter->getUnitAttr()); 76565431d3aSDiana Picus }); 76665431d3aSDiana Picus } else { 76765431d3aSDiana Picus if (auto align = attr.getAlignment()) 76865431d3aSDiana Picus fixups.emplace_back(fixupCode, argNo, index, [=](mlir::FuncOp func) { 76965431d3aSDiana Picus func.setArgAttr( 77065431d3aSDiana Picus argNo, "llvm.align", 77165431d3aSDiana Picus rewriter->getIntegerAttr(rewriter->getIntegerType(32), align)); 77265431d3aSDiana Picus }); 77365431d3aSDiana Picus else 77465431d3aSDiana Picus fixups.emplace_back(fixupCode, argNo, index); 77565431d3aSDiana Picus } 77665431d3aSDiana Picus newInTys.push_back(argTy); 77765431d3aSDiana Picus } 77865431d3aSDiana Picus } 77965431d3aSDiana Picus 7804c263edeSDiana Picus private: 7814c263edeSDiana Picus // Replace `op` and remove it. 7824c263edeSDiana Picus void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { 7834c263edeSDiana Picus op->replaceAllUsesWith(newValues); 7844c263edeSDiana Picus op->dropAllReferences(); 7854c263edeSDiana Picus op->erase(); 7864c263edeSDiana Picus } 7874c263edeSDiana Picus 7884c263edeSDiana Picus inline void setMembers(CodeGenSpecifics *s, mlir::OpBuilder *r) { 7894c263edeSDiana Picus specifics = s; 7904c263edeSDiana Picus rewriter = r; 7914c263edeSDiana Picus } 7924c263edeSDiana Picus 7934c263edeSDiana Picus inline void clearMembers() { setMembers(nullptr, nullptr); } 7944c263edeSDiana Picus 7954c263edeSDiana Picus CodeGenSpecifics *specifics{}; 7964c263edeSDiana Picus mlir::OpBuilder *rewriter; 7974c263edeSDiana Picus }; // namespace 7984c263edeSDiana Picus } // namespace 7994c263edeSDiana Picus 8004c263edeSDiana Picus std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 8014c263edeSDiana Picus fir::createFirTargetRewritePass(const TargetRewriteOptions &options) { 8024c263edeSDiana Picus return std::make_unique<TargetRewrite>(options); 8034c263edeSDiana Picus } 804