1 //===-- TargetRewrite.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest.
10 // LLVM expects different lowering idioms to be used for distinct target
11 // triples. These distinctions are handled by this pass.
12 //
13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "PassDetail.h"
18 #include "Target.h"
19 #include "flang/Lower/Todo.h"
20 #include "flang/Optimizer/CodeGen/CodeGen.h"
21 #include "flang/Optimizer/Dialect/FIRDialect.h"
22 #include "flang/Optimizer/Dialect/FIROps.h"
23 #include "flang/Optimizer/Dialect/FIRType.h"
24 #include "flang/Optimizer/Support/FIRContext.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 
30 using namespace fir;
31 
32 #define DEBUG_TYPE "flang-target-rewrite"
33 
34 namespace {
35 
36 /// Fixups for updating a FuncOp's arguments and return values.
37 struct FixupTy {
38   enum class Codes {
39     ArgumentAsLoad,
40     ArgumentType,
41     CharPair,
42     ReturnAsStore,
43     ReturnType,
44     Split,
45     Trailing
46   };
47 
48   FixupTy(Codes code, std::size_t index, std::size_t second = 0)
49       : code{code}, index{index}, second{second} {}
50   FixupTy(Codes code, std::size_t index,
51           std::function<void(mlir::FuncOp)> &&finalizer)
52       : code{code}, index{index}, finalizer{finalizer} {}
53   FixupTy(Codes code, std::size_t index, std::size_t second,
54           std::function<void(mlir::FuncOp)> &&finalizer)
55       : code{code}, index{index}, second{second}, finalizer{finalizer} {}
56 
57   Codes code;
58   std::size_t index;
59   std::size_t second{};
60   llvm::Optional<std::function<void(mlir::FuncOp)>> finalizer{};
61 }; // namespace
62 
63 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code
64 /// generation that traverses the FIR and modifies types and operations to a
65 /// form that is appropriate for the specific target. LLVM IR has specific
66 /// idioms that are used for distinct target processor and ABI combinations.
67 class TargetRewrite : public TargetRewriteBase<TargetRewrite> {
68 public:
69   TargetRewrite(const TargetRewriteOptions &options) {
70     noCharacterConversion = options.noCharacterConversion;
71     noComplexConversion = options.noComplexConversion;
72   }
73 
74   void runOnOperation() override final {
75     auto &context = getContext();
76     mlir::OpBuilder rewriter(&context);
77 
78     auto mod = getModule();
79     if (!forcedTargetTriple.empty()) {
80       setTargetTriple(mod, forcedTargetTriple);
81     }
82 
83     auto specifics = CodeGenSpecifics::get(getOperation().getContext(),
84                                            getTargetTriple(getOperation()),
85                                            getKindMapping(getOperation()));
86     setMembers(specifics.get(), &rewriter);
87 
88     // Perform type conversion on signatures and call sites.
89     if (mlir::failed(convertTypes(mod))) {
90       mlir::emitError(mlir::UnknownLoc::get(&context),
91                       "error in converting types to target abi");
92       signalPassFailure();
93     }
94 
95     // Convert ops in target-specific patterns.
96     mod.walk([&](mlir::Operation *op) {
97       if (auto call = dyn_cast<fir::CallOp>(op)) {
98         if (!hasPortableSignature(call.getFunctionType()))
99           convertCallOp(call);
100       } else if (auto dispatch = dyn_cast<DispatchOp>(op)) {
101         if (!hasPortableSignature(dispatch.getFunctionType()))
102           convertCallOp(dispatch);
103       }
104     });
105 
106     clearMembers();
107   }
108 
109   mlir::ModuleOp getModule() { return getOperation(); }
110 
111   template <typename A, typename B, typename C>
112   std::function<mlir::Value(mlir::Operation *)>
113   rewriteCallComplexResultType(A ty, B &newResTys, B &newInTys, C &newOpers) {
114     auto m = specifics->complexReturnType(ty.getElementType());
115     // Currently targets mandate COMPLEX is a single aggregate or packed
116     // scalar, including the sret case.
117     assert(m.size() == 1 && "target lowering of complex return not supported");
118     auto resTy = std::get<mlir::Type>(m[0]);
119     auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]);
120     auto loc = mlir::UnknownLoc::get(resTy.getContext());
121     if (attr.isSRet()) {
122       assert(isa_ref_type(resTy));
123       mlir::Value stack =
124           rewriter->create<fir::AllocaOp>(loc, dyn_cast_ptrEleTy(resTy));
125       newInTys.push_back(resTy);
126       newOpers.push_back(stack);
127       return [=](mlir::Operation *) -> mlir::Value {
128         auto memTy = ReferenceType::get(ty);
129         auto cast = rewriter->create<ConvertOp>(loc, memTy, stack);
130         return rewriter->create<fir::LoadOp>(loc, cast);
131       };
132     }
133     newResTys.push_back(resTy);
134     return [=](mlir::Operation *call) -> mlir::Value {
135       auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
136       rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem);
137       auto memTy = ReferenceType::get(ty);
138       auto cast = rewriter->create<ConvertOp>(loc, memTy, mem);
139       return rewriter->create<fir::LoadOp>(loc, cast);
140     };
141   }
142 
143   template <typename A, typename B, typename C>
144   void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys,
145                                    C &newOpers) {
146     auto m = specifics->complexArgumentType(ty.getElementType());
147     auto *ctx = ty.getContext();
148     auto loc = mlir::UnknownLoc::get(ctx);
149     if (m.size() == 1) {
150       // COMPLEX is a single aggregate
151       auto resTy = std::get<mlir::Type>(m[0]);
152       auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]);
153       auto oldRefTy = ReferenceType::get(ty);
154       if (attr.isByVal()) {
155         auto mem = rewriter->create<fir::AllocaOp>(loc, ty);
156         rewriter->create<fir::StoreOp>(loc, oper, mem);
157         newOpers.push_back(rewriter->create<ConvertOp>(loc, resTy, mem));
158       } else {
159         auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
160         auto cast = rewriter->create<ConvertOp>(loc, oldRefTy, mem);
161         rewriter->create<fir::StoreOp>(loc, oper, cast);
162         newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem));
163       }
164       newInTys.push_back(resTy);
165     } else {
166       assert(m.size() == 2);
167       // COMPLEX is split into 2 separate arguments
168       for (auto e : llvm::enumerate(m)) {
169         auto &tup = e.value();
170         auto ty = std::get<mlir::Type>(tup);
171         auto index = e.index();
172         auto idx = rewriter->getIntegerAttr(rewriter->getIndexType(), index);
173         auto val = rewriter->create<ExtractValueOp>(
174             loc, ty, oper, rewriter->getArrayAttr(idx));
175         newInTys.push_back(ty);
176         newOpers.push_back(val);
177       }
178     }
179   }
180 
181   // Convert fir.call and fir.dispatch Ops.
182   template <typename A>
183   void convertCallOp(A callOp) {
184     auto fnTy = callOp.getFunctionType();
185     auto loc = callOp.getLoc();
186     rewriter->setInsertionPoint(callOp);
187     llvm::SmallVector<mlir::Type> newResTys;
188     llvm::SmallVector<mlir::Type> newInTys;
189     llvm::SmallVector<mlir::Value> newOpers;
190 
191     // If the call is indirect, the first argument must still be the function
192     // to call.
193     int dropFront = 0;
194     if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
195       if (!callOp.callee().hasValue()) {
196         newInTys.push_back(fnTy.getInput(0));
197         newOpers.push_back(callOp.getOperand(0));
198         dropFront = 1;
199       }
200     }
201 
202     // Determine the rewrite function, `wrap`, for the result value.
203     llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap;
204     if (fnTy.getResults().size() == 1) {
205       mlir::Type ty = fnTy.getResult(0);
206       llvm::TypeSwitch<mlir::Type>(ty)
207           .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
208             wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys,
209                                                 newOpers);
210           })
211           .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
212             wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys,
213                                                 newOpers);
214           })
215           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
216     } else if (fnTy.getResults().size() > 1) {
217       TODO(loc, "multiple results not supported yet");
218     }
219 
220     llvm::SmallVector<mlir::Type> trailingInTys;
221     llvm::SmallVector<mlir::Value> trailingOpers;
222     for (auto e : llvm::enumerate(
223              llvm::zip(fnTy.getInputs().drop_front(dropFront),
224                        callOp.getOperands().drop_front(dropFront)))) {
225       mlir::Type ty = std::get<0>(e.value());
226       mlir::Value oper = std::get<1>(e.value());
227       unsigned index = e.index();
228       llvm::TypeSwitch<mlir::Type>(ty)
229           .template Case<BoxCharType>([&](BoxCharType boxTy) {
230             bool sret;
231             if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
232               sret = callOp.callee() &&
233                      functionArgIsSRet(index,
234                                        getModule().lookupSymbol<mlir::FuncOp>(
235                                            *callOp.callee()));
236             } else {
237               // TODO: dispatch case; how do we put arguments on a call?
238               // We cannot put both an sret and the dispatch object first.
239               sret = false;
240               TODO(loc, "dispatch + sret not supported yet");
241             }
242             auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret);
243             auto unbox =
244                 rewriter->create<UnboxCharOp>(loc, std::get<mlir::Type>(m[0]),
245                                               std::get<mlir::Type>(m[1]), oper);
246             // unboxed CHARACTER arguments
247             for (auto e : llvm::enumerate(m)) {
248               unsigned idx = e.index();
249               auto attr = std::get<CodeGenSpecifics::Attributes>(e.value());
250               auto argTy = std::get<mlir::Type>(e.value());
251               if (attr.isAppend()) {
252                 trailingInTys.push_back(argTy);
253                 trailingOpers.push_back(unbox.getResult(idx));
254               } else {
255                 newInTys.push_back(argTy);
256                 newOpers.push_back(unbox.getResult(idx));
257               }
258             }
259           })
260           .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
261             rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers);
262           })
263           .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
264             rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers);
265           })
266           .Default([&](mlir::Type ty) {
267             newInTys.push_back(ty);
268             newOpers.push_back(oper);
269           });
270     }
271     newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
272     newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
273     if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
274       fir::CallOp newCall;
275       if (callOp.callee().hasValue()) {
276         newCall = rewriter->create<A>(loc, callOp.callee().getValue(),
277                                       newResTys, newOpers);
278       } else {
279         // Force new type on the input operand.
280         newOpers[0].setType(mlir::FunctionType::get(
281             callOp.getContext(),
282             mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys));
283         newCall = rewriter->create<A>(loc, newResTys, newOpers);
284       }
285       LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
286       if (wrap.hasValue())
287         replaceOp(callOp, (*wrap)(newCall.getOperation()));
288       else
289         replaceOp(callOp, newCall.getResults());
290     } else {
291       // A is fir::DispatchOp
292       TODO(loc, "dispatch not implemented");
293     }
294   }
295 
296   // Result type fixup for fir::ComplexType and mlir::ComplexType
297   template <typename A, typename B>
298   void lowerComplexSignatureRes(A cmplx, B &newResTys, B &newInTys) {
299     if (noComplexConversion) {
300       newResTys.push_back(cmplx);
301     } else {
302       for (auto &tup : specifics->complexReturnType(cmplx.getElementType())) {
303         auto argTy = std::get<mlir::Type>(tup);
304         if (std::get<CodeGenSpecifics::Attributes>(tup).isSRet())
305           newInTys.push_back(argTy);
306         else
307           newResTys.push_back(argTy);
308       }
309     }
310   }
311 
312   // Argument type fixup for fir::ComplexType and mlir::ComplexType
313   template <typename A, typename B>
314   void lowerComplexSignatureArg(A cmplx, B &newInTys) {
315     if (noComplexConversion)
316       newInTys.push_back(cmplx);
317     else
318       for (auto &tup : specifics->complexArgumentType(cmplx.getElementType()))
319         newInTys.push_back(std::get<mlir::Type>(tup));
320   }
321 
322   /// Convert the type signatures on all the functions present in the module.
323   /// As the type signature is being changed, this must also update the
324   /// function itself to use any new arguments, etc.
325   mlir::LogicalResult convertTypes(mlir::ModuleOp mod) {
326     for (auto fn : mod.getOps<mlir::FuncOp>())
327       convertSignature(fn);
328     return mlir::success();
329   }
330 
331   /// If the signature does not need any special target-specific converions,
332   /// then it is considered portable for any target, and this function will
333   /// return `true`. Otherwise, the signature is not portable and `false` is
334   /// returned.
335   bool hasPortableSignature(mlir::Type signature) {
336     assert(signature.isa<mlir::FunctionType>());
337     auto func = signature.dyn_cast<mlir::FunctionType>();
338     for (auto ty : func.getResults())
339       if ((ty.isa<BoxCharType>() && !noCharacterConversion) ||
340           (isa_complex(ty) && !noComplexConversion)) {
341         LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
342         return false;
343       }
344     for (auto ty : func.getInputs())
345       if ((ty.isa<BoxCharType>() && !noCharacterConversion) ||
346           (isa_complex(ty) && !noComplexConversion)) {
347         LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
348         return false;
349       }
350     return true;
351   }
352 
353   /// Rewrite the signatures and body of the `FuncOp`s in the module for
354   /// the immediately subsequent target code gen.
355   void convertSignature(mlir::FuncOp func) {
356     auto funcTy = func.getType().cast<mlir::FunctionType>();
357     if (hasPortableSignature(funcTy))
358       return;
359     llvm::SmallVector<mlir::Type> newResTys;
360     llvm::SmallVector<mlir::Type> newInTys;
361     llvm::SmallVector<FixupTy> fixups;
362 
363     // Convert return value(s)
364     for (auto ty : funcTy.getResults())
365       llvm::TypeSwitch<mlir::Type>(ty)
366           .Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
367             if (noComplexConversion)
368               newResTys.push_back(cmplx);
369             else
370               doComplexReturn(func, cmplx, newResTys, newInTys, fixups);
371           })
372           .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
373             if (noComplexConversion)
374               newResTys.push_back(cmplx);
375             else
376               doComplexReturn(func, cmplx, newResTys, newInTys, fixups);
377           })
378           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
379 
380     // Convert arguments
381     llvm::SmallVector<mlir::Type> trailingTys;
382     for (auto e : llvm::enumerate(funcTy.getInputs())) {
383       auto ty = e.value();
384       unsigned index = e.index();
385       llvm::TypeSwitch<mlir::Type>(ty)
386           .Case<BoxCharType>([&](BoxCharType boxTy) {
387             if (noCharacterConversion) {
388               newInTys.push_back(boxTy);
389             } else {
390               // Convert a CHARACTER argument type. This can involve separating
391               // the pointer and the LEN into two arguments and moving the LEN
392               // argument to the end of the arg list.
393               bool sret = functionArgIsSRet(index, func);
394               for (auto e : llvm::enumerate(specifics->boxcharArgumentType(
395                        boxTy.getEleTy(), sret))) {
396                 auto &tup = e.value();
397                 auto index = e.index();
398                 auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
399                 auto argTy = std::get<mlir::Type>(tup);
400                 if (attr.isAppend()) {
401                   trailingTys.push_back(argTy);
402                 } else {
403                   if (sret) {
404                     fixups.emplace_back(FixupTy::Codes::CharPair,
405                                         newInTys.size(), index);
406                   } else {
407                     fixups.emplace_back(FixupTy::Codes::Trailing,
408                                         newInTys.size(), trailingTys.size());
409                   }
410                   newInTys.push_back(argTy);
411                 }
412               }
413             }
414           })
415           .Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
416             if (noComplexConversion)
417               newInTys.push_back(cmplx);
418             else
419               doComplexArg(func, cmplx, newInTys, fixups);
420           })
421           .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
422             if (noComplexConversion)
423               newInTys.push_back(cmplx);
424             else
425               doComplexArg(func, cmplx, newInTys, fixups);
426           })
427           .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
428     }
429 
430     if (!func.empty()) {
431       // If the function has a body, then apply the fixups to the arguments and
432       // return ops as required. These fixups are done in place.
433       auto loc = func.getLoc();
434       const auto fixupSize = fixups.size();
435       const auto oldArgTys = func.getType().getInputs();
436       int offset = 0;
437       for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) {
438         const auto &fixup = fixups[i];
439         switch (fixup.code) {
440         case FixupTy::Codes::ArgumentAsLoad: {
441           // Argument was pass-by-value, but is now pass-by-reference and
442           // possibly with a different element type.
443           auto newArg =
444               func.front().insertArgument(fixup.index, newInTys[fixup.index]);
445           rewriter->setInsertionPointToStart(&func.front());
446           auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]);
447           auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, newArg);
448           auto load = rewriter->create<fir::LoadOp>(loc, cast);
449           func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
450           func.front().eraseArgument(fixup.index + 1);
451         } break;
452         case FixupTy::Codes::ArgumentType: {
453           // Argument is pass-by-value, but its type has likely been modified to
454           // suit the target ABI convention.
455           auto newArg =
456               func.front().insertArgument(fixup.index, newInTys[fixup.index]);
457           rewriter->setInsertionPointToStart(&func.front());
458           auto mem =
459               rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]);
460           rewriter->create<fir::StoreOp>(loc, newArg, mem);
461           auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]);
462           auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, mem);
463           mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast);
464           func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
465           func.front().eraseArgument(fixup.index + 1);
466           LLVM_DEBUG(llvm::dbgs()
467                      << "old argument: " << oldArgTy.getEleTy()
468                      << ", repl: " << load << ", new argument: "
469                      << func.getArgument(fixup.index).getType() << '\n');
470         } break;
471         case FixupTy::Codes::CharPair: {
472           // The FIR boxchar argument has been split into a pair of distinct
473           // arguments that are in juxtaposition to each other.
474           auto newArg =
475               func.front().insertArgument(fixup.index, newInTys[fixup.index]);
476           if (fixup.second == 1) {
477             rewriter->setInsertionPointToStart(&func.front());
478             auto boxTy = oldArgTys[fixup.index - offset - fixup.second];
479             auto box = rewriter->create<EmboxCharOp>(
480                 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg);
481             func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
482             func.front().eraseArgument(fixup.index + 1);
483             offset++;
484           }
485         } break;
486         case FixupTy::Codes::ReturnAsStore: {
487           // The value being returned is now being returned in memory (callee
488           // stack space) through a hidden reference argument.
489           auto newArg =
490               func.front().insertArgument(fixup.index, newInTys[fixup.index]);
491           offset++;
492           func.walk([&](mlir::ReturnOp ret) {
493             rewriter->setInsertionPoint(ret);
494             auto oldOper = ret.getOperand(0);
495             auto oldOperTy = ReferenceType::get(oldOper.getType());
496             auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, newArg);
497             rewriter->create<fir::StoreOp>(loc, oldOper, cast);
498             rewriter->create<mlir::ReturnOp>(loc);
499             ret.erase();
500           });
501         } break;
502         case FixupTy::Codes::ReturnType: {
503           // The function is still returning a value, but its type has likely
504           // changed to suit the target ABI convention.
505           func.walk([&](mlir::ReturnOp ret) {
506             rewriter->setInsertionPoint(ret);
507             auto oldOper = ret.getOperand(0);
508             auto oldOperTy = ReferenceType::get(oldOper.getType());
509             auto mem =
510                 rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]);
511             auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, mem);
512             rewriter->create<fir::StoreOp>(loc, oldOper, cast);
513             mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem);
514             rewriter->create<mlir::ReturnOp>(loc, load);
515             ret.erase();
516           });
517         } break;
518         case FixupTy::Codes::Split: {
519           // The FIR argument has been split into a pair of distinct arguments
520           // that are in juxtaposition to each other. (For COMPLEX value.)
521           auto newArg =
522               func.front().insertArgument(fixup.index, newInTys[fixup.index]);
523           if (fixup.second == 1) {
524             rewriter->setInsertionPointToStart(&func.front());
525             auto cplxTy = oldArgTys[fixup.index - offset - fixup.second];
526             auto undef = rewriter->create<UndefOp>(loc, cplxTy);
527             auto zero = rewriter->getIntegerAttr(rewriter->getIndexType(), 0);
528             auto one = rewriter->getIntegerAttr(rewriter->getIndexType(), 1);
529             auto cplx1 = rewriter->create<InsertValueOp>(
530                 loc, cplxTy, undef, func.front().getArgument(fixup.index - 1),
531                 rewriter->getArrayAttr(zero));
532             auto cplx = rewriter->create<InsertValueOp>(
533                 loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one));
534             func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx);
535             func.front().eraseArgument(fixup.index + 1);
536             offset++;
537           }
538         } break;
539         case FixupTy::Codes::Trailing: {
540           // The FIR argument has been split into a pair of distinct arguments.
541           // The first part of the pair appears in the original argument
542           // position. The second part of the pair is appended after all the
543           // original arguments. (Boxchar arguments.)
544           auto newBufArg =
545               func.front().insertArgument(fixup.index, newInTys[fixup.index]);
546           auto newLenArg = func.front().addArgument(trailingTys[fixup.second]);
547           auto boxTy = oldArgTys[fixup.index - offset];
548           rewriter->setInsertionPointToStart(&func.front());
549           auto box =
550               rewriter->create<EmboxCharOp>(loc, boxTy, newBufArg, newLenArg);
551           func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
552           func.front().eraseArgument(fixup.index + 1);
553         } break;
554         }
555       }
556     }
557 
558     // Set the new type and finalize the arguments, etc.
559     newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end());
560     auto newFuncTy =
561         mlir::FunctionType::get(func.getContext(), newInTys, newResTys);
562     LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n');
563     func.setType(newFuncTy);
564 
565     for (auto &fixup : fixups)
566       if (fixup.finalizer)
567         (*fixup.finalizer)(func);
568   }
569 
570   inline bool functionArgIsSRet(unsigned index, mlir::FuncOp func) {
571     if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret"))
572       return true;
573     return false;
574   }
575 
576   /// Convert a complex return value. This can involve converting the return
577   /// value to a "hidden" first argument or packing the complex into a wide
578   /// GPR.
579   template <typename A, typename B, typename C>
580   void doComplexReturn(mlir::FuncOp func, A cmplx, B &newResTys, B &newInTys,
581                        C &fixups) {
582     if (noComplexConversion) {
583       newResTys.push_back(cmplx);
584       return;
585     }
586     auto m = specifics->complexReturnType(cmplx.getElementType());
587     assert(m.size() == 1);
588     auto &tup = m[0];
589     auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
590     auto argTy = std::get<mlir::Type>(tup);
591     if (attr.isSRet()) {
592       unsigned argNo = newInTys.size();
593       fixups.emplace_back(
594           FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::FuncOp func) {
595             func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr());
596           });
597       newInTys.push_back(argTy);
598       return;
599     }
600     fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size());
601     newResTys.push_back(argTy);
602   }
603 
604   /// Convert a complex argument value. This can involve storing the value to
605   /// a temporary memory location or factoring the value into two distinct
606   /// arguments.
607   template <typename A, typename B, typename C>
608   void doComplexArg(mlir::FuncOp func, A cmplx, B &newInTys, C &fixups) {
609     if (noComplexConversion) {
610       newInTys.push_back(cmplx);
611       return;
612     }
613     auto m = specifics->complexArgumentType(cmplx.getElementType());
614     const auto fixupCode =
615         m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType;
616     for (auto e : llvm::enumerate(m)) {
617       auto &tup = e.value();
618       auto index = e.index();
619       auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
620       auto argTy = std::get<mlir::Type>(tup);
621       auto argNo = newInTys.size();
622       if (attr.isByVal()) {
623         if (auto align = attr.getAlignment())
624           fixups.emplace_back(
625               FixupTy::Codes::ArgumentAsLoad, argNo, [=](mlir::FuncOp func) {
626                 func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr());
627                 func.setArgAttr(argNo, "llvm.align",
628                                 rewriter->getIntegerAttr(
629                                     rewriter->getIntegerType(32), align));
630               });
631         else
632           fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(),
633                               [=](mlir::FuncOp func) {
634                                 func.setArgAttr(argNo, "llvm.byval",
635                                                 rewriter->getUnitAttr());
636                               });
637       } else {
638         if (auto align = attr.getAlignment())
639           fixups.emplace_back(fixupCode, argNo, index, [=](mlir::FuncOp func) {
640             func.setArgAttr(
641                 argNo, "llvm.align",
642                 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align));
643           });
644         else
645           fixups.emplace_back(fixupCode, argNo, index);
646       }
647       newInTys.push_back(argTy);
648     }
649   }
650 
651 private:
652   // Replace `op` and remove it.
653   void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
654     op->replaceAllUsesWith(newValues);
655     op->dropAllReferences();
656     op->erase();
657   }
658 
659   inline void setMembers(CodeGenSpecifics *s, mlir::OpBuilder *r) {
660     specifics = s;
661     rewriter = r;
662   }
663 
664   inline void clearMembers() { setMembers(nullptr, nullptr); }
665 
666   CodeGenSpecifics *specifics{};
667   mlir::OpBuilder *rewriter;
668 }; // namespace
669 } // namespace
670 
671 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
672 fir::createFirTargetRewritePass(const TargetRewriteOptions &options) {
673   return std::make_unique<TargetRewrite>(options);
674 }
675