1 //===-- TargetRewrite.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest.
10 // LLVM expects different lowering idioms to be used for distinct target
11 // triples. These distinctions are handled by this pass.
12 //
13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "PassDetail.h"
18 #include "Target.h"
19 #include "flang/Optimizer/Builder/Character.h"
20 #include "flang/Optimizer/Builder/FIRBuilder.h"
21 #include "flang/Optimizer/Builder/Todo.h"
22 #include "flang/Optimizer/CodeGen/CodeGen.h"
23 #include "flang/Optimizer/Dialect/FIRDialect.h"
24 #include "flang/Optimizer/Dialect/FIROps.h"
25 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
26 #include "flang/Optimizer/Dialect/FIRType.h"
27 #include "flang/Optimizer/Support/FIRContext.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
32 
33 #define DEBUG_TYPE "flang-target-rewrite"
34 
35 namespace {
36 
37 /// Fixups for updating a FuncOp's arguments and return values.
38 struct FixupTy {
39   enum class Codes {
40     ArgumentAsLoad,
41     ArgumentType,
42     CharPair,
43     ReturnAsStore,
44     ReturnType,
45     Split,
46     Trailing,
47     TrailingCharProc
48   };
49 
50   FixupTy(Codes code, std::size_t index, std::size_t second = 0)
51       : code{code}, index{index}, second{second} {}
52   FixupTy(Codes code, std::size_t index,
53           std::function<void(mlir::func::FuncOp)> &&finalizer)
54       : code{code}, index{index}, finalizer{finalizer} {}
55   FixupTy(Codes code, std::size_t index, std::size_t second,
56           std::function<void(mlir::func::FuncOp)> &&finalizer)
57       : code{code}, index{index}, second{second}, finalizer{finalizer} {}
58 
59   Codes code;
60   std::size_t index;
61   std::size_t second{};
62   llvm::Optional<std::function<void(mlir::func::FuncOp)>> finalizer{};
63 }; // namespace
64 
65 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code
66 /// generation that traverses the FIR and modifies types and operations to a
67 /// form that is appropriate for the specific target. LLVM IR has specific
68 /// idioms that are used for distinct target processor and ABI combinations.
69 class TargetRewrite : public fir::TargetRewriteBase<TargetRewrite> {
70 public:
71   TargetRewrite(const fir::TargetRewriteOptions &options) {
72     noCharacterConversion = options.noCharacterConversion;
73     noComplexConversion = options.noComplexConversion;
74   }
75 
76   void runOnOperation() override final {
77     auto &context = getContext();
78     mlir::OpBuilder rewriter(&context);
79 
80     auto mod = getModule();
81     if (!forcedTargetTriple.empty())
82       fir::setTargetTriple(mod, forcedTargetTriple);
83 
84     auto specifics = fir::CodeGenSpecifics::get(
85         mod.getContext(), fir::getTargetTriple(mod), fir::getKindMapping(mod));
86     setMembers(specifics.get(), &rewriter);
87 
88     // Perform type conversion on signatures and call sites.
89     if (mlir::failed(convertTypes(mod))) {
90       mlir::emitError(mlir::UnknownLoc::get(&context),
91                       "error in converting types to target abi");
92       signalPassFailure();
93     }
94 
95     // Convert ops in target-specific patterns.
96     mod.walk([&](mlir::Operation *op) {
97       if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
98         if (!hasPortableSignature(call.getFunctionType()))
99           convertCallOp(call);
100       } else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) {
101         if (!hasPortableSignature(dispatch.getFunctionType()))
102           convertCallOp(dispatch);
103       } else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) {
104         if (addr.getType().isa<mlir::FunctionType>() &&
105             !hasPortableSignature(addr.getType()))
106           convertAddrOp(addr);
107       }
108     });
109 
110     clearMembers();
111   }
112 
113   mlir::ModuleOp getModule() { return getOperation(); }
114 
115   template <typename A, typename B, typename C>
116   std::function<mlir::Value(mlir::Operation *)>
117   rewriteCallComplexResultType(A ty, B &newResTys, B &newInTys, C &newOpers) {
118     auto m = specifics->complexReturnType(ty.getElementType());
119     // Currently targets mandate COMPLEX is a single aggregate or packed
120     // scalar, including the sret case.
121     assert(m.size() == 1 && "target lowering of complex return not supported");
122     auto resTy = std::get<mlir::Type>(m[0]);
123     auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
124     auto loc = mlir::UnknownLoc::get(resTy.getContext());
125     if (attr.isSRet()) {
126       assert(fir::isa_ref_type(resTy) && "must be a memory reference type");
127       mlir::Value stack =
128           rewriter->create<fir::AllocaOp>(loc, fir::dyn_cast_ptrEleTy(resTy));
129       newInTys.push_back(resTy);
130       newOpers.push_back(stack);
131       return [=](mlir::Operation *) -> mlir::Value {
132         auto memTy = fir::ReferenceType::get(ty);
133         auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, stack);
134         return rewriter->create<fir::LoadOp>(loc, cast);
135       };
136     }
137     newResTys.push_back(resTy);
138     return [=](mlir::Operation *call) -> mlir::Value {
139       auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
140       rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem);
141       auto memTy = fir::ReferenceType::get(ty);
142       auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, mem);
143       return rewriter->create<fir::LoadOp>(loc, cast);
144     };
145   }
146 
147   template <typename A, typename B, typename C>
148   void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys,
149                                    C &newOpers) {
150     auto m = specifics->complexArgumentType(ty.getElementType());
151     auto *ctx = ty.getContext();
152     auto loc = mlir::UnknownLoc::get(ctx);
153     if (m.size() == 1) {
154       // COMPLEX is a single aggregate
155       auto resTy = std::get<mlir::Type>(m[0]);
156       auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
157       auto oldRefTy = fir::ReferenceType::get(ty);
158       if (attr.isByVal()) {
159         auto mem = rewriter->create<fir::AllocaOp>(loc, ty);
160         rewriter->create<fir::StoreOp>(loc, oper, mem);
161         newOpers.push_back(rewriter->create<fir::ConvertOp>(loc, resTy, mem));
162       } else {
163         auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
164         auto cast = rewriter->create<fir::ConvertOp>(loc, oldRefTy, mem);
165         rewriter->create<fir::StoreOp>(loc, oper, cast);
166         newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem));
167       }
168       newInTys.push_back(resTy);
169     } else {
170       assert(m.size() == 2);
171       // COMPLEX is split into 2 separate arguments
172       auto iTy = rewriter->getIntegerType(32);
173       for (auto e : llvm::enumerate(m)) {
174         auto &tup = e.value();
175         auto ty = std::get<mlir::Type>(tup);
176         auto index = e.index();
177         auto idx = rewriter->getIntegerAttr(iTy, index);
178         auto val = rewriter->create<fir::ExtractValueOp>(
179             loc, ty, oper, rewriter->getArrayAttr(idx));
180         newInTys.push_back(ty);
181         newOpers.push_back(val);
182       }
183     }
184   }
185 
186   // Convert fir.call and fir.dispatch Ops.
187   template <typename A>
188   void convertCallOp(A callOp) {
189     auto fnTy = callOp.getFunctionType();
190     auto loc = callOp.getLoc();
191     rewriter->setInsertionPoint(callOp);
192     llvm::SmallVector<mlir::Type> newResTys;
193     llvm::SmallVector<mlir::Type> newInTys;
194     llvm::SmallVector<mlir::Value> newOpers;
195 
196     // If the call is indirect, the first argument must still be the function
197     // to call.
198     int dropFront = 0;
199     if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
200       if (!callOp.getCallee().hasValue()) {
201         newInTys.push_back(fnTy.getInput(0));
202         newOpers.push_back(callOp.getOperand(0));
203         dropFront = 1;
204       }
205     }
206 
207     // Determine the rewrite function, `wrap`, for the result value.
208     llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap;
209     if (fnTy.getResults().size() == 1) {
210       mlir::Type ty = fnTy.getResult(0);
211       llvm::TypeSwitch<mlir::Type>(ty)
212           .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
213             wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys,
214                                                 newOpers);
215           })
216           .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
217             wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys,
218                                                 newOpers);
219           })
220           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
221     } else if (fnTy.getResults().size() > 1) {
222       TODO(loc, "multiple results not supported yet");
223     }
224 
225     llvm::SmallVector<mlir::Type> trailingInTys;
226     llvm::SmallVector<mlir::Value> trailingOpers;
227     for (auto e : llvm::enumerate(
228              llvm::zip(fnTy.getInputs().drop_front(dropFront),
229                        callOp.getOperands().drop_front(dropFront)))) {
230       mlir::Type ty = std::get<0>(e.value());
231       mlir::Value oper = std::get<1>(e.value());
232       unsigned index = e.index();
233       llvm::TypeSwitch<mlir::Type>(ty)
234           .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
235             bool sret;
236             if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
237               sret = callOp.getCallee() &&
238                      functionArgIsSRet(
239                          index, getModule().lookupSymbol<mlir::func::FuncOp>(
240                                     *callOp.getCallee()));
241             } else {
242               // TODO: dispatch case; how do we put arguments on a call?
243               // We cannot put both an sret and the dispatch object first.
244               sret = false;
245               TODO(loc, "dispatch + sret not supported yet");
246             }
247             auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret);
248             auto unbox = rewriter->create<fir::UnboxCharOp>(
249                 loc, std::get<mlir::Type>(m[0]), std::get<mlir::Type>(m[1]),
250                 oper);
251             // unboxed CHARACTER arguments
252             for (auto e : llvm::enumerate(m)) {
253               unsigned idx = e.index();
254               auto attr =
255                   std::get<fir::CodeGenSpecifics::Attributes>(e.value());
256               auto argTy = std::get<mlir::Type>(e.value());
257               if (attr.isAppend()) {
258                 trailingInTys.push_back(argTy);
259                 trailingOpers.push_back(unbox.getResult(idx));
260               } else {
261                 newInTys.push_back(argTy);
262                 newOpers.push_back(unbox.getResult(idx));
263               }
264             }
265           })
266           .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
267             rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers);
268           })
269           .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
270             rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers);
271           })
272           .template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
273             if (fir::isCharacterProcedureTuple(tuple)) {
274               mlir::ModuleOp module = getModule();
275               if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
276                 if (callOp.getCallee()) {
277                   llvm::StringRef charProcAttr =
278                       fir::getCharacterProcedureDummyAttrName();
279                   // The charProcAttr attribute is only used as a safety to
280                   // confirm that this is a dummy procedure and should be split.
281                   // It cannot be used to match because attributes are not
282                   // available in case of indirect calls.
283                   auto funcOp = module.lookupSymbol<mlir::func::FuncOp>(
284                       *callOp.getCallee());
285                   if (funcOp &&
286                       !funcOp.template getArgAttrOfType<mlir::UnitAttr>(
287                           index, charProcAttr))
288                     mlir::emitError(loc, "tuple argument will be split even "
289                                          "though it does not have the `" +
290                                              charProcAttr + "` attribute");
291                 }
292               }
293               mlir::Type funcPointerType = tuple.getType(0);
294               mlir::Type lenType = tuple.getType(1);
295               fir::FirOpBuilder builder(*rewriter, fir::getKindMapping(module));
296               auto [funcPointer, len] =
297                   fir::factory::extractCharacterProcedureTuple(builder, loc,
298                                                                oper);
299               newInTys.push_back(funcPointerType);
300               newOpers.push_back(funcPointer);
301               trailingInTys.push_back(lenType);
302               trailingOpers.push_back(len);
303             } else {
304               newInTys.push_back(tuple);
305               newOpers.push_back(oper);
306             }
307           })
308           .Default([&](mlir::Type ty) {
309             newInTys.push_back(ty);
310             newOpers.push_back(oper);
311           });
312     }
313     newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
314     newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
315     if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
316       fir::CallOp newCall;
317       if (callOp.getCallee().hasValue()) {
318         newCall = rewriter->create<A>(loc, callOp.getCallee().getValue(),
319                                       newResTys, newOpers);
320       } else {
321         // Force new type on the input operand.
322         newOpers[0].setType(mlir::FunctionType::get(
323             callOp.getContext(),
324             mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys));
325         newCall = rewriter->create<A>(loc, newResTys, newOpers);
326       }
327       LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
328       if (wrap.hasValue())
329         replaceOp(callOp, (*wrap)(newCall.getOperation()));
330       else
331         replaceOp(callOp, newCall.getResults());
332     } else {
333       // A is fir::DispatchOp
334       TODO(loc, "dispatch not implemented");
335     }
336   }
337 
338   // Result type fixup for fir::ComplexType and mlir::ComplexType
339   template <typename A, typename B>
340   void lowerComplexSignatureRes(A cmplx, B &newResTys, B &newInTys) {
341     if (noComplexConversion) {
342       newResTys.push_back(cmplx);
343     } else {
344       for (auto &tup : specifics->complexReturnType(cmplx.getElementType())) {
345         auto argTy = std::get<mlir::Type>(tup);
346         if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet())
347           newInTys.push_back(argTy);
348         else
349           newResTys.push_back(argTy);
350       }
351     }
352   }
353 
354   // Argument type fixup for fir::ComplexType and mlir::ComplexType
355   template <typename A, typename B>
356   void lowerComplexSignatureArg(A cmplx, B &newInTys) {
357     if (noComplexConversion)
358       newInTys.push_back(cmplx);
359     else
360       for (auto &tup : specifics->complexArgumentType(cmplx.getElementType()))
361         newInTys.push_back(std::get<mlir::Type>(tup));
362   }
363 
364   /// Taking the address of a function. Modify the signature as needed.
365   void convertAddrOp(fir::AddrOfOp addrOp) {
366     rewriter->setInsertionPoint(addrOp);
367     auto addrTy = addrOp.getType().cast<mlir::FunctionType>();
368     llvm::SmallVector<mlir::Type> newResTys;
369     llvm::SmallVector<mlir::Type> newInTys;
370     for (mlir::Type ty : addrTy.getResults()) {
371       llvm::TypeSwitch<mlir::Type>(ty)
372           .Case<fir::ComplexType>([&](fir::ComplexType ty) {
373             lowerComplexSignatureRes(ty, newResTys, newInTys);
374           })
375           .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
376             lowerComplexSignatureRes(ty, newResTys, newInTys);
377           })
378           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
379     }
380     llvm::SmallVector<mlir::Type> trailingInTys;
381     for (mlir::Type ty : addrTy.getInputs()) {
382       llvm::TypeSwitch<mlir::Type>(ty)
383           .Case<fir::BoxCharType>([&](auto box) {
384             if (noCharacterConversion) {
385               newInTys.push_back(box);
386             } else {
387               for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) {
388                 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
389                 auto argTy = std::get<mlir::Type>(tup);
390                 llvm::SmallVector<mlir::Type> &vec =
391                     attr.isAppend() ? trailingInTys : newInTys;
392                 vec.push_back(argTy);
393               }
394             }
395           })
396           .Case<fir::ComplexType>([&](fir::ComplexType ty) {
397             lowerComplexSignatureArg(ty, newInTys);
398           })
399           .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
400             lowerComplexSignatureArg(ty, newInTys);
401           })
402           .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
403             if (fir::isCharacterProcedureTuple(tuple)) {
404               newInTys.push_back(tuple.getType(0));
405               trailingInTys.push_back(tuple.getType(1));
406             } else {
407               newInTys.push_back(ty);
408             }
409           })
410           .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
411     }
412     // append trailing input types
413     newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
414     // replace this op with a new one with the updated signature
415     auto newTy = rewriter->getFunctionType(newInTys, newResTys);
416     auto newOp = rewriter->create<fir::AddrOfOp>(addrOp.getLoc(), newTy,
417                                                  addrOp.getSymbol());
418     replaceOp(addrOp, newOp.getResult());
419   }
420 
421   /// Convert the type signatures on all the functions present in the module.
422   /// As the type signature is being changed, this must also update the
423   /// function itself to use any new arguments, etc.
424   mlir::LogicalResult convertTypes(mlir::ModuleOp mod) {
425     for (auto fn : mod.getOps<mlir::func::FuncOp>())
426       convertSignature(fn);
427     return mlir::success();
428   }
429 
430   /// If the signature does not need any special target-specific converions,
431   /// then it is considered portable for any target, and this function will
432   /// return `true`. Otherwise, the signature is not portable and `false` is
433   /// returned.
434   bool hasPortableSignature(mlir::Type signature) {
435     assert(signature.isa<mlir::FunctionType>());
436     auto func = signature.dyn_cast<mlir::FunctionType>();
437     for (auto ty : func.getResults())
438       if ((ty.isa<fir::BoxCharType>() && !noCharacterConversion) ||
439           (fir::isa_complex(ty) && !noComplexConversion)) {
440         LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
441         return false;
442       }
443     for (auto ty : func.getInputs())
444       if (((ty.isa<fir::BoxCharType>() || fir::isCharacterProcedureTuple(ty)) &&
445            !noCharacterConversion) ||
446           (fir::isa_complex(ty) && !noComplexConversion)) {
447         LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
448         return false;
449       }
450     return true;
451   }
452 
453   /// Determine if the signature has host associations. The host association
454   /// argument may need special target specific rewriting.
455   static bool hasHostAssociations(mlir::func::FuncOp func) {
456     std::size_t end = func.getFunctionType().getInputs().size();
457     for (std::size_t i = 0; i < end; ++i)
458       if (func.getArgAttrOfType<mlir::UnitAttr>(i, fir::getHostAssocAttrName()))
459         return true;
460     return false;
461   }
462 
463   /// Rewrite the signatures and body of the `FuncOp`s in the module for
464   /// the immediately subsequent target code gen.
465   void convertSignature(mlir::func::FuncOp func) {
466     auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
467     if (hasPortableSignature(funcTy) && !hasHostAssociations(func))
468       return;
469     llvm::SmallVector<mlir::Type> newResTys;
470     llvm::SmallVector<mlir::Type> newInTys;
471     llvm::SmallVector<FixupTy> fixups;
472 
473     // Convert return value(s)
474     for (auto ty : funcTy.getResults())
475       llvm::TypeSwitch<mlir::Type>(ty)
476           .Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
477             if (noComplexConversion)
478               newResTys.push_back(cmplx);
479             else
480               doComplexReturn(func, cmplx, newResTys, newInTys, fixups);
481           })
482           .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
483             if (noComplexConversion)
484               newResTys.push_back(cmplx);
485             else
486               doComplexReturn(func, cmplx, newResTys, newInTys, fixups);
487           })
488           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
489 
490     // Convert arguments
491     llvm::SmallVector<mlir::Type> trailingTys;
492     for (auto e : llvm::enumerate(funcTy.getInputs())) {
493       auto ty = e.value();
494       unsigned index = e.index();
495       llvm::TypeSwitch<mlir::Type>(ty)
496           .Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
497             if (noCharacterConversion) {
498               newInTys.push_back(boxTy);
499             } else {
500               // Convert a CHARACTER argument type. This can involve separating
501               // the pointer and the LEN into two arguments and moving the LEN
502               // argument to the end of the arg list.
503               bool sret = functionArgIsSRet(index, func);
504               for (auto e : llvm::enumerate(specifics->boxcharArgumentType(
505                        boxTy.getEleTy(), sret))) {
506                 auto &tup = e.value();
507                 auto index = e.index();
508                 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
509                 auto argTy = std::get<mlir::Type>(tup);
510                 if (attr.isAppend()) {
511                   trailingTys.push_back(argTy);
512                 } else {
513                   if (sret) {
514                     fixups.emplace_back(FixupTy::Codes::CharPair,
515                                         newInTys.size(), index);
516                   } else {
517                     fixups.emplace_back(FixupTy::Codes::Trailing,
518                                         newInTys.size(), trailingTys.size());
519                   }
520                   newInTys.push_back(argTy);
521                 }
522               }
523             }
524           })
525           .Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
526             if (noComplexConversion)
527               newInTys.push_back(cmplx);
528             else
529               doComplexArg(func, cmplx, newInTys, fixups);
530           })
531           .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
532             if (noComplexConversion)
533               newInTys.push_back(cmplx);
534             else
535               doComplexArg(func, cmplx, newInTys, fixups);
536           })
537           .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
538             if (fir::isCharacterProcedureTuple(tuple)) {
539               fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
540                                   newInTys.size(), trailingTys.size());
541               newInTys.push_back(tuple.getType(0));
542               trailingTys.push_back(tuple.getType(1));
543             } else {
544               newInTys.push_back(ty);
545             }
546           })
547           .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
548       if (func.getArgAttrOfType<mlir::UnitAttr>(index,
549                                                 fir::getHostAssocAttrName())) {
550         func.setArgAttr(index, "llvm.nest", rewriter->getUnitAttr());
551       }
552     }
553 
554     if (!func.empty()) {
555       // If the function has a body, then apply the fixups to the arguments and
556       // return ops as required. These fixups are done in place.
557       auto loc = func.getLoc();
558       const auto fixupSize = fixups.size();
559       const auto oldArgTys = func.getFunctionType().getInputs();
560       int offset = 0;
561       for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) {
562         const auto &fixup = fixups[i];
563         switch (fixup.code) {
564         case FixupTy::Codes::ArgumentAsLoad: {
565           // Argument was pass-by-value, but is now pass-by-reference and
566           // possibly with a different element type.
567           auto newArg = func.front().insertArgument(fixup.index,
568                                                     newInTys[fixup.index], loc);
569           rewriter->setInsertionPointToStart(&func.front());
570           auto oldArgTy =
571               fir::ReferenceType::get(oldArgTys[fixup.index - offset]);
572           auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, newArg);
573           auto load = rewriter->create<fir::LoadOp>(loc, cast);
574           func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
575           func.front().eraseArgument(fixup.index + 1);
576         } break;
577         case FixupTy::Codes::ArgumentType: {
578           // Argument is pass-by-value, but its type has likely been modified to
579           // suit the target ABI convention.
580           auto newArg = func.front().insertArgument(fixup.index,
581                                                     newInTys[fixup.index], loc);
582           rewriter->setInsertionPointToStart(&func.front());
583           auto mem =
584               rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]);
585           rewriter->create<fir::StoreOp>(loc, newArg, mem);
586           auto oldArgTy =
587               fir::ReferenceType::get(oldArgTys[fixup.index - offset]);
588           auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, mem);
589           mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast);
590           func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
591           func.front().eraseArgument(fixup.index + 1);
592           LLVM_DEBUG(llvm::dbgs()
593                      << "old argument: " << oldArgTy.getEleTy()
594                      << ", repl: " << load << ", new argument: "
595                      << func.getArgument(fixup.index).getType() << '\n');
596         } break;
597         case FixupTy::Codes::CharPair: {
598           // The FIR boxchar argument has been split into a pair of distinct
599           // arguments that are in juxtaposition to each other.
600           auto newArg = func.front().insertArgument(fixup.index,
601                                                     newInTys[fixup.index], loc);
602           if (fixup.second == 1) {
603             rewriter->setInsertionPointToStart(&func.front());
604             auto boxTy = oldArgTys[fixup.index - offset - fixup.second];
605             auto box = rewriter->create<fir::EmboxCharOp>(
606                 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg);
607             func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
608             func.front().eraseArgument(fixup.index + 1);
609             offset++;
610           }
611         } break;
612         case FixupTy::Codes::ReturnAsStore: {
613           // The value being returned is now being returned in memory (callee
614           // stack space) through a hidden reference argument.
615           auto newArg = func.front().insertArgument(fixup.index,
616                                                     newInTys[fixup.index], loc);
617           offset++;
618           func.walk([&](mlir::func::ReturnOp ret) {
619             rewriter->setInsertionPoint(ret);
620             auto oldOper = ret.getOperand(0);
621             auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
622             auto cast =
623                 rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
624             rewriter->create<fir::StoreOp>(loc, oldOper, cast);
625             rewriter->create<mlir::func::ReturnOp>(loc);
626             ret.erase();
627           });
628         } break;
629         case FixupTy::Codes::ReturnType: {
630           // The function is still returning a value, but its type has likely
631           // changed to suit the target ABI convention.
632           func.walk([&](mlir::func::ReturnOp ret) {
633             rewriter->setInsertionPoint(ret);
634             auto oldOper = ret.getOperand(0);
635             auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
636             auto mem =
637                 rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]);
638             auto cast = rewriter->create<fir::ConvertOp>(loc, oldOperTy, mem);
639             rewriter->create<fir::StoreOp>(loc, oldOper, cast);
640             mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem);
641             rewriter->create<mlir::func::ReturnOp>(loc, load);
642             ret.erase();
643           });
644         } break;
645         case FixupTy::Codes::Split: {
646           // The FIR argument has been split into a pair of distinct arguments
647           // that are in juxtaposition to each other. (For COMPLEX value.)
648           auto newArg = func.front().insertArgument(fixup.index,
649                                                     newInTys[fixup.index], loc);
650           if (fixup.second == 1) {
651             rewriter->setInsertionPointToStart(&func.front());
652             auto cplxTy = oldArgTys[fixup.index - offset - fixup.second];
653             auto undef = rewriter->create<fir::UndefOp>(loc, cplxTy);
654             auto iTy = rewriter->getIntegerType(32);
655             auto zero = rewriter->getIntegerAttr(iTy, 0);
656             auto one = rewriter->getIntegerAttr(iTy, 1);
657             auto cplx1 = rewriter->create<fir::InsertValueOp>(
658                 loc, cplxTy, undef, func.front().getArgument(fixup.index - 1),
659                 rewriter->getArrayAttr(zero));
660             auto cplx = rewriter->create<fir::InsertValueOp>(
661                 loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one));
662             func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx);
663             func.front().eraseArgument(fixup.index + 1);
664             offset++;
665           }
666         } break;
667         case FixupTy::Codes::Trailing: {
668           // The FIR argument has been split into a pair of distinct arguments.
669           // The first part of the pair appears in the original argument
670           // position. The second part of the pair is appended after all the
671           // original arguments. (Boxchar arguments.)
672           auto newBufArg = func.front().insertArgument(
673               fixup.index, newInTys[fixup.index], loc);
674           auto newLenArg =
675               func.front().addArgument(trailingTys[fixup.second], loc);
676           auto boxTy = oldArgTys[fixup.index - offset];
677           rewriter->setInsertionPointToStart(&func.front());
678           auto box = rewriter->create<fir::EmboxCharOp>(loc, boxTy, newBufArg,
679                                                         newLenArg);
680           func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
681           func.front().eraseArgument(fixup.index + 1);
682         } break;
683         case FixupTy::Codes::TrailingCharProc: {
684           // The FIR character procedure argument tuple must be split into a
685           // pair of distinct arguments. The first part of the pair appears in
686           // the original argument position. The second part of the pair is
687           // appended after all the original arguments.
688           auto newProcPointerArg = func.front().insertArgument(
689               fixup.index, newInTys[fixup.index], loc);
690           auto newLenArg =
691               func.front().addArgument(trailingTys[fixup.second], loc);
692           auto tupleType = oldArgTys[fixup.index - offset];
693           rewriter->setInsertionPointToStart(&func.front());
694           fir::FirOpBuilder builder(*rewriter,
695                                     fir::getKindMapping(getModule()));
696           auto tuple = fir::factory::createCharacterProcedureTuple(
697               builder, loc, tupleType, newProcPointerArg, newLenArg);
698           func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple);
699           func.front().eraseArgument(fixup.index + 1);
700         } break;
701         }
702       }
703     }
704 
705     // Set the new type and finalize the arguments, etc.
706     newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end());
707     auto newFuncTy =
708         mlir::FunctionType::get(func.getContext(), newInTys, newResTys);
709     LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n');
710     func.setType(newFuncTy);
711 
712     for (auto &fixup : fixups)
713       if (fixup.finalizer)
714         (*fixup.finalizer)(func);
715   }
716 
717   inline bool functionArgIsSRet(unsigned index, mlir::func::FuncOp func) {
718     if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret"))
719       return true;
720     return false;
721   }
722 
723   /// Convert a complex return value. This can involve converting the return
724   /// value to a "hidden" first argument or packing the complex into a wide
725   /// GPR.
726   template <typename A, typename B, typename C>
727   void doComplexReturn(mlir::func::FuncOp func, A cmplx, B &newResTys,
728                        B &newInTys, C &fixups) {
729     if (noComplexConversion) {
730       newResTys.push_back(cmplx);
731       return;
732     }
733     auto m = specifics->complexReturnType(cmplx.getElementType());
734     assert(m.size() == 1);
735     auto &tup = m[0];
736     auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
737     auto argTy = std::get<mlir::Type>(tup);
738     if (attr.isSRet()) {
739       unsigned argNo = newInTys.size();
740       fixups.emplace_back(
741           FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
742             func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr());
743           });
744       newInTys.push_back(argTy);
745       return;
746     }
747     fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size());
748     newResTys.push_back(argTy);
749   }
750 
751   /// Convert a complex argument value. This can involve storing the value to
752   /// a temporary memory location or factoring the value into two distinct
753   /// arguments.
754   template <typename A, typename B, typename C>
755   void doComplexArg(mlir::func::FuncOp func, A cmplx, B &newInTys, C &fixups) {
756     if (noComplexConversion) {
757       newInTys.push_back(cmplx);
758       return;
759     }
760     auto m = specifics->complexArgumentType(cmplx.getElementType());
761     const auto fixupCode =
762         m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType;
763     for (auto e : llvm::enumerate(m)) {
764       auto &tup = e.value();
765       auto index = e.index();
766       auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
767       auto argTy = std::get<mlir::Type>(tup);
768       auto argNo = newInTys.size();
769       if (attr.isByVal()) {
770         if (auto align = attr.getAlignment())
771           fixups.emplace_back(
772               FixupTy::Codes::ArgumentAsLoad, argNo,
773               [=](mlir::func::FuncOp func) {
774                 func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr());
775                 func.setArgAttr(argNo, "llvm.align",
776                                 rewriter->getIntegerAttr(
777                                     rewriter->getIntegerType(32), align));
778               });
779         else
780           fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(),
781                               [=](mlir::func::FuncOp func) {
782                                 func.setArgAttr(argNo, "llvm.byval",
783                                                 rewriter->getUnitAttr());
784                               });
785       } else {
786         if (auto align = attr.getAlignment())
787           fixups.emplace_back(
788               fixupCode, argNo, index, [=](mlir::func::FuncOp func) {
789                 func.setArgAttr(argNo, "llvm.align",
790                                 rewriter->getIntegerAttr(
791                                     rewriter->getIntegerType(32), align));
792               });
793         else
794           fixups.emplace_back(fixupCode, argNo, index);
795       }
796       newInTys.push_back(argTy);
797     }
798   }
799 
800 private:
801   // Replace `op` and remove it.
802   void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
803     op->replaceAllUsesWith(newValues);
804     op->dropAllReferences();
805     op->erase();
806   }
807 
808   inline void setMembers(fir::CodeGenSpecifics *s, mlir::OpBuilder *r) {
809     specifics = s;
810     rewriter = r;
811   }
812 
813   inline void clearMembers() { setMembers(nullptr, nullptr); }
814 
815   fir::CodeGenSpecifics *specifics = nullptr;
816   mlir::OpBuilder *rewriter = nullptr;
817 }; // namespace
818 } // namespace
819 
820 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
821 fir::createFirTargetRewritePass(const fir::TargetRewriteOptions &options) {
822   return std::make_unique<TargetRewrite>(options);
823 }
824