1 //===-- TargetRewrite.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest.
10 // LLVM expects different lowering idioms to be used for distinct target
11 // triples. These distinctions are handled by this pass.
12 //
13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "PassDetail.h"
18 #include "Target.h"
19 #include "flang/Lower/Todo.h"
20 #include "flang/Optimizer/CodeGen/CodeGen.h"
21 #include "flang/Optimizer/Dialect/FIRDialect.h"
22 #include "flang/Optimizer/Dialect/FIROps.h"
23 #include "flang/Optimizer/Dialect/FIRType.h"
24 #include "flang/Optimizer/Support/FIRContext.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 
30 using namespace fir;
31 
32 #define DEBUG_TYPE "flang-target-rewrite"
33 
34 namespace {
35 
36 /// Fixups for updating a FuncOp's arguments and return values.
37 struct FixupTy {
38   enum class Codes {
39     ArgumentAsLoad,
40     ArgumentType,
41     CharPair,
42     ReturnAsStore,
43     ReturnType,
44     Split,
45     Trailing
46   };
47 
48   FixupTy(Codes code, std::size_t index, std::size_t second = 0)
49       : code{code}, index{index}, second{second} {}
50   FixupTy(Codes code, std::size_t index,
51           std::function<void(mlir::FuncOp)> &&finalizer)
52       : code{code}, index{index}, finalizer{finalizer} {}
53   FixupTy(Codes code, std::size_t index, std::size_t second,
54           std::function<void(mlir::FuncOp)> &&finalizer)
55       : code{code}, index{index}, second{second}, finalizer{finalizer} {}
56 
57   Codes code;
58   std::size_t index;
59   std::size_t second{};
60   llvm::Optional<std::function<void(mlir::FuncOp)>> finalizer{};
61 }; // namespace
62 
63 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code
64 /// generation that traverses the FIR and modifies types and operations to a
65 /// form that is appropriate for the specific target. LLVM IR has specific
66 /// idioms that are used for distinct target processor and ABI combinations.
67 class TargetRewrite : public TargetRewriteBase<TargetRewrite> {
68 public:
69   TargetRewrite(const TargetRewriteOptions &options) {
70     noCharacterConversion = options.noCharacterConversion;
71     noComplexConversion = options.noComplexConversion;
72   }
73 
74   void runOnOperation() override final {
75     auto &context = getContext();
76     mlir::OpBuilder rewriter(&context);
77 
78     auto mod = getModule();
79     if (!forcedTargetTriple.empty())
80       setTargetTriple(mod, forcedTargetTriple);
81 
82     auto specifics = CodeGenSpecifics::get(getOperation().getContext(),
83                                            getTargetTriple(getOperation()),
84                                            getKindMapping(getOperation()));
85     setMembers(specifics.get(), &rewriter);
86 
87     // Perform type conversion on signatures and call sites.
88     if (mlir::failed(convertTypes(mod))) {
89       mlir::emitError(mlir::UnknownLoc::get(&context),
90                       "error in converting types to target abi");
91       signalPassFailure();
92     }
93 
94     // Convert ops in target-specific patterns.
95     mod.walk([&](mlir::Operation *op) {
96       if (auto call = dyn_cast<fir::CallOp>(op)) {
97         if (!hasPortableSignature(call.getFunctionType()))
98           convertCallOp(call);
99       } else if (auto dispatch = dyn_cast<DispatchOp>(op)) {
100         if (!hasPortableSignature(dispatch.getFunctionType()))
101           convertCallOp(dispatch);
102       } else if (auto addr = dyn_cast<AddrOfOp>(op)) {
103         if (addr.getType().isa<mlir::FunctionType>() &&
104             !hasPortableSignature(addr.getType()))
105           convertAddrOp(addr);
106       }
107     });
108 
109     clearMembers();
110   }
111 
112   mlir::ModuleOp getModule() { return getOperation(); }
113 
114   template <typename A, typename B, typename C>
115   std::function<mlir::Value(mlir::Operation *)>
116   rewriteCallComplexResultType(A ty, B &newResTys, B &newInTys, C &newOpers) {
117     auto m = specifics->complexReturnType(ty.getElementType());
118     // Currently targets mandate COMPLEX is a single aggregate or packed
119     // scalar, including the sret case.
120     assert(m.size() == 1 && "target lowering of complex return not supported");
121     auto resTy = std::get<mlir::Type>(m[0]);
122     auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]);
123     auto loc = mlir::UnknownLoc::get(resTy.getContext());
124     if (attr.isSRet()) {
125       assert(isa_ref_type(resTy));
126       mlir::Value stack =
127           rewriter->create<fir::AllocaOp>(loc, dyn_cast_ptrEleTy(resTy));
128       newInTys.push_back(resTy);
129       newOpers.push_back(stack);
130       return [=](mlir::Operation *) -> mlir::Value {
131         auto memTy = ReferenceType::get(ty);
132         auto cast = rewriter->create<ConvertOp>(loc, memTy, stack);
133         return rewriter->create<fir::LoadOp>(loc, cast);
134       };
135     }
136     newResTys.push_back(resTy);
137     return [=](mlir::Operation *call) -> mlir::Value {
138       auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
139       rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem);
140       auto memTy = ReferenceType::get(ty);
141       auto cast = rewriter->create<ConvertOp>(loc, memTy, mem);
142       return rewriter->create<fir::LoadOp>(loc, cast);
143     };
144   }
145 
146   template <typename A, typename B, typename C>
147   void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys,
148                                    C &newOpers) {
149     auto m = specifics->complexArgumentType(ty.getElementType());
150     auto *ctx = ty.getContext();
151     auto loc = mlir::UnknownLoc::get(ctx);
152     if (m.size() == 1) {
153       // COMPLEX is a single aggregate
154       auto resTy = std::get<mlir::Type>(m[0]);
155       auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]);
156       auto oldRefTy = ReferenceType::get(ty);
157       if (attr.isByVal()) {
158         auto mem = rewriter->create<fir::AllocaOp>(loc, ty);
159         rewriter->create<fir::StoreOp>(loc, oper, mem);
160         newOpers.push_back(rewriter->create<ConvertOp>(loc, resTy, mem));
161       } else {
162         auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
163         auto cast = rewriter->create<ConvertOp>(loc, oldRefTy, mem);
164         rewriter->create<fir::StoreOp>(loc, oper, cast);
165         newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem));
166       }
167       newInTys.push_back(resTy);
168     } else {
169       assert(m.size() == 2);
170       // COMPLEX is split into 2 separate arguments
171       for (auto e : llvm::enumerate(m)) {
172         auto &tup = e.value();
173         auto ty = std::get<mlir::Type>(tup);
174         auto index = e.index();
175         auto idx = rewriter->getIntegerAttr(rewriter->getIndexType(), index);
176         auto val = rewriter->create<ExtractValueOp>(
177             loc, ty, oper, rewriter->getArrayAttr(idx));
178         newInTys.push_back(ty);
179         newOpers.push_back(val);
180       }
181     }
182   }
183 
184   // Convert fir.call and fir.dispatch Ops.
185   template <typename A>
186   void convertCallOp(A callOp) {
187     auto fnTy = callOp.getFunctionType();
188     auto loc = callOp.getLoc();
189     rewriter->setInsertionPoint(callOp);
190     llvm::SmallVector<mlir::Type> newResTys;
191     llvm::SmallVector<mlir::Type> newInTys;
192     llvm::SmallVector<mlir::Value> newOpers;
193 
194     // If the call is indirect, the first argument must still be the function
195     // to call.
196     int dropFront = 0;
197     if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
198       if (!callOp.callee().hasValue()) {
199         newInTys.push_back(fnTy.getInput(0));
200         newOpers.push_back(callOp.getOperand(0));
201         dropFront = 1;
202       }
203     }
204 
205     // Determine the rewrite function, `wrap`, for the result value.
206     llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap;
207     if (fnTy.getResults().size() == 1) {
208       mlir::Type ty = fnTy.getResult(0);
209       llvm::TypeSwitch<mlir::Type>(ty)
210           .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
211             wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys,
212                                                 newOpers);
213           })
214           .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
215             wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys,
216                                                 newOpers);
217           })
218           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
219     } else if (fnTy.getResults().size() > 1) {
220       TODO(loc, "multiple results not supported yet");
221     }
222 
223     llvm::SmallVector<mlir::Type> trailingInTys;
224     llvm::SmallVector<mlir::Value> trailingOpers;
225     for (auto e : llvm::enumerate(
226              llvm::zip(fnTy.getInputs().drop_front(dropFront),
227                        callOp.getOperands().drop_front(dropFront)))) {
228       mlir::Type ty = std::get<0>(e.value());
229       mlir::Value oper = std::get<1>(e.value());
230       unsigned index = e.index();
231       llvm::TypeSwitch<mlir::Type>(ty)
232           .template Case<BoxCharType>([&](BoxCharType boxTy) {
233             bool sret;
234             if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
235               sret = callOp.callee() &&
236                      functionArgIsSRet(index,
237                                        getModule().lookupSymbol<mlir::FuncOp>(
238                                            *callOp.callee()));
239             } else {
240               // TODO: dispatch case; how do we put arguments on a call?
241               // We cannot put both an sret and the dispatch object first.
242               sret = false;
243               TODO(loc, "dispatch + sret not supported yet");
244             }
245             auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret);
246             auto unbox =
247                 rewriter->create<UnboxCharOp>(loc, std::get<mlir::Type>(m[0]),
248                                               std::get<mlir::Type>(m[1]), oper);
249             // unboxed CHARACTER arguments
250             for (auto e : llvm::enumerate(m)) {
251               unsigned idx = e.index();
252               auto attr = std::get<CodeGenSpecifics::Attributes>(e.value());
253               auto argTy = std::get<mlir::Type>(e.value());
254               if (attr.isAppend()) {
255                 trailingInTys.push_back(argTy);
256                 trailingOpers.push_back(unbox.getResult(idx));
257               } else {
258                 newInTys.push_back(argTy);
259                 newOpers.push_back(unbox.getResult(idx));
260               }
261             }
262           })
263           .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
264             rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers);
265           })
266           .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
267             rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers);
268           })
269           .Default([&](mlir::Type ty) {
270             newInTys.push_back(ty);
271             newOpers.push_back(oper);
272           });
273     }
274     newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
275     newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
276     if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
277       fir::CallOp newCall;
278       if (callOp.callee().hasValue()) {
279         newCall = rewriter->create<A>(loc, callOp.callee().getValue(),
280                                       newResTys, newOpers);
281       } else {
282         // Force new type on the input operand.
283         newOpers[0].setType(mlir::FunctionType::get(
284             callOp.getContext(),
285             mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys));
286         newCall = rewriter->create<A>(loc, newResTys, newOpers);
287       }
288       LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
289       if (wrap.hasValue())
290         replaceOp(callOp, (*wrap)(newCall.getOperation()));
291       else
292         replaceOp(callOp, newCall.getResults());
293     } else {
294       // A is fir::DispatchOp
295       TODO(loc, "dispatch not implemented");
296     }
297   }
298 
299   // Result type fixup for fir::ComplexType and mlir::ComplexType
300   template <typename A, typename B>
301   void lowerComplexSignatureRes(A cmplx, B &newResTys, B &newInTys) {
302     if (noComplexConversion) {
303       newResTys.push_back(cmplx);
304     } else {
305       for (auto &tup : specifics->complexReturnType(cmplx.getElementType())) {
306         auto argTy = std::get<mlir::Type>(tup);
307         if (std::get<CodeGenSpecifics::Attributes>(tup).isSRet())
308           newInTys.push_back(argTy);
309         else
310           newResTys.push_back(argTy);
311       }
312     }
313   }
314 
315   // Argument type fixup for fir::ComplexType and mlir::ComplexType
316   template <typename A, typename B>
317   void lowerComplexSignatureArg(A cmplx, B &newInTys) {
318     if (noComplexConversion)
319       newInTys.push_back(cmplx);
320     else
321       for (auto &tup : specifics->complexArgumentType(cmplx.getElementType()))
322         newInTys.push_back(std::get<mlir::Type>(tup));
323   }
324 
325   /// Taking the address of a function. Modify the signature as needed.
326   void convertAddrOp(AddrOfOp addrOp) {
327     rewriter->setInsertionPoint(addrOp);
328     auto addrTy = addrOp.getType().cast<mlir::FunctionType>();
329     llvm::SmallVector<mlir::Type> newResTys;
330     llvm::SmallVector<mlir::Type> newInTys;
331     for (mlir::Type ty : addrTy.getResults()) {
332       llvm::TypeSwitch<mlir::Type>(ty)
333           .Case<fir::ComplexType>([&](fir::ComplexType ty) {
334             lowerComplexSignatureRes(ty, newResTys, newInTys);
335           })
336           .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
337             lowerComplexSignatureRes(ty, newResTys, newInTys);
338           })
339           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
340     }
341     llvm::SmallVector<mlir::Type> trailingInTys;
342     for (mlir::Type ty : addrTy.getInputs()) {
343       llvm::TypeSwitch<mlir::Type>(ty)
344           .Case<BoxCharType>([&](BoxCharType box) {
345             if (noCharacterConversion) {
346               newInTys.push_back(box);
347             } else {
348               for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) {
349                 auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
350                 auto argTy = std::get<mlir::Type>(tup);
351                 llvm::SmallVector<mlir::Type> &vec =
352                     attr.isAppend() ? trailingInTys : newInTys;
353                 vec.push_back(argTy);
354               }
355             }
356           })
357           .Case<fir::ComplexType>([&](fir::ComplexType ty) {
358             lowerComplexSignatureArg(ty, newInTys);
359           })
360           .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
361             lowerComplexSignatureArg(ty, newInTys);
362           })
363           .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
364     }
365     // append trailing input types
366     newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
367     // replace this op with a new one with the updated signature
368     auto newTy = rewriter->getFunctionType(newInTys, newResTys);
369     auto newOp =
370         rewriter->create<AddrOfOp>(addrOp.getLoc(), newTy, addrOp.symbol());
371     replaceOp(addrOp, newOp.getResult());
372   }
373 
374   /// Convert the type signatures on all the functions present in the module.
375   /// As the type signature is being changed, this must also update the
376   /// function itself to use any new arguments, etc.
377   mlir::LogicalResult convertTypes(mlir::ModuleOp mod) {
378     for (auto fn : mod.getOps<mlir::FuncOp>())
379       convertSignature(fn);
380     return mlir::success();
381   }
382 
383   /// If the signature does not need any special target-specific converions,
384   /// then it is considered portable for any target, and this function will
385   /// return `true`. Otherwise, the signature is not portable and `false` is
386   /// returned.
387   bool hasPortableSignature(mlir::Type signature) {
388     assert(signature.isa<mlir::FunctionType>());
389     auto func = signature.dyn_cast<mlir::FunctionType>();
390     for (auto ty : func.getResults())
391       if ((ty.isa<BoxCharType>() && !noCharacterConversion) ||
392           (isa_complex(ty) && !noComplexConversion)) {
393         LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
394         return false;
395       }
396     for (auto ty : func.getInputs())
397       if ((ty.isa<BoxCharType>() && !noCharacterConversion) ||
398           (isa_complex(ty) && !noComplexConversion)) {
399         LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
400         return false;
401       }
402     return true;
403   }
404 
405   /// Rewrite the signatures and body of the `FuncOp`s in the module for
406   /// the immediately subsequent target code gen.
407   void convertSignature(mlir::FuncOp func) {
408     auto funcTy = func.getType().cast<mlir::FunctionType>();
409     if (hasPortableSignature(funcTy))
410       return;
411     llvm::SmallVector<mlir::Type> newResTys;
412     llvm::SmallVector<mlir::Type> newInTys;
413     llvm::SmallVector<FixupTy> fixups;
414 
415     // Convert return value(s)
416     for (auto ty : funcTy.getResults())
417       llvm::TypeSwitch<mlir::Type>(ty)
418           .Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
419             if (noComplexConversion)
420               newResTys.push_back(cmplx);
421             else
422               doComplexReturn(func, cmplx, newResTys, newInTys, fixups);
423           })
424           .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
425             if (noComplexConversion)
426               newResTys.push_back(cmplx);
427             else
428               doComplexReturn(func, cmplx, newResTys, newInTys, fixups);
429           })
430           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
431 
432     // Convert arguments
433     llvm::SmallVector<mlir::Type> trailingTys;
434     for (auto e : llvm::enumerate(funcTy.getInputs())) {
435       auto ty = e.value();
436       unsigned index = e.index();
437       llvm::TypeSwitch<mlir::Type>(ty)
438           .Case<BoxCharType>([&](BoxCharType boxTy) {
439             if (noCharacterConversion) {
440               newInTys.push_back(boxTy);
441             } else {
442               // Convert a CHARACTER argument type. This can involve separating
443               // the pointer and the LEN into two arguments and moving the LEN
444               // argument to the end of the arg list.
445               bool sret = functionArgIsSRet(index, func);
446               for (auto e : llvm::enumerate(specifics->boxcharArgumentType(
447                        boxTy.getEleTy(), sret))) {
448                 auto &tup = e.value();
449                 auto index = e.index();
450                 auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
451                 auto argTy = std::get<mlir::Type>(tup);
452                 if (attr.isAppend()) {
453                   trailingTys.push_back(argTy);
454                 } else {
455                   if (sret) {
456                     fixups.emplace_back(FixupTy::Codes::CharPair,
457                                         newInTys.size(), index);
458                   } else {
459                     fixups.emplace_back(FixupTy::Codes::Trailing,
460                                         newInTys.size(), trailingTys.size());
461                   }
462                   newInTys.push_back(argTy);
463                 }
464               }
465             }
466           })
467           .Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
468             if (noComplexConversion)
469               newInTys.push_back(cmplx);
470             else
471               doComplexArg(func, cmplx, newInTys, fixups);
472           })
473           .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
474             if (noComplexConversion)
475               newInTys.push_back(cmplx);
476             else
477               doComplexArg(func, cmplx, newInTys, fixups);
478           })
479           .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
480     }
481 
482     if (!func.empty()) {
483       // If the function has a body, then apply the fixups to the arguments and
484       // return ops as required. These fixups are done in place.
485       auto loc = func.getLoc();
486       const auto fixupSize = fixups.size();
487       const auto oldArgTys = func.getType().getInputs();
488       int offset = 0;
489       for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) {
490         const auto &fixup = fixups[i];
491         switch (fixup.code) {
492         case FixupTy::Codes::ArgumentAsLoad: {
493           // Argument was pass-by-value, but is now pass-by-reference and
494           // possibly with a different element type.
495           auto newArg = func.front().insertArgument(fixup.index,
496                                                     newInTys[fixup.index], loc);
497           rewriter->setInsertionPointToStart(&func.front());
498           auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]);
499           auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, newArg);
500           auto load = rewriter->create<fir::LoadOp>(loc, cast);
501           func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
502           func.front().eraseArgument(fixup.index + 1);
503         } break;
504         case FixupTy::Codes::ArgumentType: {
505           // Argument is pass-by-value, but its type has likely been modified to
506           // suit the target ABI convention.
507           auto newArg = func.front().insertArgument(fixup.index,
508                                                     newInTys[fixup.index], loc);
509           rewriter->setInsertionPointToStart(&func.front());
510           auto mem =
511               rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]);
512           rewriter->create<fir::StoreOp>(loc, newArg, mem);
513           auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]);
514           auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, mem);
515           mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast);
516           func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
517           func.front().eraseArgument(fixup.index + 1);
518           LLVM_DEBUG(llvm::dbgs()
519                      << "old argument: " << oldArgTy.getEleTy()
520                      << ", repl: " << load << ", new argument: "
521                      << func.getArgument(fixup.index).getType() << '\n');
522         } break;
523         case FixupTy::Codes::CharPair: {
524           // The FIR boxchar argument has been split into a pair of distinct
525           // arguments that are in juxtaposition to each other.
526           auto newArg = func.front().insertArgument(fixup.index,
527                                                     newInTys[fixup.index], loc);
528           if (fixup.second == 1) {
529             rewriter->setInsertionPointToStart(&func.front());
530             auto boxTy = oldArgTys[fixup.index - offset - fixup.second];
531             auto box = rewriter->create<EmboxCharOp>(
532                 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg);
533             func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
534             func.front().eraseArgument(fixup.index + 1);
535             offset++;
536           }
537         } break;
538         case FixupTy::Codes::ReturnAsStore: {
539           // The value being returned is now being returned in memory (callee
540           // stack space) through a hidden reference argument.
541           auto newArg = func.front().insertArgument(fixup.index,
542                                                     newInTys[fixup.index], loc);
543           offset++;
544           func.walk([&](mlir::ReturnOp ret) {
545             rewriter->setInsertionPoint(ret);
546             auto oldOper = ret.getOperand(0);
547             auto oldOperTy = ReferenceType::get(oldOper.getType());
548             auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, newArg);
549             rewriter->create<fir::StoreOp>(loc, oldOper, cast);
550             rewriter->create<mlir::ReturnOp>(loc);
551             ret.erase();
552           });
553         } break;
554         case FixupTy::Codes::ReturnType: {
555           // The function is still returning a value, but its type has likely
556           // changed to suit the target ABI convention.
557           func.walk([&](mlir::ReturnOp ret) {
558             rewriter->setInsertionPoint(ret);
559             auto oldOper = ret.getOperand(0);
560             auto oldOperTy = ReferenceType::get(oldOper.getType());
561             auto mem =
562                 rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]);
563             auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, mem);
564             rewriter->create<fir::StoreOp>(loc, oldOper, cast);
565             mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem);
566             rewriter->create<mlir::ReturnOp>(loc, load);
567             ret.erase();
568           });
569         } break;
570         case FixupTy::Codes::Split: {
571           // The FIR argument has been split into a pair of distinct arguments
572           // that are in juxtaposition to each other. (For COMPLEX value.)
573           auto newArg = func.front().insertArgument(fixup.index,
574                                                     newInTys[fixup.index], loc);
575           if (fixup.second == 1) {
576             rewriter->setInsertionPointToStart(&func.front());
577             auto cplxTy = oldArgTys[fixup.index - offset - fixup.second];
578             auto undef = rewriter->create<UndefOp>(loc, cplxTy);
579             auto zero = rewriter->getIntegerAttr(rewriter->getIndexType(), 0);
580             auto one = rewriter->getIntegerAttr(rewriter->getIndexType(), 1);
581             auto cplx1 = rewriter->create<InsertValueOp>(
582                 loc, cplxTy, undef, func.front().getArgument(fixup.index - 1),
583                 rewriter->getArrayAttr(zero));
584             auto cplx = rewriter->create<InsertValueOp>(
585                 loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one));
586             func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx);
587             func.front().eraseArgument(fixup.index + 1);
588             offset++;
589           }
590         } break;
591         case FixupTy::Codes::Trailing: {
592           // The FIR argument has been split into a pair of distinct arguments.
593           // The first part of the pair appears in the original argument
594           // position. The second part of the pair is appended after all the
595           // original arguments. (Boxchar arguments.)
596           auto newBufArg = func.front().insertArgument(
597               fixup.index, newInTys[fixup.index], loc);
598           auto newLenArg =
599               func.front().addArgument(trailingTys[fixup.second], loc);
600           auto boxTy = oldArgTys[fixup.index - offset];
601           rewriter->setInsertionPointToStart(&func.front());
602           auto box =
603               rewriter->create<EmboxCharOp>(loc, boxTy, newBufArg, newLenArg);
604           func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
605           func.front().eraseArgument(fixup.index + 1);
606         } break;
607         }
608       }
609     }
610 
611     // Set the new type and finalize the arguments, etc.
612     newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end());
613     auto newFuncTy =
614         mlir::FunctionType::get(func.getContext(), newInTys, newResTys);
615     LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n');
616     func.setType(newFuncTy);
617 
618     for (auto &fixup : fixups)
619       if (fixup.finalizer)
620         (*fixup.finalizer)(func);
621   }
622 
623   inline bool functionArgIsSRet(unsigned index, mlir::FuncOp func) {
624     if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret"))
625       return true;
626     return false;
627   }
628 
629   /// Convert a complex return value. This can involve converting the return
630   /// value to a "hidden" first argument or packing the complex into a wide
631   /// GPR.
632   template <typename A, typename B, typename C>
633   void doComplexReturn(mlir::FuncOp func, A cmplx, B &newResTys, B &newInTys,
634                        C &fixups) {
635     if (noComplexConversion) {
636       newResTys.push_back(cmplx);
637       return;
638     }
639     auto m = specifics->complexReturnType(cmplx.getElementType());
640     assert(m.size() == 1);
641     auto &tup = m[0];
642     auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
643     auto argTy = std::get<mlir::Type>(tup);
644     if (attr.isSRet()) {
645       unsigned argNo = newInTys.size();
646       fixups.emplace_back(
647           FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::FuncOp func) {
648             func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr());
649           });
650       newInTys.push_back(argTy);
651       return;
652     }
653     fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size());
654     newResTys.push_back(argTy);
655   }
656 
657   /// Convert a complex argument value. This can involve storing the value to
658   /// a temporary memory location or factoring the value into two distinct
659   /// arguments.
660   template <typename A, typename B, typename C>
661   void doComplexArg(mlir::FuncOp func, A cmplx, B &newInTys, C &fixups) {
662     if (noComplexConversion) {
663       newInTys.push_back(cmplx);
664       return;
665     }
666     auto m = specifics->complexArgumentType(cmplx.getElementType());
667     const auto fixupCode =
668         m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType;
669     for (auto e : llvm::enumerate(m)) {
670       auto &tup = e.value();
671       auto index = e.index();
672       auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
673       auto argTy = std::get<mlir::Type>(tup);
674       auto argNo = newInTys.size();
675       if (attr.isByVal()) {
676         if (auto align = attr.getAlignment())
677           fixups.emplace_back(
678               FixupTy::Codes::ArgumentAsLoad, argNo, [=](mlir::FuncOp func) {
679                 func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr());
680                 func.setArgAttr(argNo, "llvm.align",
681                                 rewriter->getIntegerAttr(
682                                     rewriter->getIntegerType(32), align));
683               });
684         else
685           fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(),
686                               [=](mlir::FuncOp func) {
687                                 func.setArgAttr(argNo, "llvm.byval",
688                                                 rewriter->getUnitAttr());
689                               });
690       } else {
691         if (auto align = attr.getAlignment())
692           fixups.emplace_back(fixupCode, argNo, index, [=](mlir::FuncOp func) {
693             func.setArgAttr(
694                 argNo, "llvm.align",
695                 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align));
696           });
697         else
698           fixups.emplace_back(fixupCode, argNo, index);
699       }
700       newInTys.push_back(argTy);
701     }
702   }
703 
704 private:
705   // Replace `op` and remove it.
706   void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
707     op->replaceAllUsesWith(newValues);
708     op->dropAllReferences();
709     op->erase();
710   }
711 
712   inline void setMembers(CodeGenSpecifics *s, mlir::OpBuilder *r) {
713     specifics = s;
714     rewriter = r;
715   }
716 
717   inline void clearMembers() { setMembers(nullptr, nullptr); }
718 
719   CodeGenSpecifics *specifics{};
720   mlir::OpBuilder *rewriter;
721 }; // namespace
722 } // namespace
723 
724 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
725 fir::createFirTargetRewritePass(const TargetRewriteOptions &options) {
726   return std::make_unique<TargetRewrite>(options);
727 }
728